1.介绍

    从MobileNet V2的名字,我们就知道,它是对基于MobileNet V1 而进行改进的,同样也是Google针对手机的智能型嵌入式设备提出的一种轻量级深度卷积神经网络,该网络的核心为深度可分离卷积和颠倒残差块。

2.模型结构    可分离卷积参见我上一篇博客文章MobileNets V1神经网络简介与代码实战_天竺街潜水的八角的博客-CSDN博客,这一篇文章重点介绍一下颠倒残差块(Inverted Residual Block), MobileNet V1没有利用残差网络,而残差网络一般能使网络的表现更好,因此,MobileNet V2加入了残差网络。下图左边为原始残差块,先用1x1降通道过ReLU,再3x3空间卷积过ReLU,再用1x1卷积过ReLU恢复通道,并和输入相加。之所以要1x1卷积降通道,是为了减少计算量,不然中间的3x3空间卷积计算量太大。所以Residual block是沙漏形,两边宽中间窄。但是,现在我们中间的3x3卷积变为了Depthwise的了,计算量很少了,所以通道可以多一点,效果更好,所以通过1x1卷积先提升通道数,再Depthwise的3x3空间卷积,再用1x1卷积降低维度。两端的通道数都很小,所以1x1卷积升通道或降通道计算量都并不大,而中间通道数虽然多,但是Depthwise 的卷积计算量也不大。作者称之为Inverted Residual Block,两边窄中间宽,像柳叶,较小的计算量得到较好的性能。

3.模型特点     MobileNet V2相对于MobileNets V1有以下两个特点:

    1. Inverted Residual Block,利用较小的计算量得到比较好的性能表现。

    2. 最后输出的 ReLU6 (MobileNets V1中为了满足移动端设备 float16/int8 的低精度)去掉,直接线性输出,提高模型的表现效果(Xception已经实验证明了 Depthwise 卷积后再加ReLU 效果会变差)

 4.代码实现 pytorch

class BaseBlock(nn.Module):

alpha = 1

def __init__(self, input_channel, output_channel, t = 6, downsample = False):

"""

t: expansion factor, t*input_channel is channel of expansion layer

alpha: width multiplier, to get thinner models

rho: resolution multiplier, to get reduced representation

"""

super(BaseBlock, self).__init__()

self.stride = 2 if downsample else 1

self.downsample = downsample

self.shortcut = (not downsample) and (input_channel == output_channel)

# apply alpha

input_channel = int(self.alpha * input_channel)

output_channel = int(self.alpha * output_channel)

# for main path:

c = t * input_channel

# 1x1 point wise conv

self.conv1 = nn.Conv2d(input_channel, c, kernel_size = 1, bias = False)

self.bn1 = nn.BatchNorm2d(c)

# 3x3 depth wise conv

self.conv2 = nn.Conv2d(c, c, kernel_size = 3, stride = self.stride, padding = 1, groups = c, bias = False)

self.bn2 = nn.BatchNorm2d(c)

# 1x1 point wise conv

self.conv3 = nn.Conv2d(c, output_channel, kernel_size = 1, bias = False)

self.bn3 = nn.BatchNorm2d(output_channel)

def forward(self, inputs):

# main path

x = F.relu6(self.bn1(self.conv1(inputs)), inplace = True)

x = F.relu6(self.bn2(self.conv2(x)), inplace = True)

x = self.bn3(self.conv3(x))

# shortcut path

x = x + inputs if self.shortcut else x

return x

class MobileNetV2(nn.Module):

def __init__(self, output_size, alpha = 1):

super(MobileNetV2, self).__init__()

self.output_size = output_size

# first conv layer

self.conv0 = nn.Conv2d(3, int(32*alpha), kernel_size = 3, stride = 1, padding = 1, bias = False)

self.bn0 = nn.BatchNorm2d(int(32*alpha))

# build bottlenecks

BaseBlock.alpha = alpha

self.bottlenecks = nn.Sequential(

BaseBlock(32, 16, t = 1, downsample = False),

BaseBlock(16, 24, downsample = False),

BaseBlock(24, 24),

BaseBlock(24, 32, downsample = False),

BaseBlock(32, 32),

BaseBlock(32, 32),

BaseBlock(32, 64, downsample = True),

BaseBlock(64, 64),

BaseBlock(64, 64),

BaseBlock(64, 64),

BaseBlock(64, 96, downsample = False),

BaseBlock(96, 96),

BaseBlock(96, 96),

BaseBlock(96, 160, downsample = True),

BaseBlock(160, 160),

BaseBlock(160, 160),

BaseBlock(160, 320, downsample = False))

# last conv layers and fc layer

self.conv1 = nn.Conv2d(int(320*alpha), 1280, kernel_size = 1, bias = False)

self.bn1 = nn.BatchNorm2d(1280)

self.fc = nn.Linear(1280, output_size)

# weights init

self.weights_init()

def weights_init(self):

for m in self.modules():

if isinstance(m, nn.Conv2d):

n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels

m.weight.data.normal_(0, math.sqrt(2. / n))

elif isinstance(m, nn.BatchNorm2d):

m.weight.data.fill_(1)

m.bias.data.zero_()

def forward(self, inputs):

# first conv layer

x = F.relu6(self.bn0(self.conv0(inputs)), inplace = True)

# assert x.shape[1:] == torch.Size([32, 32, 32])

# bottlenecks

x = self.bottlenecks(x)

# assert x.shape[1:] == torch.Size([320, 8, 8])

# last conv layer

x = F.relu6(self.bn1(self.conv1(x)), inplace = True)

# assert x.shape[1:] == torch.Size([1280,8,8])

# global pooling and fc (in place of conv 1x1 in paper)

x = F.adaptive_avg_pool2d(x, 1)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

 

 

精彩文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: