mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
as_strided support for functionalization; introduce as_strided_scatter
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77128 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7ddc1425ff
commit
3a921f2d26
@ -128,8 +128,8 @@ Tensor FunctionalInverses::_neg_view_copy_inverse(const Tensor& base, const Tens
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
|
||||
TORCH_INTERNAL_ASSERT(false, "as_strided has not been implemented in the functionalization pass yet");
|
||||
return Tensor();
|
||||
// Pessimism: we can't reapply views for as_strided_scatter.
|
||||
return base.as_strided_scatter(mutated_view, size, stride, storage_offset);
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t offset, int64_t dim1, int64_t dim2) {
|
||||
|
@ -3318,6 +3318,15 @@ at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64
|
||||
slice.copy_(src);
|
||||
return output;
|
||||
}
|
||||
at::Tensor as_strided_scatter(const at::Tensor& self, const at::Tensor& src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
|
||||
// See Note [as_strided_scatter backward support]
|
||||
TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
|
||||
auto output = self.clone();
|
||||
auto slice = output.as_strided(size, stride, storage_offset);
|
||||
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
|
||||
slice.copy_(src);
|
||||
return output;
|
||||
}
|
||||
|
||||
// The default implementation of lift is a no-op.
|
||||
// If TLS is set appropriately (for wrapper-tensor keys like Functionalize or functorch transforms),
|
||||
|
@ -4320,6 +4320,13 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: diagonal_scatter
|
||||
|
||||
- func: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: as_strided_scatter
|
||||
|
||||
- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
@ -186,6 +186,18 @@ $2 = torch._ops.aten.add.Tensor($0, tensor([[1., 1.],
|
||||
$0 = input('input')
|
||||
$1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 0, 1, 0)""")
|
||||
|
||||
def test_as_strided(self):
|
||||
def f(x):
|
||||
y = x.as_strided((2,), (2,), 1)
|
||||
y.add_(1)
|
||||
return x
|
||||
self.assert_functionalization(f, torch.ones(9))
|
||||
logs = self.get_logs(f, torch.ones(9))
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.aten.as_strided_copy.default($0, [2], [2], 1)
|
||||
$2 = torch._ops.aten.add.Tensor($1, 1)""")
|
||||
|
||||
def test_tensor_list_composite(self):
|
||||
def f(x):
|
||||
# Test an op with TensorList input
|
||||
|
@ -1412,6 +1412,12 @@
|
||||
src: grad.diagonal(offset, dim1, dim2)
|
||||
result: auto_linear
|
||||
|
||||
- name: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor
|
||||
self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset)
|
||||
# See Note [as_strided_scatter backward support]
|
||||
src: grad.contiguous().as_strided(size, stride, storage_offset)
|
||||
result: auto_linear
|
||||
|
||||
- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
|
||||
self: slogdet_backward(grad, self, sign, logabsdet)
|
||||
output_differentiability: [false, true]
|
||||
|
@ -197,6 +197,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"select",
|
||||
"where",
|
||||
"as_strided",
|
||||
"as_strided_scatter",
|
||||
"slice",
|
||||
"constant_pad_nd",
|
||||
"unbind",
|
||||
|
@ -1201,11 +1201,18 @@ See :func:`torch.diagonal`
|
||||
|
||||
add_docstr_all('diagonal_scatter',
|
||||
r"""
|
||||
diagonal(src, offset=0, dim1=0, dim2=1) -> Tensor
|
||||
diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor
|
||||
|
||||
See :func:`torch.diagonal_scatter`
|
||||
""")
|
||||
|
||||
add_docstr_all('as_strided_scatter',
|
||||
r"""
|
||||
as_strided_scatter(src, size, stride, storage_offset=0) -> Tensor
|
||||
|
||||
See :func:`torch.as_strided_scatter`
|
||||
""")
|
||||
|
||||
add_docstr_all('fill_diagonal_',
|
||||
r"""
|
||||
fill_diagonal_(fill_value, wrap=False) -> Tensor
|
||||
|
@ -3291,6 +3291,47 @@ Examples::
|
||||
[0., 0., 0.]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.as_strided_scatter,
|
||||
r"""
|
||||
as_strided_scatter(input, src, size, stride, storage_offset=0) -> Tensor
|
||||
|
||||
Embeds the values of the :attr:`src` tensor into :attr:`input` along
|
||||
the elements corresponding to the result of calling
|
||||
input.as_strided(size, stride, storage_offset).
|
||||
|
||||
This function returns a tensor with fresh storage; it does not
|
||||
return a view.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
size (tuple or ints): the shape of the output tensor
|
||||
stride (tuple or ints): the stride of the output tensor
|
||||
storage_offset (int, optional): the offset in the underlying storage of the output tensor
|
||||
|
||||
.. note::
|
||||
|
||||
:attr:`src` must be of the proper size in order to be embedded
|
||||
into :attr:`input`. Specifically, it should have the same shape as
|
||||
`torch.as_strided(input, size, stride, storage_offset)`
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(4).reshape(2, 2) + 1
|
||||
>>> a
|
||||
tensor([[1, 2],
|
||||
[3, 4]])
|
||||
>>> b = torch.zeros(3, 3)
|
||||
>>> b
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
>>> torch.as_strided_scatter(b, a, (2, 2), (1, 2))
|
||||
tensor([[1., 3., 2.],
|
||||
[4., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.diff, r"""
|
||||
diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor
|
||||
|
||||
|
@ -2373,6 +2373,22 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayR
|
||||
return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
|
||||
}
|
||||
|
||||
Tensor as_strided_scatter_backward(Tensor grad, TensorGeometry input_geometry, TensorGeometry src_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset) {
|
||||
// Note [as_strided_scatter backward support]
|
||||
// as_strided_scatter handling for autograd is a beast, and is non-trivial to implement for arbitrarily strided inputs.
|
||||
// Most uses for as_strided with functionalization only care about the contiguous case anyway,
|
||||
// So for now this is not implemented.
|
||||
// When autograd is being used, we ban non-contiguous inputs.
|
||||
// We can assume that the input was a contiguous tensor.
|
||||
// Also, we'll take the perf hit and contiguify grad for now.
|
||||
auto grad_ = grad.contiguous();
|
||||
auto grad_slice = grad_.as_strided(sizes, strides, storage_offset);
|
||||
auto result = grad_.new_empty_strided(input_geometry.sizes(), input_geometry.strides());
|
||||
auto result_slice = result.as_strided(sizes, strides, storage_offset);
|
||||
result_slice.copy_(grad_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
|
||||
if (!grad.defined()) {
|
||||
return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
|
||||
|
@ -319,6 +319,7 @@ Tensor gelu_double_backward(
|
||||
const Tensor & input,
|
||||
c10::string_view approximate);
|
||||
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
|
||||
Tensor as_strided_scatter_backward(Tensor grad, TensorGeometry input_geometry, TensorGeometry src_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset);
|
||||
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
|
||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
||||
const Tensor & input,
|
||||
|
@ -473,6 +473,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
|
||||
torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
|
||||
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
|
||||
torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
|
||||
torch.digamma: lambda input, out=None: -1,
|
||||
torch.dist: lambda input, other, p=2: -1,
|
||||
torch.div: lambda input, other, rounding_mode=None, out=None: -1,
|
||||
|
@ -1713,6 +1713,28 @@ def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
|
||||
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)))
|
||||
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)), kwargs={'storage_offset': 0})
|
||||
|
||||
def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
# input shape, output shape, output stride, output storage offset
|
||||
test_cases = [
|
||||
((1,), (1,), (1,), 0),
|
||||
((3, 3), (2, 2), (1, 2), 0),
|
||||
((3, 3), (2, 2), (1, 2), 1),
|
||||
((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
|
||||
((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
|
||||
]
|
||||
|
||||
samples = []
|
||||
|
||||
for input_shape, output_shape, stride, storage_offset in test_cases:
|
||||
input_t = make_arg(input_shape)
|
||||
input_src = make_arg(output_shape)
|
||||
kwargs = dict(storage_offset=storage_offset)
|
||||
samples.append(SampleInput(input_t, args=(input_src, output_shape, stride), kwargs=kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
|
||||
inputs = (
|
||||
(0,),
|
||||
@ -12882,6 +12904,27 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip("Numerous errors"), 'TestGradients'))),
|
||||
OpInfo('as_strided_scatter',
|
||||
op=lambda x, src, size, stride, storage_offset=0:
|
||||
torch.as_strided_scatter(x, src, size, stride, storage_offset=storage_offset),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# vmap does not support inplace views
|
||||
check_inplace_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_as_strided_scatter,
|
||||
skips=(
|
||||
DecorateInfo(unittest.skip('Works only for CPU complex64'), 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.skip('Works for float64, fails for everything else'), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950
|
||||
DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950
|
||||
DecorateInfo(unittest.skip('Only fails for LAZY, passes on everything else'), 'TestCompositeCompliance', 'test_backward'), # noqa: B950
|
||||
DecorateInfo(unittest.skip('Passes on complex64 and float32 only'), 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommonCUDA', 'test_complex_half_reference_testing'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
|
||||
DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestGradients', 'test_fn_fwgrad_bwgrad'),)),
|
||||
OpInfo('nn.functional.cosine_similarity',
|
||||
aten_name="cosine_similarity",
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
|
Reference in New Issue
Block a user