跳到主要內容

RepVGG 論文解讀,新的模型架構設計,CVPR 2021

簡介

RepVGG: Making VGG-style ConvNets Great Again 將會收錄在 CVPR 2021,在現在深度學習模型越來越複雜的,或是各種NAS(neural architecture serach) 動輒幾千GPU小時找出的模型架構的時代中令人眼睛一亮的論文,在視覺相關任務的模型設計上可以提供很大的啟發。

建議讀者先有下列的知識,在閱讀此文章或論文時會好理解許多:   
 1. 深度學習常見模型的架構: VGGResNet, Densenet, Inception 等,理解 CNN 模型演變史
 2. 熟悉將 BN (batch norm) 的參數合併進 convolution 的方法。因 BN 和 convolution 實際上都是矩陣乘法和加法的運算,因此模型在推論 (inference) 時可以將 BN 合併進 convolution 來減少模型的參數和計算量,詳細的計算可以參考這邊

背景

近年來深度學習模型的發展,多分支(multi-branch) 模型逐漸成為主流,例如 ResNet 的殘差連結(residual connect) 讓訓練較深的模型變容易,Inception 則藉由多分支來獲得不同感受野(receptive fileld)的特徵。多分支模型往往可以達到較高的效益,但也有一個很明顯地缺點: 記憶體使用量,多分支的模型比須保存中間結果直到分支合併,導致推論時速度變慢或是需要更多的記憶體,不利於工業界做模型的部署與加速。

為了解決這個問題,作者設計了 RepVGG,  在多分支的狀況下訓練模型,並用結構重新參數化的方法將分支合併,讓模型在推論時只有單一分支,藉此達到高準確率又能維持記憶體使用量。
 
RepVGG 論文解讀,新的模型架構設計,CVPR 2021
多分支的記憶體使用量較多示意圖


模型設計

論文首先解釋為何模型骨架要使用 VGG,主要因為下列三個原因:
1. 速度快: RepVGG 只使用 3*3 convolution , NVIDIA cuDNN 或是 Intel MKL 都有對 3*3 convolution 做加速,而且未來如果某個硬體想對 RepVGG 的部署做優化,就只需要針對 3*3 convolution 的計算特別優化,不必考慮其他操作。
2. 節省記憶體: RepVGG 只有單一分支,不會遇到多分支模型需要保存中間結果導致記憶體使用量過大的麻煩。
3. 靈活有彈性: 多分支模型在做剪枝(pruning) 的時候會有許多限制,例如兩個需要合併的分支在做剪枝的時候必須移除相同的 channel ,導致許多剪枝的方法做在多分支模型上效果不佳。

RepVGG 的模型其實很好理解,在原始 VGG 的架構中加入 1*1 的 convolution 分支和殘差分支,需要特別注意的地方是這些分支在合併前不會經過 relu ,讓 RepVGG 可以在訓練結束後將不同分支做合併。實驗結果也顯示多分支的效果顯著 ( 論文 Table. 4 ),RepVGG-A0 在不到 VGG16 十分之一的參數量下即可有相近的準確率,而 RepVGG-A2 則不論是準確率、速度和參數量都比 ResNet-50 優異。
RepVGG的模型設計


分支的合併可以說是 RepVGG 的精隨,如下圖所示:  原本模型有 1 個 3*3 convolution 、1 個 1*1 的 convolution 和 1 個 identity 的殘差分支,首先 identity 可以視為 1*1 convolution 的特例,因為只要 1*1 convolution 的 kernel 對應原始通道位置的權重等於 1,而其他通道位置的權重皆等於 0 ,就不會對輸入做任何的改變。而 1*1 的 convolution 又可以視為 3*3 的 convolution中的特例,也就是 kernel 只有中間非零,其餘八個位置的參數都是零的 3*3 convolution 。因此 RepVGG 可以視為有 3 個 3*3 convolution 分支。接下來將 BN 的參數合併到 3*3 convolution 的運算就跟本文開頭提到的部分一樣,而原本沒有 bias 的 convolution 在合併 BN 後會多了 bias 項 。最後將 3 個 3*3 convolution 的參數直接相加就得到 RepVGG 在推論階段時的模型參數。因此 RepVGG 可以在訓練時使用多分支來提高準確率,並且在推論可以只用一個3*3 convolution,達到高準確率又能有較低的記憶體使用量。






實驗

實驗的部分就是展現 RepVGG 的快速和準確。比較有趣的是 ablation study 中的 +ReLU in branch ( 論文 Table. 7 ) 可以讓 RepVGG 得到更高的準確率,但加了 ReLU 這個非線性的操作後就無法將不同分支的 convolution 合併了。 這也是 ResNet 無法合併不同分支的主因,或許如何讓 ResNet 用類似的方法合併分支是不錯的研究方向,還可以結合模型剪枝在有分支的模型上的運用等相關問題....

結論

RepVGG 的模型架構簡潔有效,運用多分支的架構訓練模型來達到高準確率,再將模型轉變成單分支來降低記憶體使用量的想法值得效法,是一篇非常有價值的論文。

延伸閱讀

1. Winograd Convolution: 加速 3*3 convolution 的演算法,開放課程有講解
2. ResNet 真的輸了嗎? 新的訓練 ResNet 的方法: Revisiting ResNets: Improved Training and Scaling Strategies


















留言

張貼留言

這個網誌中的熱門文章

春美harumi,文山區平價日式丼飯

身為文山區在地住民,我竟然不知道興隆路上開了一家平價又好吃的春美日式料理,這家讓我想起公館的靜壽司,兩家店有一點像,都有賣壽司丼飯,但春美除了丼飯還有賣炒飯、煎蛋跟其他小菜,菜單比較多元一點,喜歡平價日式料理的朋友絕對不能錯過。 要提醒大家的是,店內的位子不多、生意很好,若想早點吃到美食建議早一點到場。 春美-鮭魚丼 春美-菜單 本 來想點 google 評論上很推的秋葵(不在菜單上)與水蓮,但很可惜當日沒有賣,想吃這兩樣的朋友可能要碰碰運氣了。 春美-鮭魚炒飯 春美-塔香蝦卵煎蛋 我們各自點了鮭魚炒飯跟鮪魚丼,還有 google 評論很推的塔香蝦卵煎蛋,老闆先送了剛煮好的炒飯跟煎蛋,上來的菜還熱騰騰,煎蛋很濕潤很香,炒飯有我很喜歡的大鍋快炒的一種焦味(?),是真的蠻好吃的。 鮭魚丼上面的鮪魚卵我很喜歡,吃起來很像真的(?),但鮭魚本人我覺得還好,值得一提的是,他給了一大坨"真實"的山葵,吃得出來不是化學做的,飯上的小菜也都很好吃。 我們去的時候老闆只有一人,因此炒飯跟煎蛋上菜會等久一點,老闆都先煮好後才後續上大家丼飯。 [整體心得]: 好吃,適合學生族,吃習慣超頂級生魚片的人可能不適合,例如我爸。 生意很好,不適合趕時間的人。 [春美-harumi] 地址:11649台北市文山區興隆路四段71號 營業時間: 週一~週六 17:00~23:00 週日 17:00~22:30

netron 好用的模型可視化工具

 深度學習模型可視化 深度學習的框架越來越多,從最多人用的 Tensorflow, Pytorch, 已經被併入 Tenserflow 的 Keras, 或是工業部署會用到的 Tensorflow lite 等等,當你今天拿到一個模型,想要知道模型架構是如何設計,使用可視化工具是一種方便且簡單的方式。 然而,不同的框架對模型可視化的支援度不一樣,因此 netron  的出現可以省去許多麻煩。自從我去公司上班後才發現這個好用的工具,真是相見恨晚啊~ netron 簡介 netron 的 github 頁面列出了所有支援的模型檔案,真的是林林總總,學都學不完啊,如果每次拿到一個新框架的模型,為了看懂模型架構而從頭開始學習,實在是太浪費時間了。netron 就可以幫助你快速的理解模型架構,netron對不同框架的支援度不太一樣,有些已經非常完整,有些可能會有一點小問題,大家可以自行嘗試。 netron 可以安裝在 macOS, Windows, Linux, 也可以直接使用網頁板,非常的方便,我平常使用時都直接用網頁版的。netron 的操作非常直觀,滑鼠滾輪放大縮小,按住左鍵可以上下移動。 netron 可視化結果 netron 的 github 已經提供許多範例了,這邊放個比較不一樣的: tensorflow lite 的 quantize model,quantize(量化) 可以有效減少模型大小,是部署深度學習模型常用的方法。這邊從 tensorflow hub 下載的 mobilenet_v1_0.25_128_quantized  ,用 netron 來看看模型的架構和參數。 下圖是選取第一個 Conv2D 的結果,可以看到 netron 把 Conv2D 的參數,input, weight, bias, 和 output 全都包了,提供了許多有用的資訊。比如說這個模型的輸入是大小 128 * 128 的彩色圖片,第一個 Conv2D 的 bias 實際的數值是多少等等。有關 quantize 的參數若想知道可以在底下告訴我。 netron 視覺化範例 Pytorch 呢,怎麼畫出來這麼奇怪 很可惜,如果你今天只有 pytorch 的 model weight 檔 (.pth) ,是沒有辦法知道模型是怎麼相連的,這是因為 pytor...