atong01 / conditional-flow-matching Goto Github PK
View Code? Open in Web Editor NEWTorchCFM: a Conditional Flow Matching library
Home Page: https://arxiv.org/abs/2302.00482
License: MIT License
TorchCFM: a Conditional Flow Matching library
Home Page: https://arxiv.org/abs/2302.00482
License: MIT License
Hi, I noticed that you have provided a demo on unconditional generation on the MNIST dataset. Do you plan to provide some demos on conditional generation? That would be of great help to me.
Thanks again for your valuable work and codebase!
Dear experts,
Thank you very much for this nice package! I am trying to use your batch OT version of flow matching to perform the matching between two distributions, but I wonder how the OT mapping should work for conditional distributions.
Suppose I have a function y(x|c) and I want to transport it to z(x|c). When I sample n_events from these distributions, they will have different c values among themselves. How can one perform OT mapping between "events" with different condition values?
I thought of something like this: we split the data into bins of c [0, 0.1, 0.2, ...] where we don't expect the PDFs to change much, and during training, sample only events of y(x|c) and z(x|c) where the events of both are in the same bin.
Does that sound reasonable to you? Or is there a clever and simpler way to do it?
Best,
Caio
Hi again! Been playing around with finetuning SD using your codebase. It's worked pretty well and seems to converge quickly compared to eps/v-objective. However, I'm seeing noise appear in the final steps (around t = 0.6 to t=1.0) during sampling. (It is looking clean before then). This is pretty surprising because it should be obvious to the model that there shouldn't be any noise there
It doesn't happen every time but maybe 20% of the time.
I've tried training with sigma=0.0 and sigma=0.1. The std of my data is around 0.5.
Any thoughts or tips?
Only maybe related:
I noticed that this simple implementation:
https://gist.github.com/francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa
Uses this setup instead
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z
u = (1 - 1e-4) * z - x
And I thought this might be related. I know that the "eps" objective and "predicting denoised image" objective work badly at different ends of the timesteps
Hi,
First of all thanks for this amazing repository!
In conditional_mnist.ipynb
, the UNetModel
is called like this,
model = UNetModel(
dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True
).to(device)
The above code works fine, but I do not fully understand why, since the UNetModel
itself does not have these arguments, but instead the wrapper UNetModelWrapper
Massive appreciation for the remarkable work on this repository!!
I would like to know how to reproduce the results in table 5, in your paper Improving and Generalizing Flow-Based
Generative Models with Minibatch Optimal Transport.
I am a bit confused as it looks like the readme has to be updated to be consistent with the v1 codebase.
Thank you in advance!
This issue is opened for users to suggest improvements for the minibatch OT tutorial notebook.
See scverse/scanpy#2411. Recommend downgrade to Matplotlib==3.6.3 as hot fix.
Hello and thanks for this super interesting work!
I'm curious to try this with custom, image-like data, and just wondering what steps are required to get something going.
Any guidance would be awesome.
Dear all,
Thanks for the great package!
I am training with suggested command:
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
But at a rough estimation, it needs 3 days(60 hours) to finish the whole training process with log:
lr, total_steps, ema decay, save_step: 0.0002 400001 0.9999 20000
Files already downloaded and verified
Model params: 34.09 M
0%|▍ | 1255/400001 [11:11<60:19:15, 1.84it/s]
I am using a 32GB V100. I have two questions about training speed:
This issue is opened for users to suggest improvements for the Flow Matching tutorial notebook.
Thanks very much for this code base, it's been a great way to learn about flow matching. I have a question regarding conditional generation with OT-CFM.
When testing different FM approaches on my own data, I noticed that OT-CFM trains significantly slower and tends to perform much worse on tasks with conditioning. In an effort to isolate this problem I tried conditional MNIST, comparing OT-CFM with FM (using the example provided).
After a single epoch of training, I visualized the generations of both approaches with 1 step and dopri5. FM is on the left, OT-CFM is on the right.
One step generation (euler with 1 step):
Adaptive generation with dopri5:
After one epoch of training, FM has much nicer generations for both 1 sampling step and with dopri5. Even after a longer training time, FM continues to outperform OT-CFM (converges much faster).
After reading more, I noticed that both OT-CFM and Multisample Flow Matching papers only report results for unconditional generation, while papers doing conditional generation such as Stable Diffusion 3 and Flow Matching in Latent Space seem to use standard flow matching without batch optimal transport.
I wonder if the authors have studied this, and if there are any results for OT-CFM conditional tasks, or perhaps if there is a reason or explanation that OT-CFM should not work in this setting. My intuition was that adding conditioning makes the combinatorial space of the OT plan extremely hard to approximate from the limited samples in the batch, and this would be further exaggerated if the conditioning is not on simple class labels but rather continuous values (for example language embeddings for text to image generation etc).
I would greatly appreciate any insight on this, and if there is an approach that is applicable to conditional generation. Thank you!
The code tweaks for this were:
sigma=0.0
if args.fm_method == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
elif args.fm_method == "otcfm":
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
if args.fm_method == "fm":
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
y1 = y
elif args.fm_method == "otcfm":
t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)
Tests failing because POT changed the implementation of sinkhorn_unbalanced
. The new output is numerically different but same enough for me.
array([[0.51122814, 0.18807032],
[0.18807032, 0.51122814]])
The parameter num_images
is currently unused in the cifar10 example
It would be great to have a generate_cifar10.py
to generate a given number of images, as in the edm code https://github.com/NVlabs/edm
Hi, thanks for this very interesting codebase!
I'd like to reproduce the CelebA Image-to-Image example presented at the end of:
https://arxiv.org/pdf/2302.00482.pdf
Are you planning to add example code for this or could you point me in the right direction?
Best regards!
Dear all,
Thanks for the great package!
I am writing to seek guidance with a doubt I have. With discrete flows, during training we learn a transformation from data space to, say, Gaussian space (
In the case of this package, it's clear how we can use the models to start from a (Gaussian) noise space, get the initial conditions of the ODE and solve with torchdiffeq or torchdyn. The question is how can I use the model to compute the reverse trajectory, i.e. go from new data back to the (Gaussian) noise space? How do we compute
Do I have to reverse the time steps order (from 1 to 0) and give the data as the initial conditions of the ODE?
I would be really grateful for any guidance or code example you could provide.
Best regards,
Francesco
python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5;
bug:
XXX/python3.10/zipfile.py", line 1336, in _RealGetContents
raise BadZipFile("File is not a zip file")
zipfile.BadZipFile: File is not a zip file
Waiting on DiffEqML/torchdyn#180
Thank you for your GREAT work and for providing the reproduction scripts. However, based on my reproduction, I found that the FID on the CIFAR10 is much worse than 11 in the paper(https://arxiv.org/pdf/2302.00482.pdf), and I actually got 30 (with TargetConditionalFlowMatcher). Any idea about this results?
Hi, thanks for this great implementation!
I was wondering where the variance reduction method described in Appendix C.1 is present in this code. Also, is there an easy way to reproduce the results of Appendix D.1 on variance reduction?
I am opening this issue as there are several warnings in the SDE class.
Should we remove in torch.ones_like(t)
https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/utils.py#L94 ?
Sign problem in the drift function within the https://github.com/atong01/conditional-flow-matching/blob/main/examples/notebooks/mnist_example.ipynb notebook. I think it should be
return self.drift(t, y).flatten(start_dim=1) - self.score(t, y).flatten(start_dim=1)
Note that there is the same problem for https://github.com/atong01/conditional-flow-matching/blob/main/examples/notebooks/conditional_mnist.ipynb.
I would suggest updating these notebooks and using the SDE class from the utils folder. However, after having corrected the first issue.
We recommend using PyTorch 1.13.1 with Torchdyn 1.0.3. This was the original setup we used to build the library
We are working on updating the library to make it compatible with Torch 2.0. It should be done promptly.
Hi!
I noticed that in several places of the code, e.g. in the conditional_mnist.ipynb
tutorial, there's a slight CPU-GPU overhead. For example, the line
torch.randn(100, 1, 28, 28).to(device)
first creates the tensor on the CPU and then copies it onto the GPU (in case device
is cuda:0
).
It is possible to directly create the tensor on the specified device via
torch.randn(100, 1, 28, 28, device=device)
which is more memory-efficient, since there's no CPU-GPU overhead.
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000:
bug:
Traceback (most recent call last):
File "/t9k/mnt/pycharm_project/conditional-flow-matching-main/examples/cifar10/train_cifar10.py", line 165, in
app.run(train)
File "/t9k/mnt/.conda/envs/torchcfm/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/t9k/mnt/.conda/envs/torchcfm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/t9k/mnt/pycharm_project/conditional-flow-matching-main/examples/cifar10/train_cifar10.py", line 139, in train
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) # xt = self.sample_xt(x0, x1, t, eps)
File "/t9k/mnt/pycharm_project/conditional-flow-matching-main/torchcfm/conditional_flow_matching.py", line 185, in sample_location_and_conditional_flow
xt = self.sample_xt(x0, x1, t, eps)
File "/t9k/mnt/pycharm_project/conditional-flow-matching-main/torchcfm/conditional_flow_matching.py", line 123, in sample_xt
mu_t = self.compute_mu_t(x0, x1, t)
File "/t9k/mnt/pycharm_project/conditional-flow-matching-main/torchcfm/conditional_flow_matching.py", line 329, in compute_mu_t
return t * x1
RuntimeError: The size of tensor a (128) must match the size of tensor b (32) at non-singleton dimension 3
May I ask how to solve this bug?
Thank you for your contribution, I have some confusion about SB-CFM, I find when I remove OT, the results are better than with OT in my task. As discussed in the issue, I am also a condition-generating task, if the method in the original article still right after removing OT. I can't understand the derivation in your paper, especially Equation 19. It still seems to be correct even if the OT is removed, but not sure, so I'm asking for your advice, looking forward to your reply!
Thanks for this great work.
I have reproduced the claimed FID 3.5 on cifar10. But the default FID computation is quite time consuming(about 40 minutes) since it requires generating 50,000 images.
When I reduce the num_gen option to 5,000 in compute_fid.py, the resulted FID increases a lot, e.g. increased from 5.12 to 9.75.
Is this normal? According FID's definition, FID shall not affected by the num_gen.
Dear all,
Thanks for the great package!
I am working with some data
I did implement a solution for this use case. I would like to get your feedback on whether this is a reasonable solution, or something has already been done to address this type of problem.
The main steps are as follows:
context
), and we create a model wrapper which evolves the data with the base model and assign 0s to the derivatives of the context (in this way it remains constant through all the trajectory).context
input in the forward callcontext
for the model wrapper to work directly with ODE solvers such as torchdiffeqThe code is as follows:
We define the base model to take an additional input (context
), and we create a model wrapper which evolves the data with the base model and assign 0s to the derivatives of the context (in this way it remains constant through all the trajectory).
class ModelWrapper(nn.Module):
def __init__(self, base_model, context_dim=6):
"""
Wraps a base model to only evolve the first part of the input specifying a certain context using the model.
Args:
base_model (nn.Module): The base model to wrap.
"""
super(ModelWrapper, self).__init__()
self.base_model = base_model.eval()
self.context_dim = context_dim
def forward(self, t, x, **kwargs):
"""
Forward pass of the wrapped model.
Args:
t (torch.Tensor): The time tensor.
x (torch.Tensor): The input tensor: concatenation of [actual input, context].
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The output tensor.
"""
xt, context = x[:, :-self.context_dim], x[:, -self.context_dim:]
t_broadcasted = t.expand(x.shape[0], 1)
# Only evolve xt using the model (notice the additional input in the forward).
dxt_dt = self.base_model(xt, context=context, flow_time=t_broadcasted)
# Concatenate the derivatives of xt with zeros for context to keep their values unchanged
zeros_for_context = torch.zeros_like(context)
dx_dt = torch.cat([dxt_dt, zeros_for_context], dim=-1)
return dx_dt
Then in the training loop we do something like:
for i in range(0, len(X_train), batch_size):
X_batch = X_train[i : i + batch_size]
Y_batch = Y_train[i : i + batch_size] # NOTE: this is the context
optimizer.zero_grad()
x0 = noise_dist(X_batch.shape[0], X_batch.shape[1]).to(device)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, X_batch)
vt = model(xt, context=Y_batch, flow_time=t[:, None])
loss = torch.mean((vt - ut) ** 2)
train_loss += loss.item()
loss.backward()
optimizer.step()
# Update the progress bar
pbar.update(1)
pbar.set_postfix({"Batch Loss": loss.item()})
While for sampling new data from noise:
print("Starting sampling")
model.eval()
samples_list = []
# NOTE the call to model wrapper
sampler = ModelWrapper(model, context_dim=context_dim)
t_span = torch.linspace(0, 1, timesteps).to(device)
with torch.no_grad():
with tqdm(
total=len(X_test) // test_batch_size,
desc="Sampling",
dynamic_ncols=True,
) as pbar:
for i in range(0, len(X_test), test_batch_size):
Y_batch = Y_test[i : i + test_batch_size, :]
# protection against underflows in torchdiffeq solver
while True:
try:
x0_sample = noise_dist(len(Y_batch), X_test.shape[1]).to(
device
)
# NOTE the context is concatenated to the initial conditions
# for the wrapper to work
initial_conditions = torch.cat([x0_sample, Y_batch], dim=-1)
# NOTE we take only the last timestep
samples = odeint(
sampler,
initial_conditions,
t_span,
atol=1e-5,
rtol=1e-5,
method="dopri5",
)[timesteps - 1, :, : X_test.shape[1]]
This approach works perfectly for our use case. Do you think it's reasonable and efficient enough?
I would appreciate any guidance whatsoever, and if the solution seems interesting, I would be more than happy to work on a pull request!
Best regards
Francesco
Do these both refer to the same thing? You cite Liu but there's no RectifiedFlowMatcher class, for example.
Hi,
I am trying to use your package with the from torchcfm.models.unet.unet import SuperResModel
and other custom models that have kwargs in their forward method, but I think that the NeuralODE.trajectory method is not compatible with those models?
Could you please try to add a model_kwargs parameters to NeuralODE.trajectory, NeuralODE.forward, etc?
Thanks!
Massive thanks in advance!!
Hi and thanks for releasing the code for the fascinating work! I am getting my hand on this repo and found that I cannot get satisfactory results using the provided configs for sbcfm
, i.e.
python src/train.py trainer=cpu model=sbcfm`
The generated samples are just noise:
Meanwhile, it is able to generate good results using otcfm
and cfm
(for cfm
it may be needed to increase the max_epoch and hidden dimensions).
Would you give some suggestions on how to make this work? Thank you!
With POT 0.9.1 GPU memory is allocated on import in combination with deep learning packages see PythonOT/POT#523
PythonOT/POT#516
Recommend downgrading to POT < 0.9.1 for now.
I am opening this issue to discuss the road we should follow to the 1.1 version. Here is the list of the changes I would like to see:
Hello, thank you for your enlightening work, I was reading the code with some confusion arose, and I saw that the calculation of the optimal transport between noise and samples is implemented in the train_step() function in pytorch_lightning, but as far as I know, each GPU inside this function can only see the data on the current node. This means that if we use more than one GPU for training, the number of samples to calculate the optimal transmission is actually only total_batch_size/num_node, will this have an impact on the performance of the algorithm, and is there any way to calculate the optimal transport on total_batch_size samples when using more than one GPU?
Hello, thanks for the wonderful package. I am trying to do time series anomaly detection using normalizing flows/flow matching. My idea is to use a sliding window to obtain multiple sub-sequences of the time series and reconstruct the time series at each step and then use the reconstruction error to detect anomalies (so it's unsupervised). But I am running into an issue trying to do this.
This is what I tried so far:
ConditionalFlowMatcher
with a sigma of 0.0for epoch in range(20):
for i, batch_data in enumerate(train_loader):
optimizer.zero_grad()
x1 = batch_data[0].unsqueeze(1).to(device)
y = batch_data[1].to(device)
# print(y.shape)
x0 = torch.randn_like(x1)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
vt = model(t, xt, y)
print(vt.shape)
print(ut.shape)
loss = torch.mean((vt - ut) ** 2)
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, steps: {i}, loss: {loss.item():.4}", end="\r")
USE_TORCH_DIFFEQ = True
batch_size = 100 # Adjust as needed
time_series_length = 64 # Adjust to the length of your time series
initial_state = torch.randn(batch_size, 1, time_series_length, device=device)
generated_class_list = torch.arange(batch_size, device=device)
with torch.no_grad():
if USE_TORCH_DIFFEQ:
traj = torchdiffeq.odeint(
lambda t, x: model.forward(t, x, generated_class_list),
initial_state,
torch.linspace(0, 1, 2, device=device),
atol=1e-4,
rtol=1e-4,
method="dopri5",
)
else:
traj = node.trajectory(
initial_state,
t_span=torch.linspace(0, 1, 2, device=device),
)
The problem is when I visualize the reconstructed inputs, it doesn't seem to resemble anything compared to the input. Not sure what I am doing wrong, I know the problem doesn't have anything to do with the autoencoder model because I tested it separately. Any help is much appreciated.
plt.plot(traj[-1, :batch_size, 0, :].cpu()[10])
Reconstructed sample:
Input sample:
There is a bug when we set sigma to be an integer (FM = ConditionalFlowMatcher(sigma=0)). It modifies the type of the time tensor t.
Hi, thanks for your excellent paper and code. It seems that you didn't use the entropy regularisation for sbcfm in the notebook examples. For example, in training-8gaussians-to-moons.ipynb:
FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
"""Child class for Schrödinger bridge conditional flow matching method. This class implements
the SB-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class.
It overrides the compute_sigma_t, compute_conditional_flow and
sample_location_and_conditional_flow functions.
"""
def __init__(self, sigma: float = 1.0, ot_method="exact"):
r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper-
parameter $\sigma$ and the entropic OT map.
Parameters
----------
sigma : float
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
self.sigma = sigma
self.ot_method = ot_method
self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)
The SchrodingerBridgeConditionalFlowMatcher
has ot_method="exact"
by default. That means sbcfm exactly solves OT without entropy regularisation. Should we use SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, method='sinkhorn')
, rather than SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.