When we do `torch.compile(mod)`, we eventually end up returning a new
module instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) from the default `torch.nn.Module.__call__`.
As a result we can't reuse the inherited default `__call__` as is,
because we'd end up running the logic twice.
This patch makes the returned `OptimizedModule` override the default
`__call__`, and directly calls into its compiled `forward` method.
Fixes#149502
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305
When torch.compile is applied to a module via `mod.compile(...)`, it's equivalent to `torch.compile(mod._call_impl)` which takes a different path than `OptimizedModule`. This PR ensures that the `wrap_top_frame` config can also take effect for the `torch.compile(mod._call_impl)` use case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150209
Approved by: https://github.com/anijain2305
This preserves graph breaks in the case that one graph break directly causes another, e.g. graph breaks in generic context managers.
```python
import torch
class CtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
@torch.compile(backend="eager", fullgraph=True)
def fn():
with CtxMgr():
with CtxMgr():
pass
with CtxMgr():
with CtxMgr():
pass
torch._dynamo.graph_break()
fn()
```
Output:
```
torch._dynamo.exc.Unsupported: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/playground.py", line 23, in <module>
fn()
File "/data/users/williamwen/pytorch/torch/_dynamo/eval_frame.py", line 664, in _fn
raise e.with_traceback(None) from e.__cause__
torch._dynamo.exc.Unsupported: Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(CtxMgr), GenericContextWrappingVariable(CtxMgr)]
from user code:
File "/data/users/williamwen/pytorch/playground.py", line 20, in fn
torch._dynamo.graph_break()
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Note in particular that both graph breaks (torch._dynamo.graph_break and graph break in context manager) are present in the logs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149676
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
This allows for each device type to check current devices for Triton compatibility and ensure their Triton backend is present.
This PR replaces the `has_triton()` global method which was previously used for this task, and moves the initial check for each Inductor backend on to their associated `BaseScheduler` subclass. This means that other backends, such as Halide, can also implement their own availability checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139171
Approved by: https://github.com/jansel
Sometimes `eager_then_compile` stance isn't enough since some models are so close to the memory limit that going to eager will OOM since we don't get the memory reductions from activation checkpointing. This PR introduces `aot_eager_then_compile` which avoids the expensive inductor compile, but still does aot_eager to get the benefits of memory reduction in the first invocation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148509
Approved by: https://github.com/williamwen42
Recently I've been experimenting with introducing new APIs to delay compile as a way to reduce compile times while improving the ergonomics of using dynamic shapes. The high level idea is to run the first invocation of compile in eager, save the example inputs, and on the second invocation we can derive the dynamism in the inputs so that we don't need to waste our time doing a compile with static shapes (which is the status quo today with automatic dynamic).
Another benefit of this is most users no longer need to annotate their inputs with mark_dynamic and mark_unbaked calls since we can derive the dynamism on the very first call. Additionally we get dynamic ints out of the box in this new regime.
This PR implements this idea through the set_stance APIs. In particular it introduces a new `eager_then_compile` stance.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147983
Approved by: https://github.com/williamwen42
Fixes#143406
After this PR the error for missing Triton is:
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3624, in create_backend
raise TritonMissing(inspect.currentframe())
torch._dynamo.exc.TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
Setting `TORCHDYNAMO_VERBOSE=1` yields something like the old error:
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 576, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1383, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1167, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 548, in __call__
return _compile(
^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 988, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 716, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 751, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 232, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 663, in transform
tracer.run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 2870, in run
super().run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 1053, in run
while self.step():
^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 963, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3050, in RETURN_VALUE
self._return(inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3035, in _return
self.output.compile_subgraph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1102, in compile_subgraph
self.compile_and_call_fx_graph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1383, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1433, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1463, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/__init__.py", line 2314, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1880, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1145, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 754, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 676, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 489, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1758, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 686, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1044, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1975, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1981, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1916, in codegen
self.scheduler.codegen()
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3667, in codegen
return self._codegen()
^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3761, in _codegen
if device is not None and self.get_backend(device).ready_to_flush():
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3631, in get_backend
self.backends[device] = self.create_backend(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 3624, in create_backend
raise TritonMissing(inspect.currentframe())
torch._dynamo.exc.TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
This PR also strips dynamo stack frames from other types of backend compile errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143552
Approved by: https://github.com/yanboliang