Relax aten.to restriction (#142420)

Summary: if we have a.to(b), and b has a different dtype with a, then it must be a copy. In this case, we do not need to freeze the tensor. Instead, we use torch.ops.aten._assert_tensor_metadata.default to ensure that a must not have the same dtype as b.

Fixes https://github.com/pytorch/pytorch/issues/139718

Update executorch pin to include https://github.com/pytorch/executorch/pull/7277.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_float_conversion
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_device_to_mutation_float
```

Differential Revision: D66988295

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142420
Approved by: https://github.com/bdhirsh
This commit is contained in:
Shangdi Yu
2025-01-08 18:11:31 +00:00
committed by PyTorch MergeBot
parent 768d73f692
commit 0e1675a89b
3 changed files with 45 additions and 2 deletions

View File

@ -1 +1 @@
6f638937d64e3396793956d75ee3e14802022745
a29b208a06ab378bb29ab1aa68932e412f8e09f1

View File

@ -5108,6 +5108,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
for op in ops:
self.assertIn(op, (torch.ops.aten._to_copy.default,))
def test_float_conversion_from_int(self):
class Module(torch.nn.Module):
def forward(self, x):
return x.float()
ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions(
{}
)
ops = []
for node in ep.graph.nodes:
if node.op == "call_function":
ops.append(node.target)
self.assertGreater(len(ops), 0)
self.assertIn(torch.ops.aten._to_copy.default, ops)
self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops)
self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1)
# Raises error because the input dtype is not the same as the input
# tensor when exporting.
with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"):
ep.module()(torch.tensor(1, dtype=torch.float32))
def test_device_to_mutation_float(self):
class Module(torch.nn.Module):
def forward(self, x):

View File

@ -535,7 +535,27 @@ class FunctionalTensorMode(TorchDispatchMode):
torch.ops.aten.dropout.default,
torch.ops.aten._to_copy.default,
):
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
def must_copy():
"""
Return True if the output of the op must be copied, not an alias
"""
# output dtype is different from input
return (
func == torch.ops.aten._to_copy.default
and "dtype" in kwargs
and kwargs["dtype"] != args_unwrapped[0].dtype
)
if must_copy():
# We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think
# we have an aten op for that.
torch.ops.aten._assert_tensor_metadata.default(
torch._from_functional_tensor(args_unwrapped[0]),
dtype=args_unwrapped[0].dtype,
)
else:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)