as_strided support for functionalization; introduce as_strided_scatter

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77128

Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-05-24 08:30:35 -07:00
committed by PyTorch MergeBot
parent 7ddc1425ff
commit 3a921f2d26
12 changed files with 147 additions and 3 deletions

View File

@ -128,8 +128,8 @@ Tensor FunctionalInverses::_neg_view_copy_inverse(const Tensor& base, const Tens
}
Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
TORCH_INTERNAL_ASSERT(false, "as_strided has not been implemented in the functionalization pass yet");
return Tensor();
// Pessimism: we can't reapply views for as_strided_scatter.
return base.as_strided_scatter(mutated_view, size, stride, storage_offset);
}
Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t offset, int64_t dim1, int64_t dim2) {

View File

@ -3318,6 +3318,15 @@ at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64
slice.copy_(src);
return output;
}
at::Tensor as_strided_scatter(const at::Tensor& self, const at::Tensor& src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
// See Note [as_strided_scatter backward support]
TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
auto output = self.clone();
auto slice = output.as_strided(size, stride, storage_offset);
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
slice.copy_(src);
return output;
}
// The default implementation of lift is a no-op.
// If TLS is set appropriately (for wrapper-tensor keys like Functionalize or functorch transforms),

View File

@ -4320,6 +4320,13 @@
dispatch:
CompositeExplicitAutograd: diagonal_scatter
- func: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: as_strided_scatter
- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
variants: function, method
dispatch:

View File

@ -186,6 +186,18 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
$0 = input('input')
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""")
def test_as_strided(self):
def f(x):
y = x.as_strided((2,), (2,), 1)
y.add_(1)
return x
self.assert_functionalization(f, torch.ones(9))
logs = self.get_logs(f, torch.ones(9))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.as_strided_copy.default($0, [2], [2], 1)
$2 = torch._ops.aten.add.Tensor($1, 1)""")
def test_tensor_list_composite(self):
def f(x):
# Test an op with TensorList input

View File

@ -1412,6 +1412,12 @@
src: grad.diagonal(offset, dim1, dim2)
result: auto_linear
- name: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor
self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset)
# See Note [as_strided_scatter backward support]
src: grad.contiguous().as_strided(size, stride, storage_offset)
result: auto_linear
- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
self: slogdet_backward(grad, self, sign, logabsdet)
output_differentiability: [false, true]

View File

@ -197,6 +197,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"select",
"where",
"as_strided",
"as_strided_scatter",
"slice",
"constant_pad_nd",
"unbind",

View File

@ -1201,11 +1201,18 @@ See :func:`torch.diagonal`
add_docstr_all('diagonal_scatter',
r"""
diagonal(src, offset=0, dim1=0, dim2=1) -> Tensor
diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor
See :func:`torch.diagonal_scatter`
""")
add_docstr_all('as_strided_scatter',
r"""
as_strided_scatter(src, size, stride, storage_offset=0) -> Tensor
See :func:`torch.as_strided_scatter`
""")
add_docstr_all('fill_diagonal_',
r"""
fill_diagonal_(fill_value, wrap=False) -> Tensor

View File

@ -3291,6 +3291,47 @@ Examples::
[0., 0., 0.]])
""".format(**common_args))
add_docstr(torch.as_strided_scatter,
r"""
as_strided_scatter(input, src, size, stride, storage_offset=0) -> Tensor
Embeds the values of the :attr:`src` tensor into :attr:`input` along
the elements corresponding to the result of calling
input.as_strided(size, stride, storage_offset).
This function returns a tensor with fresh storage; it does not
return a view.
Args:
{input}
size (tuple or ints): the shape of the output tensor
stride (tuple or ints): the stride of the output tensor
storage_offset (int, optional): the offset in the underlying storage of the output tensor
.. note::
:attr:`src` must be of the proper size in order to be embedded
into :attr:`input`. Specifically, it should have the same shape as
`torch.as_strided(input, size, stride, storage_offset)`
Example::
>>> a = torch.arange(4).reshape(2, 2) + 1
>>> a
tensor([[1, 2],
[3, 4]])
>>> b = torch.zeros(3, 3)
>>> b
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
>>> torch.as_strided_scatter(b, a, (2, 2), (1, 2))
tensor([[1., 3., 2.],
[4., 0., 0.],
[0., 0., 0.]])
""".format(**common_args))
add_docstr(torch.diff, r"""
diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor

View File

@ -2373,6 +2373,22 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayR
return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
}
Tensor as_strided_scatter_backward(Tensor grad, TensorGeometry input_geometry, TensorGeometry src_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset) {
// Note [as_strided_scatter backward support]
// as_strided_scatter handling for autograd is a beast, and is non-trivial to implement for arbitrarily strided inputs.
// Most uses for as_strided with functionalization only care about the contiguous case anyway,
// So for now this is not implemented.
// When autograd is being used, we ban non-contiguous inputs.
// We can assume that the input was a contiguous tensor.
// Also, we'll take the perf hit and contiguify grad for now.
auto grad_ = grad.contiguous();
auto grad_slice = grad_.as_strided(sizes, strides, storage_offset);
auto result = grad_.new_empty_strided(input_geometry.sizes(), input_geometry.strides());
auto result_slice = result.as_strided(sizes, strides, storage_offset);
result_slice.copy_(grad_slice);
return result;
}
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
if (!grad.defined()) {
return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};

View File

@ -319,6 +319,7 @@ Tensor gelu_double_backward(
const Tensor & input,
c10::string_view approximate);
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
Tensor as_strided_scatter_backward(Tensor grad, TensorGeometry input_geometry, TensorGeometry src_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset);
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
const Tensor & input,

View File

@ -473,6 +473,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
torch.digamma: lambda input, out=None: -1,
torch.dist: lambda input, other, p=2: -1,
torch.div: lambda input, other, rounding_mode=None, out=None: -1,

View File

@ -1713,6 +1713,28 @@ def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)))
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)), kwargs={'storage_offset': 0})
def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# input shape, output shape, output stride, output storage offset
test_cases = [
((1,), (1,), (1,), 0),
((3, 3), (2, 2), (1, 2), 0),
((3, 3), (2, 2), (1, 2), 1),
((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
]
samples = []
for input_shape, output_shape, stride, storage_offset in test_cases:
input_t = make_arg(input_shape)
input_src = make_arg(output_shape)
kwargs = dict(storage_offset=storage_offset)
samples.append(SampleInput(input_t, args=(input_src, output_shape, stride), kwargs=kwargs))
return samples
def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
inputs = (
(0,),
@ -12882,6 +12904,27 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip("Numerous errors"), 'TestGradients'))),
OpInfo('as_strided_scatter',
op=lambda x, src, size, stride, storage_offset=0:
torch.as_strided_scatter(x, src, size, stride, storage_offset=storage_offset),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_as_strided_scatter,
skips=(
DecorateInfo(unittest.skip('Works only for CPU complex64'), 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.skip('Works for float64, fails for everything else'), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950
DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950
DecorateInfo(unittest.skip('Only fails for LAZY, passes on everything else'), 'TestCompositeCompliance', 'test_backward'), # noqa: B950
DecorateInfo(unittest.skip('Passes on complex64 and float32 only'), 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestCommonCUDA', 'test_complex_half_reference_testing'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestGradients', 'test_fn_fwgrad_bwgrad'),)),
OpInfo('nn.functional.cosine_similarity',
aten_name="cosine_similarity",
dtypes=floating_types_and(torch.bfloat16),