Commit Graph

672 Commits

Author SHA1 Message Date
cc8f1cddd4 Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934, #136935
2024-10-01 13:22:10 +00:00
3f457ee1f6 Fix AOT Graph capture not propagating non_blocking copy parameter to … (#136513)
…inductor codegen.

Fixes #136260

**Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow.

Regarding the bug, re-running the issue's reproduction code:
```
import torch

def fn(x):
    return x.to(device="cuda", non_blocking=True)

inp = torch.randn(3, 4)

torch.compile(fn)(inp)
```

We now have the non_blocking being passed on to codegen properly:

```
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4]"):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4][4, 1]cpu"):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  ===== Forward graph 0 =====
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]     def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True);  arg0_1 = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32);  device_put = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         return (convert_element_type,)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code:
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference']
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals())
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1, = args
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     args.clear()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     assert_size_stride(arg0_1, (3, 4), (4, 1))
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         torch.cuda.set_device(0)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0.copy_(arg0_1, True)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         del arg0_1
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return (buf0, )
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     fn = lambda: call([arg0_1])
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__":
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
```
See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513
Approved by: https://github.com/eellison
2024-10-01 00:32:47 +00:00
8982906502 Revert "Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)"
This reverts commit 3ff2d93d9f72fd26503ef0cf5c5956edad4c52e6.

Reverted https://github.com/pytorch/pytorch/pull/136972 on behalf of https://github.com/ezyang due to need to back out for merge conflict ([comment](https://github.com/pytorch/pytorch/pull/136972#issuecomment-2384182244))
2024-09-30 21:35:08 +00:00
3ff2d93d9f Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136917, #136934, #136935
2024-09-30 18:04:36 +00:00
dfdda2f6a6 inductor: use previous guards to know if a size is 1 for broadcasting (#136670)
Fixes https://github.com/pytorch/pytorch/issues/136640

Today, inductor has some logic to figure out when it needs to do broadcasting during lowering, which just checks if any of the input shapes have sizes equal to 1.

In particular: we should already have this information by the time we get to inductor, because our FakeTensor compute will have branched/guarded on whether any ops performed broadcasting, appropriately.

In particular, if we have a tensor with a size value of `(64//((2048//(s3*((s2//s3)))))))`, and it happens to be equal to one (and it is used in an op that requires this dim to be broadcasted), FakeTensorProp will have generated a guard:
```
Eq((64//((2048//(s3*((s2//s3))))))), 1)
```

I chose the simplest possible way to beef up inductor's checks to know when a given size is equal to 1: loop over the existing shape env guards, and if our current size is a sympy expression on the LHS of one of our `Eq(LHS, 1)` guards, then return True.

I'm hoping for feedback on whether or not this approach is reasonable. One better option I could imagine is that our symbolic reasoning should have automatically simplified the size of our tensor down to a constant as part of evaluating that guard. I was originally going to try to do this directly in the shape env, but I ran into a few issues:

(1) I wanted to call some version of `set_replacement(expr, 1)`. But `set_replacement()` only accepts plain symbols on the LHS, not expressions

(2) in theory I could get this to work if I could rework the above expression to move everything that is not a free variable to the RHS, e.g. `Eq(s2, 32)`. It looks like our existing  `try_solve()` logic is... [not quite able](https://github.com/pytorch/pytorch/blob/main/torch/utils/_sympy/solve.py#L27) to do this generally though.

Checking the guards feels pretty simple-and-easy. Are we worried that it is too slow to iterate over all the guards? I could also cache the lookup so we only need to iterate over guards that are of the form `Eq(LHS, 1)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136670
Approved by: https://github.com/ezyang
2024-09-30 13:24:57 +00:00
06909803cc Existing mypy issues (#136236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136236
Approved by: https://github.com/bobrenjc93, https://github.com/Skylion007
2024-09-24 01:02:07 +00:00
31c4e0d37d [inductor] Cleanup analysis done at lowering time (#135412)
Before this we would take multiple passes over the body of each IRNode as we did lowering.  This combines most analysis into `OpCounterCSE` so it can be done in a single pass.

Before:
![image](https://github.com/user-attachments/assets/0047db09-4258-4491-a9a6-b078e183092a)

After:
![image](https://github.com/user-attachments/assets/1e03adcb-8303-4bb1-8bbb-cc42dacd44d7)

This stack:
![image](https://github.com/user-attachments/assets/d6b50b24-c30c-4d23-8b1a-344b3ba65d7a)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135412
Approved by: https://github.com/oulgen
ghstack dependencies: #135286, #135306, #135377, #135400
2024-09-08 18:02:36 +00:00
16f5155992 [inductor] Fast path for extract_read_writes without tracing (#135377)
Before (bottom of stack):
![image](https://github.com/user-attachments/assets/13060ff9-b31d-42a9-8e8f-c50b2bf3dc2f)

After (this PR):
![image](https://github.com/user-attachments/assets/7d190821-b614-46b7-9e9e-9087443df654)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135377
Approved by: https://github.com/oulgen
ghstack dependencies: #135286, #135306
2024-09-08 18:02:32 +00:00
3bdc54ed18 [inductor] Refactor LoopBody.memory_usage (#135286)
This is preparing for some other changes where I speed up extract_read_writes tracing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135286
Approved by: https://github.com/oulgen
2024-09-08 18:02:24 +00:00
eac5e12548 [inductor] Move LoopBody to its own file (#135257)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135257
Approved by: https://github.com/oulgen
2024-09-07 16:29:15 +00:00
7ffb3b201c [inductor] Remove LoopBody.reads,writes,other (#135256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135256
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084, #135079, #135235
2024-09-06 06:11:55 +00:00
f946bf88c4 [inductor] Skip retracing an existing LoopBody (#135235)
This is roughly a 7% speedup in inductor compile time for hf_Bert_large.  The time spent in `LoopBody.__init__` improves from 15% to 8% of `fx_codegen_and_compile`.

Before
![image](https://github.com/user-attachments/assets/7de0f28e-35bd-472f-b4be-b52733d2a85c)

After
![image](https://github.com/user-attachments/assets/5f0cf11a-43c5-43ae-b13c-f32383a75a7f)

Overall
![image](https://github.com/user-attachments/assets/6a369d8c-fb5e-4ad2-9504-0fc745ad6568)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135235
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084, #135079
2024-09-06 06:11:55 +00:00
4262755b5a [cond] fix typo in cond codegen (#134708)
As titled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134708
Approved by: https://github.com/jansel
2024-09-05 22:38:24 +00:00
50d1e37079 [AOTI] Fix a unbacked symint retrieve bug (#134670)
Summary: Fix https://github.com/pytorch/pytorch/issues/134081. When a unbacked symint is computed as the shape of a tensor from a tuple, generated C++ code needs to use std::get<> to extract the tensor.

Differential Revision: [D62142113](https://our.internmc.facebook.com/intern/diff/D62142113)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134670
Approved by: https://github.com/angelayi, https://github.com/22quinn, https://github.com/chenyang78
2024-09-05 11:34:14 +00:00
13a4a0c60d [Inductor] Apply loop split optimization in codegen_node (#132389)
This PR applies loop split optimization in codegen_node to avoid non-contiguous load. When the vector is loaded in a non-contiguous manner due to a division in the index, we eliminate the division by splitting the loop to avoid non-contiguous load.

Example:
```
import torch
import torch.nn as nn

class GNReLU(torch.nn.Module):
    def __init__(self, num_groups, num_channels):
        super(GNReLU, self).__init__()
        self.gn = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return torch.nn.functional.relu(self.gn(x))

input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last)
m = GNReLU(32, 960).eval()
compiled_m = torch.compile(m)

with torch.no_grad():
    compiled_m(input)
```

Generated code:

- Before:
```
cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(56)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<long>(17280L));
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(9216L); x2+=static_cast<long>(1L))
                        {
                            for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16);
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14);
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(9216L); x1+=static_cast<long>(1L))
                {
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(960L); x2+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)), 16);
                        auto tmp1 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr0[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp3 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr1[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x2), 16);
                        auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x2), 16);
                        auto tmp2 = tmp0 - tmp1;
                        auto tmp4 = static_cast<float>(276480.0);
                        auto tmp5 = at::vec::Vectorized<float>(tmp4);
                        auto tmp6 = tmp3 / tmp5;
                        auto tmp7 = static_cast<float>(1e-05);
                        auto tmp8 = at::vec::Vectorized<float>(tmp7);
                        auto tmp9 = tmp6 + tmp8;
                        auto tmp10 = tmp9.rsqrt();
                        auto tmp11 = tmp2 * tmp10;
                        auto tmp13 = tmp11 * tmp12;
                        auto tmp15 = tmp13 + tmp14;
                        auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0));
                        tmp16.store(out_ptr2 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)));
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

def call(args):
    arg2_1, = args
    args.clear()
    assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960))
    buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32)
    cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3)
    del arg2_1
    return (buf3, )
```

- After:
```
cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(56)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<long>(17280L));
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(9216L); x2+=static_cast<long>(1L))
                        {
                            for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16);
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14);
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(9216L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(32L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 16);
                            auto tmp1 = out_ptr0[static_cast<long>(x2 + (32L*x0))];
                            auto tmp4 = out_ptr1[static_cast<long>(x2 + (32L*x0))];
                            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x3 + (30L*x2)), 16);
                            auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x3 + (30L*x2)), 16);
                            auto tmp2 = at::vec::Vectorized<float>(tmp1);
                            auto tmp3 = tmp0 - tmp2;
                            auto tmp5 = static_cast<float>(276480.0);
                            auto tmp6 = tmp4 / tmp5;
                            auto tmp7 = static_cast<float>(1e-05);
                            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                            auto tmp9 = 1 / std::sqrt(tmp8);
                            auto tmp10 = at::vec::Vectorized<float>(tmp9);
                            auto tmp11 = tmp3 * tmp10;
                            auto tmp13 = tmp11 * tmp12;
                            auto tmp15 = tmp13 + tmp14;
                            auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0));
                            tmp16.store(out_ptr2 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)));
                        }
                        for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                        {
                            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14);
                            auto tmp1 = out_ptr0[static_cast<long>(x2 + (32L*x0))];
                            auto tmp4 = out_ptr1[static_cast<long>(x2 + (32L*x0))];
                            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x3 + (30L*x2)), 14);
                            auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x3 + (30L*x2)), 14);
                            auto tmp2 = at::vec::Vectorized<float>(tmp1);
                            auto tmp3 = tmp0 - tmp2;
                            auto tmp5 = static_cast<float>(276480.0);
                            auto tmp6 = tmp4 / tmp5;
                            auto tmp7 = static_cast<float>(1e-05);
                            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                            auto tmp9 = 1 / std::sqrt(tmp8);
                            auto tmp10 = at::vec::Vectorized<float>(tmp9);
                            auto tmp11 = tmp3 * tmp10;
                            auto tmp13 = tmp11 * tmp12;
                            auto tmp15 = tmp13 + tmp14;
                            auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0));
                            tmp16.store(out_ptr2 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14);
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

def call(args):
    arg2_1, = args
    args.clear()
    assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960))
    buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32)
    cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3)
    del arg2_1
    return (buf3, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132389
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
2024-09-04 22:42:46 +00:00
f927bcb934 Revert "[Inductor] Apply loop split optimization in codegen_node (#132389)"
This reverts commit 3cb5d251224b3fb59b5a10c6fefbb4c84eb565a6.

Reverted https://github.com/pytorch/pytorch/pull/132389 on behalf of https://github.com/ZainRizvi due to Hi, this seems to be breaking in trunk. See test_dataloader.py::TestDataLoader::test_segfault [GH job link](https://github.com/pytorch/pytorch/actions/runs/10660461216/job/29556282081) [HUD commit link](de3a641476) ([comment](https://github.com/pytorch/pytorch/pull/132389#issuecomment-2326843129))
2024-09-03 15:40:45 +00:00
3daca187aa [Inductor] Allow customizing the padding format (#133939)
Based on https://github.com/pytorch/pytorch/pull/130956.

Inductor already supports padding through the `config.comprehensive_padding` option, but the padding format involves a few heuristics that are specific to Nvidia GPUs:
  - When we pad, it is always aligned to the next multiple of 128 bytes.
  - Strides smaller than 1024 are not padded.
  - Only intermediate values are padded, not outputs.

 The last of these is not really GPU-specific, but there are certain cases where we may want to override it. For example, padding outputs is useful on hardware accelerators with specific memory alignment requirements, or for applications where performance is more important than conformity with eager mode.

 This PR surfaces padding parameters up to Inductor's config module, so the user can control them.
   - `config.pad_outputs`: choose whether to pad outputs (default: `False`)
   - `config.padding_alignment_bytes`: choose the alignment size for padding (default: `128`)
   - `config.padding_stride_threshold`:  choose the smallest stride that we will pad. For example, setting this to 0 will pad all unaligned strides. (default: `1024`)

 **Test plan**
 Added a new test in `test_padding.py` which tries various combinations of these options, checking that the output strides match our expectations.

  These changes should not affect perf, because the defaults are identical to Inductor's current behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133939
Approved by: https://github.com/shunting314

Co-authored-by: Yueming Hao <yhao@meta.com>
2024-09-02 05:56:33 +00:00
3cb5d25122 [Inductor] Apply loop split optimization in codegen_node (#132389)
This PR applies loop split optimization in codegen_node to avoid non-contiguous load. When the vector is loaded in a non-contiguous manner due to a division in the index, we eliminate the division by splitting the loop to avoid non-contiguous load.

Example:
```
import torch
import torch.nn as nn

class GNReLU(torch.nn.Module):
    def __init__(self, num_groups, num_channels):
        super(GNReLU, self).__init__()
        self.gn = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return torch.nn.functional.relu(self.gn(x))

input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last)
m = GNReLU(32, 960).eval()
compiled_m = torch.compile(m)

with torch.no_grad():
    compiled_m(input)
```

Generated code:

- Before:
```
cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(56)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<long>(17280L));
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(9216L); x2+=static_cast<long>(1L))
                        {
                            for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16);
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14);
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(9216L); x1+=static_cast<long>(1L))
                {
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(960L); x2+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)), 16);
                        auto tmp1 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr0[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp3 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr1[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x2), 16);
                        auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x2), 16);
                        auto tmp2 = tmp0 - tmp1;
                        auto tmp4 = static_cast<float>(276480.0);
                        auto tmp5 = at::vec::Vectorized<float>(tmp4);
                        auto tmp6 = tmp3 / tmp5;
                        auto tmp7 = static_cast<float>(1e-05);
                        auto tmp8 = at::vec::Vectorized<float>(tmp7);
                        auto tmp9 = tmp6 + tmp8;
                        auto tmp10 = tmp9.rsqrt();
                        auto tmp11 = tmp2 * tmp10;
                        auto tmp13 = tmp11 * tmp12;
                        auto tmp15 = tmp13 + tmp14;
                        auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0));
                        tmp16.store(out_ptr2 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)));
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

def call(args):
    arg2_1, = args
    args.clear()
    assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960))
    buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32)
    cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3)
    del arg2_1
    return (buf3, )
```

- After:
```
cpp_fused_native_group_norm_relu_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/vu/cvuckxaygqfovv2zu2byqhcmiejbke7mdhf2rpgpr5mlscdev2hg.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(56)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<long>(17280L));
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(9216L); x2+=static_cast<long>(1L))
                        {
                            for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16);
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14);
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(9216L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(32L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 16);
                            auto tmp1 = out_ptr0[static_cast<long>(x2 + (32L*x0))];
                            auto tmp4 = out_ptr1[static_cast<long>(x2 + (32L*x0))];
                            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x3 + (30L*x2)), 16);
                            auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x3 + (30L*x2)), 16);
                            auto tmp2 = at::vec::Vectorized<float>(tmp1);
                            auto tmp3 = tmp0 - tmp2;
                            auto tmp5 = static_cast<float>(276480.0);
                            auto tmp6 = tmp4 / tmp5;
                            auto tmp7 = static_cast<float>(1e-05);
                            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                            auto tmp9 = 1 / std::sqrt(tmp8);
                            auto tmp10 = at::vec::Vectorized<float>(tmp9);
                            auto tmp11 = tmp3 * tmp10;
                            auto tmp13 = tmp11 * tmp12;
                            auto tmp15 = tmp13 + tmp14;
                            auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0));
                            tmp16.store(out_ptr2 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)));
                        }
                        for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                        {
                            auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14);
                            auto tmp1 = out_ptr0[static_cast<long>(x2 + (32L*x0))];
                            auto tmp4 = out_ptr1[static_cast<long>(x2 + (32L*x0))];
                            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x3 + (30L*x2)), 14);
                            auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x3 + (30L*x2)), 14);
                            auto tmp2 = at::vec::Vectorized<float>(tmp1);
                            auto tmp3 = tmp0 - tmp2;
                            auto tmp5 = static_cast<float>(276480.0);
                            auto tmp6 = tmp4 / tmp5;
                            auto tmp7 = static_cast<float>(1e-05);
                            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                            auto tmp9 = 1 / std::sqrt(tmp8);
                            auto tmp10 = at::vec::Vectorized<float>(tmp9);
                            auto tmp11 = tmp3 * tmp10;
                            auto tmp13 = tmp11 * tmp12;
                            auto tmp15 = tmp13 + tmp14;
                            auto tmp16 = at::vec::clamp_min(tmp15, decltype(tmp15)(0));
                            tmp16.store(out_ptr2 + static_cast<long>(x3 + (30L*x2) + (960L*x1) + (8847360L*x0)), 14);
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

def call(args):
    arg2_1, = args
    args.clear()
    assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960))
    buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32)
    cpp_fused_native_group_norm_relu_0(arg2_1, _frozen_param3, _frozen_param2, buf0, buf1, buf3)
    del arg2_1
    return (buf3, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132389
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
2024-09-02 00:28:34 +00:00
86e03a64e1 Revert "[Inductor] Allow customizing the padding format (#133939)"
This reverts commit 8b258b3b14408986a1d4142cff5a153c798ceecc.

Reverted https://github.com/pytorch/pytorch/pull/133939 on behalf of https://github.com/ZainRizvi due to sorry but this PR is causing issues with diff train imports reverting it for now but it can be merged back in as-is ([comment](https://github.com/pytorch/pytorch/pull/133939#issuecomment-2322635388))
2024-08-31 00:38:30 +00:00
8b258b3b14 [Inductor] Allow customizing the padding format (#133939)
Based on https://github.com/pytorch/pytorch/pull/130956.

Inductor already supports padding through the `config.comprehensive_padding` option, but the padding format involves a few heuristics that are specific to Nvidia GPUs:
  - When we pad, it is always aligned to the next multiple of 128 bytes.
  - Strides smaller than 1024 are not padded.
  - Only intermediate values are padded, not outputs.

 The last of these is not really GPU-specific, but there are certain cases where we may want to override it. For example, padding outputs is useful on hardware accelerators with specific memory alignment requirements, or for applications where performance is more important than conformity with eager mode.

 This PR surfaces padding parameters up to Inductor's config module, so the user can control them.
   - `config.pad_outputs`: choose whether to pad outputs (default: `False`)
   - `config.padding_alignment_bytes`: choose the alignment size for padding (default: `128`)
   - `config.padding_stride_threshold`:  choose the smallest stride that we will pad. For example, setting this to 0 will pad all unaligned strides. (default: `1024`)

 **Test plan**
 Added a new test in `test_padding.py` which tries various combinations of these options, checking that the output strides match our expectations.

  These changes should not affect perf, because the defaults are identical to Inductor's current behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133939
Approved by: https://github.com/shunting314

Co-authored-by: Yueming Hao <yhao@meta.com>
2024-08-30 20:34:11 +00:00
1e92d7b688 [inductor] move loop ordering after fusion (#126254)
Restart the work from PR https://github.com/pytorch/pytorch/pull/100331 in this new PR since it's hard to rebase. It would be expected that some code is copy/pasted from the previous PR and main idea is the same.

Previously we see relatively large compilation time increase due to too many loop orders being considered. This PR tries to continue the work by doing pruning and only considering loop orders that we know for sure are relevant (i.e. do it on demand).

Some manually created cases that loop ordering matters are added as unit tests. The PR can make sure inductor does not miss fusion opportunities for them.

This PR should solve the not-able to fusion problem in https://github.com/pytorch/pytorch/issues/130015

Right now there is still significant increase of compilation time. I'll disable the feature by default. Later on after the compilation time issue is resolved, I'll enable it  by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126254
Approved by: https://github.com/jansel
2024-08-29 21:50:07 +00:00
4f9c68454a [inductor]Let output or input_as_strided match exact strides (#130956)
Fixes #130394

TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in https://github.com/pytorch/pytorch/issues/130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue.  This PR enables dense and non-dense outputs' strides follow the strides required by semantics.

The comparison between the original and after this fix for the test is the below.

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 8
    x1 = (xindex // 8)
-   x2 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask)
    tmp1 = tmp0 + tmp0
-   tl.store(out_ptr0 + (x2), tmp1, xmask)
+   tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask)

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (16, 8), (16, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
-       buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32)
+       buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32)
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0)
        del arg0_1
    return (buf1, )
```

The buf1 is created with exact stride required by users, and its values are written in same stride with the input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130956
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/desertfire
2024-08-29 03:06:58 +00:00
13049cd6e5 [aotinductor][UserDefinedTritonKernel] fix case with non-constexpr params declared after autotuned params (#134520)
## Context
In some user Triton kernels, we have this set-up for whatever reason.
```
@triton.jit
def mykernel(
  param0,
  param1,
  param2,
  param3: tl.constexpr,   # autotuned
  param4,                 # non-constexpr
):
  ...
```

This is an edge case because it's a general practice to declare all constexprs params at the end.

And this will be an issue for AOTI because it fails to codegen all 4 params. That will surface as a device-side error: CUDA IMA, invalid argument...

```
>     void* kernel_args_var_0[] = {&var_0, &var_1, &var_2};
---
<     CUdeviceptr var_3;
<     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(buf0, reinterpret_cast<void**>(&var_3)));
<     void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3};
```

## Root-cause
* `kernel.constexpr` from the Kernel side-table contains the indices for all `constexpr` params that includes autotuned params.
* `raw_args`, that gets passed to wrapper codegen, excludes autotuned args.
* In the wrapper codegen, we try to find non-constexpr args using `kernel.constexpr` & `raw_args`. This is okay unless there's a `raw_arg` after an autotuned param in the function signature.

79b7fff188/torch/_inductor/codegen/cpp_wrapper_cuda.py (L118-L126)

## Fix
We try to fix this, by calculating the right constexprs wrt `raw_args`.

An illustration
```
         raw_args: [arg0, arg1, arg2, arg4]
 kernel.arg_names: [param0, param1, param2, param3, param4]
kernel.constexprs: [3]                      # param3 is autotuned; this is correct wrt kernel.arg_names
constexpr_indices: []                       # this is correct wrt raw_args
```

Differential Revision: [D61831625](https://our.internmc.facebook.com/intern/diff/D61831625)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134520
Approved by: https://github.com/oulgen
2024-08-27 17:20:27 +00:00
3322ee236d [aoti] remove c_shim_version v1 logic (#134283)
Summary: Previously, https://github.com/pytorch/pytorch/pull/132750 and https://github.com/pytorch/pytorch/pull/133105 set c_shim_version to 2 for all cases. So removing c_shim_version logic.

Test Plan: ci

Differential Revision: D61574695

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134283
Approved by: https://github.com/desertfire
2024-08-26 18:29:40 +00:00
17e8a51ff2 Revert "[inductor]Let output or input_as_strided match exact strides (#130956)"
This reverts commit a63efee5cd422db0aabe5d02d2fe35fef9be7978.

Reverted https://github.com/pytorch/pytorch/pull/130956 on behalf of https://github.com/ZainRizvi due to sorry but this seems to cause internal tests to fail. Please see D61771533 for details ([comment](https://github.com/pytorch/pytorch/pull/130956#issuecomment-2310490049))
2024-08-26 15:31:23 +00:00
a63efee5cd [inductor]Let output or input_as_strided match exact strides (#130956)
Fixes #130394

TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in https://github.com/pytorch/pytorch/issues/130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue.  This PR enables non-dense outputs' strides follow the strides required by semantics.

The comparison between the original and after this fix for the test is the below.

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 8
    x1 = (xindex // 8)
-   x2 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask)
    tmp1 = tmp0 + tmp0
-   tl.store(out_ptr0 + (x2), tmp1, xmask)
+   tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask)

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (16, 8), (16, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
-       buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32)
+       buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32)
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0)
        del arg0_1
    return (buf1, )
```

The buf1 is created with exact stride required by users, and its values are written in same stride with the input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130956
Approved by: https://github.com/eellison, https://github.com/blaine-rister
2024-08-24 17:04:05 +00:00
afd081c9d4 [inductor] Fix needs_fixed_stride_order silent incorrectness (#133639)
Fixes #128084

The approach is option 2 of what Elias suggested in the comment
thread:
- We require tensors to have the correct stride at usage. This may
  involve a clone; if there was a clone and then a mutation into it
  then we copy_ back the result of the mutation.

The reason why I went this approach was because it was the easiest and
Inductor already works really hard to remove additional clones/copy_.

There are some cases that this doesn't generate efficient code for; for
example, if the tensor is a view, we don't change the base of the view
to have the right stride order, instead we do a clone.
The view case isn't very common so I'm ignoring it for now but we could
improve this in the future.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133639
Approved by: https://github.com/eellison
2024-08-23 17:07:58 +00:00
09127b096c Revert "[inductor] Fix needs_fixed_stride_order silent incorrectness (#133639)"
This reverts commit 8604c0a150b12e0ba3f9a6faaf52498370f21368.

Reverted https://github.com/pytorch/pytorch/pull/133639 on behalf of https://github.com/jeanschmidt due to Broke internal fbgemm signals, see [D61670495](https://www.internalfb.com/diff/D61670495) ([comment](https://github.com/pytorch/pytorch/pull/133639#issuecomment-2307133060))
2024-08-23 13:48:04 +00:00
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
8604c0a150 [inductor] Fix needs_fixed_stride_order silent incorrectness (#133639)
Fixes #128084

The approach is option 2 of what Elias suggested in the comment
thread:
- We require tensors to have the correct stride at usage. This may
  involve a clone; if there was a clone and then a mutation into it
  then we copy_ back the result of the mutation.

The reason why I went this approach was because it was the easiest and
Inductor already works really hard to remove additional clones/copy_.

There are some cases that this doesn't generate efficient code for; for
example, if the tensor is a view, we don't change the base of the view
to have the right stride order, instead we do a clone.
The view case isn't very common so I'm ignoring it for now but we could
improve this in the future.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133639
Approved by: https://github.com/eellison
2024-08-21 22:54:16 +00:00
21906ddaba [AOTI] Fix complex64 not defined (#132810)
Partially fixes #122980

- change cpp type mapping for complex64 to std::complex<float>
- add `aoti_torch_item_complex64` and `aoti_torch_scalar_to_tensor_complex64`.
- add `expensiveCopyToTensor()` to convert `ArrayRefTensor<T>` type to `AtenTensorHandle` type.

- if we want to fully fix #122980, we still need to let ArrayRef and MiniArrayRef to consider underlying storage number of elements. See more details in https://github.com/pytorch/pytorch/pull/132347 (#132347  broke some internal tests, so we need more work before landing it).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132810
Approved by: https://github.com/desertfire
2024-08-08 18:08:23 +00:00
fd874b799f [AOTI][refactor] Update MKLDNN ops cpp wrapper support (#132367)
Summary: Set op_overload for MKLDNN ops so that cpp_kernel_name and python_kernel_name are constructed from there. This is an important step towards support those MKLDNN ops in the ABI-compatible mode, because we will need to read schema from op_overload for generating correct fallback op call in C++.

Differential Revision: [D60909798](https://our.internmc.facebook.com/intern/diff/D60909798)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132367
Approved by: https://github.com/leslie-fang-intel, https://github.com/angelayi
2024-08-08 03:02:29 +00:00
5cb05a82b4 [BC breaking] move benchmarking + prefer inductor path (#132827)
move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827
Approved by: https://github.com/eellison
2024-08-08 00:47:45 +00:00
e0514a5b99 [AOTI][refactor] Consolidate how python_kernel_name is set (#132320)
Summary: Similar to the refactoring of set_cpp_kernel, consolidate the ways of setting python_kernel_name

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132320
Approved by: https://github.com/angelayi, https://github.com/chenyang78
ghstack dependencies: #132319
2024-08-02 18:34:25 +00:00
a9e1133faa [AOTI][refactor] Move set_cpp_kernel to base class (#132319)
Summary: Consolidate how cpp_kernel_name is set and make it a method in the base ExternKernel class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132319
Approved by: https://github.com/angelayi, https://github.com/chenyang78
2024-08-02 18:34:24 +00:00
97eba8e174 [AOTI] Fix a typo in ExternKernel.codegen_const_args (#132191)
Differential Revision: [D60513923](https://our.internmc.facebook.com/intern/diff/D60513923)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132191
Approved by: https://github.com/chenyang78
2024-08-01 17:46:25 +00:00
f32ab3b9e3 Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.

See, repro here: P1453035092.

Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
2024-08-01 04:37:15 +00:00
f9e4d05c15 Save and run post compilation steps within FXGraphCache (#130572)
This PR mostly refactors by putting code into utils files so that they can be shared between codecache.py and compile_fx.py. Afterwards, it then changes compile_fx so that:
- When saving to FXGraphCache, we save onto the CompiledFXGraph all the necessary metadata for running post compile steps (realigning inputs, cudagraphification).
- When loading from FXGraphCache, we use the saved information directly, instead of calculating them from scratch.

What this does is make it so that `FXGraphCache.load()` is a perfect cache on compile_fx_inner, in that it **returns exactly what compile_fx_inner returns**. This also makes it possible for AOTAutogradCache, given a key to the fx graph cache and example inputs, to get back the full return value of compile_fx_inner.

## What's a post compile step?
We define a **post-compile** to be the set of actions that need to run after FXGraphCache either loads from the cache or misses and runs compilation. These steps include:
- Setting the tracing context's output strides
- Running cudagraphs if enabled
- Maybe realign inputs if cudagraphs didn't run

To run these steps, we save all the necessary metadata in CompiledFxGraph, and use them on a cache hit to reconstruct the object.

## Splitting cudagraphs work into pre/post compile
Cudagraphs does a lot of work on the input graph module to determine if cudagraphs can be enabled. This is the code that involves cudagraph_tests and stack traces. This will work in a world where we have access to the input graph module, but with AOTAutograd warm start, we won't have access to that information anymore. Therefore we can split cudagraphs work into two parts: on a cache miss (and therefore a full compile), we do the cudagraphs testing work, and save cudagraph_fail_reasons into the cache. Then on a cache hit, we know whether or not we can run cudagraphs, and if we can't, we can emit the correct error messages.

Implementation notes:
- We save `fx_kwargs` directly onto the CompiledFXGraph. `fx_kwargs` is already, by definition, part of the cache key, so this is safe to do when it comes to cache correctness.
- ^ Why do we do above even though FXGraphCache.load takes fx_kwargs as an argument? Because AOTAutogradCache **doesn't** have access to fx_kwargs: they're annoyingly encoded in the functools.partial() of the fw_compiler, so *only* inductor knows about these options. They're fully captured by the AOTAutogradCache key (since every key to fx_kwargs is either a global config, or a field that's deterministic based on an input graph module), but their values are still needed to run cudagraphs/postprocessing. Therefore, it's easier/safer to store it on the cached result.
- Willing to hear other approaches here if we think saving these extra fields is not reasonable, though I can't think of another way to do this that's less complicated to explain.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130572
Approved by: https://github.com/eellison
2024-07-31 18:32:40 +00:00
6214b5388b typing ir.py - part 1 (#131845)
See #131852

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131845
Approved by: https://github.com/Skylion007, https://github.com/eellison
2024-07-31 17:37:14 +00:00
784a6ec5a3 Revert "Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)"
This reverts commit 13d744464f10e35c0de50feb4e2340d4dae8e05f.

Reverted https://github.com/pytorch/pytorch/pull/130004 on behalf of https://github.com/clee2000 due to broke lint [GH job link](https://github.com/pytorch/pytorch/actions/runs/10183945999/job/28170099930) [HUD commit link](13d744464f) probably a landrace, the base is 21 hours old ([comment](https://github.com/pytorch/pytorch/pull/130004#issuecomment-2260946562))
2024-07-31 16:49:21 +00:00
13d744464f Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.

See, repro here: P1453035092.

Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
2024-07-31 16:22:11 +00:00
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde91ee20deb5ddcd9ff4593cd78d74c64.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
9e06572704 [Traceable FSDP2][Inductor] Create grouped nodes for FSDP2 all-gather code block and reduce-scatter code block (after Buffer/Operation split) (#131510)
This PR creates these `GroupedSchedulerNode`s:
- One for each all-gather code block (cast + copy-in + all-gather)
- One for each all-gather-wait code block (all-gather-wait + copy-out)
- One for each reduce-scatter code block (copy-in + reduce-scatter)
- One for each reduce-scatter-wait code block (reduce-scatter-wait)

This serves two goals:
- Prevent outside ops from being fused into these op groups, in order to have more predicable memory usage.
- Make it easier to specify the dependency e.g. from `i+1` all-gather group node to the `i` all-gather-wait group node, to enforce FSDP2 comm ordering (i.e. "serialization of comms").

The actual "reorder-for-FSDP-compute-comm-overlap" PR will come next.

Test commands:
- `pytest -rA  test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131510
Approved by: https://github.com/yifuwang
2024-07-27 08:39:58 +00:00
16cd1aaa1d [inductor] Improve sort kernel perf (#131719)
Closes #129507

This makes two changes to the sort kernel:
1. Use int16 for the indices since we only operate on small dims anyway
2. Instead of passing an explicit mask, we pass the rnumel and imply the
   mask from that which saves an additional reduction in the sort
   kernel's inner loop.

In my benchmarks, this gives enough of a perf improvement to bump up the
max rblock to 512.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131719
Approved by: https://github.com/eellison
2024-07-26 21:56:47 +00:00
1a2edf6dca [AOTI] Fix _mm_plus_mm codegen (#131689)
Summary: Fixes https://github.com/pytorch/pytorch/issues/128474

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131689
Approved by: https://github.com/chenyang78
2024-07-26 16:50:12 +00:00
d57de73fe0 AutoHeuristic: Add support for kernel choice selection (#131610)
This PR enables AutoHeuristic for kernel choice selection, where the feedback can not immediately be provided when AutoHeuristic is called, but only after autotuning has happened. The steps are the following:

When the AutoHeuristic constructor is called, AutoHeuristic registers a function in select_algorithm.py.
After autotuning in select_algorithm.py has happened, and there is an entry in autoheuristic_registry, select_algorithm provides the autotuning results to AutoHeuristic, which stores the results.
I enabled AutoHeuristic for mixed_mm to have an example to test it on. We probably want to add more context, and also add an augment_context function. I will add support for this in another PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131610
Approved by: https://github.com/eellison
2024-07-26 16:35:55 +00:00
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
3ce6f61416 [AOTI] Support fallback ops not in inductor_fallback_ops (#131247)
Summary: For aten ops that are not listed in inductor_fallback_ops, AOTI will use proxy executor to execute them instead of erroring out as missing C shim implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131247
Approved by: https://github.com/angelayi
2024-07-24 19:16:43 +00:00
89d5391bbf [inductor] Kill mark_node_as_mutating (#130834)
Resubmit of #129346

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130834
Approved by: https://github.com/lezcano
ghstack dependencies: #130832, #130833
2024-07-24 11:11:19 +00:00
6415c45da5 [inductor] Use multiple outputs for flex-attention (#130833)
Resubmit of #129344

This fixes the DCE issue for attention output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130833
Approved by: https://github.com/lezcano
ghstack dependencies: #130832
2024-07-24 11:11:19 +00:00