Comments (1)
Hey @SixtyTrees
Sequential does what it basically says, creates a model that is a sequence of layers.
And that is just syntactic sugar for feeding the output of a layer as input to the next one.
That being said, you cannot implement skip connections inside the nn.Sequential
.
For that you have to subclass nn.Module
and define the forward pass yourself.
Note that in order to create a skip connection the output dimension must match the dimensions the layer expects as input. There are 2 ways you can implement skip connections (that I know of), the first one is through concatenation
, and the second one (with an example below) is using addition
.
In a ResNet-like (which uses addition) you'd need your linear2
and linear3
output dims to match the input dim of linear4
. This means:
nn.Linear(10, **10**), # linear2
...
nn.Linear(**10**, **10**), # linear3 -- note that you have to change the in dim of linear3 as well
...
nn.Linear(**10**, 20), # linear4
In a UNet-like architecture (which uses concatenation) you'd need the sum of the out dims of linear2
and linear3
to match the input dim of linear4
:
nn.Linear(10, **15**), # linear2
...
nn.Linear(15, **10**), # linear3
...
nn.Linear(**25**, 20), # linear4
Here's a simple example of how to create a skip connection using addition:
class MySimpleSkipModel(nn.Module):
def __init__(self):
self.linear1 = nn.Linear(30, 10)
self.linear2 = nn.Linear(10, 10)
self.linear3 = nn.Linear(10, 5)
def forward(self, inputs):
l1_out = self.linear1(inputs)
l2_out = self.linear2(l1_out)
l3_out = self.linear3(l1_out + l2_out) # using the outputs of both linear1 and linear2, use torch.cat for concat
return l3_out
from pytorch-deep-learning.
Related Issues (20)
- There is a typing error.
- PyTorch
- 09 Model development: TypeError: AsyncConnectionPool.__init__() got an unexpected keyword argument 'socket_options'
- CIFAR 10 dataset
- 05_Pytorch going Modular model_builder.py HOT 1
- 03_pytorch_computer_vision clarification on a comment
- Model not learning
- Where to go from here HOT 1
- Cant find solution to error - Model 2: Training Our First CNN and Evaluating Its Results HOT 1
- Help? Pls HOT 1
- 06 Error with pretrained Model
- Clarification for discussion of shared memory in `torch.from_numpy(ndarray)` and `torch.Tensor.numpy()` HOT 4
- Incorrect Batch Indexing in FashionMNIST Batched Image: Displays 32 Instead of 31
- 04_pytorch_custom_datasets_video.ipynb - RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 1 HOT 4
- segmentation fault
- 09. PyTorch Model Deployment - **9.3 Uploading to Hugging Face**, RuntimeError: Could not infer dtype of numpy.uint8 HOT 1
- Helper Function github link requests error HOT 2
- How to Create GitHub Pages from Scratch? HOT 1
- Modular approch for transfer learning
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-deep-learning.