GoogleNet
大约 2 分钟
GoogleNet
导入需要的包
import sys
import torch
import torch.nn as nn
from torch.nn import functional as F
sys.path.append(".")
import d2lzh_pytorch as d2l
Inception 块
Inception 块的基础结构为:
Inception 块是GoogleNet
的基本组成单位,分为四条通道,也就是对同一个输入图像做四种处理,四条通道通过对填充与步幅的合理规定,可以使得四条通道得到的图像尺寸相同,最后将四个通道的输出在通道维上连结。
注意这里说的
四条通道
指的是四种处理方式,而非图像尺寸中的通道
概念。最后的通道维
才指的是图像尺寸中的通道
。
Inception 块的超参数是各个卷积/池化层的输出通道数,对于多层的通道,用tuple
结构定义某通道内各层的输出通道数。
定义代码如下:
class inception_block(nn.Module):
def __init__(self, in_c, c1, c2: tuple, c3: tuple, c4):
super(inception_block, self).__init__()
# 左 1 通道
self.p_1 = nn.Conv2d(in_c, c1, kernel_size=1)
# 左 2 通道
self.p_21 = nn.Conv2d(in_c, c2[0], kernel_size=1)
self.p_22 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
# 左 3 通道
self.p_31 = nn.Conv2d(in_c, c3[0], kernel_size=1)
self.p_32 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
# 左 4 通道
self.p_41 = nn.MaxPool2d(kernel_size=3, padding=1)
self.p_42 = nn.Conv2d(in_c, c4, kernel_size=1)
def forward(self, x):
c1 = F.relu(self.p_1(x))
c2 = F.relu(self.p_22(F.relu(self.p_21(x))))
c3 = F.relu(self.p_32(F.relu(self.p_31(x))))
c4 = F.relu(self.p_42(self.p_41(x)))
return torch.cat((c1, c2, c3, c4), dim=1)
GoogleNet 网络
网络结构见原文档
代码如下:
b1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
b2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=1),
nn.Conv2d(64, 192, kernel_size=3, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
b3 = nn.Sequential(
Inception(192, 64, (96, 128), (16, 32), 32),
Inception(256, 128, (128, 192), (32, 96), 64),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
b4 = nn.Sequential(
Inception(480, 192, (96, 208), (16, 48), 64),
Inception(512, 160, (112, 224), (24, 64), 64),
Inception(512, 128, (128, 256), (24, 64), 64),
Inception(512, 112, (144, 288), (32, 64), 64),
Inception(528, 256, (160, 320), (32, 128), 128),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
b5 = nn.Sequential(
Inception(832, 256, (160, 320), (32, 128), 128),
Inception(832, 384, (192, 384), (48, 128), 128),
d2l.GlobalAvgPool2d(),
)
net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))