mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix full_like decomposition to preserve strides (#144765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144765 Approved by: https://github.com/amjames, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6401d1d53d
commit
01b0f09931
@ -823,8 +823,6 @@ aten::from_file
|
|||||||
aten::from_file.out
|
aten::from_file.out
|
||||||
aten::full.names
|
aten::full.names
|
||||||
aten::full.names_out
|
aten::full.names_out
|
||||||
aten::full_like
|
|
||||||
aten::full_like.out
|
|
||||||
aten::gather
|
aten::gather
|
||||||
aten::gather.out
|
aten::gather.out
|
||||||
aten::geqrf
|
aten::geqrf
|
||||||
|
@ -52,8 +52,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||||
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
||||||
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||||
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
div_1 = torch.ops.aten.div.Scalar(full, 1); full = None
|
||||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||||
@ -98,8 +98,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
|||||||
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None
|
||||||
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
neg = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||||
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
div = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
||||||
full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format)
|
full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||||
div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None
|
div_1 = torch.ops.aten.div.Scalar(full, 1); full = None
|
||||||
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||||
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||||
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None
|
||||||
|
@ -432,6 +432,8 @@ def check_model(
|
|||||||
check_gradient=False,
|
check_gradient=False,
|
||||||
check_has_compiled=True,
|
check_has_compiled=True,
|
||||||
output_process_fn_grad=lambda x: x,
|
output_process_fn_grad=lambda x: x,
|
||||||
|
# TODO: enable this for all tests
|
||||||
|
exact_stride=False,
|
||||||
):
|
):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
@ -544,6 +546,7 @@ def check_model(
|
|||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
equal_nan=True,
|
equal_nan=True,
|
||||||
exact_dtype=exact_dtype,
|
exact_dtype=exact_dtype,
|
||||||
|
exact_stride=exact_stride,
|
||||||
)
|
)
|
||||||
# In case of input mutations, check that inputs are the same
|
# In case of input mutations, check that inputs are the same
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -554,6 +557,7 @@ def check_model(
|
|||||||
equal_nan=True,
|
equal_nan=True,
|
||||||
# our testing sometimes uses higher precision inputs for the reference
|
# our testing sometimes uses higher precision inputs for the reference
|
||||||
exact_dtype=False,
|
exact_dtype=False,
|
||||||
|
exact_stride=exact_stride,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for correct_val, actual_val in zip(correct_flat, actual_flat):
|
for correct_val, actual_val in zip(correct_flat, actual_flat):
|
||||||
@ -567,6 +571,8 @@ def check_model(
|
|||||||
assert correct_val.layout == actual_val.layout
|
assert correct_val.layout == actual_val.layout
|
||||||
if exact_dtype:
|
if exact_dtype:
|
||||||
assert correct_val.dtype == actual_val.dtype
|
assert correct_val.dtype == actual_val.dtype
|
||||||
|
if exact_stride:
|
||||||
|
assert correct_val.stride() == actual_val.stride()
|
||||||
|
|
||||||
if check_gradient:
|
if check_gradient:
|
||||||
actual = output_process_fn_grad(actual)
|
actual = output_process_fn_grad(actual)
|
||||||
@ -620,6 +626,7 @@ def check_model(
|
|||||||
rtol=grad_rtol or rtol,
|
rtol=grad_rtol or rtol,
|
||||||
equal_nan=True,
|
equal_nan=True,
|
||||||
exact_dtype=exact_dtype,
|
exact_dtype=exact_dtype,
|
||||||
|
exact_stride=exact_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
@ -645,6 +652,8 @@ def check_model_gpu(
|
|||||||
check_gradient=False,
|
check_gradient=False,
|
||||||
check_has_compiled=True,
|
check_has_compiled=True,
|
||||||
output_process_fn_grad=lambda x: x,
|
output_process_fn_grad=lambda x: x,
|
||||||
|
# TODO: enable this for all tests
|
||||||
|
exact_stride=False,
|
||||||
):
|
):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
if hasattr(model, "to"):
|
if hasattr(model, "to"):
|
||||||
@ -671,6 +680,7 @@ def check_model_gpu(
|
|||||||
check_gradient=check_gradient,
|
check_gradient=check_gradient,
|
||||||
check_has_compiled=check_has_compiled,
|
check_has_compiled=check_has_compiled,
|
||||||
output_process_fn_grad=output_process_fn_grad,
|
output_process_fn_grad=output_process_fn_grad,
|
||||||
|
exact_stride=exact_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
if check_lowp:
|
if check_lowp:
|
||||||
@ -703,6 +713,7 @@ def check_model_gpu(
|
|||||||
check_gradient=check_gradient,
|
check_gradient=check_gradient,
|
||||||
check_has_compiled=check_has_compiled,
|
check_has_compiled=check_has_compiled,
|
||||||
output_process_fn_grad=output_process_fn_grad,
|
output_process_fn_grad=output_process_fn_grad,
|
||||||
|
exact_stride=exact_stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -6960,6 +6971,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||||||
|
|
||||||
self.common(fn, (torch.randn(8),))
|
self.common(fn, (torch.randn(8),))
|
||||||
|
|
||||||
|
def test_full_like_stride(self):
|
||||||
|
def fn(a):
|
||||||
|
return torch.full_like(a, 3)
|
||||||
|
|
||||||
|
self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
|
||||||
|
|
||||||
def test_full_truncation(self):
|
def test_full_truncation(self):
|
||||||
def fn(a):
|
def fn(a):
|
||||||
return a + torch.full_like(a, 7.777)
|
return a + torch.full_like(a, 7.777)
|
||||||
|
@ -545,6 +545,11 @@ comprehensive_failures = {
|
|||||||
xfail(
|
xfail(
|
||||||
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
|
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
|
||||||
), # off by one error
|
), # off by one error
|
||||||
|
skip(
|
||||||
|
"nn.functional.nll_loss",
|
||||||
|
"",
|
||||||
|
dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16),
|
||||||
|
), # non-deterministic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -861,7 +866,16 @@ def forward(self, scores_1, mask_1, value_1):
|
|||||||
assert len(real_out) == len(decomp_out)
|
assert len(real_out) == len(decomp_out)
|
||||||
|
|
||||||
if do_relative_check:
|
if do_relative_check:
|
||||||
upcast = partial(upcast_tensor, dtype=torch.float64)
|
device_arg = kwargs.get("device", None)
|
||||||
|
|
||||||
|
def upcast(x):
|
||||||
|
if (isinstance(x, Tensor) and x.device.type == "mps") or (
|
||||||
|
device_arg and torch.device(device_arg).type == "mps"
|
||||||
|
):
|
||||||
|
return upcast_tensor(x, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
return upcast_tensor(x, dtype=torch.float64)
|
||||||
|
|
||||||
real_out_double, _ = tree_flatten(
|
real_out_double, _ = tree_flatten(
|
||||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||||
)
|
)
|
||||||
|
@ -8530,14 +8530,6 @@ BACKWARD_SKIPS_AND_XFAILS = [
|
|||||||
|
|
||||||
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
||||||
*FORWARD_SKIPS_AND_XFAILS,
|
*FORWARD_SKIPS_AND_XFAILS,
|
||||||
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
|
|
||||||
# e.g. Expected 5 == 4
|
|
||||||
XFailRule(
|
|
||||||
error_type=AssertionError,
|
|
||||||
op_match_fn=lambda device, op: (op.full_name == "fill"),
|
|
||||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
|
||||||
name="fill_aot_autograd_bug_with_transposed_input",
|
|
||||||
),
|
|
||||||
# Bug: cross-device conversions with to() result in new nested ints within compile only
|
# Bug: cross-device conversions with to() result in new nested ints within compile only
|
||||||
XFailRule(
|
XFailRule(
|
||||||
error_type=AssertionError,
|
error_type=AssertionError,
|
||||||
@ -8581,12 +8573,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
|||||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||||
name="crazy_aot_autograd_bug1",
|
name="crazy_aot_autograd_bug1",
|
||||||
),
|
),
|
||||||
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
|
|
||||||
XFailRule(
|
|
||||||
op_match_fn=lambda device, op: (op.full_name == "isreal"),
|
|
||||||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
|
||||||
name="crazy_aot_autograd_bug2",
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||||
|
@ -2294,7 +2294,6 @@ class TestRefsOpsInfo(TestCase):
|
|||||||
"_refs.empty_strided",
|
"_refs.empty_strided",
|
||||||
"_refs.equal",
|
"_refs.equal",
|
||||||
"_refs.full",
|
"_refs.full",
|
||||||
"_refs.full_like",
|
|
||||||
"_refs.is_complex",
|
"_refs.is_complex",
|
||||||
"_refs.to",
|
"_refs.to",
|
||||||
"_refs.mvlgamma",
|
"_refs.mvlgamma",
|
||||||
@ -2409,7 +2408,6 @@ class TestRefsOpsInfo(TestCase):
|
|||||||
"_refs.unflatten",
|
"_refs.unflatten",
|
||||||
"_refs.sum_to_size",
|
"_refs.sum_to_size",
|
||||||
# ref implementation missing kwargs
|
# ref implementation missing kwargs
|
||||||
"_refs.full_like", # missing "layout"
|
|
||||||
"_refs.scalar_tensor", # missing "layout"
|
"_refs.scalar_tensor", # missing "layout"
|
||||||
# other
|
# other
|
||||||
"_refs.block_diag", # only refs._block_diag_iterable is in decomposition table
|
"_refs.block_diag", # only refs._block_diag_iterable is in decomposition table
|
||||||
|
@ -346,6 +346,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
|||||||
aten.floor_divide,
|
aten.floor_divide,
|
||||||
aten.frac,
|
aten.frac,
|
||||||
aten.frac_,
|
aten.frac_,
|
||||||
|
aten.full_like,
|
||||||
aten._fused_moving_avg_obs_fq_helper,
|
aten._fused_moving_avg_obs_fq_helper,
|
||||||
aten.gelu_,
|
aten.gelu_,
|
||||||
aten.gelu_backward,
|
aten.gelu_backward,
|
||||||
|
@ -625,28 +625,6 @@ def randn_like(
|
|||||||
).to(memory_format=get_like_layout(self, memory_format))
|
).to(memory_format=get_like_layout(self, memory_format))
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.full_like)
|
|
||||||
def full_like(
|
|
||||||
self: torch.Tensor,
|
|
||||||
fill_value: Union[int, float],
|
|
||||||
*,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
layout: Optional[torch.layout] = None,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
pin_memory: bool = False,
|
|
||||||
requires_grad: bool = False,
|
|
||||||
memory_format: torch.memory_format = torch.preserve_format,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.full(
|
|
||||||
[*self.size()],
|
|
||||||
fill_value,
|
|
||||||
dtype=dtype or self.dtype,
|
|
||||||
layout=layout or self.layout,
|
|
||||||
device=device or self.device,
|
|
||||||
requires_grad=requires_grad,
|
|
||||||
).to(memory_format=get_like_layout(self, memory_format))
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.randint_like.default)
|
@register_decomposition(aten.randint_like.default)
|
||||||
def randint_like(
|
def randint_like(
|
||||||
self: torch.Tensor,
|
self: torch.Tensor,
|
||||||
|
@ -3177,7 +3177,6 @@ def _full(fill_value, device, dtype, size):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_lowering(aten.full_like, type_promotion_kind=None)
|
|
||||||
def full_like(x, fill_value, **kwargs):
|
def full_like(x, fill_value, **kwargs):
|
||||||
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
|
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
|
||||||
|
|
||||||
@ -6121,6 +6120,17 @@ def fill_(x, fill_value):
|
|||||||
return mutate_to(x, full_like(x, fill_value))
|
return mutate_to(x, full_like(x, fill_value))
|
||||||
|
|
||||||
|
|
||||||
|
@register_lowering(prims.fill, type_promotion_kind=None)
|
||||||
|
def prims_fill(x, fill_value):
|
||||||
|
dtype = x.get_dtype()
|
||||||
|
return Pointwise.create(
|
||||||
|
device=x.get_device(),
|
||||||
|
dtype=dtype,
|
||||||
|
inner_fn=lambda _: ops.constant(fill_value, dtype),
|
||||||
|
ranges=list(x.get_size()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_lowering(aten.copy_, type_promotion_kind=None)
|
@register_lowering(aten.copy_, type_promotion_kind=None)
|
||||||
def copy_(dst, src, non_blocking=False):
|
def copy_(dst, src, non_blocking=False):
|
||||||
if dst is src:
|
if dst is src:
|
||||||
|
@ -5588,9 +5588,26 @@ def full(
|
|||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
requires_grad=requires_grad,
|
requires_grad=requires_grad,
|
||||||
)
|
)
|
||||||
return torch.fill(e, fill_value) # type: ignore[arg-type]
|
return prims.fill(e, fill_value) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_shape_permutation_like(
|
||||||
|
a: TensorLikeType, layout: torch.layout
|
||||||
|
) -> tuple[ShapeType, StrideType]:
|
||||||
|
assert layout == torch.strided
|
||||||
|
|
||||||
|
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(a)
|
||||||
|
shape = [a.shape[l] for l in physical_layout]
|
||||||
|
|
||||||
|
permutation = [0] * len(shape)
|
||||||
|
for p, l in enumerate(physical_layout):
|
||||||
|
permutation[l] = p
|
||||||
|
|
||||||
|
return (shape, permutation)
|
||||||
|
|
||||||
|
|
||||||
|
@register_decomposition(aten.full_like)
|
||||||
|
@out_wrapper()
|
||||||
def full_like(
|
def full_like(
|
||||||
a: TensorLikeType,
|
a: TensorLikeType,
|
||||||
fill_value: NumberType,
|
fill_value: NumberType,
|
||||||
@ -5602,16 +5619,36 @@ def full_like(
|
|||||||
requires_grad: bool = False,
|
requires_grad: bool = False,
|
||||||
memory_format: torch.memory_format = torch.preserve_format,
|
memory_format: torch.memory_format = torch.preserve_format,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
e = torch.empty_like(
|
dtype = a.dtype if dtype is None else dtype
|
||||||
a,
|
layout = a.layout if layout is None else layout
|
||||||
dtype=dtype,
|
device = a.device if device is None else device
|
||||||
layout=layout,
|
|
||||||
device=device,
|
if memory_format != torch.preserve_format:
|
||||||
pin_memory=pin_memory,
|
result = torch.full(
|
||||||
requires_grad=requires_grad,
|
a.shape,
|
||||||
memory_format=memory_format,
|
fill_value,
|
||||||
)
|
dtype=dtype,
|
||||||
return fill(e, fill_value)
|
layout=layout,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
)
|
||||||
|
return result.to(memory_format=memory_format)
|
||||||
|
|
||||||
|
else:
|
||||||
|
shape, permutation = _get_shape_permutation_like(a, layout)
|
||||||
|
result = torch.full(
|
||||||
|
shape,
|
||||||
|
fill_value,
|
||||||
|
dtype=dtype,
|
||||||
|
layout=layout,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
)
|
||||||
|
if permutation == list(range(len(permutation))):
|
||||||
|
return result
|
||||||
|
return result.permute(permutation).clone()
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.zeros_like)
|
@register_decomposition(aten.zeros_like)
|
||||||
|
@ -1923,7 +1923,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
|
|||||||
def get_val(dtype):
|
def get_val(dtype):
|
||||||
return make_tensor([], dtype=dtype, device="cpu").item()
|
return make_tensor([], dtype=dtype, device="cpu").item()
|
||||||
|
|
||||||
double_dtype = torch.double if device != "mps:0" else torch.float
|
double_dtype = torch.double if torch.device(device).type != "mps" else torch.float
|
||||||
inputs = [
|
inputs = [
|
||||||
((), get_val(dtype), {}),
|
((), get_val(dtype), {}),
|
||||||
((S, S), get_val(dtype), {}),
|
((S, S), get_val(dtype), {}),
|
||||||
@ -24603,6 +24603,10 @@ python_ref_db = [
|
|||||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
PythonRefInfo(
|
||||||
|
"_refs.full_like",
|
||||||
|
torch_opinfo_name="full_like",
|
||||||
|
),
|
||||||
PythonRefInfo(
|
PythonRefInfo(
|
||||||
"_refs.randn",
|
"_refs.randn",
|
||||||
torch_opinfo_name="randn",
|
torch_opinfo_name="randn",
|
||||||
|
Reference in New Issue
Block a user