kingoflolz / mesh-transformer-jax Goto Github PK
View Code? Open in Web Editor NEWModel parallel transformers in JAX and Haiku
License: Apache License 2.0
Model parallel transformers in JAX and Haiku
License: Apache License 2.0
Hi, would it be possible to run GPT-J on GPU in the future?
Whenever I try to run device_sample or device_train I get the following error here:
**Traceback (most recent call last):
File "device_sample.py", line 30, in <module>
params = json.load(open(args.config))
File "/home/nin/.local/lib/python3.8/site-packages/smart_open/smart_open_lib.py", line 235, in open
binary = _open_binary_stream(uri, binary_mode, transport_params)
File "/home/nin/.local/lib/python3.8/site-packages/smart_open/smart_open_lib.py", line 394, in _open_binary_stream
raise TypeError("don't know how to handle uri %s" % repr(uri))
TypeError: don't know how to handle uri None**
I have all the required libraries, is there a config file that I am missing somewhere?
Hi @kingoflolz, amazing work!!
I am trying to test the model on TPU VM using the step_383500/ data.
On following the steps mentioned in jax-quickstart-tpu-vm and then running jax.devices()
returns correct output which is
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
After that when I clone your repo and run pip install -r requirements.txt
and then run pip install jax==0.2.12
as you mentioned in your fine tune docs, but It gives this error,
2021-07-18 05:15:21.807145: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped)
so I run this from jax-quickstart-tpu-vm docs,
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
after which jax.devices()
gives this output
>>> import jax
>>> jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
Somehow jax
fails to detect the TPU, don't know what's the issue, I am running Google Cloud's TPU VM v3-8 with v2-alpha software version. Would really appreciate your help.
I know that I can't run this locally without a TPU, however, I'm looking at implementing this on a portable device (high-end laptop). Would using an Intel Neural Compute Stick work?
If not, are there alternatives or do I have to use Google Collab's TPUs?
How to predict the summarized result of given input text
hello once again,
i have trained the model on my dataset thanks to you. but when i now run device_sample.py with the same config file i used to train i got this following error.
Traceback (most recent call last):
File "device_sample.py", line 99, in <module>
output = network.generate(batched_tokens, length, 512, {"top_p": np.ones(total_batch) * 0.9,
File "/home/adnanmunye/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 328, in generate
return self.generate_xmap(self.state,
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 613, in fun_mapped
out_flat = xmap_p.bind(
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 820, in bind
return core.call_bind(self, fun, *args, **params) # type: ignore
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/core.py", line 1552, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 823, in process
return trace.process_xmap(self, fun, tracers, params)
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/core.py", line 607, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 644, in xmap_impl
xmap_callable = make_xmap_callable(
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 671, in make_xmap_callable
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
File "/home/adnanmunye/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 1633, in _check_out_avals_vs_out_axes
raise TypeError(f"One of xmap results has an out_axes specification of "
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
could you please help again?
Hi,
I am facing this issue while running the "mesh-transformer-jax/resharding_example.py"
RuntimeError: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0
2021-07-21 16:15:16.513776: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1981] Execution of replica 0 failed: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0
Traceback (most recent call last):
File "resharding_example.py", line 47, in
network = CausalTransformer(params)
File "/drive4/user1/GPT_J/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 272, in init
self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 516, in fun_mapped
out_flat = xmap_p.bind(
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 652, in bind
return core.call_bind(self, fun, *args, **params) # type: ignore
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 655, in process
return trace.process_xmap(self, fun, tracers, params)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
return primitive.impl(f, *tracers, **params)
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 539, in xmap_impl
return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
File "/drive4/user1/GPT_J/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1130, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Resource exhausted: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0
System Config mention below:
GPU: 11177MiB
RAM: 62G
export XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform
Python 3.8.7
Hi,
I am using A100 GPU, and its taking 10 seconds to generate 150 tokens. How can we reduce the inference time?
ValueError Traceback (most recent call last)
in
24
25 mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
---> 26 devices = np.array(jax.devices()).reshape(mesh_shape)
27
28 maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
ValueError: cannot reshape array of size 1 into shape (0,8)
Using the online demo, with a prompt of a few words and a temperature of zero tends to give repeated punctuation symbols as a continuation. Nudge it up to temp = .01 and it behaves as GPT-2 and 3 do at zero temperature: usually repetitive, but on a full sentence level. I think the program may not handle inputs of 0 temperature correctly.
Hi, this is a general question, feel free to close this issue if it doesn't apply.
How do I know if I should use T5-11b or GPT-J for a certain task? In this case, is for a conversational chatbot.
I know that T5 uses an encoder and decoder and GPT models only a decoder.
Did anyone else encounter this error while trying the demo.ipynb on colab?
ImportError Traceback (most recent call last)
<ipython-input-9-72cd76e3a907> in <module>()
4 from jax.experimental import maps
5 import numpy as np
----> 6 import optax
7 import transformers
8
6 frames
/usr/local/lib/python3.7/dist-packages/optax/__init__.py in <module>()
16 """Optax: composable gradient processing and optimization, in JAX."""
17
---> 18 from optax._src.alias import adabelief
19 from optax._src.alias import adagrad
20 from optax._src.alias import adam
/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py in <module>()
20 import jax.numpy as jnp
21
---> 22 from optax._src import combine
23 from optax._src import privacy
24 from optax._src import schedule
/usr/local/lib/python3.7/dist-packages/optax/_src/combine.py in <module>()
16 """Flexibly compose gradient transformations."""
17
---> 18 from optax._src import transform
19 GradientTransformation = transform.GradientTransformation
20
/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py in <module>()
18 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
19
---> 20 import chex
21 import jax
22 import jax.numpy as jnp
/usr/local/lib/python3.7/dist-packages/chex/__init__.py in <module>()
15 """Chex: Testing made fun, in JAX!"""
16
---> 17 from chex._src.asserts import assert_axis_dimension
18 from chex._src.asserts import assert_axis_dimension_gt
19 from chex._src.asserts import assert_devices_available
/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py in <module>()
29 import jax
30 import jax.numpy as jnp
---> 31 import jax.test_util as jax_test
32 import numpy as np
33 import tree as dm_tree
/usr/local/lib/python3.7/dist-packages/jax/test_util.py in <module>()
33 from . import dtypes as _dtypes
34 from . import lax
---> 35 from .config import flags, bool_env, config
36 from ._src.util import partial, prod
37 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
ImportError: cannot import name 'flags' from 'jax.config' (/usr/local/lib/python3.7/dist-packages/jax/config.py)
I am running device_train.py to train a custom model on a smaller dataset. I have only 4-5 txt files and converted txt files into the tfrecord using the code given in the below link
https://github.com/EleutherAI/gpt-neo/blob/master/data/create_tfrecords.py
I get only 2-3 tfrecords files.
Now, When I execute the device_train.py on these tfrecords files, it runs for 15-20 minutes and then shows following error.
Traceback (most recent call last):
File "device_train.py", line 268, in
train_step(network, train_dataset.get_samples())
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 70, in get_samples
return self.get_samples()
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 70, in get_samples
return self.get_samples()
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 70, in get_samples
return self.get_samples()
[Previous line repeated 952 more times]
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 66, in get_samples
return next(self.sample_fn)
File "/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py", line 44, in sample_once
file = tf.data.TFRecordDataset(i, compression_type=compression).map(self.parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1927, in map
return ParallelMapDataset(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4522, in init
self._map_func = StructuredFunctionWrapper(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3712, in init
self._function = fn_factory()
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3134, in get_concrete_function
graph_function = self._get_concrete_function_garbage_collected(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3100, in _get_concrete_function_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3687, in wrapped_fn
ret = wrapper_helper(*args)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3617, in wrapper_helper
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.impl.api.StagingError: in user code:
/home/paramjeetsingh80/mesh-transformer-jax/tfrecord_loader.py:85 tf_parse *
parsed_features = tf.io.parse_single_example(example_proto, features)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper **
return target(*args, **kwargs)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py:452 parse_single_example_v2
return parse_example_v2(serialized, features, example_names, name)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py:314 parse_example_v2
outputs = _parse_example_raw(serialized, example_names, params, name=name)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py:350 _parse_example_raw
outputs = gen_parsing_ops.parse_example_v2(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/ops/gen_parsing_ops.py:748 parse_example_v2
_, _, _op, _outputs = _op_def_library._apply_op_helper(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py:517 _apply_op_helper
values = ops.convert_to_tensor(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/profiler/trace.py:163 wrapped
return func(*args, **kwargs)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1566 convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py:339 _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py:264 constant
return _constant_impl(value, dtype, shape, name, verify_shape=False,
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py:286 _constant_impl
const_tensor = g._create_op_internal( # pylint: disable=protected-access
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:599 _create_op_internal
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3557 _create_op_internal
ret = Operation(
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:2057 __init__
tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:405 _create_with_tf_output
ret = Tensor(op, value_index, dtype)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:381 __init__
self._dtype = dtypes.as_dtype(dtype)
/home/paramjeetsingh80/.local/lib/python3.8/site-packages/tensorflow/python/framework/dtypes.py:633 as_dtype
return _ANY_TO_TF[type_value]
RecursionError: maximum recursion depth exceeded while calling a Python object
Hi, when using 16G T4 to load the model, it causes memory overflow? Is there not enough memory?
I'm interested in implementing CPU inference and would be happy to submit a PR for it, if I can figure out xmesh internals for model parallelism. I've taken a look at the resharding example but it looks like that is for a single GPU? Would an 8X or 16X cluster each with 8GB of memory be a viable platform with xmap for distributed inference?
What protocol is xmap using for distributed inference, is it gRPC based?
hey.
I'm trying launching fine tuning the gpt-J model, but encounter data loading problem-
when running:
python3 device_train.py --config configs/6B_roto_256.json --tune-model-path 'gs://gpt3-srulik/step_383500
I get the following error:
INFO:absl:Starting the local TPU driver. INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local:// INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host TPU Interpreter jax devices: 8 jax runtime initialized in 3.14744s
--tune_model_pathpassed: we are beginning a fine-tuning run path to load checkpoint from: gs://gpt3-srulik/step_383500 setting up datasets Traceback (most recent call last): File "device_train.py", line 234, in <module> for k, v in params['val_set'].items(): AttributeError: 'str' object has no attribute 'items'
I have created a tfrecord using the standard gpt-neo code, uploaded it to my bucket, and changed the config files accordingly.
what could be the problem?
I am using colab_demo to generate content on my V3-8 machine. Earlier I had trained my custom model with custom dataset.
I had converted text files to tfrecords and then created a model using device_train.py.
Now when I executing the code in the colab_demo which is pointing to the my trained model repo in GS bucket. Then I get the following errror:
Traceback (most recent call last):
network.state = read_ckpt(network.state, "gs://agencyq-gptj6b/trained-model/step_5/", devices.shape[1])
File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/checkpoint.py", line 147, in read_ckpt
assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
AssertionError: Incompatible checkpoints (8,) vs (8, 4096)
Hey there,
I'm trying to piece together a git with some one-liners to set up mesh-transformers-jax. Currently it's pretty well a skeleton, but I'm at the stage where I want to provide some functionality to fine-tune and test some few-shot training examples.
I'm finding the source-code for the training files difficult to go through, it runs on lot of built in methods that I'm unfamiliar with, with little documentation. Can you provide any insight on what you think is the simplest way of going about fine-tuning.
Specifically some of my questions are:
If you want to take a look at the project check out GPT-J-Simple
Does anyone have any pointers of how to log generated examples during training? I'm planning on logging them to W&B, so I can check in on what the generation quality is like.
I tried naively just grabbing some of the code from device_sample.py
and calling it once validation is finished. It generates some sample text, but I get the following error when training resumes:
RuntimeError: Resource exhausted: Attempting to reserve 4.44G at the bottom of memory. That was not possible. There are 9.61G free, 0B reserved, and 2.65G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well)
I guess its becuase I'm calling network.move_xmap
in order to generate, and need to undo that before resuming training? Any pointers how what I need to do?
Full stack trace: https://gist.github.com/morganmcg1/0e4344df49fe3b43243505992ce998d5
Code used for generation:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
context = "EleutherAI is" #input("Type input:")
tokens = tokenizer.encode(context)
provided_ctx = len(tokens)
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
top_p=0.9
temp=0.75
gen_len=512
start = time.time()
### taken from device_sample
local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
# del network.state["opt_state"]
network.state = network.move_xmap(network.state, np.zeros(local_shards))
###
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p,
"temp": np.ones(total_batch) * temp})
Could the source code for the web demo be pushed into this or other repo? Thanks in advance.
Running resharding example gave the following error. Installing jax==0.2.12 fixed the issue.
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
Environment:
python 3.8.5
ubuntu 20.04
a single rtx3090
all python lib versions exactly as in requirements.txt except
flask==1.1.4
ftfy==6.0.3
func-timeout==4.3.5
smart-open==5.1.0
This doesn't run correctly for me anymore. A day or two ago it was working fine.
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
network = CausalTransformer(params)
network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
FilteredStackTrace: IndexError: Too many indices for array: 4 non-None/Ellipsis indices for dim 3.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _canonicalize_tuple_index(arr_ndim, idx)
4860 if len_without_none > arr_ndim:
4861 msg = "Too many indices for array: {} non-None/Ellipsis indices for dim {}."
-> 4862 raise IndexError(msg.format(len_without_none, arr_ndim))
4863 ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
4864 ellipsis_index = next(ellipses, None)
IndexError: Too many indices for array: 4 non-None/Ellipsis indices for dim 3.
I am trying to finetune it on my custom txt data but i don't know how to convert txt file into tfrecords could you guide me please i am a learner at this point and your help would be very useful for me. thanks!
Whenever I try to run resharding_example.py I keep this error here.
Illegal instruction (core dumped)
I made sure to tar the weights and throw them in the root folder of the project and I installed the dependencies needed.
Is there still something that i am missing?
Hi, it's a simple fix, both the demo link leads to the Colab.
The GPT-3 API allows you to fetch probabilities when you generate text. This is helpful for debugging or when using the probabilities to do few-shot classification. It would be great if this feature was also available for GPT-J.
For each generated token there would also be a full list of probabilities for each item in the vocabulary. Then you could look up the probability at any step given a token.
Hello, I've had good experiences running the slimmed down model. It's functionality is awesome and you guys have made it really easy to implement. I'm very curious to also run the model with the full weights, which you've kindly made available for download.
I've tried to do this in a paid-tier colab notebook, but even with the extra RAM and storage, the notebook crashes with an out of RAM memory error when I've untarred the checkpoints and run
network.state = read_ckpt(network.state, "/content/step_383500/", devices.shape[1])
I appreciate this is a problem with colab that has nothing to do with you. Still, I wonder do you have any advice on how this model might be run without reserving a TPU instance and refactoring all the code that you guys have so usefully provided?
Thanks!
running
from mesh_transformer import util
gives
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
<ipython-input-5-8f674c418b34> in <module>()
----> 1 from mesh_transformer import util
/usr/local/lib/python3.7/dist-packages/mesh_transformer/util.py in <module>()
2 import jax.numpy as jnp
3 from jax.experimental.pjit import with_sharding_constraint
----> 4 from optax._src.transform import OptState, GradientTransformation, AdditiveWeightDecayState
5
6
ImportError: cannot import name 'OptState' from 'optax._src.transform' (/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py)
Great work folks! From the blog post https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/
Completion on a question from BoolQ (SuperGLUE). While both sampling methods result in the same correct conclusion, the nucleus sampling hallucinates and contains incorrect reasoning, while the greedy sampling answers concisely and reasonably. In general, we observed that greedy sampling is more accurate and contains less hallcinations than nucleus sampling when the output is supposed to be short like this, which is predictable given that classification task is usually done with greedy sampling.
I am using the code from the colab you have published. How do I configure greedy sampling? What should I set top_p
to? Or is there some other way?
Hi!
Thanks for this amazing model! Could you please help to provide batch inference for the infer
function in resharding_example.py
?
I've tried the following:
infer(["EleutherAI is", "sunny day"]), which gives a single string instead of two strings.
I then tried to modify the function for batches:
def infer(context, top_k=40, top_p=0.9, temp=1, gen_len=20):
tokens = tokenizer.batch_encode_plus(context, padding=True).input_ids
batched_tokens = np.array(tokens).astype(np.uint32)
length = np.ones(batched_tokens.shape)
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp})
samples = []
decoded_tokens = output[1][0]
for o in decoded_tokens[:, :, 0]:
samples.append(tokenizer.decode(o))
print(f"completion done in {time.time() - start:06}s")
return samples
But got the error in jax/experimental/maps.py in _get_axis_sizes(args_flat, in_axes_flat, global_axis_sizes, axis_resource_count):
The size of axis batch was previously inferred to be 2, but found an argument of shape (1,) with in_axes specification ['batch', ...]. Shape mismatch occurs in dimension 0: 1 != 2
. The output of batched_tokens.shape is (2,4) and length.shape is (2,4), so I am not sure what the shape mismatch is referring to. It occurs when the length is shape (1,4) as well.
Attached screenshot of error below:
I am using the params from the resharding_example.py
:
params = {
"layers": 28,
"d_model": 4096,
"n_heads": 16,
"n_vocab": 50400,
"norm": "layernorm",
"pe": "rotary",
"pe_rotary_dims": 64,
"early_cast": True,
"seq": 2048,
"cores_per_replica": 1, # only running on one GPU
"per_replica_batch": 1,
}
and I followed the installation from here: https://github.com/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb
(jax==0.2.1.2)
Hello @trisongz Hope you're doing well, do you have any plan to Implement Logits and RepetitionPenalty
something similar implemented here
https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
Thankyou.
Tried deploying it using
ERROR:uvicorn.error:Exception in ASGI application Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/uvicorn/protocols/http/h11_impl.py", line 369, in run_asgi result = await app(self.scope, self.receive, self.send) File "/usr/local/lib/python3.6/dist-packages/uvicorn/middleware/proxy_headers.py", line 59, in __call__ return await self.app(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/fastapi/applications.py", line 201, in __call__ await super().__call__(scope, receive, send) # pragma: no cover File "/usr/local/lib/python3.6/dist-packages/starlette/applications.py", line 112, in __call__ await self.middleware_stack(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/middleware/errors.py", line 181, in __call__ raise exc from None File "/usr/local/lib/python3.6/dist-packages/starlette/middleware/errors.py", line 159, in __call__ await self.app(scope, receive, _send) File "/usr/local/lib/python3.6/dist-packages/starlette/middleware/cors.py", line 78, in __call__ await self.app(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/exceptions.py", line 82, in __call__ raise exc from None File "/usr/local/lib/python3.6/dist-packages/starlette/exceptions.py", line 71, in __call__ await self.app(scope, receive, sender) File "/usr/local/lib/python3.6/dist-packages/starlette/routing.py", line 580, in __call__ await route.handle(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/routing.py", line 241, in handle await self.app(scope, receive, send) File "/usr/local/lib/python3.6/dist-packages/starlette/routing.py", line 52, in app response = await func(request) File "/usr/local/lib/python3.6/dist-packages/fastapi/routing.py", line 217, in app dependant=dependant, values=values, is_coroutine=is_coroutine File "/usr/local/lib/python3.6/dist-packages/fastapi/routing.py", line 151, in run_endpoint_function return await run_in_threadpool(dependant.call, **values) File "/usr/local/lib/python3.6/dist-packages/starlette/concurrency.py", line 40, in run_in_threadpool return await loop.run_in_executor(None, func, *args) File "/usr/lib/python3.6/concurrent/futures/thread.py", line 56, in run result = self.fn(*self.args, **self.kwargs) File "/usr/local/lib/python3.6/dist-packages/contextvars/__init__.py", line 38, in run return callable(*args, **kwargs) File "./main.py", line 49, in model_prediction re = MODEL_API.infer("what is AI?", length=30) File "./ops.py", line 115, in infer "temp": np.ones(self.total_batch) * temp, File "/app/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 314, in generate sampler_options) File "/usr/local/lib/python3.6/dist-packages/jax/experimental/maps.py", line 568, in fun_mapped resource_env = thread_resources.env AttributeError: '_thread._local' object has no attribute 'env'
Line 333 of transformer_shard.py is throwing this error since the last commit :
"TypeError: fun_mapped() got an unexpected keyword argument 'return_logits'"
out_axes specification issue when you run the script with sudo, works fine without sudo, But i want to run the script through a systemd service, so it gives this error.
even i run it as
sudo python3 device_serve.py
it gives same error, t works fine with python3 device_serve.py
here's some stack trace.
key shape (8, 2) in shape (1, 2048) dp 1 mp 8 read from disk/gcs in 6.40554s Traceback (most recent call last): File "simple.py", line 149, in <module> output = network.generate(batched_tokens, length, gen_length, {"top_p": np.ones(total_batch) * 0.9, File "/home/ahmedjawed/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate return self.generate_xmap(self.state, File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 615, in fun_mapped out_flat = xmap_p.bind( File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 818, in bind return core.call_bind(self, fun, *args, **params) # type: ignore File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 821, in process return trace.process_xmap(self, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 646, in xmap_impl xmap_callable = make_xmap_callable( File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 673, in make_xmap_callable _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes raise TypeError(f"One of xmap results has an out_axes specification of " TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
I think I've got everything installed but it is failing. It tries and fails to use TPU rather than GPU when I run the resharded script.
I am running resharding with a P5000 GPU and 30gb ram, but this is the error I get, please guide me thanks
Generated text on one line can run outside the cell which makes it difficult to read.
The fix in this SO answer helps resolve this.
Hey π first of all big thanks for the model and the repo! You've done a brilliant job here, and the model generates some AMAZING results to its size
An having an issue when trying to finetune the 6B model on Google Colab. I roughly followed the modification device_train.py -h
said but I'm getting this error: AssertionError: Incompatible checkpoints (1, 4096) vs (1,)
Reqproduction
I ran the command with --tune-model-path
pointing to the downloaded model and the modified configs/6B_roto_256.json
file with the following changes - train&val sets is poiting to my dataset, changed the number of steps, changed tpu_size
from 256 to 8 and changed cores_per_replica
from 8 to 1. I suspect the cores_per_replica
change made a difference in the mesh_shape
which in turn effects the shards_in
parameter in read_ckpt
.
When I printed the in_shards
thats getting to the read_ckpt
function it said 1 although I see in the model files that it has 8 shard_X
folders
Thanks
Hello @kingoflolz ,
How to make model stop generating further tokens if stop_sequence is found like we have in GPT-3.
I had converted the model into _slim format, not _slim_f16.
Now, when I execute the colab code and I get below error
loading netwrok from the Google storage
read from disk/gcs in 107.5s
Traceback (most recent call last):
File "content_generation.py", line 90, in <module>
print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])
File "content_generation.py", line 77, in infer
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate
return self.generate_xmap(self.state,
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 615, in fun_mapped
out_flat = xmap_p.bind(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 818, in bind
return core.call_bind(self, fun, *args, **params) # type: ignore
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 821, in process
return trace.process_xmap(self, fun, tracers, params)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 646, in xmap_impl
xmap_callable = make_xmap_callable(
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 673, in make_xmap_callable
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes
raise TypeError(f"One of xmap results has an out_axes specification of "
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
Any update on the issue #65
Even after following @kingoflolz Recommendations in the post #65 I am not able to move ahead as library itself breaks and import jax itself abort.
Found you on ProductHunt today! π Iβm very impressed by your work and enjoyed reading the blog posts.
I think thereβs a great opportunity for a one-command-deploy api via serverless (serverless deploy
). I think this option will make your work very accessible to hobbyists (easy deployment, low cost, low maintenance, scalable).
AWS recently released support for docker lambda functions, which the serverless framework now supports too.
Luckily, the serverless limit with docker is 10GB, so the 8.8GB slim weights for inference should fit snuggly.
The serverless.yml file could look like:
service: gpt-j
frameworkVersion: '2'
variablesResolutionMode: 20210326
provider:
name: aws
lambdaHashingVersion: 20201221
runtime: python:3.8
memorySize: 10240
ecr:
images:
gpt-j:
path: ./
file: ./Dockerfile
functions:
gpt-j:
image:
name: gpt-j
events:
- httpApi:
method: post
path: /predict
The dockerfile would look similar to your existing Dockerfile, but use the AWS runtime:
FROM public.ecr.aws/lambda/python:3.8
# install gpt-j
# app.py handles inference
COPY app.py ./
CMD ["app.handler"]
There would only be a single api call, since a lambda services only a single request at a time.
I wish I could be of more help, but my specialty at the moment is JavaScript, not Python. I will gladly test any demos though π already having a Dockerfile example means half the work is probably done
I have set up a TPU VM and when I try to run device_sample.py I get the following error when calling network.generate:
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
Currently, I am only doing inference so I am using the slim weights. Therefore, I had to change the line of code that sets up the optimizer to this:
params['optimizer'] = optax.scale(0)
This is exactly what the collab notebook does which I can run without any errors. Do I need to use the full weights to run device_sample.py? Thank you any help is appreciated.
Hi,
I am trying to develop a webapp for the model as it is very good in writing, I used gradio, as I am a data scientist, but it give an eooro for environment. env to be precise, Please share how you developed that app or guide me I am looking forward to giving a demo soon please help.
I am using this repo to train on a small custom dataset and Jax - 0.2.16 to train this model. However requirements.txt says jax 0.2.12. I really don't know how ray and jax have been working internally but my assumptions is thos code shoud work fine on v3-8 TPU. When I execute train.py then generates following error
(pid=9454, ip=10.164.0.9) jax runtime initialization starting
2021-07-09 12:11:59,794 ERROR worker.py:78 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): ray::NetworkRunner.run() (pid=9454, ip=10.164.0.9)
File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/train_actor.py", line 24, in run
TypeError: new() missing 1 required positional argument: 'loops'
I am not sure, what this error is? Can someone help be debug this error? Is this error a library versioning issue or something else.
Hello @kingoflolz,
I am using 4 - T4 GPUs, and facing the below error:
Traceback (most recent call last): File "gpt_j_inference.py", line 59, in <module> network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica) File "/home/pragnakalp/Desktop/GPT-J/env_gptj/lib/python3.8/site-packages/mesh_transformer/checkpoint.py", line 145, in read_ckpt assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}" AssertionError: Incompatible checkpoints (1, 4096) vs (4, 4096)
I have passed cores_per_replica
as 4.
How can I run the code using 4 GPUs, as single GPU results in Memory Error.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. πππ
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google β€οΈ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.