0
譯者:AI研習社(季一帆)
雙語原文鏈接:Easy Self-Supervised Learning with BYOL
注:本文所有代碼可見Google Colab notebook,你可用Colab的免費GPU運行或改進。
在深度學習中,經(jīng)常遇到的問題是沒有足夠的標記數(shù)據(jù),而手工標記數(shù)據(jù)耗費大量時間且人工成本高昂?;诖耍晕冶O(jiān)督學習成為深度學習的研究熱點,旨在從未標記樣本中進行學習,以緩解數(shù)據(jù)標注困難的問題。子監(jiān)督學習的目標很簡單,即訓練一個模型使得相似的樣本具有相似的表示,然而具體實現(xiàn)卻困難重重。經(jīng)過谷歌這樣的諸多先驅(qū)者若干年的研究,子監(jiān)督學習如今已取得一系列的進步與發(fā)展。
在BYOL之前,多數(shù)自我監(jiān)督學習都可分為對比學習或生成學習,其中,生成學習一般GAN建模完整的數(shù)據(jù)分布,計算成本較高,相比之下,對比學習方法就很少面臨這樣的問題。對此,BYOL的作者這樣說道:
通過對比方法,同一圖像不同視圖的表示更接近(正例),不同圖像視圖的表示相距較遠(負例),通過這樣的方式減少表示的生成成本。
為了實現(xiàn)對比方法,我們必須將每個樣本與其他許多負例樣本進行比較。然而這樣會使訓練很不穩(wěn)定,同時會增大數(shù)據(jù)集的系統(tǒng)偏差。BYOL的作者顯然明白這點:
對比方法對圖像增強的方式非常敏感。例如,當消除圖像增強中的顏色失真時,SimCLR表現(xiàn)不佳??赡艿脑蚴牵粓D像的不同裁切一般會共享顏色直方圖,而不同圖像的顏色直方圖是不同的。因此,在對比任務(wù)中,可以通過關(guān)注顏色直方圖,使用隨機裁切方式實現(xiàn)圖像增強,其結(jié)果表示幾乎無法保留顏色直方圖之外的信息。
不僅僅是顏色失真,其他類型的數(shù)據(jù)轉(zhuǎn)換也是如此。一般來說,對比訓練對數(shù)據(jù)的系統(tǒng)偏差較為敏感。在機器學習中,數(shù)據(jù)偏差是一個廣泛存在的問題(見facial recognition for women and minorities),這對對比方法來說影響更大。不過好在BYOL不依賴負采樣,從而很好的避免了該問題。
BYOL的目標與對比學習相似,但一個很大的區(qū)別是,BYOL不關(guān)心不同樣本是否具有不同的表征(即對比學習中的對比部分),僅僅使相似的樣品表征類似。看上去似乎無關(guān)緊要,但這樣的設(shè)定會顯著改善模型訓練效率和泛化能力:
由于不需要負采樣,BLOY有更高的訓練效率。在訓練中,每次遍歷只需對每個樣本采樣一次,而無需關(guān)注負樣本。
BLOY模型對訓練數(shù)據(jù)的系統(tǒng)偏差不敏感,這意味著模型可以對未見樣本也有較好的適用性。
BYOL最小化樣本表征和該樣本變換之后的表征間的距離。其中,不同變換類型包括0:平移、旋轉(zhuǎn)、模糊、顏色反轉(zhuǎn)、顏色抖動、高斯噪聲等(我在此以圖像操作來舉例說明,但BYOL也可以處理其他數(shù)據(jù)類型)。至于是單一變換還是幾種不同類型的聯(lián)合變換,這取決于你自己,不過我一般會采用聯(lián)合變換。但有一點需要注意,如果你希望訓練的模型能夠應(yīng)對某種變換,那么用該變換處理訓練數(shù)據(jù)時必要的。
手把手教你編碼BYOL
首先是數(shù)據(jù)轉(zhuǎn)換增強的編碼。BYOL的作者定義了一組類似于SimCLR的特殊轉(zhuǎn)換:
import random from typing import Callable, Tuple from kornia import augmentation as aug from kornia import filters from kornia.geometry import transform as tf import torch from torch import nn, Tensor class RandomApply(nn.Module): def __init__(self, fn: Callable, p: float): super().__init__() self.fn = fn self.p = p def forward(self, x: Tensor) -> Tensor: return x if random.random() > self.p else self.fn(x) def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: return nn.Sequential( tf.Resize(size=image_size), RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), aug.RandomGrayscale(p=0.2), aug.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), aug.RandomResizedCrop(size=image_size), aug.Normalize( mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]), ), ) |
上述代碼通過Kornia實現(xiàn)數(shù)據(jù)轉(zhuǎn)換,這是一個基于 PyTorch 的可微分的計算機視覺開源庫。當然,你可以用其他開源庫實現(xiàn)數(shù)據(jù)轉(zhuǎn)換擴充,甚至是自己編寫。實際上,可微分性對BYOL而言并沒有那么必要。
接下來,我們編寫編碼器模塊。該模塊負責從基本模型提取特征,并將這些特征投影到低維隱空間。具體的,我們通過wrapper類實現(xiàn)該模塊,這樣我們可以輕松將BYOL用于任何模型,無需將模型編碼到腳本。該類主要由兩部分組成:
特征抽取,獲取模型最后一層的輸出。
映射,非線性層,將輸出映射到更低維空間。
特征提取通過hooks實現(xiàn)(如果你不了解hooks,推薦閱讀我之前的介紹文章How to Use PyTorch Hooks)。除此之外,代碼其他部分很容易理解。
from typing import Union def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module: return nn.Sequential( nn.Linear(dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), nn.Linear(hidden_size, projection_size), ) class EncoderWrapper(nn.Module): def __init__( self, model: nn.Module, projection_size: int = 256, hidden_size: int = 4096, layer: Union[str, int] = -2, ): super().__init__() self.model = model self.projection_size = projection_size self.hidden_size = hidden_size self.layer = layer self._projector = None self._projector_dim = None self._encoded = torch.empty(0) self._register_hook() @property def projector(self): if self._projector is None: self._projector = mlp( self._projector_dim, self.projection_size, self.hidden_size ) return self._projector def _hook(self, _, __, output): output = output.flatten(start_dim=1) if self._projector_dim is None: self._projector_dim = output.shape[-1] self._encoded = self.projector(output) def _register_hook(self): if isinstance(self.layer, str): layer = dict([*self.model.named_modules()])[self.layer] else: layer = list(self.model.children())[self.layer] layer.register_forward_hook(self._hook) def forward(self, x: Tensor) -> Tensor: _ = self.model(x) return self._encoded |
BYOL包含兩個相同的編碼器網(wǎng)絡(luò)。第一個編碼器網(wǎng)絡(luò)的權(quán)重隨著每一訓練批次進行更新,而第二個網(wǎng)絡(luò)(稱為“目標”網(wǎng)絡(luò))使用第一個編碼器權(quán)重均值進行更新。在訓練過程中,目標網(wǎng)絡(luò)接收原始批次訓練數(shù)據(jù),而另一個編碼器則接收相應(yīng)的轉(zhuǎn)換數(shù)據(jù)。兩個編碼器網(wǎng)絡(luò)會分別為相應(yīng)數(shù)據(jù)生成低維表示。然后,我們使用多層感知器預(yù)測目標網(wǎng)絡(luò)的輸出,并最大化該預(yù)測與目標網(wǎng)絡(luò)輸出之間的相似性。
圖源:Bootstrap Your Own Latent, Figure 2
也許有人會想,我們不是應(yīng)該直接比較數(shù)據(jù)轉(zhuǎn)換之前和之后的隱向量表征嗎?為什么還有設(shè)計多層感知機?假設(shè)沒有MLP層的話,網(wǎng)絡(luò)可以通過將權(quán)重降低到零方便的使所有圖像的表示相似化,可這樣模型并沒有學到任何有用的東西,而MLP層可以識別出數(shù)據(jù)轉(zhuǎn)換并預(yù)測目標隱向量。這樣避免了權(quán)重趨零,可以學習更恰當?shù)臄?shù)據(jù)表示!
訓練結(jié)束后,舍棄目標網(wǎng)絡(luò)編碼器,只保留一個編碼器,根據(jù)該編碼器,所有訓練數(shù)據(jù)可生成自洽表示。這正是BYOL能夠進行自監(jiān)督學習的關(guān)鍵!因為學習到的表示具有自洽性,所以經(jīng)不同的數(shù)據(jù)變換后幾乎保持不變。這樣,模型使得相似示例的表示更加接近!
接下來編寫B(tài)YOL的訓練代碼。我選擇使用Pythorch Lightning開源庫,該庫基于PyTorch,對深度學習項目非常友好,能夠進行多GPU培訓、實驗日志記錄、模型斷點檢查和混合精度訓練等,甚至在cloud TPU上也支持基于該庫運行PyTorch模型!
from copy import deepcopy from itertools import chain from typing import Dict, List import pytorch_lightning as pl from torch import optim import torch.nn.functional as f def normalized_mse(x: Tensor, y: Tensor) -> Tensor: x = f.normalize(x, dim=-1) y = f.normalize(y, dim=-1) return 2 - 2 * (x * y).sum(dim=-1) class BYOL(pl.LightningModule): def __init__( self, model: nn.Module, image_size: Tuple[int, int] = (128, 128), hidden_layer: Union[str, int] = -2, projection_size: int = 256, hidden_size: int = 4096, augment_fn: Callable = None, beta: float = 0.99, **hparams, ): super().__init__() self.augment = default_augmentation(image_size) if augment_fn is None else augment_fn self.beta = beta self.encoder = EncoderWrapper( model, projection_size, hidden_size, layer=hidden_layer ) self.predictor = nn.Linear(projection_size, projection_size, hidden_size) self.hparams = hparams self._target = None self.encoder(torch.zeros(2, 3, *image_size)) def forward(self, x: Tensor) -> Tensor: return self.predictor(self.encoder(x)) @property def target(self): if self._target is None: self._target = deepcopy(self.encoder) return self._target def update_target(self): for p, pt in zip(self.encoder.parameters(), self.target.parameters()): pt.data = self.beta * pt.data + (1 - self.beta) * p.data # --- Methods required for PyTorch Lightning only! --- def configure_optimizers(self): optimizer = getattr(optim, self.hparams.get("optimizer", "Adam")) lr = self.hparams.get("lr", 1e-4) weight_decay = self.hparams.get("weight_decay", 1e-6) return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay) def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x = batch[0] with torch.no_grad(): x1, x2 = self.augment(x), self.augment(x) pred1, pred2 = self.forward(x1), self.forward(x2) with torch.no_grad(): targ1, targ2 = self.target(x1), self.target(x2) loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) self.log("train_loss", loss.item()) return {"loss": loss} @torch.no_grad() def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x = batch[0] x1, x2 = self.augment(x), self.augment(x) pred1, pred2 = self.forward(x1), self.forward(x2) targ1, targ2 = self.target(x1), self.target(x2) loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) return {"loss": loss} @torch.no_grad() def validation_epoch_end(self, outputs: List[Dict]) -> Dict: val_loss = sum(x["loss"] for x in outputs) / len(outputs) self.log("val_loss", val_loss.item()) |
上述代碼部分源自Pythorch Lightning提供的示例代碼。這段代碼你尤其需要關(guān)注的是training_step,在此函數(shù)實現(xiàn)模型的數(shù)據(jù)轉(zhuǎn)換、特征投影和相似性損失計算等。
下文我們將在STL10數(shù)據(jù)集上對BYOL進行實驗驗證。因為該數(shù)據(jù)集同時包含大量未標記的圖像以及標記的訓練和測試集,非常適合無監(jiān)督和自監(jiān)督學習實驗。STL10網(wǎng)站這樣描述該數(shù)據(jù)集:
STL-10數(shù)據(jù)集是一個用于研究無監(jiān)督特征學習、深度學習、自學習算法的圖像識別數(shù)據(jù)集。該數(shù)據(jù)集是對CIFAR-10數(shù)據(jù)集的改進,最明顯的便是,每個類的標記訓練數(shù)據(jù)比CIFAR-10中的要少,但在監(jiān)督訓練之前,數(shù)據(jù)集提供大量的未標記樣本訓練模型學習圖像模型。因此,該數(shù)據(jù)集主要的挑戰(zhàn)是利用未標記的數(shù)據(jù)(與標記數(shù)據(jù)相似但分布不同)來構(gòu)建有用的先驗知識。
通過Torchvision可以很方便的加載STL10,因此無需擔心數(shù)據(jù)的下載和預(yù)處理。
from torchvision.datasets import STL10 from torchvision.transforms import ToTensor TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor()) TRAIN_UNLABELED_DATASET = STL10( root="data", split="train+unlabeled", download=True, transform=ToTensor() ) TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor()) |
同時,我們使用監(jiān)督學習方法作為基準模型,以此衡量本文模型的準確性?;€模型也可通過Lightning模塊輕易實現(xiàn):
class SupervisedLightningModule(pl.LightningModule): def __init__(self, model: nn.Module, **hparams): super().__init__() self.model = model def forward(self, x: Tensor) -> Tensor: return self.model(x) def configure_optimizers(self): optimizer = getattr(optim, self.hparams.get("optimizer", "Adam")) lr = self.hparams.get("lr", 1e-4) weight_decay = self.hparams.get("weight_decay", 1e-6) return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay) def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x, y = batch loss = f.cross_entropy(self.forward(x), y) self.log("train_loss", loss.item()) return {"loss": loss} @torch.no_grad() def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: x, y = batch loss = f.cross_entropy(self.forward(x), y) return {"loss": loss} @torch.no_grad() def validation_epoch_end(self, outputs: List[Dict]) -> Dict: val_loss = sum(x["loss"] for x in outputs) / len(outputs) self.log("val_loss", val_loss.item()) |
可以看到,使用Pythorch Lightning可以方便的構(gòu)建并訓練模型。只需為訓練集和測試集創(chuàng)建DataLoader
對象,將其導入需要訓練的模型即可。本實驗中,epoch設(shè)置為25,學習率為1e-4。
from os import cpu_count from torch.utils.data import DataLoader from torchvision.models import resnet18 model = resnet18(pretrained=True) supervised = SupervisedLightningModule(model) trainer = pl.Trainer(max_epochs=25, gpus=-1, weights_summary=None) train_loader = DataLoader( TRAIN_DATASET, batch_size=128, shuffle=True, drop_last=True, ) val_loader = DataLoader( TEST_DATASET, batch_size=128, ) trainer.fit(supervised, train_loader, val_loader) |
接下來,我們使用BYOL對ResNet18模型進行預(yù)訓練。在這次實驗中,我選擇epoch為50,學習率依然是1e-4。注:該過程是本文代碼耗時最長的部分,在K80 GPU的標準Colab中大約需要45分鐘。
model = resnet18(pretrained=True) byol = BYOL(model, image_size=(96, 96)) trainer = pl.Trainer( max_epochs=50, gpus=-1, accumulate_grad_batches=2048 // 128, weights_summary=None, ) train_loader = DataLoader( TRAIN_UNLABELED_DATASET, batch_size=128, shuffle=True, drop_last=True, ) trainer.fit(byol, train_loader, val_loader) |
然后,我們使用新的ResNet18模型重新進行監(jiān)督學習。(為徹底清除BYOL中的前向hook,我們實例化一個新模型,在該模型引入經(jīng)過訓練的狀態(tài)字典。)
# Extract the state dictionary, initialize a new ResNet18 model, # and load the state dictionary into the new model. # # This ensures that we remove all hooks from the previous model, # which are automatically implemented by BYOL. state_dict = model.state_dict() model = resnet18() model.load_state_dict(state_dict) supervised = SupervisedLightningModule(model) trainer = pl.Trainer( max_epochs=25, gpus=-1, weights_summary=None, ) train_loader = DataLoader( TRAIN_DATASET, batch_size=128, shuffle=True, drop_last=True, ) trainer.fit(supervised, train_loader, val_loader) |
通過這種方式,模型準確率提高了約2.5%,達到了87.7%!雖然該方法需要更多的代碼(大約300行)以及一些庫的支撐,但相比其他自監(jiān)督方法仍顯得簡潔。作為對比,可以看下官方的SimCLR或SwAV是多么復雜。而且,本文具有更快的訓練速度,即使是Colab的免費GPU,整個實驗也不到一個小時。
本文要點總結(jié)如下。首先也是最重要的,BYOL是一種巧妙的自監(jiān)督學習方法,可以利用未標記的數(shù)據(jù)來最大限度地提高模型性能。此外,由于所有ResNet模型都是使用ImageNet進行預(yù)訓練的,因此BYOL的性能優(yōu)于預(yù)訓練的ResNet18。STL10是ImageNet的一個子集,所有圖像都從224x224像素縮小到96x96像素。雖然分辨率發(fā)生改變,我們希望自監(jiān)督學習能避免這樣的影響,表現(xiàn)出較好性能,而僅僅依靠STL10的小規(guī)模訓練集是不夠的。
類似ResNet這樣的模型中,ML從業(yè)人員過于依賴預(yù)先訓練的權(quán)重。雖然這在一定情況下是很好的選擇,但不一定適合其他數(shù)據(jù),哪怕在STL10這樣與ImageNet高度相似的數(shù)據(jù)中表現(xiàn)也不如人意。因此,我迫切希望將來在深度學習的研究中,自監(jiān)督方法能夠獲得更多的關(guān)注與實踐應(yīng)用。
https://arxiv.org/pdf/2006.07733.pdf
https://arxiv.org/pdf/2006.10029v2.pdf
https://github.com/fkodom/byol
https://github.com/lucidrains/byol-pytorch
https://github.com/google-research/simclr
https://cs.stanford.edu/~acoates/stl10/
AI研習社是AI學術(shù)青年和AI開發(fā)者技術(shù)交流的在線社區(qū)。我們與高校、學術(shù)機構(gòu)和產(chǎn)業(yè)界合作,通過提供學習、實戰(zhàn)和求職服務(wù),為AI學術(shù)青年和開發(fā)者的交流互助和職業(yè)發(fā)展打造一站式平臺,致力成為中國最大的科技創(chuàng)新人才聚集地。
如果,你也是位熱愛分享的AI愛好者。歡迎與譯站一起,學習新知,分享成長。
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。