mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Include support for the scatter gather cuda kernels to allow for comp… (#124809)"
This reverts commit 9e24c263f998819f849bb8293323213101e9aefc. Reverted https://github.com/pytorch/pytorch/pull/124809 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/124809#issuecomment-2091751002))
This commit is contained in:
@ -38,11 +38,7 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
|
||||
|
||||
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
||||
inline C10_HOST_DEVICE bool _isnan(T val) {
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
return ::isnan(val.real()) || ::isnan(val.imag());
|
||||
#else
|
||||
return std::isnan(val.real()) || std::isnan(val.imag());
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
|
||||
|
@ -35,26 +35,6 @@ struct AtomicFPOp<at::Half> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicFPOp<c10::complex<float>> {
|
||||
template <typename func_t>
|
||||
inline __device__ c10::complex<float> operator() (c10::complex<float> *address, c10::complex<float> val, const func_t& func) {
|
||||
unsigned long long int* addr_as_ull = (unsigned long long int*)address;
|
||||
unsigned long long int old = *addr_as_ull;
|
||||
unsigned long long int assumed, new_val;
|
||||
|
||||
c10::complex<float> csum;
|
||||
do {
|
||||
assumed = old;
|
||||
csum = func(csum, val);
|
||||
new_val = *reinterpret_cast<unsigned long long*>(&csum);
|
||||
old = atomicCAS(addr_as_ull, assumed, new_val);
|
||||
} while (assumed != old);
|
||||
|
||||
return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicFPOp<at::BFloat16> {
|
||||
template <typename func_t>
|
||||
@ -368,14 +348,6 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
|
||||
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
|
||||
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
|
||||
|
||||
inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
|
||||
return AtomicFPOp<c10::complex<float>>()(address, val,
|
||||
[](c10::complex<float> bsum, c10::complex<float> val) {
|
||||
bsum*=(val);
|
||||
return bsum;
|
||||
});
|
||||
}
|
||||
|
||||
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
|
||||
return AtomicFPOp<at::Half>()(address, val,
|
||||
[](at::Half bsum, at::Half val) {
|
||||
@ -397,7 +369,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
|
||||
});
|
||||
}
|
||||
|
||||
// Don't use a templated function for this since the addition function defaults to the CUDA built-in.
|
||||
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
||||
inline __device__ float gpuAtomicMul (float * address, float val) {
|
||||
unsigned int* address_as_ull = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ull;
|
||||
@ -430,29 +402,6 @@ __host__ __device__ T safe_max(T a, T b) {
|
||||
return max;
|
||||
}
|
||||
|
||||
__inline__ __device__ c10::complex<float> complex_max(c10::complex<float> a, c10::complex<float> b) {
|
||||
if(at::_isnan(b)) {
|
||||
return b;
|
||||
} else {
|
||||
// Compute the magnitude of the complex numbers and compare each to see which one is greater.
|
||||
float a_magnitude = __fsqrt_rn(
|
||||
(
|
||||
__fmul_rn(a.real(), a.real()) +
|
||||
__fmul_rn(a.imag(),a.imag())
|
||||
)
|
||||
);
|
||||
float b_magnitude = __fsqrt_rn(
|
||||
(
|
||||
__fmul_rn(b.real(), b.real()) +
|
||||
__fmul_rn(b.imag(),b.imag())
|
||||
)
|
||||
);
|
||||
return std::max<float>(a_magnitude, b_magnitude);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
ATOMIC_INTEGER_IMPL(Max)
|
||||
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
|
||||
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
|
||||
@ -467,13 +416,6 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
|
||||
});
|
||||
}
|
||||
|
||||
inline __device__ c10::complex<float> gpuAtomicMax(c10::complex<float> * address, c10::complex<float> val) {
|
||||
return AtomicFPOp<c10::complex<float>>()(address, val,
|
||||
[](c10::complex<float> bsum, c10::complex<float> val) {
|
||||
return complex_max(bsum, val);
|
||||
});
|
||||
}
|
||||
|
||||
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
|
||||
return AtomicFPOp<at::BFloat16>()(address, val,
|
||||
[](at::BFloat16 bsum, at::BFloat16 val) {
|
||||
@ -520,27 +462,6 @@ __host__ __device__ T safe_min(T a, T b) {
|
||||
return min;
|
||||
}
|
||||
|
||||
__inline__ __device__ c10::complex<float> complex_min(c10::complex<float> a, c10::complex<float> b) {
|
||||
if(at::_isnan(b)) {
|
||||
return b;
|
||||
} else {
|
||||
// Compute the magnitude of the complex numbers and compare each to see which one is smaller.
|
||||
float a_magnitude = __fsqrt_rn(
|
||||
(
|
||||
__fmul_rn(a.real(), a.real()) +
|
||||
__fmul_rn(a.imag(),a.imag())
|
||||
)
|
||||
);
|
||||
float b_magnitude = __fsqrt_rn(
|
||||
(
|
||||
__fmul_rn(b.real(), b.real()) +
|
||||
__fmul_rn(b.imag(),b.imag())
|
||||
)
|
||||
);
|
||||
return std::min<float>(a_magnitude, b_magnitude);
|
||||
}
|
||||
}
|
||||
|
||||
ATOMIC_INTEGER_IMPL(Min)
|
||||
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
|
||||
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
|
||||
@ -555,13 +476,6 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
|
||||
});
|
||||
}
|
||||
|
||||
inline __device__ c10::complex<float> gpuAtomicMin(c10::complex<float> * address, c10::complex<float> val) {
|
||||
return AtomicFPOp<c10::complex<float>>()(address, val,
|
||||
[](c10::complex<float> bsum, c10::complex<float> val) {
|
||||
return complex_min(bsum, val);
|
||||
});
|
||||
}
|
||||
|
||||
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
|
||||
return AtomicFPOp<at::BFloat16>()(address, val,
|
||||
[](at::BFloat16 bsum, at::BFloat16 val) {
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
|
||||
#include <ATen/native/ScatterGatherChecks.h>
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
@ -200,6 +201,7 @@ struct cuda_scatter_gather_base_kernel {
|
||||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
|
||||
iter.dtype(),
|
||||
@ -257,6 +259,7 @@ struct cuda_scatter_gather_base_kernel {
|
||||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
|
||||
iter.dtype(),
|
||||
@ -315,9 +318,9 @@ struct cuda_scatter_gather_base_kernel {
|
||||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
at::ScalarType::ComplexFloat,
|
||||
iter.dtype(),
|
||||
"cuda_scatter_gather_base_kernel_func", [&] {
|
||||
using dtype = typename std::conditional<cast_to_opaque,
|
||||
@ -447,9 +450,8 @@ struct cuda_scatter_fill_base_kernel {
|
||||
auto index_size = ensure_nonempty_size(self, dim);
|
||||
auto index_stride = ensure_nonempty_stride(self, dim);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
at::ScalarType::ComplexFloat,
|
||||
iter.dtype(),
|
||||
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
|
||||
using dtype = typename std::conditional<cast_to_opaque,
|
||||
|
@ -221,8 +221,7 @@ class TestScatterGather(TestCase):
|
||||
include_self=include_self)
|
||||
|
||||
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
|
||||
include_complex=False, include_bool=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
||||
def test_scatter_reduce_prod(self, device, dtype):
|
||||
for include_self in (True, False):
|
||||
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
||||
@ -230,8 +229,7 @@ class TestScatterGather(TestCase):
|
||||
include_self=include_self)
|
||||
|
||||
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
|
||||
include_complex=False, include_bool=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
||||
def test_scatter_reduce_mean(self, device, dtype):
|
||||
for include_self in (True, False):
|
||||
for deterministic in [False, True]:
|
||||
@ -241,8 +239,7 @@ class TestScatterGather(TestCase):
|
||||
include_self=include_self)
|
||||
|
||||
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
|
||||
include_complex=False, include_bool=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
||||
def test_scatter_reduce_amax(self, device, dtype):
|
||||
for include_self in (True, False):
|
||||
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
||||
@ -261,8 +258,7 @@ class TestScatterGather(TestCase):
|
||||
|
||||
|
||||
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
|
||||
include_complex=False, include_bool=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
||||
def test_scatter_reduce_amin(self, device, dtype):
|
||||
for include_self in (True, False):
|
||||
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
||||
|
@ -57,8 +57,8 @@ from torch.testing._internal.common_cuda import (
|
||||
_create_scaling_case, _create_scaling_models_optimizers)
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
from torch.testing._internal.common_dtype import (
|
||||
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, all_types_and, floating_types,
|
||||
floating_and_complex_types, integral_types_and,
|
||||
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
|
||||
all_types_and, floating_types, floating_and_complex_types, integral_types_and,
|
||||
get_all_qint_dtypes,
|
||||
)
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -3837,7 +3837,7 @@ else:
|
||||
self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.cdouble)
|
||||
@dtypes(*complex_types())
|
||||
def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
|
||||
height = 2
|
||||
width = 2
|
||||
|
Reference in New Issue
Block a user