Giter Site home page Giter Site logo

Comments (12)

Chillee avatar Chillee commented on August 16, 2024 1

It sounds like the ask is that you'd like to generate the autograd graph AOT?

There's currently some prototype functionality (well, arguably the entire repo is prototype :P ) for tracing through the C++ dispatcher (which includes the autograd engine).

See the examples in the nnc folder: https://github.com/zou3519/functorch/blob/main/examples/nnc/simple_function.py

Specifically, make_fx will do this. However, it will 1. specialize on your inputs, and 2. return functions at the aten level.

image

from functorch.

Linux-cpp-lisp avatar Linux-cpp-lisp commented on August 16, 2024

... a very exciting prototype πŸ˜„

Our ideal outcome would be to take a TorchScript model that involves a torch.autograd.grad call and turn it into a TorchScript model that doesn't. The main reason for this is inference β€” being tied to the autograd engine severely limits our options for accelerated inference (TRTorch, for example, but also TorchScript optimizations). It would also be great to be able to train without walking the autograd graph twice, but that's secondary.

From what you're saying, it sounds like tracing through the C++ dispatcher is exactly right for this. To clarify a couple of points:

  1. What, exactly, is NNC? Having a hard time figuring out how it relates to TorchScript, etc.
  2. Would torch.jit.script(make_fx(model)) give me a ScriptModule I can use like any other? (model here would already contain calls to torch.autograd.grad.) In particular, a ScriptModel I can torch::jit::load into a C++ deployment context?
  3. When you say it will "specialize on your inputs," you mean specialize to dtype + device? Or something else?
  4. Do "functions at the aten level" have any particular limitations?

Really appreciate your taking the time!

from functorch.

Chillee avatar Chillee commented on August 16, 2024

Our ideal outcome would be to take a TorchScript model that involves a torch.autograd.grad call and turn it into a TorchScript model that doesn't.

I'm not sure that we'll able to start from a Torchscript model and trace out the autograd graph. I think it's possible in theory, but would require a bit more infrastructure that I'm not currently working on.

What, exactly, is NNC? Having a hard time figuring out how it relates to TorchScript, etc.

NNC (neural network compiler, also called tensorexpr in our code base) is a codegen compiler (kind of in the veins of TVM or Halide). Currently, in this repo, we're using it primarily for overhead reductions. The idea is that if your operations/tensors are small enough, PyTorch framework overhead is often the dominating factor. Generating a single binary blob that does your computation can lead to significant speedups.

Would torch.jit.script(make_fx(model)) give me a ScriptModule I can use like any other?

Currently, somewhat awkwardly, it doesn't work. This is since I'm tracing out into torch.aten.ops, and there's some awkward mismatches there. However, torch.jit.trace(make_fx(model)) does, and will give you a ScriptModule you can use like any other.

When you say it will "specialize on your inputs," you mean specialize to dtype + device?

And shape. Unfortunately, within the autograd engine, there's a lot of instances where the autograd rules will depend on the shapes of the tensor.

Do "functions at the aten level" have any particular limitations?

Not really - it's basically just the C++ API instead of the Python API.

from functorch.

Linux-cpp-lisp avatar Linux-cpp-lisp commented on August 16, 2024

I'm not sure that we'll able to start from a Torchscript model and trace out the autograd graph. I think it's possible in theory, but would require a bit more infrastructure that I'm not currently working on.

Would starting with a Python model work?

NNC (neural network compiler)...

Very cool! I'm assuming, then, that normal TorchScript optimization already uses NNC everywhere it can, and that using nnc_jit wouldn't give you any particular speed up?

And shape. Unfortunately, within the autograd engine, there's a lot of instances where the autograd rules will depend on the shapes of the tensor.

Hm, I see. For example, if I have:

def f(x):
    outshape = x.shape[:-1]
    x = x.reshape(-1, x.shape[-1])
    x = 2 * x
    return x.reshape(outshape + (x.shape[-1],))

I will not be able to generalize to different leading shapes on x ((10, 3) vs (13, 3), for example) even if they have the same number of dimensions?

from functorch.

Chillee avatar Chillee commented on August 16, 2024

Would starting with a Python model work?

Yes.

Very cool! I'm assuming, then, that normal TorchScript optimization already uses NNC everywhere it can, and that using nnc_jit wouldn't give you any particular speed up?

nnc_jit is currently targeted at the overhead-dominated CPU use cases, so I suspect that it might not be very useful for you. But in those use cases, it does some things like lower the entire model to a binary blob, so it can be substantially faster than Torchscript then.

I will not be able to generalize to different leading shapes on x ((10, 3) vs (13, 3), for example) even if they have the same number of dimensions?

Currently, no. This is something that we're aware of, and we're trying to figure out ways of addressing this.

from functorch.

Linux-cpp-lisp avatar Linux-cpp-lisp commented on August 16, 2024

Aha β€” can nnc_jit work with models of "real" complexity, inlining more complicated operations? (Say, tensordot, permute, cat, whatever.) Or is it limited to elementwise operations that are already supported for fusion?

Currently, no. This is something that we're aware of, and we're trying to figure out ways of addressing this.

πŸ‘ How strong are the shape dependencies in autograd, in a rough sense? In the example I gave above, for example, if I manually rewrote the fx graph to use the right dynamic shape as the argument to reshape, instead of the traced constant, would you expect it to give the right result for a modified batch dimension? (Same number of dimensions.)

from functorch.

Chillee avatar Chillee commented on August 16, 2024

can nnc_jit work with models of "real" complexity, inlining more complicated operations? (Say, tensordot, permute, cat, whatever.)

Yes - we currently have lowerings for permute and cat. Generating fast lowerings/schedules for things like tensordot is significantly harder, so we've usually just been calling PyTorch C++ implementations for those. In cases where you are primarily overhead bound this can still be a significant win. Once again though, this stuff currently only works on CPU.

How strong are the shape dependencies in autograd, in a rough sense? In the example I gave above, for example, if I manually rewrote the fx graph to use the right dynamic shape as the argument to reshape, instead of the traced constant, would you expect it to give the right result for a modified batch dimension?

It's hard to say - there are 2 sources of shape specialization in autograd. The first one is the user-facing stuff, where you can generally get around this stuff by re-implementing it with stuff like torch.flatten instead of explicit accesses to shapes. However, the harder stuff is shape specialization within C++.

It's possible that if we're just changing the batch dimension we can avoid it, but that requires some investigation that I haven't done.

from functorch.

Linux-cpp-lisp avatar Linux-cpp-lisp commented on August 16, 2024

...tensordot is significantly easier,...

Do you mean harder?

Yes - we currently have lowerings for permute and cat. Generating fast lowerings/schedules for things like tensordot is significantly easier, so we've usually just been calling PyTorch C++ implementations for those. In cases where you are primarily overhead bound this can still be a significant win. Once again though, this stuff currently only works on CPU.

Awesome!

However, the harder stuff is shape specialization within C++.

Would you guess that this is usually specialization for speed, or does it affect correctness? (Context: for us, we have "batch" dimensions that constantly change during inference, but don't change much.)

from functorch.

Chillee avatar Chillee commented on August 16, 2024

Do you mean harder?

whoops

Would you guess that this is usually specialization for speed, or does it affect correctness?

Hmm, it's often stuff like pulling out the shapes, and then doing an explicit reshape using those shapes or something like that.

from functorch.

Linux-cpp-lisp avatar Linux-cpp-lisp commented on August 16, 2024

Aha ok, so stuff that really is risky for correctness. Still, would be interesting to see if those problems come up for our networks. (I'd mostly worry about matmuls, tensordots, and einsums β€” does that sound right?)

from functorch.

Chillee avatar Chillee commented on August 16, 2024

I'm not actually totally sure when it comes up - perhaps it would just work in some cases? I'll make a note to investigate that further at some point.

from functorch.

Linux-cpp-lisp avatar Linux-cpp-lisp commented on August 16, 2024

πŸ‘ if you ever do end up looking into any of these things, would be very curious to hear what you find β€” I will also probably play around with functorch for this once I have the time.

Thanks very much for the answers + all the great work on compilers for PyTorch!

from functorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.