mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ROCm] Support large inputs for coalesceValuesKernel (#158281)"
This reverts commit a7abf57aabec0ce686092e2d66e53ba185dbc56b.
Reverted https://github.com/pytorch/pytorch/pull/158281 on behalf of https://github.com/clee2000 due to broke windows cuda build? [GH job link](https://github.com/pytorch/pytorch/actions/runs/16915172288/job/47927141460) [HUD commit link](a7abf57aab
). Not caught b/c PR didn't have ciflow/trunk ([comment](https://github.com/pytorch/pytorch/pull/158281#issuecomment-3180408766))
This commit is contained in:
@ -196,17 +196,9 @@ C10_LAUNCH_BOUNDS_1(num_threads())
|
||||
__global__ void coalesceValuesKernel(
|
||||
int64_t *segment_offsets, int64_t *value_indices,
|
||||
Dtype *values, Dtype *newValues,
|
||||
int64_t nnz, int64_t newNnz,
|
||||
#ifdef USE_ROCM
|
||||
int64_t nsegments,
|
||||
#endif
|
||||
int64_t stride) {
|
||||
int64_t nnz, int64_t newNnz, int64_t stride) {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y;
|
||||
#else
|
||||
int64_t seg = blockIdx.x * 4 + threadIdx.y;
|
||||
#endif
|
||||
int seg = blockIdx.x * 4 + threadIdx.y;
|
||||
|
||||
// Number of values processed by each thread (grain size)
|
||||
const int SZ = 4;
|
||||
@ -215,11 +207,7 @@ __global__ void coalesceValuesKernel(
|
||||
const int newValueRow = seg * stride;
|
||||
const int begin = segment_offsets[seg];
|
||||
const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
|
||||
#ifdef USE_ROCM
|
||||
const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ;
|
||||
#else
|
||||
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
|
||||
#endif
|
||||
Acctype tmp[SZ];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
@ -262,17 +250,9 @@ C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4)
|
||||
__global__ void coalesceValuesKernel(
|
||||
int64_t *segment_offsets, int64_t *value_indices,
|
||||
bool *values, bool *newValues,
|
||||
int64_t nnz, int64_t newNnz,
|
||||
#ifdef USE_ROCM
|
||||
int64_t nsegments,
|
||||
#endif
|
||||
int64_t stride) {
|
||||
int64_t nnz, int64_t newNnz, int64_t stride) {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y;
|
||||
#else
|
||||
int64_t seg = blockIdx.x * 4 + threadIdx.y;
|
||||
#endif
|
||||
int seg = blockIdx.x * 4 + threadIdx.y;
|
||||
|
||||
// Number of values processed by each thread (grain size)
|
||||
const int SZ = 4;
|
||||
@ -281,11 +261,7 @@ __global__ void coalesceValuesKernel(
|
||||
const int newValueRow = seg * stride;
|
||||
const int begin = segment_offsets[seg];
|
||||
const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
|
||||
#ifdef USE_ROCM
|
||||
const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ;
|
||||
#else
|
||||
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
|
||||
#endif
|
||||
bool tmp[SZ];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
|
@ -106,14 +106,7 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) {
|
||||
values = values.contiguous();
|
||||
int64_t stride = c10::multiply_integers(values.sizes().slice(1));
|
||||
int warp_size = at::cuda::warp_size();
|
||||
#ifdef USE_ROCM
|
||||
const int64_t BATCHING_SEGMENT = 4096;
|
||||
int64_t nsegments = ceil_div(newNnz, (int64_t) SZ);
|
||||
int64_t s_batch = ceil_div(nsegments, BATCHING_SEGMENT);
|
||||
dim3 grid(s_batch, (s_batch == 1) ? nsegments : BATCHING_SEGMENT, ceil_div(stride, (int64_t) warp_size*SZ));
|
||||
#else
|
||||
dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ));
|
||||
#endif
|
||||
dim3 block(warp_size, SZ);
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
|
||||
@ -126,9 +119,6 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) {
|
||||
newValues.data_ptr<scalar_t>(),
|
||||
nnz,
|
||||
newNnz,
|
||||
#if USE_ROCM
|
||||
nsegments,
|
||||
#endif
|
||||
stride
|
||||
);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
@ -21,7 +21,7 @@ from torch.testing._internal.common_cuda import \
|
||||
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf)
|
||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
@ -367,19 +367,6 @@ class TestSparse(TestSparseBase):
|
||||
t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced)
|
||||
_test_coalesce(t) # this tests correctness
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIf(not SM80OrLater and not TEST_WITH_ROCM, "CUDA capability < SM80 and not ROCM")
|
||||
@dtypes(torch.float)
|
||||
def test_coalesce_accepts_large_tensor(self, device, dtype):
|
||||
N = 22500000
|
||||
NNZ = 272500000
|
||||
rows = torch.randint(0, N, (NNZ,), dtype=torch.int64, device=device)
|
||||
cols = torch.randint(0, N, (NNZ,), dtype=torch.int64, device=device)
|
||||
indices = torch.stack([rows, cols], dim=0)
|
||||
values = torch.randn(NNZ, dtype=dtype, device=device)
|
||||
sparse_matrix = torch.sparse_coo_tensor(indices, values, size=(N, N), dtype=torch.float32, device=device)
|
||||
sparse_matrix = sparse_matrix.coalesce()
|
||||
|
||||
@dtypes(torch.double)
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395")
|
||||
def test_coalesce_reference_cycle(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user