jannerm / ddpo Goto Github PK
View Code? Open in Web Editor NEWCode for the paper "Training Diffusion Models with Reinforcement Learning"
Home Page: https://rl-diffusion.github.io
License: MIT License
Code for the paper "Training Diffusion Models with Reinforcement Learning"
Home Page: https://rl-diffusion.github.io
License: MIT License
I wonder which kind of prompt you used to generate images, which will be feeded into LLava to measure the reward of image-text alignment?
Is it the same as those used for other reward functions, like JEPG, or Aesthetic?
Have you released the dataset or how to download the datasets, thanks
In the line ddpo/diffusers_patch/scheduling_ddim_flax.py:359
, use the formula jnp.mean
to total log prob for the total latent space.
Is more proper to use formula jnp.sum
here?
Since jnp.mean
means we compute the ratio as,
where
And in the default config, clip range is set to 1e-4, which means the true clip range is
Hi,
Thank you for open sourcing the repo. I am reading the code and want to understand how the loss is computed.
It looks like in the final loss,
ddpo/ddpo/training/policy_gradient.py
Line 125 in f0b6ca7
ddpo/ddpo/training/policy_gradient.py
Line 123 in f0b6ca7
I guess then this loss will be non-differentiable if the reward is say the jpeg encoding length?
I must be missing something, am i ?
Thanks!
Hi,
Wonderful project, congratulations !
Have you tried using a reward function for both objectives ? Because it feels like the aesthetic reward do make the generations look better but also oversimplifies it (posing only one subject in the center with a blur effect around).
Also, do you have any insights on how to use a new reward function ? Should it be normalized between a certain min-max ? How long does it take to train each model ? Does it change between the different reward functions ?
Thanks :)
The only difference between the method used in finetune.py and trandtional training of diffusion model is that the former multiplies batch-level weights (normalized rewards) to the batch-level reconstruction loss provied by the Unet in stable diffusion, is that true?
is the following code in diffusion.py the key point of success of the REINFORCE version of your method?
if weights is None:
## average over batch dimension
loss = loss.mean()
else:
## multiply loss by weights
assert loss.size == weights.size
loss = (loss * weights).sum()
WARNING:jax.experimental.compilation_cache.compilation_cache:Initialized persistent compilation cache at cache
[ utils/logger ] Suppressing most dependency logging
2023-06-11 16:58:17.065316: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.3.2 but source was compiled with: 8.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
File "/home/chaojiewang/Desktop/Logical_Diffusion/Diffusion_RL/pipeline/policy_gradient.py", line 484, in
main()
File "/home/chaojiewang/Desktop/Logical_Diffusion/Diffusion_RL/pipeline/policy_gradient.py", line 51, in main
rng = jax.random.PRNGKey(args.seed)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/random.py", line 136, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 561, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
return seed(seeds)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 813, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 519, in compile_or_get_cached
compiled = backend_compile(backend, serialized_computation,
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Thanks for your paper and code. I'm confused about the loss calculation of DDPO_IS in the paper and code is different:
unclipped_loss=-advantage*ratio
, I see no log_prob
in the unclipped_loss
: Hi,
I am trying to run the code for LLaVa BERTscore, using
python pipeline/policy_gradient.py --dataset llava_bertscore
However, I am getting the following timeout error. Is the readme missing things on the hosting server?
I am not too familiar with it.
WARNING:urllib3.connectionpool:Retrying (Retry(total=998, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f3570062e20>: Failed to establish a new connection: [Errno 111] Connection refused')': /
I use a GPU machine, is there a way to directly call the LLaVa model in case there is a script?
I am trying to work locally, without storing things on GCP, and therefore using H5Reader/H5Writer. However, that is throwing errors, in particular when local_size < max_samples
and during the read process, H5Reader does not have methods such as __len__
and make_weights
. Is it possible to release the updated code that does not utilize remote read/write?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.