Revert "Fix full_like decomposition to preserve strides (#144765)"

This reverts commit 01b0f09931d47bd2716398a0c335b2807dc3074d.

Reverted https://github.com/pytorch/pytorch/pull/144765 on behalf of https://github.com/jeanschmidt due to Seems to be breaking internal tests see [D77652778](https://www.internalfb.com/diff/D77652778), @jansel may you help get this PR merged? ([comment](https://github.com/pytorch/pytorch/pull/144765#issuecomment-3027975098))
This commit is contained in:
PyTorch MergeBot
2025-07-02 13:56:03 +00:00
parent d5a89178b0
commit c553c55be7
11 changed files with 58 additions and 101 deletions

View File

@ -823,6 +823,8 @@ aten::from_file
aten::from_file.out
aten::full.names
aten::full.names_out
aten::full_like
aten::full_like.out
aten::gather
aten::gather.out
aten::geqrf

View File

@ -52,8 +52,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
div_1 = torch.ops.aten.div.Scalar(full, 1); full = None
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
@ -98,8 +98,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
div_1 = torch.ops.aten.div.Scalar(full, 1); full = None
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None

View File

@ -432,8 +432,6 @@ def check_model(
check_gradient=False,
check_has_compiled=True,
output_process_fn_grad=lambda x: x,
# TODO: enable this for all tests
exact_stride=False,
):
kwargs = kwargs or {}
torch._dynamo.reset()
@ -546,7 +544,6 @@ def check_model(
rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
exact_stride=exact_stride,
)
# In case of input mutations, check that inputs are the same
self.assertEqual(
@ -557,7 +554,6 @@ def check_model(
equal_nan=True,
# our testing sometimes uses higher precision inputs for the reference
exact_dtype=False,
exact_stride=exact_stride,
)
else:
for correct_val, actual_val in zip(correct_flat, actual_flat):
@ -571,8 +567,6 @@ def check_model(
assert correct_val.layout == actual_val.layout
if exact_dtype:
assert correct_val.dtype == actual_val.dtype
if exact_stride:
assert correct_val.stride() == actual_val.stride()
if check_gradient:
actual = output_process_fn_grad(actual)
@ -626,7 +620,6 @@ def check_model(
rtol=grad_rtol or rtol,
equal_nan=True,
exact_dtype=exact_dtype,
exact_stride=exact_stride,
)
torch._dynamo.reset()
@ -652,8 +645,6 @@ def check_model_gpu(
check_gradient=False,
check_has_compiled=True,
output_process_fn_grad=lambda x: x,
# TODO: enable this for all tests
exact_stride=False,
):
kwargs = kwargs or {}
if hasattr(model, "to"):
@ -680,7 +671,6 @@ def check_model_gpu(
check_gradient=check_gradient,
check_has_compiled=check_has_compiled,
output_process_fn_grad=output_process_fn_grad,
exact_stride=exact_stride,
)
if check_lowp:
@ -713,7 +703,6 @@ def check_model_gpu(
check_gradient=check_gradient,
check_has_compiled=check_has_compiled,
output_process_fn_grad=output_process_fn_grad,
exact_stride=exact_stride,
)
@ -6971,12 +6960,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(8),))
def test_full_like_stride(self):
def fn(a):
return torch.full_like(a, 3)
self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
def test_full_truncation(self):
def fn(a):
return a + torch.full_like(a, 7.777)

View File

@ -545,11 +545,6 @@ comprehensive_failures = {
xfail(
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
), # off by one error
skip(
"nn.functional.nll_loss",
"",
dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16),
), # non-deterministic
}
@ -866,16 +861,7 @@ def forward(self, scores_1, mask_1, value_1):
assert len(real_out) == len(decomp_out)
if do_relative_check:
device_arg = kwargs.get("device", None)
def upcast(x):
if (isinstance(x, Tensor) and x.device.type == "mps") or (
device_arg and torch.device(device_arg).type == "mps"
):
return upcast_tensor(x, dtype=torch.float32)
else:
return upcast_tensor(x, dtype=torch.float64)
upcast = partial(upcast_tensor, dtype=torch.float64)
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)

View File

@ -8530,6 +8530,14 @@ BACKWARD_SKIPS_AND_XFAILS = [
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
*FORWARD_SKIPS_AND_XFAILS,
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
# e.g. Expected 5 == 4
XFailRule(
error_type=AssertionError,
op_match_fn=lambda device, op: (op.full_name == "fill"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="fill_aot_autograd_bug_with_transposed_input",
),
# Bug: cross-device conversions with to() result in new nested ints within compile only
XFailRule(
error_type=AssertionError,
@ -8573,6 +8581,12 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug1",
),
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "isreal"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug2",
),
]
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [

View File

@ -2294,6 +2294,7 @@ class TestRefsOpsInfo(TestCase):
"_refs.empty_strided",
"_refs.equal",
"_refs.full",
"_refs.full_like",
"_refs.is_complex",
"_refs.to",
"_refs.mvlgamma",
@ -2408,6 +2409,7 @@ class TestRefsOpsInfo(TestCase):
"_refs.unflatten",
"_refs.sum_to_size",
# ref implementation missing kwargs
"_refs.full_like", # missing "layout"
"_refs.scalar_tensor", # missing "layout"
# other
"_refs.block_diag", # only refs._block_diag_iterable is in decomposition table

View File

@ -346,7 +346,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.floor_divide,
aten.frac,
aten.frac_,
aten.full_like,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu_,
aten.gelu_backward,

View File

@ -625,6 +625,28 @@ def randn_like(
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition(aten.full_like)
def full_like(
self: torch.Tensor,
fill_value: Union[int, float],
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor:
return torch.full(
[*self.size()],
fill_value,
dtype=dtype or self.dtype,
layout=layout or self.layout,
device=device or self.device,
requires_grad=requires_grad,
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition(aten.randint_like.default)
def randint_like(
self: torch.Tensor,

View File

@ -3177,6 +3177,7 @@ def _full(fill_value, device, dtype, size):
)
@register_lowering(aten.full_like, type_promotion_kind=None)
def full_like(x, fill_value, **kwargs):
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
@ -6120,17 +6121,6 @@ def fill_(x, fill_value):
return mutate_to(x, full_like(x, fill_value))
@register_lowering(prims.fill, type_promotion_kind=None)
def prims_fill(x, fill_value):
dtype = x.get_dtype()
return Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=lambda _: ops.constant(fill_value, dtype),
ranges=list(x.get_size()),
)
@register_lowering(aten.copy_, type_promotion_kind=None)
def copy_(dst, src, non_blocking=False):
if dst is src:

View File

@ -5588,26 +5588,9 @@ def full(
pin_memory=pin_memory,
requires_grad=requires_grad,
)
return prims.fill(e, fill_value) # type: ignore[arg-type]
return torch.fill(e, fill_value) # type: ignore[arg-type]
def _get_shape_permutation_like(
a: TensorLikeType, layout: torch.layout
) -> tuple[ShapeType, StrideType]:
assert layout == torch.strided
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(a)
shape = [a.shape[l] for l in physical_layout]
permutation = [0] * len(shape)
for p, l in enumerate(physical_layout):
permutation[l] = p
return (shape, permutation)
@register_decomposition(aten.full_like)
@out_wrapper()
def full_like(
a: TensorLikeType,
fill_value: NumberType,
@ -5619,36 +5602,16 @@ def full_like(
requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:
dtype = a.dtype if dtype is None else dtype
layout = a.layout if layout is None else layout
device = a.device if device is None else device
if memory_format != torch.preserve_format:
result = torch.full(
a.shape,
fill_value,
e = torch.empty_like(
a,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
memory_format=memory_format,
)
return result.to(memory_format=memory_format)
else:
shape, permutation = _get_shape_permutation_like(a, layout)
result = torch.full(
shape,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
if permutation == list(range(len(permutation))):
return result
return result.permute(permutation).clone()
return fill(e, fill_value)
@register_decomposition(aten.zeros_like)

View File

@ -1923,7 +1923,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
def get_val(dtype):
return make_tensor([], dtype=dtype, device="cpu").item()
double_dtype = torch.double if torch.device(device).type != "mps" else torch.float
double_dtype = torch.double if device != "mps:0" else torch.float
inputs = [
((), get_val(dtype), {}),
((S, S), get_val(dtype), {}),
@ -24603,10 +24603,6 @@ python_ref_db = [
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
),
),
PythonRefInfo(
"_refs.full_like",
torch_opinfo_name="full_like",
),
PythonRefInfo(
"_refs.randn",
torch_opinfo_name="randn",