NiN(网络中的网络)
NiN(网络中的网络)
本节学习要点
- 掌握 1x1 卷积层代替全连接层的设计思路
- 学会 1x1 卷积层的定义和使用
- 了解 NiN 结构的定义
What is it NiN?
相较于 LeNet, AlexNet 与 VGG 提供了更为深度的学习网络,能够学习更复杂的特征,但是 AlexNet 与 VGG 的基本结构都是卷积层 + 全连接层,而全连接层作为最后一端,参数过多,计算量大,因此 NiN 提出,将卷积层与全连接层交替排列,减少计算,提高效率。但是问题在于,卷积层的输入输出均为四维数组,可是全连接层的输入输出为二维数组,难以实现,这时候需要用到 1x1 卷积层用来充当全连接层的作用。
本节学习要点
- 掌握 1x1 卷积层代替全连接层的设计思路
- 学习全局平均层的原理和实现
- 掌握 NiN 结构的定义
导入需要的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append(".")
import d2lzh_pytorch as d2l
NiN 块
NiN 块与 VGG 中的 VGG 块 类似,是 NiN 网络中的基本组成单元,它由一个卷积层加两个充当全连接层的 1x1 卷积层串联而成,第一个卷积层的参数可以自定义,而第二和第三个的卷积层参数一般是固定的。
定义 NiN 块的函数如下:
def nin_block(in_channels, out_channels, kernel_size, padding, stride):
net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, stride=stride),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
)
return net
NiN 网络
NiN 使用卷积窗口形状分别为 11×11、5×5 和 3×3 的卷积层,相应的输出通道数也与 AlexNet 中的一致。每个 NiN 块后接一个步幅为 2、窗口形状为 3×3 的最大池化层。
NiN 去掉了 AlexNet 最后的 3 个全连接层,取而代之地,NiN 使用了输出通道数等于标签类别数的 NiN 块,然后使用全局平均池化层[1]对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN 的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。
但是 pytorch 目前并没有给出全局池化层,需要自行定义,定义全局池化层的函数如下:
class GlobalAveragePooling(nn.Module):
def __init__(self):
super(GlobalAveragePooling, self).__init__()
def forward(self, x):
return F.avg_pool2d(x, kernel_size=x.size()[2:])
对全局平均池化层的测试
由于我看了定义和通俗理解,还是不知道这个 function.avg_pool2d()
,怎么就能变成这个全局平均池化层了,就写了个程序测试了一下,测试代码如下:
x = torch.randn(1, 16, 32, 32)
layer = GlobalAveragePooling()
print(layer(x).size())
输出结果为:torch.Size([1, 16, 1, 1])
,可以发现,这个全局平均池化层确实能够将输入的 32x32
变成 1x1
的输出,并且保证前两维不变。
定义 NiN 网络如下:
net = nn.Sequential(
nin_block(1, 96, kernel_size=11, stride=4, padding=0),
nn.MaxPool2d(kernel_size=3, stride=2),
nin_block(96, 256, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=3, stride=2),
nin_block(256, 384, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Dropout(0.5),
# 标签类别数是 10
nin_block(384, 10, kernel_size=3, stride=1, padding=1),
GlobalAveragePooling2d(),
d2l.FlattenLayer())
为了更直观地看出 NiN 对输入图像维数的操作,我们可以写一个程序测试一下。
x = torch.randn(1, 1, 224, 224)
for i in range(len(net)):
x = net[i](x)
print(i, "out shape: ", x.shape)
输出结果为:
0 out shape: torch.Size([1, 96, 54, 54])
1 out shape: torch.Size([1, 96, 26, 26])
2 out shape: torch.Size([1, 256, 26, 26])
3 out shape: torch.Size([1, 256, 12, 12])
4 out shape: torch.Size([1, 384, 12, 12])
5 out shape: torch.Size([1, 384, 5, 5])
6 out shape: torch.Size([1, 384, 5, 5])
7 out shape: torch.Size([1, 10, 5, 5])
8 out shape: torch.Size([1, 10, 1, 1])
9 out shape: torch.Size([1, 10])
可以看出网络在变的越来越深,有利于更复杂特征的学习,最后FlattenLayer()
层的意义是将四维的输出转为二维的输出,与各类别对应上。