# Description
`.coalesce` cannot handle large inputs on ROCM due to maximal grid size limit.
This PR splits axis `X` into axes `X` and `Y`, and repurposes `Z` for original `Y` on ROCm to avoid such limitation.
Confirmed the new approach can handle large inputs. Correctness needs validation.
# Testing Command
`python torch_spmv.py 22500000 272500000`
## Script `torch_spmv.py`
``` python
import torch
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="Sparse COO Matrix by Dense Vector Multiplication using PyTorch"
)
parser.add_argument("n", type=int, help="Size of the NxN matrix")
parser.add_argument("nnz", type=int, help="Number of non-zero entries")
return parser.parse_args()
def main():
args = parse_args()
n = args.n
nnz = args.nnz
dtype = torch.float32
device = torch.device('cuda')
# Generate random indices for the sparse matrix in COO format.
torch.manual_seed(42)
rows = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
cols = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
indices = torch.stack([rows, cols], dim=0)
# Generate random values.
values = torch.randn(nnz, dtype=torch.float32, device=device)
# Create the sparse COO matrix and move it to the target device.
sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.float32, device=device)
sparse_matrix = sparse_matrix.coalesce()
# Generate a random dense vector.
dense_vector = torch.randn(n, dtype=torch.float32, device=device)
# Perform sparse matrix - dense vector multiplication.
# Using torch.sparse.mm which expects a 2D tensor for the vector.
result = torch.sparse.mm(sparse_matrix, dense_vector.unsqueeze(1)).squeeze()
# result = torch.mv(sparse_matrix, dense_vector)
# Print the result.
print("Result of the multiplication:")
print(torch.sum(result))
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158281
Approved by: https://github.com/jeffdaily
# Description
`.coalesce` cannot handle large inputs on ROCM due to maximal grid size limit.
This PR splits axis `X` into axes `X` and `Y`, and repurposes `Z` for original `Y` on ROCm to avoid such limitation.
Confirmed the new approach can handle large inputs. Correctness needs validation.
# Testing Command
`python torch_spmv.py 22500000 272500000`
## Script `torch_spmv.py`
``` python
import torch
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="Sparse COO Matrix by Dense Vector Multiplication using PyTorch"
)
parser.add_argument("n", type=int, help="Size of the NxN matrix")
parser.add_argument("nnz", type=int, help="Number of non-zero entries")
return parser.parse_args()
def main():
args = parse_args()
n = args.n
nnz = args.nnz
dtype = torch.float32
device = torch.device('cuda')
# Generate random indices for the sparse matrix in COO format.
torch.manual_seed(42)
rows = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
cols = torch.randint(0, n, (nnz,), dtype=torch.int64, device=device)
indices = torch.stack([rows, cols], dim=0)
# Generate random values.
values = torch.randn(nnz, dtype=torch.float32, device=device)
# Create the sparse COO matrix and move it to the target device.
sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(n, n), dtype=torch.float32, device=device)
sparse_matrix = sparse_matrix.coalesce()
# Generate a random dense vector.
dense_vector = torch.randn(n, dtype=torch.float32, device=device)
# Perform sparse matrix - dense vector multiplication.
# Using torch.sparse.mm which expects a 2D tensor for the vector.
result = torch.sparse.mm(sparse_matrix, dense_vector.unsqueeze(1)).squeeze()
# result = torch.mv(sparse_matrix, dense_vector)
# Print the result.
print("Result of the multiplication:")
print(torch.sum(result))
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158281
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
Fixes#122757
## Test Result
```python
import torch
model_output = torch.randn(10, 5).cuda()
labels = torch.randint(0, 5, (10,)).cuda()
weights = torch.randn(5)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(input=model_output, target=labels)
print(loss)
Traceback (most recent call last):
File "/home/zong/code/pytorch/../loss2.py", line 17, in <module>
loss = loss_fn(input=model_output, target=labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/nn/modules/loss.py", line 1297, in forward
return F.cross_entropy(
^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/nn/functional.py", line 3494, in cross_entropy
return torch._C._nn.cross_entropy_loss(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but got weight is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_nll_loss_forward)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150750
Approved by: https://github.com/malfet
As per title.
The following implementation removes the usage of `repeat_interleave, tile` and `full_coo_indices` and replaces them with broadcasting. That way we reduce memory traffic (and are likely to hit cache a lot) and the total number of launched kernels.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142364
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
This PR fixes a bug in `search_end_matrix_indices_cuda_kernel` causing an illegal memory access when calling `bmm_sparse_cuda` on a sparse matrix with no non-zero values in the first batch dimension. Reproducible example:
```py
import torch
ind = torch.tensor([[1], [0], [0]], device="cuda")
val = torch.tensor([1.], device="cuda")
A = torch.sparse_coo_tensor(ind, val, size=(2, 1, 1))
B = torch.zeros((2, 1, 1), device="cuda")
C = torch.bmm(A, B)
```
## Details
In the previous code, we may for example end up with the following situation:
```
i : indices_1D[i]
------------------------------------------
0 : 1 <- start_idx, mid_idx
1 : 1 <- end_idx
...
```
When `target_mat_num = 0`, the next iteration of the while loop will assign `-1` to `end_idx` and thus `(0 + (-1)) >> 1 = -1` to `mid_idx`, causing an access error on line 703. The updated code maintains the invariant `start_idx <= end_idx` and will not go out of bounds.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131977
Approved by: https://github.com/amjames, https://github.com/pearu, https://github.com/nikitaved
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.
Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0
| Repository | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7 | 251.8 | 351.1 | 274.9 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
This PR enables `test_addmm_sizes_all_sparse_csr_k_*_n_*_m_*_cuda_complex128` for ROCm for trivial cases (m or n or k = 0)
CUSPARSE_SPMM_COMPLEX128_SUPPORTED also used for `test_addmm_all_sparse_csr` and ` test_sparse_matmul` and both of them are skipped for ROCm by `@skipIfRocm` or `@skipCUDAIf(not _check_cusparse_spgemm_available())`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120504
Approved by: https://github.com/jithunnair-amd, https://github.com/ezyang