Giter Site home page Giter Site logo

pzflow's People

Contributors

dependabot[bot] avatar jfcrenshaw avatar vladdoster avatar yanzastro avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pzflow's Issues

Divergent flows on subsets of photometric redshift data

Description

Training pzflow on combined photometric redshift catalogs produces divergent flows on subset catalogs. For example, a catalog consisting of filter magnitudes for Euclid + LSST has (nearly) monotonically decreasing loss while training only on p(z | Euclid filters) diverges.

Screenshots

Output of training:

Screenshot 2023-11-02 at 3 58 53 PM

Catalog definitions for unsuccessful training:

Euclid-only

And superset which is successful:

LSST+Euclid

to reproduce:

Definition of the flow ensembles:

flow ensembles

Investigate optimizing jax code

Might be able to further optimize jax code by using lax.cond to control if statements (see e.g. here), and by jitting bijectors/class methods

Is it possible to draw conditional samples?

Hello,

I was wondering if the following is possible with pzflow.

Suppose we have a 2D dataset e.g (x, y) and we fit a normalising flow model to it.

Can we draw samples from p(x|y=some value) ?

Thanks!

Batching for posterior

Computing the posteriors quickly is memory intensive. For large data sets, it would be helpful to calculate posteriors in batches.

AttributeError: module 'jax' has no attribute 'Array'

When running the tutorial code, I got the following errors.
I am seeking a solution for it.
jax version 0.4.12
jaxlib version 0.4.12

AttributeError Traceback (most recent call last)
/tmp/ipykernel_301829/2238758520.py in <cell line: 2>()
1 ## Introuction
----> 2 from pzflow import Flow
3 from pzflow.examples import get_twomoons_data
4 import jax.numpy as jnp
5 import matplotlib.pyplot as plt

~/anaconda3/envs/nf/lib/python3.8/site-packages/pzflow/init.py in
1 """Import modules and set version."""
----> 2 from pzflow.flow import Flow
3 from pzflow.flowEnsemble import FlowEnsemble
4
5 version = "3.1.1"

~/anaconda3/envs/nf/lib/python3.8/site-packages/pzflow/flow.py in
5 import jax.numpy as jnp
6 import numpy as np
----> 7 import optax
8 import pandas as pd
9 from jax import grad, jit, random

~/anaconda3/envs/nf/lib/python3.8/site-packages/optax/init.py in
15 """Optax: composable gradient processing and optimization, in JAX."""
16
---> 17 from optax import experimental
18 from optax._src.alias import adabelief
19 from optax._src.alias import adafactor

~/anaconda3/envs/nf/lib/python3.8/site-packages/optax/experimental/init.py in
18 """
19
---> 20 from optax._src.experimental.complex_valued import split_real_and_imaginary
21 from optax._src.experimental.complex_valued import SplitRealAndImaginaryState
22 from optax._src.experimental.extra_args import GradientTransformationWithExtraArgs

~/anaconda3/envs/nf/lib/python3.8/site-packages/optax/_src/experimental/complex_valued.py in
30 from typing import NamedTuple, Union
31
---> 32 import chex
33 import jax
34 import jax.numpy as jnp

~/anaconda3/envs/nf/lib/python3.8/site-packages/chex/init.py in
15 """Chex: Testing made fun, in JAX!"""
16
---> 17 from chex._src.asserts import assert_axis_dimension
18 from chex._src.asserts import assert_axis_dimension_comparator
19 from chex._src.asserts import assert_axis_dimension_gt

~/anaconda3/envs/nf/lib/python3.8/site-packages/chex/_src/asserts.py in
24 from unittest import mock
25
---> 26 from chex._src import asserts_internal as _ai
27 from chex._src import pytypes
28 import jax

~/anaconda3/envs/nf/lib/python3.8/site-packages/chex/_src/asserts_internal.py in
32
33 from absl import logging
---> 34 from chex._src import pytypes
35 import jax
36 from jax.experimental import checkify

~/anaconda3/envs/nf/lib/python3.8/site-packages/chex/_src/pytypes.py in
25
26 # For instance checking, use isinstance(x, jax.Array).
---> 27 ArrayDevice = jax.Array
28
29 # Types for backward compatibility.

AttributeError: module 'jax' has no attribute 'Array'

Refactor `Flow` and `FlowEnsemble`

Some refactoring needs to be done. In particular, look at the __init__ and posterior methods, as these are quite large and unwieldly.

Question about Dequantization

Hi,

I'm experimenting with pzflow on some toy tabular data with 2 continuous columns and 1 categorical column.

I'm comparing your UniformDequantizer with a naive implementation where I do the dequanization and quantization outside of the training process. Specifically I had a uniform noise [U(0, 1)] to the categorical column and feed this to pzflow. Then when sampling I just floor the categorical column. I'm expecting this to behave in the same way to the UniformDequantizer but I'm not getting as good results when using the UniformDequantizer.

Do you mind if I share some python code / notebook comparing the two methods for you to take a look?

Many thanks

Overflow with >10 variables

Hi,
pzflow seems to work fantastically. I am wondering if there is a hard limit on the nr. of columns that it can handle. Whenever I feed a table with 10 or more columns I get an overflow error when doing flow.train(). This is independent of the dataset and even when all paremeters, bijectors, etc. are set at their default values. Thanks.


OverflowError Traceback (most recent call last)
in

~/a3/envs/tenf/lib/python3.9/site-packages/pzflow/flow.py in train(self, inputs, epochs, batch_size, optimizer, loss_fn, convolve_errs, patience, seed, verbose, progress_bar)
999 X = jnp.array(inputs[columns].to_numpy())
1000 C = self._get_conditions(inputs)
-> 1001 losses = [loss_fn(model_params, X, C)]
1002 if verbose:
1003 print(f"(0) {losses[-1]:.4f}")

[... skipping hidden 11 frame]

~/a3/envs/tenf/lib/python3.9/site-packages/pzflow/flow.py in loss_fn(params, x, c)
941 @jit
942 def loss_fn(params, x, c):
--> 943 return -jnp.mean(self._log_prob(params, x, c))
944
945 # initialize the optimizer

~/a3/envs/tenf/lib/python3.9/site-packages/pzflow/flow.py in _log_prob(self, params, inputs, conditions)
400 # calculate log_prob
401 u, log_det = self._forward(params[1], inputs, conditions=conditions)
--> 402 log_prob = self.latent.log_prob(params[0], u) + log_det
403 # set NaN's to negative infinity (i.e. zero probability)
404 log_prob = jnp.nan_to_num(log_prob, nan=jnp.NINF)

~/a3/envs/tenf/lib/python3.9/site-packages/pzflow/distributions.py in log_prob(self, params, inputs)
437
438 # calculate log_prob
--> 439 prob = mask / (2 * self.B) ** self.input_dim
440 prob = jnp.where(prob == 0, epsilon, prob)
441 log_prob = jnp.log(prob)

[... skipping hidden 1 frame]

~/a3/envs/tenf/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in deferring_binary_op(self, other)
4934 args = (other, self) if swap else (self, other)
4935 if isinstance(other, _accepted_binop_types):
-> 4936 return binary_op(*args)
4937 if isinstance(other, _rejected_binop_types):
4938 raise TypeError(f"unsupported operand type(s) for {opchar}: "

[... skipping hidden 6 frame]

~/a3/envs/tenf/lib/python3.9/site-packages/jax/_src/dtypes.py in _scalar_type_to_dtype(typ, value)
168 if typ is int and value is not None:
169 if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
--> 170 raise OverflowError(f"Python int {value} too large to convert to {dtype}")
171 return dtype
172

OverflowError: Python int 10000000000 too large to convert to int32

Better tests

Simplify/streamline tests, and switch onp/np -> np/jnp

Compilation time as a function of input dimensionality

Hi,

First of all, I would like to say how nice this package is. It has been very nice to use and experiment with using normalising flows for tabular data!

I'm interested in applying normalising flows on datasets that could have O(100) columns so I did a quick timing test and these are my results. The main finding is that the JIT compilation step starts to throw the following error when the number of columns is greater than about 30.

2023-03-10 10:40:40.407123: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 10m16.840481s

********************************
[Compiling module jit_step] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************

I have a minimal working example here if you would like to run something. It creates dummy data with 10 rows and an increasingly larger number of columns (features) from 2 to 100. Then I just instantiate a new Flow object and train it for 1 epoch and use that as an estimate of a compilation time. The whole code will take about 1.5 hours to run on my machine and the memory usage maybe got to 1GB or more (that estimate might be quite inaccurate though!)

So my questions are:

  • do you have any experience or guidance using pzflow on datasets with this many columns?
  • can you think of anything to improve the compilation performance for these cases?
  • do you agree that the number of samples (rows) in the dataset shouldn't impact the compilation time? I was thinking that it's just the number of columns that would be important, but I could be wrong! For example do you think the compilation time and memory usage would be significantly larger if I had a dataset with 1e6 rows? (I could just test this myself though!)
import jax

jax.config.update("jax_enable_x64", True)

from pzflow import Flow

import pandas as pd
import numpy as np
import datetime

def build_data(n_samples=10, n_features=2):
    feat_names = [f'x{i}' for i in range(n_features)]
    X = np.random.normal(size=(n_samples, n_features))
    df_X = pd.DataFrame({k:X[:,i] for i, k in enumerate(feat_names)})
    return df_X, feat_names

n_samples_arr = np.array([10]).astype(int)
n_features_arr = np.array([2, 5, 10, 15, 20, 30, 50, 100]).astype(int)
print(f"samples array: {n_samples_arr}")
print(f"features array: {n_features_arr}")


res=[]
for n_f in n_features_arr:
    for n_s in n_samples_arr:
        print(f"working: n_f: {n_f}, n_s: {n_s}")
        df_X, feat_names = build_data(n_s, n_f)
        flow = Flow(data_columns=feat_names)
        start_time = datetime.datetime.now()
        flow.train(df_X, verbose=False, epochs=1)
        end_time = datetime.datetime.now()
        duration = end_time - start_time
        res.append([n_s, n_f, duration])
df_res = pd.DataFrame(res, columns=['n_samples', 'n_features', 'duration'])

If you run the above code the output dataframe looks like this

Screenshot 2023-03-10 at 12 14 39

Thanks for any help!

Thoughts from implementing in Matlab

This is not really an issue. I needed something like this, but in Matlab (without calling out to Python), so I chose to crib this package. I've just published my code at [https://github.com/jeremylea/DLextras]. These are some thoughts from implementing things again in a different language and with a different target problem. I don't want to file each as a separate issue and pollute things, so this will be a mixed bag. But before that, thank you for developing this package. It's been a real help in solving a long-standing problem in my research, and I'll find several other uses for the idea soon. I didn't need the data error model for my work, so I didn't do that, but I might add it for future projects.

The first thing I found was that using (-B,B) as the domain didn't do anything, and using (0,1) everywhere simplified the code. This might

Using 1 for the fixed slopes at the ends of the splines did not work for me when using a uniform input. I needed this to be zero to obtain any reasonable result. If you're using a beta distribution, this would be less of an issue, but that was causing problems for me. I added controls for setting the end slopes to one, zero, or a learnable value, along with periodic. I tried to add a zero-inflated feature but have yet to get that to work.

I found it better to make smaller internal networks and stack more "rolls" on the bijector with fewer knots. That might be my data. However, I did find that scaling the number of nodes in the hidden layers in the internal networks down to input dim (so start with input_dim in the first layer, then linearly scale the number of nodes in each layer up to hidden_dimension in the last layer), seemed to make the networks more trainable. I also pre-seeded the bias in the output layer to generate a "diagonal" spline.

I added a latent layer that is a conditional beta using the same idea of an internal network to get the A and B parameters. This seemed to work well in testing, but for my problem, I found it better in the end to use a uniform latent distribution and a bijector layer with transformed_dim=data_dim, followed by layers with transformed_dim=1. This also required passing the conditions into some latent distribution functions.

I was having some issues with the training and tried a bunch of things... The first was to remove the hard limits on the spacing of the spline knots (I think that's gone in your current code) and add a penalty function for closely spaced knots, along with some seatbelts to prevent division by zero. I also added a penalty for high differences in derivatives at the knots and for having residual correlation in the forward results. These helped, but they caused the solution to find a local minimum. I made these penalties configurable in training, and they seemed to help stabilize the problem in early training, but then I removed them for most of the training.

I also found that learning one less bin width and height and replacing that with zero before the softmax helped to keep the hidden networks stable, and I scaled D by sk so that the learned slopes were proportional to the bin slopes.

I didn't implement the idea of patience (I think that was added after I started converting), but I'll probably do that. I did make the training loop save the best fit and restore that on exit.

I needed a weighting column, so I added that. I added other debugging features, like capturing the random input for the sampling, so you can run the same sample repeatedly throughout training to see only the learning changes.

There are probably other things, but I see them right now. Thanks again for a great package.

Change jax.numpy import to jnp

Currently I use

import jax.numpy as np
import numpy as onp

I should change this to

import jax.numpy as jnp
import numpy as np

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.