Giter Site home page Giter Site logo

coax's Introduction

coax's People

Contributors

kristianholsheimer avatar microsoft-github-policy-service[bot] avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

coax's Issues

Best way to save a model/agent?

Suppose I train a Q-Learning agent on FrozenLake as per this tutorial, what is the best way to save this model/agent for use elsewhere?

I'm thinking I need to just save the Q-values and the potentially the policy pi? Interestingly the doc seems to have a Q.save(...) and a pi.save(...), however when I try to use them they appear to be possibly not currently implemented in the code?

Many thanks for any help, and for this fantastic lib! :)

coax.Policy(func_pi, env) gives RuntimeError: Unimplemented: DNN library is not found when running the PPO on Pong example

Describe the bug
When running the PPO on Pong example, the line pi = coax.Policy(func_pi, env) throws the following exception:

--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-e81d2937c9e6> in <module>
      1 # function approximators
----> 2 pi = coax.Policy(func_pi, env)
      3 v = coax.V(func_v, env)
      4 
      5 # target networks

/opt/conda/lib/python3.8/site-packages/coax/_core/policy.py in __init__(self, func, env, observation_preprocessor, proba_dist, random_seed)
     70             proba_dist = ProbaDist(env.action_space)
     71 
---> 72         super().__init__(
     73             func=func,
     74             observation_space=env.observation_space,

/opt/conda/lib/python3.8/site-packages/coax/_core/base_stochastic_func_type2.py in __init__(self, func, observation_space, action_space, observation_preprocessor, proba_dist, random_seed)
    151 
    152         # note: self._modeltype is set in super().__init__ via self._check_signature
--> 153         super().__init__(
    154             func=func,
    155             observation_space=observation_space,

/opt/conda/lib/python3.8/site-packages/coax/_core/base_func.py in __init__(self, func, observation_space, action_space, random_seed)
     99 
    100         # init function params and state
--> 101         self._params, self._function_state = transformed.init(self.rng, *example_data.inputs.args)
    102 
    103         # check if output has the expected shape etc.

/opt/conda/lib/python3.8/site-packages/haiku/_src/transform.py in init_fn(rng, *args, **kwargs)
    275     rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
    276     with base.new_context(rng=rng) as ctx:
--> 277       f(*args, **kwargs)
    278     return ctx.collect_params(), ctx.collect_initial_state()
    279 

<ipython-input-8-51963c5ee651> in func_pi(S, is_training)
     15         hk.Linear(env.action_space.n, w_init=jnp.zeros),
     16     ))
---> 17     X = shared(S, is_training)
     18     return {'logits': logits(X)}
     19 

<ipython-input-8-51963c5ee651> in shared(S, is_training)
      7     ])
      8     X = jnp.stack(S, axis=-1) / 255.  # stack frames
----> 9     return seq(X)
     10 
     11 

/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
    404         f = stateful.named_call(f, name=local_name)
    405 
--> 406       out = f(*args, **kwargs)
    407 
    408       # Notify parent modules about our existence.

/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in run_interceptors(bound_method, method_name, self, *args, **kwargs)
    261   """Runs any method interceptors or the original method."""
    262   if not interceptor_stack:
--> 263     return bound_method(*args, **kwargs)
    264 
    265   ctx = MethodContext(module=self,

/opt/conda/lib/python3.8/site-packages/haiku/_src/basic.py in __call__(self, inputs, *args, **kwargs)
    124         out = layer(out, *args, **kwargs)
    125       else:
--> 126         out = layer(out)
    127     return out
    128 

/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
    404         f = stateful.named_call(f, name=local_name)
    405 
--> 406       out = f(*args, **kwargs)
    407 
    408       # Notify parent modules about our existence.

/opt/conda/lib/python3.8/site-packages/haiku/_src/module.py in run_interceptors(bound_method, method_name, self, *args, **kwargs)
    261   """Runs any method interceptors or the original method."""
    262   if not interceptor_stack:
--> 263     return bound_method(*args, **kwargs)
    264 
    265   ctx = MethodContext(module=self,

/opt/conda/lib/python3.8/site-packages/haiku/_src/conv.py in __call__(self, inputs)
    193       w *= self.mask
    194 
--> 195     out = lax.conv_general_dilated(inputs,
    196                                    w,
    197                                    window_strides=self.stride,

/opt/conda/lib/python3.8/site-packages/jax/_src/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision)
    596         np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
    597         window_strides, padding)
--> 598   return conv_general_dilated_p.bind(
    599       lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
    600       lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),

/opt/conda/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
    280     top_trace = find_top_trace(args)
    281     tracers = map(top_trace.full_raise, args)
--> 282     out = top_trace.process_primitive(self, tracers, params)
    283     return map(full_lower, out) if self.multiple_results else full_lower(out)
    284 

/opt/conda/lib/python3.8/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
    626 
    627   def process_primitive(self, primitive, tracers, params):
--> 628     return primitive.impl(*tracers, **params)
    629 
    630   def process_call(self, primitive, f, tracers, params):

/opt/conda/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    237   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    238   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 239   return compiled_fun(*args)
    240 
    241 

/opt/conda/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
    355   device, = compiled.local_devices()
    356   input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 357   out_bufs = compiled.execute(input_bufs)
    358   check_special(prim, out_bufs)
    359   return result_handler(*out_bufs)

RuntimeError: Unimplemented: DNN library is not found.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

Expected behavior
Code should run without throwing any exception

Desktop (please complete the following information):

  • OS: Ubuntu 18.04
  • Python version: 3.8.5
  • CUDA version (if applicable): release 10.1, V10.1.243
  • jax version: 0.2.9
  • jaxlib version: 0.1.60+cuda101
  • coax version: 0.1.6
  • (any other version numbers that may be relevant)

To Reproduce
Steps to reproduce the behavior:

  1. create a docker image based on this dockerfile.
  2. run the jupyter lab from the container:
docker build -t rl:gpu -f gpu.Dockerfile . && \
docker run -it \
    --gpus all \
    -u vscode \
    -p 8888:8888 \
    -v $(pwd):/workspaces/rl \
    -w /workspaces/rl \
    --rm \
    --name rl \
    rl:gpu jupyter lab --ip 0.0.0.0 --no-browser
  1. Run the ppo.py script in the PPO on Pong example.

Script to run:

import os

# set some env vars
os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu')     # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet

import gym
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam


# the name of this script
name = 'ppo'

# env with preprocessing
env = gym.make('PongNoFrameskip-v4')
env = gym.wrappers.AtariPreprocessing(env)
env = coax.wrappers.FrameStacking(env, num_frames=3)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")


def shared(S, is_training):
    seq = hk.Sequential([
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,
        hk.Flatten(),
    ])
    X = jnp.stack(S, axis=-1) / 255.  # stack frames
    return seq(X)


def func_pi(S, is_training):
    logits = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = shared(S, is_training)
    return {'logits': logits(X)}


def func_v(S, is_training):
    value = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
    ))
    X = shared(S, is_training)
    return value(X)


# function approximators
pi = coax.Policy(func_pi, env)

Additional context
The DQN on CartPole example runs without problems and is utilizing the GPU, so I think all the necessary GPU drivers are correctly installed.

Parallelizing environments with coax?

Hi there,

Great library! Thanks a ton for doing this work.

I'm working on some environments that require a lot of samples and so was reading through the documentation looking for some sort of built-in way to handle running multiple environments in parallel. I didn't find anything, so I'm wondering, is there a utility I'm missing or perhaps a recommended way to handle this with coax?

If not, are there plans to add support for parallel environments in the future?

Thanks again, so far coax is a delight to use!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.