torch.prod backward for complex types. (#48125)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/53511
torch.det does depend on torch.prod, which in turn depends on several other functions, and they also depend on torch.prod, so there is a circular relationship, hence this PR will enable complex backward support for several functions at once.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48125

Reviewed By: pbelevich

Differential Revision: D27188589

Pulled By: anjali411

fbshipit-source-id: bbb80f8ecb83a0c3bea2b917627d3cd3b84eb09a
This commit is contained in:
Nikita Vedeneev
2021-03-19 09:42:53 -07:00
committed by Facebook GitHub Bot
parent cc7a28d727
commit 61b074581c
7 changed files with 124 additions and 48 deletions

View File

@ -273,10 +273,16 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
return grad; return grad;
} }
const auto w = output * grad; // To enable complex support.
// From this line on `input_conj` and output_conj`
// are interchangeable with `input` and `output`.
auto input_conj = input.conj();
auto output_conj = output.conj();
const auto w = output_conj * grad;
const auto is_zero = input == 0; const auto is_zero = input == 0;
if (!(is_zero.any().item<uint8_t>())) { if (!(is_zero.any().item<uint8_t>())) {
return reversed_cumsum(w, dim).div(input); return reversed_cumsum(w, dim).div(input_conj);
} }
// If we are not computing a second order gradient, we can use an // If we are not computing a second order gradient, we can use an
@ -309,7 +315,7 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
auto mask = cumsum == 0; auto mask = cumsum == 0;
// equiv to grad_input[mask] = deriv[grad] // equiv to grad_input[mask] = deriv[grad]
grad_input.masked_scatter_(mask, grad_input.masked_scatter_(mask,
reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input).masked_select(mask)); reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask));
// select everything from the first zero to the second zero [z1, z2) // select everything from the first zero to the second zero [z1, z2)
mask = cumsum == 1; mask = cumsum == 1;
@ -332,10 +338,10 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
// relu_() necessary as gather does not support negative indices // relu_() necessary as gather does not support negative indices
// finally, we do grad_input[z1] = dy_j / dx_z1 // finally, we do grad_input[z1] = dy_j / dx_z1
grad_input.masked_scatter_(first_zero_mask, grad_input.masked_scatter_(first_zero_mask,
input.masked_fill(~mask, 1.).cumprod(dim) input_conj.masked_fill(~mask, 1.).cumprod(dim)
.mul_(grad.masked_fill(cumsum != 1, 0.)) .mul_(grad.masked_fill(cumsum != 1, 0.))
.sum(dim, /*keepdim*/true) .sum(dim, /*keepdim*/true)
.mul_(at::gather(output, dim, (first_zero_index - 1).relu_()) .mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_())
.masked_fill_(first_zero_index == 0, 1.)) .masked_fill_(first_zero_index == 0, 1.))
.masked_select(first_zero_mask)); .masked_select(first_zero_mask));
} else { // GradMode::enabled() } else { // GradMode::enabled()
@ -367,14 +373,14 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
Tensor omitted_products; Tensor omitted_products;
for (int k = 0; k < dim_size; ++k) { for (int k = 0; k < dim_size; ++k) {
if (k == 0) { if (k == 0) {
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim); prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k + 1), dim);
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim); omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
} else if (k == dim_size - 1) { } else if (k == dim_size - 1) {
const Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true); const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
omitted_products = prods_until_k; omitted_products = prods_until_k;
} else { } else {
const Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true); const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim); prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k+1), dim);
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1; omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
omitted_products = at::cat({prods_until_k, omitted_products}, dim); omitted_products = at::cat({prods_until_k, omitted_products}, dim);
} }

View File

@ -939,7 +939,7 @@ namespace {
template <typename mask_t> template <typename mask_t>
void masked_fill_kernel(TensorIterator& iter, const Scalar& value) { void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND3( AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kHalf, kBFloat16, iter.common_dtype(), "masked_fill_", [&]() { kBool, kHalf, kBFloat16, iter.common_dtype(), "masked_fill_", [&]() {
const auto value_ = value.to<scalar_t>(); const auto value_ = value.to<scalar_t>();
gpu_kernel( gpu_kernel(

View File

@ -1416,6 +1416,11 @@ class TestCuda(TestCase):
x = torch.ones(240000, device='cuda', dtype=torch.float32) x = torch.ones(240000, device='cuda', dtype=torch.float32)
self.assertEqual(x.prod(), 1) self.assertEqual(x.prod(), 1)
# test for complex types. Note 240k is divisible by 4
for dtype in [torch.cfloat, torch.cdouble]:
x = torch.ones(240000, device='cuda', dtype=dtype) * (0 + 1j)
self.assertEqual(x.prod(), 1)
def test_multinomial_ext(self): def test_multinomial_ext(self):
# Test two corner cases from older PyTorch (Issue #4858) # Test two corner cases from older PyTorch (Issue #4858)
freqs = torch.cuda.FloatTensor([ freqs = torch.cuda.FloatTensor([

View File

@ -91,7 +91,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'replication_pad1d', 'replication_pad2d', 'replication_pad3d',
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum',
'eig', 'lerp', 'linalg_vector_norm' 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod'
} }
# Some operators invalidate the grad_accumulator. Let's reset it. # Some operators invalidate the grad_accumulator. Let's reset it.

View File

@ -466,7 +466,7 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim); Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
return grad * (exclusive_normal * exclusive_reverse); return grad * (exclusive_normal * exclusive_reverse).conj();
} }
// note that the gradient for prod is equivalent to: // note that the gradient for prod is equivalent to:
@ -482,7 +482,7 @@ Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& resu
} }
Tensor zero_idx = (input == 0).nonzero(); Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.numel() == 0) { if (zero_idx.numel() == 0) {
return (grad * result) / input; return grad * (result / input).conj();
} else if (zero_idx.size(0) > 1) { } else if (zero_idx.size(0) > 1) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else { } else {
@ -504,7 +504,7 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di
Tensor slice_zero_count = zero_mask.sum(dim, true); Tensor slice_zero_count = zero_mask.sum(dim, true);
int64_t total_zeros = slice_zero_count.sum().item<int64_t>(); int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
if (total_zeros == 0) { if (total_zeros == 0) {
return (grad * result) / input; return grad * (result / input).conj();
} else { } else {
return prod_safe_zeros_backward(grad, input, dim); return prod_safe_zeros_backward(grad, input, dim);
} }

View File

@ -23,7 +23,7 @@ from torch.testing._internal.common_device_type import \
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride,) skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride,)
from torch.testing._internal.common_cuda import CUDA11OrLater from torch.testing._internal.common_cuda import CUDA11OrLater
from torch.testing._internal.common_utils import \ from torch.testing._internal.common_utils import \
(prod_single_zero, random_square_matrix_of_rank, (random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_matrix, random_symmetric_psd_matrix,
random_symmetric_pd_matrix, make_nonzero_det, random_symmetric_pd_matrix, make_nonzero_det,
random_fullrank_matrix_distinct_singular_value, set_rng_seed, SEED, random_fullrank_matrix_distinct_singular_value, set_rng_seed, SEED,
@ -1478,6 +1478,69 @@ def sample_inputs_clamp(op_info, device, dtype, requires_grad):
output += [SampleInput(empty_tensor, args=(0.0, 1.0)), ] output += [SampleInput(empty_tensor, args=(0.0, 1.0)), ]
return output return output
def sample_inputs_cumprod(op_info, device, dtype, requires_grad):
def make_arg(shape):
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)
def prod_zeros(dim_select):
assert len(dim_select) == 2
result = make_arg(3 * (S,))
with torch.no_grad():
result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
return result
# will not be needed once OpInfo tests suport Iterables
def sample_generator():
for dim in range(3):
yield SampleInput((make_arg((S, S, S)), dim))
# Scalar tensors and empty tensor
for size in [(), (1,), (0,)]:
yield SampleInput((make_arg(size), 0))
yield SampleInput((prod_zeros([0, 1]), 1))
yield SampleInput((prod_zeros([0, 2]), 1))
yield SampleInput((prod_zeros([1, 2]), 1))
# test dtype kwarg
yield SampleInput((prod_zeros([1, 2]), 1), kwargs={'dtype': dtype})
return list(sample_generator())
def sample_inputs_prod(op_info, device, dtype, requires_grad):
def make_arg(shape):
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)
def prod_single_zero():
result = make_arg(2 * (S,))
with torch.no_grad():
result[0, 1] = 0
return result
# will not be needed once OpInfo tests support Iterables
def sample_generator():
for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
yield SampleInput(sample.input[0]) # only Tensor, ignore other inputs
yield sample
sample.kwargs['keepdim'] = True
yield sample
yield SampleInput(prod_single_zero())
yield SampleInput((make_arg((3, 3, 3)), 1))
yield SampleInput((make_arg((3, 3, 3)), 1), kwargs={'keepdim': True})
# test zero scalar tensor
zero = make_arg(())
with torch.no_grad():
zero.zero_()
yield SampleInput(zero)
yield SampleInput((zero, 0))
yield SampleInput((zero, 0), kwargs={'keepdim': True})
return list(sample_generator())
def sample_inputs_diag(op_info, device, dtype, requires_grad): def sample_inputs_diag(op_info, device, dtype, requires_grad):
vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad)) vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))
@ -2126,6 +2189,29 @@ op_db: List[OpInfo] = [
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'), SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
), ),
sample_inputs_func=sample_inputs_cumsum), sample_inputs_func=sample_inputs_cumsum),
OpInfo('cumprod',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16),
test_inplace_grad=False,
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/53360
# For integer inputs,
# inplace variant preserves dtype of `self` while method variant
# always promotes it to torch.long.
# >>> t = torch.randint(2, 10, (3, 2), dtype=torch.int8)
# >>> t.cumprod(0).dtype
# torch.int64
# >>> t.cumprod_(0).dtype
# torch.int8
SkipInfo('TestCommon', 'test_variant_consistency_eager',
dtypes=[torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32]),
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.bool, torch.float16]),
# cumprod does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out',
dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_cumprod),
UnaryUfuncInfo('deg2rad', UnaryUfuncInfo('deg2rad',
ref=np.radians, ref=np.radians,
decorators=(precisionOverride({torch.bfloat16: 7e-1, decorators=(precisionOverride({torch.bfloat16: 7e-1,
@ -2661,6 +2747,18 @@ op_db: List[OpInfo] = [
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
assert_autodiffed=True,), assert_autodiffed=True,),
OpInfo('prod',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
test_inplace_grad=False,
skips=(
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.float16, torch.bfloat16]),
# prod does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out',
dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_prod),
OpInfo('qr', OpInfo('qr',
op=torch.qr, op=torch.qr,
dtypes=floating_and_complex_types(), dtypes=floating_and_complex_types(),
@ -3728,25 +3826,6 @@ def method_tests():
('nansum', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('nansum', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('nansum', (S, S, S), ([1, 2],), 'multi_dim'), ('nansum', (S, S, S), ([1, 2],), 'multi_dim'),
('nansum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'), ('nansum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
('prod', (S, S, S), NO_ARGS),
('prod', (S, S, S), (1,), 'dim', (), [0]),
('prod', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('prod', (), NO_ARGS, 'scalar'),
('prod', (), (0,), 'scalar_dim', (), [0]),
('prod', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'),
('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'),
('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'),
('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', (), [0]),
('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', (), [0]),
('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', (), [0]),
('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', (), [0]),
('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', (), [0]),
('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', (), [0]),
('prod', prod_single_zero(S), NO_ARGS, 'single_zero'),
('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'),
('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]),
('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]),
('var_mean', (S, S, S), NO_ARGS, ''), ('var_mean', (S, S, S), NO_ARGS, ''),
('var_mean', (S, S, S), (1,), 'dim', [0]), ('var_mean', (S, S, S), (1,), 'dim', [0]),
('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]), ('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
@ -3770,14 +3849,6 @@ def method_tests():
('cummin', (S, S, S), (1,), 'dim1', (), [0]), ('cummin', (S, S, S), (1,), 'dim1', (), [0]),
('cummin', (), (0,), 'dim0_scalar', (), [0]), ('cummin', (), (0,), 'dim0_scalar', (), [0]),
('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), ident, {'dtype': torch.float64}), ('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), ident, {'dtype': torch.float64}),
('cumprod', (S, S, S), (0,)),
('cumprod', (S, S, S), (1,), 'dim1', (), [0]),
('cumprod', (), (0,), 'scalar'),
('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'),
('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', (), [0]),
('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', (), [0]),
('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', (), [0]),
('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', (), [0], (), ident, {'dtype': torch.float64}),
('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)), ('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)),
('unfold', (), (0, 1, 1), 'scalar', (), [0]), ('unfold', (), (0, 1, 1), 'scalar', (), [0]),
('unfold', (S, S, S, S), (0, 3, 1), '4d_dim0_step1', (), [0]), ('unfold', (S, S, S, S), (0, 3, 1), '4d_dim0_step1', (), [0]),

View File

@ -1644,12 +1644,6 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig
return result return result
def prod_single_zero(dim_size):
result = torch.randn(dim_size, dim_size)
result[0, 1] = 0
return result
def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
assert rank <= l assert rank <= l
A = torch.randn(l, l, dtype=dtype, device=device) A = torch.randn(l, l, dtype=dtype, device=device)