From 01b0f09931d47bd2716398a0c335b2807dc3074d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 1 Jul 2025 15:07:32 +0000 Subject: [PATCH] Fix full_like decomposition to preserve strides (#144765) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144765 Approved by: https://github.com/amjames, https://github.com/jansel --- ...asDecompTest.test_has_decomposition.expect | 2 - test/export/test_experimental.py | 8 +-- test/inductor/test_torchinductor.py | 17 ++++++ test/test_decomp.py | 16 ++++- test/test_nestedtensor.py | 14 ----- test/test_ops.py | 2 - torch/_decomp/__init__.py | 1 + torch/_inductor/decomposition.py | 22 ------- torch/_inductor/lowering.py | 12 +++- torch/_refs/__init__.py | 59 +++++++++++++++---- .../_internal/common_methods_invocations.py | 6 +- 11 files changed, 101 insertions(+), 58 deletions(-) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 042959c22cd4..74ead8e17739 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -823,8 +823,6 @@ aten::from_file aten::from_file.out aten::full.names aten::full.names_out -aten::full_like -aten::full_like.out aten::gather aten::gather.out aten::geqrf diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 641dd586edb5..168a58463380 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -52,8 +52,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None div = torch.ops.aten.div.Scalar(neg, 1); neg = None - full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) - div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None + full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + div_1 = torch.ops.aten.div.Scalar(full, 1); full = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None @@ -98,8 +98,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None div = torch.ops.aten.div.Scalar(neg, 1); neg = None - full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) - div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None + full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + div_1 = torch.ops.aten.div.Scalar(full, 1); full = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 17a0da43a2aa..9e761e82ed82 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -432,6 +432,8 @@ def check_model( check_gradient=False, check_has_compiled=True, output_process_fn_grad=lambda x: x, + # TODO: enable this for all tests + exact_stride=False, ): kwargs = kwargs or {} torch._dynamo.reset() @@ -544,6 +546,7 @@ def check_model( rtol=rtol, equal_nan=True, exact_dtype=exact_dtype, + exact_stride=exact_stride, ) # In case of input mutations, check that inputs are the same self.assertEqual( @@ -554,6 +557,7 @@ def check_model( equal_nan=True, # our testing sometimes uses higher precision inputs for the reference exact_dtype=False, + exact_stride=exact_stride, ) else: for correct_val, actual_val in zip(correct_flat, actual_flat): @@ -567,6 +571,8 @@ def check_model( assert correct_val.layout == actual_val.layout if exact_dtype: assert correct_val.dtype == actual_val.dtype + if exact_stride: + assert correct_val.stride() == actual_val.stride() if check_gradient: actual = output_process_fn_grad(actual) @@ -620,6 +626,7 @@ def check_model( rtol=grad_rtol or rtol, equal_nan=True, exact_dtype=exact_dtype, + exact_stride=exact_stride, ) torch._dynamo.reset() @@ -645,6 +652,8 @@ def check_model_gpu( check_gradient=False, check_has_compiled=True, output_process_fn_grad=lambda x: x, + # TODO: enable this for all tests + exact_stride=False, ): kwargs = kwargs or {} if hasattr(model, "to"): @@ -671,6 +680,7 @@ def check_model_gpu( check_gradient=check_gradient, check_has_compiled=check_has_compiled, output_process_fn_grad=output_process_fn_grad, + exact_stride=exact_stride, ) if check_lowp: @@ -703,6 +713,7 @@ def check_model_gpu( check_gradient=check_gradient, check_has_compiled=check_has_compiled, output_process_fn_grad=output_process_fn_grad, + exact_stride=exact_stride, ) @@ -6960,6 +6971,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar self.common(fn, (torch.randn(8),)) + def test_full_like_stride(self): + def fn(a): + return torch.full_like(a, 3) + + self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True) + def test_full_truncation(self): def fn(a): return a + torch.full_like(a, 7.777) diff --git a/test/test_decomp.py b/test/test_decomp.py index 07dcd8252c5b..53ef92dba61d 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -545,6 +545,11 @@ comprehensive_failures = { xfail( "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) ), # off by one error + skip( + "nn.functional.nll_loss", + "", + dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16), + ), # non-deterministic } @@ -861,7 +866,16 @@ def forward(self, scores_1, mask_1, value_1): assert len(real_out) == len(decomp_out) if do_relative_check: - upcast = partial(upcast_tensor, dtype=torch.float64) + device_arg = kwargs.get("device", None) + + def upcast(x): + if (isinstance(x, Tensor) and x.device.type == "mps") or ( + device_arg and torch.device(device_arg).type == "mps" + ): + return upcast_tensor(x, dtype=torch.float32) + else: + return upcast_tensor(x, dtype=torch.float64) + real_out_double, _ = tree_flatten( func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) ) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index f53268cb24d3..f3ea420c814c 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8530,14 +8530,6 @@ BACKWARD_SKIPS_AND_XFAILS = [ COMPILE_FORWARD_SKIPS_AND_XFAILS = [ *FORWARD_SKIPS_AND_XFAILS, - # Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails - # e.g. Expected 5 == 4 - XFailRule( - error_type=AssertionError, - op_match_fn=lambda device, op: (op.full_name == "fill"), - sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), - name="fill_aot_autograd_bug_with_transposed_input", - ), # Bug: cross-device conversions with to() result in new nested ints within compile only XFailRule( error_type=AssertionError, @@ -8581,12 +8573,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [ sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), name="crazy_aot_autograd_bug1", ), - # Bug: also no idea what's going on here: needs investigation within AOTAutograd - XFailRule( - op_match_fn=lambda device, op: (op.full_name == "isreal"), - sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), - name="crazy_aot_autograd_bug2", - ), ] COMPILE_BACKWARD_SKIPS_AND_XFAILS = [ diff --git a/test/test_ops.py b/test/test_ops.py index 0f079e5c45ee..c4b257ef138d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2294,7 +2294,6 @@ class TestRefsOpsInfo(TestCase): "_refs.empty_strided", "_refs.equal", "_refs.full", - "_refs.full_like", "_refs.is_complex", "_refs.to", "_refs.mvlgamma", @@ -2409,7 +2408,6 @@ class TestRefsOpsInfo(TestCase): "_refs.unflatten", "_refs.sum_to_size", # ref implementation missing kwargs - "_refs.full_like", # missing "layout" "_refs.scalar_tensor", # missing "layout" # other "_refs.block_diag", # only refs._block_diag_iterable is in decomposition table diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index abb94b109cc0..8f61fa15f9bc 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -346,6 +346,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.floor_divide, aten.frac, aten.frac_, + aten.full_like, aten._fused_moving_avg_obs_fq_helper, aten.gelu_, aten.gelu_backward, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 08c3abc9f23f..3f75a7ab6a97 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -625,28 +625,6 @@ def randn_like( ).to(memory_format=get_like_layout(self, memory_format)) -@register_decomposition(aten.full_like) -def full_like( - self: torch.Tensor, - fill_value: Union[int, float], - *, - dtype: Optional[torch.dtype] = None, - layout: Optional[torch.layout] = None, - device: Optional[torch.device] = None, - pin_memory: bool = False, - requires_grad: bool = False, - memory_format: torch.memory_format = torch.preserve_format, -) -> torch.Tensor: - return torch.full( - [*self.size()], - fill_value, - dtype=dtype or self.dtype, - layout=layout or self.layout, - device=device or self.device, - requires_grad=requires_grad, - ).to(memory_format=get_like_layout(self, memory_format)) - - @register_decomposition(aten.randint_like.default) def randint_like( self: torch.Tensor, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 5db712372a11..600df3233e7a 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3177,7 +3177,6 @@ def _full(fill_value, device, dtype, size): ) -@register_lowering(aten.full_like, type_promotion_kind=None) def full_like(x, fill_value, **kwargs): return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) @@ -6121,6 +6120,17 @@ def fill_(x, fill_value): return mutate_to(x, full_like(x, fill_value)) +@register_lowering(prims.fill, type_promotion_kind=None) +def prims_fill(x, fill_value): + dtype = x.get_dtype() + return Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=lambda _: ops.constant(fill_value, dtype), + ranges=list(x.get_size()), + ) + + @register_lowering(aten.copy_, type_promotion_kind=None) def copy_(dst, src, non_blocking=False): if dst is src: diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 8fd234c8a0e9..d627aab58277 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -5588,9 +5588,26 @@ def full( pin_memory=pin_memory, requires_grad=requires_grad, ) - return torch.fill(e, fill_value) # type: ignore[arg-type] + return prims.fill(e, fill_value) # type: ignore[arg-type] +def _get_shape_permutation_like( + a: TensorLikeType, layout: torch.layout +) -> tuple[ShapeType, StrideType]: + assert layout == torch.strided + + physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(a) + shape = [a.shape[l] for l in physical_layout] + + permutation = [0] * len(shape) + for p, l in enumerate(physical_layout): + permutation[l] = p + + return (shape, permutation) + + +@register_decomposition(aten.full_like) +@out_wrapper() def full_like( a: TensorLikeType, fill_value: NumberType, @@ -5602,16 +5619,36 @@ def full_like( requires_grad: bool = False, memory_format: torch.memory_format = torch.preserve_format, ) -> TensorLikeType: - e = torch.empty_like( - a, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - requires_grad=requires_grad, - memory_format=memory_format, - ) - return fill(e, fill_value) + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + if memory_format != torch.preserve_format: + result = torch.full( + a.shape, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return result.to(memory_format=memory_format) + + else: + shape, permutation = _get_shape_permutation_like(a, layout) + result = torch.full( + shape, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + if permutation == list(range(len(permutation))): + return result + return result.permute(permutation).clone() @register_decomposition(aten.zeros_like) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index daf42f4bba59..55c1780961cc 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1923,7 +1923,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): def get_val(dtype): return make_tensor([], dtype=dtype, device="cpu").item() - double_dtype = torch.double if device != "mps:0" else torch.float + double_dtype = torch.double if torch.device(device).type != "mps" else torch.float inputs = [ ((), get_val(dtype), {}), ((S, S), get_val(dtype), {}), @@ -24603,6 +24603,10 @@ python_ref_db = [ DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), ), ), + PythonRefInfo( + "_refs.full_like", + torch_opinfo_name="full_like", + ), PythonRefInfo( "_refs.randn", torch_opinfo_name="randn",