> previous: Originally, the variables `new_eta` and `new_mu` would be constructed `len(grouped_mus)` times, but each of their values is the same and won't be changed. Therefore, it can be simplified using Python list multiplication, which only constructs one tensor.
- [X] Ill assumption that every param will have the same step.
- [x] DIfferent implementation between `foreach=Ture` and `foreach=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125440
Approved by: https://github.com/janeyx99
This commit introduces a meta function for the `channel_shuffle` operation, enabling PyTorch to perform shape inference and optimizations related to this operation without actual computation. The meta function assumes input shape (*, C, H, W) and validates that the number of channels (C) is divisible by the specified number of groups.
Fixes#122771
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123033
Approved by: https://github.com/ezyang, https://github.com/mikaylagawarecki
Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_print = torch.ops.aten._print('moo')
res = l_x_ + l_x_; l_x_ = None
_print_1 = torch.ops.aten._print('moo')
return (res,)
```
AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
def forward(self, arg1_1: "f32[2, 3]"):
_make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo'); _make_dep_token_default = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,)); getitem_2 = None
return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
arg1_1, = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
# Source Nodes: [_print], Original ATen: []
buf2 = aten._print.default('moo')
# Source Nodes: [_print_1], Original ATen: []
buf3 = aten._print.default('moo')
buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, buf4)
del arg1_1
return (buf4, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh
`linalg_eigvals_out` calls into a dispatch stub, so only supports CPU and CUDA
strided tensors but incorrectly claimed to be a composite op. `linalg_eigvals`
also shouldn't defer to the out variant inside a `CompositeImplicitAutograd` op
as not all types support out variants. Instead, I add a new helper
`_linalg_eigvals` which does the same thing in a non-composite operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121142
Approved by: https://github.com/lezcano
**description**
Enable lowering of dynamic qlinear for X86Inductor. The pattern is `choose_qparams -> getitem -> q -> dq -> linear`. We only fuse `dq -> linear` and get `choose_qparams -> getitem -> q -> onednn.qlinear_pointwise`. So, we treat it as dynamic quantization of activation + static quantized linear.
The previous implementation of `onednn.qlinear_pointwise` is for the case where `x_scale` and `x_zp` are scalars. Since `choose_qparams` returns tensors, we added a variation `onednn.qlinear_pointwise.tensor` to support the case.
This feature is targeting PyTorch 2.3 release.
**Test plan**
```
python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_cpu
python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_qat_cpu
python inductor/test_cpu_cpp_wrapper.py -k test_dynamic_qlinear
```
**Performance before and after lowering `choose_qparam` to Inductor**
Before
- latency for shape (32, 32) = 0.151 ms
latency for shape (128, 128) = 0.153 ms
latency for shape (1024, 1024) = 0.247 ms
After
- latency for shape (32, 32) = 0.049 ms
- latency for shape (128, 128) = 0.052 ms
- latency for shape (1024, 1024) = 0.133 ms
Test method: A module with a single Linear layer, dynamic-quantize, lower to X86Inductor
Test env & config: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz, single instance, single core, using Intel OpenMP and Tcmalloc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120605
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This PR is mostly just code movement to make the code review easier - AFAIK it should not change any functionality. The final goal is to remove the xfails for some of the test_fake opinfos for these ops. The opinfos are failing because the outputs can have mixed devices - we need to move them to fake_impls first before we can support mixed device returns.
This PR:
* Move the `_meta_registrations.py` implementations to `fake_impls.py`
* Change the function signature from taking explicit named variables to taking `{args, kwargs}` and normalizing them
* Wrap all the returned tensors in FakeTensors
Tests: relying on opinfos. I also checked `test_fake_*` for these tests (by removing x-fails and patching things until they passed) to verify general correctness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120682
Approved by: https://github.com/drisspg
The first try reused TensorListMetadata, which caused illegal memory access issues when there were too many tensors in the list. We just launch multiple kernels with a simpler version of the struct (to minimize kernels launched).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119927
Approved by: https://github.com/albanD
Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case.
Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812
Approved by: https://github.com/drisspg
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
This should fix remaining errors with Resize op in torchvision: https://github.com/pytorch/vision/actions/runs/7298953575?pr=8127
```
/opt/conda/envs/ci/lib/python3.8/site-packages/torch/nn/functional.py:4072: in interpolate
return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors)
E torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function interpolate at 0x7f4443fe00d0>(*(FakeTensor(..., size=(1, s0, s1, s2)),), **{'size': [s4, floor(s3*s4/floor(s1*s3/s2))], 'mode': 'bicubic', 'align_corners': False, 'antialias': True}):
E aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:5567: SymIntArrayRef expected to contain only concrete integers
E
E from user code:
E File "/pytorch/vision/torchvision/transforms/v2/functional/_geometry.py", line 260, in resize_image
E image = interpolate(
E
E Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E You can suppress this exception and fall back to eager by setting:
E import torch._dynamo
E torch._dynamo.config.suppress_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117347
Approved by: https://github.com/peterbell10