partially address https://github.com/pytorch/pytorch/issues/118785
This diff fixes three things:
1. add get_function to FunctoolsPartialVariable note that it will be available only if all args constant otherwise,
it would throw unimplemented in the call to asPythonConstant.
2. NamedTupleVariable takes args dispatched not as list ex: NamedTuple(a, b, c) vs NamedTuple([a, b, c]),
hence fix that by specializing asProxy.
3. A call to create_arg from within create_proxy, changes a python NamedTuple to a function call node without
associating an example value! Updated get_fake_values_from_nodes to handle such case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119435
Approved by: https://github.com/jansel, https://github.com/anijain2305
ghstack dependencies: #119314
Fix https://github.com/pytorch/pytorch/issues/118787
In the compiled function, calls to random() are replaced with a single function call
to a function that generates all the random variables .
The random calls encountered during compilation used to be tracked inside a variable
stored inside the instruction translator. And when there are nested translators, the tracked
calls used to get lost when the inner instructions translator popped out.
This diff fixes that by moving the tracked calla to the output graph which is shared across translators that are generating the same function.
More details about the issue and why this solution is picked are in the github issue above.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119218
Approved by: https://github.com/jansel, https://github.com/anijain2305
```
def f():
def g():
return ()
print(g.__name__)
f()
```
The following script should print `g` (with or without torch.compile),
but prints `f.<locals>.g` with torch.compile.
The problem looks like we use the co_qualname when reconstructing the
NestedUserFunctionVariable. I switched this over to use the co_name.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118768
Approved by: https://github.com/yanboliang, https://github.com/jansel
Before the PR, we have a graph break for code like this,
```python
def test_get_device_properties_tensor_device(a):
x = a.to("cuda")
prop = torch.cuda.get_device_properties(x.device)
if prop.major == 8:
return x + prop.multi_processor_count
return x + prop.max_threads_per_multi_processor
```
This PR constant folds the torch.cuda.get_device_properties and we'll get a following dynamo graph:
```python
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] def forward(self, L_a_ : torch.Tensor):
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] l_a_ = L_a_
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:544 in test_get_device_properties_tensor_device, code: x = a.to("cuda")
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] x = l_a_.to('cuda'); l_a_ = None
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:547 in test_get_device_properties_tensor_device, code: return x + prop.multi_processor_count
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] add = x + 108; x = None
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] return (add,)
[2024-01-26 13:28:13,253] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]
```
The signature of get_device_properties is:
```python
def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
```
I think it's safe to constant fold get_device_properties():
1. torch.cuda.get_device_properties(tensor.device). In this case, tensor.device.index is guarded in _check_tensor
2. torch.cuda.get_device_properties(device_int_id). We don't expect the GPU properties for a particular index changes during a torch.compile run and it make sense to specialize the properties for a concrete device_int_id.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118422
Approved by: https://github.com/yanboliang, https://github.com/jansel
Before the PR, we have a graph break for the following test:
```python
def test_cublas_allow_tf32(x):
if torch.backends.cuda.matmul.allow_tf32:
return x.sin() + 1
return x.cos() - 1
```
In this PR, we first add "torch.backends.cuda" to MOD_INLINELIST to trace through the python binding and get the actual call torch._C._get_cublas_allow_tf32, where it's already a TorchInGraphVariable. Because _get_cublas_allow_tf32 is accessing the same variable as at::globalContext().allowTF32CuBLAS(), which is guarded by dynamo as a global state [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp#L443), we could safely assume it returns a ConstantVariable during tracing.
After this pr, we get the following graph:
```python
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_x_ : torch.Tensor):
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_x_ = L_x_
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: /home/yidi/local/pytorch/test/dynamo/test_functions.py:515 in test_cublas_allow_tf32, code: return x.cos() - 1
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] cos = l_x_.cos(); l_x_ = None
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] sub = cos - 1; cos = None
[2024-01-24 15:31:01,501] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (sub,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118236
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
* This is an old builtin function equivalent to the bool constructor. it is easy enough to add support for.
* I also realized the tests were in the wrong class (the one reserved for testing default args) so I moved them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117463
Approved by: https://github.com/jansel
The `initial` argument in `functools.reduce` can be `None`.
```python
initial_missing = object()
def reduce(function, iterable, initial=initial_missing, /):
it = iter(iterable)
if initial is initial_missing:
value = next(it)
else:
value = initial
for element in it:
value = function(value, element)
return value
```
Reference:
- python/cpython#102759
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116398
Approved by: https://github.com/Skylion007
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.
Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.
Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
Summary: Currently, we [`clone`](19207b9183/torch/_inductor/lowering.py (L5273)) every `TensorBox` argument of custom Triton kernels while lowering them to the Inductor IR, during which the stride information of the kernel inputs is lost. This is problematic in the common case when the strides of a `torch.Tensor` argument are passed as scalars to a custom Triton kernel alongside the tensor itself (due to the underlying Triton code interpreting the tensors as raw pointers, so the contained stride semantics of the `torch.Tensor` is lost).
In this PR, we add an extended version of the existing [`clone` lowering](19207b9183/torch/_inductor/lowering.py (L2289))---`clone_preserve_reinterpret_view`---which carries over the `ir.ReinterpretVew` layers (if any) from the source `TensorBox` to the cloned one. The rationale behind adding a new function (and switching to it in the `triton_kernel_wrap` only for now) as opposed to extending the existing `clone` is keeping the semantics of the latter untouched, as it is a lowering of `torch.clone` (albeit incomplete, as the `memory_format` is currently ignored). Changing the existing `clone` would change the semantics which is not necessarily desirable in general. Open to suggestions, though.
Test Plan:
```
$ python test/dynamo/test_functions.py -k test_triton_kernel_strided_input
...
----------------------------------------------------------------------
Ran 1 test in 5.568s
OK
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116219
Approved by: https://github.com/jansel
This PR fixes two bugs
1) Constant folding a triton kernel results in the kernel's inputs to be returned back without any modification. Disable constant folding for triton kernels. Need more investigation
2) NoneLayout buffers should not be deleted as they do not exist
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115908
Approved by: https://github.com/aakhundov, https://github.com/jansel
Due to not all tests in the Dynamo shard actually running in CI, we've
started to bitrot on this implementation. Since our plan is to trace
into the functorch implementations instead of construct a HOP
(which is what capture_func_transforms=True does), let's turn off this
config by default.
Test Plan:
- Tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115267
Approved by: https://github.com/voznesenskym, https://github.com/guilhermeleobas
Fixes#104797
```
File "/home/jansel/pytorch/torch/_dynamo/utils.py", line 1486, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/home/jansel/pytorch/torch/_dynamo/utils.py", line 1591, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
File "/home/jansel/pytorch/torch/_dynamo/utils.py", line 1570, in run_node
return node.target(*args, **kwargs)
File "/home/jansel/conda/envs/pytorch/lib/python3.10/site-packages/einops/packing.py", line 153, in unpack
n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function unpack at 0x7f644b962710>(*(FakeTensor(..., device='cuda:0', size=(1, s0*s1, 128)), [(s0, s1)], 'b * c'), **{}):
unsupported operand type(s) for +: 'int' and 'SymBool'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114828
Approved by: https://github.com/lezcano
Prior to this PR, autotuned arguments could only be at the back of the argument list. This is an inductor limitation and not triton limitation. Fixing this allows more MRS kernels to use user defined triton kernels.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114002
Approved by: https://github.com/aakhundov
ghstack dependencies: #113967