jcmgray / autoray Goto Github PK
View Code? Open in Web Editor NEWAbstract your array operations.
Home Page: https://autoray.readthedocs.io
License: Apache License 2.0
Abstract your array operations.
Home Page: https://autoray.readthedocs.io
License: Apache License 2.0
from autoray.lazy import Variable
from autoray import do
a = Variable(shape=(4,))
b = do('random.normal', size=(4,), like='numpy')
c = do('asarray', b, like='torch')
index = [[2, 3]]
print(a[index].shape)
print(b[index].shape)
print(c[index].shape)
print('----------------------------')
index = [[[0, 1, 2, 3]]]
print(a[index].shape)
print(b[index].shape)
print(c[index].shape)
# (1,)
# (1, 2)
# torch.Size([2])
# ----------------------------
# (1,)
# (1, 1, 4)
# torch.Size([1, 4])
The shape of LazyArray
are different from the torch.Tensor
and numpy.array
after the getitem operation in these two situation.
Here are my rough thoughts:
LazyArray
shape consistent with the numpy.array
.torch.Tensor
shape consistent with the numpy.array
(may use the wrapper for getitem when dealing with the torch.Tensor
)Hi!
many thanks for this nice library. I dont understand how to run a lazy einsum please. For example in dask terms that would be:
import dask.array as da
x = [[1,2],[3,4]]
de = da.einsum('ij, nk -> ', x, x)
de.compute()
I might not get it right, but I believe under autoray this should be:
from autoray import lazy
x = [[1,2],[3,4]]
lx = lazy.array(x)
le = np.einsum('ij, nk -> ', lx, lx)
and this is where it fails. The error I get on pyCharm is:
/home/dimitris/miniconda3/envs/autoray/lib/python3.9/site-packages/autoray/lazy/core.py:1414: UserWarning: Could not find a full input parser for einsum expressions. Please install either cotengra or opt_einsum for advanced input formats (interleaved, ellipses, no-output).
warnings.warn(
Traceback (most recent call last):
File "/home/dimitris/dev/python/pciSeq/run_app.py", line 14, in <module>
le = lazy.einsum('ij, nk->', lx, lx)
File "/home/dimitris/miniconda3/envs/autoray/lib/python3.9/site-packages/autoray/lazy/core.py", line 1118, in wrapped
return fn(*args, **kwargs)
File "/home/dimitris/miniconda3/envs/autoray/lib/python3.9/site-packages/autoray/lazy/core.py", line 1431, in einsum
size_dict[char] = max(size_dict.get(char, 1), op_shape[i])
IndexError: tuple index out of range
For some strange to me reason, the error is different when I run the same commands in a jupyter notebook:
[/home/dimitris/miniconda3/envs/autoray/lib/python3.9/site-packages/autoray/lazy/core.py:774](http://localhost:8888/home/dimitris/miniconda3/envs/autoray/lib/python3.9/site-packages/autoray/lazy/core.py#line=773): UserWarning: Iterating over LazyArray to get the computational graph nodes is deprecated - use `LazyArray.descend()` instead. Eventually `iter(lz)` will iterate over first axis slices.
warnings.warn(
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 1
----> 1 le = np.einsum('ij, nk -> ', lx, lx)
File [~/miniconda3/envs/autoray/lib/python3.9/site-packages/numpy/core/einsumfunc.py:1371](http://localhost:8888/home/dimitris/miniconda3/envs/autoray/lib/python3.9/site-packages/numpy/core/einsumfunc.py#line=1370), in einsum(out, optimize, *operands, **kwargs)
1369 if specified_out:
1370 kwargs['out'] = out
-> 1371 return c_einsum(*operands, **kwargs)
1373 # Check the kwargs to avoid a more cryptic error later, without having to
1374 # repeat default values here
1375 valid_einsum_kwargs = ['dtype', 'order', 'casting']
ValueError: setting an array element with a sequence. The requested array would exceed the maximum number of dimension of 32.
Am i doing something wrong please? Thanks in advance
Hi, this is an interesting, yet difficult library! While it seems intriguing, I wonder thought about the interface: why is there always a 'do'? Why does the library not just wrap the API directly? Writing do("command")
seems quite cumbersome compared to command
.
Is that API a future plan?
When using concatenate
, the result is always converted to numpy arrays. E.g.
A = np.random.normal(size=(10,10),like='tensorflow')
B = np.random.normal(size=(10,10),like='tensorflow')
concat = ar.do('concatenate',(A,B),axis=0)
type(concat)
>> numpy.ndarray
This can be mitigated by instead doing
ar.do('concatenate',(A,B),axis=0,like=ar.infer_backend(A))
but this is a bit unwieldy. The problem is that the argument (A,B)
is a tuple, which belongs to backend builtins
, which in turn always gets inferred as numpy
by infer_backend
.
This problem applies to any function whose first argument is a list/tuple of arrays. I know at least that this applied to concatenate
, einsum
and stack
. For einsum
I just opted to call opt_einsum
directly, which does correctly infer backend in this case, but that is besides the point.
I can see several possible approaches:
ar.register_function
the user should also be able to indicate the function is of this type._infer_class_backend_cached
make a specific check for builtins
: we check if the item is iterable, if so we check the backend of the first element. If it is again builtins
, then leave it as is, but if it is something else then return that backend instead.I'm partial to the second option, as I don't expect it to have too many side-effects. If you want I can do a PR.
from autoray.lazy import Variable
from autoray import numpy as np
a = Variable(shape=(2,))
b = Variable(shape=(2,))
d = np.where(a>1, b, 1)
fn = d.get_function((a, b))
x = np.asarray([0.5, 3])
y = np.asarray([0.2, 1.5])
print(fn((x, y)))
line 44, in
fn = d.get_function((a, b))
AttributeError: 'numpy.ndarray' object has no attribute 'get_function'
How can I use the np.where to the LazyArray and get the function to the real array
Currently, autoray defaults "jax" to "jax.numpy", with only a few functions explicitly aliased from other submodules. This makes many functions not in jax.numpy
, such as jax.scipy.fft.fft()
, inaccessible unless registered manually. Adding _SUBMODULE_ALIASES["jax", "scipy.fft.fft"] = "jax.scipy.fft"
for each function works but I wonder if there's a more elegant solution.
The api for the numpy.ndarray transpose
attribute allows it to permute an arbitrary number of indices into an arbitrary order. However, the torch.Tensor transpose
attribute assumes a matrix and therefore only accepts two indices. This means something like the following will fail:
import numpy
import torch
from autoray import do, transpose
Ttorch = torch.zeros([2,3,4,5])
Tnp = numpy.zeros([2,3,4,5])
print(Tnp.transpose([2,1,3,0]).shape) # gives (4,3,5,2), as expected
print(transpose(Tnp, [2,1,3,0]).shape) # also gives (4,3,5,2)
print(Ttorch.transpose([2,1,3,0]).size()) # this fails with a TypeError
print(transpose(Ttorch, [2,1,3,0]).size()) # which means this also fails
The correct torch.Tensor attribute is permute
, which has the same exact behavior as numpy.ndarray.transpose
. This means that something like the following will do what we want:
import numpy
import torch
from autoray import do, transpose
Ttorch = torch.zeros([2,3,4,5])
Tnp = numpy.zeros([2,3,4,5])
print(Tnp.transpose([2,1,3,0]).shape) # gives (4,3,5,2), as expected
print(transpose(Tnp, [2,1,3,0]).shape) # also gives (4,3,5,2)
print(Ttorch.permute(2,1,3,0).size()) # also gives (4,3,5,2)
I'm not sure that there is a way to incorporate this behavior in a clean, non-invasive manner. As far as I understand, the _module_aliases
and _func_aliases
dictionaries are not applicable since permute
is only an attribute of torch.Tensor
(i.e. there is no torch.permute(torch.Tensor, *args)
). This therefore seems to necessitate direct modification of the autoray.transpose
function (line 308). The following patch works, but it's not very clean:
current code:
def transpose(x, *args):
try:
return x.transpose(*args)
except AttributeError:
return do('transpose', x, *args)
patched code:
def transpose(x, *args):
backend = infer_backend(x)
if backend == 'torch':
return x.permute(*args)
else:
try:
return x.transpose(*args)
except AttributeError:
return do('transpose', x, *args)
The inherent challenge is that we need to alias x.transpose()
to x.permute()
when x
is a torch.Tensor
. If there is a better way than what I have suggested, let me know!
(p.s.) I found this problem via an error I obtained in quimb
. I was trying to fuse
multiple bonds of a quimb
Tensor when using pyTorch
as the backend, and this problem arose.
NumPy 1.24 removed some long deprecated aliases for dtypes, including np.complex
.
Numba is currently holding back an update to NumPy 1.24, but we have a patched numba (with numba/numba#8620) and thus see this autoray generated error in the quimb test suite:
[ 343s] __________________ TestCircuit.test_all_gate_methods[Circuit] __________________
[ 343s]
[ 343s] backend = 'builtins', fn = 'complex'
[ 343s]
[ 343s] def get_lib_fn(backend, fn):
[ 343s] """Cached retrieval of correct function for backend, all the logic for
[ 343s] finding the correct funtion only runs the first time.
[ 343s]
[ 343s] Parameters
[ 343s] ----------
[ 343s] backend : str
[ 343s] The module defining the array class to dispatch on.
[ 343s] fn : str
[ 343s] The function to retrieve.
[ 343s]
[ 343s] Returns
[ 343s] -------
[ 343s] callable
[ 343s] """
[ 343s]
[ 343s] try:
[ 343s] > lib_fn = _FUNCS[backend, fn]
[ 343s] E KeyError: ('builtins', 'complex')
[ 343s]
[ 343s] /usr/lib/python3.8/site-packages/autoray/autoray.py:431: KeyError
[ 343s]
[ 343s] During handling of the above exception, another exception occurred:
[ 343s]
[ 343s] backend = 'builtins', fn = 'complex'
[ 343s]
[ 343s] def import_lib_fn(backend, fn):
[ 343s]
[ 343s] # first check explicitly composed functions -> if the function hasn't been
[ 343s] # called directly yet, it won't have been loaded into the cache, and needs
[ 343s] # generating before e.g. the ``do`` verrsion will work
[ 343s] if fn in _COMPOSED_FUNCTION_GENERATORS:
[ 343s] return _COMPOSED_FUNCTION_GENERATORS[fn](backend)
[ 343s]
[ 343s] try:
[ 343s] # alias for global module,
[ 343s] # e.g. 'decimal' -> 'math'
[ 343s] module = _MODULE_ALIASES.get(backend, backend)
[ 343s]
[ 343s] # submodule where function is found for backend,
[ 343s] # e.g. ['tensorflow', trace'] -> 'tensorflow.linalg'
[ 343s] try:
[ 343s] full_location = _SUBMODULE_ALIASES[backend, fn]
[ 343s]
[ 343s] # if explicit submodule alias given, don't use prepended location
[ 343s] # for example, ('torch', 'linalg.svd') -> torch.svd
[ 343s] only_fn = fn.split(".")[-1]
[ 343s]
[ 343s] except KeyError:
[ 343s] full_location = module
[ 343s]
[ 343s] # move any prepended location into the full module path
[ 343s] # e.g. 'fn=linalg.eigh' -> ['linalg', 'eigh']
[ 343s] split_fn = fn.split(".")
[ 343s] full_location = ".".join([full_location] + split_fn[:-1])
[ 343s] only_fn = split_fn[-1]
[ 343s]
[ 343s] # cached lookup of custom name function might take
[ 343s] # e.g. ['tensorflow', 'sum'] -> 'reduce_sum'
[ 343s] fn_name = _FUNC_ALIASES.get((backend, fn), only_fn)
[ 343s]
[ 343s] # import the function into the cache
[ 343s] try:
[ 343s] lib = importlib.import_module(full_location)
[ 343s] except ImportError:
[ 343s] if "." in full_location:
[ 343s] # sometimes libraries hack an attribute to look like submodule
[ 343s] mod, submod = full_location.split(".")
[ 343s] lib = getattr(importlib.import_module(mod), submod)
[ 343s] else:
[ 343s] # failed to import library at all -> catch + raise ImportError
[ 343s] raise AttributeError
[ 343s]
[ 343s] # check for a custom wrapper but default to identity
[ 343s] wrapper = _CUSTOM_WRAPPERS.get((backend, fn), lambda fn: fn)
[ 343s]
[ 343s] # store the function!
[ 343s] > lib_fn = _FUNCS[backend, fn] = wrapper(getattr(lib, fn_name))
[ 343s]
[ 343s] /usr/lib/python3.8/site-packages/autoray/autoray.py:397:
[ 343s] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[ 343s]
[ 343s] attr = 'complex'
[ 343s]
[ 343s] def __getattr__(attr):
[ 343s] # Warn for expired attributes, and return a dummy function
[ 343s] # that always raises an exception.
[ 343s] import warnings
[ 343s] try:
[ 343s] msg = __expired_functions__[attr]
[ 343s] except KeyError:
[ 343s] pass
[ 343s] else:
[ 343s] warnings.warn(msg, DeprecationWarning, stacklevel=2)
[ 343s]
[ 343s] def _expired(*args, **kwds):
[ 343s] raise RuntimeError(msg)
[ 343s]
[ 343s] return _expired
[ 343s]
[ 343s] # Emit warnings for deprecated attributes
[ 343s] try:
[ 343s] val, msg = __deprecated_attrs__[attr]
[ 343s] except KeyError:
[ 343s] pass
[ 343s] else:
[ 343s] warnings.warn(msg, DeprecationWarning, stacklevel=2)
[ 343s] return val
[ 343s]
[ 343s] if attr in __future_scalars__:
[ 343s] # And future warnings for those that will change, but also give
[ 343s] # the AttributeError
[ 343s] warnings.warn(
[ 343s] f"In the future `np.{attr}` will be defined as the "
[ 343s] "corresponding NumPy scalar. (This may have returned Python "
[ 343s] "scalars in past versions.", FutureWarning, stacklevel=2)
[ 343s]
[ 343s] # Importing Tester requires importing all of UnitTest which is not a
[ 343s] # cheap import Since it is mainly used in test suits, we lazy import it
[ 343s] # here to save on the order of 10 ms of import time for most users
[ 343s] #
[ 343s] # The previous way Tester was imported also had a side effect of adding
[ 343s] # the full `numpy.testing` namespace
[ 343s] if attr == 'testing':
[ 343s] import numpy.testing as testing
[ 343s] return testing
[ 343s] elif attr == 'Tester':
[ 343s] from .testing import Tester
[ 343s] return Tester
[ 343s]
[ 343s] > raise AttributeError("module {!r} has no attribute "
[ 343s] "{!r}".format(__name__, attr))
[ 343s] E AttributeError: module 'numpy' has no attribute 'complex'
[ 343s]
[ 343s] /usr/lib64/python3.8/site-packages/numpy/__init__.py:284: AttributeError
[ 343s]
[ 343s] During handling of the above exception, another exception occurred:
[ 343s]
[ 343s] self = <tests.test_tensor.test_circuit.TestCircuit object at 0x7f0ad3309a00>
[ 343s] Circ = <class 'quimb.tensor.circuit.Circuit'>
[ 343s]
[ 343s] @pytest.mark.parametrize(
[ 343s] 'Circ', [qtn.Circuit, qtn.CircuitMPS, qtn.CircuitDense]
[ 343s] )
[ 343s] def test_all_gate_methods(self, Circ):
[ 343s] import random
[ 343s]
[ 343s] g_nq_np = [
[ 343s] # single qubit
[ 343s] ('x', 1, 0),
[ 343s] ('y', 1, 0),
[ 343s] ('z', 1, 0),
[ 343s] ('s', 1, 0),
[ 343s] ('t', 1, 0),
[ 343s] ('h', 1, 0),
[ 343s] ('iden', 1, 0),
[ 343s] ('x_1_2', 1, 0),
[ 343s] ('y_1_2', 1, 0),
[ 343s] ('z_1_2', 1, 0),
[ 343s] ('w_1_2', 1, 0),
[ 343s] ('hz_1_2', 1, 0),
[ 343s] # single qubit parametrizable
[ 343s] ('rx', 1, 1),
[ 343s] ('ry', 1, 1),
[ 343s] ('rz', 1, 1),
[ 343s] ('u3', 1, 3),
[ 343s] ('u2', 1, 2),
[ 343s] ('u1', 1, 1),
[ 343s] # two qubit
[ 343s] ('cx', 2, 0),
[ 343s] ('cy', 2, 0),
[ 343s] ('cz', 2, 0),
[ 343s] ('cnot', 2, 0),
[ 343s] ('swap', 2, 0),
[ 343s] ('iswap', 2, 0),
[ 343s] # two qubit parametrizable
[ 343s] ('cu3', 2, 3),
[ 343s] ('cu2', 2, 2),
[ 343s] ('cu1', 2, 1),
[ 343s] ('fsim', 2, 2),
[ 343s] ('fsimg', 2, 5),
[ 343s] ('rzz', 2, 1),
[ 343s] ('su4', 2, 15),
[ 343s] ]
[ 343s] random.shuffle(g_nq_np)
[ 343s]
[ 343s] psi0 = qtn.MPS_rand_state(2, 2)
[ 343s] circ = Circ(2, psi0, tags='PSI0')
[ 343s]
[ 343s] for g, n_q, n_p in g_nq_np:
[ 343s] args = [
[ 343s] *np.random.uniform(0, 2 * np.pi, size=n_p),
[ 343s] *np.random.choice([0, 1], replace=False, size=n_q)
[ 343s] ]
[ 343s] > getattr(circ, g)(*args)
[ 343s]
[ 343s] tests/test_tensor/test_circuit.py:185:
[ 343s] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[ 343s] quimb/tensor/circuit.py:1077: in fsimg
[ 343s] self.apply_gate('FSIMG', theta, zeta, chi, gamma, phi, i, j,
[ 343s] quimb/tensor/circuit.py:941: in apply_gate
[ 343s] self._apply_gate(gate, **gate_opts)
[ 343s] quimb/tensor/circuit.py:883: in _apply_gate
[ 343s] self._psi.gate_(gate.array, gate.qubits, tags=tags, **opts)
[ 343s] quimb/tensor/circuit.py:655: in array
[ 343s] self._array = self.build_array()
[ 343s] quimb/tensor/circuit.py:650: in build_array
[ 343s] return _cached_param_gate_build(param_fn, self._params)
[ 343s] quimb/tensor/circuit.py:534: in _cached_param_gate_build
[ 343s] return fn(params)
[ 343s] quimb/tensor/circuit.py:436: in fsimg_param_gen
[ 343s] img = do('complex', img_re, img_im)
[ 343s] /usr/lib/python3.8/site-packages/autoray/autoray.py:79: in do
[ 343s] return get_lib_fn(backend, fn)(*args, **kwargs)
[ 343s] /usr/lib/python3.8/site-packages/autoray/autoray.py:433: in get_lib_fn
[ 343s] lib_fn = import_lib_fn(backend, fn)
[ 343s] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[ 343s]
[ 343s] backend = 'builtins', fn = 'complex'
[ 343s]
[ 343s] def import_lib_fn(backend, fn):
[ 343s]
[ 343s] # first check explicitly composed functions -> if the function hasn't been
[ 343s] # called directly yet, it won't have been loaded into the cache, and needs
[ 343s] # generating before e.g. the ``do`` verrsion will work
[ 343s] if fn in _COMPOSED_FUNCTION_GENERATORS:
[ 343s] return _COMPOSED_FUNCTION_GENERATORS[fn](backend)
[ 343s]
[ 343s] try:
[ 343s] # alias for global module,
[ 343s] # e.g. 'decimal' -> 'math'
[ 343s] module = _MODULE_ALIASES.get(backend, backend)
[ 343s]
[ 343s] # submodule where function is found for backend,
[ 343s] # e.g. ['tensorflow', trace'] -> 'tensorflow.linalg'
[ 343s] try:
[ 343s] full_location = _SUBMODULE_ALIASES[backend, fn]
[ 343s]
[ 343s] # if explicit submodule alias given, don't use prepended location
[ 343s] # for example, ('torch', 'linalg.svd') -> torch.svd
[ 343s] only_fn = fn.split(".")[-1]
[ 343s]
[ 343s] except KeyError:
[ 343s] full_location = module
[ 343s]
[ 343s] # move any prepended location into the full module path
[ 343s] # e.g. 'fn=linalg.eigh' -> ['linalg', 'eigh']
[ 343s] split_fn = fn.split(".")
[ 343s] full_location = ".".join([full_location] + split_fn[:-1])
[ 343s] only_fn = split_fn[-1]
[ 343s]
[ 343s] # cached lookup of custom name function might take
[ 343s] # e.g. ['tensorflow', 'sum'] -> 'reduce_sum'
[ 343s] fn_name = _FUNC_ALIASES.get((backend, fn), only_fn)
[ 343s]
[ 343s] # import the function into the cache
[ 343s] try:
[ 343s] lib = importlib.import_module(full_location)
[ 343s] except ImportError:
[ 343s] if "." in full_location:
[ 343s] # sometimes libraries hack an attribute to look like submodule
[ 343s] mod, submod = full_location.split(".")
[ 343s] lib = getattr(importlib.import_module(mod), submod)
[ 343s] else:
[ 343s] # failed to import library at all -> catch + raise ImportError
[ 343s] raise AttributeError
[ 343s]
[ 343s] # check for a custom wrapper but default to identity
[ 343s] wrapper = _CUSTOM_WRAPPERS.get((backend, fn), lambda fn: fn)
[ 343s]
[ 343s] # store the function!
[ 343s] lib_fn = _FUNCS[backend, fn] = wrapper(getattr(lib, fn_name))
[ 343s]
[ 343s] except AttributeError:
[ 343s]
[ 343s] # check if there is a backup function (e.g. for older library version)
[ 343s] backend_alt = backend + "[alt]"
[ 343s] if backend_alt in _MODULE_ALIASES:
[ 343s] return import_lib_fn(backend_alt, fn)
[ 343s]
[ 343s] > raise ImportError(
[ 343s] f"autoray couldn't find function '{fn}' for "
[ 343s] f"backend '{backend.replace('[alt]', '')}'."
[ 343s] )
[ 343s] E ImportError: autoray couldn't find function 'complex' for backend 'builtins'.
[ 343s]
[ 343s] /usr/lib/python3.8/site-packages/autoray/autoray.py:406: ImportError
In addition to the take
translation I added in my previous PR, there is some more that might be good to add. At least, I am using these myself. I can make a PR.
split
. The syntax is different for numpy and tensorflow/torch. The former wants the number of splits or an array of locations of splits, whereas tensorflow/torch either want the number of splits or an array of split sizes. We can go from one format the other using np.diff
diff
. This is implemented in tensorflow as tf.experimental.numpy.diff
, and not implemented at all for torch. This also means I don't know what the cleanest way is to implement split
mentioned above. Maybe just using np.diff
and then convert to array of right backend if necessary?linalg.norm
, seems to work with tensorflow, but for torch we need to do _SUBMODULE_ALIASES["torch", "linalg.norm"] = "torch"
Maybe a bit of an overly ambitious idea, but have you ever thought about baking in support for JIT? Right now it seems that for TensorFlow everything works with eager execution, and I'm not sure you can compile the computation graphs resulting from a series of ar.do calls.
PyTorch also support JIT to some extend with TorchScript
Numpy doesn' t have JIT, but there is Numba
Cupy has an interface with Numba that does seem to allow JIT.
JAX has support for JIT
Another thing is gradients. Several of these libraries have automatic gradients, and having an autoray interface for doing computations with automatic gradients would be fantastic as well (although probably also ambitious).
If you think these things are doable at all, I wouldn't mind spending some time to try to figure out how this could work.
Less ambitiously, you did mention in #3 that something along the lines of
with set_backend(like):
...
would be pretty nice. I can try to do this. This probably comes down to checking for a global flag in ar.do
after the line
if like is None:
Something I ran into is that different backends prefer either single or double precision. I personally need double precision, or at least prefer to consistently use one precision. This is also much more fair for benchmarking. The main problem is when forming arrays, for example to generate (2,2) random normal array with double precision we should do:
import jax
jax.config.update('jax_enable_x64', True)
for backend in ['numpy', 'tensorflow', 'torch', 'jax', 'dask', 'mars', 'sparse']:
if backend in ('tensorflow', 'torch'):
A = ar.do("random.normal", size=(2,2), like=backend, dtype=ar.to_backend_dtype('float64', backend))
else:
A = ar.do("random.normal", size=(2,2), like=backend)
We can't just always supply the dtype
argument, since numpy, dask and sparse throw an error when fed dtype
. We could also generate whatever dtype array and then convert the result to double precision, but this doesn't really address the problem. This doesn't just hold for random.normal
, but for essentially any kind of array-creating functions, like zeros
or eye
, although there supplying dtype
does work. (for jax we still need to set jax_enable_x64
.)
I can also see from your gen_rand
method in test_autoray.py
that you encountered similar problems.
Suggested solutions
dtype
keyword, and then converts result to the correct dtype after the fact (if dtype is 'float32'). For jax we should maybe throw a warning if trying to generate double precision random numbers without setting 'jax_enable_x64'
to True
. In fact, for example for 'zeros', jax already throws a warning in this situation.It might also be worthwhile to add translations for some more standard distributions like binomial
or poisson
, although I mostly use normal and uniform myself.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.