丁香五月天婷婷久久婷婷色综合91|国产传媒自偷自拍|久久影院亚洲精品|国产欧美VA天堂国产美女自慰视屏|免费黄色av网站|婷婷丁香五月激情四射|日韩AV一区二区中文字幕在线观看|亚洲欧美日本性爱|日日噜噜噜夜夜噜噜噜|中文Av日韩一区二区

您正在使用IE低版瀏覽器,為了您的雷峰網(wǎng)賬號安全和更好的產(chǎn)品體驗(yàn),強(qiáng)烈建議使用更快更安全的瀏覽器
此為臨時鏈接,僅用于文章預(yù)覽,將在時失效
人工智能開發(fā)者 正文
發(fā)私信給AI研習(xí)社
發(fā)送

0

PyTorch 的預(yù)訓(xùn)練,是時候?qū)W習(xí)一下了

本文作者: AI研習(xí)社 編輯:賈智龍 2017-05-02 18:10
導(dǎo)語:PyTorch又簡潔又快,你試過么?

前言

最近使用 PyTorch 感覺妙不可言,有種當(dāng)初使用 Keras 的快感,而且速度還不慢。各種設(shè)計直接簡潔,方便研究,比 tensorflow 的臃腫好多了。今天讓我們來談?wù)?PyTorch 的預(yù)訓(xùn)練,主要是自己寫代碼的經(jīng)驗(yàn)以及論壇 PyTorch Forums上的一些回答的總結(jié)整理。

直接加載預(yù)訓(xùn)練模型

如果我們使用的模型和原模型完全一樣,那么我們可以直接加載別人訓(xùn)練好的模型:

my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

當(dāng)然這樣的加載方法是基于 PyTorch 推薦的存儲模型的方法:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

還有第二種加載方法:

my_resnet = torch.load("my_resnet.pth")

加載部分預(yù)訓(xùn)練模型

其實(shí)大多數(shù)時候我們需要根據(jù)我們的任務(wù)調(diào)節(jié)我們的模型,所以很難保證模型和公開的模型完全一樣,但是預(yù)訓(xùn)練模型的參數(shù)確實(shí)有助于提高訓(xùn)練的準(zhǔn)確率,為了結(jié)合二者的優(yōu)點(diǎn),就需要我們加載部分預(yù)訓(xùn)練模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])

model_dict = model.state_dict()

# 將pretrained_dict里不屬于model_dict的鍵剔除掉

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 更新現(xiàn)有的model_dict

model_dict.update(pretrained_dict)

# 加載我們真正需要的state_dict

model.load_state_dict(model_dict)

因?yàn)樾枰蕹P椭胁黄ヅ涞逆I,也就是層的名字,所以我們的新模型改變了的層需要和原模型對應(yīng)層的名字不一樣,比如:resnet 最后一層的名字是 fc(PyTorch 中),那么我們修改過的 resnet 的最后一層就不能取這個名字,可以叫 fc_

微改基礎(chǔ)模型預(yù)訓(xùn)練

對于改動比較大的模型,我們可能需要自己實(shí)現(xiàn)一下再加載別人的預(yù)訓(xùn)練參數(shù)。但是,對于一些基本模型 PyTorch 中已經(jīng)有了,而且我只想進(jìn)行一些小的改動那么怎么辦呢?難道我又去實(shí)現(xiàn)一遍嗎?當(dāng)然不是。

我們首先看看怎么進(jìn)行微改模型。

微改基礎(chǔ)模型

PyTorch 中的 torchvision 里已經(jīng)有很多常用的模型了,可以直接調(diào)用:

  • AlexNet

  • VGG

  • ResNet

  • SqueezeNet

  • DenseNet

import torchvision.models as models

resnet18 = models.resnet18()

alexnet = models.alexnet()

squeezenet = models.squeezenet1_0()

densenet = models.densenet_161()

但是對于我們的任務(wù)而言有些層并不是直接能用,需要我們微微改一下,比如,resnet 最后的全連接層是分 1000 類,而我們只有 21 類;又比如,resnet 第一層卷積接收的通道是 3, 我們可能輸入圖片的通道是 4,那么可以通過以下方法修改:

resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)

resnet.fc = nn.Linear(2048, 21)

簡單預(yù)訓(xùn)練

模型已經(jīng)改完了,接下來我們就進(jìn)行簡單預(yù)訓(xùn)練吧。
我們先從 torchvision 中調(diào)用基本模型,加載預(yù)訓(xùn)練模型,然后,重點(diǎn)來了,將其中的層直接替換為我們需要的層即可

resnet = torchvision.models.resnet152(pretrained=True)

# 原本為1000類,改為10類

resnet.fc = torch.nn.Linear(2048, 10)

其中使用了 pretrained 參數(shù),會直接加載預(yù)訓(xùn)練模型,內(nèi)部實(shí)現(xiàn)和前文提到的加載預(yù)訓(xùn)練的方法一樣。因?yàn)槭窍燃虞d的預(yù)訓(xùn)練參數(shù),相當(dāng)于模型中已經(jīng)有參數(shù)了,所以替換掉最后一層即可。OK!

雷鋒網(wǎng)按:本文作者ycszen,文章原載于作者的知乎專欄。


實(shí)戰(zhàn)特訓(xùn):遠(yuǎn)場語音交互技術(shù)  

智能音箱這么火,聽聲智科技CTO教你深入解析AI設(shè)備語音交互關(guān)鍵技術(shù)!

課程鏈接:http://www.mooc.ai/course/80

加入AI慕課學(xué)院人工智能學(xué)習(xí)交流QQ群:624413030,與AI同行一起交流成長

雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。

PyTorch 的預(yù)訓(xùn)練,是時候?qū)W習(xí)一下了

分享:
相關(guān)文章

編輯

聚焦數(shù)據(jù)科學(xué),連接 AI 開發(fā)者。更多精彩內(nèi)容,請?jiān)L問:yanxishe.com
當(dāng)月熱門文章
最新文章
請?zhí)顚懮暾埲速Y料
姓名
電話
郵箱
微信號
作品鏈接
個人簡介
為了您的賬戶安全,請驗(yàn)證郵箱
您的郵箱還未驗(yàn)證,完成可獲20積分喲!
請驗(yàn)證您的郵箱
立即驗(yàn)證
完善賬號信息
您的賬號已經(jīng)綁定,現(xiàn)在您可以設(shè)置密碼以方便用郵箱登錄
立即設(shè)置 以后再說