## Issue
During autotune, we're not applying size hints atomically for the example inputs used for benchmarking.
If there is unbacked symint showing up in inputs' strides, this might lead to CUDA IMA,
and this could be reproduced by the added unittest, with stride being `[128 * u0, 128, 1]` and unbacked fallback being 8192, after calling `benchmark_example_value`, we get back a tensor with stride as `[8192, 128, 1]` as opposed to `[128 * 8192, 128, 1]`
## Fix
Using the atomic API when trying to apply size hints to input tensor' strides.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163660
Approved by: https://github.com/ColinPeppler
Summary:
As titled.
Without the diff, we got P1963055009
With the diff passing in the enviroment, we can do correct sym_int deduction:
https://fburl.com/mlhub/p5zy7o28
Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:unbacked_symints -- test_sdfpa_unbacked_strides --print-passing-details --env TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --env TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(u0, 0)"
```
Without the fix: P1964887260
With the fix: P1964888579
Differential Revision: D83211018
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163925
Approved by: https://github.com/ColinPeppler
Internal user tried enabling combo kernels, but ran into "Cannot convert symbols to int". This PR is to enable combo kernels on inputs with data-dependent shapes.
### Example exception
```
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 4997, in benchmark_combo_kernel
kernel_code_list = self.generate_combo_kernel_code(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/simd.py", line 1849, in generate_combo_kernel_code
src_code = kernel.codegen_kernel()
^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 802, in codegen_kernel
code.splice(self.codegen_kernel_benchmark(num_gb=0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 852, in codegen_kernel_benchmark
var_names.extend(self.kernel_benchmark_extra_args())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 733, in kernel_benchmark_extra_args
extra_args.append(str(V.graph.sizevars.size_hint(tree.numel)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 584, in size_hint
return int(out)
^^^^^^^^
File "/home/colinpeppler/.conda/envs/pytorch/lib/python3.12/site-packages/sympy/core/expr.py", line 307, in __int__
raise TypeError("Cannot convert symbols to int")
torch._inductor.exc.InductorError: TypeError: Cannot convert symbols to int
```
Differential Revision: [D82042230](https://our.internmc.facebook.com/intern/diff/D82042230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162442
Approved by: https://github.com/jansel
Switch from `guard_size_oblivious` to `guard_or_false` if you encounter a DDE, this would then fallback to computing elementwise strides.
2dccff7dcf/torch/_prims/__init__.py (L1919-L1923)
We think it's safe because Laith tested whether this fallback would fail any tests. It did not.
https://github.com/pytorch/pytorch/pull/158157
## Data-dependent exceptions (DDE)
```
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 2139, in _to_copy
x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
...
File "/data/users/colinpeppler/pytorch/torch/_prims/__init__.py", line 1920, in _convert_element_type_meta
if torch._prims_common.is_non_overlapping_and_dense(a):
File "/data/users/colinpeppler/pytorch/torch/_prims_common/__init__.py", line 494, in is_non_overlapping_and_dense
if guard_size_oblivious(length == 1):
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0 - 4, 1) (unhinted: Eq(u0 - 4, 1)). (Size-like symbols: u0)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158894
Approved by: https://github.com/pianpwk, https://github.com/laithsakka
### What
- Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path.
- For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`).
### Example DDE
```
torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0)
Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes)
```
```
torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0)
Caused by: (_inductor/ir.py:2797 in create)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267
Approved by: https://github.com/eellison
Related: #125914 (specifically see [comment](https://github.com/pytorch/pytorch/issues/125914#issuecomment-2513044125))
This PR addresses two broken things involving the usage of unbacked SymInts for calls to `slice()` with data-dependent bounds. These issues are encountered in practice for `narrow()` operating on the batch dim with an NJT input, but apply to other subclasses as well. The test in this PR uses a purpose-built subclass.
There are two different issues here, depending on whether `torch.compile()` is called with `dynamic=True`. In practice, these only occur when the unbacked SymInts are created within the torch_dispatch implementation of a subclass, because the unbacked symbols are considered "freshly created" when the output subclass instance is handled in Dynamo.
**Error 1 (dynamic=False):**
```
LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(-Min(22, Max(0, u0)) + Min(22, Max(u0 + u1, Max(0, u0))), 0) (unhinted: Eq(-Min(s0, Max(0, u0)) + Min(s0, Max(u0 + u1, Max(0, u0))), 0)). (Size-like symbols: u1, u0)
```
The expression comes from the use of `clamp()` logic for `SliceView` in Inductor:
41e59754b4/torch/_inductor/ir.py (L3014)
If the (start, end) bounds for the `slice()` are statically known to be in range for the given dim (e.g. provided via `torch._check()` calls), we can avoid this `clamp()` logic and the error. This PR implements this fix.
**Error 2 (dynamic=True):**
```
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u0} not in returned outputs NestedTensor(size=(2, s16, s1), offsets=FakeTensor(..., device='cuda:0', size=(3,), dtype=torch.int64), grad_fn=<NarrowBackwardAutogradNestedTensor0 object at 0x7f1f8603cfd0>, contiguous=True) ((s1*s16, s1, 1), s1*u0)
```
The storage offset of the values component of the returned NJT is `s1*u0` where `s1` is known to be an integer. This PR expands the special logic handling the `constant * u0` case to handle SymInts as well:
314e08eb52/torch/fx/experimental/symbolic_shapes.py (L1013-L1031)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142062
Approved by: https://github.com/ezyang
ghstack dependencies: #143526
Differential Revision: D61506212
Use `skipCUDAIf` from `torch.testing._internal.common_device_type` if we create the test class with `instantiate_device_type_tests`.
`instantiate_device_type_tests` would make sure the class has attr device_type, which works with`skipCUDAIf` from `torch.testing._internal.common_device_type`.
Also skipping test_vertical_pointwise_reduction_fusion for cpu test class, since the test expects cuda.
FAILED [0.0026s] test/inductor/test_unbacked_symints.py::TestUnbackedSymintsCPU::test_vertical_pointwise_reduction_fusion_cpu - AttributeError: 'TestUnbackedSymintsCPU' object has no attribute 'device'
repro:
```
CUDA_VISIBLE_DEVICES="" pytest test/inductor/test_unbacked_symints.py -k cpu -v
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133936
Approved by: https://github.com/ColinPeppler, https://github.com/desertfire
```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion
File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
for node1, node2 in self.get_possible_fusions():
File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
check_all_pairs(node_grouping)
File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
if self.can_fuse(node1, node2):
File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
return self.get_backend(device).can_fuse_vertical(node1, node2)
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
return self._triton_scheduling.can_fuse_vertical(node1, node2)
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
if not all(
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
cls._split_iteration_ranges(groups, lengths)
File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
return int(out)
File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```
Where the unbacked symints show up at.
```
> /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges()
(Pdb) print(groups)
(1, 512*u0)
(Pdb) print(lengths)
([u0, 32, 16], [])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125982
Approved by: https://github.com/jansel
I want to generate runtime assert nodes during lowering, which means
that I need a finalized list of asserts by the time I start lowering.
This means this runtime assert introduced in
https://github.com/pytorch/pytorch/pull/113839 must go. Fortunately,
this runtime assert was never exercisable, apparently, and the test
still "passes" without it. I replace it with a compile time test. We
can revisit if this assert fails in practice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124864
Approved by: https://github.com/jansel
## Context
Suppose we have two symbols: `u0` and `s0` where we know that `u0 = s0`. Now, let's say we tried to look up the size hint for `u0 + 1`.
* Before this PR, we would use a fallback hint if one was provided.
3f6acf65fd/torch/_inductor/sizevars.py (L406-L407)
* With this PR, we would try to replace `u0` with `s0` via `simplify()` before using a fallback hint. 3f6acf65fd/torch/_inductor/sizevars.py (L46-L47)
## Concrete Example
A scenario where this is useful is when we're running autotuning benchmarking on bmm with two input nodes: one who has `s0` as the batch size and one who has `u0` as the batch size. During benchmarking, we'll create two example input tensors where the input with `u0` has to use a fallback hint for batch size. This will lead to a mismatch.
e3d80f2fa9/torch/_inductor/select_algorithm.py (L991-L997)
Using the fallback hint (i.e. 8192) leads to a batch size mismatch.
```
# Note: s0 = 7 and u0 = 7 and fallback hint is 8192.
LoweringException: ErrorFromChoice: Expected size for first two dimensions of batch2 tensor to be: [7, 30] but got: [8192, 30].
From choice ExternKernelCaller(extern_kernels.bmm)
```
Differential Revision: D55619331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123140
Approved by: https://github.com/aakhundov
Summary: In `torch.inference_mode()`, fake tensors don't have `_version`s. This breaks unbacked SymInt memoization in `torch.nonzero` tracing. Here we disable the latter in inference mode.
Fixes https://github.com/pytorch/pytorch/issues/122127
Test Plan:
```
$ python test/inductor/test_unbacked_symints.py -k test_nonzero_in_inference_mode
...
----------------------------------------------------------------------
Ran 2 tests in 14.060s
OK
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122147
Approved by: https://github.com/ezyang
Summary: In `torch.inference_mode()`, fake tensors don't have `_version`s. This breaks unbacked SymInt memoization in `torch.nonzero` tracing. Here we disable the latter in inference mode.
Test Plan:
```
$ python test/inductor/test_unbacked_symints.py -k test_nonzero_in_inference_mode
...
----------------------------------------------------------------------
Ran 2 tests in 14.060s
OK
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122147
Approved by: https://github.com/ezyang
## Problem
A user-defined Triton kernel grid may use a sympy magic method like `Max`. This comes in the form of a form of a `sympy.Expr`, namely `sympy.core.function.FunctionClass`.
Handling this is not trivial since `user_defined_kernel_grid_fn_code` is used in Eager & Inductor. Eager usage below.
## Approach
Pass in wrapper when Inductor codegens grid with ints/sympy.Expr, so we can utilize wrapper functions, such as `codegen_shape_tuple()`.
Differential Revision: D53367012
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119165
Approved by: https://github.com/aakhundov
## Context
This is an example that runs into an AssertionError while lowering in Inductor.
```
# While lowering, b will be expanded because b.size(1) == 1.
a = torch.zeros([u0, 512])
b = torch.ones([u0, 1])
return a * b
```
Below's the tail-end of the stack trace. Here's the important bits:
1. In _inductor/sizevars.py, we'll call `self.shape_env.defer_runtime_assert(expr, msg, fx_node=V.graph.current_node)`.
2. This leads to the creation of a `ShapeEnvEvent` with an FX node via `kwargs={"fx_node": V.graph.current_node}` ([see](0c9b513470/torch/fx/experimental/recording.py (L245-L247))).
3. Eventually, we try to call `maybe_convert_node()` but it expects translation validation to be on ([see](0c9b513470/torch/fx/experimental/recording.py (L118-L121))).
```
File "pytorch/torch/_inductor/lowering.py", line 221, in transform_args
for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
File "pytorch/torch/_inductor/lowering.py", line 294, in wrapped
out = decomp_fn(*args, **kwargs)
File "pytorch/torch/_inductor/lowering.py", line 676, in broadcast_tensors
x = expand(x, target)
File "pytorch/torch/_inductor/lowering.py", line 294, in wrapped
out = decomp_fn(*args, **kwargs)
File "pytorch/torch/_inductor/lowering.py", line 793, in expand
return TensorBox(ExpandView.create(x.data, tuple(sizes)))
File "pytorch/torch/_inductor/ir.py", line 1871, in create
new_size = cls._normalize_size(x, new_size)
File "pytorch/torch/_inductor/ir.py", line 1862, in _normalize_size
new_size[i] = V.graph.sizevars.expect_equals(
File "pytorch/torch/_inductor/sizevars.py", line 338, in expect_equals
self.expect_true(sympy.Eq(left, right), msg=msg)
File "pytorch/torch/_inductor/sizevars.py", line 333, in expect_true
self.shape_env.defer_runtime_assert(expr, msg, fx_node=V.graph.current_node) # (1) is here
File "pytorch/torch/fx/experimental/recording.py", line 257, in wrapper
return event.run(self) # (2) happens right before this
File "pytorch/torch/fx/experimental/recording.py", line 155, in run
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
File "pytorch/torch/fx/experimental/recording.py", line 138, in replacearg
kwargs[key] = fn(kwargs[key])
File "pytorch/torch/fx/experimental/recording.py", line 128, in maybe_convert_node
assert hasattr(shape_env, "name_to_node") # (3) is here
```
## Approach
Since [translation validation](c6be5d55a5/torch/fx/experimental/validator.py (L574)) may not be on during Inductor lowering, we can check if that's True and return the FX node's name in this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118066
Approved by: https://github.com/ezyang, https://github.com/peterbell10
# Context
Let's say we do `View.create(x, sizes)` where `x` is a `SliceView` and `sizes` contains unbacked symints e.g. `sizes = [i14, 256]`. Then, this we'll run ([this code](7e37f63e5e/torch/_inductor/ir.py (L2058-L2071))) where we.
1. Call `x.realize()` -- SliceView(Pointwise) -> SliceView(ComputedBuffer).
2. Retrieve storage & layout via `as_storage_and_layout(x)`
3. Calculate `new_layout` based off layout & `new_sizes`
3. `return ReinterpretView(storage, new_layout)`
However, (2) will raise `NotImplementedError` ([see](7e37f63e5e/torch/_inductor/ir.py (L1704-L1731))) since `x` is a `SliceView` and that isn't supported.
Thus, I tried adding support for `SliceView` in `as_storage_and_layout`. This worked for my case, but if instead `sizes` had backed symints e.g. `sizes=[s0, 256]` then some benchmarked models lost accuracy.
```
if isinstance(x, SliceView):
return as_storage_and_layout(
x.data,
freeze=freeze,
want_contiguous=want_contiguous,
stride_order=stride_order,
)
```
So instead of the above, I tried unwrapping the `SliceView` via `x = x.unwrap_view()`. This works for my usecase and passes CI but I'm not entirely sure why. If we unwrap our `SliceView` and create a `ReinterpretView`, I'd assume we'd lose the reindexer from `SliceView`. ~~But maybe we can re-create the same indexing from the `ReinterpretView`'s strides?~~ edit: we do lose vital information (like offset) when you release your `SliceView` and create a `ReinterpretView` so that's a no-go.
Moving onto the final version of this PR. We call `ExternKernel.realize_input()` (feels a bit weird to use `ExternKernel` but it's exactly what I need). It will go ahead and handle our `SliceView` case ([see](a468b9fbdf/torch/_inductor/ir.py (L3733-L3739))) by converting it to a `ReinterpretView` with the correct offset.
# Test
```
$ python test/inductor/test_unbacked_symints.py
..
----------------------------------------------------------------------
Ran 10 tests in 20.813s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117013
Approved by: https://github.com/jansel, https://github.com/ezyang
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
Summary: Unbacked SymInts can't get a `sizevars.size_hint` due to being data-dependent. #109893 has added a new `fallback` parameter to `sizevars.size_hint` to specify the fallback value in cases like unbacked SymInt. In this PR we add more of those.
Test Plan: CI
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110520
Approved by: https://github.com/jansel, https://github.com/ezyang