largeworldmodel / lwm Goto Github PK
View Code? Open in Web Editor NEWLicense: Apache License 2.0
License: Apache License 2.0
My community need a gradio demo and a pytorch version of LWM-Chat models
So they can use on Windows for image captioning
Please do these 2 I started following the repo
Thank you so much
Hello there! Architecture innovator here! Everything preceding my model seems very inefficient.
(lwm) llm@llm-PowerEdge-R730xd:~/projects/LWM-main$ bash scripts/run_vision_chat.sh
I0221 14:02:43.257625 139932541391232 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 14:02:43.260045 139932541391232 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
100%|██████████| 1/1 [00:05<00:00, 5.59s/it]
Traceback (most recent call last):
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 254, in
run(main)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 250, in main
output = sampler(prompts, FLAGS.max_n_frames)[0]
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 230, in call
output, self.sharded_rng = self._forward_generate(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/llm/projects/LWM-main/lwm/vision_chat.py", line 206, in fn
output = self.model.generate(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 429, in generate
return self._sample(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 733, in _sample
state = sample_search_body_fn(state)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/transformers/generation/flax_utils.py", line 704, in sample_search_body_fn
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 232, in call
outputs = self.module.apply(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1511, in apply
return apply(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 934, in wrapper
y = fn(root, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 2082, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 401, in call
outputs = self.transformer(
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/projects/LWM-main/lwm/vision_llama.py", line 313, in call
input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 836, in _call_wrapped_method
self._try_setup()
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1094, in _try_setup
self.setup()
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 771, in setup
self.embedding = self.param('embedding',
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/module.py", line 1263, in param
v = self.scope.param(name, init_fn, *init_args, unbox=unbox)
File "/home/llm/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/core/scope.py", line 842, in param
raise errors.ScopeParamNotFoundError(name, self.path_text)
flax.errors.ScopeParamNotFoundError: Could not find parameter named "embedding" in scope "/transformer/wte". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)
please provide an example. both vision and language are needed, thanks @wilson1yan
Hi, I met this error while run the script run_sample_img.sh with dimention 1,1,3,1 as I am using 3 A100 GPUs.
Here is the script:
export llama_tokenizer_path="LWM-Chat-1M-Jax/tokenizer.model"
export vqgan_checkpoint="LWM-Chat-1M-Jax/vqgan"
export lwm_checkpoint="LWM-Chat-1M-Jax/params"
python3 -u -m lwm.vision_generation \
--prompt='Fireworks over the city' \
--output_file='fireworks.png' \
--temperature_image=1.0 \
--top_k_image=8192 \
--cfg_scale_image=5.0 \
--vqgan_checkpoint="$vqgan_checkpoint" \
--n_frames=1 \
--mesh_dim='1,1,3,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer.vocab_file="$llama_tokenizer_path"
Here is the error info:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/lwm/vision_generation.py", line 258, in <module>
run(main)
File "/opt/conda/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/opt/conda/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/lwm/vision_generation.py", line 94, in main
model = FlaxVideoLLaMAForCausalLM(
File "/workspace/lwm/vision_llama.py", line 145, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 213, in __init__
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/workspace/lwm/vision_llama.py", line 170, in init_weights
random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
File "/workspace/lwm/vision_llama.py", line 401, in __call__
outputs = self.transformer(
File "/workspace/lwm/vision_llama.py", line 320, in __call__
outputs = self.h(
File "/workspace/lwm/llama.py", line 991, in __call__
hidden_states, _ = nn.scan(
File "/opt/conda/lib/python3.10/site-packages/flax/core/axes_scan.py", line 139, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/opt/conda/lib/python3.10/site-packages/flax/core/axes_scan.py", line 115, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/workspace/lwm/llama.py", line 766, in __call__
attn_outputs = self.attention(
File "/workspace/lwm/llama.py", line 657, in __call__
attn_output = ring_attention_sharded(
ValueError: shard_map applied to the function 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f98223cb6d0>, axis_name='sp')' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 1, 3, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
Detail message:
Traceback (most recent call last):
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/user_work_path/train/codes/LWM/lwm/train.py", line 396, in
run(main)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/user_work_path/train/codes/LWM/lwm/train.py", line 387, in main
save_checkpoint(train_state, milestone=True)
File "/user_work_path/train/codes/LWM/lwm/train.py", line 325, in save_checkpoint
checkpointer.save_all(
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 102, in save_all
self.save_checkpoint(
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 46, in save_checkpoint
self.save_train_state_to_file(
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 78, in save_train_state_to_file
fout.write(packer.pack((key, to_bytes(value))))
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 826, in pack
self._pack(obj)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 803, in _pack
self._pack(obj[i], nest_limit - 1)
File "/user_work_path/miniconda3/envs/lwm/lib/python3.10/site-packages/msgpack/fallback.py", line 750, in _pack
raise ValueError("%s is too large" % type(obj).name)
ValueError: bytes is too large
train script
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LLAMA_TOKENIZER_PATH=/user_work_path/tokenizer.model
export DATASET_PATH=/user_work_path/sample.jsonl
export SEED=1025
export PROJECT_ID='lwm'
export EXPERIMENT_NOTE=''
export EXPERIMENT_ID='example-text-train'
export OUTPUT_DIR=${PROJECT_DIR}/output
export COORDINATOR_ADDRESS=localhost:12345
export NUM_PROCESSES=1
export PROCESS_ID=0
export INITIALIZE_JAX_DISTRIBUTED=true
python3 -u -m lwm.train
--jax_distributed.coordinator_address ${COORDINATOR_ADDRESS}
--jax_distributed.initialize_jax_distributed ${INITIALIZE_JAX_DISTRIBUTED}
--jax_distributed.num_processes ${NUM_PROCESSES}
--jax_distributed.process_id ${PROCESS_ID}
--modality='text'
--mesh_dim='1,1,1,8'
--dtype='bf16'
--seed=${SEED}
--total_steps=10
--log_freq=1
--save_model_freq=0
--save_milestone_freq=5
--load_llama_config='13b'
--update_llama_config="dict(theta=10000,max_sequence_length=4096,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)"
--tokenizer.vocab_file="$LLAMA_TOKENIZER_PATH"
--optimizer.type='adamw'
--optimizer.accumulate_gradient_steps=1
--optimizer.adamw_optimizer.weight_decay=0.1
--optimizer.adamw_optimizer.lr=8e-5
--optimizer.adamw_optimizer.end_lr=8e-5
--optimizer.adamw_optimizer.lr_warmup_steps=5
--optimizer.adamw_optimizer.lr_decay_steps=200
--use_data_sharded_loader=True
--train_dataset.type='json'
--train_dataset.text_processor.fields='text'
--train_dataset.json_dataset.path="$DATASET_PATH"
--train_dataset.json_dataset.seq_length=1024
--train_dataset.json_dataset.batch_size=8
--train_dataset.json_dataset.tokenizer_processes=4
--train_dataset.json_dataset.tokenizer_parallel_chunk_size=2
--train_dataset.json_dataset.tokenizer_parallel_batch_size=8
--train_dataset.json_dataset.use_data_sharded_loader=True
--checkpointer.save_optimizer_state=True
--autoresume=False
--logger.append_uuid=False
--logger.online=False
--logger.project_id="$PROJECT_ID"
--logger.experiment_id="$EXPERIMENT_ID"
--logger.experiment_note="$EXPERIMENT_NOTE"
--logger.output_dir="$OUTPUT_DIR"
--logger.wandb_dir="$HOME/experiment_output/$PROJECT_ID"
read
environment
Package Version
absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cachetools 5.3.3
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.17.1
decorator 5.1.1
decord 0.6.0
dill 0.3.8
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2023.10.0
gcsfs 2023.10.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.1
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib_resources 6.1.2
ipdb 0.13.13
ipython 8.22.1
jax 0.4.23
jaxlib 0.4.23+cuda12.cudnn89
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.16
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.1
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.54
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.10.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0
Could some one help me?
Hi everyoneee,
I anticipate that it might be a stupid question but why do we have model_max_length: 2048
in the tokenizer_config.json https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M/blob/main/tokenizer_config.json?
Thank you!
May I ask what is the minimum requirement for an individual to run the LWM model? I have a 4090 graphic card and I don't know which model I can run.
Thanks for sharing this excellent great work. We want to use pytorch models to try the effect of ring attention. Are there any plans to develop ring attention implementation under pytorch?
Could you provide an example data for vision-language training (especially the format)? Thank you!
Hi,
LWM is incredible! Any plans to release a Mistral version?
Thanks!
I'm currently able to use run_vision_chat.sh with a limited number of video frames being passed in for a single text query. The text result is output from the model and then the process ends. However, the paper shows examples of a continuous dialogue about a video and I was wondering if it's possible to set this up.
Is it necessary to use float32 in training? Why not use the widely used bf16 type, which saves more gpu memory.
Looking forward to your reply, Thansk!
Hi,
I'm trying to run run_vision_chat.sh
but getting the following error:
(lwm) minyoung@claw2:~/Projects/LWM$ bash scripts/run_vision_chat.sh
I0215 18:19:20.605390 140230836105600 xla_bridge.py:689] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0215 18:19:20.607900 140230836105600 xla_bridge.py:689] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-15 18:19:29.755994: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Traceback (most recent call last):
File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/minyoung/Projects/LWM/lwm/vision_chat.py", line 254, in <module>
run(main)
File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/minyoung/Projects/LWM/lwm/vision_chat.py", line 249, in main
sampler = Sampler()
File "/home/minyoung/Projects/LWM/lwm/vision_chat.py", line 42, in __init__
self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
File "/home/minyoung/Projects/LWM/lwm/llama.py", line 260, in get_jax_mesh
return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp'))
File "/home/minyoung/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 140, in get_jax_mesh
mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
ValueError: cannot reshape array of size 1 into shape (1,newaxis,32,1)
These are the model configs I used.
export llama_tokenizer_path="./LWM-Chat-1M-Jax/tokenizer.model"
export vqgan_checkpoint="./LWM-Chat-1M-Jax/vqgan"
export lwm_checkpoint="./LWM-Chat-1M-Jax/params"
export input_file="./traj0.mp4"
Hi,
Great work on LWM! I noticed the weights are licensed under the Apache license but derived from Llama 2, do both the Llama 2 license and the Apache license apply to the weights?
Thanks!
(Related to #10)
we need some samples that can run actually inference on vision / image samples on gpu
(lwm) ➜ LWM git:(main) ./scripts/run_sample_image.sh
WARNING: Logging before InitGoogle() is written to STDERR
I0000 00:00:1707909968.893833 11548 common_lib.cc:148] Failed to fetch URL on try 1 out of 6: Couldn't connect to server
I0000 00:00:1707909972.473827 11548 common_lib.cc:148] Failed to fetch URL on try 2 out of 6: Couldn't connect to server
I0000 00:00:1707909976.025912 11548 common_lib.cc:148] Failed to fetch URL on try 3 out of 6: Couldn't connect to server
^C^CI0000 00:00:1707909979.577878 11548 common_lib.cc:148] Failed to fetch URL on try 4 out of 6: Couldn't connect to server
^CI0000 00:00:1707909983.129666 11548 common_lib.cc:148] Failed to fetch URL on try 5 out of 6: Couldn't connect to server
I0000 00:00:1707909986.681892 11548 common_lib.cc:148] Failed to fetch URL on try 6 out of 6: Couldn't connect to server
Failed to get 'tpu-env' from instance metadata: INTERNAL: Couldn't connect to server
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:145
learning/45eac/tfrc/runtime/common_lib.cc:162
learning/45eac/tfrc/runtime/common_lib.cc:188
I0000 00:00:1707909990.233946 11548 common_lib.cc:148] Failed to fetch URL on try 1 out of 6: Couldn't connect to server
I0000 00:00:1707909993.785913 11548 common_lib.cc:148] Failed to fetch URL on try 2 out of 6: Couldn't connect to server
I0000 00:00:1707909997.337871 11548 common_lib.cc:148] Failed to fetch URL on try 3 out of 6: Couldn't connect to server
I0000 00:00:1707910000.890123 11548 common_lib.cc:148] Failed to fetch URL on try 4 out of 6: Couldn't connect to server
I0000 00:00:1707910004.442720 11548 common_lib.cc:148] Failed to fetch URL on try 5 out of 6: Couldn't connect to server
I0000 00:00:1707910007.994498 11548 common_lib.cc:148] Failed to fetch URL on try 6 out of 6: Couldn't connect to server
Failed to get 'tpu-env' from instance metadata: INTERNAL: Couldn't connect to server
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:145
learning/45eac/tfrc/runtime/common_lib.cc:162
learning/45eac/tfrc/runtime/common_lib.cc:188
We need API
bash scripts/run_vision_chat.sh
removed --mesh_dim param
model is LWM-Chat-32K-Jax
out of memory error, how to solve it
my card is nvidia 2080 super 8G
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708500656.672727 10871 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
I0221 15:30:57.202437 140383335174272 xla_bridge.py:513] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 15:30:57.202921 140383335174272 xla_bridge.py:513] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-21 15:36:18.340692: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.00GiB (rounded to 2147483648)requested by op
2024-02-21 15:36:18.340908: W external/tsl/tsl/framework/bfc_allocator.cc:497] *________**********************************************************************_____________________
2024-02-21 15:36:18.340944: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.00GiB
constant allocation: 0B
maybe_live_out allocation: 2.00GiB
preallocated temp allocation: 0B
total allocation: 3.00GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 2.00GiB
Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
XLA Label: fusion
Shape: f32[32,4096,4096]
==========================
Buffer 2:
Size: 1.00GiB
Entry Parameter Subshape: bf16[32,4096,4096]
==========================
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 254, in <module>
run(main)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 249, in main
sampler = Sampler()
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 51, in __init__
self._load_model()
File "/mnt/data/test/LWM/lwm/vision_chat.py", line 199, in _load_model
self.params = tree_apply(shard_fns, self.params)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in tree_apply
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in <lambda>
return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 95, in shard_fn
return jax_shard_function(tensor).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.00GiB
constant allocation: 0B
maybe_live_out allocation: 2.00GiB
preallocated temp allocation: 0B
total allocation: 3.00GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 2.00GiB
Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
XLA Label: fusion
Shape: f32[32,4096,4096]
==========================
Buffer 2:
Size: 1.00GiB
Entry Parameter Subshape: bf16[32,4096,4096]
==========================
I0000 00:00:1708500978.900009 10871 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(lwm) test@test-3:/mnt/data/test/LWM$ nvidia-smi
Wed Feb 21 15:47:00 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 2080 S... Off| 00000000:01:00.0 Off | N/A |
| 0% 40C P0 23W / 250W| 0MiB / 8192MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
Hello, what is the chat format for the chat models?
Llama?
This paper talks about long context length for a language model which is extended to be a vision-language model. I wonder why is it called World Model. It is not obvious in the paper. This paper seems focus more on the long context and evaluation of related retrieval ability with little discussion on the world modelling.
I wonder is there any specific discovery on the model ability that improves along with long context training. Does it make it more robust against prompt variations? More robust on reasoning ? More semantically riched in concept representations ? Better ontological/hierarchical learning towards the meaning ?
Will be curious on hearing more about the findings from the authors.
Thanks a lot for any insights in advance : )
I use bash scripts/run_sample_video.sh, the sh file is:
using LWM-Chat-1M-JAX model.
...
python3 -u -m lwm.vision_generation \
--prompt='A long big pig is walking across the street' \
--output_file='fireworks.mp4' \
--temperature_image=1.0 \
--temperature_video=1.0 \
--top_k_image=8192 \
--top_k_video=1000 \
--cfg_scale_image=5.0 \
--cfg_scale_video=1.0 \
--vqgan_checkpoint="$vqgan_checkpoint" \
--n_frames=8 \
--mesh_dim='!1,1,2,1' \
--dtype='bf16' \
--load_llama_config='7b' \
--update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=256,scan_key_chunk_size=256,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer.vocab_file="$llama_tokenizer_path"
read
after generation, the output video only first frame has meaningful frame, other frame are all random pixel.
Hello,
It's a really great work which contributes a lot to the community!
Do you have any plan to train a smaller version of large world model (e.g., 1~3B), which may be based on smaller models like Phi-2? It should be much easier and use less computing resources.
Environment
GPUs: 8x4090
Package Version
Package Version
absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cachetools 5.3.2
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.13.0
decorator 5.1.1
decord 0.6.0
dill 0.3.6
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2024.2.0
gcsfs 2024.2.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.0
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib-resources 6.1.1
ipdb 0.13.13
ipython 8.21.0
jax 0.4.23
jaxlib 0.4.23+cuda11.cudnn86
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.14
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.0
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.53
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.9.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0
Error Messasge
I0223 22:58:05.579038 140312230876992 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0223 22:58:05.579842 140312230876992 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
100%|██████████| 1/1 [00:09<00:00, 9.21s/it]
2024-02-23 23:00:08.992159: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992208: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992237: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992261: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992281: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992348: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992392: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992430: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.
2024-02-23 23:00:08.992459: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992472: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992483: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992499: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992510: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992522: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992536: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
2024-02-23 23:00:08.992551: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/users/yu01.xie/project/mllm/LWM-main/lwm/vision_chat.py", line 254, in
run(main)
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/users/yu01.xie/software/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/users/yu01.xie/project/mllm/LWM-main/lwm/vision_chat.py", line 250, in main
output = sampler(prompts, FLAGS.max_n_frames)[0]
File "/home/users/yu01.xie/project/mllm/LWM-main/lwm/vision_chat.py", line 230, in call
output, self.sharded_rng = self._forward_generate(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:134: NCCL operation ncclGetUniqueId(&id) failed: Unable to load NCCL library. Multi-GPU collectives will not work.. Last NCCL warning(error) log entry (may be unrelated) 'Unable to load NCCL library. Multi-GPU collectives will not work.'.; current tracing scope: reduce-scatter-start.4; current profiling annotation: XlaModule:#hlo_module=pjit_fn,program_id=50#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
I0228 17:55:33.471474 139972342939648 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0228 17:55:33.473013 139972342939648 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
0%| | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:05<00:00, 5.02s/it]
100%|██████████| 1/1 [00:05<00:00, 5.02s/it]
2024-02-28 17:56:12.936238: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 29.79GiB (rounded to 31985762560)requested by op
2024-02-28 17:56:12.936491: W external/tsl/tsl/framework/bfc_allocator.cc:497] ____************************************************************
Fatal Python error: Segmentation fault
Thread 0x00007f4b5e7fc640 (most recent call first):
File "/usr/lib/python3.10/threading.py", line 324 in wait
File "/usr/lib/python3.10/threading.py", line 607 in wait
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap
Current thread 0x00007f4dd9c78000 (most recent call first):
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/compiler.py", line 256 in backend_compile
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/profiler.py", line 336 in wrapper
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/compiler.py", line 333 in compile_or_get_cached
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2528 in _cached_compilation
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2659 in from_hlo
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2219 in compile
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 1165 in _pjit_call_impl_python
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 1229 in call_impl_cache_miss
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 1245 in _pjit_call_impl
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/core.py", line 935 in process_primitive
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/core.py", line 447 in bind_with_trace
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/core.py", line 2740 in bind
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 168 in _python_pjit_helper
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/pjit.py", line 257 in cache_miss
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179 in reraise_with_filtered_traceback
File "/home/azureuser/LWM1/LWM/lwm/vision_chat.py", line 230 in call
File "/home/azureuser/LWM1/LWM/lwm/vision_chat.py", line 250 in main
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/absl/app.py", line 254 in _run_main
File "/home/azureuser/LWM_env/lib/python3.10/site-packages/absl/app.py", line 308 in run
File "/home/azureuser/LWM1/LWM/lwm/vision_chat.py", line 254 in
File "/usr/lib/python3.10/runpy.py", line 86 in _run_code
File "/usr/lib/python3.10/runpy.py", line 196 in _run_module_as_main
Extension modules: PIL._imaging, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, jaxlib.cpu_feature_guard, charset_normalizer.md, yaml._yaml, msgpack._cmsgpack, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._flinalg, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._ansari_swilk_statistics, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, scipy.stats._unuran.unuran_wrapper, multidict._multidict, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, sentencepiece._sentencepiece (total: 129)
Hi, thanks for releasing the code. Looks pretty interesting! I noticed that the LWM-Chat (multimodal) model checkpoint is only released in Jax. It would be great if you could release the model in PyTorch as well as you did for other text-only models!
GPUs: 4x80G
Package Version
absl-py 2.1.0
aiohttp 3.9.3
aiosignal 1.3.1
appdirs 1.4.4
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cachetools 5.3.2
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.82
click 8.1.7
cloudpickle 3.0.0
contextlib2 21.6.0
datasets 2.13.0
decorator 5.1.1
decord 0.6.0
dill 0.3.6
docker-pycreds 0.4.0
einops 0.7.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.1
flax 0.7.0
frozenlist 1.4.1
fsspec 2024.2.0
gcsfs 2024.2.0
gitdb 4.0.11
GitPython 3.1.42
google-api-core 2.17.1
google-auth 2.28.1
google-auth-oauthlib 1.2.0
google-cloud-core 2.4.1
google-cloud-storage 2.14.0
google-crc32c 1.5.0
google-resumable-media 2.7.0
googleapis-common-protos 1.62.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.34.0
imageio-ffmpeg 0.4.9
importlib-resources 6.1.1
ipdb 0.13.13
ipython 8.21.0
jax 0.4.23
jaxlib 0.4.23+cuda12.cudnn89
jedi 0.19.1
markdown-it-py 3.0.0
matplotlib-inline 0.1.6
mdurl 0.1.2
ml-collections 0.1.1
ml-dtypes 0.3.2
msgpack 1.0.7
multidict 6.0.5
multiprocess 0.70.14
nest-asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.3.4.1
nvidia-cuda-cupti-cu12 12.3.101
nvidia-cuda-nvcc-cu12 12.3.107
nvidia-cuda-nvrtc-cu12 12.3.107
nvidia-cuda-runtime-cu12 12.3.101
nvidia-cudnn-cu12 8.9.7.29
nvidia-cufft-cu12 11.0.12.1
nvidia-cusolver-cu12 11.5.4.101
nvidia-cusparse-cu12 12.2.0.103
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
oauthlib 3.2.2
opt-einsum 3.3.0
optax 0.1.7
orbax-checkpoint 0.5.3
packaging 23.2
pandas 2.2.0
parso 0.8.3
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.0
pyasn1 0.5.1
pyasn1-modules 0.3.0
Pygments 2.17.2
python-dateutil 2.8.2
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rsa 4.9
scipy 1.12.0
sentencepiece 0.2.0
sentry-sdk 1.40.5
setproctitle 1.3.3
setuptools 68.2.2
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
tensorstore 0.1.53
tiktoken 0.6.0
tokenizers 0.13.3
tomli 2.0.1
toolz 0.12.1
tqdm 4.66.2
traitlets 5.14.1
transformers 4.29.2
tux 0.0.2
typing_extensions 4.9.0
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.3
wcwidth 0.2.13
wheel 0.41.2
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0
`I0222 09:24:21.054814 140683333334848 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0222 09:24:21.056322 140683333334848 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-22 09:24:21.097023: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
0%| | 0/1 [00:00<?, ?it/s]2024-02-22 09:25:32.707642: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.707708: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
2024-02-22 09:25:32.807973: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.808024: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
2024-02-22 09:25:32.821063: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.821116: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
2024-02-22 09:25:32.825532: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.
2024-02-22 09:25:32.825585: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.
0%| | 0/1 [00:07<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/data/jeff/Git/LLM/src/T2V/LWM/lwm/vision_generation.py", line 258, in
run(main)
File "/root/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/root/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/data/jeff/Git/LLM/src/T2V/LWM/lwm/vision_generation.py", line 184, in main
img_enc, img = generate_first_frame(prompts, max_input_length=128)
File "/data/jeff/Git/LLM/src/T2V/LWM/lwm/vision_generation.py", line 158, in generate_first_frame
output, sharded_rng = _sharded_forward_generate(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.reduce_scatter' failed: external/xla/xla/service/gpu/nccl_utils.cc:305: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: internal error - please report this issue to the NCCL developers. Last NCCL warning(error) log entry (may be unrelated) 'Attribute busid of node nic not found'.; current tracing scope: reduce-scatter-start.5; current profiling annotation: XlaModule:#hlo_module=pjit__forward_generate,program_id=21#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed`
Hello, I am a big fan of your project and I am interested in using your model for my own data. However, I am new to fine-tuning models and I am not sure how to proceed. Could you please provide some guidance on the steps I need to take to fine-tune your model with my own data?
Specifically, I would like to know:
1.What format my data needs to be in
2.How to preprocess my data
3.How to configure the model for fine-tuning
4.How to train the model on my data
5.How to evaluate the performance of the fine-tuned model
I would greatly appreciate any help you can provide. Thank you!
I hope this helps! Let me know if you have any other questions.
Also, it's worth noting that it's always a good idea to include as much information as possible about your specific use case and any error messages or unexpected behavior you are encountering when you are creating an issue on GitHub. This will help the maintainers of the project to better understand your problem and provide a more accurate solution.
Do you have any plans on creating safetensors for the models?
It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.
Using run_vision_chat.sh with a .PNG image results in
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 254, in <module>
run(main)
File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 250, in main
output = sampler(prompts, FLAGS.max_n_frames)[0]
File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 228, in __call__
batch = self.construct_input(prompts, max_n_frames)
File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 123, in construct_input
vision = self._read_process_vision(prompt['input_path'], max_n_frames)
File "/mnt/vol_f/LWM/lwm/vision_chat.py", line 102, in _read_process_vision
enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
File "/mnt/vol_f/LWM/lwm/vqgan.py", line 53, in encode
return self._encode(pixel_values)
File "/mnt/vol_f/LWM/lwm/vqgan.py", line 35, in fn
return self.model.apply(
File "/mnt/vol_f/LWM/lwm/vqgan.py", line 122, in encode
hidden_states = self.encoder(pixel_values)
File "/mnt/vol_f/LWM/lwm/vqgan.py", line 155, in __call__
hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)
File "/home/ubuntu/miniconda3/envs/lwm/lib/python3.10/site-packages/flax/linen/linear.py", line 429, in __call__
kernel = self.param('kernel', self.kernel_init, kernel_shape,
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (3, 3, 3, 128) but got shape (3, 3, 4, 128) instead for parameter "kernel" in "/encoder/Conv_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
when number of channels in the input is > 3 (if transparency is present).
Thank you for your contribution! Amazing performance! I just wonder the computational requirements for training such world models, e.g., how many GPUs and how long you need to train it?
Hello, please tell me about the error when running vision_chat.py, where jax==0.4.23, tux==0.0.2
The specific errors are as follows:
File "/home/LWM-main/lwm/vision_chat.py", line 12, in
from tux import (
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/init.py", line 17, in
from .optimizers import (AdamWOptimizerFactory, get_weight_decay_mask, optax_add_scheduled_weight_decay,
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/optimizers.py", line 193, in
class OptaxScheduledWeightDecayState(NamedTuple):
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/tux/optimizers.py", line 194, in OptaxScheduledWeightDecayState
count: jnp.DeviceArray
File "/root/miniconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
Hi, Great work! Can you provide the vision-language models in Pytorch.
Traceback (most recent call last):
File "/output/LWM/scripts/sample_pyt.py", line 8, in <module>
model = LlamaForCausalLM.from_pretrained(args.model)
File "/usr/local/envs/lwm/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1112, in __getattribute__
requires_backends(cls, cls._backends)
File "/usr/local/envs/lwm/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1100, in requires_backends
raise ImportError("".join(failed))
ImportError:
LlamaForCausalLM requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
If possible please provide a pytorch version of the vision model.
@wilson1yan It didn't work. More samples are needed including language and vision version.
./scripts/run_vision_chat.sh
Traceback (most recent call last):
File "/home/jiapeiyang/anaconda3/envs/nlp/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/jiapeiyang/anaconda3/envs/nlp/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/jiapeiyang/workspace/LWM/lwm/vision_chat.py", line 18, in
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
File "/home/jiapeiyang/workspace/LWM/lwm/vision_llama.py", line 21, in
from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm
File "/home/jiapeiyang/workspace/LWM/lwm/llama.py", line 31, in
from lwm.ring_attention import blockwise_ffn, ring_flash_attention_tpu,
File "/home/jiapeiyang/workspace/LWM/lwm/ring_attention.py", line 557, in
class BlockSizes:
File "/home/jiapeiyang/workspace/LWM/lwm/ring_attention.py", line 563, in BlockSizes
block_q_major_dkv: int | None = None
TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'
`
export llama_tokenizer_path="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/tokenizer.model"
export vqgan_checkpoint="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/vqgan"
export lwm_checkpoint="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/params"
export input_file="/home/jiapeiyang/workspace/LWM/models/LWM-Chat-32K-Jax/test_a.jpg"
python3 -u -m lwm.vision_chat
--prompt="What is the video about?"
--input_file="$input_file"
--vqgan_checkpoint="$vqgan_checkpoint"
--mesh_dim='!1,1,8,1'
--dtype='fp32'
--load_llama_config='7b'
--max_n_frames=8
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)"
--load_checkpoint="params::$lwm_checkpoint"
--tokenizer.vocab_file="$llama_tokenizer_path"
2>&1 | tee ~/output.log
read
`
Only got image with vision jax model to work, and even then had to remove the mesh_grid arg.
Everything else has failed.
E.g. needle fails like:
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
export llama_tokenizer_path="LWM-Chat-1M-Jax/tokenizer.model"
export lwm_text_checkpoint="LWM-Chat-1M-Jax/params"
# jsonl file containing text for haystack. Each line should be a json
# with a single key "text" containing the text.
export haystack_file="../ultrachat_qa_mix_128K/data.jsonl"
export output_file="output"
python3 -u scripts/eval_needle.py \
--mesh_dim='!1,-1,4,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(theta=10000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
--load_checkpoint="params::$lwm_text_checkpoint" \
--tokenizer.vocab_file="$llama_tokenizer_path" \
--max_tokens_per_batch=5000 \
--output_file="$output_file" \
--haystack_file="$haystack_file" \
--context_lengths_min=1000 \
--context_lengths_max=10000 \
--n_context_length_intervals=20 \
--n_document_depth_intervals=20 \
--n_rounds=3
read
(lwm) jon@gpu:~/LWM$ bash scripts/run_eval_needle.sh
I0216 10:25:24.068257 139879088207680 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0216 10:25:24.070914 139879088207680 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
Starting Needle In A Haystack Testing...
- Context Lengths: 20, Min: 1000, Max: 10000
- Document Depths: 20, Min: 0%, Max: 100%
- Needle: The special magic {city} number is: {rnd_number}
W0216 10:26:39.398258 139879088207680 _metadata.py:139] Compute Engine Metadata server unavailable on attempt 1 of 3. Reason: timed out
W0216 10:26:39.447406 139879088207680 _metadata.py:139] Compute Engine Metadata server unavailable on attempt 2 of 3. Reason: [Errno 113] No route to host
W0216 10:26:42.451228 139879088207680 _metadata.py:139] Compute Engine Metadata server unavailable on attempt 3 of 3. Reason: timed out
W0216 10:26:42.451697 139879088207680 _default.py:338] Authentication failed using Compute Engine authentication due to unavailable metadata server.
W0216 10:26:42.530295 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 1 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f4c0>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.607035 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 2 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430efb0>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.686556 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 3 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f130>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.767113 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 4 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f160>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
W0216 10:26:42.851304 139879088207680 _metadata.py:208] Compute Engine Metadata server unavailable on attempt 5 of 5. Reason: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/service-accounts/default/?recursive=true (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7f372430f7f0>: Failed to resolve 'metadata.google.internal' ([Errno -2] Name or service not known)"))
completed 0
Traceback (most recent call last):
File "/home/jon/LWM/scripts/eval_needle.py", line 447, in <module>
run(main)
File "/home/jon/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/jon/miniconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/jon/LWM/scripts/eval_needle.py", line 444, in main
ht.start_test()
File "/home/jon/LWM/scripts/eval_needle.py", line 306, in start_test
self.run_test()
File "/home/jon/LWM/scripts/eval_needle.py", line 230, in run_test
full_contexts = self.read_context_files(FLAGS.n_rounds)
File "/home/jon/LWM/scripts/eval_needle.py", line 129, in read_context_files
text = json.loads(f.readline())['text']
KeyError: 'text'
i.e. some specific files are required that aren't shared, and some access to google is used, which isn't explained.
Trying to install on an Ubuntu 22.04 system with pip 24.0 and python 3.11.
pip install -r requirements.txt
yields the error: Could not find a version that satisfies the requirement tensorflow==2.11.0
. Min version number that shows up for me is 2.12.0rc0.
Hi, when I follow the script run_vision_chat.sh from #13 and deprecate --mesh_dim, I still face such error:
attn_output = ring_attention_sharded(
ValueError: shard_map applied to the function 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
args[0] of shape float32[1,2560,32,128], where args[0] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
args[1] of shape float32[1,2560,32,128], where args[1] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
args[2] of shape float32[1,2560,32,128], where args[2] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
args[3] of shape float32[1,1,2560,2560], where args[3] is bound to functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')'s parameter 'attn_mask', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, 'sp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7f9b40199ab0>, axis_name='sp')' appropriately.
Any suggestion for the solution?
请问这个文件中哪些参数都是什么意思呢?能不能写些注释呢?加入我想生成分辨率更高的视频,或者要生成时长更长的视频,我需要修改哪些参数呢?另外视频大小或者是分辨率有限制么?
Do you have this plan ?
Thank you for publicizing such impressive work! And can you public the LWM-1K/8K version of LWM, too?
I am trying to run the run_sample_video.sh
file from the scripts folder.
I am running into a lot of dependency issues when running this on a mac M1.
Has anyone been successful in running it on M1 ?
see this colab: https://gist.github.com/Mistobaan/605e212f951c5ae82ea420765fce381b
What am I doing wrong? the scripts might require better naming to match the released model weights.
Thank you again for this work !
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.