mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes: https://github.com/pytorch/pytorch/issues/156540 https://github.com/pytorch/pytorch/issues/129842 Should be merged after: https://github.com/pytorch/pytorch/pull/165102 To compare MPS and CPU, you can use this script: ```python import torch import time import matplotlib.pyplot as plt B, I, J, K = 8, 20000, 20000, 20000 num_iterations = 500 nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000] speedups = [] for nnz in nnz_values: indices = torch.stack([ torch.randint(0, B, (nnz,)), torch.randint(0, I, (nnz,)), torch.randint(0, J, (nnz,)), ]) values = torch.rand(nnz) sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce() dense = torch.randn(B, J, 200, device="mps") t1 = time.time() for _ in range(num_iterations): result = torch.bmm(sparse, dense) torch.mps.synchronize() t2 = time.time() mps_time = (t2 - t1) / num_iterations sparse_cpu = sparse.cpu() dense_cpu = dense.cpu() t1 = time.time() for _ in range(num_iterations): result_cpu = torch.bmm(sparse_cpu, dense_cpu) t2 = time.time() cpu_time = (t2 - t1) / num_iterations speedup = cpu_time / mps_time speedups.append(speedup) print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x") plt.figure(figsize=(10, 6)) plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8) plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12) plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12) plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14) plt.grid(True, alpha=0.3) plt.axhline(y=1, color='r', linestyle='--', alpha=0.5) plt.xscale('log') plt.tight_layout() plt.show() ``` ## Tested on M1 Pro <img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232 Approved by: https://github.com/malfet