This reverts commit 5494b2a8d38c3ddbeb2d96a5ac990e20ec4c48fd.
Need to skip `test_sparse_csr.py::TestSparseCSRCUDA::test_sampled_addmm_zero_sized_cuda_*` again. Tests are failing now with "core dumped" error
```
python test_sparse_csr.py -v -k test_sampled_addmm_zero_sized_cuda_float64
test_sampled_addmm_zero_sized_cuda_float64 (__main__.TestSparseCSRCUDA) ... /tmp/pytorch/test/test_sparse_csr.py:2503: c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
GPU core dump created: gpucore.186789
:0:rocdevice.cpp :2992: 4701819131755 us: Callback: Queue 0x760cdcd00000 aborting with error : HSA_STATUS_ERROR_EXCEPTION: An HSAIL operation resulted in a hardware exception. code: 0x1016
Aborted (core dumped)
```
These failures are linked to `test_sparse_csr.py::TestSparseCSRCUDA::test_select_SparseBSC_int32_cuda_*` due to incorrect test log parsing. We will be able to close these issues also:
- Fixes https://github.com/pytorch/pytorch/issues/163663
- Fixes https://github.com/pytorch/pytorch/issues/160786
- Fixes https://github.com/pytorch/pytorch/issues/160785
- Fixes https://github.com/pytorch/pytorch/issues/160784
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163848
Approved by: https://github.com/jeffdaily
As the title states, suffixes like`.dylib` and `lib` can be replaced by `CMAKE_SHARED_LIBRARY_SUFFIX`, and prefixes like `lib` can be replaced by `CMAKE_SHARED_LIBRARY_PREFIX` on Unix or `CMAKE_IMPORT_LIBRARY_PREFIX` on Windows.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163850
Approved by: https://github.com/albanD
As title, in practice we found that sometimes, the dtype of gather does not match when it comes to output among all ranks, which is a undefined behavior. Same with broadcast and scatter. And they are all completed, so we should not think they are errors, we can skip it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163839
Approved by: https://github.com/VieEeEw
```
import torch
import torch.fx.traceback as fx_traceback
import torch.export
class M(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
x = x + 1
x = x - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
x = x * 2
x = x / 3
return x
m = M()
with fx_traceback.preserve_node_meta():
ep = torch.export.export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.op == "call_function":
print(f"{node.target}, {node.meta.get("custom", {})}")
```
prints
```
aten.add.Tensor, {'pp_stage': 0, 'fdsp_bucket': 0}
aten.sub.Tensor, {'pp_stage': 0}
aten.mul.Tensor, {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1}
aten.div.Tensor, {}
```
TODOs:
- run_decomposition is failing
- Need to test with the new full graph capture + aot_export_joint apis
- Need to make the annotation propagate through autograd engine to reach the bw nodes. Sample impl here: https://github.com/pytorch/pytorch/pull/83558
- Edward want to restrict the key in custom field to be top-level singleton objects only
- also need to take care of metadata merging when passes are fusing nodes
Thanks @angelayi for contributing the dynamo fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163673
Approved by: https://github.com/albanD, https://github.com/angelayi
Previously, an eval() call before a training step() would not correctly initialize the backward pass of the pipeline stages, leading to errors during the subsequent training step. This PR ensures that the backward stages can still be initialized after an eval() call.
Fixes#162822
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162823
Approved by: https://github.com/dcci, https://github.com/H-Huang
## Overview
This PR allows the profiler users to access `Kineto` and `TorchOp` metadata in JSON string format through a new `metadata_json` attribute in `FunctionEvent` objects, which is triggered through a new `expose_kineto_event_metadata` flag in `ExperimentalConfig`.
## Testing
A unit test was added to validate functionality.
## Documentation
Added/updated function doc strings where appropriate.
## Example output
```python
import torch
from torch.profiler import profile
with profile(experimental_config=torch._C._profiler._ExperimentalConfig(expose_kineto_event_metadata=True)) as prof:
res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))
for event in prof.events():
print(f'name: {event.key}, metadata: {event.metadata_json}')
```
```
name: aten::rand, metadata: "Ev Idx": 0
name: aten::empty, metadata: "Ev Idx": 1
name: aten::uniform_, metadata: "Ev Idx": 2
name: aten::rand, metadata: "Ev Idx": 3
name: aten::empty, metadata: "Ev Idx": 4
name: aten::uniform_, metadata: "Ev Idx": 5
name: aten::mm, metadata: "Ev Idx": 6
name: aten::resolve_conj, metadata: "Ev Idx": 7
name: aten::resolve_conj, metadata: "Ev Idx": 8
name: aten::resolve_conj, metadata: "Ev Idx": 9
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161624
Approved by: https://github.com/sraikund16
Current state: Shape mismatch failure when mm+rs on the last mm scatter dim.
Adding separate path to handle lastdim for aten.mm, scaled_mm should be handled similarly, but needs additional PR.
So disabling scaled_mm case with filter matmul function.
Adding inductor.config for this change that is True by default for fast debuggability of new path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162794
Approved by: https://github.com/fegin
The version finding logic triggered from `setup.py` generally tries to take the git information into account.
This is fine for most situations where we are building from a checkout, but it creates a problem in the case of sdists, as here the version is determined at the time of sdist creation, taking the git information into account, but then later recalculated when building wheels or installing from the sdist, now with the git information missing.
The solution is to take the version information directly from the sdist, which this PR adds by means of parsing the `PKG-INFO` which marks an unpacked sdist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160315
Approved by: https://github.com/atalman
ghstack dependencies: #157814
1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.
Fixes#163103
LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```
This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
Fixes misleading warning messages when running on sm12x devices using binaries built with sm120.
PyTorch binary built with sm120 is compatible with e.g. sm121, so no need for the warning of incompatibility.
Also allow the 'matched_cuda_warn' message to show when e.g. the user is running a binary built with only sm90 on sm12x, so that the user would be prompted to get a build which supports e.g. sm120.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161299
Approved by: https://github.com/eqy, https://github.com/atalman
Fixes#160520
Summary:
When running Inductor with cpp_wrapper under a DeviceContext, non-tensor arguments were being wrapped with torch.tensor(arg) without specifying the device.
creating the tensor on the current active device (like CUDA), and later fetching it back to CPU via .item(), causing unnecessary host-device-host memory transfers.
PR fixes issue by explicitly creating scalar tensors on the CPU:
```
input_tensors = [
arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu')
for arg in args
]
```
impact: inductor, codegen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160584
Approved by: https://github.com/benjaminglass1, https://github.com/desertfire, https://github.com/mlazos, https://github.com/jeffdaily
Avoid `at::alias` in the `repeat` op implementation
## Summary
This PR removed the usage of `at::alias` in the implementation and just `permute`+`reshape` the tensor to fit the specs of the result.
This is a less hacky and a more readable way of implementing the op.
All the new ops we are using are view-only ops, which does not introduce overhead of changing the storage.
## Who want this
We are using `PrivateUse1` and accelerator, but this request to avoid `at::alias` in any op should be general enough for any backend who is using XLA, or who do not have explicit control over the memory allocation on the devices.
## Why we/they need this
As we support TPU, we are overriding some ATen ops by binding them to PrivateUse1.
However, it is not recommended to override the `repeat` op directly as we saw the following in `RegistrationDeclaration.h`.
```
at::Tensor repeat(const at::Tensor & self, c10::SymIntArrayRef repeats); // {"schema": "aten::repeat(Tensor self, SymInt[] repeats) -> Tensor", "dispatch": "True", "default": "True"}
```
We had to reuse the existing implementation of `repeat` to decomposite to other ops.
However, we are unable to support the current implementation, which uses `at::alias`.
It have two tensors share the same storage and modify one of them and return the other assuming it is changed, too.
As, we do not have explicit control over the memory allocation of the tensors using XLA/PJRT.
## Alternatives
We are open to alternative solutions that work for us if this PR is not in favor of the PyTorch community.
For example, we may just bind our version of `repeat` op implementation to both `PrivateUse` and `AutogradPrivateUse1`.
However, to my understanding, this would not work well with torch dynamo and `torch.compile`.
Would you mind guiding us on how to solve this?
Thanks!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163455
Approved by: https://github.com/Skylion007