TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
Our three main users are OK with this, with two of them (foreach_map,
invoke_quant) prefering it like this.
I was originally worried about BC issues (this now means you cannot add
any positional args) but I think that's not a concern -- one can always
add kwonly args.
Test Plan
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146730
Approved by: https://github.com/ydwu4, https://github.com/mlazos
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).
The primary motivations are
- Unifying scattered reasoning for quant operators throughout the code base
- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:
```
@register_graph_pattern(
CallFunction(
torch.ops.aten.mm,
CallFunction(
torch.ops.higher_order.invoke_quant,
Ignored(),
Ignored(),
Ignored(),
scheme="nf4",
),
Arg(),
),
pass_dict=test_pass,
)
```
- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.
Example graph:
``` Python
===== AFTER POST GRAD =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
# File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
repeated_subgraph0 = self.repeated_subgraph0
invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None
return (invoke_quant,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
# File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None
add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None
return add
```
The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.
I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.
```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.
Feedback welcome.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
With the `_scaled_dot_product_efficient_attention.default`, we have lowering logic to realize the bias to specific alignment constraints. Some of the dims can be expanded, and we need to keep the stride of that dim to 0 to avoid materializing a larger tensor than we need. Previously, we had checked stride of tensor, but if it is not realized, that will not work. so we should check the strides of the meta as well.
Note: getting the exact of realizing/slicing/requiring_exact_strides was a little tricky. I commented to @exclamaforte on an example unable-to-fuse message you get if you do it incorrectly.
Fix for https://github.com/pytorch/pytorch/issues/145760
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146054
Approved by: https://github.com/shunting314
This PR implements the user-facing dim change, i.e., that the scan dim provided by the user is always moved to dim 0 and then the associative_scan operation always operates on dim 0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139864
Approved by: https://github.com/ydwu4
Record input fake tensors at time of tracing and store them in the node meta. Inductor passes have the possibility of changing strides, so it is safer to record the strides of the inputs at tracing. See, https://github.com/pytorch/pytorch/issues/137979 for more context.
We can also extend this to custom ops, and user-visible outputs. If this ends up being compilation time sensitive we can just record strides (and maybe storage offset, per @zou3519) instead of the complete fake tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145448
Approved by: https://github.com/zou3519
Summary:
Remove torch.ops.aten._assert_tensor_metadata.default in post_grad_pass because this op is blocking fusion.
This should not have any affect on the result, because the op would not show up in the final aoti compiled model anyway (the assertion has no effect).
An real example where this improves performance:
In the example below, the post grad graph would contain `torch.ops.aten._assert_tensor_metadata.default`, because of PR https://github.com/pytorch/pytorch/pull/142420. This op is added when functionalizing aten.to.
We want the `add` node from `linear` to be fused with the rest of the pointwise ops, instead of fused with the `mm` from `linear`.
```
class Model(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super(Model, self).__init__()
self.linear = nn.Linear(input_dim, hidden_dim).half()
self.rms_norm = nn.RMSNorm(hidden_dim)
def forward(self, x):
linear_458 = self.linear(x) # Linear layer with weights'
# mimic the torchtune rms norm: /torchtune/torchtune/modules/rms_norm.py
linear_458 = linear_458.to(torch.float32)
rms_norm_34 = self.rms_norm(linear_458) # RMS Normalization
sigmoid_168 = torch.sigmoid(rms_norm_34) # Sigmoid activation function
mul_168 = sigmoid_168 * rms_norm_34 # Element-wise multiplication
return mul_168
def main():
with torch.no_grad():
input_dim = 512
hidden_dim = 256
batch_size = 32
model = Model(input_dim, hidden_dim).to("cuda")
example_inputs = (
torch.randn(batch_size, input_dim).to("cuda").to(torch.float16),
)
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(ep)
```
Test Plan:
CI
Differential Revision: D68303114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145028
Approved by: https://github.com/angelayi
Requested in #77764
PR is still in draft because it needs some cleanups and optimizations to get to cpu performance the least. Tasks:
- [x] Make `upper=True` work, only `upper=False` works now
- [x] Code cleanup
- [x] Optimizations(Though might need some help on this)(tried my best, maybe there is still some more to squeeze out)
- [x] Checks for positive definite input
- [x] Support for (*, N, N) input, currently only supports (B, N, N) input
- [x] Support other dtypes(float16, bfloat16)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144193
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Fixes#143738. Currently the scaler for averaging is rounded to 0 if dtype is an integer, resulting to all-zero output. This fix uses `truediv` instead for integer cases.
## Test
```bash
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_avg_pool1d_cpu_int64
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_avg_pool2d_cpu_int64
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_avg_pool3d_cpu_int64
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_local_response_norm_cpu_int64
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144059
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel, https://github.com/jgong5
This is the initial foreach map HOP for pointwise ops which will be extended in the future to support grouped GEMMs and other ops.
This PR utilizes PrimHOPBase class to represent foreach_map as a HOP with a single subgraph. The way this is implemented is that the user API `foreach_map` provides a single pointwise torch op, and internally this function calls a polyfill which has the same semantics as a foreach op (ie iterates over lists of operands applying the op elementwise). The higher order op is passed through the stack down to inductor where a lowering in essence inlines the subgraph into the main graph. This is done by interpreting it with a pointwise subgraph lowering, grouping the outputs by device, and registering the output buffers as foreach groups as applicable. For testing I was able to reuse the existing foreach tests by creating a wrapper function which matches the foreach op interfaces for those tests and then run all of the existing foreach tests on foreach_map.
TODO before landing:
* Add tests for general functions
* Test warning if unsupported op will block fusion
Followups:
* I need to add tests for backwards (this will be a followup PR because backwards will require other work as well)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142098
Approved by: https://github.com/eellison