Fixes#156052 and #156444.
This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.
Changes done in this PR:
1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.
This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
Fixes#156052 and #156444.
This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.
Changes done in this PR:
1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.
This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
The motivation for this change can be seen through the following example:
```
import torch
GPU_TYPE = "cuda"
@torch.compile
def no_override(x):
return x.sum(dim=0)
@torch.compile
def override(x):
return x.sum(dim=0)
x_small = torch.randn(4096, 512, device=GPU_TYPE)
no_override(x_small)
torch._dynamo.decorators.mark_dynamic(x_small, 0, hint_override=4096 * 1000)
override(x_small)
```
Previously, when reductions were split, codegen relied only on the first observed shape. With a small input, this resulted in a small split size:
```
def triton_red_fused_sum_0(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 16384
rnumel = r0_numel
```
With the new scheme, inductor honors hint_override during codegen, producing larger and more appropriate split sizes:
```
def triton_red_fused_sum_0(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1024000
rnumel = r0_numel
```
This addresses a broader problem with dynamism: performance and numerics previously depended on whichever shape was seen first. For example:
```
f(s0) -> f(s2)
f(s1) -> f(s2)
```
could generate different kernels. With the new approach, an explicit override pins the chosen configuration:
```
f(s0, hint_override=s0) -> f(s2)
f(s1, hint_override=s0) -> f(s2)
```
ensuring consistent kernel generation regardless of input order.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161007
Approved by: https://github.com/jansel
The motivation for this change can be seen through the following example:
```
import torch
GPU_TYPE = "cuda"
@torch.compile
def no_override(x):
return x.sum(dim=0)
@torch.compile
def override(x):
return x.sum(dim=0)
x_small = torch.randn(4096, 512, device=GPU_TYPE)
no_override(x_small)
torch._dynamo.decorators.mark_dynamic(x_small, 0, hint_override=4096 * 1000)
override(x_small)
```
Previously, when reductions were split, codegen relied only on the first observed shape. With a small input, this resulted in a small split size:
```
def triton_per_fused_sum_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 512
r0_numel = 32
```
With the new scheme, inductor honors hint_override during codegen, producing larger and more appropriate split sizes:
```
def triton_red_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 16384
r0_numel = 128
```
This addresses a broader problem with dynamism: performance and numerics previously depended on whichever shape was seen first. For example:
```
f(s0) -> f(s2)
f(s1) -> f(s2)
```
could generate different kernels. With the new approach, an explicit override pins the chosen configuration:
```
f(s0, hint_override=s0) -> f(s2)
f(s1, hint_override=s0) -> f(s2)
```
ensuring consistent kernel generation regardless of input order.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161007
Approved by: https://github.com/jansel
I feel uneasy about touching `__warningregistry__` since it is undocumented and private surface. The only public API hook that doesn't increment warnings version seems to be https://docs.python.org/3/library/warnings.html#warnings.showwarning.
So we could wack a mole all the warnings muters in compile to just not display warnings, and we wouldn't invalidate warnings cache. This PR adds it for torch/_dynamo, and I didn't find any warnings versioning mutation from torch/_inductor.
There is a behavior change if someone calls a compiled graph with simplefilter("error"):
```python
# e.g. test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing
with warnings.catch_warnings():
warnings.simplefilter("error") # turns all warnings into errors
compiled_fn() # will throw if any of the muted warnings fire
```
FIXES https://github.com/pytorch/pytorch/issues/128427
A note for the future: The warnings module doesn't offer a thread safe way of using it. Even regular filters have this problem, directly editing `__warningregistry__` would be very bad, and this PR would mute all threads. Someone will need to build a thread safe warnings interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158520
Approved by: https://github.com/anijain2305, https://github.com/zou3519
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944
Approved by: https://github.com/Skylion007
PR does following
* Turns `inference_mode` to False and `no_grad` for `convert_frame`, if the inference_mode is on globally.
* Turns off inference_mode for fake tensor prop. This ensures that converting from real inference tensor to a fake tensor removes the inference-ness.
* Graph breaks on is_inference and is_inference_mode_enabled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149321
Approved by: https://github.com/jansel, https://github.com/zou3519
Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
A few changes to MetaTensorDesc and friends:
1. Change view_func from a raw method to an ADT where the common case (FakeTensor._view_func_unsafe) is a simple representation instead.
2. (minor) Remove and fix some `type: ignore`s added by #141839
3. (minor) Fix _UNSERIALIZABLE to be a set instead of a dict which is converted into a set each time it's used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141926
Approved by: https://github.com/ezyang
Combines contributions from https://github.com/pytorch/pytorch/pull/130505
Some context can be found in this large comment block:
a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)
Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)
Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.
Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
Rewrite of original PR in https://github.com/pytorch/pytorch/pull/130291
To answer review comments from https://github.com/pytorch/pytorch/pull/130291#pullrequestreview-2166671953:
> At a higher level, do we need this?
Today, this should not change the behavior of anything. But an invariant of "same tensor always corresponds to the same FakeTensor" is nice (from discussion with @bdhirsh).
> Why does this happen?
Today, both dynamo and meta_utils do some recursion when it comes to FakeTensors. So whenever we fakify a subclass, the process would roughly like:
```
wrap_to_fake (subclass)
meta_utils (subclass)
meta_utils (values) -> not cached because we use callback
meta_utils(offsets) -> not cached because we use callback
wrap_to_fake (values)
wrap_to_fake (offsets) -> cached because we rely on top-level meta_utils
```
However, we know that:
- Caching only occurs at the top-level of meta_utils.
- The return value of the top-level wrap_to_fake is returned.
This means that after all of this:
- The fakified subclass holds inner FakeTensors that are NOT part of the cache
- values/offsets are Fakified a second time, and those instances are cached.
Differential Revision: [D60406773](https://our.internmc.facebook.com/intern/diff/D60406773)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131803
Approved by: https://github.com/ezyang
ghstack dependencies: #131916
Adds support for SymInts in the FakeTensor cache.
A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.
Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO: cache_hits: 68137
INFO: cache_misses: 837
INFO: cache_bypasses:
INFO: symbolic shape: 48224
INFO: CompositeImplicitAutograd: 917
INFO: non-fake tensor: 70
INFO: non-FakeTensor output: 62
INFO: non-builtin: 8
INFO: dynamic output shape: 1
```
and after:
```
INFO: FakeTensor cache stats:
INFO: cache_hits: 88187
INFO: cache_misses: 14233
INFO: cache_bypasses:
INFO: CompositeImplicitAutograd: 1037
INFO: non-FakeTensor output: 602
INFO: non-fake tensor: 70
INFO: unsafe view: 36
INFO: non-builtin: 8
INFO: dynamic output shape: 1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127596
Approved by: https://github.com/eellison
ghstack dependencies: #131014, #129780
When handling an input to dynamo that's a view of a subclass, dynamo does some handling to reconstruct the view. Part of this is to construct symints for the input parameters to the view.
Previously, the code would just call `create_symbol()` which by default specifies a _positive_ symint (>= 0); this fails in the case where you have an aten::view that was called with a -1.
Fix: just specify `positive=None` when calling `create_symbol()`, to avoid restricting the symint to >= 0 or <= 0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128662
Approved by: https://github.com/jbschlosser
This adds dumps of MetaTensorDesc and MetaStorageDesc to structured logs
when they are triggered from Dynamo. The logs look like this:
```
V0522 08:13:25.267000 140224882566144 torch/_subclasses/meta_utils.py:195] {"describe_storage": {"id": 0, "describer_id": 0, "size": 32}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
V0522 08:13:25.267000 140224882566144 torch/_subclasses/meta_utils.py:220] {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [8], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "<built-in method _view_func_unsafe of Tensor object at 0x7f882959e840>", "describer_id": 0}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
V0522 08:13:25.268000 140224882566144 torch/_subclasses/meta_utils.py:1594] {"describe_source": {"describer_id": 0, "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
```
The `describer_id` is used to disambiguate ids. We expect it to be
unique per frame id, but if there is a bug it possibly is not. Note you will get
redundant dumps when evaluation restarts.
tlparse can use this to give a visualization of input tensors to a
model, you could also use this to generate example inputs to run graphs
on.
Some care is taken to avoid redumping the tensor metadata multiple
times, which would happen ordinarily because AOTAutograd refakifies
everything after Dynamo, to deal with metadata mutation.
Partially fixes https://github.com/pytorch/pytorch/issues/126644
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126879
Approved by: https://github.com/jamesjwu
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127122
Approved by: https://github.com/kit1980
Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7228787720582401/
There a few improvements here, which luckily fix some xfails:
* In generally, it can be unsafe to call operations on Tensors under a `no_dispatch()` mode that is purely trying to disable ambient modes, because this ALSO disables tensor subclass handling. So we test to see if there is a tensor subclass and don't propagate real tensors if that's the case. Another acceptable outcome might be to try to only disable the ambient fake tensor mode, this would help us propagate real tensors through more exotic tensor types, but I'm not going to do it until someone asks for it.
* We're graph breaking for wrapped tensors too late. Pull it up earlier so we do it before we try to muck around with the real tensor.
* I noticed that occasionally when I do `storage.copy_(real_storage)`, the sizes mismatch. Careful code reading suggests that I should just copy in the real data when the tensor was initially allocated, so that's what I do now, eliminating the need for a storage copy.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126281
Approved by: https://github.com/Skylion007
Internal xref:
https://fb.workplace.com/groups/6829516587176185/posts/7211398545654652/
Previously I did it in a crappy way using clone_input in the callback,
but this results in tensors that don't have quite the same
size/stride/storage offset and there was an internal test case where
not having completely accurate information was causing a downstream
problem in propagation. So now I make real tensors as similar to their
fake equivalents as much as possible. Though... I don't bother with
autograd lol.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126175
Approved by: https://github.com/albanD
More details further down, but first a more high-level description of "how do we functionalize storage resizing"
Today, dynamo converts `param.untyped_storage().resize_(x)` calls that it sees from fsdp into a custom op, `ops.inductor.resize_storage_bytes_(x)`
So given this setup, there are 3 main cases that I think we want to handle:
(1) graph input starts with a real storage size, gets resized down to zero in the graph
(2) graph input starts with 0 storage size, gets resized up in the graph
(3) graph input starts with 0 storage size, gets resized up and used in some compute, then resized back down to 0
For case (1) we need to emit a `resize_storage_bytes_` at the end of the graph, similar to how we emit `copy_()` for data mutations.
For case (2), we need to emit a `resize_storage_bytes_` in the graph, and we **also** need to emit a `copy_()` (the input had its storage resized up, and filled in with data, which is we need to reflect as an input mutation)
For case (3), the net effect is that the input had no data on entry and exit of the function, so we don't need to emit any mutable ops in the end of the graph.
The main thing to call out is that: we need to write a functionalization rule for `resize_storage_byte_`, (`FunctionalTensorWrapper::storage_resize_()`) and this rule actually does very little. We would like to **not** emit any new ops in the graph (like say, a functional resize op). Instead, we should expect / rely on the fact that any resize up will be immediately followed by a `copy_()`/`foreach_copy_`/`out=` op, that will fill in the data of the tensor. So `FunctionalTensor` can temporarily live in a state where its data is invalid, until the `x.copy_(y)` "updates" its data with the new tensor.
So effectively, all that this rule does is:
(1) it stores metadata on the storage, indicating that the tensor was resized, as well as the updated storage size. We need this info in AOTAutograd, so it knows whether to emit a mutable resize_() op in the graph epilogue
(2) There is also a corner case: if we are resizing down to zero, but our tensor had **previously** had a zero size storage, then we update `value_` to point to the original value of the tensor. The reason this seems safe is because if we have a zero storage sized tensor `x`, and we resize it up, use it in some compute, resize it back down to zero, and use it somewhere, we would want the functional version of this code to use the original `x` after the second resize. For FSDP, this is important because we end up saving parameters (graph inputs) for backward, and we want to make sure that the thing we save (and the output to the forward graph) is the original, zero-storage-sized parameter, and not the "version 2" of the parameter after the first resize_()
I think a good order to look at changes in this PR would be:
(1) `test_aotdispatch.py` shows the 3 main cases I focused on as well as the expected functionalized graphs
(2) In `FunctionalStorageImpl.h/cpp`, I had to add a notion of "original base", and "original/curr_size". The first is so I can re-use the zero-size tensor after multiple resizes, and the second is so I can tell in AOTAutograd whether any resizes canceled each other out into a no-op
(3) FunctionalTensorWrapper.h/cpp has the new resize functionalizion rule + some extra utils
(4) `_functorch/_autograd`: the main changes in this folder were around adding the logic at trace-time to detect when we need to put a resize_() in the graph. I also have some assertions to check that any inputs that experience storage resizing will **always be in the graph** and not the opaque epilogue, and I also limited the resize_() mutation case so that you can only ever start with zero storage, or end with zero storage (you can't do e.g. `torch.ones(2).storage().resize_(3)`), and banned it on tensor subclasses
(5) `fake_tensor.py`/`meta_utils.py`: we now need to be able to fakeify tensors with zero storage, so I added a quick version of it in meta_utils.py. This also.. has ramifications for fake tensor caching that I need to fix (include the storage size on the cache key, maybe?)
------------------
This PR subsumes https://github.com/pytorch/pytorch/pull/120971.
This PR is enough to **almost** get a simple ppFSDP forward pass tracing with a functionalized resize_() properly. It also attempts to do the updated version from @jansel, where we don't have any notion of `resize_()` in the graph at all, post functionalization. It would probably be good to test it with @yf225 's FSDP changes, and see how many of the FX passes it allows us to remove. I think that in theory, it should allow us to remove all FX passes that affect the forward graph / partitioner, **except** the one that forces views to be recomputed in the backward (more details below).
There are a few things worth calling out:
(1) failed attempt at functionalizing `aten.copy_()`. I originally wanted to get a version takes these operations:
```
param.storage().resize_(all_gather_size)
param.copy_(all_gather_buffer)
out = aten.matmul(param, param)
```
and functionalizes them into:
```
out = aten.matmul(all_gather_buffer, all_gather_buffer)
```
This would involve getting functionalization to turn `x.copy_(y)` into a giant no-op that just returns `y`. Unfortunately, we can't actually do this in a reasonable way within functionalization (instead, there's a functional `aten.copy` in the graph - see the test case graph expecttest for details). Why? In order for that transformation to be safe, `x` and `y` need to have the same metadata. However, it's possible for `x` and `y` to be subclasses of different types. This is not something we can easily tell from within functionalization, and would be a layering violation. So for now I'm leaving it to downstream code to optimize away the `aten.copy` (this is already the case today, so I think inductor can handle this)
(2) The forward doesn't **actually** run successfully in this PR (see the `assertRaisesRegex` in the test). Why?
The final forward graph looks like this:
```
def forward(self, primals_1, primals_2):
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_2 = None
getitem = _foreach_copy[0]; _foreach_copy = None
mm = torch.ops.aten.mm.default(getitem, getitem); getitem = None
t_1 = torch.ops.aten.t.default(primals_1); primals_1 = None
return [mm, t_1]
```
Where `primals_1` starts out as a secretly-zero-storage-size parameter, and gets resized up and back down within the forward (these are functionalized away).
Importantly, the matmul happy on the result of the `foreach_copy`, **but** the activation that we save for backward (`t_1`) is the result of transposing the **original parameter** (the zero-storage-size param). This is exactly the optimization in fsdp that allows us to have good peak memory usage.
The problem is that the min-cut partitioner decides to save `t_1` for backward. Running this code in eager breaks, because the kernel for `aten.permute(x)` is not happy when `x` has secretly-zero-sized-storage.
The real problem here is that in eager mode the `permute` kernel runs during the backward, after backward hooks have properly resized the saved activation. Here, we are running the transpose in the forward.
One option would be to turn off the checks in our view kernels and allow them to work on zero-storage-sized tensors, which feels pretty bad. Another option is to tweak the partitioner (or use one of Will's FX passes) to force the partitioner to not save views for backward, and allow the views to be recomputed in the backward. This seems kind of silly, but is also probably harmless.
(3) The backward is still broken. To be fair, this issue is pretty separable from "functionalizing storage resize calls", and can be fixed later (either by a real fix to our tracing infra, or via another hacky FX pass). More description of this problem is described at issue (8) of my PR description in https://github.com/pytorch/pytorch/pull/120971
(4) I only added support for "full graph" resizing: basically, the limited case where a param starts with zero storage size, and gets resized up and back down. I think we can add support for the graph break case, but I think we can keep that add-on separate from this PR unless we need it immediately. I also added asserts so we should fail loudly when we hit this case
(5) I have a change to FakeTensor creation when inputs have zero storage size that.. is probably ok. But I also removed FakeTensor caching on view ops, which I probably need to fix before I can land this PR
(6) I added a notion of "original_base" to `FunctionalStorageImpl`. More details are in the comments, but my rational for this was that we basically need it to ensure that autograd saves the **original**, zero-storage-sized param for backward, after resizing up and back down
(7) I had to update our eager kernels for `aten.copy` and `aten._foreach_copy`, to handle the case where the `self` argument has secretly-zero-storage. Inductor can probably generate correct code for this case, but we need these ops to work properly in this situation for the `aot_eager` backend to do the right thing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122434
Approved by: https://github.com/jansel