TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
For a custom op that returns a list of a single tensor with unbacked symint shape:
```python
@torch.library.custom_op(
"aoti_custom_ops::fn_ret_list_of_single_tensor", mutates_args={}
)
def fn_ret_list_of_single_tensor(x: torch.Tensor) -> list[torch.Tensor]:
s = x.sum().to(torch.int64)
return [torch.randn(s.item())]
@fn_ret_list_of_single_tensor.register_fake
def _(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.new_dynamic_size()
return [torch.randn(i0)]
```
Before the fix, we have the following error:
```
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: error: type/value mismatch at argument 1 in template parameter list for ‘template<class _Tp, class ... _Types> constexpr const _Tp& std::get(const std::variant<_Types ...>&)’
456 | auto u0 = std::get<0>(buf1).size(0);
| ~~~~~~~~~~~^~~~~~
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: note: expected a type, got ‘0’
In file included from /data/users/yidi/pytorch/torch/include/c10/util/Exception.h:14,
from /data/users/yidi/pytorch/torch/include/c10/core/ScalarType.h:5,
from /data/users/yidi/pytorch/torch/include/ATen/AccumulateType.h:4,
from /data/users/yidi/pytorch/torch/include/ATen/native/Math.h:3,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec_base.h:31,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec512/vec512.h:8,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/vec.h:4,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/functional_base.h:6,
from /data/users/yidi/pytorch/torch/include/ATen/cpu/vec/functional.h:3,
from /tmp/tmp5iikarn2/3b/c3bi5gk6mslf6u4iaqafhxm64z6u65e3eain4xlary5blqnvv6xx.h:39,
from /tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:366:
/usr/include/c++/11/variant:1145:27: note: candidate: ‘template<class _Tp, class ... _Types> constexpr const _Tp&& std::get(const std::variant<_Types ...>&&)’
1145 | constexpr const _Tp&& get(const variant<_Types...>&& __v)
| ^~~
/usr/include/c++/11/variant:1145:27: note: template argument deduction/substitution failed:
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: error: type/value mismatch at argument 1 in template parameter list for ‘template<class _Tp, class ... _Types> constexpr const _Tp&& std::get(const std::variant<_Types ...>&&)’
456 | auto u0 = std::get<0>(buf1).size(0);
| ~~~~~~~~~~~^~~~~~
/tmp/tmp5iikarn2/cci3ruqb7zdwtl457zo4itspq3sjnqiayhcshp5uaak7ktksckix/cggzqlwf4bmu6tjqodhoto3hhkhgharhwtvw2uxsasqrdipnazrv.cpp:456:26: note: expected a type, got ‘0’
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147649
Approved by: https://github.com/angelayi
ghstack dependencies: #147130
This pull request reverts the changes to `torch/_inductor/ir.py` file that were added in #146917.
Where I tested, there were changes only from `torch/_inductor/codegen/cpp_wrapper_gpu.py`, it turns out that changes in `torch/_inductor/ir.py` file are not really needed. So it's my fault, I didn't sync the environments (between several machines) correctly.
@davidberard98 @YUNQIUGUO maybe that's why the tests on CUDA didn't pass?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147639
Approved by: https://github.com/etaf, https://github.com/davidberard98
As title.
Many changes adapted from https://github.com/pytorch/pytorch/pull/129537.
Also this diff is only for *static* method of torchbind *attributes*. Some case that's not supported/tested:
- dynamic torchbind objects
- torchbind objects as an input to the module.
Note that in JIT Inductor, the attributes are lifted as inputs. So even if we just have torchbind objects as attributes, they will show up as inputs in the graph.
Example generated python code in torch.compile with inductor backend for the test case in `inductor/test_torchbind.py` (P1730554370):
```python
async_compile.wait(globals())
del async_compile
def call(args):
arg1_1, arg2_1, arg3_1 = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
assert_size_stride(arg2_1, (2, 3), (3, 1))
buf2 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, arg2_1, buf2)
del arg1_1
del arg2_1
# Topologically Sorted Source Nodes: [x, takes_foo_tuple_return], Original ATen: [aten.add]
buf3 = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(arg3_1, buf2)
buf4 = buf3[0]
assert_size_stride(buf4, (2, 3), (3, 1))
buf5 = buf3[1]
assert_size_stride(buf5, (2, 3), (3, 1))
buf6 = buf4; del buf4 # reuse
cpp_fused_add_1(buf6, buf5)
del buf5
# Topologically Sorted Source Nodes: [y, b], Original ATen: [aten.add]
buf7 = torch.ops._TorchScriptTesting.takes_foo.default(arg3_1, buf6)
del buf3
del buf6
buf8 = buf7
assert_size_stride(buf8, (2, 3), (3, 1))
# Topologically Sorted Source Nodes: [c], Original ATen: []
buf9 = torch.ops.higher_order.call_torchbind(arg3_1, 'add_tensor', buf2)
del arg3_1
del buf7
buf10 = buf9
assert_size_stride(buf10, (2, 3), (3, 1))
del buf9
buf11 = buf2; del buf2 # reuse
cpp_fused_add_2(buf11, buf8, buf10)
return (buf11, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg1_1 = rand_strided((2, 3), (3, 1), device='cpu', dtype=torch.float32)
arg2_1 = rand_strided((2, 3), (3, 1), device='cpu', dtype=torch.float32)
import pickle
global arg3_1
arg3_1 = pickle.loads(b'\x80\x04\x95[\x00\x00\x00\x00\x00\x00\x00\x8c\x05torch\x94\x8c\x0cScriptObject\x94\x93\x94)\x81\x94]\x94(K\nK\x14e\x8c0__torch__.torch.classes._TorchScriptTesting._Foo\x94\x86\x94b.')
fn = lambda: call([arg1_1, arg2_1, arg3_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146927
Approved by: https://github.com/angelayi
With the `_scaled_dot_product_efficient_attention.default`, we have lowering logic to realize the bias to specific alignment constraints. Some of the dims can be expanded, and we need to keep the stride of that dim to 0 to avoid materializing a larger tensor than we need. Previously, we had checked stride of tensor, but if it is not realized, that will not work. so we should check the strides of the meta as well.
Note: getting the exact of realizing/slicing/requiring_exact_strides was a little tricky. I commented to @exclamaforte on an example unable-to-fuse message you get if you do it incorrectly.
Fix for https://github.com/pytorch/pytorch/issues/145760
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146054
Approved by: https://github.com/shunting314
We were codegening intermediary dtype asserts in some places but not all. expands assertions, fixes newly failing assertion in
`TORCHINDUCTOR_COMPILE_THREADS=1 TORCH_LOGS="output_code" PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_logcumsumexp_cuda_float16` for scan.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146067
Approved by: https://github.com/shunting314, https://github.com/jansel
Before the PR, we're getting an undefined symbol error for output code when an unbacked symint is **only** used in the hop because we didn't correctly record the dependency of the unbacked symbols for hops and it gets DCEed accidentally.
This PR adds the symbol arguments to `constant_args`, where the dependencies can be correctly constructed when `get_unbacked_symbol_uses` is called to check constant_args.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143456
Approved by: https://github.com/desertfire
Before the PR, we're getting an undefined symbol error for output code when an unbacked symint is **only** used in the hop because we didn't correctly record the dependency of the unbacked symbols for hops and it gets DCEed accidentally.
This PR adds the symbol arguments to `constant_args`, where the dependencies can be correctly constructed when `get_unbacked_symbol_uses` is called to check constant_args.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143456
Approved by: https://github.com/desertfire
Record input fake tensors at time of tracing and store them in the node meta. Inductor passes have the possibility of changing strides, so it is safer to record the strides of the inputs at tracing. See, https://github.com/pytorch/pytorch/issues/137979 for more context.
We can also extend this to custom ops, and user-visible outputs. If this ends up being compilation time sensitive we can just record strides (and maybe storage offset, per @zou3519) instead of the complete fake tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145448
Approved by: https://github.com/zou3519
ghstack dependencies: #145953
Previous impl would take a size hint, which was failing internally with a
```
strides1 = [V.graph.sizevars.size_hint(strides1[i]) for i in non_1_indices]
File "/dev/shm/uid-30083/6f57b5f9-seed-nspid4026541609_cgpid284393-ns-4026541967/torch/_inductor/sizevars.py", line 554, in size_hint
return int(out)
File "/dev/shm/uid-30083/6f57b5f9-seed-nspid4026541609_cgpid284393-ns-4026541967/sympy/core/expr.py", line 307, in __int__
raise TypeError("Cannot convert symbols to int")
```
There are unbacked tests in test_triton which should exercise this, as well as other tests for these functions when they were added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145953
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Example failing test:
`pytest -s test_torchinductor_opinfo.py -k test_comprehensive_special_polygamma_special_polygamma_n_0_cpu_float32` when using triton CPU.
Failure:
```shell
triton.compiler.errors.CompilationError: at 10:11:
def triton_poi_fused_polygamma_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 25
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = 1.0
tl.static_assert(tmp1.dtype == tl.float32)
tmp2 = ops.polygamma(tmp1, tmp0)
^
NameError('ops is not defined')
```
This occurs because the registered triton fallbacks are not used during the lowering to inductor IR.
Marked the problematic code in the excerpt below from 6bc17b0725/torch/_inductor/lowering.py (L572)
```python
def make_pointwise(
fn,
override_return_dtype=None,
override_device=None,
override_fn_when_input_bool=None,
override_fn_when_gpu_float64=None,
allow_alpha=False,
triton_fallback=None,
):
def inner(*inputs: TensorBox, alpha=None):
if triton_fallback is not None and any(
isinstance(inp, IRNode) and is_triton(inp) for inp in inputs <--- is_triton should return True when using triton CPU
):
assert not allow_alpha # not implemented
return triton_fallback(*inputs)
inputs = promote_constants(inputs, override_return_dtype)
if allow_alpha:
if alpha is not None and alpha != 1:
inputs = list(inputs)
```
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144389
Approved by: https://github.com/jansel
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits.
What this PR fixes:
* in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API
* ir.py - don't remove None args when using newer triton versions
* wrapper.py - update signature & constant handling
What this doesn't fix:
* correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants).
* cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels)
test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: 1374074098)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145348
Approved by: https://github.com/jansel
ghstack dependencies: #145051
Found this bug when debugging a MA issue in CI that can not be repro-ed on devgpu.
On GPU with less than 68 SMs (like NVidia L4 used in CI), running torch compile in max-autotune mode may result in the following confusing error https://gist.github.com/shunting314/370f42f547e3367a3773237942725a86 complaining about layout:
```
torch._inductor.exc.InductorError: LoweringException: AssertionError: convert FlexibleLayout to FixedLayout first
```
The reason is, even if we don't pick Triton template, Inductor still returns a MultiTemplateBuffer for tuned addmm. MultiTemplateBuffer.get_reads called from Reduction.num_splits may indexing a FlexibleLayout which results in the error aforementioned.
The issue does not appear on devgpu because we freeze the layout of addmm inputs when rendering triton templates.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145133
Approved by: https://github.com/jansel
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012
- Support flexible number of GEMMs
- Share activation across GEMMs
- The Grouped GEMM Template supports independent activations
- However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
- Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
- Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```
**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return self.linear0(x), self.linear1(x)
if __name__ == "__main__":
with torch.no_grad():
input = torch.randn(batch_size, in_features, dtype=dtype)
m = M(bias=bias).to(dtype=dtype).eval()
cm = torch.compile(m)
act_res = cm(input)
```
Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py
**Next Step**
- Support Epilogue fusion
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
When calling a fallback op in cpp_wrapper mode, where any of the inputs are complex numbers, utilize the runtime dispatched fallback mode. This properly handles the Conjugate and Negative dispatch keys, if present, in exchange for a performance pessimization in complex arithmetic.
This PR additionally fixes some cascading failure modes exposed in our `aot_inductor` tests by this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143223
Approved by: https://github.com/desertfire
ghstack dependencies: #141371
Additionally, enable torchinductor opinfo tests exercising all
previously fixed bugs in this stack.
Note: I've manually sharded the cpp_wrapper CI checks into 2 shards.
Once all OpInfo tests are enabled we should switch back to automatic
sharding, but until then the pipeline doesn't have appropriate timing
stats. More shards would be helpful given the compilation slowdown
associated with cpp_wrapper, but 2 will do for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141371
Approved by: https://github.com/desertfire
NonOwningLayout is always constructed to a FixedLayout. We should handle it the same way as FixedLayout. Note - this case is very rare, I added an assertion here and no test/model failed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143315
Approved by: https://github.com/zou3519
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.
Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`
And the modification api:
```
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
out="qk"
) | indent_except_first(1) }}
```
We have:
```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```
Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.
There are a couple main use cases for prologue fusion:
- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.
Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066
Other notes:
By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.
With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel