mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix full_like decomposition to preserve strides (#144765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144765 Approved by: https://github.com/amjames, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6401d1d53d
commit
01b0f09931
@ -545,6 +545,11 @@ comprehensive_failures = {
|
||||
xfail(
|
||||
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
|
||||
), # off by one error
|
||||
skip(
|
||||
"nn.functional.nll_loss",
|
||||
"",
|
||||
dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16),
|
||||
), # non-deterministic
|
||||
}
|
||||
|
||||
|
||||
@ -861,7 +866,16 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
assert len(real_out) == len(decomp_out)
|
||||
|
||||
if do_relative_check:
|
||||
upcast = partial(upcast_tensor, dtype=torch.float64)
|
||||
device_arg = kwargs.get("device", None)
|
||||
|
||||
def upcast(x):
|
||||
if (isinstance(x, Tensor) and x.device.type == "mps") or (
|
||||
device_arg and torch.device(device_arg).type == "mps"
|
||||
):
|
||||
return upcast_tensor(x, dtype=torch.float32)
|
||||
else:
|
||||
return upcast_tensor(x, dtype=torch.float64)
|
||||
|
||||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
|
Reference in New Issue
Block a user