Comments (6)
I converted densenet121
if anyone is interested, you can download from here, but that requires some changes to the DenseNetEfficient
class:
class DenseNetEfficient(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
This model uses shared memory allocations for the outputs of batch norm and
concat operations, as described in `"Memory-Efficient Implementation of DenseNets"`.
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5,
num_init_features=24, bn_size=4, drop_rate=0,
num_classes=10, cifar=True):
super(DenseNetEfficient, self).__init__()
assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1'
self.avgpool_size = 8 if cifar else 7
# First convolution
if cifar:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)),
]))
else:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
]))
self.features.add_module('norm0', nn.BatchNorm2d(num_init_features))
self.features.add_module('relu0', nn.ReLU(inplace=True))
self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
ceil_mode=False))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features,
num_output_features=int(num_features
* compression))
self.features.add_module('transition%d' % (i + 1), trans)
num_features = int(num_features * compression)
# Final batch norm
self.features.add_module('norm_final', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view(
features.size(0), -1)
out = self.classifier(out)
return out
then you can instantiate the densenet121 model
densenet = DenseNetEfficient(growth_rate=32, block_config=[6,12,24,16], num_classes=1000, cifar=False, num_init_features=64)
densenet.load_state_dict(torch.load(pretrained_model_path))
from efficient_densenet_pytorch.
We don't have pre-trained models that were trained using this implementation. It should be possible to convert the existing pre-trained imagenet models to work with this implementation.
from efficient_densenet_pytorch.
Thanks @ZhengRui! Any chance you can make a PR adding those changes?
from efficient_densenet_pytorch.
@gpleiss Pull request sent, you may take a look. I also make multi-gpu and single-gpu having the same module names, so models can be shared between them.
from efficient_densenet_pytorch.
Merged #19 - closes this issue
from efficient_densenet_pytorch.
I have converted the pretrained models to pytorch efficient models. You can download from here. Or you can convert it yourself.
from efficient_densenet_pytorch.
Related Issues (20)
- Can we test using the trained model. HOT 1
- Inference time issue HOT 1
- New adaptive pooling layer. HOT 1
- Unable to run demo.
- How about the version of the torchvision, project killer and pyhon-fire? HOT 1
- Question: why use bn_function on 1x1 conv, not on 3x3 conv HOT 1
- dropout not in 3x3 convolutional layer
- What is bn_size? HOT 3
- 网络内存消耗?
- Is it possible to provide ImageNet pre-trained models? HOT 2
- Is this really memory efficient? HOT 1
- Is the normalizatin values for CIFAR-10 correct? HOT 1
- test error interpretation HOT 1
- How can I apply this to my own model? HOT 1
- AttributeError: module 'fire' has no attribute 'Fire' HOT 8
- The function received no value for the required argument: data HOT 2
- The BN running mean&var with torch.utils.checkpoint.checkpoint HOT 2
- will the inference memory reduced too? HOT 1
- Question about the place of checkpoint (shared memory allocation) HOT 1
- Excuse me, what is the cause of this problem? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from efficient_densenet_pytorch.