Giter Site home page Giter Site logo

Comments (6)

borisfom avatar borisfom commented on June 11, 2024 1

Here, a repro case with fully dynamic Dims that used to work with last week PyTorch nightly (05/15), but fails with 05/20. dynamo_export works but export.export() fails :

import torch
from nemo.core.classes import typecheck
from nemo.utils.export_utils import wrap_forward_method, parse_input_example
from nemo.collections.nlp.models import PunctuationCapitalizationModel
model = PunctuationCapitalizationModel.from_pretrained(model_name="punctuation_en_distilbert")
model.cuda().eval()
wrap_forward_method(model)
model._prepare_for_export()
typecheck.set_typecheck_enabled(enabled=False)

with torch.no_grad():
    input_example = model.input_module.input_example(max_batch=4)
    input_list, input_dict = parse_input_example(input_example)

    print("Running torch.onnx.dynamo_export ...")
    options = torch.onnx.ExportOptions(dynamic_shapes=True)
    ex = torch.onnx.dynamo_export(model, *input_list, **input_dict, export_options=options)

    print("Running torch.export.export ...")
    x1 = torch.export.Dim("x1")
    x2 = torch.export.Dim("x2")
    x3 = torch.export.Dim("x3")
    b1 = torch.export.Dim("b1")
    b2 = torch.export.Dim("b2")
    b3 = torch.export.Dim("b3")
    dynamic_shapes={'input_ids': {0: b1, 1: x1}, 'attention_mask': {0: b2, 1: x2}, 'token_type_ids': {0: b3, 1: x3}}
    ex_model = torch.export.export(
        model,
        tuple(input_list),
        kwargs=input_dict,
        dynamic_shapes=dynamic_shapes,
        strict=False
    )

from pytorch.

borisfom avatar borisfom commented on June 11, 2024

This is what I get when run the repro. Specializing some dimensions would result in more efficient code, so it would be nice to either make guards calculation succeed in this case, or be able to ignore guards and treat this error as warning.

V0518 04:08:06.019000 139943363598144 torch/fx/experimental/symbolic_shapes.py:2289] create_env
I0518 04:08:06.050000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3260] create_symbol s0 = 1024 for L['args'][0][0].size()[0] [2, 9223372036854775806] (_export/non_strict_utils.py:92 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
V0518 04:08:06.051000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval True == True [statically known]
V0518 04:08:06.052000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval False == False [statically known]
I0518 04:08:06.173000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4661] eval Ne(s0, 20) [guard added] (_refs/__init__.py:3685 in _reshape_view_helper), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(s0, 20)"
I0518 04:08:06.177000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4661] eval Ne(Mod(s0, 20), 0) [guard added] (_refs/__init__.py:3694 in _reshape_view_helper), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(Mod(s0, 20), 0)"
V0518 04:08:06.178000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval Eq(s0, 1) == False [statically known]
V0518 04:08:06.179000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval True == True [statically known]
V0518 04:08:06.183000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval Ne(20*s0, 20) == True [statically known]
V0518 04:08:06.184000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval False == False [statically known]
V0518 04:08:06.206000 139943363598144 torch/fx/experimental/symbolic_shapes.py:4746] eval Ne(s0, 1) == True [statically known]
I0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3348] produce_guards
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=ValueRanges(lower=0, upper=oo, is_bool=False))
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].size()[1] 20 None
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].size()[2] 16 None
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].stride()[0] 320 None
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].stride()[1] 16 None
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].stride()[2] 1 None
V0518 04:08:06.234000 139943363598144 torch/fx/experimental/symbolic_shapes.py:3530] track_symint L['args'][0][0].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1088, in _export
    produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 270, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 238, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 3853, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of batch = L['args'][0][0].size()[0] in the specified range satisfy the generated guard Ne(L['args'][0][0].size()[0], 20).
  - Not all values of batch = L['args'][0][0].size()[0] in the specified range satisfy the generated guard Ne(Mod(L['args'][0][0].size()[0], 20), 0).
Suggested fixes:
  batch = Dim('batch')
 

from pytorch.

borisfom avatar borisfom commented on June 11, 2024

BTW, my real-world model that I wrote this repro after, does fail even when all dimensions are specified, or when dynamo_export is called directly:

torch._dynamo.exc.UserError: Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs.  Could not guard\
 on data-dependent expression Eq(4*u0**2, 0) (unhinted: Eq(s0*u0**2, 0)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
  File "/usr/local/lib/python3.10/dist-packages/torch/_decomp/decompositions.py", line 1126, in _softmax
    if x.numel() == 0:
...
  File "/git/NeMo/nemo/collections/tts/modules/transformer.py", line 115, in forward
    return self._forward(inp, attn_mask, conditioning)
  File "/git/NeMo/nemo/collections/tts/modules/transformer.py", line 148, in _forward
    attn_prob = F.softmax(attn_score, dim=2)

attn_prob is undergoing transformations similar to view() I used in my example. Not sure what I should mark/check as size.

from pytorch.

borisfom avatar borisfom commented on June 11, 2024

With the latest PyTorch nightly (05/20), I am also getting more failures on Nemo unit tests when trying to run export() with all dimensions being dynamic with no min/max, compared to running dynamo_export() directly.
Means: when I try dynamic Dims() for export() that would be equivalent to 'dynamic_shapes=True', I still can't export same networks that are exportable by calling dynamo_export(model, ... , dynamic_shapes=True) directly.

from pytorch.

justinchuby avatar justinchuby commented on June 11, 2024

So torch.export fails and torch.onnx.dynamo_export succeeds?

from pytorch.

borisfom avatar borisfom commented on June 11, 2024

Correct: torch.onnx.dynamo_export with dynamic_shapes=True succeeds, but if I have to use export.export.first, even if I specify all the axes as dynamic with no bounds (which should be equivalent to dynamic_shapes=True), it fails in export.export as it can't calculate bounds properly, even with strict=False.
I have to use export.export() first for large models to force external ONNX format, as direct dynamo_export() produces non-parseable huge ONNX in this case - I files separate bugs for those issues.

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.