mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
torch.prod backward for complex types. (#48125)
Summary: Fixes https://github.com/pytorch/pytorch/issues/53511 torch.det does depend on torch.prod, which in turn depends on several other functions, and they also depend on torch.prod, so there is a circular relationship, hence this PR will enable complex backward support for several functions at once. Pull Request resolved: https://github.com/pytorch/pytorch/pull/48125 Reviewed By: pbelevich Differential Revision: D27188589 Pulled By: anjali411 fbshipit-source-id: bbb80f8ecb83a0c3bea2b917627d3cd3b84eb09a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cc7a28d727
commit
61b074581c
@ -273,10 +273,16 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
|
|||||||
return grad;
|
return grad;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto w = output * grad;
|
// To enable complex support.
|
||||||
|
// From this line on `input_conj` and output_conj`
|
||||||
|
// are interchangeable with `input` and `output`.
|
||||||
|
auto input_conj = input.conj();
|
||||||
|
auto output_conj = output.conj();
|
||||||
|
|
||||||
|
const auto w = output_conj * grad;
|
||||||
const auto is_zero = input == 0;
|
const auto is_zero = input == 0;
|
||||||
if (!(is_zero.any().item<uint8_t>())) {
|
if (!(is_zero.any().item<uint8_t>())) {
|
||||||
return reversed_cumsum(w, dim).div(input);
|
return reversed_cumsum(w, dim).div(input_conj);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are not computing a second order gradient, we can use an
|
// If we are not computing a second order gradient, we can use an
|
||||||
@ -309,7 +315,7 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
|
|||||||
auto mask = cumsum == 0;
|
auto mask = cumsum == 0;
|
||||||
// equiv to grad_input[mask] = deriv[grad]
|
// equiv to grad_input[mask] = deriv[grad]
|
||||||
grad_input.masked_scatter_(mask,
|
grad_input.masked_scatter_(mask,
|
||||||
reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input).masked_select(mask));
|
reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask));
|
||||||
// select everything from the first zero to the second zero [z1, z2)
|
// select everything from the first zero to the second zero [z1, z2)
|
||||||
mask = cumsum == 1;
|
mask = cumsum == 1;
|
||||||
|
|
||||||
@ -332,10 +338,10 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
|
|||||||
// relu_() necessary as gather does not support negative indices
|
// relu_() necessary as gather does not support negative indices
|
||||||
// finally, we do grad_input[z1] = dy_j / dx_z1
|
// finally, we do grad_input[z1] = dy_j / dx_z1
|
||||||
grad_input.masked_scatter_(first_zero_mask,
|
grad_input.masked_scatter_(first_zero_mask,
|
||||||
input.masked_fill(~mask, 1.).cumprod(dim)
|
input_conj.masked_fill(~mask, 1.).cumprod(dim)
|
||||||
.mul_(grad.masked_fill(cumsum != 1, 0.))
|
.mul_(grad.masked_fill(cumsum != 1, 0.))
|
||||||
.sum(dim, /*keepdim*/true)
|
.sum(dim, /*keepdim*/true)
|
||||||
.mul_(at::gather(output, dim, (first_zero_index - 1).relu_())
|
.mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_())
|
||||||
.masked_fill_(first_zero_index == 0, 1.))
|
.masked_fill_(first_zero_index == 0, 1.))
|
||||||
.masked_select(first_zero_mask));
|
.masked_select(first_zero_mask));
|
||||||
} else { // GradMode::enabled()
|
} else { // GradMode::enabled()
|
||||||
@ -367,14 +373,14 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
|
|||||||
Tensor omitted_products;
|
Tensor omitted_products;
|
||||||
for (int k = 0; k < dim_size; ++k) {
|
for (int k = 0; k < dim_size; ++k) {
|
||||||
if (k == 0) {
|
if (k == 0) {
|
||||||
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim);
|
prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k + 1), dim);
|
||||||
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
|
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
|
||||||
} else if (k == dim_size - 1) {
|
} else if (k == dim_size - 1) {
|
||||||
const Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
|
const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
|
||||||
omitted_products = prods_until_k;
|
omitted_products = prods_until_k;
|
||||||
} else {
|
} else {
|
||||||
const Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
|
const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
|
||||||
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim);
|
prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k+1), dim);
|
||||||
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
|
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
|
||||||
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
|
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -939,7 +939,7 @@ namespace {
|
|||||||
|
|
||||||
template <typename mask_t>
|
template <typename mask_t>
|
||||||
void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
|
void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
|
||||||
AT_DISPATCH_ALL_TYPES_AND3(
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||||
kBool, kHalf, kBFloat16, iter.common_dtype(), "masked_fill_", [&]() {
|
kBool, kHalf, kBFloat16, iter.common_dtype(), "masked_fill_", [&]() {
|
||||||
const auto value_ = value.to<scalar_t>();
|
const auto value_ = value.to<scalar_t>();
|
||||||
gpu_kernel(
|
gpu_kernel(
|
||||||
|
|||||||
@ -1416,6 +1416,11 @@ class TestCuda(TestCase):
|
|||||||
x = torch.ones(240000, device='cuda', dtype=torch.float32)
|
x = torch.ones(240000, device='cuda', dtype=torch.float32)
|
||||||
self.assertEqual(x.prod(), 1)
|
self.assertEqual(x.prod(), 1)
|
||||||
|
|
||||||
|
# test for complex types. Note 240k is divisible by 4
|
||||||
|
for dtype in [torch.cfloat, torch.cdouble]:
|
||||||
|
x = torch.ones(240000, device='cuda', dtype=dtype) * (0 + 1j)
|
||||||
|
self.assertEqual(x.prod(), 1)
|
||||||
|
|
||||||
def test_multinomial_ext(self):
|
def test_multinomial_ext(self):
|
||||||
# Test two corner cases from older PyTorch (Issue #4858)
|
# Test two corner cases from older PyTorch (Issue #4858)
|
||||||
freqs = torch.cuda.FloatTensor([
|
freqs = torch.cuda.FloatTensor([
|
||||||
|
|||||||
@ -91,7 +91,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||||||
'replication_pad1d', 'replication_pad2d', 'replication_pad3d',
|
'replication_pad1d', 'replication_pad2d', 'replication_pad3d',
|
||||||
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
|
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
|
||||||
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum',
|
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum',
|
||||||
'eig', 'lerp', 'linalg_vector_norm'
|
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod'
|
||||||
}
|
}
|
||||||
|
|
||||||
# Some operators invalidate the grad_accumulator. Let's reset it.
|
# Some operators invalidate the grad_accumulator. Let's reset it.
|
||||||
|
|||||||
@ -466,7 +466,7 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d
|
|||||||
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
|
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
|
||||||
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
|
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
|
||||||
|
|
||||||
return grad * (exclusive_normal * exclusive_reverse);
|
return grad * (exclusive_normal * exclusive_reverse).conj();
|
||||||
}
|
}
|
||||||
|
|
||||||
// note that the gradient for prod is equivalent to:
|
// note that the gradient for prod is equivalent to:
|
||||||
@ -482,7 +482,7 @@ Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& resu
|
|||||||
}
|
}
|
||||||
Tensor zero_idx = (input == 0).nonzero();
|
Tensor zero_idx = (input == 0).nonzero();
|
||||||
if (zero_idx.numel() == 0) {
|
if (zero_idx.numel() == 0) {
|
||||||
return (grad * result) / input;
|
return grad * (result / input).conj();
|
||||||
} else if (zero_idx.size(0) > 1) {
|
} else if (zero_idx.size(0) > 1) {
|
||||||
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||||
} else {
|
} else {
|
||||||
@ -504,7 +504,7 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di
|
|||||||
Tensor slice_zero_count = zero_mask.sum(dim, true);
|
Tensor slice_zero_count = zero_mask.sum(dim, true);
|
||||||
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
|
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
|
||||||
if (total_zeros == 0) {
|
if (total_zeros == 0) {
|
||||||
return (grad * result) / input;
|
return grad * (result / input).conj();
|
||||||
} else {
|
} else {
|
||||||
return prod_safe_zeros_backward(grad, input, dim);
|
return prod_safe_zeros_backward(grad, input, dim);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from torch.testing._internal.common_device_type import \
|
|||||||
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride,)
|
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride,)
|
||||||
from torch.testing._internal.common_cuda import CUDA11OrLater
|
from torch.testing._internal.common_cuda import CUDA11OrLater
|
||||||
from torch.testing._internal.common_utils import \
|
from torch.testing._internal.common_utils import \
|
||||||
(prod_single_zero, random_square_matrix_of_rank,
|
(random_square_matrix_of_rank,
|
||||||
random_symmetric_matrix, random_symmetric_psd_matrix,
|
random_symmetric_matrix, random_symmetric_psd_matrix,
|
||||||
random_symmetric_pd_matrix, make_nonzero_det,
|
random_symmetric_pd_matrix, make_nonzero_det,
|
||||||
random_fullrank_matrix_distinct_singular_value, set_rng_seed, SEED,
|
random_fullrank_matrix_distinct_singular_value, set_rng_seed, SEED,
|
||||||
@ -1478,6 +1478,69 @@ def sample_inputs_clamp(op_info, device, dtype, requires_grad):
|
|||||||
output += [SampleInput(empty_tensor, args=(0.0, 1.0)), ]
|
output += [SampleInput(empty_tensor, args=(0.0, 1.0)), ]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def sample_inputs_cumprod(op_info, device, dtype, requires_grad):
|
||||||
|
def make_arg(shape):
|
||||||
|
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
|
||||||
|
return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
def prod_zeros(dim_select):
|
||||||
|
assert len(dim_select) == 2
|
||||||
|
result = make_arg(3 * (S,))
|
||||||
|
with torch.no_grad():
|
||||||
|
result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
|
||||||
|
result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
|
||||||
|
result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
|
||||||
|
return result
|
||||||
|
|
||||||
|
# will not be needed once OpInfo tests suport Iterables
|
||||||
|
def sample_generator():
|
||||||
|
for dim in range(3):
|
||||||
|
yield SampleInput((make_arg((S, S, S)), dim))
|
||||||
|
# Scalar tensors and empty tensor
|
||||||
|
for size in [(), (1,), (0,)]:
|
||||||
|
yield SampleInput((make_arg(size), 0))
|
||||||
|
|
||||||
|
yield SampleInput((prod_zeros([0, 1]), 1))
|
||||||
|
yield SampleInput((prod_zeros([0, 2]), 1))
|
||||||
|
yield SampleInput((prod_zeros([1, 2]), 1))
|
||||||
|
|
||||||
|
# test dtype kwarg
|
||||||
|
yield SampleInput((prod_zeros([1, 2]), 1), kwargs={'dtype': dtype})
|
||||||
|
|
||||||
|
return list(sample_generator())
|
||||||
|
|
||||||
|
def sample_inputs_prod(op_info, device, dtype, requires_grad):
|
||||||
|
def make_arg(shape):
|
||||||
|
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
|
||||||
|
return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
def prod_single_zero():
|
||||||
|
result = make_arg(2 * (S,))
|
||||||
|
with torch.no_grad():
|
||||||
|
result[0, 1] = 0
|
||||||
|
return result
|
||||||
|
|
||||||
|
# will not be needed once OpInfo tests support Iterables
|
||||||
|
def sample_generator():
|
||||||
|
for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
|
||||||
|
yield SampleInput(sample.input[0]) # only Tensor, ignore other inputs
|
||||||
|
yield sample
|
||||||
|
sample.kwargs['keepdim'] = True
|
||||||
|
yield sample
|
||||||
|
yield SampleInput(prod_single_zero())
|
||||||
|
yield SampleInput((make_arg((3, 3, 3)), 1))
|
||||||
|
yield SampleInput((make_arg((3, 3, 3)), 1), kwargs={'keepdim': True})
|
||||||
|
|
||||||
|
# test zero scalar tensor
|
||||||
|
zero = make_arg(())
|
||||||
|
with torch.no_grad():
|
||||||
|
zero.zero_()
|
||||||
|
yield SampleInput(zero)
|
||||||
|
yield SampleInput((zero, 0))
|
||||||
|
yield SampleInput((zero, 0), kwargs={'keepdim': True})
|
||||||
|
|
||||||
|
return list(sample_generator())
|
||||||
|
|
||||||
def sample_inputs_diag(op_info, device, dtype, requires_grad):
|
def sample_inputs_diag(op_info, device, dtype, requires_grad):
|
||||||
vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))
|
vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))
|
||||||
|
|
||||||
@ -2126,6 +2189,29 @@ op_db: List[OpInfo] = [
|
|||||||
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
|
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
|
||||||
),
|
),
|
||||||
sample_inputs_func=sample_inputs_cumsum),
|
sample_inputs_func=sample_inputs_cumsum),
|
||||||
|
OpInfo('cumprod',
|
||||||
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16),
|
||||||
|
test_inplace_grad=False,
|
||||||
|
skips=(
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/53360
|
||||||
|
# For integer inputs,
|
||||||
|
# inplace variant preserves dtype of `self` while method variant
|
||||||
|
# always promotes it to torch.long.
|
||||||
|
# >>> t = torch.randint(2, 10, (3, 2), dtype=torch.int8)
|
||||||
|
# >>> t.cumprod(0).dtype
|
||||||
|
# torch.int64
|
||||||
|
# >>> t.cumprod_(0).dtype
|
||||||
|
# torch.int8
|
||||||
|
SkipInfo('TestCommon', 'test_variant_consistency_eager',
|
||||||
|
dtypes=[torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32]),
|
||||||
|
SkipInfo('TestCommon', 'test_variant_consistency_jit',
|
||||||
|
dtypes=[torch.bool, torch.float16]),
|
||||||
|
# cumprod does not correctly warn when resizing out= inputs
|
||||||
|
SkipInfo('TestCommon', 'test_out',
|
||||||
|
dtypes=[torch.float32]),
|
||||||
|
),
|
||||||
|
sample_inputs_func=sample_inputs_cumprod),
|
||||||
UnaryUfuncInfo('deg2rad',
|
UnaryUfuncInfo('deg2rad',
|
||||||
ref=np.radians,
|
ref=np.radians,
|
||||||
decorators=(precisionOverride({torch.bfloat16: 7e-1,
|
decorators=(precisionOverride({torch.bfloat16: 7e-1,
|
||||||
@ -2661,6 +2747,18 @@ op_db: List[OpInfo] = [
|
|||||||
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
|
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||||
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
|
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||||
assert_autodiffed=True,),
|
assert_autodiffed=True,),
|
||||||
|
OpInfo('prod',
|
||||||
|
dtypes=all_types_and_complex_and(torch.bool),
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||||
|
test_inplace_grad=False,
|
||||||
|
skips=(
|
||||||
|
SkipInfo('TestCommon', 'test_variant_consistency_jit',
|
||||||
|
dtypes=[torch.float16, torch.bfloat16]),
|
||||||
|
# prod does not correctly warn when resizing out= inputs
|
||||||
|
SkipInfo('TestCommon', 'test_out',
|
||||||
|
dtypes=[torch.float32]),
|
||||||
|
),
|
||||||
|
sample_inputs_func=sample_inputs_prod),
|
||||||
OpInfo('qr',
|
OpInfo('qr',
|
||||||
op=torch.qr,
|
op=torch.qr,
|
||||||
dtypes=floating_and_complex_types(),
|
dtypes=floating_and_complex_types(),
|
||||||
@ -3728,25 +3826,6 @@ def method_tests():
|
|||||||
('nansum', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
|
('nansum', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
|
||||||
('nansum', (S, S, S), ([1, 2],), 'multi_dim'),
|
('nansum', (S, S, S), ([1, 2],), 'multi_dim'),
|
||||||
('nansum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
|
('nansum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
|
||||||
('prod', (S, S, S), NO_ARGS),
|
|
||||||
('prod', (S, S, S), (1,), 'dim', (), [0]),
|
|
||||||
('prod', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
|
|
||||||
('prod', (), NO_ARGS, 'scalar'),
|
|
||||||
('prod', (), (0,), 'scalar_dim', (), [0]),
|
|
||||||
('prod', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
|
|
||||||
('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'),
|
|
||||||
('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'),
|
|
||||||
('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'),
|
|
||||||
('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', (), [0]),
|
|
||||||
('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', (), [0]),
|
|
||||||
('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', (), [0]),
|
|
||||||
('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', (), [0]),
|
|
||||||
('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', (), [0]),
|
|
||||||
('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', (), [0]),
|
|
||||||
('prod', prod_single_zero(S), NO_ARGS, 'single_zero'),
|
|
||||||
('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'),
|
|
||||||
('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]),
|
|
||||||
('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]),
|
|
||||||
('var_mean', (S, S, S), NO_ARGS, ''),
|
('var_mean', (S, S, S), NO_ARGS, ''),
|
||||||
('var_mean', (S, S, S), (1,), 'dim', [0]),
|
('var_mean', (S, S, S), (1,), 'dim', [0]),
|
||||||
('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
|
('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
|
||||||
@ -3770,14 +3849,6 @@ def method_tests():
|
|||||||
('cummin', (S, S, S), (1,), 'dim1', (), [0]),
|
('cummin', (S, S, S), (1,), 'dim1', (), [0]),
|
||||||
('cummin', (), (0,), 'dim0_scalar', (), [0]),
|
('cummin', (), (0,), 'dim0_scalar', (), [0]),
|
||||||
('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), ident, {'dtype': torch.float64}),
|
('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), ident, {'dtype': torch.float64}),
|
||||||
('cumprod', (S, S, S), (0,)),
|
|
||||||
('cumprod', (S, S, S), (1,), 'dim1', (), [0]),
|
|
||||||
('cumprod', (), (0,), 'scalar'),
|
|
||||||
('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'),
|
|
||||||
('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', (), [0]),
|
|
||||||
('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', (), [0]),
|
|
||||||
('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', (), [0]),
|
|
||||||
('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', (), [0], (), ident, {'dtype': torch.float64}),
|
|
||||||
('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)),
|
('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)),
|
||||||
('unfold', (), (0, 1, 1), 'scalar', (), [0]),
|
('unfold', (), (0, 1, 1), 'scalar', (), [0]),
|
||||||
('unfold', (S, S, S, S), (0, 3, 1), '4d_dim0_step1', (), [0]),
|
('unfold', (S, S, S, S), (0, 3, 1), '4d_dim0_step1', (), [0]),
|
||||||
|
|||||||
@ -1644,12 +1644,6 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def prod_single_zero(dim_size):
|
|
||||||
result = torch.randn(dim_size, dim_size)
|
|
||||||
result[0, 1] = 0
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
|
def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
|
||||||
assert rank <= l
|
assert rank <= l
|
||||||
A = torch.randn(l, l, dtype=dtype, device=device)
|
A = torch.randn(l, l, dtype=dtype, device=device)
|
||||||
|
|||||||
Reference in New Issue
Block a user