Arcfaceお試し(pytorch metric learning)
つい数年前は機械学習には手を出さず本業のゴリゴリ数値計算を追求するぜみたいな感じのスタンスを取っていたはずだが、興味本位であれこれ試してみると大変おもしろかったので、今では仕事にも活用する機運が高まっている。
本職ではないことを良いことに、フラフラと面白そうな技術とかを余力があるときに試しているのだが、最近気になっているのは深層距離学習である。Metric learningとか言われると、ぐっと来るものがある。計量とどう関係しているかまでは理解してないのだが、アバウトには似ているものを近く、似ていないものは遠くなるように学習させるという仕組みらしい。
深層学習といえば、数万枚の画像を学習させてクラス分類させるというのが、典型的な例だが、クラス分類というのは問題としてはわかり良いものの実戦で使うとなるとビミョーな時がある。顔の識別とかがそんな例っぽい。学習に使った以外のデータについても、似ているもの同士を似ていると判定させるためには、明確に区分されたクラスに振り分けるというよりはスペクトルのどこに位置するのか当てたいという感じになるんだろう。
なので、距離学習の場合、出力はラベル0,1,2,…という形ではなく、もう少し高次元のベクトルへの埋め込み結果という形になる。その埋め込んだ結果を別のモデルで解釈するなり、t-SNEやUMAPで可視化するとかの用途に使う。近さの尺度を損失関数にして学習させる必要があるので、基本的にはネットワークの上部は通常の畳み込みニューラルネットワーク等を使って、最後の全結合→損失関数の定義のとこで工夫するというのが一般的になるようだ。
近さを使った損失関数というのの定番のテクニックは、近いペア遠いペアなどを定義したり、若干面倒な部分があるぽいのだが、3年ぐらい前(?)から伸してきているArcfaceというテクニックは、追加でなにかデータの処理とかをしなくても、損失関数部分をちょっと工夫するだけで、クラス分類問題を学習させているうちに自然と距離学習にもなっているという優れものらしい。なにそれすごい。
faceってついてるし、顔認識専用なんじゃねと思って最初スルーしていたのだが、先人のまとめたわかりやすい記事とかを読むと特に用途に制限もなく、実装のコストも低く、良さげだったので試してみることにした。
ざっくりとした原理
自然と距離学習になるマジカルな損失関数はどうやって実装されているんだろうか。畳み込み層+全結合ネットワークの構成でクラス分類する場合、全結合層からの出力ベクトルを$${\mathbb{a}=(a_1, a_2, a_3,\cdots a_k)}$$とするとsoftmax関数
$$
y_i= \frac{\exp{a_i}}{\sum_k \exp(a_k)}
$$
が、分類クラスiである確率を与える。
そこで、この結果が正しい場合は小さい値、間違っている場合は大きい値を返すような損失関数として使われるのがCross-entropy loss
$$
L_j= - log \frac{\exp{a_{cj}}}{\sum_k \exp(a_k)}
$$
である。いまはサンプルjのデータについてのlossを計算することを考えていて、ここの$${a_{cj}}$$というのは、そのサンプルの正解クラスに対応する要素の値を取ってくることを指している。
この損失関数の入力は最後の全結合層の出力だが、それは、層の中の重みと入力ベクトルの行列積+バイアスの形になっている。$${W_j^Tx+b_j}$$のように表現されることが多い。重みと入力ベクトルは結局はベクトルの内積のような計算をすることになる。それならば、重みもベクトルも一回単位長に正規化してしまえば、内積と$${\cos}$$の関係を考えると、この行列積部分は高次元の超球面上への射影だと考えられるんじゃないかというのがポイントぽい。損失関数の中に学習可能な重みをいれて、バイアスは省略して、
$$
L_j= - log \frac{\exp{W_{cj}^Tx}}{\sum_k \exp(W_k^Tx)}= - log \frac{\exp(\cos\theta_{cj})}{\sum_k \exp(\cos\theta_k)}
$$
みたいな風に考えているらしい。これだけだと単に三角関数を使って計算しただである。加えて正解クラスに対しては$${\cos(\theta+m)}$$を返すようなべナルティを入れてやるのがポイントだ。そうすると、本来は正解クラスは1に近くなるはずなのだが、余分なペナルティmがあることになる。これを受けて、上部のネットワークが正解クラス内のものは成す角度が小さく、クラス間は角度が離れるように学習してくれるようだ。この正規化→内積を取る操作は、コサイン類似度の計算と同じだから、そうやって学習させた結果が類似度をよく再現しそうなのはなんとなく納得がいく。
実際には$${\cos(\theta+m)}$$の値が小さすぎて学習が滞ったりしないように、スケーリングを入れて
$$
L_j= - log \frac{\exp(s\cos(\theta_{cj}+m))}{ \exp(s\cos(\theta_{cj}+m))+\sum_{k\neq cj}\exp(s\cos\theta_k)}
$$
というのを計算するみたいだ。なので、このスケーリングsとペナルティmの値がハイパーパラメータになる。また、この損失関数はそれ自体が重みを持っているので、ステップごとにその更新処理も必要になってくる。
ひとまず動かしてみる
とりあえずこれを実際に使ってみよう。pytorchだと、pytorch metric learningというライブラリがあって、その中に距離学習で使う損失関数一式が用意されている。ArcFaceLossを使えば上記のArcfaceが使える。よくあるcifer-10でうまく距離学習ぽいことができるかを試してみた。
使ったコード
(2022/3/5:間違ってたので修正)
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.models as models
from pytorch_metric_learning import losses
from torchvision import datasets,transforms
# Set device.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load dataset.
data_path='./data_cifer'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True,
transform=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
#data augumentation
cifar10_aug = datasets.CIFAR10(data_path, train=True, download=True,
transform=transforms.Compose([
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.RandomHorizontalFlip(0.2),
transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True,
transform=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
train_loader =torch.utils.data.DataLoader(cifar10, batch_size=360, shuffle=True)
train_loader_aug =torch.utils.data.DataLoader(cifar10_aug, batch_size=360, shuffle=True)
val_loader =torch.utils.data.DataLoader(cifar10_val, batch_size=200, shuffle=True)
import torch.optim as optim
# Set a model.
model = models.resnet18()
out_features=512
model.fc = nn.Linear(in_features=512, out_features=out_features)
model = model.to(device)
# Set a metric
metric = losses.ArcFaceLoss(num_classes=10, embedding_size=out_features,scale=10, margin=0.5)
metric.to(device)
optimizer = optim.SGD([{'params': model.parameters()}, {'params': metric.parameters()}],
lr=0.01,
momentum=0.9,
weight_decay=0.001)
import datetime
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader, val_loader):
for epoch in range(1, n_epochs+1):
loss_train=0.0
model.train()
#first 10 epoch: no augumentation
#after 10 epoch start augumentation to supress overfitting
if epoch <10:
for imgs, labels in train_loader:
imgs=imgs.to(device=device)
labels=labels.to(device=device)
outputs=model(imgs)
loss=loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train+=loss.item()
else:
for imgs, labels in train_loader_aug:
imgs=imgs.to(device=device)
labels=labels.to(device=device)
outputs=model(imgs)
loss=loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train+=loss.item()
#validation
model.eval()
correct=0
total=0
loss_val=0.0
with torch.no_grad():
for val, val_labels in val_loader:
#val=preprocess_val(val)
val=val.to(device)
val_labels=val_labels.to(device)
outputs=model(val)
loss=loss_fn(outputs, val_labels)
loss_val+=loss.item()
_, predicted=torch.max(outputs, dim=1)
correct += int((predicted==val_labels).sum())
total += val_labels.shape[0]
print('{} Epoch {}, Training loss {}'.format(datetime.datetime.now(),epoch, loss_train/(len(train_loader))))
print('{} Epoch {}, Validation loss {}'.format(datetime.datetime.now(),epoch, loss_val/(len(val_loader))))
# %%
training_loop(60,optimizer=optimizer,model=model,loss_fn=metric,train_loader=train_loader,val_loader=val_loader)
#saving models
torch.save(model.state_dict(), "model_arcface_aug.pth")
補足
cifer-10なので、クラス数は10である。 そして、上部のResNet18モデルについては、最終層の出力が埋め込みの次元(今回は512にした)になるように
out_features=512
model.fc = nn.Linear(in_features=512, out_features=out_features)
と微調整する。この最終層の次元とクラス数の次元で、losses.ArcFaceLossを定義している。やってて引っかかったのがこの損失関数をGPUに送り忘れたり、
optimizer = optim.SGD([{'params': model.parameters()}, {'params': metric.parameters()}],
lr=0.01,
momentum=0.9,
weight_decay=0.001)
とoptimizerに複数のパラメータのセットを送る方法のとこだった。今回は精度を上げるというよりは、埋め込みがどうなっているかを見たいので、作ったモデルが埋め込んだデータを可視化してみる。
UMAPで可視化
cifer-10のvalidationデータを流し込んで出てきた512次元のベクトルをUMAPに突っ込んでそれっぽくなるか見てみた。処理はだいたいこんな感じ。
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.models as models
# Set a model.
model = models.resnet18()
out_features=512
model.fc = nn.Linear(in_features=512, out_features=out_features)
model_path = 'model_arcface_aug.pth'
model.load_state_dict(torch.load(model_path))
model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
#check embedding in validation data
from torchvision import datasets,transforms
data_path='./data_cifer'
cifar10_val = datasets.CIFAR10(data_path, train=False, download=False,
transform=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.4915, 0.4823, 0.4468),
(0.2470, 0.2435, 0.2616))
]))
val_loader =torch.utils.data.DataLoader(cifar10_val, batch_size=32, shuffle=True)
#get embedding
from tqdm import tqdm
predict=[]
labels=[]
with torch.no_grad():
for val, val_labels in tqdm(val_loader):
#val=preprocess_val(val)
val=val.to(device)
val_labels=val_labels.to(device)
outputs=model(val).detach().to('cpu').numpy()
labels.extend(val_labels.detach().to('cpu').numpy())
predict.extend(outputs)
#plot umap image
import umap
mapper = umap.UMAP(random_state=0)
embedding = mapper.fit_transform(predict)
import matplotlib.pyplot as plt
plt.figure(figsize=(13, 7))
plt.scatter(embedding[:, 0], embedding[:, 1],
c=labels, cmap='jet',
s=15, alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.savefig('umap_PML.png',dpi=600)
結果はこれ。
なんか抽象絵画みたいになったが、クラスごとに分離ができている。(通常の中間出力に比べてどれぐらい向上しているかはよくわからないが)
損失関数を書き換えるだけで、典型的なモデルと組み合わせて使える&クラス分類というやりやすいタスクで距離学習になるのは便利なので、自分の仕事関係でもなにか使えないか考えてみようかなと思った。
この記事が気に入ったらサポートをしてみませんか?