I feel uneasy about touching `__warningregistry__` since it is undocumented and private surface. The only public API hook that doesn't increment warnings version seems to be https://docs.python.org/3/library/warnings.html#warnings.showwarning.
So we could wack a mole all the warnings muters in compile to just not display warnings, and we wouldn't invalidate warnings cache. This PR adds it for torch/_dynamo, and I didn't find any warnings versioning mutation from torch/_inductor.
There is a behavior change if someone calls a compiled graph with simplefilter("error"):
```python
# e.g. test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing
with warnings.catch_warnings():
warnings.simplefilter("error") # turns all warnings into errors
compiled_fn() # will throw if any of the muted warnings fire
```
FIXES https://github.com/pytorch/pytorch/issues/128427
A note for the future: The warnings module doesn't offer a thread safe way of using it. Even regular filters have this problem, directly editing `__warningregistry__` would be very bad, and this PR would mute all threads. Someone will need to build a thread safe warnings interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158520
Approved by: https://github.com/anijain2305, https://github.com/zou3519
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.
Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #156762, #155166
This PR includes the GBID weblink whenever a user encounters a graph break. I also had to include the JSON file in setup.py, so it can be part of the files that are packaged in during CI. It also fixes the issue of the hardcoded error messages stripping away one of the '/' in 'https'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156033
Approved by: https://github.com/williamwen42
This PR includes the GBID weblink whenever a user encounters a graph break. I also had to include the JSON file in setup.py, so it can be part of the files that are packaged in during CI. It also fixes the issue of the hardcoded error messages stripping away one of the '/' in 'https'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156033
Approved by: https://github.com/williamwen42
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.
Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #155166
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.
Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #155166
Summary:
Today when guard serialization fails, dynamo will raise an internal error like:
```
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: CLOSURE_MATCH guard cannot be serialized.
```
Adding a dedicated PackageError type to surface the error more clearly.
Test Plan: CI
Differential Revision: D75452124
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154430
Approved by: https://github.com/jamesjwu, https://github.com/jansel
Also show the line of code relevant to a dynamo-compiled frame, instead of just the first line (this was broken for data-dependent jump graph breaks and for 3.11+).
Also collapses resume frames together (use config.verbose to see full stack trace - for developers).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148401
Approved by: https://github.com/zou3519, https://github.com/jansel
Fixes#143406
After this PR the error for missing Triton is:
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3624, in create_backend
raise TritonMissing(inspect.currentframe())
torch._dynamo.exc.TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
Setting `TORCHDYNAMO_VERBOSE=1` yields something like the old error:
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 576, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1383, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1167, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 548, in __call__
return _compile(
^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 988, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 716, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 751, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 232, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 663, in transform
tracer.run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 2870, in run
super().run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 1053, in run
while self.step():
^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 963, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3050, in RETURN_VALUE
self._return(inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3035, in _return
self.output.compile_subgraph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1102, in compile_subgraph
self.compile_and_call_fx_graph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1383, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1433, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1463, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/__init__.py", line 2314, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1880, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1145, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 754, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 676, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 489, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1758, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 686, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1044, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1975, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1981, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1916, in codegen
self.scheduler.codegen()
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3667, in codegen
return self._codegen()
^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3761, in _codegen
if device is not None and self.get_backend(device).ready_to_flush():
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3631, in get_backend
self.backends[device] = self.create_backend(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3624, in create_backend
raise TritonMissing(inspect.currentframe())
torch._dynamo.exc.TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
This PR also strips dynamo stack frames from other types of backend compile errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143552
Approved by: https://github.com/yanboliang
Fixes#130559
* Intro
This PR adds support for `@contextmanager` in Dynamo. We chose to limit the
scope of this work to only `@contextmanager` and plan to handle generators fully
in #141055 (still in draft).
* Motivation
Dynamo lacks support for generator functions. When it encounters one, it traces
it as if it were a regular function. This is problematic because it can lead to
incorrect behavior. To illustrate, consider the test case below:
```python
import torch
import contextlib
@contextlib.contextmanager
def set_default_dtype(dtype):
old_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(old_dtype)
@torch.compile(backend="eager", fullgraph=True)
def fn():
with set_default_dtype(torch.float64):
x = torch.tensor([3.0, 3.0 + 5.0j])
return x
```
Before this work, Dynamo would not stop at the `yield`, and the graph produced
would contain both calls to `set_default_dtype` executed one after the other.
This is incorrect because the context manager should execute code before and
after the `yield`.
* List of changes
`YIELD_VALUE` now raises an exception (`YieldValueOp`) to signal that control
flow must be suspended and returned to the caller. Additionally, `RETURN_VALUE`
behaves differently in a generator function. Unlike regular functions, where
`RETURN_VALUE` indicates the final result, in generators it signifies that the
generator is exhausted and implicitly raises `StopIteration`.
A new `VariableTracker` named `FunctionDecoratedByContextlibContextManagerVariable`
was introduced to handle `@contextmanager`. This variable tracker acts not just
as a wrapper for the original function but also maintains an internal `tx`
(InstructionTranslator) object to suspend and return control flow to the parent
tracer when a `yield` is encountered.
* Corner cases
Returning a context manager from a compiled function is not supported. This
would require PyTorch to synchronize the generator state between Dynamo and the
interpreter. Any attempt to return it will result in an `IncorrectUsage`
exception.
Graph breaks require special handling as well. In the event of a graph break,
the frame associated with the context manager is skipped, and the context
manager runs in eager mode.
* This PR is breaking my code
There is a configuration flag (`enable_trace_contextlib`) that can be set to
`False` to disable tracing context managers. If this still causes crashes,
please revert this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136033
Approved by: https://github.com/zou3519
Fixes a bunch of benchmarks that failed with cudagraph errors including `tlp python benchmarks/dynamo/timm_models.py --device cuda --inductor --accuracy --amp --training --only resmlp_12_224` when `specialize_float=False`
Also brings down number of overall failures (with keep-going) from 108 => 62. I'd estimate >80% of those 62 are wobbly expect tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140346
Approved by: https://github.com/ezyang
ghstack dependencies: #140983, #141003