Giter Site home page Giter Site logo

aesara-devs / aesara Goto Github PK

View Code? Open in Web Editor NEW
1.2K 20.0 156.0 82.67 MB

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.

Home Page: https://aesara.readthedocs.io

License: Other

Shell 0.01% Python 96.22% C 1.68% CSS 0.02% HTML 0.04% Makefile 0.03% C++ 1.69% Cython 0.31%
symbolic-computation tensors theano automatic-differentiation transpiler aesara term-rewriting-system optimizing-compiler optimizing-compilers

aesara's People

Contributors

aalmah avatar abalkin avatar abergeron avatar affanv14 avatar amrithasuresh avatar ballasn avatar brandonwillard avatar breuleux avatar caglar avatar carriepl avatar chienlima avatar dwf avatar gvtulder avatar harlouci avatar hengjean avatar jaberg avatar jlowin avatar khaotik avatar lamblin avatar nicolasbouchard avatar notoraptor avatar nouiz avatar pascanur avatar reyhaneaskari avatar ricardov94 avatar rlouf avatar royxue avatar sentient07 avatar slefrancois avatar turian avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

aesara's Issues

Simplify/combine status checks

Our CI workflow sets up a bunch of independent jobs via the matrix settings, and each one of those creates a possible status check under this repo's branch protection settings. It would be better if we had a single job that could serve as a status check for all the matrix jobs, especially as we add/remove jobs via the matrix settings (e.g. new Python versions).

So far, this was the best advice I could find; it says to create a new job that simply checks the matrix build status.

Reconsider the graph object model

The constant caching severely complicates graph manipulation/"optimization" by—for instance—requiring that such constants be cloned in order to be used within more than one FunctionGraph. This cloning requirement ultimately extends to the graphs that contain these cached constants and the result is an overly complicated and wasteful process of object cloning and clone-to-original map management (and maps to and from further clones, in some cases).

First, how much does this caching actually save? Was this ever measured or was it just applied? It's implied that this caching improves the performance of the MergeOptimizer; is that the only thing it helps? Is it simply attempting to save resources by only performing an is check instead of a potentially costly descent into subgraphs for further comparisons? What problem is the FunctionGraph trying to avoid by not allowing cached variables, and is that problem only avoided by completely cloning graphs?

All this touches upon some very fundamental graph object model choices/inconsistencies revolving around object identity and equality. Theano graph objects do not consistently implement __eq__ (well, most simply don't), so there are numerous "local" work-arounds throughout the codebase to determine object equality to varying degrees—including the MergeOptimizer itself. Constant caching looks like just another example, since it's effectively assigning a singleton-like property to a subset of graph objects.

This—and many other issues and complexities—could be better addressed with a more consistent object model, and one that uses the built-in Python OOP features to implement it. (Especially since they were arguably designed for exactly these kinds of things.)

For instance, a base implementation of __eq__ is simple to provide, since—as usual—the core graph types can be mapped directly to S-expressions (see the Theano meta graph support in symbolic-pymc). Furthermore, if immutability is added to the model, hash generation and equality comparisons can be cached (and locally, by the the objects themselves).

Alternatively, unique IDs/names can be used to simplify identity and equality comparisons, with some severe limitations, though.

Anyway, we should start considering changes like this now, because, after they're in place, I believe we could start to make vast improvements in performance and general usability. Without them, we could end up facing the same problems over and over again in superficially different forms.

Theano top-level tests module is not resolved when running coverage on a different project that has a "tests" module

I wanted to run the tests on our calibr8 package, but import Theano blew up on this line:
https://github.com/pymc-devs/Theano-PyMC/blob/master/theano/__init__.py#L170

I can import theano from a jupyter notebook without problems so it must be related to some changes that coverage does to the path variables.

Nevertheless the problem did not appear with the original Theano.

Explanation of the traceback: calibr8.utils imports theano, but theano imports tests which is a directory in our project. The code in tests/test.py imports a class (from the core module) that is defined after import theano.

Traceback (most recent call last):
  File "tests/tests.py", line 9, in <module>
    import calibr8
  File "c:\users\osthege\repos\calibr8\calibr8\calibr8\__init__.py", line 1, in <module>
    from . core import  *
  File "c:\users\osthege\repos\calibr8\calibr8\calibr8\core.py", line 10, in <module>
    from . import utils
  File "c:\users\osthege\repos\calibr8\calibr8\calibr8\utils.py", line 10, in <module>
    import theano
  File "c:\users\osthege\repos\theano-pymc\theano\__init__.py", line 170, in <module>
    import tests
  File "C:\Users\osthege\Repos\calibr8\calibr8\tests\tests.py", line 30, in <module>
    class _TestModel(calibr8.ErrorModel):
AttributeError: module 'calibr8' has no attribute 'ErrorModel'

Prevent `Elemwise` graphs from violating `Op` arities

The scalar multiplication Op, Mul, has an impl method that actually uses np.product. When used in conjunction with Elemwise, optimizations like local_mul_canonizer construct graphs that essentially have Elemwise(Mul)(a, b, c, ...) nodes.

In other words, it adds nodes to the graph that violate the arity of scalar functions (e.g. by turning the binary operator Mul into a variadic operator) and puts the graph into a representationally invalid state.

As well, Elemwise implements numerous unnecessarily complicated hacks throughout its numerical evaluation stages in order to make these invalid graphs work (e.g. vectorizations of arity-violating impl functions occurring in Op.prepare_node during calls to Op.perform?!).

If anything, to preserve the validity of the graph, any optimization that flatten these should also replace the scalar Ops with arity-appropriate ones (e.g. replace Mul with Product).

Refactor Travis Build Script

It looks like some/all Travis build failures are due to dependency issues involving old versions of Python (e.g. 2.7, 3.4). We should probably remove all tests involving Python <= 3.5 and cover 3.6-8 instead.

Additionally, the .travis.yml is unnecessarily complex/redundant and has confusing, unexplained combinations of versions, settings, and tests (e.g. why are we only sometimes performing some tests with different Python and NumPy versions?). We should try to simplify it as much as possible.

Remove theano.dot

The __init__.py-level theano.dot function is awkward and—as far as I can tell—completely unnecessary. What it appears to be doing is supposed to be done in class-level implementations of __dot__ and __rdot__ (e.g. an interface/mixin class that checks both objects for the appropriate methods).

This function should be removed and the classes that implement __dot__ and __rdot__ (e.g. theano.tensor.var._tensor_py_operators) should handle this.

No JAX conversion for CumOp

I'm sure you guys are working on it but just for reference, the following code raises an error due to missing CumOp:

import numpy as np
import pymc3 as pm
import pymc3.sampling_jax

with pm.Model() as model:
    x = pm.Dirichlet("x", a=np.ones(10))

with model:
    trace_jax = pm.sampling_jax.sample_numpyro_nuts(
        500, tune=500
    )

Traceback:

/Users/fb90/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/gof/cc.py:1029: UserWarning: Your g++ compiler fails to compile OpenMP code. We know this happen with some version of the EPD mingw compiler and LLVM compiler on Mac OS X. We disable openmp everywhere in Theano. To remove this warning set the theano flags `openmp` to False.
  ret += x.c_compile_args()
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-1-b482cf1c4aad> in <module>
      7 
      8 with model:
----> 9     trace_jax = pm.sampling_jax.sample_numpyro_nuts(
     10         500, tune=500
     11     )

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/pymc3-3.9.3-py3.8.egg/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar)
    115 
    116     fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
--> 117     fns = theano.sandbox.jaxify.jax_funcify(fgraph)
    118     logp_fn_jax = fns[0]
    119 

~/anaconda3/envs/pymc3jax/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in jax_funcify_FunctionGraph(fgraph)
    550 
    551     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 552     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    553 
    554     return jax_funcs

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in <listcomp>(.0)
    550 
    551     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 552     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    553 
    554     return jax_funcs

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    129             input_f = jax_data_func
    130         else:
--> 131             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    132 
    133         input_funcs.append(input_f)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    110         return memo[out_node]
    111 
--> 112     jax_return_func = jax_funcify(out_node.op)
    113 
    114     input_funcs = []

~/anaconda3/envs/pymc3jax/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+40.g5de033879-py3.8.egg/theano/sandbox/jaxify.py in jax_funcify(op)
    156 def jax_funcify(op):
    157     """Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
--> 158     raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
    159 
    160 

NotImplementedError: No JAX conversion for the given `Op`: CumOp{0, mul}

Versions and main components

  • PyMC3 Version: pymc3jax branch
  • Theano Version: Theano-Pymc master branch
  • Python Version: 3.8
  • Operating system: Mac OS Catalina

Fix links to repo

Various links in README's point to the original repo. Change them over to this repo

Fix constant folding optimization error caused by AdvancedBooleanSubtensor

Theano's boolean indexing does not match NumPy's.

Consider the following:

import numpy as np

import theano
import theano.tensor as tt


theano.config.cxx = ""


test_array = np.array([[np.inf],
                       [10],
                       [np.inf]])

test_array[~np.isinf(test_array)]
[10.]

The Theano analog is as follows:

test_array_tt = tt.as_tensor_variable(test_array)

test_array_tt[~tt.isinf(test_array_tt)]
/tmp/user/1000/babel-Hkkg0F/python-ZsmiGy in <module>
      1 test_array_tt = tt.as_tensor_variable(test_array)
      2
----> 3 test_array_tt[~tt.isinf(test_array_tt)]

~/projects/code/python/Theano/theano/tensor/var.py in __getitem__(self, args)
    579
    580         if advanced_boolean:
--> 581             return theano.tensor.subtensor.advanced_boolean_subtensor(self, *args)
    582         elif advanced:
    583             if (

~/projects/code/python/Theano/theano/gof/op.py in __call__(self, *inputs, **kwargs)
    611         """
    612         return_list = kwargs.pop("return_list", False)
--> 613         node = self.make_node(*inputs, **kwargs)
    614
    615         if config.compute_test_value != "off":

~/projects/code/python/Theano/theano/tensor/subtensor.py in make_node(self, x, *index)
   2305
   2306         index = tuple(map(as_index_variable, index))
-> 2307         bcast = adv_index_broadcastable_pattern(x, index)
   2308         return gof.Apply(
   2309             self,

~/projects/code/python/Theano/theano/tensor/subtensor.py in adv_index_broadcastable_pattern(a, idx)
   2239     # 2 - True = 1; 2 - False = 2
   2240     fakeshape = [2 - bc for bc in a.broadcastable]
-> 2241     retshape = np.empty(fakeshape)[newidx].shape
   2242     return tuple([dim == 1 for dim in retshape])
   2243

IndexError: boolean index did not match indexed array along dimension 1; dimension is 1 but corresponding boolean dimension is 2

The code can be made to work after tt.squeeze-ing the boolean index array, but this is a poor solution.

Fix poorly identified and brittle tests

ImportError with PyMC3

With latest PyMC3 master and Theano-PyMC master, I still get an import error:

image

Just import theano is fine..
Should it be a PyMC3 ticket?

Base core graph classes on attrs or dataclasses

In line with #37, we should consider basing the core graph objects (e.g. theano.gof.graph.[Node, Variable, Apply], theano.gof.op.[PureOp, Op], and all their subclasses) on some form of flexible, (relatively) immutable, easily convertible data type. The packages attrs and the newly built-in dataclasses provide much—if not all—of what we need.

Specifically, both packages provide automatic recursive conversion into tuples and dicts and easily configurable defaults for things like equality, ordering, repr, hashing, etc. The attrs goes a little further and can perform automatic input validation and conversion, as well as use __slots__ for a relative performance gain.

More importantly, both provide convenient mutation functions (i.e. dataclasses.replace and attrs.evolve) that fit well with the existing object identity model and the pervasive requirement for object cloning. In this setting, graph object mutation could be turned into a considerably easier—and universal—process!

Adding support of FEniCS models

I've made a small package for embedding FEniCS PDE solvers in Theano.
https://github.com/IvanYashchuk/fenics-pymc3
The current API is just one function create_fenics_theano_op which turns a normal Python function, which expects FEniCS inputs and outputs a solution to the problem, into a differentiable Theano Op that can be directly used in a PyMC3 model.
This function can be used as a decorator (with @create_fenics_theano_op syntax) as well.
@junpenglao suggested that it could of interest to port it here. What do other people think about this?
I'm not well familiar with the code structure here, what would be a good place for this FEniCS integration functionality?
Also if we get to the point of having FEniCS support here, we should also make it work with Firedrake, which has more or less the same Python interface as FEniCS, but a different backend.

Consider making the `Scan` inner-graph an input to the `Op`

Scan has a member/field that holds the inner-graph (i.e a graph representing the computation in the Scan's body). This design choice effectively hides the inner-graph from a standard graph traversal, and—as a result—forces the need for special Scan-specific logic changes in generic routines in order to work with these Ops.

We should consider changing this.

Missing Op: ExtractDiag

Trying to run a GP model:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-19-7a2d0b5601fb> in <module>
      1 with model:
----> 2     tr_jax = sample_tfp_nuts(chains=4, target_accept=0.95)

<ipython-input-9-54c111d89074> in sample_tfp_nuts(draws, tune, chains, target_accept, seed, model)
     15 
     16     fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
---> 17     fns = theano.sandbox.jaxify.jax_funcify(fgraph)
     18     logp_fn_jax = fns[0]
     19 

~/miniconda3/envs/pymc3theano/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_FunctionGraph(fgraph)
    522 
    523     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 524     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    525 
    526     return jax_funcs

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in <listcomp>(.0)
    522 
    523     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 524     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    525 
    526     return jax_funcs

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
     90         return memo[out_node]
     91 
---> 92     jax_return_func = jax_funcify(out_node.op)
     93 
     94     input_funcs = []

~/miniconda3/envs/pymc3theano/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

~/projects/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify(op)
    136 def jax_funcify(op):
    137     """Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
--> 138     raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
    139 
    140 

NotImplementedError: No JAX conversion for the given `Op`: ExtractDiag{offset=0, axis1=0, axis2=1, view=False}

Jax 0.2 does not support jax.numpy.reshape with non-constant values in omnistaging mode

import pymc3 as pm
import theano
import numpy as np
import theano.sandbox.jax

theano.compile.mode.predefined_linkers["jax"] = theano.sandbox.jax.JaxLinker()
jax_mode = theano.compile.Mode(linker="jax")

x = np.linspace(0, 1, 10)
y = x * 4. + 1.4 + np.random.randn(10)

with pm.Model() as model:
    beta = pm.Normal("beta", 0., 5., shape=2)
    sigma = pm.HalfNormal("sigma", 2.5)
    obs = pm.Normal("obs", beta[0] + beta[1] * x, sigma, observed=y)
    pm.sample(mode=jax_mode)

Traceback:

Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-20-21adaeaad34c> in <module>
     21 with model:
---> 22     pm.sample(mode=jax_mode)

~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    480             _log.info("Auto-assigning NUTS sampler...")
--> 481             start_, step = init_nuts(
    482                 init=init,

~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
   2133 
-> 2134     step = pm.NUTS(potential=potential, model=model, **kwargs)
   2135 

~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    167         """
--> 168         super().__init__(vars, **kwargs)
    169 

~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
     92 
---> 93         super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
     94 

~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
    253             q = func.dict_to_array(model.test_point)
--> 254             logp, dlogp = func(q)
    255         except ValueError:

~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
    738 
--> 739         output = self._theano_function(array)
    740         if grad_out is None:

~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    978             outputs = (
--> 979                 self.fn()
    980                 if output_subset is None

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    653                 ):
--> 654                     outputs = [
    655                         jax_impl_jit(*[x[0] for x in thunk_inputs])

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    654                     outputs = [
--> 655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    126             func_args = [fn(*inputs) for fn in input_funcs]
--> 127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    545     def reshape(x, shape):
--> 546         return jnp.reshape(x, shape)
    547 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ConcretizationTypeError                   Traceback (most recent call last)
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    653                 ):
--> 654                     outputs = [
    655                         jax_impl_jit(*[x[0] for x in thunk_inputs])

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    654                     outputs = [
--> 655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    136     try:
--> 137       return fun(*args, **kwargs)
    138     except Exception as e:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    208     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209     out = xla.xla_call(
    210         flat_fun,

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1143   def bind(self, fun, *args, **params):
-> 1144     return call_bind(self, fun, *args, **params)
   1145 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1134   with maybe_new_sublevel(top_trace):
-> 1135     outs = primitive.process(top_trace, fun, tracers, params)
   1136   return map(full_lower, apply_todos(env_trace_todo(), outs))

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1146   def process(self, trace, fun, tracers, params):
-> 1147     return trace.process_call(self, fun, tracers, params)
   1148 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    576   def process_call(self, primitive, f, tracers, params):
--> 577     return primitive.impl(f, *tracers, **params)
    578   process_map = process_call

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    530                                *unsafe_map(arg_spec, args))

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    233     else:
--> 234       ans = call(fun, *args)
    235       cache[key] = (ans, fun.stores)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    594   if config.omnistaging_enabled:
--> 595     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    596     if any(isinstance(c, core.Tracer) for c in consts):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1022     main.jaxpr_stack = ()  # type: ignore
-> 1023     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1024     del main

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1003     in_tracers = map(trace.new_arg, in_avals)
-> 1004     ans = fun.call_wrapped(*in_tracers)
   1005     out_tracers = map(trace.full_raise, ans)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     try:
--> 151       ans = self.f(*args, **dict(self.params, **kwargs))
    152     except:

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    126             func_args = [fn(*inputs) for fn in input_funcs]
--> 127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    545     def reshape(x, shape):
--> 546         return jnp.reshape(x, shape)
    547 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
    873     else:
--> 874       raise_concretization_error(val, context)
    875   else:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
    852           f"Encountered tracer value: {val}")
--> 853   raise ConcretizationTypeError(msg)
    854 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>

During handling of the above exception, another exception occurred:

ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-20-21adaeaad34c> in <module>
     20 
     21 with model:
---> 22     pm.sample(mode=jax_mode)

~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    479             # By default, try to use NUTS
    480             _log.info("Auto-assigning NUTS sampler...")
--> 481             start_, step = init_nuts(
    482                 init=init,
    483                 chains=chains,

~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
   2132         raise ValueError(f"Unknown initializer: {init}.")
   2133 
-> 2134     step = pm.NUTS(potential=potential, model=model, **kwargs)
   2135 
   2136     return start, step

~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    166         `pm.sample` to the desired number of tuning steps.
    167         """
--> 168         super().__init__(vars, **kwargs)
    169 
    170         self.max_treedepth = max_treedepth

~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
     91         vars = inputvars(vars)
     92 
---> 93         super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
     94 
     95         self.adapt_step_size = adapt_step_size

~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
    252             func.set_extra_values(model.test_point)
    253             q = func.dict_to_array(model.test_point)
--> 254             logp, dlogp = func(q)
    255         except ValueError:
    256             if logp_dlogp_func is not None:

~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
    737             out = grad_out
    738 
--> 739         output = self._theano_function(array)
    740         if grad_out is None:
    741             return output

~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    977         try:
    978             outputs = (
--> 979                 self.fn()
    980                 if output_subset is None
    981                 else self.fn(output_subset=output_subset)

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    705                         old_s[0] = None
    706             except Exception:
--> 707                 raise_with_op(node, thunk)
    708 
    709         f = streamline_default_f

~/projects/Theano-PyMC/theano/gof/link.py in raise_with_op(node, thunk, exc_info, storage_map)
    346         # extra long error message in that case.
    347         pass
--> 348     reraise(exc_type, exc_value, exc_trace)
    349 
    350 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/six.py in reraise(tp, value, tb)
    700                 value = tp()
    701             if value.__traceback__ is not tb:
--> 702                 raise value.with_traceback(tb)
    703             raise value
    704         finally:

~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
    701                     thunks, order, post_thunk_old_storage
    702                 ):
--> 703                     thunk()
    704                     for old_s in old_storage:
    705                         old_s[0] = None

~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
    652                     node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
    653                 ):
--> 654                     outputs = [
    655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    653                 ):
    654                     outputs = [
--> 655                         jax_impl_jit(*[x[0] for x in thunk_inputs])
    656                         for jax_impl_jit in jax_impl_jits
    657                     ]

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    135   def reraise_with_filtered_traceback(*args, **kwargs):
    136     try:
--> 137       return fun(*args, **kwargs)
    138     except Exception as e:
    139       if not is_under_reraiser(e):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    207       _check_arg(arg)
    208     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209     out = xla.xla_call(
    210         flat_fun,
    211         *args_flat,

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1142 
   1143   def bind(self, fun, *args, **params):
-> 1144     return call_bind(self, fun, *args, **params)
   1145 
   1146   def process(self, trace, fun, tracers, params):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1133   tracers = map(top_trace.full_raise, args)
   1134   with maybe_new_sublevel(top_trace):
-> 1135     outs = primitive.process(top_trace, fun, tracers, params)
   1136   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1137 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1145 
   1146   def process(self, trace, fun, tracers, params):
-> 1147     return trace.process_call(self, fun, tracers, params)
   1148 
   1149   def post_process(self, trace, out_tracers, params):

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    575 
    576   def process_call(self, primitive, f, tracers, params):
--> 577     return primitive.impl(f, *tracers, **params)
    578   process_map = process_call
    579 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    527 
    528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    530                                *unsafe_map(arg_spec, args))
    531   try:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    232       fun.populate_stores(stores)
    233     else:
--> 234       ans = call(fun, *args)
    235       cache[key] = (ans, fun.stores)
    236     return ans

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    593   abstract_args, arg_devices = unzip2(arg_specs)
    594   if config.omnistaging_enabled:
--> 595     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    596     if any(isinstance(c, core.Tracer) for c in consts):
    597       raise core.UnexpectedTracerError("Encountered an unexpected tracer.")

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1021     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1022     main.jaxpr_stack = ()  # type: ignore
-> 1023     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1024     del main
   1025   return jaxpr, out_avals, consts

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1002     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1003     in_tracers = map(trace.new_arg, in_avals)
-> 1004     ans = fun.call_wrapped(*in_tracers)
   1005     out_tracers = map(trace.full_raise, ans)
   1006   jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    149 
    150     try:
--> 151       ans = self.f(*args, **dict(self.params, **kwargs))
    152     except:
    153       # Some transformations yield from inside context managers, so we have to

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    124 
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
    124 
    125         def jax_func(*inputs):
--> 126             func_args = [fn(*inputs) for fn in input_funcs]
    127             return return_func(*func_args)
    128 

~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
    125         def jax_func(*inputs):
    126             func_args = [fn(*inputs) for fn in input_funcs]
--> 127             return return_func(*func_args)
    128 
    129         jax_funcs.append(update_wrapper(jax_func, return_func))

~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
    544 def jax_funcify_Reshape(op):
    545     def reshape(x, shape):
--> 546         return jnp.reshape(x, shape)
    547 
    548     return reshape

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
   1144 def reshape(a, newshape, order="C"):
   1145   try:
-> 1146     return a.reshape(newshape, order=order)  # forward to method for ndarrays
   1147   except AttributeError:
   1148     return _reshape(a, newshape, order=order)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1190           type(newshape[0]) is not Poly):
   1191     newshape = newshape[0]
-> 1192   return _reshape(a, newshape, order=order)
   1193 
   1194 

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1166 
   1167 def _reshape(a, newshape, order="C"):
-> 1168   computed_newshape = _compute_newshape(a, newshape)
   1169   if order == "C":
   1170     return lax.reshape(a, computed_newshape, None)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
   1161   if newsize < 0:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
-> 1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)
   1160   newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
   1161   if newsize < 0:

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
   1155   else: iterable = True
   1156   def check(size):
-> 1157     return size if type(size) is Poly else core.concrete_or_error(
   1158       int, size, "The error arose in jax.numpy.reshape.")
   1159   newshape = [check(size) for size in newshape] if iterable else check(newshape)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
    872       return force(val.aval.val)
    873     else:
--> 874       raise_concretization_error(val, context)
    875   else:
    876     return force(val)

~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
    851          "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
    852           f"Encountered tracer value: {val}")
--> 853   raise ConcretizationTypeError(msg)
    854 
    855 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:

  operation yn:bool[] = lt yl:int64[] ym:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

  operation yq:int64[] = xla_call[ backend=None
                       call_jaxpr={ lambda  ; a b c.
                                    let d = select a b c
                                    in (d,) }
                       device=None
                       donated_invars=(False, False, False)
                       name=_where ] yn:bool[] yo:int64[] yp:int64[]
    from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
Apply node that caused the error: Sum{acc_dtype=float64}(MakeVector{dtype='float64'}.0)
Toposort index: 46
Inputs types: [TensorType(float64, vector)]
Inputs shapes: [(3,)]
Inputs strides: [(8,)]
Inputs values: [array([0.69049938, 0.        , 0.        ])]
Outputs clients: [['output']]

Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
  File "<ipython-input-20-21adaeaad34c>", line 22, in <module>
    pm.sample(mode=jax_mode)
  File "/Users/twiecki/projects/pymc/pymc3/sampling.py", line 481, in sample
    start_, step = init_nuts(
  File "/Users/twiecki/projects/pymc/pymc3/sampling.py", line 2134, in init_nuts
    step = pm.NUTS(potential=potential, model=model, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/hmc/nuts.py", line 168, in __init__
    super().__init__(vars, **kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/hmc/base_hmc.py", line 93, in __init__
    super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
  File "/Users/twiecki/projects/pymc/pymc3/step_methods/arraystep.py", line 245, in __init__
    func = model.logp_dlogp_function(
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1005, in logp_dlogp_function
    costs = [self.logpt]
  File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1015, in logpt
    logp = tt.sum([tt.sum(factor) for factor in factors])

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.

Add new Scan Op utils

Let's consider adding the Scan utilities from pymc-devs/symbolic-pymc#114 (e.g. the improved ScanArgs class and convert_outer_out_to_in function).

These would also help with a much needed general refactoring and simplification of Scan.

Coveralls build error "No build matching CI build number..."

This build had the following Coveralls error:

> Run exoplanet-dev/coveralls-python-action@develop
/usr/bin/docker run --name e5c3570f78f745ad14bcd878b0e92176cdb26_634b8d --label 1e5c35 --workdir /github/workspace --rm -e INPUT_PARALLEL-FINISHED -e INPUT_GITHUB-TOKEN -e INPUT_PARALLEL -e INPUT_DEBUG -e HOME -e GITHUB_JOB -e GITHUB_REF -e GITHUB_SHA -e GITHUB_REPOSITORY -e GITHUB_REPOSITORY_OWNER -e GITHUB_RUN_ID -e GITHUB_RUN_NUMBER -e GITHUB_RETENTION_DAYS -e GITHUB_ACTOR -e GITHUB_WORKFLOW -e GITHUB_HEAD_REF -e GITHUB_BASE_REF -e GITHUB_EVENT_NAME -e GITHUB_SERVER_URL -e GITHUB_API_URL -e GITHUB_GRAPHQL_URL -e GITHUB_WORKSPACE -e GITHUB_ACTION -e GITHUB_EVENT_PATH -e GITHUB_PATH -e GITHUB_ENV -e RUNNER_OS -e RUNNER_TOOL_CACHE -e RUNNER_TEMP -e RUNNER_WORKSPACE -e ACTIONS_RUNTIME_URL -e ACTIONS_RUNTIME_TOKEN -e ACTIONS_CACHE_URL -e GITHUB_ACTIONS=true -e CI=true -v "/var/run/docker.sock":"/var/run/docker.sock" -v "/home/runner/work/_temp/_github_home":"/github/home" -v "/home/runner/work/_temp/_github_workflow":"/github/workflow" -v "/home/runner/work/_temp/_runner_file_commands":"/github/file_commands" -v "/home/runner/work/Theano-PyMC/Theano-PyMC":"/github/workspace" 1e5c35:70f78f745ad14bcd878b0e92176cdb26  "--github-token" "***" "--parallel" "false" "--parallel-finished" "true" "--debug" "false"
{'error': 'No build matching CI build number 2121162aaa4fc23bf7aedc13a4bbc8afaaac23ee-PR-80 found'}
ExitCode.FAILURE

@dfm, any ideas?

Implement NumPy's broadcast_to

Theano could really benefit from an implementation of NumPy's broadcast_to—one that uses views and doesn't create an entirely new array. With an implementation of this, it seems like most of the other broadcast_* functions could be implemented, as well.

Rename scalar and tensor maximum and minimum Ops

There are distinct maximum and minimum Ops in both theano.scalar.basic and theano.tensor.basic. It seems like we could give these different names and avoid the confusion and potential for bugs.

pymc3jax: AttributeError: partially initialized module 'theano' has no attribute 'compile' (most likely due to a circular import)

Description of your problem

I was very excited to try out the new JAX sampler on my models but I ran into an error when trying to run the example notebook:
https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434

I created a new conda environment and installed all the dependencies mentioned in the notebook but I get the following error when I run it:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-1-f103cf5131d4> in <module>
      3 import numpy as np
      4 import pandas as pd
----> 5 import pymc3 as pm
      6 import theano
      7 import pymc3.sampling_jax

~/anaconda3/envs/pymc3jax/lib/python3.8/importlib/_bootstrap.py in _find_and_load(name, import_)

~/anaconda3/envs/pymc3jax/lib/python3.8/importlib/_bootstrap.py in _find_and_load_unlocked(name, import_)

~/anaconda3/envs/pymc3jax/lib/python3.8/importlib/_bootstrap.py in _load_unlocked(spec)

~/anaconda3/envs/pymc3jax/lib/python3.8/importlib/_bootstrap.py in _load_backward_compatible(spec)

<frozen zipimport> in load_module(self, fullname)

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/pymc3-3.9.3-py3.8.egg/pymc3/__init__.py in <module>
     37 
     38 
---> 39 __set_compiler_flags()
     40 
     41 from .blocking import *

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/pymc3-3.9.3-py3.8.egg/pymc3/__init__.py in __set_compiler_flags()
     31 def __set_compiler_flags():
     32     # Workarounds for Theano compiler problems on various platforms
---> 33     import theano
     34 
     35     current = theano.config.gcc.cxxflags

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+28.gcbb4f83ee-py3.8.egg/theano/__init__.py in <module>
     92 
     93 
---> 94 from theano.configdefaults import config
     95 from theano.configparser import change_flags
     96 

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+28.gcbb4f83ee-py3.8.egg/theano/configdefaults.py in <module>
    620 
    621 
--> 622 AddConfigVar(
    623     "mode", "Default compilation mode", ConfigParam("Mode", filter_mode), in_c_key=False
    624 )

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+28.gcbb4f83ee-py3.8.egg/theano/configparser.py in AddConfigVar(name, doc, configparam, root, in_c_key)
    299         # This allow to filter wrong value from the user.
    300         if not callable(configparam.default):
--> 301             configparam.__get__(root, type(root), delete_key=True)
    302         else:
    303             # We do not want to evaluate now the default value

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+28.gcbb4f83ee-py3.8.egg/theano/configparser.py in __get__(self, cls, type_, delete_key)
    345                 else:
    346                     val_str = self.default
--> 347             self.__set__(cls, val_str)
    348         # print "RVAL", self.val
    349         return self.val

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+28.gcbb4f83ee-py3.8.egg/theano/configparser.py in __set__(self, cls, val)
    357         # print "SETTING PARAM", self.fullname,(cls), val
    358         if self.filter:
--> 359             self.val = self.filter(val)
    360         else:
    361             self.val = val

~/anaconda3/envs/pymc3jax/lib/python3.8/site-packages/Theano_PyMC-1.0.5+28.gcbb4f83ee-py3.8.egg/theano/configdefaults.py in filter_mode(val)
    605             "DEBUG_MODE",
    606         ]
--> 607         or val in theano.compile.mode.predefined_modes
    608     ):
    609         return val

AttributeError: partially initialized module 'theano' has no attribute 'compile' (most likely due to a circular import)

Versions and main components

  • PyMC3 Version: pymc3jax branch
  • Theano Version: Theano-Pymc master branch
  • Python Version: 3.8
  • Operating system: Mac OS Catalina
  • How did you install PyMC3: manual installation

Implement JAX conversion for Scan Op

The Theano Scan Op currently has an incomplete implementation. We need to finish that, especially since this Op is an important bridge to a lot of valuable JAX functionality and future performance enhancements.

There's already a (skipped) test for this conversion that builds a Scan Op and its equivalent jax.lax.scan. This should provide a good start for anyone wanting to get involved. Plus, the current partial implementation already provides the necessary framework for the conversion. All that's needed is for someone to finish bridging the inputs and outputs.

`local_optimizers_map` in `EquilibriumOptimizer` prevents inheritance-based filtering

The OrderedDict member named local_optimizers_map in EquilibriumOptimizer is used to filter "tracked" Ops by type; however, local_optimizers_map cannot consider inheritance. In other words, every subclass of a class must have an entry in order to be considered "tracked" by an optimization.

Given the apparently intended meaning of "tracked" Op classes, we should probably change this so that subclasses are also considered. Otherwise, this will prevent numerous optimizations for newly/user-constructed Ops.

Some tests take a long time

There are 920 tests generated and run in test_sort.py, so it takes a while to finish. Are this many tests necessary?

Make CAReduce a singleton/factory

CAReduce should be a singleton—at least in terms of its first constructor argument scalar_op. Because it's not, newly introduced optimizations could miss out on using the specific Min and Max Ops, for instance.

Many other Ops should probably be singletons, too (conditional on their fields or otherwise).

Beta often diverges when logodds is close to 19.0

Description

This simple model never diverges in NUTS:

import numpy as np
import pymc3 as pm

with pm.Model() as good_model:
    good_beta = pm.Beta('good_beta', alpha=0.25, beta=3.0, shape=40)

with good_model:
    good_trace = pm.sample(draws=600, chains=2, tune=500) 

But this very similar model---with alpha and beta swapped---diverges on about half the samples:

import numpy as np
import pymc3 as pm

with pm.Model() as bad_model:
    bad_beta = pm.Beta('bad_beta', alpha=3.0, beta=0.25, shape=40)

with bad_model:
    bad_trace = pm.sample(draws=600, chains=2, tune=500)
Sampling 2 chains for 500 tune and 600 draw iterations (1_000 + 1_200 draws total) took 17 seconds.
There were 328 divergences after tuning. Increase `target_accept` or reparameterize.
There were 282 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.8848443240050202, but should be close to 0.8. Try to increase the number of tuning steps.
The estimated number of effective samples is smaller than 200 for some parameters.

The problem appears to be that in the diverging samples, one element of bad_beta jumps to a logodds that is close to (but less than) the value 19.0

def max_bad_beta(warning):
    return np.amax(warning.extra['bad_beta_logodds__'])

def logodds_and_beta(value):
    return value, pm.transforms.logodds.backward(value).eval()[()]

def maxs_from_warnings(trace, limit=50):
    warnings = trace.report._warnings[:limit]
    return [logodds_and_beta(max_bad_beta(warning)) for warning in warnings]

maxs_from_warnings(bad_trace)
[(18.522285926243683, 0.9999999909661387),
 (18.026951801202443, 0.9999999851750137),
 (18.273787751452954, 0.9999999884176999),
 (15.811521213093858, 0.9999998641237577),
 (18.788894703456123, 0.999999993080309),
 (17.840933189508505, 0.9999999821441242),
 (18.99068873589819, 0.9999999943447908),
 (18.650396555754487, 0.9999999920524058),
 (18.886465782831156, 0.9999999937235782),
 (18.614120566380194, 0.9999999917588059),
 (18.748619470023584, 0.9999999927959284),
 (18.96328594127111, 0.9999999941876794),
 (18.43516905323891, 0.9999999901438387),
 (18.7824011307429, 0.9999999930352292),
 (18.61426914065103, 0.9999999917600301),
 (18.61426914065103, 0.9999999917600301),
 (18.959200871340606, 0.9999999941638871),
 (18.310579256420056, 0.9999999888360865),
 (18.875771090438388, 0.9999999936560935),
 (18.419609630862556, 0.9999999899892832),
 (18.743196196992155, 0.9999999927567527), 
...
]

What is special about 19? Note that good_beta generates logodds close to -19.0 as often as bad_beta generates logodds close to 19.0, but good_beta never diverges.

Versions and main components

  • PyMC3 Version: 3.9.1
  • Theano Version: 1.0.4
  • Python Version: 3.8.3
  • Operating system: macOS 10.15.16
  • How did you install PyMC3: conda

Implement grad for GammaInc

The Gamma logcdf does not have a gradient function implemented.

Example:

self.GI_mean = pm.Normal('GI_mean', gi_mean_mean, gi_mean_sd)
self.GI_sd = pm.Normal('GI_sd', gi_sd_mean, gi_sd_sd)

gi_beta = self.GI_mean / self.GI_sd ** 2
gi_alpha = self.GI_mean ** 2 / self.GI_sd ** 2

GI_dist = pm.Gamma.dist(alpha=gi_alpha, beta=gi_beta)

bins = np.zeros(gi_truncation + 1)
bins[1:] = np.arange(gi_truncation)
bins[2:] += 0.5
bins[:2] += 1e-5

cdf_vals = T.exp(GI_dist.logcdf(bins))
pmf = cdf_vals[1:] - cdf_vals[:-1]
GI_rev = T.repeat(T.reshape(pmf[::-1] / T.sum(pmf), (1, 1, gi_truncation)), 2, axis=0)

Trying to sample from this raises an error:

MethodNotDefined: ('grad', <class 'theano.scalar.basic_scipy.GammaInc'>, 'GammaInc')

Note: the logcdf derivative should be available as a function of the logpdf. It would be great to have this implemented. For example, this would mean my use case above (generating a discretised distribution) can be implemented.

Fix/remove confusing __iter__ implementation in theano.tensor.var.TensorVariable

I'm trying to translate Statistical Rethinking from R and RStan to Python and PyMC3.

On page 304, there's a simple logistic regression example. Data being used (publicly available):

	dept	applicant.gender	admit	reject	applications	is_male
1	A	male	512	313	825	1
2	A	female	89	19	108	0
3	B	male	353	207	560	1
4	B	female	17	8	25	0
5	C	male	120	205	325	1
6	C	female	202	391	593	0
7	D	male	138	279	417	1
8	D	female	131	244	375	0
9	E	male	53	138	191	1
10	E	female	94	299	393	0
11	F	male	22	351	373	1
12	F	female	24	317	341	0

Data is stored in a pandas DataFrame object. When I try and fit the model using:

with pm.Model() as m106:
    
    alpha = pm.Normal('alpha', 0, 10)
    beta_m = pm.Normal('beta_m', 0, 10)
    
    lin = alpha + beta_m * data['is_male']
    p = np.exp(lin) / (1 + np.exp(lin))
    
    admit = pm.Binomial('admit', n=data['applications'], p=p, observed=data['admit'])
    
    m106_map = pm.find_MAP()
    m106_traces = pm.sample(1000, start=m106_map)

I get the following error (which seems similar to pymc-devs/pymc#918):

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/type.py in dtype_specs(self)
    266                 'complex64': (complex, 'theano_complex64', 'NPY_COMPLEX64')
--> 267             }[self.dtype]
    268         except KeyError:

KeyError: 'object'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/basic.py in constant_or_value(x, rtype, name, ndim, dtype)
    407             rval = rtype(
--> 408                 TensorType(dtype=x_.dtype, broadcastable=bcastable),
    409                 x_.copy(),

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/type.py in __init__(self, dtype, broadcastable, name, sparse_grad)
     49         self.broadcastable = tuple(bool(b) for b in broadcastable)
---> 50         self.dtype_specs()  # error checking is done there
     51         self.name = name

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/type.py in dtype_specs(self)
    269             raise TypeError("Unsupported dtype for %s: %s"
--> 270                             % (self.__class__.__name__, self.dtype))
    271 

TypeError: Unsupported dtype for TensorType: object

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/basic.py in as_tensor_variable(x, name, ndim)
    201     try:
--> 202         return constant(x, name=name, ndim=ndim)
    203     except TypeError:

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/basic.py in constant(x, name, ndim, dtype)
    421     ret = constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
--> 422                             dtype=dtype)
    423 

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/basic.py in constant_or_value(x, rtype, name, ndim, dtype)
    416     except Exception:
--> 417         raise TypeError("Could not convert %s to TensorType" % x, type(x))
    418 

TypeError: ('Could not convert 1     Elemwise{mul,no_inplace}.0\n2     Elemwise{mul,no_inplace}.0\n3     Elemwise{mul,no_inplace}.0\n4     Elemwise{mul,no_inplace}.0\n5     Elemwise{mul,no_inplace}.0\n6     Elemwise{mul,no_inplace}.0\n7     Elemwise{mul,no_inplace}.0\n8     Elemwise{mul,no_inplace}.0\n9     Elemwise{mul,no_inplace}.0\n10    Elemwise{mul,no_inplace}.0\n11    Elemwise{mul,no_inplace}.0\n12    Elemwise{mul,no_inplace}.0\nName: applications, dtype: object to TensorType', <class 'pandas.core.series.Series'>)

During handling of the above exception, another exception occurred:

AsTensorError                             Traceback (most recent call last)
<ipython-input-144-fb615dfa2e93> in <module>()
      7     p = np.exp(lin) / (1 + np.exp(lin))
      8 
----> 9     admit = pm.Binomial('admit', n=data['applications'], p=p, observed=data['admit'])
     10 
     11     m106_map = pm.find_MAP()

/Users/horatiu/anaconda/lib/python3.5/site-packages/pymc3/distributions/distribution.py in __new__(cls, name, *args, **kwargs)
     24         if isinstance(name, string_types):
     25             data = kwargs.pop('observed', None)
---> 26             dist = cls.dist(*args, **kwargs)
     27             return model.Var(name, dist, data)
     28         elif name is None:

/Users/horatiu/anaconda/lib/python3.5/site-packages/pymc3/distributions/distribution.py in dist(cls, *args, **kwargs)
     37     def dist(cls, *args, **kwargs):
     38         dist = object.__new__(cls)
---> 39         dist.__init__(*args, **kwargs)
     40         return dist
     41 

/Users/horatiu/anaconda/lib/python3.5/site-packages/pymc3/distributions/discrete.py in __init__(self, n, p, *args, **kwargs)
     43         self.n = n
     44         self.p = p
---> 45         self.mode = tt.cast(tt.round(n * p), self.dtype)
     46 
     47     def random(self, point=None, size=None, repeat=None):

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/basic.py in round(a, mode)
   2052     """round_mode(a) with mode in [half_away_from_zero, half_to_even]"""
   2053     if mode == "half_away_from_zero":
-> 2054         return round_half_away_from_zero(a)
   2055     elif mode == "half_to_even":
   2056         return round_half_to_even(a)

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/gof/op.py in __call__(self, *inputs, **kwargs)
    609         """
    610         return_list = kwargs.pop('return_list', False)
--> 611         node = self.make_node(*inputs, **kwargs)
    612 
    613         if config.compute_test_value != 'off':

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/elemwise.py in make_node(self, *inputs)
    541         using DimShuffle.
    542         """
--> 543         inputs = list(map(as_tensor_variable, inputs))
    544         shadow = self.scalar_op.make_node(
    545             *[get_scalar_type(dtype=i.type.dtype).make_variable()

/Users/horatiu/anaconda/lib/python3.5/site-packages/theano/tensor/basic.py in as_tensor_variable(x, name, ndim)
    206         except Exception:
    207             str_x = repr(x)
--> 208         raise AsTensorError("Cannot convert %s to TensorType" % str_x, type(x))
    209 
    210 # this has a different name, because _as_tensor_variable is the

AsTensorError: ('Cannot convert 1     Elemwise{mul,no_inplace}.0\n2     Elemwise{mul,no_inplace}.0\n3     Elemwise{mul,no_inplace}.0\n4     Elemwise{mul,no_inplace}.0\n5     Elemwise{mul,no_inplace}.0\n6     Elemwise{mul,no_inplace}.0\n7     Elemwise{mul,no_inplace}.0\n8     Elemwise{mul,no_inplace}.0\n9     Elemwise{mul,no_inplace}.0\n10    Elemwise{mul,no_inplace}.0\n11    Elemwise{mul,no_inplace}.0\n12    Elemwise{mul,no_inplace}.0\nName: applications, dtype: object to TensorType', <class 'pandas.core.series.Series'>)

The error is not very informative, and doesn't seem to point to my code.
But when using:

with pm.Model() as m106:
    
    alpha = pm.Normal('alpha', 0, 10)
    beta_m = pm.Normal('beta_m', 0, 10)
    
    lin = alpha + beta_m * data['is_male']
    p = np.exp(lin) / (1 + np.exp(lin))
    
    admit = pm.Binomial('admit', n=data['applications'].values, p=p, observed=data['admit'])
    
    m106_map = pm.find_MAP()
    m106_traces = pm.sample(1000, start=m106_map)

So explicitly passing in the numpy array rather than the pandas Series:

admit = pm.Binomial('admit', n=data['applications'].values, p=p, observed=data['admit'])

Everything works as expected. I'm just trying to figure out why that is? Is there a reference in the documentation for this behavior? What is the lesson I should take away from this? :)

Using pandas 0.19, numpy 1.11, pymc3.0rc2

Merge RandomVariable Ops

The Symbolic PyMC project contains a base RandomVariable Op with advanced shape-handling logic and a more optimization-friendly form, as well as implementations for many common random variable types. We should merge/replace the current RandomFunction Op with RandomVariable (and the related [rng_mrg](https://github.com/pymc-devs/Theano-PyMC/blob/master/theano/sandbox/rng_mrg.py implementations).

There's also RandomStreams class that should be considered/made to work with RandomVariable.

Conda-forge build is broken on import test

import: 'theano.compile'
import: 'theano.compile.sandbox'
import: 'theano.compile.tests'
Traceback (most recent call last):
  File "/home/conda/feedstock_root/build_artifacts/theano-pymc_1602124543407/test_tmp/run_test.py", line 11, in <module>
    import theano.compile.tests

from https://dev.azure.com/conda-forge/feedstock-builds/_build/results?buildId=219395&view=logs&j=d0d954b5-f111-5dc4-4d76-03b6c9d0cf7e&t=841356e0-85bb-57d8-dbbc-852e683d1642&l=2075

This happens because the recipe tests all kinds of imports: https://github.com/pymc-devs/Theano-PyMC/blob/master/conda/meta.yaml#L32

Replace unification framework

We should replace the custom Theano unification implementation with use of an external—and more robust—library, as demonstrated in Symbolic PyMC. Since there are already at least two Python unification libraries that can be easily made to work with Theano (e.g. unification and the fork used by Symbolic PyMC, logical-unification), there's no point in maintaining an independent implementation here, especially since those capabilities aren't core to this library's offerings.

Implement JAX conversion for CGemv Op

This model runs

data = np.random.normal(0, 1, 10)
with pm.Model() as model:
    a = pm.Normal("a", 0., 1.)
    b = pm.Normal("b", 0, 1.)
    y = pm.Normal("y", a, b, observed=data)
    trace = pm.sample_smc(100, chains=1, parallel=False)

This other model

def two_gaussians(x):
    log_like1 = - 0.5 * n * tt.log(2 * np.pi) \
                - 0.5 * tt.log(dsigma) \
                - 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
    log_like2 = - 0.5 * n * tt.log(2 * np.pi) \
                - 0.5 * tt.log(dsigma) \
                - 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
    return pm.math.logsumexp([tt.log(w1) + log_like1, tt.log(w2) + log_like2])


with pm.Model() as model:
    X = pm.Uniform('X',
                   shape=n,
                   lower=-2. * np.ones_like(mu1),
                   upper=2. * np.ones_like(mu1),
                   testval=-1. * np.ones_like(mu1))
    llk = pm.Potential('llk', two_gaussians(X))
    trace = pm.sample_smc(2000, chains=1, parallel=False)

throws AttributeError: 'ScalarSigmoid' object has no attribute 'nfunc_spec'

It seems this could be fixing adding nfunc_spec = ('scipy.special.expit', 1, 1) to the class ScalarSigmoid (not sure this is correct) by doing so
I then get the error 'ScalarSoftplus' object has no attribute 'nfunc_spec' which I guess should have a similar solution but not sure exactly what.

If I just add something just to avoid the error I then get NotImplementedError: No JAX conversion for the given Op: CGemv{inplace}

Use Numba, Cython, and JAX for Python-implemented Ops

Here's an old example of a Numba-enabled Theano Op. We can most likely do something similar for Cython and JAX, as well.

As an extension to that example, it would be nice to have the compiled function be the Op.perform method itself; that way, we could attempt to compile existing Ops with little-to-no changes.

Move Theano Type instance definitions

The modules theano.tensor.basic and theano.scalar.basic contain numerous type definitions (e.g. TensorType("float32", (True, False))). These are more appropriately defined in their respective type.py modules (e.g. where TensorType itself is defined). Currently, they're just creating unnecessary dependencies between said types and those very large basic.py modules.

Refactor sub-packages, modules, and imports

The current package/module structure leads to numerous circular dependencies and an overall unwieldy import process. We need to refactor the imports and module/subpackage layouts entirely.

For instance, it should be possible to import the core graph types (e.g. Node, Apply, Variable) without importing any Ops, Linkers, and other compilation-based classes. Furthermore, there's rampant use of module references like theano.* to access objects within subpackages. This obfuscates the dependency structure when casually reading source file headers and balloons the cross-dependencies when use of theano.* alone perform numerous automatic imports.

Use GitHub Actions

If I'm not mistaken, we can run more jobs in parallel using GitHub Actions and speed up these tests. For that reason—and some others (e.g. we can easily perform our deployment automations in Actions, as well), let's switch from Travis to GitHub Actions.

pymc3jax: AttributeError: 'Identity' object has no attribute 'nfunc_spec'

Description of your problem

I ran into another problem with the experimental JAX-based sampler on the pymc3jax branch:

I am playing around with a hierarchical model where I am simulating something like a hierarchical Gaussian mixture process, i.e. I have 3 clusters with associated cluster_means and std around them and then I simulate a number of instantiations for each cluster, where each instance has its own mean and a fixed std.

This is the code to simulate the data

import pandas as pd
import numpy as np
import pymc3 as pm
import arviz as az
import pylab as pl

np.random.seed(123)

N_clusters = 3  # Number of clusters
N_samples = [10, 5, 0]  # Number of samples per cluster
total_samples = sum(N_samples)
N = 100 # Number of samples per sample
cluster_means = [1., 2., 3.]  # Mean of means within cluster
cluster_means_std = [0.1, 0.1, 0.1]  # Std of means within cluster
std = 0.5

data = []
true_means = []
for i in range(N_clusters):
    if N_samples[i] > 0:
        means = np.random.normal(loc=cluster_means[i], scale=cluster_means_std[i], size=N_samples[i])
        true_means = np.append(true_means, means)
        data.append(np.array([np.random.normal(means[j], std, N) for j in range(N_samples[i])]))
data = np.vstack(data)
clusters = []
for i in range(N_clusters):
    clusters += [i] * N_samples[i]
data = data.reshape(-1)

c = np.repeat(clusters, N).reshape(-1)
sample = np.repeat(np.arange(sum(N_samples)), N)

Using these data, I am creating this model:

with pm.Model() as model:
    a = pm.Normal('a', mu= 0., sigma=3., shape=N_clusters)
    sigma_a = pm.Exponential('sigma_a', 1., shape=N_clusters)
    
    mu_tilde = pm.Normal('mu_t', mu=0., sigma=1., shape=total_samples)
    mu = pm.Deterministic('mu', mu_tilde * sigma_a[clusters] + a[clusters])
    
    sigma = pm.Exponential('sigma', 1., shape=total_samples)
    
    data_obs = pm.Normal('data', mu=mu[sample], sigma=sigma[sample], observed=data)

and then I want to use the sampler to do inference:

import pymc3.sampling_jax


with model:
    trace_jax = pm.sampling_jax.sample_numpyro_nuts(
            2000, tune=2000, target_accept=.9)

Please provide the full traceback.
Running this code, I am getting this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-6-7bb1ee15df7f> in <module>
      3 
      4 with model:
----> 5     trace_jax = pm.sampling_jax.sample_numpyro_nuts(
      6             2000, tune=2000, target_accept=.9)
      7     idata = trace_jax

/path/to/pymc3/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar)
    114 
    115     fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
--> 116     fns = theano.sandbox.jaxify.jax_funcify(fgraph)
    117     logp_fn_jax = fns[0]
    118 

/path/to/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_FunctionGraph(fgraph)
    523 
    524     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 525     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    526 
    527     return jax_funcs

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in <listcomp>(.0)
    523 
    524     out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
--> 525     jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
    526 
    527     return jax_funcs

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
    109             input_f = jax_data_func
    110         else:
--> 111             input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
    112 
    113         input_funcs.append(input_f)

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in compose_jax_funcs(out_node, fgraph_inputs, memo)
     90         return memo[out_node]
     91 
---> 92     jax_return_func = jax_funcify(out_node.op)
     93 
     94     input_funcs = []

/path/to/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_Elemwise(op)
    320 def jax_funcify_Elemwise(op):
    321     scalar_op = op.scalar_op
--> 322     return jax_funcify(scalar_op)
    323 
    324 

/path/to/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

/path/to/Theano-PyMC/theano/sandbox/jaxify.py in jax_funcify_ScalarOp(op)
    142 def jax_funcify_ScalarOp(op):
    143     print(op)
--> 144     func_name = op.nfunc_spec[0]
    145 
    146     if "." in func_name:

AttributeError: 'Identity' object has no attribute 'nfunc_spec'

Versions and main components

  • PyMC3 Version: checkout of pymc3jax branch
  • Theano Version: checkout of Theano-Pymc master branch
  • Python Version: 3.8
  • Operating system: Mac OS
  • How did you install PyMC3: manual installation of the branch

Turn theano.compat into a module

Currently, theano.compat is a sub-package containing only an __init__.py file. The __init__.py file should be renamed to compat.py, moved into the theano directory/package, and the old compat directory should be deleted.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.