Giter Site home page Giter Site logo

undark-lab / swyft Goto Github PK

View Code? Open in Web Editor NEW
157.0 11.0 13.0 389.6 MB

A system for scientific simulation-based inference at scale.

License: Other

Python 0.68% Jupyter Notebook 99.32%
likelihood-free-inference simulation-based-inference marginal-neural-ratio-estimation neural-ratio-estimation machine-learning parameter-estimation python pytorch truncated-neural-ratio-estimation

swyft's People

Contributors

a-e-cole avatar adam-coogan avatar anchal-009 avatar bkmi avatar cweniger avatar fnattino avatar meiertgrootes avatar noemiam avatar rogerkuou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

swyft's Issues

Using different models in the same zarr storage?

This is a follow-up question we have for issue #30. created this issue just to keep this question documented. In #30, the Zarr storage was started from an incorrect model, and we want to continue using the same storage after correcting the model.

This makes me wonder if we are potentially facing a case where the original model was correct, and we want to change the model but keep using the same storage?

zarr storage is slow

EDITED:
DirectoryStore is slow, presumably due to saving to disk. However, MemoryStore is also really slow using some parameters. (namely when you use it for a dataset which does not have any parallelism, or few threads nworkers=0 or nworkers=1.)

We need to figure out why and how the user can avoid this behavior.

ask me for proof of this if you need.

create more flexible data storage (xarray?)

Dictionaries allow for clearly named parameters (useful for analysis). Slicing an array is very natural for the computational side of things (marginalizing a sample is merely a matter of slicing). Both of these methods are less than optimal in the other use case.

It would be helpful to have an array-like object that is easily passed to a simulator while also retaining the named identity of parameters for human use.

A necessary feature is to easily select something like slices of this array-like object (hopefully both by column/row and by parameter name).

One possible solution to this would be to use xarrays which have dask and zarr integration.

Error involving dask and pyro

There's a conflict involving dask and pyro models that use guides. See this gist for a script to reproduce with used_cond_model = True. I've tested this on Lisa.

The errors should look like:

distributed.worker - WARNING - Compute Failed
Function:  execute_task
args:      ((<function load_store_chunk at 0x150830422f70>, (subgraph_callable, 'image', (subgraph_callable, (<function concatenate_axes at 0x150830423c10>, [array([[1.82431417],
       [1.26026485],
       [1.08029839],
       [1.27862527],
       [1.97171492],
       [1.53252534],
       [1.85259017],
       [1.32079712],
       [1.32101153],
       [1.12730499],
       [1.10410938],
       [1.97516082],
       [1.61328824],
       [1.71373738],
       [1.58564648],
       [1.78064157],
       [1.78559028],
       [1.18012141],
       [1.32523062],
       [1.77047335],
       [1.89087196],
       [1.4552674 ],
       [1.94671355],
       [1.92100448],
       [1.73849129],
       [1.36837829],
       [1.15413041],
       [1.28342286],
       [1.47284077],
       [1.97885167],
       [1.51724013],
       [1.22956549],
       [1.74454764],
       [1.02810834],
       [1.19016476],
       [1.58492559],
       [1.39366463],
       [1.07402483],
       [1.70527518],
       [1.9074943 ],
       [1.081
kwargs:    {}
Exception: IndexError('invalid index to scalar variable.')

distributed.worker - WARNING - Compute Failed
Function:  execute_task
args:      ((subgraph_callable, (subgraph_callable, (<function concatenate_axes at 0x150830423c10>, [array([[1.82431417],
       [1.26026485],
       [1.08029839],
       [1.27862527],
       [1.97171492],
       [1.53252534],
       [1.85259017],
       [1.32079712],
       [1.32101153],
       [1.12730499],
       [1.10410938],
       [1.97516082],
       [1.61328824],
       [1.71373738],
       [1.58564648],
       [1.78064157],
       [1.78559028],
       [1.18012141],
       [1.32523062],
       [1.77047335],
       [1.89087196],
       [1.4552674 ],
       [1.94671355],
       [1.92100448],
       [1.73849129],
       [1.36837829],
       [1.15413041],
       [1.28342286],
       [1.47284077],
       [1.97885167],
       [1.51724013],
       [1.22956549],
       [1.74454764],
       [1.02810834],
       [1.19016476],
       [1.58492559],
       [1.39366463],
       [1.07402483],
       [1.70527518],
       [1.9074943 ],
       [1.08190857],
       [1.05591684],
       [1.66436089],

kwargs:    {}
Exception: IndexError('invalid index to scalar variable.')

review checklist + submit paper

General overview

Review criteria

Review checklist

Conflict of interest

Code of Conduct

General checks

  • Repository: Is the source code for this software available at the repository url?

  • License: Does the repository contain a plain-text LICENSE file with the contents of an OSI approved software license?

  • Version: Does the release version given match the GitHub release (v0.8)?

  • Authorship: Has the submitting author (@bbengfort) made major contributions to the software? Does the full list of paper authors seem appropriate and complete?

Functionality

  • Installation: Does installation proceed as outlined in the documentation?

  • Functionality: Have the functional claims of the software been confirmed?

  • Performance: If there are any performance claims of the software, have they been confirmed? (If there are no claims, please check off this item.)

Documentation

  • A statement of need: Do the authors clearly state what problems the software is designed to solve and who the target audience is?

  • Installation instructions: Is there a clearly-stated list of dependencies? Ideally these should be handled with an automated package management solution.

  • Example usage: Do the authors include examples of how to use the software (ideally to solve real-world analysis problems).

  • Functionality documentation: Is the core functionality of the software documented to a satisfactory level (e.g., API method documentation)?

  • Automated tests: Are there automated tests or manual steps described so that the function of the software can be verified?

  • Community guidelines: Are there clear guidelines for third parties wishing to 1) Contribute to the software 2) Report issues or problems with the software 3) Seek support

Software paper

  • Authors: Does the paper.md file include a list of authors with their affiliations?

  • A statement of need: Do the authors clearly state what problems the software is designed to solve and who the target audience is?

  • References: Do all archival references that should have a DOI list one (e.g., papers, datasets, software)?

Apply inverse transform to posterior samples?

Right now, we generate posterior samples which are on the hypercube. Similarly, our plotting functions expect the truth to be on the hypercube as well.

This is a bit odd since the user expects the parameter space to correspond to their prior, not the hypercube prior. If we wanted, reporting values on the actual prior could be automated when the user gives us u(theta) and v(theta), where u^{-1} = v.

introduce credible interval testing function

We should include this as a basic function in swyft. It needs semantically meaningful code too. What's written below is not (easily) understandable.

--

Line 19 in https://github.com/bkmi/tmnre/blob/main/torus/metrics/plot-torus.ipynb

for k, ccount in toplot.items():
    fig, ax = plt.subplots(figsize = (width, height))
    ax.plot(np.linspace(0, 100, len(ccount)), ccount*100)
    #plt.plot(np.linspace(0, 100, len(ccount)), (ccount+ccount_err)*100)
    ax.set_ylim([0, 100])
    f68 = np.interp(0.68, np.linspace(0, 1, len(ccount)), ccount)
    f95 = np.interp(0.95, np.linspace(0, 1, len(ccount)), ccount)
    ax.axvline(68, color='k', ls=':')
    ax.axvline(95, color='k', ls=':')
    ax.axhline(f68*100, color='k', ls=':')
    ax.axhline(f95*100, color='k', ls=':')
    ax.set_xlabel("HPDI Nominal Credibility for " + r"$\theta" + f"_{int(k) + 1}$ [%]")
    ax.plot([0, 100], [0, 100])
    ax.set_ylabel("HDI Empirical Credibility [%]")
    print(f68, f95)
    fig.savefig(f"figures/torus-ci-{int(k)+1}.png", bbox_inches="tight")

HDI ValueError in plot.py

First off, this issue only occurs when the model has high certainty (probably due to overfitting, or too simple a problem).

The problem can occur in lines 123-125. This is because zm[v > levels[i]] can be empty if v has all the mass in one bin. This can be fixed by changing to >=. Maybe even better to do a try catch with a message to increase the number of bins or check for overtraining.

-Kees

RuntimeError in Examples - 1. Head networks

I am running swyft on an environment with pytorch 1.9.0 and swyft 0.2.0 installed. The example Jupyter notebooks all fail when post.train() is called. The notebook Examples - 1. Head networks returns the following stack trace in cell 11:

`---------------------------------------------------------------------------
Empty Traceback (most recent call last)
~\Anaconda3\envs\nre\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
989 try:
--> 990 data = self._data_queue.get(timeout=timeout)
991 return (True, data)

~\Anaconda3\envs\nre\lib\queue.py in get(self, block, timeout)
177 if remaining <= 0.0:
--> 178 raise Empty
179 self.not_empty.wait(remaining)

Empty:

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
in
1 marginals = [0, 1, 2]
2 post.add(marginals, device = DEVICE, head = Head)
----> 3 post.train(marginals, max_epochs = 10)

~\Anaconda3\envs\nre\lib\site-packages\swyft\inference\posteriors.py in train(self, marginals, batch_size, validation_size, early_stopping_patience, max_epochs, optimizer, optimizer_args, scheduler, scheduler_args, nworkers, non_blocking)
149 )
150
--> 151 re.train(self._dataset, trainoptions)
152
153 def train_diagnostics(self, marginals: MarginalsType):

~\Anaconda3\envs\nre\lib\site-packages\swyft\inference\ratios.py in train(self, dataset, trainoptions)
97 self.tail.train()
98
---> 99 diagnostics = trainloop(
100 head=self.head,
101 tail=self.tail,

~\Anaconda3\envs\nre\lib\site-packages\swyft\inference\train.py in trainloop(head, tail, dataset, trainoptions, device)
204 drop_last=True,
205 )
--> 206 tl, vl, sd_head, sd_tail = do_training(
207 head, tail, train_loader, valid_loader, trainoptions, device
208 )

~\Anaconda3\envs\nre\lib\site-packages\swyft\inference\train.py in do_training(head, tail, train_loader, validation_loader, trainoptions, device)
109 head.train()
110 tail.train()
--> 111 train_loss = do_epoch(train_loader, True)
112 train_losses.append(train_loss / n_train_batches)
113

~\Anaconda3\envs\nre\lib\site-packages\swyft\inference\train.py in do_epoch(loader, train)
75 training_context = suppress() if train else torch.no_grad()
76 with training_context:
---> 77 for batch in loader:
78 optimizer.zero_grad()
79 sim, z, _ = batch

~\Anaconda3\envs\nre\lib\site-packages\torch\utils\data\dataloader.py in next(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \

~\Anaconda3\envs\nre\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
1184
1185 assert not self._shutdown and self._tasks_outstanding > 0
-> 1186 idx, data = self._get_data()
1187 self._tasks_outstanding -= 1
1188 if self._dataset_kind == _DatasetKind.Iterable:

~\Anaconda3\envs\nre\lib\site-packages\torch\utils\data\dataloader.py in _get_data(self)
1140 elif self._pin_memory:
1141 while self._pin_memory_thread.is_alive():
-> 1142 success, data = self._try_get_data()
1143 if success:
1144 return data

~\Anaconda3\envs\nre\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
1001 if len(failed_workers) > 0:
1002 pids_str = ', '.join(str(w.pid) for w in failed_workers)
-> 1003 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
1004 if isinstance(e, queue.Empty):
1005 return (False, None)

RuntimeError: DataLoader worker (pid(s) 2540, 17392) exited unexpectedly`

Am I simply running the wrong version of pytorch or swyft? What is the origin of this error?

Problems initializing networks when simhook changes the shape of data

The function self._init_networks in ratios.py takes the first element of the dataset to determine the shapes used in the network. This is a problem when simhook affects the shape of array. This happens in gw inference because we save the frequency domain simulation but we do inference on the time domain version with noise.

I'll try to fix it, although I think this is another piece of evidence that we shouldn't lazy initialize the networks. Perhaps we can instead create a function which does the saving / loading we're trying to achieve here. What do you think @cweniger?

Make DirectoryStore & Dataset save to use

Right now calling dataset on the store will write to the store if the wrong priors are specified. I suggest the following changes:

  • Introduce mode = 'r' for directory store, so that we have a read-only configuration
  • Make Dataset fail when points are requested that are not there (rather than adding them to the store).
  • Add an additional add method to the store, which explicitly adds points for a given prior.
  • Add some check method that allows to check whether a specific prior is already covered.
  • Add some info method that provides information about the contained priors and associated number of samples in the store
  • Add the ability to remove the latest prior and points.

Consistency between pre/post refactor

I'm not sure whether I should expect consistent results between runs I made the the "pre-refactor" code and the refactor code, but I'm finding differences (refactor less accurate).

Previously I would run

s = swyft.NestedRatios(
        model,
        prior,
        noise=noise,
        obs=obs,
        device='cuda:0',
        Ninit=5000,
        cache=cache,
    )

s.run(
    max_rounds=1,
    train_args={'max_epochs': 50},
    head=Head10
    )

Where Head10 is a linear compression into 10 features. This gives me great results (I rescale the following plot via an expected 1-sigma)
image

Now, I run what I believe to be the equivalent code

micro = swyft.Microscope(partitions, prior, obs, store = store, simhook = noise, device = 'cuda',
                        Ninit=5000,
                        head=Head10)
micro.focus(max_rounds=1)

which yields the analogous plot (same rescaling applied).
image
Note that I am not actually zooming, just training for one round. While the tip of the second plot is perhaps prettier, the 1 and 2 sigma are less accurate. Any ideas as to why this is happening?

safety check on store and simulator consistency

for example zdim and the pnames should be the same

swyft/swyft/store/store.py

Lines 466 to 507 in 07a5bff

class DirectoryStore(Store):
"""Instantiate DirectoryStore.
Args:
path (PathType): path to storage directory
simulator (swyft.Simulator): simulator object
sync_path: path for synchronization via file locks (files will be stored in the given path).
It must differ from path, it must be accessible to all processes working on the store,
and the underlying filesystem must support file locking.
Example::
>>> store = swyft.DirectoryStore(PATH_TO_STORE)
>>> print("Number of simulations in store:", len(store))
"""
def __init__(
self, path: PathType, simulator=None, sync_path: Optional[PathType] = None
):
zarr_store = zarr.DirectoryStore(path)
sync_path = sync_path or os.path.splitext(path)[0] + ".sync"
super().__init__(
zarr_store=zarr_store, simulator=simulator, sync_path=sync_path
)
class MemoryStore(Store):
"""Instantiate a new memory store for a given simulator.
Args:
simulator (swyft.Simulator): Simulator object
.. note::
The swyft.MemoryStore is in general expected to be faster than
swyft.DirectoryStore, and useful for quick explorations, or for
loading training data into memory before training.
Example::
>>> store = swyft.MemoryStore(simulator)
"""

reduce number of dictionaries returned, replaced with dataclasses

Why?

  1. As a user, you have to know the names of the keys to be able to access the data. This is annoying when you're writing code, especially in a non-interactive setting.

  2. As a code designer, we already know what should be in the dictionary. If there is nothing dynamic about the keys, we should handle it.

Where?
I will have to put together a list of objects which would be useful to do this. The first thing that comes to mind are the returned samples with params and weights. This comes up throughout the code though

rewrite git history to remove jupyter notebook outputs

Our repo is needlessly large because of the excessive saves of notebooks from previous git commits. This can be undone. We should do so.

A method can be found here:
https://mg.readthedocs.io/git-jupyter.html#cleaning-a-whole-repository

@cweniger are you fine with this? If you want to save the output of the current notebooks we can keep them in the next version release.

In general, we should avoid changing the notebook outputs if possible. One option would be to clear them automatically with pre-commit.

LazyLinear etc

We should explore if replacing some of the network structures with Lazy initialization could simplify writing head-networks etc. We might get around having to specify self.n_features in that way.

Clarify types

@cweniger I left a set of characters for you. Please search all documents for # TODO Christoph typing.

This should be done in #41

example for eScience contributions

create an example highlighting the new features the eScience center added. Thereby users understand how to use your contributions.

consider doctest style and/or a notebook.

introduce new numpy random_state, especially in store, for reproducibility

For reproducibility there is the option to split the seed, we should use it.

https://numpy.org/doc/stable/reference/random/index.html?

swyft/swyft/store/store.py

Lines 94 to 131 in 07a5bff

def add(self, N, prior, bound=None):
"""Adds points to the store.
Args:
N (int): Number of samples
prior (swyft.Prior): Prior
bound (swyft.Bound): Bound object for prior truncation
.. warning::
Calling this method will alter the content of the store by adding
additional points. Currently this cannot be reverted, so use with
care when applying it to the DirectoryStore.
"""
pdf = swyft.TruncatedPrior(prior, bound)
# Lock store while adding new points
self.lock()
self._update()
# Generate new points
z_prop = pdf.sample(N=np.random.poisson(N))
log_lambda_target = pdf.log_prob(z_prop) + np.log(N)
log_lambda_store = self.log_lambda(z_prop)
log_w = np.log(np.random.rand(len(z_prop))) + log_lambda_target
accept_new = log_w > log_lambda_store
z_new = z_prop[accept_new]
log_w_new = log_w[accept_new]
# Anything new?
if sum(accept_new) > 0:
# Add new entries to store
self._append_new_points(z_new, log_w_new)
print("Store: Adding %i new samples to simulator store." % sum(accept_new))
# Update intensity function
self.log_lambdas.resize(len(self.log_lambdas) + 1)
self.log_lambdas[-1] = dict(pdf=pdf.state_dict(), N=N)
log.debug(f" total size of simulator store {len(self)}.")

rename store to cache and fix saving, loading, and copying of cache (store)

We use a cache. it is implemented in a zarr store. Calling them the same thing when they are conceptually different is confusing and leads to longer variable names. E.g. store, zarr_store, zarr_memory_store memory_store which is which? Exactly.

What about when instantiating one or the other?

def copy(self, sync_path=None):
    zarr_store = zarr.MemoryStore()
    zarr.convenience.copy_store(source=self.zarr_store, dest=zarr_store)
    return MemoryStore(
        params=self.params,
        zarr_store=zarr_store,
        simulator=self._simulator,
        sync_path=sync_path,
    )

Which one is the zarr store? Which one is our "MemoryStore" aka Cache?

We already had tests for saving and loading the caches. Please reinstate them and make the implementations comply with the spirit of the tests. (that might require rewriting the tests)

Introduce greater modularity into swyft by clarifying the roles of subsections

What do we have now? What do we want?

As it stands, there is essentially only one access point to swyft and that is via NestedRatios. Having this interface is acceptable; however, it would be better if NestedRatios was built out of smaller pieces which clearly delineate which aspect of the inference process they accomplish. Furthermore, users could then define their own version of NestedRatios which suits their needs.

Definitions

I will introduce a few terms to clarify what I mean. The Problem is defined by everything necessary to compute the exact posterior theoretically, namely the prior, the simulator, and the observation. A another structure will be called the Task, i.e. a prior and the simulator. (Emphasis could be another prior, namely a constrained prior.) In order to approximate the posterior for the problem, NestedRatios solves a series of tasks, the previous one defining the next task's prior. Each of those tasks is solved using ratio estimation, abstractly this could be called an Algorithm. We call NestedRatios, the method used to solve the problem a strategy.

The Workflow

From a conceptual point of view, every part (simulator, prior, observation) of a problem is independent and should be separated in code. This is what we are actually trying to solve, our "true beliefs" are reflected in the prior.

An algorithm consumes a task (either conceptually or programmatically) and produces an estimated ratio. In combination with an observation, we can then use that estimated ratio either for inference, to define a posterior, or to define a constrained prior, yielding a new task. The details of when to do what, defines out strategy.

Technical Aspects / Data Structures

Previously I said,

(a) The prior, constrained prior, and the posterior are all marginalized probability distributions. Their similarity suggests a similar API; however, instantiation must be different for each.

but I update my statement to
(a1) The prior and constrained prior are general (factorizable) distributions. It would be nice to use some common APIs rather than "reinventing the square wheel.
(a2) The posterior is ONLY a marginalized probability distribution. Although it would be good to have a similar API, it will need to be custom since each marginal must be dealt with separately.

(b) In order to reduce calls to the simulator, we utilize the cache. Drawing from the cache is implicitly drawing from the joint distribution p(x, theta). Fundamentally, it is a database where parameter samples and corresponding simulations are stored. Dealing with accessing this should be abstracted from its use. If I have a prior, implemented using a cache, I want to be able to treat the prior just like a normal marginalized probability distribution.

(c) A simulator, a prior, and the neural network define ratio estimation. The output is a sort of "proto-posterior." A trained RatioEstimator, combined with an observation and a prior, should produce a posterior. One ought to be able to define a constrained prior by excluding regions of low probability from this posterior.

(d) This is where it loops around. We can define a constrained prior, given a RatioEstimator, prior, and an observation since we still have the simulator. These two (simulator, constrained prior) produce a new task. At this highest level, utilizing only calls to constituent structures, we can recreate the strategy NestedRatios.

Clarification about existing code

This document outlines the workflow and highlights technical challenges. How are these technical challenges overcome in the code as it exists now? I have grouped swyft into subfolders so that each part is somewhat separated from the other parts. Here I enumerate which part is which.

(a) The folder swyft/marginals hopes to accomplish this. Given (a1) and (a2), perhaps this should be divided between the two cases.

(b) This is currently divided between swyft/cache and swyft/ip3. swyft/cache should handle actually dealing with the database while swyft/ip3 deals with the mathematical aspects that allow for the prior to be reformulated as a Poisson Point Process. Ideally, this "underwrites" prior and is mostly invisible to the user.

(c) This is all in swyft/inference, including the subfolders. The subfolders are there to separate out the training / design / hyperparameters of the algorithm and the execution of the algorithm.

(d) Right now this is just the file swyft/nestedratios.py.

The rest, including swyft/nn, swyft/utils, swyft/plot.py are minimal NN pieces, very broad utilities, and the plotting scripts respectively. These could also use a refactor, but I think it only makes sense once we've accomplished the above.

Long term saving and loading standards - using TorchScript?

I propose that we operate under the assumption that all of our users will be able to instantiate the class again, then load the parameters using state_dict, as usual in pytorch. Why? This is how pytorch works and we are not better at coding that the devs of pytorch. If we try to support this, we will only answering corner cases for individuals who should be able to handle this themselves.

To address the implicit main question of using swyft in production, namely what if I want to use the ratio estimator fast and without access to the defining python code? Well there is an answer for that, TorchScript. I propose that we create a long term goal of supporting the export of our ratio estimators to TorchScript for production use, i.e., integrating into the gw detection pipeline, etc.

Simulation hangs if it was previously interrupted

If the simulator encounters an exception before it finishes, subsequent simulator runs will hang. For example, consider running this simulation:

import numpy as np
import swyft

prior = swyft.Prior(lambda u: 2 * u - 1, zdim=1)

def broken_model(v):
    raise ValueError("oops!")

sim = swyft.Simulator(broken_model, sim_shapes={"obs": (1,)})
store = swyft.DirectoryStore(["x"], "test.zarr", simulator=sim)
task = swyft.Task(10, prior, store)
task.simulate()

By design, it immediately fails. But then if you fix the model and simulate again with

def model(v):
    return {"obs": v**2}

sim = swyft.Simulator(model, sim_shapes={"obs": (1,)})
store = swyft.DirectoryStore(["x"], "test.zarr", simulator=sim)
task = swyft.Task(20, prior, store)
task.simulate()

the simulation will finish running and hang, with the following stack trace after a keyboard interrupt:

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-2-240aff058111> in <module>
      5 store = swyft.DirectoryStore(["x"], "test.zarr", simulator=sim)
      6 task = swyft.Task(20, prior, store)
----> 7 task.simulate()

~/lensing/swyft/swyft/inference/task.py in simulate(self)
     38 
     39     def simulate(self):
---> 40         self.dataset.simulate()
     41 
     42     def train(self, marginals, train_args={}):

~/lensing/swyft/swyft/store/dataset.py in simulate(self, batch_size, wait_for_results)
     68         if self._no_store():
     69             return
---> 70         self._store.simulate(
     71             self.indices, batch_size=batch_size, wait_for_results=wait_for_results
     72         )

~/lensing/swyft/swyft/store/store.py in simulate(self, indices, batch_size, wait_for_results)
    385 
    386         if wait_for_results:
--> 387             self.wait_for_simulations(indices)
    388 
    389     def wait_for_simulations(self, indices):

~/lensing/swyft/swyft/store/store.py in wait_for_simulations(self, indices)
    395         done = False
    396         while not done:
--> 397             time.sleep(1)
    398             status = self.get_simulation_status(indices)
    399             done = np.isin(status, [SimulationStatus.FINISHED, SimulationStatus.FAILED])

KeyboardInterrupt: 

It seems like this requires deleting the store and restarting the analysis.

I'm guessing this can be fixed by putting a finally statement somewhere to switch PENDING to FAILED when the simulator throws an exception.

Accelerate sampling in store

The sampling method of the store is very slow, presumably because of inefficient iteration over the model points. Needs to be investigated and potentially fixed.

Make Head and Tail networks simpler.

Right now both head and tail networks have quite some boiler plate code, and inherit from swyft.Module. This needs to change in order to make the code more accessible.

Proposed changes:

  • Inherit from torch.Module
  • n_features and calling super().init() will become irrelevant
  • Save initialization parameters together with ratios, instead as part of the network
  • The connection layer between head and tail should be decoupled from the tail, making it much easier to define custom tails.

Clean up store and simulator

  • DaskSimulator as derived Simulator object
  • check and fix docstrings
  • remove import of non-used object
  • check functionality of from_command method, to run command-line simulators.

Add TransformedDataset class in order to enable inference for stochastic states

Goal: Strategy for generating posteriors for parameters with implicit priors.

  • Parameters are part of the model prediction
  • Original priors and simulations are defined as before
  • Dataset is the thing that provides information about priors, bounds and parameter names
    • We should tweak the dataset to change the expected results
  • We could have some transformed dataset?
    • Overwrite: pnames, v, bound, prior
    • The prior would have to be defined empirically, based on the distribution of the parameter in the training set.

Suggestion: TransformedDataset(dataset, sim_to_v, pnames, eff_prior, eff_bound)

use array_to_tensor everywhere

Right now, everything in our code forces the use of float32. This may not be appealing for some users, and is, in general, bad practice.

If we made more use of torch builtins, instead of numpy, including this change would allow us to generate some tensors directly on the gpu.

--

Update: The goal here is to allow pytorch standard use of device and dtype.

error handling in store

When this function requests too many samples from the store it returns an empty list.

swyft/swyft/store/store.py

Lines 282 to 320 in 0ab5832

def sample(
self,
N: int,
prior: "swyft.Prior",
bound: Optional["swyft.Bound"] = None,
check_coverage: bool = True,
add: bool = False,
) -> np.ndarray:
"""Return samples from store.
Args:
N (int): Number of samples
prior (swyft.Prior): Prior
bound (swyft.Bound): Bound object for prior truncation
check_coverage (bool): Check whether requested points are contained in the store.
add (bool): If necessary, add requested points to the store.
Returns:
Indices (list): Index list pointing to the relevant store entries.
"""
if add:
if self.coverage(N, prior, bound=bound) < 1:
self.add(N, prior, bound=bound)
if check_coverage:
if self.coverage(N, prior, bound=bound) < 1.0:
print("WARNING: Store does not contain enough samples.")
return []
pdf = swyft.TruncatedPrior(prior, bound)
self._update()
# Select points from cache
z_store = self.v[:]
log_w_store = self.log_w[:]
log_lambda_target = pdf.log_prob(z_store) + np.log(N)
accept_stored = log_w_store <= log_lambda_target
indices = np.array(range(len(accept_stored)))[accept_stored]
return indices

The problem is that this function, when receiving an empty list, says that there are not enough indices.

self._indices = store.sample(N, prior, bound=bound)
if len(self._indices) == 0:
raise RuntimeError("Not enough simulations in store")

This can also happen when there is an intensity function mismatch. It's not a very helpful error.

Remove Pickel dependence

Right now we use numcodecs.Pickle for encoding the lambda functions in the store. This caused compatibility problems between python versions 3.7.x and 3.8.x because changes in the associated default pickle version. We should try to get rid of pickle and replace with some more stable format like JSON? Simply replacing Pickle() --> JSON() does however not work, since there seem to be limitations on the JSON side.

Potential storage corruption of the storage

This is a related issue of #30. In #30 we fixed updating the storage status when a simulation fails within the model. However, this does not solve some specific cases. For example, if there is a keyboard interruption occurring within the model execution, the try-catch part will not be reached. In this case, the storage status of those simulations will stay Running. This can potentially corrupt storage.

A possible solution is to add a manual clean-up function to storage to remove the stalled simulations. One can also think of filtering out the stalling simulations by the time stamp of simulations.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.