The original issue is we see accuracy problem in a meta internal model [meta internal link](https://fb.workplace.com/groups/1075192433118967/posts/1567334737238065/). The debugging is hard but the root cause is relatively simple. The root cause is that the model has mix-device inputs for index.Tensor which causes Inductor to fallback. And the meta kernel for index.Tensor returns a tensor with inconsistent strides to the eager kernel.
The following code snippet
```
import torch
from torch._subclasses import FakeTensorMode
device = "cuda"
x = torch.randn((24, 16, 32, 32), device=device).to(memory_format=torch.channels_last)
x = x.view(2, 12, 16, 32, 32)
i1 = torch.arange(2).unsqueeze(-1)
i2 = torch.argsort(torch.rand(2, 12), dim=-1)[:, :3]
print(f"Eager stride: {x[i1, i2].stride()}")
mode = FakeTensorMode()
with mode:
f_x = mode.from_tensor(x)
f_i1 = mode.from_tensor(i1)
f_i2 = mode.from_tensor(i2)
f_out = f_x[f_i1, f_i2]
print(f"Meta stride: {f_out.stride()}")
```
would output:
```
Eager stride: (49152, 16384, 1, 512, 16)
Meta stride: (49152, 16384, 1024, 32, 1)
```
In this PR, I fix the problem to run eager kernel to get the index.Tensor fallback's output layout. A better solution would be to change meta/eager kernel implementation so that their output layout matches. But I'm not sure how to properly do that.
In the index.Tensor meta kernel, we always produce dense output: 6d56277682/torch/_meta_registrations.py (L3184) . While the eager kernel seems to leverage TensorIteratorBase to decide some dimension permutation: 6d56277682/aten/src/ATen/TensorIterator.cpp (L232-L308) . We can duplicate this logic to the meta kernel implementation if we really want meta matches eager. I can follow up on this if people have strong opinion to do this.
And here is an issue https://github.com/pytorch/pytorch/issues/144717 for asserting size/strides for fallback kernels. With that, the issue debugged here would be much easier to root cause.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144736
Approved by: https://github.com/jansel
Summary:
Fix `nonzero is not registered to meta` issue:
```
"NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered".
```
Reviewed By: ezyang
Differential Revision: D66525640
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144727
Approved by: https://github.com/ezyang
Tracking issue: #138399
This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:
```python
def run(device):
S = (5,)
a = torch.rand(S, device=device, dtype=torch.float32)
b = 2
out = torch.empty(S, device=device, dtype=torch.float64)
return torch.pow(a, b, out=out)
>>> run("cpu")
Traceback (most recent call last):
File "test.py", line 34, in run
return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float
>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```
**~Update:~**
~Note that this happens only for `pow.Tensor_Scalar` overloads. Therefore, this PR needed
further 2 modifications:~
- ~Split the `pow` ref implementation, making `pow.Tensor_Scalar` error on mismatching
output dtypes~
- ~Create a dispatch for `pow` when `_refs.pow()` is called~
**Update:**
Changing the `TensorIteratorConfig` for `pow.Tensor_Scalar` was easier and,
after the discussion below, more correct. The solution was to change the
`TensorIteratorBase::build_output_borrowing_argument_owning_unary_op` function,
setting:
- `cast_common_dtype_to_outputs`; and
- `enforce_safe_casting_to_output`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140287
Approved by: https://github.com/ezyang
This PR replaces the parameter names specified in the `triangular_solve_meta`
function (specifically in its `@out_wrapper(...)` decorator) by those written in the
_native_functions.yaml_ file.
This name mismatch caused the operation to fail when using the meta device (see error
below):
```python
Traceback (most recent call last):
File "examples/test.py", line 23, in <module>
torch.triangular_solve(b.to("meta"), A.to("meta"), out=meta_out)
File "torch/_decomp/__init__.py", line 100, in _fn
return f(*args, **kwargs, out=None if is_none else out_kwargs)
File "torch/_prims_common/wrappers.py", line 289, in _fn
result = fn(*args, **kwargs)
TypeError: triangular_solve_meta() got an unexpected keyword argument 'X'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140186
Approved by: https://github.com/ezyang
This PR resolves several sets of `_scaled_mm` test failures:
- `scale_a` and `scale_b` are now required arguments, so the function `sample_inputs_scaled_mm` must supply them
- `_scaled_mm` does not support `"meta"` device, so it should be skipped in `test_meta.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130897
Approved by: https://github.com/drisspg
Fixes https://github.com/pytorch/pytorch/issues/121085
This PR pretty involved so pay attention to this description. At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.
However, this transformation is not always entirely mechanical. Here
is what you need to pay attention to:
- The memo table from real Tensor -> meta/fake Tensor is now broken
into two memo tables: real Tensor -> stable int id -> meta/fake
Tensor. The stable int id is needed so that when we do serialization,
we know when tensors/storages alias each other and can ensure we preserve
this aliasing upon deserialization.
The way I have implemented changes the weak reference behavior.
Previously, when either the real Tensor OR the meta/fake Tensor went
dead, we would remove the entry from the memo table. Now, this only
removes entries from one of the two memo tables. This semantically
makes sense, because the user may have held on to the stable int id
out of band, and may expect a real Tensor to continue to be numbered
consistently / expect to be able to lookup a meta/fake tensor from
this id. If this is unacceptable, it may be possible to rejigger
the memo tables so that we have real Tensor -> stable int id
and real Tensor -> meta/fake Tensor, but TBH I find the new
implementation a lot simpler, and arranging the memo tables in this
way means that I have to muck around with the real tensor to save
to the memo table; in the current implementation, I never pass the
Tensor to meta_tensor function AT ALL, which means it is impossible
to accidentally depend on it.
- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
to be careful not to poke fields when they are not valid. Previously,
preconditions were implicitly checked via the conditional structure
("is this sparse? is this nested?") that is tested before we start
reading attributes. This structure has to be replicated in
describe_tensor, and I have almost assuredly gotten it wrong on my
first try (I'll be grinding through it on CI; a careful audit will
help too, by auditing that I've tested all the same conditionals that
the original access was guarded by.)
- I originally submitted https://github.com/pytorch/pytorch/pull/121821
for the symbolic shapes change, but it turned out the way I did it
there didn't actually work so well for this PR. I ended up just
inlining the symbolic shapes allocation logic into MetaConverter
(look for calls to maybe_specialize_sym_int_with_hint), maybe there
is a better way to structure it, but what I really want is to
just read sizes/strides/offset directly off of MetaTensorDesc; I
don't want another intermediate data structure.
- Some fields aren't serializable. These are documented as "NOT
serializable". ctx/type should morally be serializable and I just
need to setup a contract with subclasses to let them be serialized.
The fake_mode is used solely to test if we are refakefying with
a pre-existing ShapeEnv and we want to reuse the SymInt
directly--serializing this case is hopeless but I am kind of hoping
after this refactor we do not need this at all. view_func is not
serializable because it's a bound C implemented method. Joel has
promised me that this is not too difficult to actually expose as a
true data structure, but this is the edgiest of edge cases and there
is no reason to deal with it right now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
Fixes https://github.com/pytorch/pytorch/issues/121085
This PR pretty involved so pay attention to this description. At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.
However, this transformation is not always entirely mechanical. Here
is what you need to pay attention to:
- The memo table from real Tensor -> meta/fake Tensor is now broken
into two memo tables: real Tensor -> stable int id -> meta/fake
Tensor. The stable int id is needed so that when we do serialization,
we know when tensors/storages alias each other and can ensure we preserve
this aliasing upon deserialization.
The way I have implemented changes the weak reference behavior.
Previously, when either the real Tensor OR the meta/fake Tensor went
dead, we would remove the entry from the memo table. Now, this only
removes entries from one of the two memo tables. This semantically
makes sense, because the user may have held on to the stable int id
out of band, and may expect a real Tensor to continue to be numbered
consistently / expect to be able to lookup a meta/fake tensor from
this id. If this is unacceptable, it may be possible to rejigger
the memo tables so that we have real Tensor -> stable int id
and real Tensor -> meta/fake Tensor, but TBH I find the new
implementation a lot simpler, and arranging the memo tables in this
way means that I have to muck around with the real tensor to save
to the memo table; in the current implementation, I never pass the
Tensor to meta_tensor function AT ALL, which means it is impossible
to accidentally depend on it.
- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
to be careful not to poke fields when they are not valid. Previously,
preconditions were implicitly checked via the conditional structure
("is this sparse? is this nested?") that is tested before we start
reading attributes. This structure has to be replicated in
describe_tensor, and I have almost assuredly gotten it wrong on my
first try (I'll be grinding through it on CI; a careful audit will
help too, by auditing that I've tested all the same conditionals that
the original access was guarded by.)
- I originally submitted https://github.com/pytorch/pytorch/pull/121821
for the symbolic shapes change, but it turned out the way I did it
there didn't actually work so well for this PR. I ended up just
inlining the symbolic shapes allocation logic into MetaConverter
(look for calls to maybe_specialize_sym_int_with_hint), maybe there
is a better way to structure it, but what I really want is to
just read sizes/strides/offset directly off of MetaTensorDesc; I
don't want another intermediate data structure.
- Some fields aren't serializable. These are documented as "NOT
serializable". ctx/type should morally be serializable and I just
need to setup a contract with subclasses to let them be serialized.
The fake_mode is used solely to test if we are refakefying with
a pre-existing ShapeEnv and we want to reuse the SymInt
directly--serializing this case is hopeless but I am kind of hoping
after this refactor we do not need this at all. view_func is not
serializable because it's a bound C implemented method. Joel has
promised me that this is not too difficult to actually expose as a
true data structure, but this is the edgiest of edge cases and there
is no reason to deal with it right now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
ghstack dependencies: #122018
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
`linalg_eigvals_out` calls into a dispatch stub, so only supports CPU and CUDA
strided tensors but incorrectly claimed to be a composite op. `linalg_eigvals`
also shouldn't defer to the out variant inside a `CompositeImplicitAutograd` op
as not all types support out variants. Instead, I add a new helper
`_linalg_eigvals` which does the same thing in a non-composite operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121142
Approved by: https://github.com/lezcano
The first try reused TensorListMetadata, which caused illegal memory access issues when there were too many tensors in the list. We just launch multiple kernels with a simpler version of the struct (to minimize kernels launched).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119927
Approved by: https://github.com/albanD
This should fix remaining errors with Resize op in torchvision: https://github.com/pytorch/vision/actions/runs/7298953575?pr=8127
```
/opt/conda/envs/ci/lib/python3.8/site-packages/torch/nn/functional.py:4072: in interpolate
return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors)
E torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function interpolate at 0x7f4443fe00d0>(*(FakeTensor(..., size=(1, s0, s1, s2)),), **{'size': [s4, floor(s3*s4/floor(s1*s3/s2))], 'mode': 'bicubic', 'align_corners': False, 'antialias': True}):
E aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:5567: SymIntArrayRef expected to contain only concrete integers
E
E from user code:
E File "/pytorch/vision/torchvision/transforms/v2/functional/_geometry.py", line 260, in resize_image
E image = interpolate(
E
E Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E You can suppress this exception and fall back to eager by setting:
E import torch._dynamo
E torch._dynamo.config.suppress_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117347
Approved by: https://github.com/peterbell10
Context: pt2 oncall is revamping its labeling system. One of the guidelines is to remove duplicate labeling in our system. Both primTorch and decomposition labels are referring to the same thing. primTorch was the legacy name (and we no longer have a primTorch project), so using decomposition as the label name makes more sense.
Right now, the only open issues that use "module: primTorch" are the ones generated by the DISABLED bots. Once we replace the label in the bot, we can safely remove the primTorch label.
Here an example of the issue that has primTorch label :
https://github.com/pytorch/pytorch/issues/112719
Torchbot uses following logic to auto extract module owners:
https://github.com/pytorch/test-infra/blob/main/torchci/pages/api/flaky-tests/disable.ts#L391
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114754
Approved by: https://github.com/huydhn