Fix full_like decomposition to preserve strides (#158898)

Summary:
See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again.

Rollback Plan:

Differential hack Revision: D78783627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158898
Approved by: https://github.com/eellison
This commit is contained in:
Sam Larsen
2025-07-25 20:21:36 +00:00
committed by PyTorch MergeBot
parent 28ee8be5bf
commit fa0355c18d
6 changed files with 99 additions and 26 deletions

View File

@ -432,6 +432,8 @@ def check_model(
check_gradient=False,
check_has_compiled=True,
output_process_fn_grad=lambda x: x,
# TODO: enable this for all tests
exact_stride=False,
):
kwargs = kwargs or {}
torch._dynamo.reset()
@ -465,7 +467,12 @@ def check_model(
x.dtype == torch.float16 or x.dtype == torch.bfloat16
):
has_lowp_args = True
return x.float()
# Preserve strides when casting
result = torch.empty_strided(
x.size(), x.stride(), device=x.device, dtype=torch.float
)
result.copy_(x)
return result
else:
return x
@ -555,6 +562,7 @@ def check_model(
rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
exact_stride=exact_stride,
)
# In case of input mutations, check that inputs are the same
# (This never uses a custom assert_equal fn.)
@ -566,6 +574,7 @@ def check_model(
equal_nan=True,
# our testing sometimes uses higher precision inputs for the reference
exact_dtype=False,
exact_stride=exact_stride,
)
else:
for correct_val, actual_val in zip(correct_flat, actual_flat):
@ -579,6 +588,8 @@ def check_model(
assert correct_val.layout == actual_val.layout
if exact_dtype:
assert correct_val.dtype == actual_val.dtype
if exact_stride:
assert correct_val.stride() == actual_val.stride()
if check_gradient:
actual = output_process_fn_grad(actual)
@ -632,6 +643,7 @@ def check_model(
rtol=grad_rtol or rtol,
equal_nan=True,
exact_dtype=exact_dtype,
exact_stride=exact_stride,
)
torch._dynamo.reset()
@ -657,6 +669,8 @@ def check_model_gpu(
check_gradient=False,
check_has_compiled=True,
output_process_fn_grad=lambda x: x,
# TODO: enable this for all tests
exact_stride=False,
):
kwargs = kwargs or {}
if hasattr(model, "to"):
@ -683,6 +697,7 @@ def check_model_gpu(
check_gradient=check_gradient,
check_has_compiled=check_has_compiled,
output_process_fn_grad=output_process_fn_grad,
exact_stride=exact_stride,
)
if check_lowp:
@ -715,6 +730,7 @@ def check_model_gpu(
check_gradient=check_gradient,
check_has_compiled=check_has_compiled,
output_process_fn_grad=output_process_fn_grad,
exact_stride=exact_stride,
)
@ -6979,6 +6995,18 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(8),))
def test_full_like_transposed(self):
def fn(a):
return torch.full_like(a, 3)
self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
def test_full_like_sliced(self):
def fn(a):
return torch.full_like(a, 3)
self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True)
def test_full_truncation(self):
def fn(a):
return a + torch.full_like(a, 7.777)

View File

@ -985,7 +985,15 @@ def get_sort_argsort_assert_equal_fn(is_argsort, args, kwargs):
dim = kwargs["dim"]
def argsort_sort_assert_equal(
test_case_inst, x, y, *, atol=None, rtol=None, equal_nan=True, exact_dtype=True
test_case_inst,
x,
y,
*,
atol=None,
rtol=None,
equal_nan=True,
exact_dtype=True,
exact_stride=False,
):
if is_argsort:
assert isinstance(x, torch.Tensor)
@ -1004,6 +1012,7 @@ def get_sort_argsort_assert_equal_fn(is_argsort, args, kwargs):
rtol=rtol,
equal_nan=equal_nan,
exact_dtype=exact_dtype,
exact_stride=exact_stride,
)
# The second tensor is the same result as an argsort.
@ -1015,6 +1024,11 @@ def get_sort_argsort_assert_equal_fn(is_argsort, args, kwargs):
assert x.shape == y.shape
if exact_stride and (x.stride() != y.stride()):
raise AssertionError(
f"The strides do not match: {x.stride()} != {y.stride()}."
)
def el_to_indices(el):
"""Turn an element number into a list of indices"""
indices = [None] * x.dim()

View File

@ -861,7 +861,16 @@ def forward(self, scores_1, mask_1, value_1):
assert len(real_out) == len(decomp_out)
if do_relative_check:
upcast = partial(upcast_tensor, dtype=torch.float64)
device_arg = kwargs.get("device", None)
def upcast(x):
if (isinstance(x, Tensor) and x.device.type == "mps") or (
device_arg and torch.device(device_arg).type == "mps"
):
return upcast_tensor(x, dtype=torch.float32)
else:
return upcast_tensor(x, dtype=torch.float64)
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)

View File

@ -8536,14 +8536,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
*FORWARD_SKIPS_AND_XFAILS,
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
# e.g. Expected 5 == 4
XFailRule(
error_type=AssertionError,
op_match_fn=lambda device, op: (op.full_name == "fill"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="fill_aot_autograd_bug_with_transposed_input",
),
# Bug: cross-device conversions with to() result in new nested ints within compile only
XFailRule(
error_type=AssertionError,
@ -8587,12 +8579,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug1",
),
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "isreal"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug2",
),
]
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [

View File

@ -625,6 +625,21 @@ def randn_like(
).to(memory_format=get_like_layout(self, memory_format))
def _get_shape_permutation_like(
self: torch.Tensor, layout: torch.layout
) -> tuple[utils.ShapeType, utils.StrideType]:
assert layout == torch.strided
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
shape = [self.shape[l] for l in physical_layout]
permutation = [0] * len(shape)
for p, l in enumerate(physical_layout):
permutation[l] = p
return (shape, permutation)
@register_decomposition(aten.full_like)
def full_like(
self: torch.Tensor,
@ -637,14 +652,36 @@ def full_like(
requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor:
return torch.full(
[*self.size()],
fill_value,
dtype=dtype or self.dtype,
layout=layout or self.layout,
device=device or self.device,
requires_grad=requires_grad,
).to(memory_format=get_like_layout(self, memory_format))
dtype = self.dtype if dtype is None else dtype
layout = self.layout if layout is None else layout
device = self.device if device is None else device
if memory_format != torch.preserve_format:
result = torch.full(
self.shape,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
return result.to(memory_format=memory_format)
else:
shape, permutation = _get_shape_permutation_like(self, layout)
result = torch.full(
shape,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
if permutation == list(range(len(permutation))):
return result
return result.permute(permutation).clone()
@register_decomposition(aten.randint_like.default)

View File

@ -3217,7 +3217,6 @@ def _full(fill_value, device, dtype, size):
)
@register_lowering(aten.full_like, type_promotion_kind=None)
def full_like(x, fill_value, **kwargs):
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)