Comments (6)
Could you share the network's code?
from flops-counter.pytorch.
I use DnCNN pytorch version from https://github.com/cszn/DnCNN/tree/master/TrainingCodes/dncnn_pytorch
from flops-counter.pytorch.
The following is a customized code
import argparse
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import data_generator as dg
from data_generator import DenoisingDataset
import torch.nn.functional as F
from ptflops import get_model_complexity_info
# Params
parser = argparse.ArgumentParser(description='PyTorch DnCNN')
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--train_data', default='data/Train400', type=str, help='path of train data')
parser.add_argument('--sigma', default=25, type=int, help='noise level')
parser.add_argument('--epoch', default=180, type=int, help='number of train epoches')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
parser.add_argument('--depth', default = 17, type=int, help='number of convolutions for DnCNN')
parser.add_argument('--n_channels', default = 64, type=int, help='number of intermediate channels for DnCNN')
args = parser.parse_args()
batch_size = args.batch_size
cuda = torch.cuda.is_available()
n_epoch = args.epoch
sigma = args.sigma
save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
class DnCNN(nn.Module):
def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def forward(self, x):
y = x
out = self.dncnn(x)
return y-out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
if __name__ == '__main__':
# model selection
print('===> Building model')
model = DnCNN()
model = model.cpu()
#model.
with torch.cuda.device(0):
flops, params = get_model_complexity_info(model, (1, 256, 256), as_strings=True, print_per_layer_stat=False)
print(' - Flops: ' + flops)
print(' - Params: ' + params)
from flops-counter.pytorch.
I guess the number is too big beyond the default type of python.
from flops-counter.pytorch.
It's strange because that means really huge amount of computations in such a small network.
I've launched your code with the latest version of ptflops (0.3):
===> Building model
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
- Flops: 36.51 GMac
- Params: 556.1 k
May be the problem is in strange behavior of python in Windows. Do you use python2? I use python3.6 in ubuntu 16.04.
from flops-counter.pytorch.
Fixed in #43
from flops-counter.pytorch.
Related Issues (20)
- Is the input size of function "get_model_complexity_info()" must be fixed to 3 demensions? HOT 2
- How do I calculate the FLOPs of a model with some frozen layers during training? HOT 2
- Does this code also calculates MACs for back propagation? HOT 1
- how to calculate the flops if one module have 'einsum' option? HOT 2
- flops are counted multiple times if a module is shared by other modules HOT 4
- Support LayerNorm? HOT 1
- support for torch.compile? HOT 1
- The Conv1d with the same architecture yields different Macs in different models HOT 4
- integer overflow, when calculate the MACs of the ViT on Windows HOT 1
- There was a bug with computing FLOPs in OpenPCdet HOT 1
- Do this work with the 'deformable convolution' as well? HOT 5
- Is there some bug in the 'input_constructor' function? HOT 2
- Can't work with `F.interpolate` HOT 2
- FLOPs for a linear layer with 3D input HOT 2
- How to work with two input or more HOT 6
- Support ViT from timm huggingface HOT 1
- Fail to install the newest version HOT 1
- failed to install HOT 2
- Question about add operation count in different case
- Op Flatten not supported HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flops-counter.pytorch.