Giter Site home page Giter Site logo

cql's Introduction

CQL

Code for Conservative Q-Learning for Offline Reinforcement Learning (https://arxiv.org/abs/2006.04779)

In this repository we provide code for CQL algorithm described in the paper linked above. We provide code in two sub-directories: atari containing code for Atari experiments and d4rl containing code for D4RL experiments. Due to changes in the datasets in D4RL, we expect some changes in CQL performance on the new D4RL datasets and we will soon provide a table with new performance numbers for CQL here in this README. We will continually keep updating the numbers here.

If you find this repository useful for your research, please cite:

@article{kumar2020conservative,
  author       = {Aviral Kumar and Aurick Zhou and George Tucker and Sergey Levine},
  title        = {Conservative Q-Learning for Offline Reinforcement Learning},
  conference   = {arXiv Pre-print},
  url          = {https://arxiv.org/abs/2006.04779},
}

Atari Experiments

Our code is built on top of the batch_rl repository. Please run installation instructions from the batch_rl repository. CQL in this case was implemented on top of QR-DQN for which the implementation is present in batch_rl/multi_head/quantile_agent.py.

To run experiments in the paper, you will have to specify the size of an individual replay buffer for the purpose of being able to use 1% and 10% data. This is specified in line 53 in batch_rl/fixed_replay/replay_memory/fixed_replay_memory.py. For 1%, set args[2]=1000 and for 10% set args[2] = 10000. Depending upon the availability of RAM, you may be able to raise the value of num_buffers from 10 to 50 (we were able to do this for 1% datasets) and then change this value in: self._load_replay_buffers(num_buffers=<>).

Now, to run CQL, use the follwing command:

python -um batch_rl.fixed_replay.train \
  --base_dir=/tmp/batch_rl \
  --replay_dir=$DATA_DIR/Pong/1 \
  --agent_name=quantile \
  --gin_files='batch_rl/fixed_replay/configs/quantile_pong.gin' \
  --gin_bindings='FixedReplayRunner.num_iterations=1000' \
  --gin_bindings='atari_lib.create_atari_environment.game_name = "Pong"'
  --gin_bindings='FixedReplayQuantileAgent.minq_weight=1.0'

For 1% data, use minq_weight=4.0 and for 10% data, use minq_weight=1.0.

D4RL Experiments

Our code is built off of rlkit. Please install the conda environment for rlkit while making sure to install torch>=1.1.0. Please install d4rl. Code for the CQL algorithm is present in rlkit/torch/sac/cql.py. After this, for running CQL on the MuJoCo environments, run:

python examples/cql_mujoco_new.py --env=<d4rl-mujoco-env-with-version e.g. hopper-medium-v0>
        --policy_lr=1e-4 --seed=10 --lagrange_thresh=-1.0 
        --min_q_weight=(5.0 or 10.0) --gpu=<gpu-id> --min_q_version=3

In terms of parameters, we have found min_q_weight=5.0 or min_q_weight=10.0 along with policy_lr=1e-4 or policy_lr=3e-4 to work reasonably fine for the Gym MuJoCo tasks. These parameters are slightly different from the paper (which will be updated soon) due to differences in the D4RL datasets. For sample performance numbers (final numbers to be updated soon), hopper-medium acheives ~3000 return, and hopper-medium-exprt obtains ~1300 return at the end of 500k gradient steps. To run CQL(\rho) [i.e. without the importance sampling], set min_q_version=2.

For Ant-Maze tasks, please run:

python examples/cql_antmaze_new.py --env=antmaze-medium-play-v0 --policy_lr=1e-4 --seed=10
        --lagrange_thresh=5.0 --min_q_wight=5.0 --gpu=<gpu-id> --min_q_version=3

In case of any questions, bugs, suggestions or improvements, please feel free to contact me at [email protected] or open an issue.

cql's People

Contributors

aviralkumar2907 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

cql's Issues

Examples for D4RL Experiments

Two examples files are provided for D4RL experiments, e.g., cql_mujoco_new.py and cql_antmaze_new.py. I am a little bit confused about these two files.

In cql_mujoco_new.py, gym is used to create an environment.
In cql_antmaze_new.py, HalfCheetah-v2 instance is used. Should we use an antmaze instance and the mujoco environment?

About hyperparameters for D4RL Kitchen

Hi! I was unable to reproduce the result in kitchen-mixed-v0 using the same hyperparameters in D4RL mujoco tasks. Could you please provide the configurations for Kitchen?

Why substract entropy from Q-values? ("min_q_version == 3")

Hey,

unlike in the paper, implementation has this part with subtracting the action probabilities from Q:

if self.min_q_version == 3:

My guess that the effect would be to have less focus of the loss on a single high-Q action, should policy focus on such. But then we already have temperature parameter. Not sure author will answer, so anybody who knows, I'd appreciate your insights :)

ImportError: cannot import name 'elem_or_tuple_to_numpy' from 'rlkit.torch.core'

When I tried to run d4rl MuJoCo benchmark, this happens.
It looks like there is a version discrepancy of rlkit library between the published code and the real one used in the paper.
If updating the code bothers you, maybe you could give a pointer to me about which rlkit version is used in your experiments? Or, where can I download the rlkit library you used in the experiments?

cql_mujoco_new.py:124: SyntaxWarning: "is not" with a literal. Did you mean "!="? if (gpu_str is not ""): No personal conf_private.py found. doodad not detected Traceback (most recent call last): File "cql_mujoco_new.py", line 6, in <module> from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic File "/home/hsinyu/rlkit/rlkit/torch/sac/policies/__init__.py", line 1, in <module> from rlkit.torch.sac.policies.base import ( File "/home/hsinyu/rlkit/rlkit/torch/sac/policies/base.py", line 11, in <module> from rlkit.torch.core import torch_ify, elem_or_tuple_to_numpy ImportError: cannot import name 'elem_or_tuple_to_numpy' from 'rlkit.torch.core' (/home/hsinyu/rlkit/rlkit/torch/core.py)

QF_Loss backprops policy network

In the CQL trainer, the policy_loss is formulated before the QF_Loss is, but the QF_Loss backprops the policy network before policy_loss does, which causes a Torch error. Would the intended use be to optimize policy network on the policy_loss before formulating the QF_Loss (and still optimize the policy using the QF_Loss) or to not reparametrize the policy output when formulating the QF_Loss (eg line 201)?

`rlkit/torch/sac/cql.py` not found

I cannot find the path rlkit/torch/sac/cql.py in the rlkit master branch that i pulled.

Can you let me know which branch are you referring to?

Thanks!

ResolvePackageNotFound

I wont to create conda env CQL/d4rl/environment/linux-gpu-env.yml
I run this command
conda env create -f linux-gpu-env.yml

I got this error
`ResolvePackageNotFound:

  • matplotlib==2.0.2=np111py35_0
  • python=3.5.2
  • path.py==10.3.1=py35_0
  • mako==1.0.6=py35_0
  • joblib=0.9.4
  • numba==0.35.0=np111py35_0
  • python-dateutil==2.6.1=py35_0`
    Please teach me now version

code bugs

q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)

q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1) is wrong.
It should be q1_next_actions = self._get_tensor_values(next_obs, new_curr_actions_tensor, network=self.qf1)

SAC CQL: Potential mismatch between observations and actions fed to the Q network in CQL computation.

Greetings.

Thank you for your amazing work on Offline RL, as well as for open-sourcing the code.

This present issue pertains to the computation for the lower bounding component of the SAC CQL:

        ## add CQL
        random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda()
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy)
        q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
        q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)
        q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)

Namely, at line 236, , the actions new_curr_actions_tensor of the policy for the next states in the batch, next_obs, are computed by feeding the latter to the policy.

new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy)

When computing the corresponding Q value, however, the next_curr_actions_tensor are fed to the Q networks with what seems to be the observations at the current time step obs:

q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)

Shouldn't it be next_obs instead of obs at those two lines 241 and 242?
Or is there a specific reason we might want to use actions of the next states to compute the Q value for the current observations batch (states) ?
(Sampling "incorrect" actions with regard to the current observations (states) on purpose ?)

Thank you for your time, and sorry for the inconvenience.

About the readability

Hi Aviral,

In the paper, you claim CQL can be implemented with less than 20 lines of code, but it's really difficult to identify these "20 lines of code" from the current version of your project which is built upon other projects. Would you please point out which part of code exactly corresponds to the core of CQL? I really like your idea of CQL, both the theoretical part and its simplicity, but currently, it seems very hard to follow.

Best,
Zhi-Hong

Make checkpoints public

Hi, would it be possible to release the checkpoints for this implementation? Would be very grateful for this.

Function argument problem about expl_path_collector.collect_new_paths()

Error when creating quantile agent for using CQL for Atari

Hi Aviral,

when I tried to run the 'train.py' script for Atari games, I noticed an error when creating the quantile agent in 'atari/batch_rl/multi_head/quantile_agent.py': In line 117 you are passing the argument "minq_weight" into the init function of the rainbow agent, but this class has no argument "minq_weight". After I deleted the argument the code runs perfectly.

Best,
Timo

Logsumexp calculation in CQL(H) for continuous action space

Hi Aviral,

Thanks for sharing your code!

My concern is about logsumexp calculation in CQL(H) for D4RL. On page 29 of your paper, you mentioned your technique for computing logsumexp, which looks fine to me. However, the code seems to be a bit different in a number of ways:

  1. $N=20$ rather than 10 for the current policy, which is ok (i.e., not a bug).
  2. Following my derivation based on your derivation on page 29, you seem to forget to multiply by $2N_1$ and $2N_2$ in your code (cql.py, line 263 to 271). In your code, you simply plugged in the log probabilities without doing the multiplication first. Is this a bug?

Also, a perhaps unrelated question: How did you come up with this way of computing logsumexp? Splitting the sum into two separate expectations (instead of just using one) w.r.t. two different distributions is not so intuitive for me.

image

Potential mismatch between math and code for CQL(rho)

This is a question regarding how CQL(rho) works in terms of code ๐Ÿ˜Š.

In the CQL section (starting from line 235) within /CQL/d4rl/rlkit/torch/sac/cql.py, we first computed:

cat_q1 = torch.cat(
    [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
)
cat_q2 = torch.cat(
    [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
)

and then used them to compute

min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp

I'm a bit confused about why the Q values of actions drawn from three distinct distributions can be used to compute this quantity:

  • q1_rand: uniform distribution
  • q1_pred: dataset distribution
  • q1_curr_actions and q1_next_actions: last-iteration policy

Here are my questions:

  • In Appendix A section CQL(rho), don't we have that the expectation is with respect to the rho distribution only (which we have chosen to be the last-iteration policy)?
  • Why do we use log-sum-exp here while the corresponding term (the first term) in Equation 7 of the paper does not contain log at all?

I'm able to completely understand how CQL(H) works in the codebase though.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.