通過(guò)thop計(jì)算模型的參數(shù)量與運(yùn)算量(FLOPs)
最近想要計(jì)算模型的參數(shù)量和運(yùn)算量(FLOPs),瀏覽了一些帖子,發(fā)現(xiàn)可以通過(guò)Python中的thop工具包來(lái)實(shí)現(xiàn),但查到的資料中僅介紹了對(duì)一些簡(jiǎn)單模型(例如Resnet50)的計(jì)算,而沒(méi)有考慮復(fù)雜模型的情況。研究了一番后找到了一個(gè)通過(guò)thop對(duì)復(fù)雜模型進(jìn)行計(jì)算的方法,雖然不知道是否準(zhǔn)確,但先進(jìn)行一下記錄。
1. 參數(shù)量計(jì)算(通過(guò)pytorch)
在正式介紹thop之前,先講一下torch的內(nèi)置方法。如果僅僅想要知道模型中參數(shù)的數(shù)量,那么無(wú)需安裝thop,通過(guò)調(diào)用torch內(nèi)置的一些方法就可以實(shí)現(xiàn),代碼如下:
其中model是待計(jì)算模型的對(duì)象。
2.thop的安裝與計(jì)算簡(jiǎn)單模型的參數(shù)
(1)安裝:通過(guò)pip安裝即可,命令如下
(2)簡(jiǎn)單模型的參數(shù)量與FLOPs計(jì)算
上面的代碼摘自官方文檔。在調(diào)用thop工具包時(shí),需要將模型的輸入作為參數(shù)傳入。在上面的代碼中,第4行就是隨機(jī)生成了一個(gè)模型的輸入,(1,3,224,224)是一個(gè)四維矩陣,其含義是batch size = 1,通道數(shù)為3,分辨率為224*224。
3.復(fù)雜模型的參數(shù)量與FLIPs計(jì)算
從上面的介紹可以發(fā)現(xiàn),thop只能計(jì)算輸入為一個(gè)矩陣的模型的參數(shù)量和FLOPs,而我們實(shí)際想要計(jì)算的模型可能要復(fù)雜的多。例如我想要計(jì)算的一個(gè)模型需要接收五個(gè)矩陣作為輸入,調(diào)用模型的代碼為:
其中作為模型傳入?yún)?shù)的是五個(gè)矩陣,其尺寸為:
想要計(jì)算這樣一個(gè)模型的參數(shù)量和計(jì)算量,需要考慮一種新的方式。我在這里給出的方案是:構(gòu)建一個(gè)新的,接收單矩陣輸入的類(lèi)。代碼如下:
可以看到,該類(lèi)繼承自nn.Module,以確保該類(lèi)可以被聲明為一個(gè)模型。接著在類(lèi)的初始化函數(shù)中,將待計(jì)算模型的對(duì)象作為參數(shù)傳入,并賦值給self.model;在forward函數(shù)中,規(guī)定其必須接收一個(gè)參數(shù)(實(shí)際上我們并不會(huì)使用這個(gè)接收到的參數(shù)),并通過(guò)torch內(nèi)置的隨機(jī)函數(shù)產(chǎn)生需要形狀的張量(之所以特殊對(duì)待captions是因?yàn)閙odel要求該參數(shù)的元素為整型)。
最后,就可以進(jìn)行參數(shù)量和FLOPs的計(jì)算,代碼如下:
其中,第一行聲明Transformer類(lèi),并將其對(duì)象命名為model;
第二行聲明上面定義的func類(lèi),將model作為聲明類(lèi)時(shí)的傳入?yún)?shù);
第三行隨機(jī)生成一個(gè)矩陣
第四行調(diào)用thop中的profile方法對(duì)func類(lèi)的參數(shù)量和運(yùn)算量進(jìn)行計(jì)算,實(shí)際上等價(jià)于對(duì)Transformer類(lèi)的參數(shù)量和運(yùn)算量計(jì)算。
(完)