package vision
import (
"math"
"torch"
"torch/nn"
"torch/nn/init"
)
func max(x, y int64) int64 { // math.max only works on floats
if x > y {
return x
}
return y
}
func makeDivisible(value float64, divisor int64, minValue *int64) int64 {
if minValue == nil {
min_value = divisor
}
newValue := max(*minValue, (int64(value+float64(divisor)/2)/divisor)*divisor)
if newValue < .9*value {
newValue += divisor
}
return newValue
}
type ConvBNReLU struct {
nn.Sequential
}
func NewConvBNReLU(in_planes, out_planes, kernel_size, stride, groups int64) ConvBNReLU {
ret := ConvBNReLU{nn.NewSequential()}
options := nn.Conv2dOptions{in_planes, out_planes, kernel_size}
ret.PushBack(nn.NewConv2d(
options.stride(stride).padding(padding).groups(groups).bias(false)))
ret.PushBack(nn.BatchNorm2d{out_planes})
ret.PushBack(nn.Functional{nn.ReLU})
}
func (net *ConvBNReLU) Forward(x torch.Tensor) torch.Tensor {
return net.Sequential.Forward(x)
}
type MobileNetInvertedResidual struct {
nn.Module
stride int64
useResConnect bool
conv nn.Sequential
}
func NewMobileNetInvertedResidual(
input, output, stride int64, expandRatio float64) MobileNetInvertedResidual {
net := MobileNetInvertedResidual{
Module: nn.NewModule(),
stride: stride,
useResConnect: stride == 1 && input == output,
conv: nn.NewSequential()}
net.stride = stride
net.useResConnect = stride == 1 && input == output
net.conv = nn.NewSequential()
doubleCompare := func(a, b float64) {
return math.Abs(a-b) < 1e-20
}
torch.CHECK(stride == 1 || stride == 2)
hiddenDim := int64(math.Round(float64(input) * expandRatio))
if !doubleCompare(expandRatio, 1) {
conv.PushBack(NewConvBNReLU(input, hiddenDim, 1, 3, 1, 1))
}
net.conv.PushBack(NewConvBNReLU(hiddenDim, hiddenDim, 3, stride, hiddenDim, 1))
options := nn.Conv2dOptions{hiddenDim, output, 1}
net.conv.PushBack(nn.NewConv2d(options.stride(1).padding(0).bias(false)))
net.RegisterModule("conv", net.conv)
return net
}
func (net *MobileNetInvertedResidual) Forward(x torch.Tensor) torch.Tensor {
if net.useResConnect {
return net.Add(x + net.conv.Forward(x))
}
return net.conv.Forward(x)
}
type MobileNetV2 struct {
nn.Module // nn.Module is a monadic type
lastChannel int64
features, classifier nn.Sequential
}
func NewMobileNetV2(
numClasses int64,
widthMult float64,
invertedResidualSettings [][]int64,
roundNearest int64) MobileNetV2 {
net := MobileNetV2{
Module: nn.NewModule(),
features: nn.NewSequential(),
classfier: nn.NewSequential()}
var inputChannel int64 = 32
var lastChannel int64 = 1280
if invertedResidualSettings == nil || len(invertedResidualSettings) == 0 {
invertedResidualSettings := [][]int64{
// t, c, n, s
{1, 16, 1, 1},
{6, 24, 2, 2},
{6, 32, 3, 2},
{6, 64, 4, 2},
{6, 96, 3, 1},
{6, 160, 3, 2},
{6, 320, 1, 1},
}
}
torch.CHECK(
len(invertedResidualSettings[0]) == 4,
"inverted_residual_settings should contain 4-element vectors")
inputChannel := makeDivisible(inputChannel*widthMult, roundNearest, nil)
net.lastChannel =
makeDivisible(lastChannel*math.max(1.0, widthMult), roundNearest, nil)
net.features.PushBack(NewConvBNReLU(3, inputChannel, 3, 2))
for setting := range invertedResidualSettings {
outputChannel := makeDivisible(setting[1]*widthMult, roundNearest, nil)
for i := 0; i < setting[2]; i++ {
stride := 1
if i == 0 {
stride = setting[3]
}
features.PushBack(
NewMobileNetInvertedResidual(
inputChannel, outputChannel, stride, setting[0]))
inputChannel = outputChannel
}
}
net.features.PushBack(NewConvBNReLU(inputChannel, net.lastChannel, 1, 3, 1, 1))
classifier.PushBack(nn.Dropout(0.2))
classifier.PushBack(nn.Linear(net.lastChannel, net.numClasses))
net.RegisterModule("features", net.features)
net.RegisterModule("classifier", net.classifier)
for module := range net.Modules(false) {
switch M := module.(type) {
case nn.Conv2d:
init.KaimingNormal(M.Weight, 0, torch.kFanOut)
if M.options.Bias {
init.zeros(M.Bias)
}
case nn.BatchNorm2d:
init.Ones(M.Weight)
init.Zeros(M.Bias)
case nn.Linear:
init.Normal(M.Weight, 0, 0.01)
init.Zero(M.Bias)
}
}
return net
}
func (net *MobileNetV2) Forward(x torch.Tensor) torch.Tensor {
x = net.features.Forward(x)
x = net.Mean(x, []int{2, 3})
x = net.classifier.Forwart(x)
return x
}
|
from torch import nn
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
if norm_layer is None:
norm_layer = nn.BatchNorm2d
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
norm_layer=None):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
norm_layer: Module specifying the normalization layer to use
"""
super(MobileNetV2, self).__init__()
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = nn.BatchNorm2d
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
x = self.classifier(x)
return x
def forward(self, x):
return self._forward_impl(x)
|