diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp index f5235a8e1770..ed63b86c85e6 100644 --- a/aten/src/ATen/native/PointwiseOps.cpp +++ b/aten/src/ATen/native/PointwiseOps.cpp @@ -19,7 +19,15 @@ TORCH_META_FUNC(addcmul) const Tensor& tensor1, const Tensor& tensor2, const Scalar& value) { - build_ternary_op(maybe_get_output(), self, tensor1, tensor2); + build(TensorIteratorConfig() + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true) + .add_owned_output(maybe_get_output()) + .add_owned_const_input(self) + .add_owned_const_input(tensor1) + .add_owned_const_input(tensor2)); } TORCH_META_FUNC(addcdiv) diff --git a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu index 45b0d01ceebb..14807c0200ec 100644 --- a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu +++ b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu @@ -11,10 +11,30 @@ namespace at::native { +void addcmul_cuda_scalar_tensor2_kernel( + TensorIteratorBase& iter, + const Scalar& scalar_tensor2, + const Scalar& value +); + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 constexpr char addcmul_name[] = "addcmul"; #endif void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { + TORCH_CHECK( + !iter.is_cpu_scalar(1), + "CPU Scalar support for self argument is not supported when " + "calling addcmul on CUDA tensors." + ); + + TORCH_CHECK( + !iter.is_cpu_scalar(2), + "CPU Scalar support for tensor1 argument is not supported when " + "calling addcmul on CUDA tensors. " + "However, CPU Scalar support for tensor2 is supported, " + "please swap your tensor1 and tensor2 terms." + ); + auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { // When using Jiterator, addcmul and addcdiv kernels get stuck during a @@ -25,6 +45,11 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { auto alpha = value.to(); static const auto addcmul_string = jiterator_stringify( template T addcmul(T a, T b, T c, T alpha) { return a + alpha * (b * c); }); + if (iter.is_cpu_scalar(3)) { + auto tensor2_val = iter.scalar_value(3); + iter.remove_operand(3); + return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value); + } jitted_gpu_kernel< /*name=*/addcmul_name, /*return_dtype=*/scalar_t, @@ -38,6 +63,12 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { }); #else AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() { + if (iter.is_cpu_scalar(3)) { + auto tensor2_val = iter.scalar_value(3); + iter.remove_operand(3); + return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value); + } + auto alpha = value.to(); gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t { return a + alpha * b * c; @@ -46,6 +77,11 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { #endif } else { AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() { + if (iter.is_cpu_scalar(3)) { + auto tensor2_val = iter.scalar_value(3); + iter.remove_operand(3); + return addcmul_cuda_scalar_tensor2_kernel(iter, tensor2_val, value); + } // note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float // and do math in fp32 for better accuracy. using accscalar_t = at::acc_type; @@ -57,6 +93,58 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { } } +#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 +constexpr char addcmul_scalar_tensor2_name[] = "addcmul_scalar_tensor2"; +#endif +void addcmul_cuda_scalar_tensor2_kernel(TensorIteratorBase& iter, const Scalar& scalar_tensor2, const Scalar& value) { + auto dtype = iter.common_dtype(); + + if (at::isComplexType(dtype)) { + // When using Jiterator, addcmul and addcdiv kernels get stuck during a + // promotion test on CUDA 11.3, so only enable that from CUDA 11.5: + // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209 + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 + AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() { + auto c = scalar_tensor2.to(); + auto alpha = value.to(); + + static const auto addcmul_scalar_tensor2_string = jiterator_stringify( + template T addcmul_scalar_tensor2(T a, T b, T c, T alpha) { return a + alpha * (b * c); }); + + jitted_gpu_kernel< + /*name=*/addcmul_scalar_tensor2_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/2>( + iter, + addcmul_scalar_tensor2_string, + /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar, + /*scalar_val=*/0, + /*extra_args=*/std::make_tuple(c, alpha)); + }); + #else + AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() { + auto c = scalar_tensor2.to(); + auto alpha = value.to(); + gpu_kernel(iter, [alpha, c]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a + alpha * (b * c); + }); + }); + #endif + } else { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "addcmul_cuda", [&]() { + // note(mkozuki): If scalar_t is fp16 or bfloat16, cast scalar to float + // and do math in fp32 for better accuracy. + using accscalar_t = at::acc_type; + auto c = scalar_tensor2.to(); + auto alpha = value.to(); + gpu_kernel(iter, [alpha, c]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a + alpha * (static_cast(b) * c); + }); + }); + } +} + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 // return a + alpha * (b / static_cast(c)); constexpr char addcdiv_name[] = "addcdiv"; diff --git a/test/test_torch.py b/test/test_torch.py index af78be2cd64f..bb4f2669e561 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3289,9 +3289,10 @@ else: self.assertTrue(y.stride() == (1, 4)) # FIXME: move to elementwise ternary test suite + @parametrize("use_cpu_scalar", [True, False]) @dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) @dtypes(*set(get_all_math_dtypes('cpu'))) - def test_addcmul(self, device, dtype): + def test_addcmul(self, device, dtype, use_cpu_scalar): # Returns floating or integral scalar corresponding to dtype def _number(floating, integer, dtype): if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]: @@ -3311,7 +3312,10 @@ else: a = rand_tensor((2, 2), dtype=dtype, device=device) b = rand_tensor((2, 2), dtype=dtype, device=device) - c = rand_tensor((2, 2), dtype=dtype, device=device) + if use_cpu_scalar: + c = rand_tensor([], device="cpu", dtype=dtype) + else: + c = rand_tensor((2, 2), dtype=dtype, device=device) alpha = _number(0.5, 3, dtype) @@ -3331,6 +3335,21 @@ else: out = torch.addcmul(a, b, c, value=-1) self.assertTrue(not (out.isnan() or out.isinf())) + @onlyCUDA + def test_addcmul_cuda_errors_with_cpu_scalars(self, device): + # Logic is dtype agnostic, so dtype isn't tested + alpha = 0.5 + + a = torch.rand((2, 2), device=device) + b = torch.rand((2, 2), device=device) + c = torch.rand((2, 2), device=device) + scalar = torch.rand([], device="cpu") + + with self.assertRaisesRegex(RuntimeError, r'CPU Scalar support for tensor1 argument'): + torch.addcmul(a, scalar, c, value=alpha) + with self.assertRaisesRegex(RuntimeError, r'CPU Scalar support for self argument'): + torch.addcmul(scalar, b, c, value=alpha) + # FIXME: move to shape ops test suite def test_narrow_empty(self, device): x = torch.randn(2, 3, 4, device=device)