mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix full_like decomposition to preserve strides (#144765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144765 Approved by: https://github.com/amjames, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6401d1d53d
commit
01b0f09931
@ -823,8 +823,6 @@ aten::from_file
|
||||
aten::from_file.out
|
||||
aten::full.names
|
||||
aten::full.names_out
|
||||
aten::full_like
|
||||
aten::full_like.out
|
||||
aten::gather
|
||||
aten::gather.out
|
||||
aten::geqrf
|
||||
|
@ -52,8 +52,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
||||
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
||||
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
||||
full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
div_1 = torch.ops.aten.div.Scalar(full, 1); full = None
|
||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||
@ -98,8 +98,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
||||
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
||||
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
||||
full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
div_1 = torch.ops.aten.div.Scalar(full, 1); full = None
|
||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||
|
@ -432,6 +432,8 @@ def check_model(
|
||||
check_gradient=False,
|
||||
check_has_compiled=True,
|
||||
output_process_fn_grad=lambda x: x,
|
||||
# TODO: enable this for all tests
|
||||
exact_stride=False,
|
||||
):
|
||||
kwargs = kwargs or {}
|
||||
torch._dynamo.reset()
|
||||
@ -544,6 +546,7 @@ def check_model(
|
||||
rtol=rtol,
|
||||
equal_nan=True,
|
||||
exact_dtype=exact_dtype,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
# In case of input mutations, check that inputs are the same
|
||||
self.assertEqual(
|
||||
@ -554,6 +557,7 @@ def check_model(
|
||||
equal_nan=True,
|
||||
# our testing sometimes uses higher precision inputs for the reference
|
||||
exact_dtype=False,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
else:
|
||||
for correct_val, actual_val in zip(correct_flat, actual_flat):
|
||||
@ -567,6 +571,8 @@ def check_model(
|
||||
assert correct_val.layout == actual_val.layout
|
||||
if exact_dtype:
|
||||
assert correct_val.dtype == actual_val.dtype
|
||||
if exact_stride:
|
||||
assert correct_val.stride() == actual_val.stride()
|
||||
|
||||
if check_gradient:
|
||||
actual = output_process_fn_grad(actual)
|
||||
@ -620,6 +626,7 @@ def check_model(
|
||||
rtol=grad_rtol or rtol,
|
||||
equal_nan=True,
|
||||
exact_dtype=exact_dtype,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
@ -645,6 +652,8 @@ def check_model_gpu(
|
||||
check_gradient=False,
|
||||
check_has_compiled=True,
|
||||
output_process_fn_grad=lambda x: x,
|
||||
# TODO: enable this for all tests
|
||||
exact_stride=False,
|
||||
):
|
||||
kwargs = kwargs or {}
|
||||
if hasattr(model, "to"):
|
||||
@ -671,6 +680,7 @@ def check_model_gpu(
|
||||
check_gradient=check_gradient,
|
||||
check_has_compiled=check_has_compiled,
|
||||
output_process_fn_grad=output_process_fn_grad,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
if check_lowp:
|
||||
@ -703,6 +713,7 @@ def check_model_gpu(
|
||||
check_gradient=check_gradient,
|
||||
check_has_compiled=check_has_compiled,
|
||||
output_process_fn_grad=output_process_fn_grad,
|
||||
exact_stride=exact_stride,
|
||||
)
|
||||
|
||||
|
||||
@ -6960,6 +6971,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
self.common(fn, (torch.randn(8),))
|
||||
|
||||
def test_full_like_stride(self):
|
||||
def fn(a):
|
||||
return torch.full_like(a, 3)
|
||||
|
||||
self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
|
||||
|
||||
def test_full_truncation(self):
|
||||
def fn(a):
|
||||
return a + torch.full_like(a, 7.777)
|
||||
|
@ -545,6 +545,11 @@ comprehensive_failures = {
|
||||
xfail(
|
||||
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
|
||||
), # off by one error
|
||||
skip(
|
||||
"nn.functional.nll_loss",
|
||||
"",
|
||||
dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16),
|
||||
), # non-deterministic
|
||||
}
|
||||
|
||||
|
||||
@ -861,7 +866,16 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
assert len(real_out) == len(decomp_out)
|
||||
|
||||
if do_relative_check:
|
||||
upcast = partial(upcast_tensor, dtype=torch.float64)
|
||||
device_arg = kwargs.get("device", None)
|
||||
|
||||
def upcast(x):
|
||||
if (isinstance(x, Tensor) and x.device.type == "mps") or (
|
||||
device_arg and torch.device(device_arg).type == "mps"
|
||||
):
|
||||
return upcast_tensor(x, dtype=torch.float32)
|
||||
else:
|
||||
return upcast_tensor(x, dtype=torch.float64)
|
||||
|
||||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
|
@ -8530,14 +8530,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
|
||||
|
||||
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
||||
*FORWARD_SKIPS_AND_XFAILS,
|
||||
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
|
||||
# e.g. Expected 5 == 4
|
||||
XFailRule(
|
||||
error_type=AssertionError,
|
||||
op_match_fn=lambda device, op: (op.full_name == "fill"),
|
||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="fill_aot_autograd_bug_with_transposed_input",
|
||||
),
|
||||
# Bug: cross-device conversions with to() result in new nested ints within compile only
|
||||
XFailRule(
|
||||
error_type=AssertionError,
|
||||
@ -8581,12 +8573,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="crazy_aot_autograd_bug1",
|
||||
),
|
||||
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
|
||||
XFailRule(
|
||||
op_match_fn=lambda device, op: (op.full_name == "isreal"),
|
||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="crazy_aot_autograd_bug2",
|
||||
),
|
||||
]
|
||||
|
||||
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
|
@ -2294,7 +2294,6 @@ class TestRefsOpsInfo(TestCase):
|
||||
"_refs.empty_strided",
|
||||
"_refs.equal",
|
||||
"_refs.full",
|
||||
"_refs.full_like",
|
||||
"_refs.is_complex",
|
||||
"_refs.to",
|
||||
"_refs.mvlgamma",
|
||||
@ -2409,7 +2408,6 @@ class TestRefsOpsInfo(TestCase):
|
||||
"_refs.unflatten",
|
||||
"_refs.sum_to_size",
|
||||
# ref implementation missing kwargs
|
||||
"_refs.full_like", # missing "layout"
|
||||
"_refs.scalar_tensor", # missing "layout"
|
||||
# other
|
||||
"_refs.block_diag", # only refs._block_diag_iterable is in decomposition table
|
||||
|
@ -346,6 +346,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.floor_divide,
|
||||
aten.frac,
|
||||
aten.frac_,
|
||||
aten.full_like,
|
||||
aten._fused_moving_avg_obs_fq_helper,
|
||||
aten.gelu_,
|
||||
aten.gelu_backward,
|
||||
|
@ -625,28 +625,6 @@ def randn_like(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition(aten.full_like)
|
||||
def full_like(
|
||||
self: torch.Tensor,
|
||||
fill_value: Union[int, float],
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False,
|
||||
requires_grad: bool = False,
|
||||
memory_format: torch.memory_format = torch.preserve_format,
|
||||
) -> torch.Tensor:
|
||||
return torch.full(
|
||||
[*self.size()],
|
||||
fill_value,
|
||||
dtype=dtype or self.dtype,
|
||||
layout=layout or self.layout,
|
||||
device=device or self.device,
|
||||
requires_grad=requires_grad,
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition(aten.randint_like.default)
|
||||
def randint_like(
|
||||
self: torch.Tensor,
|
||||
|
@ -3177,7 +3177,6 @@ def _full(fill_value, device, dtype, size):
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.full_like, type_promotion_kind=None)
|
||||
def full_like(x, fill_value, **kwargs):
|
||||
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
|
||||
|
||||
@ -6121,6 +6120,17 @@ def fill_(x, fill_value):
|
||||
return mutate_to(x, full_like(x, fill_value))
|
||||
|
||||
|
||||
@register_lowering(prims.fill, type_promotion_kind=None)
|
||||
def prims_fill(x, fill_value):
|
||||
dtype = x.get_dtype()
|
||||
return Pointwise.create(
|
||||
device=x.get_device(),
|
||||
dtype=dtype,
|
||||
inner_fn=lambda _: ops.constant(fill_value, dtype),
|
||||
ranges=list(x.get_size()),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.copy_, type_promotion_kind=None)
|
||||
def copy_(dst, src, non_blocking=False):
|
||||
if dst is src:
|
||||
|
@ -5588,9 +5588,26 @@ def full(
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
return torch.fill(e, fill_value) # type: ignore[arg-type]
|
||||
return prims.fill(e, fill_value) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _get_shape_permutation_like(
|
||||
a: TensorLikeType, layout: torch.layout
|
||||
) -> tuple[ShapeType, StrideType]:
|
||||
assert layout == torch.strided
|
||||
|
||||
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(a)
|
||||
shape = [a.shape[l] for l in physical_layout]
|
||||
|
||||
permutation = [0] * len(shape)
|
||||
for p, l in enumerate(physical_layout):
|
||||
permutation[l] = p
|
||||
|
||||
return (shape, permutation)
|
||||
|
||||
|
||||
@register_decomposition(aten.full_like)
|
||||
@out_wrapper()
|
||||
def full_like(
|
||||
a: TensorLikeType,
|
||||
fill_value: NumberType,
|
||||
@ -5602,16 +5619,36 @@ def full_like(
|
||||
requires_grad: bool = False,
|
||||
memory_format: torch.memory_format = torch.preserve_format,
|
||||
) -> TensorLikeType:
|
||||
e = torch.empty_like(
|
||||
a,
|
||||
dtype = a.dtype if dtype is None else dtype
|
||||
layout = a.layout if layout is None else layout
|
||||
device = a.device if device is None else device
|
||||
|
||||
if memory_format != torch.preserve_format:
|
||||
result = torch.full(
|
||||
a.shape,
|
||||
fill_value,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
return fill(e, fill_value)
|
||||
return result.to(memory_format=memory_format)
|
||||
|
||||
else:
|
||||
shape, permutation = _get_shape_permutation_like(a, layout)
|
||||
result = torch.full(
|
||||
shape,
|
||||
fill_value,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
if permutation == list(range(len(permutation))):
|
||||
return result
|
||||
return result.permute(permutation).clone()
|
||||
|
||||
|
||||
@register_decomposition(aten.zeros_like)
|
||||
|
@ -1923,7 +1923,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
|
||||
def get_val(dtype):
|
||||
return make_tensor([], dtype=dtype, device="cpu").item()
|
||||
|
||||
double_dtype = torch.double if device != "mps:0" else torch.float
|
||||
double_dtype = torch.double if torch.device(device).type != "mps" else torch.float
|
||||
inputs = [
|
||||
((), get_val(dtype), {}),
|
||||
((S, S), get_val(dtype), {}),
|
||||
@ -24603,6 +24603,10 @@ python_ref_db = [
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.full_like",
|
||||
torch_opinfo_name="full_like",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.randn",
|
||||
torch_opinfo_name="randn",
|
||||
|
Reference in New Issue
Block a user