Comments (15)
Disentanglement - this is a great research direction. I think ALAE did something like that, so take a look there. We do not deal with this in our paper. However, we do experiment with Image Translation, but it uses a different architecture than the StyleALAE one. There are many papers on disentanglement that involve VAEs. If you find this interesting, you might want to try and combine our training method with their method/architecture.
Line - this follows ALAE. Notice that a couple of line before we add a dimension (x[None,..]
) for the forward pass. The transition is done in the style space. You can take a look at the model architecture to see what mapping_fl()
returns.
from soft-intro-vae-pytorch.
@taldatech Hi, I'm sorry to bother you. I have a question in interpolation.
In training, I tried to experiment again when I interpolated at size 256 because it worked well, but as the model learned in a hierarchy, I brought the latest weights and now your code is interpolating at size 512. I changed image size I wanted, and there was an error.
How can I solve this? My latest weights was lod7.pth, and my model layer count is 8, max resolution level is 8, mapping layers is 8.
from soft-intro-vae-pytorch.
Hi, I'm not sure I completely understood your issue. The interpolation uses the config file, but you can also try to change the lod
in the input/output of the decoder. Maybe you meant the latent size? They are the same for z
and w
and set to 512, as in the config file.
from soft-intro-vae-pytorch.
Oh, my training image size is 512 x 512 and test image size is 512 x 512. I modified some in config file, like this.
The image size result was 256 x 256 when I interpolated during training. Is this related to lod?
Thank you reply.
NAME: ffhq
DATASET:
PART_COUNT: 16
SIZE: 85964
FFHQ_SOURCE: /workspace/data/splitted_train_layer_8_mod/-r%02d.tfrecords
PATH: /workspace/data/splitted_train_layer_8_mod/-r%02d.tfrecords
PART_COUNT_TEST: 2
PATH_TEST: /workspace/data/splitted_test_layer_8_mod/-r%02d.tfrecords
SAMPLES_PATH: /workspace/data/test/
STYLE_MIX_PATH: ./style_mixing
MAX_RESOLUTION_LEVEL: 7
MODEL:
LATENT_SPACE_SIZE: 512
LAYER_COUNT: 7
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 64
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 7
BETA_KL: 0.2 # 1.0
BETA_REC: 0.1 # 1.0
BETA_NEG: [2048, 2048, 2048, 1024, 512, 512, 512, 512, 512]
SCALE: 0.000005
OUTPUT_DIR: './output_layer_8/'
TRAIN:
BASE_LEARNING_RATE: 0.002
EPOCHS_PER_LOD: 16
NUM_VAE: 1
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 300
####### 4 8 16 32 64 128 256 512 1024
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 16, 8] # If GPU memory ~16GB reduce last number from 32 to 24
LOD_2_BATCH_4GPU: [ 4096, 2048, 2048, 768, 192, 96, 48, 10, 12,16]
LOD_2_BATCH_2GPU: [ 4096, 2048, 2048, 768, 192, 96, 48, 10, 32,32]
LOD_2_BATCH_1GPU: [ 512, 512, 64, 64, 16, 8, 4]
LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003]
from soft-intro-vae-pytorch.
Okay so the latent size is indeed 512 (note that you used 7 mapping layers instead of 8). As for the resolution, for 256x256, MAX_RESOLUTION
should be 8 (9 for 512x512) and LAYER_COUNT
should be 7 for 256x256 (8 for 512x512).
from soft-intro-vae-pytorch.
I thought so too, so I changed the late_count to 7, but there was loading state_dict error
Failed to load: generator_s
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
2021-03-02 14:37:33,311 logger WARNING:
Failed to load: Error(s) in loading state_dict for Generator:
size mismatch for decode_block.4.noise_weight_1: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.bias_1: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.noise_weight_2: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.bias_2: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.conv_1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
size mismatch for decode_block.4.blur.weight: copying a param with shape torch.Size([512, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3]).
size mismatch for decode_block.4.style_1.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for decode_block.4.style_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for decode_block.4.conv_2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for decode_block.4.style_2.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for decode_block.4.style_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for decode_block.5.noise_weight_1: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.bias_1: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.noise_weight_2: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.bias_2: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.conv_1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
size mismatch for decode_block.5.blur.weight: copying a param with shape torch.Size([256, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
size mismatch for decode_block.5.style_1.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for decode_block.5.style_1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for decode_block.5.conv_2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for decode_block.5.style_2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for decode_block.5.style_2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for decode_block.6.noise_weight_1: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.bias_1: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.noise_weight_2: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.bias_2: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.conv_1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
size mismatch for decode_block.6.blur.weight: copying a param with shape torch.Size([128, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3]).
size mismatch for decode_block.6.style_1.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for decode_block.6.style_1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for decode_block.6.conv_2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for decode_block.6.style_2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for decode_block.6.style_2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for to_rgb.4.to_rgb.weight: copying a param with shape torch.Size([3, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 256, 1, 1]).
size mismatch for to_rgb.5.to_rgb.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 128, 1, 1]).
size mismatch for to_rgb.6.to_rgb.weight: copying a param with shape torch.Size([3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
2021-03-02 14:37:33,315 logger WARNING: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Failed to load: dlatent_avg
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
2021-03-02 14:37:33,315 logger WARNING:
Failed to load: Error(s) in loading state_dict for DLatent:
size mismatch for buff: copying a param with shape torch.Size([16, 512]) from checkpoint, the shape in current model is torch.Size([14, 512]).
from soft-intro-vae-pytorch.
The way I understand it:
- You didn't train with this config, so you try to load weights for a different model, and you can't, thus, this error.
- If you did train at 512x512, and now want to operate in 256x256, just lower the
lod
in the interpolation code (what you input to the encoder/decoder), but the config must stay the same as the config you trained your model with.
from soft-intro-vae-pytorch.
My issues is second. I lowered the lod and the image size is still 512 x 512 output.
from soft-intro-vae-pytorch.
Change the level here, and see if that works. Otherwise, you can just resize the output image to 256x256.
from soft-intro-vae-pytorch.
Thank you for your kind reply. I solved.
One more question.
When I tested reconstruction on the test data, it is not equal the original test data, but it creates a slightly different image. Is this not learning well? Or is it because the encoder encodes the image generated by the decoder like a paper?
from soft-intro-vae-pytorch.
I'm not sure I understood. You ask about the case where for example you train on faces from CelebA and test the model on faces from FFHQ? In that case, I wouldn't expect the model to perform well, as even though both include faces, these faces are aligned differently, include other types of faces (e.g., FFHQ includes faces of children) and etc...
from soft-intro-vae-pytorch.
@taldatech Actually, My training data is Chest X-ray dataset. I think it's because Beta rec is 0.1.
from soft-intro-vae-pytorch.
I see, I haven't had the chance to work with this type of data. But I assume that for medical data you would want more finesse, and better reconstructions, so maybe use higher beta_rec
for that. Also, maybe give the regular ResNet-based architecture a chance.
from soft-intro-vae-pytorch.
Very interesting. By the way, isn't it difficult to disentangle because of the residual block when use the resnet based architecture?
from soft-intro-vae-pytorch.
I'm not sure about the answer to this question, but it is definitely interesting and worth more investigation.
from soft-intro-vae-pytorch.
Related Issues (20)
- Training interrupts on Google Colab notebook HOT 1
- System Error HOT 2
- Recommended Hyper-Params for The Enc-Dec Arch on MNIST HOT 4
- generate function parameters HOT 2
- Couldn't reconstruct when using trained model. HOT 1
- Potential Bugs in the FID Calc? HOT 5
- Questions about out-of-Distribution (OOD) Detection HOT 1
- Question about paper's equation HOT 9
- Reproducing 2d results HOT 2
- Image quality deteriorates at final image resolution HOT 1
- Digital-Monsters dataset HOT 1
- Some Question about smooth interpolation Fig.17 HOT 2
- a question HOT 2
- Aborted core dumped error HOT 16
- Sample image question HOT 6
- Can't Not find weighted sum of the extracted styles HOT 8
- Inconsistency between an equation and implementation in expELBO? HOT 1
- Request of Pretrained Models HOT 12
- Pre-trained model HOT 10
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from soft-intro-vae-pytorch.