Giter Site home page Giter Site logo

mesh-transformer-jax's People

Contributors

ablacklama avatar aeroscripts avatar cclauss avatar curtisasmith avatar djoldman avatar erichallahan avatar jacobfnl avatar kingoflolz avatar leogao2 avatar linagee avatar minimaxir avatar morganmcg1 avatar narphorium avatar nostalgebraist avatar reouno avatar rozanecm avatar srulikbd avatar stellaathena avatar trisongz avatar versae avatar vfbd avatar widiba03304 avatar yaserabdelaziz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mesh-transformer-jax's Issues

GPU support

Hi, would it be possible to run GPT-J on GPU in the future?

TypeError: don't know how to handle uri None

Whenever I try to run device_sample or device_train I get the following error here:

**Traceback (most recent call last):
  File "device_sample.py", line 30, in <module>
    params = json.load(open(args.config))
  File "/home/nin/.local/lib/python3.8/site-packages/smart_open/smart_open_lib.py", line 235, in open
    binary = _open_binary_stream(uri, binary_mode, transport_params)
  File "/home/nin/.local/lib/python3.8/site-packages/smart_open/smart_open_lib.py", line 394, in _open_binary_stream
    raise TypeError("don't know how to handle uri %s" % repr(uri))
TypeError: don't know how to handle uri None**

I have all the required libraries, is there a config file that I am missing somewhere?

Jax TPU Issue

Hi @kingoflolz, amazing work!!

I am trying to test the model on TPU VM using the step_383500/ data.

On following the steps mentioned in jax-quickstart-tpu-vm and then running jax.devices() returns correct output which is

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

After that when I clone your repo and run pip install -r requirements.txt and then run pip install jax==0.2.12 as you mentioned in your fine tune docs, but It gives this error,

2021-07-18 05:15:21.807145: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped)

so I run this from jax-quickstart-tpu-vm docs,

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

after which jax.devices() gives this output

>>> import jax
>>> jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

Somehow jax fails to detect the TPU, don't know what's the issue, I am running Google Cloud's TPU VM v3-8 with v2-alpha software version. Would really appreciate your help.

TPU Requirements

when running device_sample.py on my trained model i got this error

hello once again,
i have trained the model on my dataset thanks to you. but when i now run device_sample.py with the same config file i used to train i got this following error.

Traceback (most recent call last):
  File "device_sample.py", line 99, in <module>
    output = network.generate(batched_tokens, length, 512, {"top_p": np.ones(total_batch) * 0.9,
  File "/home/adnanmunye/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 328, in generate
    return self.generate_xmap(self.state,
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 613, in fun_mapped
    out_flat = xmap_p.bind(
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 820, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/core.py", line 1552, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 823, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/core.py", line 607, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 644, in xmap_impl
    xmap_callable = make_xmap_callable(
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 671, in make_xmap_callable
    _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
  File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 1633, in _check_out_avals_vs_out_axes
    raise TypeError(f"One of xmap results has an out_axes specification of "
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

could you please help again?

RuntimeError: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0

Hi,
I am facing this issue while running the "mesh-transformer-jax/resharding_example.py"
RuntimeError: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0

2021-07-21 16:15:16.513776: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1981] Execution of replica 0 failed: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0
Traceback (most recent call last):
File "resharding_example.py", line 47, in
network = CausalTransformer(params)
File "/drive4/user1/GPT_J/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 272, in init
self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 516, in fun_mapped
out_flat = xmap_p.bind(
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 652, in bind
return core.call_bind(self, fun, *args, **params) # type: ignore
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 655, in process
return trace.process_xmap(self, fun, tracers, params)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
return primitive.impl(f, *tracers, **params)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 539, in xmap_impl
return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1130, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0

System Config mention below:
GPU: 11177MiB

RAM: 62G
export XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform
Python 3.8.7

Reduce inference time

Hi,
I am using A100 GPU, and its taking 10 seconds to generate 150 tokens. How can we reduce the inference time?

ValueError: cannot reshape array of size 1 into shape (0,8)

ValueError Traceback (most recent call last)
in
24
25 mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
---> 26 devices = np.array(jax.devices()).reshape(mesh_shape)
27
28 maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

ValueError: cannot reshape array of size 1 into shape (0,8)

Temp 0 doesn't work properly

Using the online demo, with a prompt of a few words and a temperature of zero tends to give repeated punctuation symbols as a continuation. Nudge it up to temp = .01 and it behaves as GPT-2 and 3 do at zero temperature: usually repetitive, but on a full sentence level. I think the program may not handle inputs of 0 temperature correctly.

Include context prompt when generating text in Colab

It makes it a bit easier to maintain what the prompt is when sharing screenshots of generated text.

Screen Shot 2021-06-08 at 9 49 19 PM

One line change from

samples.append(tokenizer.decode(o))

to

samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

Question about transformer models

Hi, this is a general question, feel free to close this issue if it doesn't apply.
How do I know if I should use T5-11b or GPT-J for a certain task? In this case, is for a conversational chatbot.
I know that T5 uses an encoder and decoder and GPT models only a decoder.

Import optax on Colab gives: cannot import name 'flags' from 'jax.config'

Did anyone else encounter this error while trying the demo.ipynb on colab?

ImportError                               Traceback (most recent call last)

<ipython-input-9-72cd76e3a907> in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8 

6 frames

/usr/local/lib/python3.7/dist-packages/optax/__init__.py in <module>()
     16 """Optax: composable gradient processing and optimization, in JAX."""
     17 
---> 18 from optax._src.alias import adabelief
     19 from optax._src.alias import adagrad
     20 from optax._src.alias import adam

/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py in <module>()
     20 import jax.numpy as jnp
     21 
---> 22 from optax._src import combine
     23 from optax._src import privacy
     24 from optax._src import schedule

/usr/local/lib/python3.7/dist-packages/optax/_src/combine.py in <module>()
     16 """Flexibly compose gradient transformations."""
     17 
---> 18 from optax._src import transform
     19 GradientTransformation = transform.GradientTransformation
     20 

/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py in <module>()
     18 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
     19 
---> 20 import chex
     21 import jax
     22 import jax.numpy as jnp

/usr/local/lib/python3.7/dist-packages/chex/__init__.py in <module>()
     15 """Chex: Testing made fun, in JAX!"""
     16 
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_gt
     19 from chex._src.asserts import assert_devices_available

/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py in <module>()
     29 import jax
     30 import jax.numpy as jnp
---> 31 import jax.test_util as jax_test
     32 import numpy as np
     33 import tree as dm_tree

/usr/local/lib/python3.7/dist-packages/jax/test_util.py in <module>()
     33 from . import dtypes as _dtypes
     34 from . import lax
---> 35 from .config import flags, bool_env, config
     36 from ._src.util import partial, prod
     37 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce

ImportError: cannot import name 'flags' from 'jax.config' (/usr/local/lib/python3.7/dist-packages/jax/config.py)

Device_train execution halts

I am running device_train.py to train a custom model on a smaller dataset. I have only 4-5 txt files and converted txt files into the tfrecord using the code given in the below link

https://github.com/EleutherAI/gpt-neo/blob/master/data/create_tfrecords.py

I get only 2-3 tfrecords files.

Now, When I execute the device_train.py on these tfrecords files, it runs for 15-20 minutes and then shows following error.

Traceback (most recent call last):
File "device_train.py", line 268, in
train_step(network, train_dataset.get_samples())
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 70, in get_samples
return self.get_samples()
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 70, in get_samples
return self.get_samples()
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 70, in get_samples
return self.get_samples()
[Previous line repeated 952 more times]
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 66, in get_samples
return next(self.sample_fn)
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 44, in sample_once
file = tf.data.TFRecordDataset(i, compression_type=compression).map(self.parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1927, in map
return ParallelMapDataset(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4522, in init
self._map_func = StructuredFunctionWrapper(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3712, in init
self._function = fn_factory()
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3134, in get_concrete_function
graph_function = self._get_concrete_function_garbage_collected(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3100, in _get_concrete_function_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3687, in wrapped_fn
ret = wrapper_helper(*args)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3617, in wrapper_helper
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.impl.api.StagingError: in user code:

/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py:85 tf_parse  *
    parsed_features = tf.io.parse_single_example(example_proto, features)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper  **
    return target(*args, **kwargs)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py:452 parse_single_example_v2
    return parse_example_v2(serialized, features, example_names, name)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
    return target(*args, **kwargs)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py:314 parse_example_v2
    outputs = _parse_example_raw(serialized, example_names, params, name=name)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py:350 _parse_example_raw
    outputs = gen_parsing_ops.parse_example_v2(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/gen_parsing_ops.py:748 parse_example_v2
    _, _, _op, _outputs = _op_def_library._apply_op_helper(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py:517 _apply_op_helper
    values = ops.convert_to_tensor(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/profiler/trace.py:163 wrapped
    return func(*args, **kwargs)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1566 convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py:339 _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py:264 constant
    return _constant_impl(value, dtype, shape, name, verify_shape=False,
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py:286 _constant_impl
    const_tensor = g._create_op_internal(  # pylint: disable=protected-access
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:599 _create_op_internal
    return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3557 _create_op_internal
    ret = Operation(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:2057 __init__
    tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output)  # pylint: disable=protected-access
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:405 _create_with_tf_output
    ret = Tensor(op, value_index, dtype)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:381 __init__
    self._dtype = dtypes.as_dtype(dtype)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/dtypes.py:633 as_dtype
    return _ANY_TO_TF[type_value]

RecursionError: maximum recursion depth exceeded while calling a Python object

Memory overflow

Hi, when using 16G T4 to load the model, it causes memory overflow? Is there not enough memory?

CPU support

I'm interested in implementing CPU inference and would be happy to submit a PR for it, if I can figure out xmesh internals for model parallelism. I've taken a look at the resharding example but it looks like that is for a single GPU? Would an 8X or 16X cluster each with 8GB of memory be a viable platform with xmap for distributed inference?

What protocol is xmap using for distributed inference, is it gRPC based?

training problems

hey.
I'm trying launching fine tuning the gpt-J model, but encounter data loading problem-
when running:
python3 device_train.py --config configs/6B_roto_256.json --tune-model-path 'gs://gpt3-srulik/step_383500
I get the following error:
INFO:absl:Starting the local TPU driver. INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local:// INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host TPU Interpreter jax devices: 8 jax runtime initialized in 3.14744s --tune_model_pathpassed: we are beginning a fine-tuning run path to load checkpoint from: gs://gpt3-srulik/step_383500 setting up datasets Traceback (most recent call last): File "device_train.py", line 234, in <module> for k, v in params['val_set'].items(): AttributeError: 'str' object has no attribute 'items'
I have created a tfrecord using the standard gpt-neo code, uploaded it to my bucket, and changed the config files accordingly.
what could be the problem?

Module 'jax' has no attribute 'process_index'

Hi,
I have this problem while running the colab code:

module 'jax' has no attribute 'process_index'

I am using the colab code without modifying anything.

This is the screenshot of the error:

image

Thank you

Issue while importing trained model

I am using colab_demo to generate content on my V3-8 machine. Earlier I had trained my custom model with custom dataset.
I had converted text files to tfrecords and then created a model using device_train.py.

Now when I executing the code in the colab_demo which is pointing to the my trained model repo in GS bucket. Then I get the following errror:

Traceback (most recent call last):
    network.state = read_ckpt(network.state, "gs://agencyq-gptj6b/trained-model/step_5/", devices.shape[1])
  File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/checkpoint.py", line 147, in read_ckpt
    assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
AssertionError: Incompatible checkpoints (8,) vs (8, 4096)

Finetuning on colab

Hey there,
I'm trying to piece together a git with some one-liners to set up mesh-transformers-jax. Currently it's pretty well a skeleton, but I'm at the stage where I want to provide some functionality to fine-tune and test some few-shot training examples.

I'm finding the source-code for the training files difficult to go through, it runs on lot of built in methods that I'm unfamiliar with, with little documentation. Can you provide any insight on what you think is the simplest way of going about fine-tuning.

Specifically some of my questions are:

  1. build_model's .save and .train methods come from a TPUCluster object, which in turn pulls it's train to each individual network_runner object, which pulls its train from a ray.remote dresser? I'm not familiar with ray, can we skip this and pull our train from causaltransformer.train itself? What are the drawbacks.
  2. Do you have tips on getting started with causaltransformers train_xmap? For ctxt and tgt is it just supplying fill-in-the blank examples?
  3. what format should the data be in to load to the original script when running gcp tpu nodes,
  4. what is easiest way to format txt data into that?

If you want to take a look at the project check out GPT-J-Simple

Log generated samples during training

Does anyone have any pointers of how to log generated examples during training? I'm planning on logging them to W&B, so I can check in on what the generation quality is like.

I tried naively just grabbing some of the code from device_sample.py and calling it once validation is finished. It generates some sample text, but I get the following error when training resumes:

RuntimeError: Resource exhausted: Attempting to reserve 4.44G at the bottom of memory. That was not possible. There are 9.61G free, 0B reserved, and 2.65G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well)

I guess its becuase I'm calling network.move_xmap in order to generate, and need to undo that before resuming training? Any pointers how what I need to do?

Full stack trace: https://gist.github.com/morganmcg1/0e4344df49fe3b43243505992ce998d5

Code used for generation:

total_batch = per_replica_batch * jax.device_count() // cores_per_replica

context = "EleutherAI is" #input("Type input:")
tokens = tokenizer.encode(context)

provided_ctx = len(tokens)
pad_amount = seq - provided_ctx

padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * total_batch)
                        
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
top_p=0.9
temp=0.75
gen_len=512

start = time.time()

### taken from device_sample
local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
# del network.state["opt_state"]
network.state = network.move_xmap(network.state, np.zeros(local_shards))
###

output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, 
                                                                                    "temp": np.ones(total_batch) * temp})

Web Demo Source Code

Could the source code for the web demo be pushed into this or other repo? Thanks in advance.

resharding_example.py errors with jax==0.2.13

Running resharding example gave the following error. Installing jax==0.2.12 fixed the issue.

TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

Environment:
python 3.8.5
ubuntu 20.04
a single rtx3090

all python lib versions exactly as in requirements.txt except
flask==1.1.4
ftfy==6.0.3
func-timeout==4.3.5
smart-open==5.1.0

Colab notebook demo IndexError

This doesn't run correctly for me anymore. A day or two ago it was working fine.

https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb

total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
FilteredStackTrace: IndexError: Too many indices for array: 4 non-None/Ellipsis indices for dim 3.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

IndexError                                Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _canonicalize_tuple_index(arr_ndim, idx)
   4860   if len_without_none > arr_ndim:
   4861     msg = "Too many indices for array: {} non-None/Ellipsis indices for dim {}."
-> 4862     raise IndexError(msg.format(len_without_none, arr_ndim))
   4863   ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
   4864   ellipsis_index = next(ellipses, None)

IndexError: Too many indices for array: 4 non-None/Ellipsis indices for dim 3.

Illegal Instruction Core dumped (resharding_example.py)

Whenever I try to run resharding_example.py I keep this error here.
Illegal instruction (core dumped)

I made sure to tar the weights and throw them in the root folder of the project and I installed the dependencies needed.
Is there still something that i am missing?

Return probabilities from CausalTransformer.generate()

The GPT-3 API allows you to fetch probabilities when you generate text. This is helpful for debugging or when using the probabilities to do few-shot classification. It would be great if this feature was also available for GPT-J.

For each generated token there would also be a full list of probabilities for each item in the vocabulary. Then you could look up the probability at any step given a token.

gpt-3_probabilities_1500x1124

Advice on running the model with full weights?

Hello, I've had good experiences running the slimmed down model. It's functionality is awesome and you guys have made it really easy to implement. I'm very curious to also run the model with the full weights, which you've kindly made available for download.

I've tried to do this in a paid-tier colab notebook, but even with the extra RAM and storage, the notebook crashes with an out of RAM memory error when I've untarred the checkpoints and run

network.state = read_ckpt(network.state, "/content/step_383500/", devices.shape[1])

I appreciate this is a problem with colab that has nothing to do with you. Still, I wonder do you have any advice on how this model might be run without reserving a TPU instance and refactoring all the code that you guys have so usefully provided?

Thanks!

ImportError: cannot import name 'OptState' from 'optax._src.transform' (/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py)

running

from mesh_transformer import util

gives

---------------------------------------------------------------------------

ImportError                               Traceback (most recent call last)

<ipython-input-5-8f674c418b34> in <module>()
----> 1 from mesh_transformer import util

/usr/local/lib/python3.7/dist-packages/mesh_transformer/util.py in <module>()
      2 import jax.numpy as jnp
      3 from jax.experimental.pjit import with_sharding_constraint
----> 4 from optax._src.transform import OptState, GradientTransformation, AdditiveWeightDecayState
      5 
      6 

ImportError: cannot import name 'OptState' from 'optax._src.transform' (/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py)

How to configure greedy sampling?

Great work folks! From the blog post https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/

Completion on a question from BoolQ (SuperGLUE). While both sampling methods result in the same correct conclusion, the nucleus sampling hallucinates and contains incorrect reasoning, while the greedy sampling answers concisely and reasonably. In general, we observed that greedy sampling is more accurate and contains less hallcinations than nucleus sampling when the output is supposed to be short like this, which is predictable given that classification task is usually done with greedy sampling.

I am using the code from the colab you have published. How do I configure greedy sampling? What should I set top_p to? Or is there some other way?

Batch_inference with `resharding_example.py`

Hi!

Thanks for this amazing model! Could you please help to provide batch inference for the infer function in resharding_example.py?

I've tried the following:

infer(["EleutherAI is", "sunny day"]), which gives a single string instead of two strings.

I then tried to modify the function for batches:

def infer(context, top_k=40, top_p=0.9, temp=1, gen_len=20):
  tokens = tokenizer.batch_encode_plus(context, padding=True).input_ids
  batched_tokens = np.array(tokens).astype(np.uint32)
  length = np.ones(batched_tokens.shape)
  output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp})

  samples = []
  decoded_tokens = output[1][0]

  for o in decoded_tokens[:, :, 0]:
    samples.append(tokenizer.decode(o))

    print(f"completion done in {time.time() - start:06}s")
    return samples

But got the error in jax/experimental/maps.py in _get_axis_sizes(args_flat, in_axes_flat, global_axis_sizes, axis_resource_count):

The size of axis batch was previously inferred to be 2, but found an argument of shape (1,) with in_axes specification ['batch', ...]. Shape mismatch occurs in dimension 0: 1 != 2. The output of batched_tokens.shape is (2,4) and length.shape is (2,4), so I am not sure what the shape mismatch is referring to. It occurs when the length is shape (1,4) as well.

Attached screenshot of error below:

image


I am using the params from the resharding_example.py:

params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,
  "early_cast": True,
  "seq": 2048,
  "cores_per_replica": 1,  # only running on one GPU
  "per_replica_batch": 1,
}

and I followed the installation from here: https://github.com/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb (jax==0.2.1.2)

when Try to run using any production wsgi server it gives '_thread._local' object has no attribute 'env'

Tried deploying it using

  1. Flask + systemd daemon service
  2. building docker example provided

ERROR:uvicorn.error:Exception in ASGI application Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/uvicorn/protocols/http/h11_impl.py", line 369, in run_asgi result = await app(self.scope, self.receive, self.send) File "/usr/local/lib/python3.6/dist-packages/uvicorn/middleware/proxy_headers.py", line 59, in __call__ return await self.app(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/fastapi/applications.py", line 201, in __call__ await super().__call__(scope, receive, send) # pragma: no cover File "/usr/local/lib/python3.6/dist-packages/starlette/applications.py", line 112, in __call__ await self.middleware_stack(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/middleware/errors.py", line 181, in __call__ raise exc from None File "/usr/local/lib/python3.6/dist-packages/starlette/middleware/errors.py", line 159, in __call__ await self.app(scope, receive, _send) File "/usr/local/lib/python3.6/dist-packages/starlette/middleware/cors.py", line 78, in __call__ await self.app(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/exceptions.py", line 82, in __call__ raise exc from None File "/usr/local/lib/python3.6/dist-packages/starlette/exceptions.py", line 71, in __call__ await self.app(scope, receive, sender) File "/usr/local/lib/python3.6/dist-packages/starlette/routing.py", line 580, in __call__ await route.handle(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/routing.py", line 241, in handle await self.app(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/routing.py", line 52, in app response = await func(request) File "/usr/local/lib/python3.6/dist-packages/fastapi/routing.py", line 217, in app dependant=dependant, values=values, is_coroutine=is_coroutine File "/usr/local/lib/python3.6/dist-packages/fastapi/routing.py", line 151, in run_endpoint_function return await run_in_threadpool(dependant.call, **values) File "/usr/local/lib/python3.6/dist-packages/starlette/concurrency.py", line 40, in run_in_threadpool return await loop.run_in_executor(None, func, *args) File "/usr/lib/python3.6/concurrent/futures/thread.py", line 56, in run result = self.fn(*self.args, **self.kwargs) File "/usr/local/lib/python3.6/dist-packages/contextvars/__init__.py", line 38, in run return callable(*args, **kwargs) File "./main.py", line 49, in model_prediction re = MODEL_API.infer("what is AI?", length=30) File "./ops.py", line 115, in infer "temp": np.ones(self.total_batch) * temp, File "/app/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 314, in generate sampler_options) File "/usr/local/lib/python3.6/dist-packages/jax/experimental/maps.py", line 568, in fun_mapped resource_env = thread_resources.env AttributeError: '_thread._local' object has no attribute 'env'

The Collab notebook fails to run

Line 333 of transformer_shard.py is throwing this error since the last commit :
"TypeError: fun_mapped() got an unexpected keyword argument 'return_logits'"

out_axes specification issue when you run the script with sudo, works fine without sudo

out_axes specification issue when you run the script with sudo, works fine without sudo, But i want to run the script through a systemd service, so it gives this error.
even i run it as
sudo python3 device_serve.py it gives same error, t works fine with python3 device_serve.py

here's some stack trace.

key shape (8, 2) in shape (1, 2048) dp 1 mp 8 read from disk/gcs in 6.40554s Traceback (most recent call last): File "simple.py", line 149, in <module> output = network.generate(batched_tokens, length, gen_length, {"top_p": np.ones(total_batch) * 0.9, File "/home/ahmedjawed/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate return self.generate_xmap(self.state, File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 615, in fun_mapped out_flat = xmap_p.bind( File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 818, in bind return core.call_bind(self, fun, *args, **params) # type: ignore File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 821, in process return trace.process_xmap(self, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 646, in xmap_impl xmap_callable = make_xmap_callable( File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 673, in make_xmap_callable _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes raise TypeError(f"One of xmap results has an out_axes specification of " TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

Usage on GPUs

I think I've got everything installed but it is failing. It tries and fails to use TPU rather than GPU when I run the resharded script.

device_train.py Error: "AssertionError: Incompatible checkpoints"

Hey πŸ‘‹ first of all big thanks for the model and the repo! You've done a brilliant job here, and the model generates some AMAZING results to its size

An having an issue when trying to finetune the 6B model on Google Colab. I roughly followed the modification device_train.py -h said but I'm getting this error: AssertionError: Incompatible checkpoints (1, 4096) vs (1,)

Reqproduction

I ran the command with --tune-model-path pointing to the downloaded model and the modified configs/6B_roto_256.json file with the following changes - train&val sets is poiting to my dataset, changed the number of steps, changed tpu_size from 256 to 8 and changed cores_per_replica from 8 to 1. I suspect the cores_per_replica change made a difference in the mesh_shape which in turn effects the shards_in parameter in read_ckpt.

When I printed the in_shards thats getting to the read_ckpt function it said 1 although I see in the model files that it has 8 shard_X folders

Thanks

Error in colab_demo

I had converted the model into _slim format, not _slim_f16.

Now, when I execute the colab code and I get below error

loading netwrok from the Google storage
read from disk/gcs in 107.5s
Traceback (most recent call last):
  File "content_generation.py", line 90, in <module>
    print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])
  File "content_generation.py", line 77, in infer
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
  File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate
    return self.generate_xmap(self.state,
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 615, in fun_mapped
    out_flat = xmap_p.bind(
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 818, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 821, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 646, in xmap_impl
    xmap_callable = make_xmap_callable(
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 673, in make_xmap_callable
    _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes
    raise TypeError(f"One of xmap results has an out_axes specification of "
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

Serverless deployment example

Found you on ProductHunt today! πŸŽ‰ I’m very impressed by your work and enjoyed reading the blog posts.

I think there’s a great opportunity for a one-command-deploy api via serverless (serverless deploy). I think this option will make your work very accessible to hobbyists (easy deployment, low cost, low maintenance, scalable).

AWS recently released support for docker lambda functions, which the serverless framework now supports too.

Luckily, the serverless limit with docker is 10GB, so the 8.8GB slim weights for inference should fit snuggly.

The serverless.yml file could look like:

service: gpt-j
frameworkVersion: '2'
variablesResolutionMode: 20210326

provider:
  name: aws
  lambdaHashingVersion: 20201221
  runtime: python:3.8
  memorySize: 10240
    ecr:
    images:
      gpt-j:
        path: ./
        file: ./Dockerfile
functions:
  gpt-j:
    image:
      name: gpt-j
    events:
      - httpApi:
          method: post
          path: /predict

The dockerfile would look similar to your existing Dockerfile, but use the AWS runtime:

FROM public.ecr.aws/lambda/python:3.8

# install gpt-j
# app.py handles inference

COPY app.py   ./
CMD ["app.handler"]      

There would only be a single api call, since a lambda services only a single request at a time.

I wish I could be of more help, but my specialty at the moment is JavaScript, not Python. I will gladly test any demos though πŸ˜„ already having a Dockerfile example means half the work is probably done

Error when running device_sample.py

I have set up a TPU VM and when I try to run device_sample.py I get the following error when calling network.generate:

TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

Currently, I am only doing inference so I am using the slim weights. Therefore, I had to change the line of code that sets up the optimizer to this:

params['optimizer'] = optax.scale(0)

This is exactly what the collab notebook does which I can run without any errors. Do I need to use the full weights to run device_sample.py? Thank you any help is appreciated.

web app for the model

Hi,

I am trying to develop a webapp for the model as it is very good in writing, I used gradio, as I am a data scientist, but it give an eooro for environment. env to be precise, Please share how you developed that app or guide me I am looking forward to giving a demo soon please help.

JAX or Ray Library issue

I am using this repo to train on a small custom dataset and Jax - 0.2.16 to train this model. However requirements.txt says jax 0.2.12. I really don't know how ray and jax have been working internally but my assumptions is thos code shoud work fine on v3-8 TPU. When I execute train.py then generates following error

(pid=9454, ip=10.164.0.9) jax runtime initialization starting
2021-07-09 12:11:59,794 ERROR worker.py:78 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): ray::NetworkRunner.run() (pid=9454, ip=10.164.0.9)
File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/train_actor.py", line 24, in run
TypeError: new() missing 1 required positional argument: 'loops'

I am not sure, what this error is? Can someone help be debug this error? Is this error a library versioning issue or something else.

Multiple GPU Error

Hello @kingoflolz,
I am using 4 - T4 GPUs, and facing the below error:
Traceback (most recent call last): File "gpt_j_inference.py", line 59, in <module> network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica) File "/home/pragnakalp/Desktop/GPT-J/env_gptj/lib/python3.8/site-packages/mesh_transformer/checkpoint.py", line 145, in read_ckpt assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}" AssertionError: Incompatible checkpoints (1, 4096) vs (4, 4096)
I have passed cores_per_replica as 4.
How can I run the code using 4 GPUs, as single GPU results in Memory Error.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.