Finally we have this PR to merge allow_in_graph/inline/skip trace rules into ```trace_rules.lookup_inner```, where we can define and lookup trace rules at both function level and file level. Going forward, this is the central place that we define and consulte Dynamo trace rule for any function.
* ```trace_rules.looup``` is the API can return allow_in_graph, inline or skip.
* ```skipfiles.check``` is the API can return inline or skip, since we have multiple places that only do inline/skip check.
* I'll move ```skipfiles.check``` to ```trace_rules.check``` as one of the follow-ups.
* Both functions consulte ```trace_rules.lookup_inner``` to get the tracing rule.
To avoid a single big PR, I left a few items as the follow-ups:
* Remove ```skipfiles.py``` and merge the code into ```trace_rules.py```.
* We do double check in ```symbolic_convert.check_inlineable```, will refactor and simplify it. We should only do inline/skip check before generating ```SkipFilesVariable``` and ```UserFunctionVariable```.
* Rename ```SkipFilesVariable``` as ```SkipFunctionVariable```, since we only handle functions.
* The inline/skip reasons are not logged for some cases, since the new lookup framework doesn't always return inline/skip reasons. I'll refactor loggings to record the inline/skip reason in next step.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118971
Approved by: https://github.com/jansel
Before the PR, we have a graph break for code like this,
```python
def test_get_device_properties_tensor_device(a):
x = a.to("cuda")
prop = torch.cuda.get_device_properties(x.device)
if prop.major == 8:
return x + prop.multi_processor_count
return x + prop.max_threads_per_multi_processor
```
This PR constant folds the torch.cuda.get_device_properties and we'll get a following dynamo graph:
```python
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] def forward(self, L_a_ : torch.Tensor):
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] l_a_ = L_a_
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:544 in test_get_device_properties_tensor_device, code: x = a.to("cuda")
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] x = l_a_.to('cuda'); l_a_ = None
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:547 in test_get_device_properties_tensor_device, code: return x + prop.multi_processor_count
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] add = x + 108; x = None
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] return (add,)
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
```
The signature of get_device_properties is:
```python
def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
```
I think it's safe to constant fold get_device_properties():
1. torch.cuda.get_device_properties(tensor.device). In this case, tensor.device.index is guarded in _check_tensor
2. torch.cuda.get_device_properties(device_int_id). We don't expect the GPU properties for a particular index changes during a torch.compile run and it make sense to specialize the properties for a concrete device_int_id.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118422
Approved by: https://github.com/yanboliang, https://github.com/jansel
Before the pr, we have an error for the following code
```python
def k(x):
with torch.inference_mode():
x = x + 1
return x
torch.compile(k, backend="eager", fullgraph=True)(x)
```
error message:
```
Traceback (most recent call last):
....
return InferenceModeVariable.create(tx, args[0].as_python_constant())
torch._dynamo.exc.InternalTorchDynamoError: list index out of range
```
This pr supports the case when torch.inference_mode is not provided any argument (i.e. default to True).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118427
Approved by: https://github.com/yanboliang, https://github.com/jansel
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.
Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
Before the PR, we have a graph break for the following test:
```python
def test_cublas_allow_tf32(x):
if torch.backends.cuda.matmul.allow_tf32:
return x.sin() + 1
return x.cos() - 1
```
In this PR, we first add "torch.backends.cuda" to MOD_INLINELIST to trace through the python binding and get the actual call torch._C._get_cublas_allow_tf32, where it's already a TorchInGraphVariable. Because _get_cublas_allow_tf32 is accessing the same variable as at::globalContext().allowTF32CuBLAS(), which is guarded by dynamo as a global state [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp#L443), we could safely assume it returns a ConstantVariable during tracing.
After this pr, we get the following graph:
```python
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_x_ : torch.Tensor):
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_x_ = L_x_
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:515 in test_cublas_allow_tf32, code: return x.cos() - 1
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] cos = l_x_.cos(); l_x_ = None
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sub = cos - 1; cos = None
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (sub,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118236
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
fixes https://github.com/pytorch/pytorch/issues/90552. This is a simpler fix that just detects the situation where AOTAutograd can't create a proper backward graph for the situation and graph breaks. This was technically a silent correctness issue before.
This PR tries to always graph break when we see a factory function that returns a tensor requiring grad. I check this by seeing if the op returned a `TensorVariable` in dynamo, and if one of the input arguments was a `requires_grad=True` kwarg. I think this is high-fidelity enough, and I'm also hoping that this is uncommon enough that a graph break is reasonable here.
The fix to avoid the graph break in user land is also pretty easy - just instantiate your tensor outside of the compiled region and plumb it in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113277
Approved by: https://github.com/eellison
ghstack dependencies: #113267, #113416, #113584
Fixes https://github.com/pytorch/pytorch/issues/113010
In eager mode, when you call an out= op like `add(..., out=out_arg)` with an out argument that is noncontiguous, the noncontiguous out arg will be returned directly. When we functionalize though, functionalization replaces it with a call to `add(...)` which ignores the contiguity of the original out arg.
Instead of trying to support this, this PR detects that situation and graph breaks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113267
Approved by: https://github.com/albanD
Currently custom VariableTrackers exist for classes that live outside of pytorch.
For these cases dynamo currently eagerly imports the module to get the class
object to compare against.
This instead uses `sys.modules.get("module.path")` such that the module is never
imported by dynamo itself, but if the user has imported the module then we will
still access the module and grab the type we need to compare against.
I noticed this issue because importing `KeyedJaggedTensor` fails half-way
through if `fbgemm_gpu` has been built with an incompatible PyTorch version, in
which case it retries the import again each time!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112319
Approved by: https://github.com/lezcano, https://github.com/ezyang