port at::pow to structured (#53669)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53669

This PR does two things:
* Ports `pow` to be structured
* Fixes a bug with how pow handles mixed cpu and cuda tensors

**bug fix**
Pow is a binary op, and all binary ops that use TensorIterator are currently written to handle the case when one of the inputs is a CUDA tensor, and the other is a zero-dimensional cpu tensor.

`pow` incidentally only handles one of the two cases: it fails when the CUDA tensor is passed as the exponent, e.g. `at::pow(torch.tensor(2.0, device='cpu'), torch.tensor([2, 2], device='cuda'))`. Porting `pow` to structured happened to change the error that was outputted from a `TORCH_CHECK` in TensorIterator to an `INTERNAL_ASSERT` in loop.cuh, so I ended up trying to fix the error and update the tests. I added more details in a comment on the PR.

**notes on the structured port**
Pow is a little weird, so I wrote down a couple of issues I noticed during the port:
* Multiple independent overloads. `pow` has two overloads that have their own cpu/cuda kernels, meaning one doesn't call the other. I have to update the names of the kernel overloads to make the compiler happy, since the codegen would otherwise try to generate two classes with the same name. `pow` actually has 3 overloads that all have `out` variants, so I ported all 3 to structured- one of them just happens to redispatch one of the others in most cases.
* Name propagation. Is name propagation implemented per operator? Or is expected to work for most/all ops by default. Right now it looks like it happens for TensorIterator ops by default. For ops that don't use TensorIterator, we need to explicitly pass the names through to the `set_output()` call in the meta function. This happened to matter for `pow` because it has 3 overloads, but only two of them directly use TensorIterator. I had to pass names directly to `set_output` in the 3rd overload to make tests happy.
*  Lack of `const Tensor &` in the C++ API. It's a goal to slowly make all `Tensor &` arguments const as part of the structured port, but in this case I needed to explicitly cast constness away because one structured kernel called back into the C++ API, which still has ordinary `Tensor &` arguments. This probably isn't something we'll fix soon, since we have boxing logic that actually relies on the `Tensor &` / `const Tensor &` distinction in some places.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D27029821

Pulled By: bdhirsh

fbshipit-source-id: c1786e770de6e6c2474b9a48210b88057ab1018e
This commit is contained in:
Brian Hirsh
2021-03-19 14:28:14 -07:00
committed by Facebook GitHub Bot
parent 454dd7ba86
commit 779cae9e42
8 changed files with 78 additions and 93 deletions

View File

@ -6,85 +6,64 @@
#include <ATen/ScalarOps.h>
#include <ATen/native/Resize.h>
namespace at { namespace native {
namespace at {
namespace meta {
DEFINE_DISPATCH(pow_tensor_tensor_stub);
DEFINE_DISPATCH(pow_tensor_scalar_stub);
Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) {
if (exp.dim() == 0 && exp.device().is_cpu() && base.is_cuda()) {
return native::pow_out(result, base, exp.item());
}
auto iter = TensorIterator::binary_op(result, base, exp);
pow_tensor_tensor_stub(iter.device_type(), iter);
return result;
TORCH_META_FUNC2(pow, Tensor_Tensor) (const Tensor& base, const Tensor& exp) {
build_binary_op(maybe_get_output(), base, exp);
}
Tensor& pow_out(Tensor& result, const Tensor& base, const Scalar& exp) {
TORCH_META_FUNC2(pow, Tensor_Scalar) (const Tensor& base, const Scalar& exp) {
// Numpy compatibility check:
TORCH_CHECK(!(isIntegralType(base.scalar_type(), true) &&
exp.isIntegral(true) && exp.toLong() < 0),
"Integers to negative integer powers are not allowed.");
auto common_dtype = at::result_type(base, exp);
TORCH_CHECK(at::can_cast(common_dtype, result.scalar_type()),
"result type ", common_dtype, " can't be cast to the desired output type ",
result.scalar_type());
build_unary_op(maybe_get_output(), base.to(common_dtype));
}
TORCH_META_FUNC2(pow, Scalar) (const Scalar& base, const Tensor& exp) {
// This overload doesn't directly use TensorIterator. It attempts to short-circuit,
// but otherwise redispatches to the Tensor_Tensor overload.
auto dtype = at::result_type(base, exp);
set_output(0, exp.sizes(), {}, exp.options().dtype(dtype), exp.names());
}
} // namespace meta
namespace native {
DEFINE_DISPATCH(pow_tensor_tensor_stub);
DEFINE_DISPATCH(pow_tensor_scalar_stub);
TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, const Tensor& out) {
if (exp.dim() == 0 && exp.device().is_cpu() && base.is_cuda()) {
at::pow_out(const_cast<Tensor&>(out), base, exp.item()); // redispatch!
} else {
pow_tensor_tensor_stub(device_type(), *this);
}
}
TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
auto common_dtype = at::result_type(base, exp);
if (exp.equal(0.0)) {
resize_output(result, base.sizes());
result.fill_(1);
namedinference::propagate_names(result, base);
out.fill_(1);
} else if (exp.equal(1.0)) {
resize_output(result, base.sizes());
result.copy_(base);
namedinference::propagate_names(result, base);
out.copy_(base);
} else {
auto iter = TensorIterator::unary_op(result, base.to(common_dtype));
pow_tensor_scalar_stub(iter.device_type(), iter, exp);
pow_tensor_scalar_stub(device_type(), *this, exp);
}
return result;
}
Tensor& pow_out(Tensor& result, const Scalar& base, const Tensor& exp) {
TORCH_IMPL_FUNC(pow_Scalar_out) (const Scalar& base, const Tensor& exp, const Tensor& out) {
if (base.isComplex() && base.toComplexDouble() == 1.0) {
resize_output(result, exp.sizes());
result.fill_(1);
namedinference::propagate_names(result, exp);
out.fill_(1);
} else if (!base.isComplex() && base.toDouble() == 1.0) {
resize_output(result, exp.sizes());
result.fill_(1);
namedinference::propagate_names(result, exp);
out.fill_(1);
} else {
native::pow_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
at::pow_out(const_cast<Tensor&>(out), c10::scalar_to_tensor(base, exp.device()), exp); // redispatch!
}
return result;
}
Tensor& pow_(Tensor& base, const Tensor& other) {
return native::pow_out(base, base, other);
}
Tensor& pow_(Tensor& base, const Scalar& alpha) {
return native::pow_out(base, base, alpha);
}
Tensor pow(const Tensor& base, const Tensor& exp) {
auto dtype = at::result_type(base, exp);
Tensor result = at::empty({0}, base.options().dtype(dtype));
return native::pow_out(result, base, exp);
}
Tensor pow(const Tensor& base, const Scalar& exp) {
auto dtype = at::result_type(base, exp);
Tensor result = at::empty_like(base, base.options().dtype(dtype), MemoryFormat::Preserve);
return native::pow_out(result, base, exp);
}
Tensor pow(const Scalar& base, const Tensor& exp) {
auto dtype = at::result_type(base, exp);
Tensor result = at::empty_like(exp, exp.options().dtype(dtype), MemoryFormat::Preserve);
return native::pow_out(result, base, exp);
}
Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {

View File

@ -54,8 +54,8 @@ static inline HOST_DEVICE T powi(T a, T b) {
return powi_impl(a, b);
}
using pow_tensor_tensor_fn = void (*)(TensorIterator&);
using pow_tensor_scalar_fn = void (*)(TensorIterator&, const Scalar&);
using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const Scalar&);
DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);

View File

@ -10,7 +10,7 @@ namespace at { namespace native {
namespace {
void pow_tensor_tensor_kernel(TensorIterator& iter) {
void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.dtype()) || isComplexType(iter.dtype())) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "pow", [&]() {
using Vec = Vec256<scalar_t>;
@ -34,7 +34,7 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
}
}
void pow_tensor_scalar_kernel(TensorIterator& iter, const Scalar& exp_scalar) {
void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar) {
if (isFloatingType(iter.dtype())) {
const auto exp = exp_scalar.to<double>();
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:

View File

@ -101,22 +101,22 @@ static inline __host__ __device__ B complex_pow_(B base, E exp) {
}
#endif
void pow_tensor_tensor_kernel(TensorIterator& iter) {
void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
if (isComplexType(iter.dtype())) {
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "pow_cuda", [&]() {
gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
return complex_pow_(base, exp);
});
});
} else if (isFloatingType(iter.dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "pow_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
return pow_(base, exp);
});
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
return native::powi(base, exp);
});
});
@ -124,7 +124,7 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
}
template<typename Base_type, typename Exp_type>
void pow_tensor_scalar_kernel_impl(TensorIterator& iter,
void pow_tensor_scalar_kernel_impl(TensorIteratorBase& iter,
Exp_type exp) {
const auto d_exp = static_cast<double>(exp);
if (d_exp == 0.5) {
@ -158,7 +158,7 @@ void pow_tensor_scalar_kernel_impl(TensorIterator& iter,
}
}
void pow_tensor_scalar_kernel(TensorIterator& iter, const Scalar& exp_scalar) {
void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar) {
if (isComplexType(iter.dtype()) || exp_scalar.isComplex()) {
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "pow_cuda", [&]() {
const auto exp = exp_scalar.to<scalar_t>();

View File

@ -6418,45 +6418,43 @@
QuantizedCPU: equal_quantized_cpu
- func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: pow_out
CPU, CUDA: pow_Tensor_Tensor_out
- func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
structured_delegate: pow.Tensor_Tensor_out
variants: method, function
dispatch:
CPU, CUDA: pow
- func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
structured: True
dispatch:
CPU, CUDA: pow_out
CPU, CUDA: pow_Scalar_out
- func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor
dispatch:
CPU, CUDA: pow
structured_delegate: pow.Scalar_out
- func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: pow_out
CPU, CUDA: pow_Tensor_Scalar_out
SparseCPU, SparseCUDA: pow_out_sparse_scalar
- func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
structured_delegate: pow.Tensor_Scalar_out
variants: function, method
dispatch:
CPU, CUDA: pow
SparseCPU, SparseCUDA: pow_sparse_scalar
- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
structured_delegate: pow.Tensor_Scalar_out
variants: method
dispatch:
CPU, CUDA: pow_
- func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
structured_delegate: pow.Tensor_Tensor_out
variants: method
dispatch:
CPU, CUDA: pow_
- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures

View File

@ -200,7 +200,7 @@ SparseTensor sqrt_sparse(const SparseTensor& t) {
// TODO: add in-place variant
SparseTensor& pow_out_sparse_scalar(SparseTensor& r, const SparseTensor& t_, const Scalar& value) {
SparseTensor& pow_out_sparse_scalar(const SparseTensor& t_, const Scalar& value, SparseTensor& r) {
AT_ASSERT(r.is_sparse());
AT_ASSERT(t_.is_sparse());
TORCH_CHECK(value.toDouble() != 0, "pow: cannot raise to zeroth power on sparse tensor; it would make the result tensor dense");
@ -220,7 +220,7 @@ SparseTensor& pow_out_sparse_scalar(SparseTensor& r, const SparseTensor& t_, con
SparseTensor pow_sparse_scalar(const SparseTensor& t, const Scalar& value) {
SparseTensor r = at::empty({0}, t.options());
pow_out_sparse_scalar(r, t, value);
pow_out_sparse_scalar(t, value, r);
return r;
}

View File

@ -654,12 +654,16 @@ class TestBinaryUfuncs(TestCase):
actual = base.pow(exponent)
self.assertEqual(actual, expected.to(actual))
actual = base.clone()
if torch.can_cast(torch.result_type(base, exponent), base.dtype):
# When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since
# `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output
if base.dim() == 0 and base.device.type == 'cpu' and exponent.device.type == 'cuda':
regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!'
elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
actual2 = actual.pow_(exponent)
self.assertEqual(actual, expected)
self.assertEqual(actual2, expected)
else:
self.assertRaisesRegex(RuntimeError, "can't be cast", lambda: actual.pow_(exponent))
self.assertRaisesRegex(RuntimeError, "Found dtype \\w+ but expected \\w+", lambda: actual.pow_(exponent))
actual = torch.pow(base, exponent)
self.assertEqual(actual, expected.to(actual))
@ -715,11 +719,15 @@ class TestBinaryUfuncs(TestCase):
@onlyCUDA
def test_cpu_tensor_pow_cuda_scalar_tensor(self, device):
cpu_tensors = [torch.randn((3, 3), device='cpu'), torch.tensor(3.0, device='cpu')]
cuda_tensors = [torch.tensor(5.0, device='cuda'), torch.tensor(-3, device='cuda')]
for base, exp in product(cpu_tensors, cuda_tensors):
for exp in cuda_tensors:
base = torch.randn((3, 3), device='cpu')
regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!'
self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp)
for exp in cuda_tensors:
# Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension
base = torch.tensor(3.0, device='cpu')
self._test_pow(base, exp)
@onlyOnCPUAndCUDA
@dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False)))

View File

@ -739,14 +739,14 @@ REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
if (p_node->Input(0).isTensor()) {
if (p_node->Input(1).isTensor()) {
at::native::pow_out(
at::cpu::pow_out(
out_t, p_node->Input(0).toTensor(), p_node->Input(1).toTensor());
} else {
at::native::pow_out(
at::cpu::pow_out(
out_t, p_node->Input(0).toTensor(), p_node->Input(1).toScalar());
}
} else {
at::native::pow_out(
at::cpu::pow_out(
out_t, p_node->Input(0).toScalar(), p_node->Input(1).toTensor());
}
};