Giter Site home page Giter Site logo

ddpo's People

Contributors

anonymized-ddpo avatar jannerm avatar kvablack avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

ddpo's Issues

About prompt for Image-Text Alignment

I wonder which kind of prompt you used to generate images, which will be feeded into LLava to measure the reward of image-text alignment?

Is it the same as those used for other reward functions, like JEPG, or Aesthetic?

About Dataset

Have you released the dataset or how to download the datasets, thanks

Log prob Computation

In the line ddpo/diffusers_patch/scheduling_ddim_flax.py:359, use the formula jnp.mean to total log prob for the total latent space.
Is more proper to use formula jnp.sum here?

Since jnp.mean means we compute the ratio as,
$r = \left(\frac{\sum_{i=1}^N p_i}{\sum_{i=1}^N q_i}\right)^{\frac{1}{N}}$
where $N$ is total dimension of latent space, $p_i$ is the new probability and $q_i$ is the old probability.

And in the default config, clip range is set to 1e-4, which means the true clip range is $10^{-4N}$, which is very very small, compared to other reinforcement learning application.

Question on the meaning of 'advantage'

Hi,
Thank you for open sourcing the repo. I am reading the code and want to understand how the loss is computed.

It looks like in the final loss,

loss = jnp.mean(jnp.maximum(unclipped_loss, clipped_loss))

the 'ratio' is just the $p_{\theta}/p_{\theta_old}$, meaning if I want to compute the loss corresponding to gradient in Eqn(3), I only need the variable 'advantage' in
unclipped_loss = -advantages * ratio

which is essentially gaussian normalized score of the original reward value?

I guess then this loss will be non-differentiable if the reward is say the jpeg encoding length?

I must be missing something, am i ?

Thanks!

Reward function on both aesthetic and prompt alignment

Hi,

Wonderful project, congratulations !

Have you tried using a reward function for both objectives ? Because it feels like the aesthetic reward do make the generations look better but also oversimplifies it (posing only one subject in the center with a blur effect around).

Also, do you have any insights on how to use a new reward function ? Should it be normalized between a certain min-max ? How long does it take to train each model ? Does it change between the different reward functions ?

Thanks :)

Question about finetune.py

The only difference between the method used in finetune.py and trandtional training of diffusion model is that the former multiplies batch-level weights (normalized rewards) to the batch-level reconstruction loss provied by the Unet in stable diffusion, is that true?

is the following code in diffusion.py the key point of success of the REINFORCE version of your method?
if weights is None:
## average over batch dimension
loss = loss.mean()
else:
## multiply loss by weights
assert loss.size == weights.size
loss = (loss * weights).sum()

error on 4090+cuda12.1

WARNING:jax.experimental.compilation_cache.compilation_cache:Initialized persistent compilation cache at cache
[ utils/logger ] Suppressing most dependency logging
2023-06-11 16:58:17.065316: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.3.2 but source was compiled with: 8.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
File "/home/chaojiewang/Desktop/Logical_Diffusion/Diffusion_RL/pipeline/policy_gradient.py", line 484, in
main()
File "/home/chaojiewang/Desktop/Logical_Diffusion/Diffusion_RL/pipeline/policy_gradient.py", line 51, in main
rng = jax.random.PRNGKey(args.seed)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/random.py", line 136, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 561, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
return seed(seeds)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 813, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 519, in compile_or_get_cached
compiled = backend_compile(backend, serialized_computation,
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

About the equation of DDPO_IS in the paper and the code

Thanks for your paper and code. I'm confused about the loss calculation of DDPO_IS in the paper and code is different:
image

  • There are mainly two differences in the codes:
    • the parameters are updated in each timestep instead of 0-T.
    • the unclipped_loss=-advantage*ratio, I see no log_prob in the unclipped_loss: $$LogProb=log(p_\theta(x_{t-1}|x_t,c))$$

Unable to run the LLaVa bertscore

Hi,
I am trying to run the code for LLaVa BERTscore, using

python pipeline/policy_gradient.py --dataset llava_bertscore

However, I am getting the following timeout error. Is the readme missing things on the hosting server?
I am not too familiar with it.

WARNING:urllib3.connectionpool:Retrying (Retry(total=998, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f3570062e20>: Failed to establish a new connection: [Errno 111] Connection refused')': /

I use a GPU machine, is there a way to directly call the LLaVa model in case there is a script?

Working offline w/o RemoteWriter/Reader

I am trying to work locally, without storing things on GCP, and therefore using H5Reader/H5Writer. However, that is throwing errors, in particular when local_size < max_samples and during the read process, H5Reader does not have methods such as __len__ and make_weights. Is it possible to release the updated code that does not utilize remote read/write?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.