This is not ready for review, this is to make sure asan is fixed.
Not sure what is the most effective way to track down the bad dec_ref within deploy yet.
The asan silencing is done to match this comment:
1c79003b3c/test/test_cpp_extensions_jit.py (L749-L752)
EDIT: since the final failing function is in libtorch_python.so, we would need to skip that whole lib (not ok). So now we're skipping based on the function name which should be restrictive enough to not hide any real bug.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103989
Approved by: https://github.com/malfet
Summary:
The planned e2e for quantization in pytorch 2.0 export is the following:
float_model -> prepare_pt2e -> calibration -> convert_pt2e -> ...
inside convert_pt2e, we will first produce a q/dq representation of the quantized model, similar to the previous output of
convert_to_reference_fx in fx grah mode quantization:
```
torch.ops.quantized_decomposed.dequantize_per_tensor -> torch.ops.aten.add -> torch.ops.quantized_decomopsed.quantize_per_tensor
torch.ops.quantized_decomposed.dequantize_per_tensor /
```
Then we'll rewrite the above to a more precise representation that express the intention in a more precise manner, since
here we actually want to do int8 addition, instead of simulating the int8 addition with fp32 operations, the representation for
quantized add is:
```
def quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point):
x = (x_scale / out_scale) * x_i8
y = (y_scale / out_scale) * y_i8
out = x + y
out -= (x_zero_point * x_scale - y_zero_point * y_scale) / out_scale
out += out_zero_point
return out
```
Test Plan:
```
buck2 test caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_representation_add (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)'
```
Reviewed By: kimishpatel
Differential Revision: D45628032
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104130
Approved by: https://github.com/kimishpatel
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
semi-structured sparse bool mask and creates an instance of the
subclass.
**SparseSemiStructuredTensor**
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
**to_sparse_semi_structured**
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
**User Details**
We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
```
from torch.sparse import to_sparse_semi_structured
mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()
linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
mask=linear.weight.bool())
```
This also updates tests and the `torch.sparse` module docstring to
reflect these changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
Summary:
## What is this?
This is a giant codemod to migrate all of fbcode from the tp2 version of gtest to the `fbsource/third-party` version.
## Why?
Various parts of the monorepo use different versions of gtest which are incompatible with each other and make maintenance of C++ testing more difficult than it should be. There also doesn't seem to be much reason for this fragmentation. Shifting all `gtest` dependencies towards `fbsource/third-party` is a big step in the right direction towards cleaning this up.
Also -- tp2 is deprecated, so we want to stop using that anyway. If we're going to make improvements to `gtest`, we should get away from tp2 as a first step.
## How?
I used bash script to perform the majority of the codemod: P777150295
I followed up with `rg` to find additional dependencies, then simply iterated a ton until CI was (mostly) happy.
This diff also includes an update to autodeps to use the `third-party/fbsource` version of gtest rather than the `tp2` version.
#forcetdhashing
Test Plan: CI
Differential Revision: D46961576
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104255
Approved by: https://github.com/huydhn
# Change
This PR adds two classes to DTensor:
1. `CudaRNGStateTracker`: `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).
2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.
# Warning
- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.
- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
This PR integrated the assertion functionalization logic into current export logic.
**NOTE:**
I finally decided to do the assertion functionalization after AOT export instead of before for the following reasons:
* The benefit of AOT export is that the graph is already functionalized so things like method call is already transformed to function call. However, if we do it before AOT export, the graph is still in torch level and extra logic like bab21d20eb/torch/_export/pass_base.py (L201-L204C17) will need to be implemented.
* The graph signature is kind of already incorrect after adding runtime assertions currently (this doesn't seem break logic since we already depend on positions instead of FQNs of outputs). This PR also fixed this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103887
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
Summary:
Also adds support for backend_config with relu fusion since XNNPACK allows it.
We should revisit the relu fusion once we gain more clarity on quantSrcPartition or some other way to do these fusion and not having to add all combinations.
We should really rename the backend config to et_xnnpack.py or something TODO
Test Plan: `buck test fbcode//mode/dev-nosan fbcode//executorch/backends/xnnpack/test:`
Differential Revision: D46985169
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104134
Approved by: https://github.com/mcr229, https://github.com/salilsdesai
Dispatch the selection function to prevent using `is_mps()` in `Histogram.cpp`.
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at b329a02</samp>
This pull request refactors and implements the logic for inferring the bin edges of histograms from the input tensor for different device types. It introduces a dispatch stub `histogram_select_outer_bin_edges_stub` and moves the device-specific code to separate files, such as `HistogramKernel.cpp` and `HistogramKernel.mm`. This improves the modularity and readability of the histogram functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101792
Approved by: https://github.com/albanD
Apart from introducing MPSProfiler, this PR also
1. removes the synchronization call after all the commands are encoded since the stream will be synchronized along the next graph op is encountered and run. One can take a look at this [PR](https://github.com/pytorch/pytorch/pull/99810) to get some insight.
2. initialize the offset calculation kernel's thread output with 0 to ensure the subsequent offset accumulation is correct. This change makes the kernel aligned with `kernel_index_offsets` kernel.
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 4094984</samp>
This change enables performance analysis of the `histogram` kernel on MPS devices by using the `MPSProfiler` class to collect and report relevant metrics. It modifies the file `HistogramKernel.mm` to add profiling calls around the kernel execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101692
Approved by: https://github.com/albanD
Prevents following cryptic error if one attempts to use `run_tests.py` on system that also has torchaudio installed in dev mode (as `tools` from https://github.com/pytorch/audio might take precedence, but this is not how script should behave):
```
Unable to import test_selections from tools/testing. Running without test selection stats.... Reason: No module named 'tools.stats'
Traceback (most recent call last):
File "/Users/nshulga/git/pytorch/pytorch/test/run_test.py", line 1673, in <module>
main()
File "/Users/nshulga/git/pytorch/pytorch/test/run_test.py", line 1604, in main
selected_tests = get_selected_tests(options)
File "/Users/nshulga/git/pytorch/pytorch/test/run_test.py", line 1418, in get_selected_tests
path = os.path.join(str(REPO_ROOT), TEST_TIMES_FILE)
NameError: name 'TEST_TIMES_FILE' is not defined
```
But make sure to remove it in the end, otherwise it will not work if torch is installed from wheel, but tests are running from clean repo checkout.
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at dd52521</samp>
> _Sing, O Muse, of the cunning code review_
> _That fixed the tests of the `tools` module_
> _By adding and removing the root path_
> _As a shepherd guides his flock to and fro._
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104214
Approved by: https://github.com/kit1980
Based on this [code search](https://fburl.com/code/gjcnw8ly) (*.yaml with `dispatch: CPU:`), update all files found to use
```
kernels:
- arg_meta: None
kernel_name:
```
instead of
```
dispatch:
CPU:
```
---
## Code changes:
- `fbcode/executorch/codegen/tools/gen_oplist.py`
- Strip ET specific fields prior to calling parse_native_yaml_struct
---
## Files edited that are not `*functions.yaml` or `custom_ops.yaml`
- fbcode/executorch/kernels/optimized/optimized.yaml
- fbcode/executorch/kernels/quantized/quantized.yaml
- fbcode/executorch/kernels/test/custom_kernel_example/my_functions.yaml
---
## Found Files that were not edited
**Dispatched to more than just CPU**
- fbcode/caffe2/aten/src/ATen/native/native_functions.yaml
- xplat/caffe2/aten/src/ATen/native/native_functions.yaml
- xros/third-party/caffe2/caffe2/aten/src/ATen/native/native_functions.yaml
**Grouped ops.yaml path**
- fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/ops.yaml
---
**Design Doc:** https://docs.google.com/document/d/1gq4Wz2R6verKJ2EFseLyPdAF0wqomnCrVDDJpRkYsRw/edit?kh_source=GDOCS#heading=h.8raqyft9y50
Differential Revision: [D46952067](https://our.internmc.facebook.com/intern/diff/D46952067/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D46952067/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104070
Approved by: https://github.com/larryliu0820
Note that in general it's not good form to try to make FakePG work with 'real data',
but the reasoning here is that we want FakePG to work with DeviceMesh's init code
that have the data validation, which makes it worth the tradeoff.
In general user should use MTPG or normal PG for cases where they may care about
real data from collectives
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104213
Approved by: https://github.com/wconstab, https://github.com/voznesenskym
Mostly refactor, that moves all the tests from `test_cuda` that benefit from multiGPU environment into its own file.
- Add `TestCudaMallocAsync` class for Async tests ( to separate them from `TestCudaComm`)
- Move individual tests from `TestCuda` to `TestCudaMultiGPU`
- Move `_create_scaling_models_optimizers` and `_create_scaling_case` to `torch.testing._internal.common_cuda`
- Add newly created `test_cuda_multigpu` to the multigpu periodic test
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at f4d46fa</samp>
This pull request fixes a flaky test and improves the testing of gradient scaling on multiple GPUs. It adds verbose output for two CUDA tests, and refactors some common code into helper functions in `torch/testing/_internal/common_cuda.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104059
Approved by: https://github.com/huydhn
https://github.com/pytorch/pytorch/pull/95715 added the functionality to abort `ncclCommInitRankConfig` by specifying `blocking=0` to enable non-blocking behavior.
However, calling the `pg._abort()` didn't recover from a stuck `ncclCommInitRankConfig` since the `_abort` method only looked through `devNCCLCommMap_` map and aborted those communicators. Since `ncclCommInitRankConfig` was stuck, the communicator itself wasn't added to the map and the host thread was stuck on this line: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1171. As a result, `_abort` was a no-op.
To resolve this issue, I added the communicators to `inProgressCommMap_` as soon as they were created and then removed them once added to `devNCCLCommMap_`.
I also added a unit test that was failing without the changes to ProcessGroupNCCL.cpp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103925
Approved by: https://github.com/osalpekar
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
semi-structured sparse bool mask and creates an instance of the
subclass.
**SparseSemiStructuredTensor**
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
**to_sparse_semi_structured**
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
**User Details**
We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
```
from torch.sparse import to_sparse_semi_structured
mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()
linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
mask=linear.weight.bool())
```
This also updates tests and the `torch.sparse` module docstring to
reflect these changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
Summary:
Details in T133020932
First commit of collective utils library. Ported over from model store, removed scuba logging, error_trait and all dependencies on modelstore.
Test Plan: In the following diffs.
Differential Revision: D45545970
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101037
Approved by: https://github.com/H-Huang
Summary:
Trying to get the `__self__` attribute on any `_OpNamespace` object should be an invalid operation. The `__self__` attribute only exists on instance method object and not on class objects.
In [dynamo](a152b3e3b8/torch/_dynamo/variables/torch.py (L164)) there is code that tries to access the `__self__` attribute on `TorchVariable`, this currently results in an expensive call to `torch._C._jit_get_operation` [here](a152b3e3b8/torch/_ops.py (L740)) which ultimately fails and throws an exception. For cases where it fails the operation turns out to be quite expensive on the order of ~0.03s.
For edge use cases when exporting large models with quantized ops this exception is thrown 100's of times resulting in a lot of time wasted. By preventing the call to `torch._C._jit_get_operation` we can quickly return from this function and significantly reduce export times. On a large ASR model for example export currently takes **~405** seconds. With this change we can reduce it to **~340s**.
Overall this should also be a harmless change as no one should mostly ever try to access the `__self__` attribute on any `_OpNamespace` object.
Test Plan: Added test case.
Differential Revision: D46959879
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104096
Approved by: https://github.com/larryliu0820, https://github.com/ezyang, https://github.com/zou3519
This PR combines the C++ code for the AOTInductor's model and interface with Bin Bao's changes to AOTInductor codegen.
It adds a number of AOTInductor C interfaces that can be used by an inference runtime. Under the hood of the interfaces, the model code generated by the AOTInductor's codegen is wrapped into a class, AOTInductorModel, which manages tensors and run the model inference.
On top of AOTInductorModel, we provide one more abstract layer, AOTInductorModelContainer, which allows the user to have multiple inference runs concurrently for the same model.
This PR also adjusts the compilation options for AOT codegen, particularly some fbcode-related changes such as libs to be linked and header-file search paths.
Note that this is the very first version of the AOTInductor model and interface, so many features (e.g. dynamic shape) are incomplete. We will support those missing features in in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104202
Approved by: https://github.com/desertfire
Fixes#104170
As noted in the above issue it seems that the code for randperm basically boils down to:
`torch.argsort(torch.rand(size, device="mps"), dim = 0)`
However it seems like in the fused(?) pytorch version the type of tensor we were drawing `torch.rand(size, device="mps")` from was int64 with an inclusive(?) upper bound of 1. This caused everything to be sorted into two groups (if you drew 0 or 1) each monotonically ascending due to sort tie breaking.
One way to fix this is to just generate the random tensor as float64s with an upper bound of 1.0 instead of int64s. An alternative to to just set the upper bound to max int 64.
~I choose the float64 one basically on a coin flip b/c I couldn't tell the original contributor's intent (due to mixed up upper bounds and type) but would be happy to change to use int64 and max int 64 as an upper bound instead if that's better.~
Edit on second thought I don't like using floats from 0.0 to 1.0 as there are fewer of them in that range than int64s from 0 to int 64 max_value. I also suspect integer math might be faster but need to benchmark this tomorrow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104171
Approved by: https://github.com/malfet
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.
In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.
This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
semi-structured sparse bool mask and creates an instance of the
subclass.
**SparseSemiStructuredTensor**
The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.
Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().
Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.
**to_sparse_semi_structured**
This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.
Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.
**User Details**
We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```
The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:
```
from torch.sparse import to_sparse_semi_structured
mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()
linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
mask=linear.weight.bool())
```
This also updates tests and the `torch.sparse` module docstring to
reflect these changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
Summary:
According to https://www.internalfb.com/omh/view/ai_infra_mobile_platform/tests these have been failing since jul 2022.
Just going to delete unless someone thinks they actually do matter and should be made green
https://www.internalfb.com/intern/test/562949996115570/ <- failing test
I ran locally and got errors like
xplat/caffe2/aten/src/ATen/native/quantized/cpu/qnnpack/test/gemm-block-sparse-microkernel-tester.h:483: Failure
Expected equality of these values:
c[mIndex * cStride() + nIndex]
Which is: -872.50446
acc[mIndex * n() + nIndex]
Which is: -872.50488
at 0, 0: reference = -872.5048828125, optimized = -872.50445556640625, Mr x Nr = 8 x 4, M x N x K = 7 x 1 x 13
xplat/caffe2/aten/src/ATen/native/quantized/cpu/qnnpack/test/gemm-block-sparse-microkernel-tester.h:483: Failure
Expected equality of these values:
c[mIndex * cStride() + nIndex]
Which is: -67.246628
acc[mIndex * n() + nIndex]
Which is: -67.24707
at 3, 0: reference = -67.2470703125, optimized = -67.246627807617188, Mr x Nr = 8 x 4, M x N x K = 4 x 1 x 15
[ FAILED ] Q8GEMM_8x4c1x4__SSE2.packedA_k_gt_8_subtile (148 ms)
Test Plan: ci
Reviewed By: kimishpatel
Differential Revision: D46950966
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104073
Approved by: https://github.com/kimishpatel