Giter Site home page Giter Site logo

fine-tune-models's Introduction

Hi there ๐Ÿ‘‹

Jonathan's GitHub stats

fine-tune-models's People

Contributors

cccntu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

fine-tune-models's Issues

about dataset format

Hi, I see your code of run_finetune_vqgan.py and Kaggle dataset, but I can't understand the dataset format of "danbooru_image_paths_ds.json". Can you give me its dataset format?

Converting from flax back to pt?

Got the training working and it works really well for fine tuning out some common "glitches" and mis-reconstructions from photos and textures. However, i'd really like to convert the weights back into pytorch-usable format.

I've looked at the convert_diffusers_to_jax script from patil suraj's repo, but it doesn't look like reverse engineering it into an inverse function is feasible. Some changes it applies to the model appear to be irreversible without reconstructing the model key structure from scratch.

Diffusers now has its own pt->flax and flax->pt conversion, but i couldn't get its flax model to work with the script. There seem to be too many discrepancies in both the methods and the model keys and my knowledge is too limited to figure out how to adapt the fine tuning script to it. Diffuser's AutoEncoderKL also returns its output as yet another encapsulated proprietary class, complicating things even further.

So... bottom line - is there something i'm missing?
Is there some simple solution i've overlooked?

Could not find parameter named "kernel" in scope "/encoder/mid_block/attentions_0/query"

Hi, upon loading this up with any dataset and the 1.4 vae converted to jax, i get this warning on loading:

The checkpoint D:\gitprojects\finetuningvae\backup\stable-diffusion-v1-5-jax\vae\ is missing required keys: {('decoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'key', 'kernel')}. Make sure to call model.init_weights to initialize the missing weights.

(it's the 1.4 vae i just copied it over)

which i presume causes the training to fail with this error

y = fun(self, *args, **kwargs)

File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\stable_diffusion_jax\modeling_vae.py", line 287, in call
hidden_states = attn(hidden_states)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\stable_diffusion_jax\modeling_vae.py", line 150, in call
query = self.query(hidden_states)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\linear.py", line 196, in call
kernel = self.param('kernel',
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 1263, in param
v = self.scope.param(name, init_fn, *init_args, unbox=unbox)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\core\scope.py", line 842, in param
raise errors.ScopeParamNotFoundError(name, self.path_text)
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/encoder/mid_block/attentions_0/query". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)

This happens whatever i start up the training with, at a bit of a loss really.

vae finetuning

does this work with normal pytorch sd and not just stablediffusion-jax?

Is there any "Artifact" issues during finetuning?

I've implemented PatchGAN and StyleGAN Loss for decoder only finetuned, but I got artifacts reconstructions even after 50k steps.

My Loss function setting is: 0.1 LPIPS + 0.1 PatchGAN (with adaptive weight) / 0.1 StyleGAN (with 1e9 gradient panelty) + MSE Loss
The learning part is post_quant_conv, decoder and GAN part.

The images are reconstructed without EMA. Here are some failure cases
reconstructions_00088697280_b-001760
reconstructions_00083321728_b-001480

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.