Giter Site home page Giter Site logo

Comments (7)

zou3519 avatar zou3519 commented on August 16, 2024

PyTorch has really weird special cases around scalar tensors. Not sure how easy it will be for us to replicate the semantics.

from functorch.

Chillee avatar Chillee commented on August 16, 2024

Not sure how easy it will be for us to replicate the semantics.

I suspect we're going to need to do so regardless.

from functorch.

laurencer avatar laurencer commented on August 16, 2024

@zou3519 - what's the expected behavior in this case? I'm trying to figure out what this should actually do.

Reasoning it out:

  1. x = torch.randn((3)) gives a tensor with shape (3) (dim=1).
  2. f(x) : Tensor<{arbitrary shape; dim >= 1}> -> Tensor<{resultant shape; dim >= 1}>
  3. vmap(f(x)) : Tensor<{arbitrary shape + 1 dim}> -> Tensor<{resultant shape + 1 dim}>

The types/shapes don't match up for x. E.g. x only has a single dimension when passed to vmap(f(x)) and would thus fail? Or should x be promoted to (3, 1) automatically (or (1, 3)) and then passed to vmap(f(x))?

This seems like an issue in the general dispatcher (rather than a batching rule registration style). E.g. when you promote a function using vmap the resultant function should check for scalar inputs (i.e. dim=1) and automatically promote to dim=2?

from functorch.

zou3519 avatar zou3519 commented on August 16, 2024

@laurencer good question. For out-of-place operations and ignoring views, vmap(f)(x) should be equivalent to running torch.stack([f(xi) for xi in x.unbind(0)]). This heuristic tells us the following:

  1. x = torch.randn((3)) gives a tensor with shape (3,) (dim=1)
  2. f(x[0]) gives a tensor with shape [] (dim=0), so torch.stack([f(xi) for xi in x.unbind(0)]) gives a tensor of shape [3] (dim=1)
  3. so vmap(f)(x) should give us a tensor with shape (3,) (dim=1)

I think the batching rule for squeeze should check to see if the tensor has dim 1 (and that the dim argument is equal to 0). If it does, then it returns an alias of the tensor (via tensor.alias()).

from functorch.

laurencer avatar laurencer commented on August 16, 2024

@zou3519 - is there a description of the new versus old style of batching rules?

I'm having a bit of trouble understanding what the following does and whether it's still needed in the new-style:

if (!participatesInCurrentLevel(self)) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    return self.squeeze();
}

From the docs - it seems as though it's now a more functional approach where raw/native Torch tensors are passed in (instead of wrappers) along with optional parameters that describe the batch dimension (instead of being carried on the wrappers).

Is it right to assume that the optional<uint_64> parameter passed after every Tensor is the number of batch dimensions (e.g. if it is set to 2 then the first 2 dimensions of the Tensor are batch dimensions?).

If it is empty then there are no batch dimensions?

Also if I return an empty batch dimension - then it appears to assume that there is 1 batch dimension (or this might be the original input number)?

Also there's implementations for squeeze and squeeze.dim - through trial-and-error I figured out that squeeze.dim is invoked when the dim optional parameter is passed. What's the mechanism/way to understand more about this dispatch method (or is this just the PyTorch dispatcher)?

from functorch.

zou3519 avatar zou3519 commented on August 16, 2024

is there a description of the new versus old style of batching rules?

Unfortunately no, not yet. This is the doc for the new style, but the old style was me hacking everything to work.

I'm having a bit of trouble understanding what the following does and whether it's still needed in the new-style:

That is not needed in the new-style if you use the VMAP_SUPPORT macro.

Is it right to assume that the optional<uint_64> parameter passed after every Tensor is the number of batch dimensions (e.g. if it is set to 2 then the first 2 dimensions of the Tensor are batch dimensions?).

The optional<int64_t> passed after every Tensor is "the index of the batch dimension" (not the number of batch dimensions!) The new-style batching rules assume that there is only a single batch dimension (but there is some magic somewhere else that allows the batching rule to operate on multiple batch dimensions).

If it is empty then there are no batch dimensions?

Yes

Also if I return an empty batch dimension - then it appears to assume that there is 1 batch dimension (or this might be the original input number)?

You should be able to return nullopt

Also there's implementations for squeeze and squeeze.dim - through trial-and-error I figured out that squeeze.dim is invoked when the dim optional parameter is passed. What's the mechanism/way to understand more about this dispatch method (or is this just the PyTorch dispatcher)?

This is just the PyTorch dispatcher, but the summary is:

  • Look in native_functions.yaml
  • Depending on how you invoke squeeze in python, one of the {squeeze, squeeze.dim} operators gets called. If one just calls x.squeeze(), it'll invoke the squeeze operator, if one calls x.squeeze(1), it'll invoke the squeeze.dim operator.
  • There's some logic in PyTorch's python bindings that parses the inputs to e.g. squeeze and selects which of {squeeze, squeeze.dim} actually gets called. If you're interested I can point you to that code (it is autogenerated so I can't link you it directly on github)

from functorch.

zou3519 avatar zou3519 commented on August 16, 2024

Fixed by #81, thank you @laurencer!

from functorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.