mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix full_like decomposition to preserve strides (#158898)
Summary: See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again. Rollback Plan: Differential hack Revision: D78783627 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158898 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
28ee8be5bf
commit
fa0355c18d
@ -432,6 +432,8 @@ def check_model(
|
||||
check_gradient=False,
|
||||
check_has_compiled=True,
|
||||
output_process_fn_grad=lambda x: x,
|
||||
# TODO: enable this for all tests
|
||||
exact_stride=False,
|
||||
):
|
||||
kwargs = kwargs or {}
|
||||
torch._dynamo.reset()
|
||||
@ -465,7 +467,12 @@ def check_model(
|
||||
x.dtype == torch.float16 or x.dtype == torch.bfloat16
|
||||
):
|
||||
has_lowp_args = True
|
||||
return x.float()
|
||||
# Preserve strides when casting
|
||||
result = torch.empty_strided(
|
||||
x.size(), x.stride(), device=x.device, dtype=torch.float
|
||||
)
|
||||
result.copy_(x)
|
||||
return result
|
||||
else:
|
||||
return x
|
||||
|
||||
@ -555,6 +562,7 @@ def check_model(
|
||||
rtol=rtol,
|
||||
equal_nan=True,
|
||||
exact_dtype=exact_dtype,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
# In case of input mutations, check that inputs are the same
|
||||
# (This never uses a custom assert_equal fn.)
|
||||
@ -566,6 +574,7 @@ def check_model(
|
||||
equal_nan=True,
|
||||
# our testing sometimes uses higher precision inputs for the reference
|
||||
exact_dtype=False,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
else:
|
||||
for correct_val, actual_val in zip(correct_flat, actual_flat):
|
||||
@ -579,6 +588,8 @@ def check_model(
|
||||
assert correct_val.layout == actual_val.layout
|
||||
if exact_dtype:
|
||||
assert correct_val.dtype == actual_val.dtype
|
||||
if exact_stride:
|
||||
assert correct_val.stride() == actual_val.stride()
|
||||
|
||||
if check_gradient:
|
||||
actual = output_process_fn_grad(actual)
|
||||
@ -632,6 +643,7 @@ def check_model(
|
||||
rtol=grad_rtol or rtol,
|
||||
equal_nan=True,
|
||||
exact_dtype=exact_dtype,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
@ -657,6 +669,8 @@ def check_model_gpu(
|
||||
check_gradient=False,
|
||||
check_has_compiled=True,
|
||||
output_process_fn_grad=lambda x: x,
|
||||
# TODO: enable this for all tests
|
||||
exact_stride=False,
|
||||
):
|
||||
kwargs = kwargs or {}
|
||||
if hasattr(model, "to"):
|
||||
@ -683,6 +697,7 @@ def check_model_gpu(
|
||||
check_gradient=check_gradient,
|
||||
check_has_compiled=check_has_compiled,
|
||||
output_process_fn_grad=output_process_fn_grad,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
if check_lowp:
|
||||
@ -715,6 +730,7 @@ def check_model_gpu(
|
||||
check_gradient=check_gradient,
|
||||
check_has_compiled=check_has_compiled,
|
||||
output_process_fn_grad=output_process_fn_grad,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
|
||||
@ -6979,6 +6995,18 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
self.common(fn, (torch.randn(8),))
|
||||
|
||||
def test_full_like_transposed(self):
|
||||
def fn(a):
|
||||
return torch.full_like(a, 3)
|
||||
|
||||
self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
|
||||
|
||||
def test_full_like_sliced(self):
|
||||
def fn(a):
|
||||
return torch.full_like(a, 3)
|
||||
|
||||
self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True)
|
||||
|
||||
def test_full_truncation(self):
|
||||
def fn(a):
|
||||
return a + torch.full_like(a, 7.777)
|
||||
|
@ -985,7 +985,15 @@ def get_sort_argsort_assert_equal_fn(is_argsort, args, kwargs):
|
||||
dim = kwargs["dim"]
|
||||
|
||||
def argsort_sort_assert_equal(
|
||||
test_case_inst, x, y, *, atol=None, rtol=None, equal_nan=True, exact_dtype=True
|
||||
test_case_inst,
|
||||
x,
|
||||
y,
|
||||
*,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
equal_nan=True,
|
||||
exact_dtype=True,
|
||||
exact_stride=False,
|
||||
):
|
||||
if is_argsort:
|
||||
assert isinstance(x, torch.Tensor)
|
||||
@ -1004,6 +1012,7 @@ def get_sort_argsort_assert_equal_fn(is_argsort, args, kwargs):
|
||||
rtol=rtol,
|
||||
equal_nan=equal_nan,
|
||||
exact_dtype=exact_dtype,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
# The second tensor is the same result as an argsort.
|
||||
@ -1015,6 +1024,11 @@ def get_sort_argsort_assert_equal_fn(is_argsort, args, kwargs):
|
||||
|
||||
assert x.shape == y.shape
|
||||
|
||||
if exact_stride and (x.stride() != y.stride()):
|
||||
raise AssertionError(
|
||||
f"The strides do not match: {x.stride()} != {y.stride()}."
|
||||
)
|
||||
|
||||
def el_to_indices(el):
|
||||
"""Turn an element number into a list of indices"""
|
||||
indices = [None] * x.dim()
|
||||
|
@ -861,7 +861,16 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
assert len(real_out) == len(decomp_out)
|
||||
|
||||
if do_relative_check:
|
||||
upcast = partial(upcast_tensor, dtype=torch.float64)
|
||||
device_arg = kwargs.get("device", None)
|
||||
|
||||
def upcast(x):
|
||||
if (isinstance(x, Tensor) and x.device.type == "mps") or (
|
||||
device_arg and torch.device(device_arg).type == "mps"
|
||||
):
|
||||
return upcast_tensor(x, dtype=torch.float32)
|
||||
else:
|
||||
return upcast_tensor(x, dtype=torch.float64)
|
||||
|
||||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
|
@ -8536,14 +8536,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
|
||||
|
||||
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
||||
*FORWARD_SKIPS_AND_XFAILS,
|
||||
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
|
||||
# e.g. Expected 5 == 4
|
||||
XFailRule(
|
||||
error_type=AssertionError,
|
||||
op_match_fn=lambda device, op: (op.full_name == "fill"),
|
||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="fill_aot_autograd_bug_with_transposed_input",
|
||||
),
|
||||
# Bug: cross-device conversions with to() result in new nested ints within compile only
|
||||
XFailRule(
|
||||
error_type=AssertionError,
|
||||
@ -8587,12 +8579,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="crazy_aot_autograd_bug1",
|
||||
),
|
||||
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
|
||||
XFailRule(
|
||||
op_match_fn=lambda device, op: (op.full_name == "isreal"),
|
||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="crazy_aot_autograd_bug2",
|
||||
),
|
||||
]
|
||||
|
||||
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
|
@ -625,6 +625,21 @@ def randn_like(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
def _get_shape_permutation_like(
|
||||
self: torch.Tensor, layout: torch.layout
|
||||
) -> tuple[utils.ShapeType, utils.StrideType]:
|
||||
assert layout == torch.strided
|
||||
|
||||
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
|
||||
shape = [self.shape[l] for l in physical_layout]
|
||||
|
||||
permutation = [0] * len(shape)
|
||||
for p, l in enumerate(physical_layout):
|
||||
permutation[l] = p
|
||||
|
||||
return (shape, permutation)
|
||||
|
||||
|
||||
@register_decomposition(aten.full_like)
|
||||
def full_like(
|
||||
self: torch.Tensor,
|
||||
@ -637,14 +652,36 @@ def full_like(
|
||||
requires_grad: bool = False,
|
||||
memory_format: torch.memory_format = torch.preserve_format,
|
||||
) -> torch.Tensor:
|
||||
return torch.full(
|
||||
[*self.size()],
|
||||
fill_value,
|
||||
dtype=dtype or self.dtype,
|
||||
layout=layout or self.layout,
|
||||
device=device or self.device,
|
||||
requires_grad=requires_grad,
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
dtype = self.dtype if dtype is None else dtype
|
||||
layout = self.layout if layout is None else layout
|
||||
device = self.device if device is None else device
|
||||
|
||||
if memory_format != torch.preserve_format:
|
||||
result = torch.full(
|
||||
self.shape,
|
||||
fill_value,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
return result.to(memory_format=memory_format)
|
||||
|
||||
else:
|
||||
shape, permutation = _get_shape_permutation_like(self, layout)
|
||||
result = torch.full(
|
||||
shape,
|
||||
fill_value,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
if permutation == list(range(len(permutation))):
|
||||
return result
|
||||
return result.permute(permutation).clone()
|
||||
|
||||
|
||||
@register_decomposition(aten.randint_like.default)
|
||||
|
@ -3217,7 +3217,6 @@ def _full(fill_value, device, dtype, size):
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.full_like, type_promotion_kind=None)
|
||||
def full_like(x, fill_value, **kwargs):
|
||||
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user