mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add support for CPU scalar in addcmul (#143264)
Step required for performance in #143122 Adds support for CPU scalar for tensor_2 in addcmul. For example: ``` import torch a = torch.rand(2, 2, device="cuda") b = torch.tensor(1e-3) torch.add(a, b) torch.addcmul(a, a, b) # used to fail, now works ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143264 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
859be14c4e
commit
576789197a
@ -19,7 +19,15 @@ TORCH_META_FUNC(addcmul)
|
||||
const Tensor& tensor1,
|
||||
const Tensor& tensor2,
|
||||
const Scalar& value) {
|
||||
build_ternary_op(maybe_get_output(), self, tensor1, tensor2);
|
||||
build(TensorIteratorConfig()
|
||||
.allow_cpu_scalars(true)
|
||||
.promote_inputs_to_common_dtype(true)
|
||||
.cast_common_dtype_to_outputs(true)
|
||||
.enforce_safe_casting_to_output(true)
|
||||
.add_owned_output(maybe_get_output())
|
||||
.add_owned_const_input(self)
|
||||
.add_owned_const_input(tensor1)
|
||||
.add_owned_const_input(tensor2));
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(addcdiv)
|
||||
|
@ -11,10 +11,30 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void addcmul_cuda_scalar_tensor2_kernel(
|
||||
TensorIteratorBase& iter,
|
||||
const Scalar& scalar_tensor2,
|
||||
const Scalar& value
|
||||
);
|
||||
|
||||
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
|
||||
constexpr char addcmul_name[] = "addcmul";
|
||||
#endif
|
||||
void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
TORCH_CHECK(
|
||||
!iter.is_cpu_scalar(1),
|
||||
"CPU Scalar support for self argument is not supported when "
|
||||
"calling addcmul on CUDA tensors."
|
||||
);
|
||||
|
||||
TORCH_CHECK(
|
||||
!iter.is_cpu_scalar(2),
|
||||
"CPU Scalar support for tensor1 argument is not supported when "
|
||||
"calling addcmul on CUDA tensors. "
|
||||
"However, CPU Scalar support for tensor2 is supported, "
|
||||
"please swap your tensor1 and tensor2 terms."
|
||||
);
|
||||
|
||||
auto dtype = iter.common_dtype();
|
||||
if (at::isComplexType(dtype)) {
|
||||
// When using Jiterator, addcmul and addcdiv kernels get stuck during a
|
||||
@ -25,6 +45,11 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
auto alpha = value.to<scalar_t>();
|
||||
static const auto addcmul_string = jiterator_stringify(
|
||||
template <typename T> T addcmul(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
|
||||
if (iter.is_cpu_scalar(3)) {
|
||||
auto tensor2_val = iter.scalar_value<scalar_t>(3);
|
||||
iter.remove_operand(3);
|
||||
return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value);
|
||||
}
|
||||
jitted_gpu_kernel<
|
||||
/*name=*/addcmul_name,
|
||||
/*return_dtype=*/scalar_t,
|
||||
@ -38,6 +63,12 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
|
||||
if (iter.is_cpu_scalar(3)) {
|
||||
auto tensor2_val = iter.scalar_value<scalar_t>(3);
|
||||
iter.remove_operand(3);
|
||||
return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value);
|
||||
}
|
||||
|
||||
auto alpha = value.to<scalar_t>();
|
||||
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
|
||||
return a + alpha * b * c;
|
||||
@ -46,6 +77,11 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
#endif
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
|
||||
if (iter.is_cpu_scalar(3)) {
|
||||
auto tensor2_val = iter.scalar_value<scalar_t>(3);
|
||||
iter.remove_operand(3);
|
||||
return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value);
|
||||
}
|
||||
// note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
|
||||
// and do math in fp32 for better accuracy.
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
@ -57,6 +93,58 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
}
|
||||
}
|
||||
|
||||
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
|
||||
constexpr char addcmul_scalar_tensor2_name[] = "addcmul_scalar_tensor2";
|
||||
#endif
|
||||
void addcmul_cuda_scalar_tensor2_kernel(TensorIteratorBase& iter, const Scalar& scalar_tensor2, const Scalar& value) {
|
||||
auto dtype = iter.common_dtype();
|
||||
|
||||
if (at::isComplexType(dtype)) {
|
||||
// When using Jiterator, addcmul and addcdiv kernels get stuck during a
|
||||
// promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
|
||||
// https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
|
||||
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
|
||||
auto c = scalar_tensor2.to<scalar_t>();
|
||||
auto alpha = value.to<scalar_t>();
|
||||
|
||||
static const auto addcmul_scalar_tensor2_string = jiterator_stringify(
|
||||
template <typename T> T addcmul_scalar_tensor2(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
|
||||
|
||||
jitted_gpu_kernel<
|
||||
/*name=*/addcmul_scalar_tensor2_name,
|
||||
/*return_dtype=*/scalar_t,
|
||||
/*common_dtype=*/scalar_t,
|
||||
/*arity=*/2>(
|
||||
iter,
|
||||
addcmul_scalar_tensor2_string,
|
||||
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
|
||||
/*scalar_val=*/0,
|
||||
/*extra_args=*/std::make_tuple(c, alpha));
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
|
||||
auto c = scalar_tensor2.to<scalar_t>();
|
||||
auto alpha = value.to<scalar_t>();
|
||||
gpu_kernel(iter, [alpha, c]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a + alpha * (b * c);
|
||||
});
|
||||
});
|
||||
#endif
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
|
||||
// note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
|
||||
// and do math in fp32 for better accuracy.
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
auto c = scalar_tensor2.to<accscalar_t>();
|
||||
auto alpha = value.to<accscalar_t>();
|
||||
gpu_kernel(iter, [alpha, c]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a + alpha * (static_cast<accscalar_t>(b) * c);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
|
||||
// return a + alpha * (b / static_cast<accscalar_t>(c));
|
||||
constexpr char addcdiv_name[] = "addcdiv";
|
||||
|
@ -3289,9 +3289,10 @@ else:
|
||||
self.assertTrue(y.stride() == (1, 4))
|
||||
|
||||
# FIXME: move to elementwise ternary test suite
|
||||
@parametrize("use_cpu_scalar", [True, False])
|
||||
@dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
|
||||
@dtypes(*set(get_all_math_dtypes('cpu')))
|
||||
def test_addcmul(self, device, dtype):
|
||||
def test_addcmul(self, device, dtype, use_cpu_scalar):
|
||||
# Returns floating or integral scalar corresponding to dtype
|
||||
def _number(floating, integer, dtype):
|
||||
if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
|
||||
@ -3311,7 +3312,10 @@ else:
|
||||
|
||||
a = rand_tensor((2, 2), dtype=dtype, device=device)
|
||||
b = rand_tensor((2, 2), dtype=dtype, device=device)
|
||||
c = rand_tensor((2, 2), dtype=dtype, device=device)
|
||||
if use_cpu_scalar:
|
||||
c = rand_tensor([], device="cpu", dtype=dtype)
|
||||
else:
|
||||
c = rand_tensor((2, 2), dtype=dtype, device=device)
|
||||
|
||||
alpha = _number(0.5, 3, dtype)
|
||||
|
||||
@ -3331,6 +3335,21 @@ else:
|
||||
out = torch.addcmul(a, b, c, value=-1)
|
||||
self.assertTrue(not (out.isnan() or out.isinf()))
|
||||
|
||||
@onlyCUDA
|
||||
def test_addcmul_cuda_errors_with_cpu_scalars(self, device):
|
||||
# Logic is dtype agnostic, so dtype isn't tested
|
||||
alpha = 0.5
|
||||
|
||||
a = torch.rand((2, 2), device=device)
|
||||
b = torch.rand((2, 2), device=device)
|
||||
c = torch.rand((2, 2), device=device)
|
||||
scalar = torch.rand([], device="cpu")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'CPU Scalar support for tensor1 argument'):
|
||||
torch.addcmul(a, scalar, c, value=alpha)
|
||||
with self.assertRaisesRegex(RuntimeError, r'CPU Scalar support for self argument'):
|
||||
torch.addcmul(scalar, b, c, value=alpha)
|
||||
|
||||
# FIXME: move to shape ops test suite
|
||||
def test_narrow_empty(self, device):
|
||||
x = torch.randn(2, 3, 4, device=device)
|
||||
|
Reference in New Issue
Block a user