Add support for CPU scalar in addcmul (#143264)

Step required for performance in #143122

Adds support for CPU scalar for tensor_2 in addcmul. For example:
```
import torch
a = torch.rand(2, 2, device="cuda")
b = torch.tensor(1e-3)

torch.add(a, b)
torch.addcmul(a, a, b)  # used to fail, now works
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143264
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
emmettbicker
2024-12-18 04:43:26 +00:00
committed by PyTorch MergeBot
parent 859be14c4e
commit 576789197a
3 changed files with 118 additions and 3 deletions

View File

@ -19,7 +19,15 @@ TORCH_META_FUNC(addcmul)
const Tensor& tensor1,
const Tensor& tensor2,
const Scalar& value) {
build_ternary_op(maybe_get_output(), self, tensor1, tensor2);
build(TensorIteratorConfig()
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.add_owned_output(maybe_get_output())
.add_owned_const_input(self)
.add_owned_const_input(tensor1)
.add_owned_const_input(tensor2));
}
TORCH_META_FUNC(addcdiv)

View File

@ -11,10 +11,30 @@
namespace at::native {
void addcmul_cuda_scalar_tensor2_kernel(
TensorIteratorBase& iter,
const Scalar& scalar_tensor2,
const Scalar& value
);
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
constexpr char addcmul_name[] = "addcmul";
#endif
void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
TORCH_CHECK(
!iter.is_cpu_scalar(1),
"CPU Scalar support for self argument is not supported when "
"calling addcmul on CUDA tensors."
);
TORCH_CHECK(
!iter.is_cpu_scalar(2),
"CPU Scalar support for tensor1 argument is not supported when "
"calling addcmul on CUDA tensors. "
"However, CPU Scalar support for tensor2 is supported, "
"please swap your tensor1 and tensor2 terms."
);
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
// When using Jiterator, addcmul and addcdiv kernels get stuck during a
@ -25,6 +45,11 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
auto alpha = value.to<scalar_t>();
static const auto addcmul_string = jiterator_stringify(
template <typename T> T addcmul(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
if (iter.is_cpu_scalar(3)) {
auto tensor2_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value);
}
jitted_gpu_kernel<
/*name=*/addcmul_name,
/*return_dtype=*/scalar_t,
@ -38,6 +63,12 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
});
#else
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
if (iter.is_cpu_scalar(3)) {
auto tensor2_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value);
}
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * b * c;
@ -46,6 +77,11 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
#endif
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
if (iter.is_cpu_scalar(3)) {
auto tensor2_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value);
}
// note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
// and do math in fp32 for better accuracy.
using accscalar_t = at::acc_type<scalar_t, true>;
@ -57,6 +93,58 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) {
}
}
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
constexpr char addcmul_scalar_tensor2_name[] = "addcmul_scalar_tensor2";
#endif
void addcmul_cuda_scalar_tensor2_kernel(TensorIteratorBase& iter, const Scalar& scalar_tensor2, const Scalar& value) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
// When using Jiterator, addcmul and addcdiv kernels get stuck during a
// promotion test on CUDA 11.3, so only enable that from CUDA 11.5:
// https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
auto c = scalar_tensor2.to<scalar_t>();
auto alpha = value.to<scalar_t>();
static const auto addcmul_scalar_tensor2_string = jiterator_stringify(
template <typename T> T addcmul_scalar_tensor2(T a, T b, T c, T alpha) { return a + alpha * (b * c); });
jitted_gpu_kernel<
/*name=*/addcmul_scalar_tensor2_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/2>(
iter,
addcmul_scalar_tensor2_string,
/*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
/*scalar_val=*/0,
/*extra_args=*/std::make_tuple(c, alpha));
});
#else
AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() {
auto c = scalar_tensor2.to<scalar_t>();
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha, c]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a + alpha * (b * c);
});
});
#endif
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() {
// note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float
// and do math in fp32 for better accuracy.
using accscalar_t = at::acc_type<scalar_t, true>;
auto c = scalar_tensor2.to<accscalar_t>();
auto alpha = value.to<accscalar_t>();
gpu_kernel(iter, [alpha, c]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a + alpha * (static_cast<accscalar_t>(b) * c);
});
});
}
}
#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050
// return a + alpha * (b / static_cast<accscalar_t>(c));
constexpr char addcdiv_name[] = "addcdiv";

View File

@ -3289,9 +3289,10 @@ else:
self.assertTrue(y.stride() == (1, 4))
# FIXME: move to elementwise ternary test suite
@parametrize("use_cpu_scalar", [True, False])
@dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
@dtypes(*set(get_all_math_dtypes('cpu')))
def test_addcmul(self, device, dtype):
def test_addcmul(self, device, dtype, use_cpu_scalar):
# Returns floating or integral scalar corresponding to dtype
def _number(floating, integer, dtype):
if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]:
@ -3311,7 +3312,10 @@ else:
a = rand_tensor((2, 2), dtype=dtype, device=device)
b = rand_tensor((2, 2), dtype=dtype, device=device)
c = rand_tensor((2, 2), dtype=dtype, device=device)
if use_cpu_scalar:
c = rand_tensor([], device="cpu", dtype=dtype)
else:
c = rand_tensor((2, 2), dtype=dtype, device=device)
alpha = _number(0.5, 3, dtype)
@ -3331,6 +3335,21 @@ else:
out = torch.addcmul(a, b, c, value=-1)
self.assertTrue(not (out.isnan() or out.isinf()))
@onlyCUDA
def test_addcmul_cuda_errors_with_cpu_scalars(self, device):
# Logic is dtype agnostic, so dtype isn't tested
alpha = 0.5
a = torch.rand((2, 2), device=device)
b = torch.rand((2, 2), device=device)
c = torch.rand((2, 2), device=device)
scalar = torch.rand([], device="cpu")
with self.assertRaisesRegex(RuntimeError, r'CPU Scalar support for tensor1 argument'):
torch.addcmul(a, scalar, c, value=alpha)
with self.assertRaisesRegex(RuntimeError, r'CPU Scalar support for self argument'):
torch.addcmul(scalar, b, c, value=alpha)
# FIXME: move to shape ops test suite
def test_narrow_empty(self, device):
x = torch.randn(2, 3, 4, device=device)