Giter Site home page Giter Site logo

Comments (15)

taldatech avatar taldatech commented on June 1, 2024

Disentanglement - this is a great research direction. I think ALAE did something like that, so take a look there. We do not deal with this in our paper. However, we do experiment with Image Translation, but it uses a different architecture than the StyleALAE one. There are many papers on disentanglement that involve VAEs. If you find this interesting, you might want to try and combine our training method with their method/architecture.

Line - this follows ALAE. Notice that a couple of line before we add a dimension (x[None,..]) for the forward pass. The transition is done in the style space. You can take a look at the model architecture to see what mapping_fl() returns.

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

@taldatech Hi, I'm sorry to bother you. I have a question in interpolation.
In training, I tried to experiment again when I interpolated at size 256 because it worked well, but as the model learned in a hierarchy, I brought the latest weights and now your code is interpolating at size 512. I changed image size I wanted, and there was an error.
How can I solve this? My latest weights was lod7.pth, and my model layer count is 8, max resolution level is 8, mapping layers is 8.

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

Hi, I'm not sure I completely understood your issue. The interpolation uses the config file, but you can also try to change the lod in the input/output of the decoder. Maybe you meant the latent size? They are the same for z and w and set to 512, as in the config file.

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

Oh, my training image size is 512 x 512 and test image size is 512 x 512. I modified some in config file, like this.
The image size result was 256 x 256 when I interpolated during training. Is this related to lod?
Thank you reply.

NAME: ffhq
DATASET:
PART_COUNT: 16
SIZE: 85964
FFHQ_SOURCE: /workspace/data/splitted_train_layer_8_mod/-r%02d.tfrecords
PATH: /workspace/data/splitted_train_layer_8_mod/-r%02d.tfrecords

PART_COUNT_TEST: 2
PATH_TEST: /workspace/data/splitted_test_layer_8_mod/-r%02d.tfrecords

SAMPLES_PATH: /workspace/data/test/
STYLE_MIX_PATH: ./style_mixing
MAX_RESOLUTION_LEVEL: 7
MODEL:
LATENT_SPACE_SIZE: 512
LAYER_COUNT: 7
MAX_CHANNEL_COUNT: 512
START_CHANNEL_COUNT: 64
DLATENT_AVG_BETA: 0.995
MAPPING_LAYERS: 7
BETA_KL: 0.2 # 1.0
BETA_REC: 0.1 # 1.0
BETA_NEG: [2048, 2048, 2048, 1024, 512, 512, 512, 512, 512]
SCALE: 0.000005
OUTPUT_DIR: './output_layer_8/'
TRAIN:
BASE_LEARNING_RATE: 0.002
EPOCHS_PER_LOD: 16
NUM_VAE: 1
LEARNING_DECAY_RATE: 0.1
LEARNING_DECAY_STEPS: []
TRAIN_EPOCHS: 300
####### 4 8 16 32 64 128 256 512 1024
LOD_2_BATCH_8GPU: [512, 256, 128, 64, 32, 32, 32, 16, 8] # If GPU memory ~16GB reduce last number from 32 to 24
LOD_2_BATCH_4GPU: [ 4096, 2048, 2048, 768, 192, 96, 48, 10, 12,16]
LOD_2_BATCH_2GPU: [ 4096, 2048, 2048, 768, 192, 96, 48, 10, 32,32]
LOD_2_BATCH_1GPU: [ 512, 512, 64, 64, 16, 8, 4]

LEARNING_RATES: [0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.002, 0.003, 0.003]

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

Okay so the latent size is indeed 512 (note that you used 7 mapping layers instead of 8). As for the resolution, for 256x256, MAX_RESOLUTION should be 8 (9 for 512x512) and LAYER_COUNT should be 7 for 256x256 (8 for 512x512).

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

I thought so too, so I changed the late_count to 7, but there was loading state_dict error

Failed to load: generator_s
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
2021-03-02 14:37:33,311 logger WARNING:
Failed to load: Error(s) in loading state_dict for Generator:
size mismatch for decode_block.4.noise_weight_1: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.bias_1: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.noise_weight_2: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.bias_2: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for decode_block.4.conv_1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
size mismatch for decode_block.4.blur.weight: copying a param with shape torch.Size([512, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1, 3, 3]).
size mismatch for decode_block.4.style_1.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for decode_block.4.style_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for decode_block.4.conv_2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for decode_block.4.style_2.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for decode_block.4.style_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for decode_block.5.noise_weight_1: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.bias_1: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.noise_weight_2: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.bias_2: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
size mismatch for decode_block.5.conv_1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
size mismatch for decode_block.5.blur.weight: copying a param with shape torch.Size([256, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
size mismatch for decode_block.5.style_1.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for decode_block.5.style_1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for decode_block.5.conv_2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for decode_block.5.style_2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
size mismatch for decode_block.5.style_2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for decode_block.6.noise_weight_1: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.bias_1: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.noise_weight_2: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.bias_2: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
size mismatch for decode_block.6.conv_1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
size mismatch for decode_block.6.blur.weight: copying a param with shape torch.Size([128, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3]).
size mismatch for decode_block.6.style_1.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for decode_block.6.style_1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for decode_block.6.conv_2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for decode_block.6.style_2.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([128, 512]).
size mismatch for decode_block.6.style_2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
size mismatch for to_rgb.4.to_rgb.weight: copying a param with shape torch.Size([3, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 256, 1, 1]).
size mismatch for to_rgb.5.to_rgb.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 128, 1, 1]).
size mismatch for to_rgb.6.to_rgb.weight: copying a param with shape torch.Size([3, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
2021-03-02 14:37:33,315 logger WARNING: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Failed to load: dlatent_avg
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
2021-03-02 14:37:33,315 logger WARNING:
Failed to load: Error(s) in loading state_dict for DLatent:
size mismatch for buff: copying a param with shape torch.Size([16, 512]) from checkpoint, the shape in current model is torch.Size([14, 512]).

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

The way I understand it:

  • You didn't train with this config, so you try to load weights for a different model, and you can't, thus, this error.
  • If you did train at 512x512, and now want to operate in 256x256, just lower the lod in the interpolation code (what you input to the encoder/decoder), but the config must stay the same as the config you trained your model with.

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

My issues is second. I lowered the lod and the image size is still 512 x 512 output.

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

Change the level here, and see if that works. Otherwise, you can just resize the output image to 256x256.

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

Thank you for your kind reply. I solved.
One more question.
When I tested reconstruction on the test data, it is not equal the original test data, but it creates a slightly different image. Is this not learning well? Or is it because the encoder encodes the image generated by the decoder like a paper?

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

I'm not sure I understood. You ask about the case where for example you train on faces from CelebA and test the model on faces from FFHQ? In that case, I wouldn't expect the model to perform well, as even though both include faces, these faces are aligned differently, include other types of faces (e.g., FFHQ includes faces of children) and etc...

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

@taldatech Actually, My training data is Chest X-ray dataset. I think it's because Beta rec is 0.1.

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

I see, I haven't had the chance to work with this type of data. But I assume that for medical data you would want more finesse, and better reconstructions, so maybe use higher beta_rec for that. Also, maybe give the regular ResNet-based architecture a chance.

from soft-intro-vae-pytorch.

chokyungjin avatar chokyungjin commented on June 1, 2024

Very interesting. By the way, isn't it difficult to disentangle because of the residual block when use the resnet based architecture?

from soft-intro-vae-pytorch.

taldatech avatar taldatech commented on June 1, 2024

I'm not sure about the answer to this question, but it is definitely interesting and worth more investigation.

from soft-intro-vae-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.