Giter Site home page Giter Site logo

pytorch / functorch Goto Github PK

View Code? Open in Web Editor NEW
1.4K 28.0 102.0 16.01 MB

functorch is JAX-like composable function transforms for PyTorch.

Home Page: https://pytorch.org/functorch/

License: BSD 3-Clause "New" or "Revised" License

Python 1.82% Batchfile 18.13% Jupyter Notebook 80.05%
pytorch hessians gradients

functorch's Issues

Batch rule plumbing codegen

Problem: Writing plumbing is repetitive, see link

We should have some way of auto-generating the plumbing and allowing a developer to insert some dispatching logic into the middle of the plumbing.

Proposal 1: Macro our way to victory

Every op gets a OP_PLUMBING_START and a OP_PLUMBING_END macro. Inside BatchRulesLoss.cpp, here's how we would write the plumbing for nll_loss_forward:

nll_loss_forward_PLUMBING_BEGIN
  if (!self_bdim && !target_bdim && !weight_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    return at::nll_loss_forward(self_value, target_value, weight_value, reduction, ignore_index);
  }

  if (self_bdim && target_bdim && (!weight || !weight->defined()) && ignore_index < 0) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    auto results = nll_loss_forward_self_target_batch_rule(
        self_value, self_bdim, target_value, target_bdim, reduction);
    return std::make_tuple(
      makeBatched(std::get<0>(results), std::get<1>(results), cur_level),
      makeBatched(std::get<2>(results), std::get<3>(results), cur_level)
    );
  }
nll_loss_backward_PLUMBING_END


TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
  m.impl("nll_loss_forward", nll_loss_forward_plumbing);
}

Proposal 2: Have a "batch_rules.yaml" file

The .yaml file could handle the registration (e.g. m.impl("nll_loss_forward", nll_loss_forward_plumbing);) as well

 - name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
 - dispatch: >
  if (!self_bdim && !target_bdim && !weight_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    return at::nll_loss_forward(self_value, target_value, weight_value, reduction, ignore_index);
  }

  if (self_bdim && target_bdim && (!weight || !weight->defined()) && ignore_index < 0) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    auto results = nll_loss_forward_self_target_batch_rule(
        self_value, self_bdim, target_value, target_bdim, reduction);
    return std::make_tuple(
      makeBatched(std::get<0>(results), std::get<1>(results), cur_level),
      makeBatched(std::get<2>(results), std::get<3>(results), cur_level)
    );
  }

I don't like either solution so haven't implemented any of them yet.

Error while installing functorch

I am trying to install functorch for doing some tests with vmap but I am not being able to install it following the instructions in the README. I'm just trying to run the Colab, but I'm getting the following error:

Collecting git+https://github.com/zou3519/functorch.git
  Cloning https://github.com/zou3519/functorch.git to /tmp/pip-req-build-57x_64b3
  Running command git clone -q https://github.com/zou3519/functorch.git /tmp/pip-req-build-57x_64b3
Requirement already satisfied: torch>=1.9.0.dev in /usr/local/lib/python3.7/dist-packages (from functorch==0.0.1a0+2890f63) (1.9.0.dev20210429+cpu)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.9.0.dev->functorch==0.0.1a0+2890f63) (3.7.4.3)
Building wheels for collected packages: functorch
  Building wheel for functorch (setup.py) ... error
  ERROR: Failed building wheel for functorch
  Running setup.py clean for functorch
Failed to build functorch
Installing collected packages: functorch
    Running setup.py install for functorch ... error
ERROR: Command errored out with exit status 1: /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-57x_64b3/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-57x_64b3/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-lv6v2_fa/install-record.txt --single-version-externally-managed --compile --user --prefix= Check the logs for full command output.

Any idea how to solve it?

Decompose CompositeImplicitAutograd ops at the FuncTorchBatched key

Background

@ezyang suggested to try this to minimize the number of operators we have to override. More concretely, instead of registering all 2000 operators to FuncTorchBatched; we only have to register (insert number here) of operators that are not composite w.r.t. autograd.

To be concrete, the suggestion was to add FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32

The experiment

I added FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32, recompiled PyTorch and functorch, and then ran the test suite. This leads to a fun number of failures (see here) that have the same root cause!

The problem is that some CompositeImplicitAutograd ops decompose to in-place operations that are not compatible with vmap (note here).

Can we solve these problems by just registering an override for the vmap key for those operations?

  • that would solve the vmap(blah) problem but I'm not sure because a vmap(grad(blah)) is always going to decompose blah since it runs through the grad transform.

batching rule for repeat has some cases it fails on

def f(x):
    return x.repeat(0)

print(vmap(f)(torch.randn(3)).shape) # Returns shape [1,0]

This one is kind of awkward (I'm surprised it's legal), but I think it makes sense to preserve the invariant that the output always has shape B in the batching dim.

def f(x):
    return x.repeat(1)

print(vmap(f)(torch.randn((3))).shape) # Returns shape [1,3], should return [3,1]

Please keep up the great work!!

This is not a bug report or feature request, but more a shout of admiration. I am a big-time PyTorch fan, but have been looking at JAX because they support vmap and other awesome functional tools. If PyTorch succeeds in reproducing these features, I think it takes one major reason away from switching to JAX.

Please, please keep up the good work and make this happen! Thanks for listening to your community and taking the initiative. You guys are the best!

grad doesn't work with _VF.frobenius_norm

from functorch import jacrev, vmap, grad

def f(x):
    return torch.norm(x, dim=1).sum()

print(grad(f)(torch.randn(3, 3)))
Traceback (most recent call last):
  File "t.py", line 39, in <module>
    print(grad(f)(torch.randn(3, 3)))
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 149, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 107, in wrapper
    output = f(*args)
  File "t.py", line 37, in f
    return torch.norm(x, dim=1).sum()
  File "/home/chilli/fb/pytorch/torch/functional.py", line 1441, in norm
    return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
NotImplementedError: Cannot access storage of TensorWrapper

Haven't dug into this issue - this might be the same root cause as others - i.e: torch.tensor

jacrev(jacrev(f)) fails for matmul

from functorch import jacrev
N = 5
M = 3
W = torch.randn(N, M)
def f(x):
    return W @ x
inps = (torch.randn(M),)
print(jacrev(jacrev(f))(*inps))
Traceback (most recent call last):
  File "python_key.py", line 70, in <module>
    print(jacrev(jacrev(f))(*inps))
  File "/opt/anaconda/lib/python3.7/site-packages/functorch-0.0.1a0+e27e16f-py3.7-linux-x86_64.egg/functorch/_src/eager_transforms.py", line 87, in wrapper_fn
    result, = vmap(vjp_fn)(basis)
  File "/opt/anaconda/lib/python3.7/site-packages/functorch-0.0.1a0+e27e16f-py3.7-linux-x86_64.egg/functorch/_src/vmap.py", line 258, in wrapped
    batched_outputs = func(*batched_inputs)
  File "/opt/anaconda/lib/python3.7/site-packages/functorch-0.0.1a0+e27e16f-py3.7-linux-x86_64.egg/functorch/_src/eager_transforms.py", line 72, in wrapper
    retain_graph=retain_graph, create_graph=create_graph)
  File "/home/chilli/fb/pytorch/torch/autograd/__init__.py", line 228, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Batching rule not implemented for aten::rsub.Scalar, aten::diag, aten::where.Scalar, aten::allclose.

Hello and thank you for the great project:
We are using it in a fully differentiable physics code atm and found some use cases that were not covered yet.
We are vmapping some functions that use torch.where and/or torch.diag e.g.:

def set_diagonal_to_inf(hamiltonian, value=10e9):
    """Args:
        hamiltonian: Matrix of hamiltonian (N_k, particle number*N_orbitals, particle number*N_orbitals)
        value: int/ float value the zeros are set to
    Returns:
        Hamiltonian with high eigenvalues for non existing particles/orbitals
    """
    diag = torch.sum(torch.abs(hamiltonian), axis=0)
    diag = torch.where(diag == 0, value, 0.)
    return torch.diag(diag)

and got the warning that they were not implemented yet:
functorch/_src/vmap.py:268: UserWarning: Batching rule not implemented for aten::diag falling back to slow (for loop and stack) implementation (Triggered internally at /tmp/pip-req-build-_exf5lof/functorch/csrc/BatchedFallback.cpp:89.)
batched_outputs = func(*batched_inputs)
UserWarning: Batching rule not implemented for aten::where.Scalar falling back to slow (for loop and stack) implementation (Triggered internally at /tmp/pip-req-build-_exf5lof/functorch/csrc/BatchedFallback.cpp:89.)

Another function was constructing a 2d array from an input vector, where each array element was constructed with a different formula. That function was performance relevant but we were able to vectorize it without vmap, anyway it lead to the following warning.
torch/_tensor.py:544: UserWarning: Batching rule not implemented for aten::rsub.Scalar falling back to slow (for loop and stack) implementation (Triggered internally at /tmp/pip-req-build-_exf5lof/functorch/csrc/BatchedFallback.cpp:89.)
return _C._VariableFunctions.rsub(self, other)

At last we would like to solve a lot of eigenvalue problems with the symeig solver from xitorch https://github.com/xitorch/xitorch

def eig(ham):
     ham =(ham+ham.T)/2
     ham = xitorch.LinearOperator.m(ham)
     return  xilinalg.symeig(ham)[0]

parallel_solve = vmap(eig,0)
test = torch.rand(7,10,10)
parallel_solve(test)

which lead to the following error
File "xitorch/_core/linop.py", line 100, in m
is_hermitian = torch.allclose(mat, mat.transpose(-2, -1))
RuntimeError: Batching rule not implemented for aten::allclose. We could not generate a fallback.

Would be great if you can find the time to add some of these, especially the last one.
Ps: thank you for using the same vmap syntax as jax that saved a lot of time converting the code.

Figure out API for working with an ensemble of modules

Problem: How does one initialize an ensemble of modules?

One possible solution is to offer an API that returns parameters with an additional "ensemble" dimension. For example, if we were trying to ensemble models that contained a single nn.Linear layer, then we'd return a weight and bias each with an extra ensemble dimension.

That leads to something like the following API:

state_dict = functional_init_ensemble(nn.Linear, 3, 3, ensemble_size=5, device='cpu')

This returns a state_dict that has two elements:

  • weight with shape (5, 3, 3)
  • bias with shape (5, 3)

There are some problems with returning a state dict:

  • need some way of separating out buffers and parameters

Another way to do this is to have all nn.Modules take in an additional 'ensemble_size' dimension and straight up just return nn.Modules...

module = nn.Linear(3, 3, ensemble_size=1)

Some requirements:

  • this should work on user-defined modules as well

Should matmul have a decomposed batch rule or an actual one?

Right now the batch rule for matmul is decomposed (https://github.com/zou3519/functorch/blob/53144b92d33d6d796359c97764ee68743f5463bf/functorch/csrc/BatchingRegistrations.cpp#L1254).

My worry is that it might be possible for us to transform some code into inefficient code. For example, if B0 and B1 are vmap dimensions and we are matrix-multiplying tensor of size [B0, 5, 5], [B1, 5, 5], we don't want to multiply tensors of size [B0, 1, 5, 5] and [1, B1, 5, 5]. If that happens, then internally, matmul will expand the tensors to [B0, B1, 5, 5] and materialize the full memory, which can be quite slow. (The ideal way to multiply these tensors is to reshape them into [B0 * 5, 5] and [5, B1 * 5], and then multiply them together).

This issue is probably just a code reading exercise to see if it's possible for the above to happen in the decomposed matmul code. I was in the middle of writing a non-decomposed matmul here: https://gist.github.com/zou3519/ddd4b2d4aacc98bf20d114f26b27b082

Improve make_functional*

Things that we can and should do now:

  • make_functional* should not destroy the original model
  • Let func be the function returned by make_functional_with_buffers. func should accept arguments as func(params, buffers, *args, **kwargs) instead of func(params, buffers, args) (what it currently accepts).
  • * We can make func into a special FunctionalModule class so it is registered as a subclass of nn.Module. This makes it so that func.eval() and func.train() work; furthermore, func can be registered as a submodule of an owning Module. func is still callable like a function and has no state.
  • We should probably provide a helper function to “stack” weights and buffers to prepare them for vmap
    Probably just accept a list of the “same” module and we will stack them together.

Things we should consider but are more tricky:

  • Combine the weights and buffers return value. The difficulty around this is the interaction with functorch.grad

Can't call `torch.tensor` within grad

from functorch import grad, vmap
import torch

def f(x):
    t = torch.tensor(0)
    return t + x
inps = (torch.randn([]),)
print(grad(f)(*inps))
Traceback (most recent call last):
  File "t.py", line 8, in <module>
    print(grad(f)(*inps))
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 152, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 110, in wrapper
    output = f(*args)
  File "t.py", line 5, in f
    t = torch.tensor(0)
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

grad doesn't run when under `torch.no_grad()`

from functorch import grad, vmap, pythonkey_trace, wrap_key
import torch
import torch.fx as fx

def f(x):
    return torch.sin(x)
with torch.no_grad():
    print(grad(f)(torch.randn(())))
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Not totally sure what the semantics should be... but I kinda think we should be ignoring torch.no_grad().

PyTorch Lightning Integration

Hey there,

Awesome work with this library !

I am part of the PyTorch Lightning team and
we could integrate FuncTorch pretty easily before it gets upstreamed to PyTorch. In Lightning, it will be added as a new TrainingType Plugin and it should make benchmarking simpler for you as many models are already implemented.

It could look like Trainer(accelerator='pmap').

If you are interested, please join Lightning Slack and PM me 🤗

Best,
T.C

vjp is silently incorrect when copy_ is involved

Repro:

import torch
import functorch
from functorch import vjp

torch.manual_seed(0)

x = torch.randn(3, 5)
gy = torch.randn(3)

ggx = torch.arange(15, dtype=torch.float).view(3, 5)

def gp(x, gy):
    res = torch.zeros(3, 5)
    res.diagonal(2).copy_(gy)
    return res

gx, vjp_fn = vjp(gp, x, gy)
result = vjp_fn(ggx)

expected = torch.diag(ggx, 2)
print(result[1])
print(expected)
assert torch.allclose(result[1], expected)

Handle cases where the gradients are 0 for inputs

from functorch import grad
import torch

def f(x):
    return (x[0]**2.0).sum()
inps = (torch.randn(3), torch.randn(3))
fx_graph = grad(f)(inps)

Error:

Traceback (most recent call last):
  File "t.py", line 7, in <module>
    fx_graph = grad(f)(inps)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 155, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 135, in wrapper
    output, flat_diff_args, create_graph=True)
  File "/home/chilli/fb/pytorch/torch/autograd/__init__.py", line 228, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

CUBLAS_STATUS_ALLOC_FAILED when jacrev of jacrev of matmul

import torch
from functorch import jacrev

device = 'cuda'

N = 5
M = 3
W = torch.randn(N, M, device=device)

def f(x):
    return W @ x

x = torch.randn(M)
result = jacrev(jacrev(f))(x)
expected = torch.zeros(N, M, M, device=device)
assert torch.allclose(result, expected)

test/test_eager_transforms.py::TestVmapOfGradCPU::test_log_softmax_cpu is broken

RuntimeError: backward() called inside torch.vmap. This is not supported, please call backward() outside torch.vmap or instead use torch.autograd.grad inside torch.vmap

This is what that test looks like:

def test_log_softmax(self, device):
    x = torch.randn(3, 5)
    v = torch.randn(5)

    def foo(x, v):
        _, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x)
        return vjp_fn(v)[0]

    result = vmap(foo, (0, None))(x, v)

    v = v.expand_as(x)
    x.requires_grad_()
    output = torch.log_softmax(x, dim=-1)
    output.backward(v)
    self.assertEqual(result, x.grad)

Ops used in torch.distributions to lower

HalfCauchy:

{'div', 'pow', 'sum', 'log', 'lt', 'exp', 'add', 'mul', 'index_put_', 'neg', 'unsqueeze', 'expand', 'sub', '_local_scalar_dense'}

LKJCholesky:

{'ge', 'softplus_backward', 'sum', 'tanh_backward', 'softplus', 'sub', 'tanh', 'add', 'mul', 'neg', '_s_where', 'clamp', 'logical_and', 'le', 'expand', 'index_put_'}

Uniform:

{'div', 'sigmoid', 'sum', 'softplus_backward', 'softplus', 'copy_', 'add', 'mul', 'neg', 'clamp', 'le', 'gt', 'sub'}

Bernoulli:

{'neg', 'binary_cross_entropy_with_logits', 'sub', 'sum'}

Beta:

{'sigmoid_backward', 'sigmoid', 'unbind', 'ge', 'sum', 'softplus', 'mul', 'stack', 'sub', 'log', 'softplus_backward', 'add', 'rsub', 'neg', 'logical_and', 'div', 'clamp', 'le', '_s_where'}

Dirichlet:

{'sigmoid_backward', 'sigmoid', 'log_sigmoid_backward', 'expand', 'ge', 'slice', 'getitem', 'sum', 'mul', 'constant_pad_nd', 'log', 'sub', 'add', 'rsub', 'neg', 'logical_and', 'log_sigmoid_forward', 'div', 'cumprod', 'copy_', 'clamp', 'le', '_s_where'}

HalfNormal:

index_select

grad transform failing tests tracking issue

Current fails for:

Calls internal_new_from_data: (#65)

  • __getitem__
  • __rpow__ (straight up calls torch.tensor)
  • torch.tensor
  • Tensor.new()

Data pointer accessed by helper function (#65)

  • linalg.cholesky (linalg_cholesky calls linalg_cholesky_ex (prim) and does error checking)
  • linalg.inv (linalg_inv calls linalg_inv_ex (prim) and does error checking)
  • linalg.matrix_power (can call inv)

The norm problem (#14); AKA: CompositeImplicitAutograd op calls an "out= variant" that calls raw native::resize_ on tensors.

  • linalg.matrix_norm
  • linalg.norm
  • nanquantile
  • quantile

Requires an integer tensor for the "splits" argument...

  • tensor_split

Test by uncommenting out https://github.com/zou3519/functorch/blob/ae97def8eb8508418053a1a7c81371b9b44dcc3d/test/test_grad.py#L49. I haven't investigated the problems yet.

Miscellaneous non-OpInfo problems (test_torch.py)

  • Tensor.numpy
  • Tensor.tolist
  • copy.copy
  • to_dlpack
  • repeat_interleave
  • Tensor.map_
  • Tensor.map2_
  • pickle.dumps
  • printing
  • torch.sobol_engine_initialize_state
  • assigning to Tensor.data

Miscellaneous non-OpInfo problems (test_nn.py)

  • F.ctc_loss
  • F.max_pool1d (testing artifact)
  • Lazy modules

Miscellaneous non-OpInfo problems (test_linalg.py)

Miscellaneous non-OpInfo problems (test_tensor_creation.py)

Miscellaneous non-OpInfo problems test_unary_ufuncs.py

  • conj

https://docs.google.com/spreadsheets/d/18sv-cKBqMGVCNdclFk5jB9LmQJGzb_eNAE9O2-oep3Q/edit?usp=sharing

Figure out how to transform over optimizers

One way to transform over training loops (e.g. to do model ensembling or the inner step of a MAML) is to use a function that represents the optimizer step instead of an actual PyTorch optimizer. Right now I think we have the following requirements

  • There should be a function version of each optimizer (e.g. F.sgd)
  • The function should have an option to not mutate (e.g. F.sgd(..., inplace=False))
  • The function should be differentiable

PyTorch already has some here (in Prototype stage): https://github.com/pytorch/pytorch/blob/master/torch/optim/_functional.py, so we should check if these fit the requirements, and, if not, decide if we should influence the design

Transforms don't work with new_ones

import torch
from functorch import vmap,

def f(x):
    return x.new_ones(x.shape)

print(vmap(f)(torch.randn(3)))

>>> RuntimeError: DispatchKey FuncTorchBatched doesn't correspond to a device

indexing with a `True` tensor fails under grad

import torch
from functorch import grad

def f(value):
    log_prob = torch.ones(())
    val = (torch.zeros(()) > 0)
    log_prob[val] = 0
    return value

grad(f)(torch.randn(()))

>>> Traceback (most recent call last):
  File "t.py", line 13, in <module>
    grad(f)(torch.randn(()))
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 178, in wrapper
    results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
  File "/home/chilli/fb/functorch/functorch/_src/eager_transforms.py", line 142, in wrapper
    output = f(*args)
  File "t.py", line 10, in f
    log_prob[val] = 0
NotImplementedError: Cannot access storage of TensorWrapper

Set up a CI and figure out how functorch depends on PyTorch master

The current scheme we're working with is "functorch should always work with the existing PyTorch viable/strict branch".

Motivations:

  • If a change to PyTorch core requires a change to functorch, it would be nice to catch it sooner than later
  • As the functorch test suite balloons, it is nice to have tests for committed code.

TODO:

  • Initial functorch CI, runs all tests except test_pythonkey.py #53
  • Add a build for LLVM PyTorch that runs test_pythonkey.py

Prototype vmap over data-dependent control flow

There needs to be something to:

  1. capture control flow
  2. represent control flow in a form that is transformable
  3. actually transform the control flow (e.g. a batching rule)

We don't have to worry too much about (1) for now. A way to prototype (2) would be to have something like control flow operators. These can either be python-based, or go through the PyTorch C++ dispatcher (!!).

CompositeImplicitAutograd ops that call *_like or new_* operators fail under certain transforms

import torch
import functorch
from functorch import vmap, grad

N = 3
C = 5

device = 'cpu'

def foo(x):
    result = x.contiguous()
    return result.sum()

x = torch.randn(C, N, device=device).t()
result = vmap(grad(foo))(x)

fails with

RuntimeError: vmap: aten::copy_(self, *extra_args) is not possible because there exists a Tensor `other` in
extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` no
t being vmapped over at level 2. Please try to use out-of-place operators instead of aten::copy_. If said op
erator is being called inside the PyTorch framework, please file a bug report instead.

Transform testing tracking issue

Add OpInfo-based testing for:

  • vmap
  • grad
  • vjp
  • vjp of vjp
  • vjp of vmap
  • vmap of vjp
  • vmap of vmap

grad is really a special case of vjp. Do we need more tests for it?

  • See if we can implement grad by calling vjp.
  • There are a lot of exceptions for in-place operations. See if we can/should test in-place in OpInfo testing
  • Tests to check which ops have batching rules is useful to make sure we actually register the batching rules correctly
  • Stress testing for TensorWrapper: Wrap all Tensors in TensorWrappers and send them through the PyTorch test suite...

Print vmap warnings by default

  1. torch._C._debug_only_display_vmap_fallback_warnings -> Use a functorch API instead of this
  2. Turn on warnings by default so that it is clear we are not promising good perf

Multiple Inner Loops

Great project --

What would be the most straightforward way to allow for multiple inner training loops?

Specifically with regards to the MAML example, how could I allow for a user defined number of inner loops?

def get_loss_for_task(x1, y1, x2, y2):
        def inner_loss(params, x1, y1):
            f = net(params, (x1,))
            loss = mse_loss(f, y1)
            return loss

        grads = grad(inner_loss)(params, x1, y1)
        new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]

        v_f = net(new_params, (x2,))
        return mse_loss(v_f, y2)

As a note, after changing losses.append(loss2) to losses.append(loss2.item()) I was able to plot the results

Running list of batching rules that needs to be implemented

For any examples we're running, if we see fallback warnings, we can add it to the list here so that we have a list of batching rules we can chip away at. We can put our name next to one if we're planning on adding it.

Batching Rules needed for Omniglot:

  • aten::mkldnn_convolution
  • aten::native_batch_norm
  • aten::nll_loss_forward
  • aten::nll_loss_backward
  • aten::_log_softmax_backward_data
  • aten::max_pool2d_with_indices_backward
  • aten::threshold_backward
  • aten::native_batch_norm_backward
  • aten::mkldnn_convolution_backward
  • aten::conv2d
  • aten::batch_norm
  • aten::linear
  • aten::nll_loss_nd
  • aten::argmax
  • aten::eq.Tensor @zou3519

Parallel Train:

  • aten::nll_loss_forward
  • aten::nll_loss_backward
  • aten::_log_softmax_backward_data
  • aten::threshold_backward @zou3519

DP Cifar10:

  • aten::mkldnn_convolution
  • aten::native_group_norm
  • aten::relu_ @zou3519
  • aten::thnn_conv2d_forward
  • aten::add_.Tensor @zou3519
  • aten::nll_loss_forward
  • aten::nll_loss_backward
  • aten::_log_softmax_backward_data
  • aten::threshold_backward @zou3519
  • aten::reciprocal_ @zou3519
  • aten::clamp_min: Warning: make sure this compiles using clang too
  • aten::thnn_conv2d_backward.output_mask
  • aten::max_pool2d_with_indices_backward
  • aten::mkldnn_convolution_backward
  • aten::cudnn_convolution
  • aten::cudnn_convolution_backward
  • aten::native_group_norm

From #26:

  • aten::rsub.Scalar

  • aten::diag

  • aten::where.Scalar

  • aten::allclose

  • advanced indexing (index, index_put_)

Top 100 torch.foo:

  • t 6837449
  • tensor 585786
  • mode 462182
  • cat 394818
  • max 368038
  • zeros 329495
  • load 327756
  • no_grad 294694
  • save 265130
  • from_numpy 243063
  • manual_seed 165044
  • ones 153696
  • randn 150796
  • stack 133358
  • sum 130772
  • arange 98087
  • rand 94715
  • mean 88546
  • exp 73883
  • zeros_like 72831
  • min 72248
  • sigmoid 66798
  • log 62135
  • matmul 47811
  • clamp 45304
  • sqrt 44911
  • abs 43535
  • tanh 42793
  • empty 40311
  • argmax 38435
  • bmm 33984
  • pow 33571
  • norm 31125 (deprecated?)
  • mm 30995
  • is_tensor 29546
  • ones_like 29512
  • nonzero 28681 (dynamic)
  • full 28373
  • unsqueeze 27911
  • where 26585
  • randperm 26450 (random)
  • eye 24342
  • mul 23236
  • topk 22537
  • as_tensor 21967
  • sort 21412
  • squeeze 20863
  • randint 20771 (random)
  • linspace 20041
  • add 19201
  • transpose 18663
  • split 18325
  • gather 17904
  • set_grad_enabled 16013
  • sin 15669
  • cos 15562
  • div 15513
  • index_select 14866
  • multinomial 14331 (random)
  • flatten 14267
  • isnan 14170
  • randn_like 13096 (random)
  • eq 12680
  • einsum 12480
  • round 12367
  • floor 11628
  • allclose 11000
  • reshape 10605
  • diag 10167
  • chunk 9581
  • std 9379
  • set_default_tensor_type 9281
  • triu 8559
  • meshgrid 8292
  • set_num_threads 8126
  • unique 7964 (dynamic)
  • full_like 7780
  • tril 7538
  • dot 7275
  • sign 6943
  • equal 6916
  • normal 6750 (random)
  • cumsum 6556
  • dist 6058
  • isfinite 6030
  • gt 5935
  • set_printoptions 5888
  • range 5491
  • empty_like 5351
  • flip 5342
  • masked_select 5341 (sometimes dynamic)
  • bernoulli 5262 (random)
  • atan 5253
  • var 5247
  • prod 5200
  • erf 5088
  • inverse 5072
  • addmm 4854
  • logsumexp 4582

Handle namedtuples and PyTorch's special return types correctly

Right now, the vmap and grad transforms will ignore namedtuples and pretend they are tuples. This leads to the names getting stripped. We should change pytrees to support named tuples and PyTorch's special return types. I don't know if PyTorch's special return types are actually named tuples though.

Question: `grad` and static compute graphs?

Hi all,

This looks like an awesome idea — it would be amazing to combine some of the functional transformations abilities of the smaller frameworks with the power of PyTorch!

Looking briefly at the code, it looks like your grad generates calls back into the autograd engine. Do you have any plans, along the lines of pytorch/pytorch#35215, to enable generating static compute graphs for the derivative of a function? Or is your plan to always use the dynamic autograd engine?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.