Fix full_like decomposition to preserve strides (#144765)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144765
Approved by: https://github.com/amjames, https://github.com/jansel
This commit is contained in:
Isuru Fernando
2025-07-01 15:07:32 +00:00
committed by PyTorch MergeBot
parent 6401d1d53d
commit 01b0f09931
11 changed files with 101 additions and 58 deletions

View File

@ -545,6 +545,11 @@ comprehensive_failures = {
xfail(
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
), # off by one error
skip(
"nn.functional.nll_loss",
"",
dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16),
), # non-deterministic
}
@ -861,7 +866,16 @@ def forward(self, scores_1, mask_1, value_1):
assert len(real_out) == len(decomp_out)
if do_relative_check:
upcast = partial(upcast_tensor, dtype=torch.float64)
device_arg = kwargs.get("device", None)
def upcast(x):
if (isinstance(x, Tensor) and x.device.type == "mps") or (
device_arg and torch.device(device_arg).type == "mps"
):
return upcast_tensor(x, dtype=torch.float32)
else:
return upcast_tensor(x, dtype=torch.float64)
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)