gnobitab / rectifiedflow Goto Github PK
View Code? Open in Web Editor NEWOfficial Implementation of Rectified Flow (ICLR2023 Spotlight)
Official Implementation of Rectified Flow (ICLR2023 Spotlight)
Thanks for your great work!
If I want to train on LSUN bedroom and get a LSUN bedroom checkpoint like in RectifiedFlow pre-trained-checkpoints, how long will it take?
I am training on 8A100, and from step100 to step1400, it takes 22mins. And in configs/default_lsun_configs.py
, training.n_iters = 2400001
. So it will take about 690 hours (29 days) in total. Is 29 days on 8xA100 a similar training cost time to yours?
Hi~
Excellient work!!
I want to ask when will you public the code of Image-to-Image Translation. I'm looking forward to it.
Thanks!!
best,
i come across a situation that DNN library is not found,also, the system can not find the GPU.
try pip install jax==0.3.25 jaxlib==0.3.25
if you have issue with the default installation
Hi. Do you have an env.yaml file for us to reproduce your environment? I tried to install the required pkgs using the requirements.txt provided in this repo, but many bugs occurred. A yaml file to create a new conda env would be much straigh-forward and easy for those who want to run your code :) Or any other alternatives would be great as long as it makes reproducing environment easier!
Can anyone explain this line of code:
RectifiedFlow/ImageGeneration/sampling.py
Line 98 in 5a1fd4d
Excellent work. But I'm confusedwhy the Eq.(1) Loss makes flows avoid crossing? I didn't find explaination in the paper? Thanks.
finished
Hi, thanks for your great work. I especially appreciate your intuitive blog post. However, I just want to let you know that there's a tiny bug in the attached colab example, Tutorial: Rectified Flow with Neural Network.ipynb.
The selected variable should be diffusion
, otherwise it overwrites the rectified_flow_1
in the subsequent blocks.
training.continuous and get_sed_loss_fn are not implemented ?
hello,I follow the commands to install the dependencies,but I have a problem with this code, how to solve the following issue?
2021-06-02 10:02:23.367252: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1553 : Unknown: Fail to find the dnn implementation. 2021-06-02 10:02:23.369234: E tensorflow/stream_executor/cuda/cuda_dnn.cc:352] Loaded runtime CuDNN library: 8.0.5 but source was compiled with: 8.1.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration. 2021-06-02 10:02:23.370402: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1553 : Unknown: Fail to find the dnn implementation.
I tried to fine-tune custom dataset with pretrained model.
But dataset_stats.npz file like assets/stats/celeba_stats.npz is needed.
No explanation of generating dataset_stats.npz ....
So How can I try to fine-tune custom datset with pretrained model?
https://github.com/gnobitab/RectifiedFlow/blob/cb2fc1906e80519b1bcd08e4c864519d52b75459/ImageGeneration/run_lib_reflow.py#L102C3-L116C69
No class except rectified flow is found in sde_lib.py.
I0619 12:16:40.841244 140569191580864 resolver.py:419] Downloading TF-Hub Module 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'.
Traceback (most recent call last):
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/urllib/request.py", line 1354, in do_open
h.request(req.get_method(), req.selector, req.data, headers,
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1256, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1302, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1251, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1011, in _send_output
self.send(msg)
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 951, in send
self.connect()
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 1418, in connect
super().connect()
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/http/client.py", line 922, in connect
self.sock = self._create_connection(
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/socket.py", line 808, in create_connection
raise err
File "/home/scholar/anaconda3/envs/rectflow/lib/python3.8/socket.py", line 796, in create_connection
sock.connect(sa)
TimeoutError: [Errno 110] Connection timed out
The default iteration is 1300k, which gonna take 5 days using a single A100...
Does the training really need 5 days?
Traceback (most recent call last):
File "E:\Googledownload\RectifiedFlow-main\ImageGeneration\main.py", line 18, in
import run_lib
File "E:\Googledownload\RectifiedFlow-main\ImageGeneration\run_lib.py", line 25, in
import tensorflow as tf
File "D:\conda\lib\site-packages\tensorflow_init_.py", line 37, in
from tensorflow.python.tools import module_util as module_util
File "D:\conda\lib\site-packages\tensorflow\python_init.py", line 37, in
from tensorflow.python.eager import context
File "D:\conda\lib\site-packages\tensorflow\python\eager\context.py", line 28, in
from tensorflow.core.framework import function_pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\function_pb2.py", line 16, in
from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\attr_value_pb2.py", line 16, in
from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\tensor_pb2.py", line 16, in
from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\resource_handle_pb2.py", line 16, in
from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
File "D:\conda\lib\site-packages\tensorflow\core\framework\tensor_shape_pb2.py", line 36, in
_descriptor.FieldDescriptor(
File "D:\conda\lib\site-packages\google\protobuf\descriptor.py", line 561, in new
_message.Message._CheckCalledFromGeneratedFile()
TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
Hi, I run the inference example python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/1_rectified_flow --config.eval.enable_sampling --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 8
and got an OOM error.
I set the batch size = 1, num_samples =1.
My GPU has 24576MiB.
Is there any way to bypass the OOM?
I0220 07:32:34.000750 139647609180864 resolver.py:106] Using /tmp/tfhub_modules to cache modules.
I0220 07:32:36.494689 139647609180864 run_lib.py:273] begin checkpoint: 8
Traceback (most recent call last):
File "main.py", line 74, in <module>
app.run(main)
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "main.py", line 66, in main
run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder)
File "/root/RectifiedFlow/ImageGeneration/run_lib.py", line 286, in evaluate
state = restore_checkpoint(ckpt_path, state, device=config.device)
File "/root/RectifiedFlow/ImageGeneration/utils.py", line 14, in restore_checkpoint
loaded_state = torch.load(ckpt_dir, map_location=device)
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 712, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 1049, in _load
result = unpickler.load()
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 1019, in persistent_load
load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 1001, in load_tensor
wrap_storage=restore_location(storage, location),
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 973, in restore_location
return default_restore_location(storage, str(map_location))
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
result = fn(storage, location)
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/serialization.py", line 157, in _cuda_deserialize
return obj.cuda(device)
File "/opt/anaconda3/envs/sde/lib/python3.7/site-packages/torch/_utils.py", line 78, in _cuda
return torch._UntypedStorage(self.size(), device=torch.device('cuda')).copy_(self, non_blocking)
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 918.46 MiB already allocated; 13.56 MiB free; 968.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/1_rectified_flow
It seems that the training process need 60w iterations.
It seems that the memory usage of each gpu is not very high during the training process. (4.3G for 24G RTX3090)
Is there any way to increase the memory usage and therefore accelerate the training process?
Hi @gnobitab , I implemented feature loss by myself, however, it did not work properly.
Could you provide some comments for my pseudo code?
import torch
import torch.nn.functional as F
def get_feature_weight(S):
def _feature_func(x):
feature = feature_extractor(x) # shape [batch_size, feature_dim, H, W]
feature = feature.sum(dim=(0, 2, 3))
return feature # shape [feature_dim]
S = S.requires_grad_(True) # shape [batch_size, dim, H, W]
w = torch.autograd.functional.jacobian(_feature_func, S) # shape [feature_dim, batch_size, dim, H, W]
return w.transpose(0, 1).detach() # shape [batch_size, feature_dim, dim, H, W]
w = get_feature_weight(z_t)
w_target = torch.einsum('bdchw,bchw->bdhw', w, target)
w_pred = torch.einsum('bdchw,bchw->bdhw', w, pred)
loss = F.mse_loss(target, pred)
RectifiedFlow/ImageGeneration/losses.py
Line 101 in 9df93aa
在采样的时候用的是xt = (1-t)x + ty, t的范围是[1e-3,1],训练的时候t=1e-3对应的是x_0 * (1-eps)+ y*eps,但是推理的时候t的起点也是1e-3,这时候是没有y用于计算x_t的,是出于什么原因这样设计的呢?
Any suggestions?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.