mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Relax aten.to restriction (#142420)
Summary: if we have a.to(b), and b has a different dtype with a, then it must be a copy. In this case, we do not need to freeze the tensor. Instead, we use torch.ops.aten._assert_tensor_metadata.default to ensure that a must not have the same dtype as b. Fixes https://github.com/pytorch/pytorch/issues/139718 Update executorch pin to include https://github.com/pytorch/executorch/pull/7277. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_float_conversion buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_device_to_mutation_float ``` Differential Revision: D66988295 Pull Request resolved: https://github.com/pytorch/pytorch/pull/142420 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
768d73f692
commit
0e1675a89b
@ -1 +1 @@
|
||||
6f638937d64e3396793956d75ee3e14802022745
|
||||
a29b208a06ab378bb29ab1aa68932e412f8e09f1
|
||||
|
||||
@ -5108,6 +5108,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
for op in ops:
|
||||
self.assertIn(op, (torch.ops.aten._to_copy.default,))
|
||||
|
||||
def test_float_conversion_from_int(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.float()
|
||||
|
||||
ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions(
|
||||
{}
|
||||
)
|
||||
ops = []
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
ops.append(node.target)
|
||||
self.assertGreater(len(ops), 0)
|
||||
self.assertIn(torch.ops.aten._to_copy.default, ops)
|
||||
self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops)
|
||||
|
||||
self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1)
|
||||
|
||||
# Raises error because the input dtype is not the same as the input
|
||||
# tensor when exporting.
|
||||
with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"):
|
||||
ep.module()(torch.tensor(1, dtype=torch.float32))
|
||||
|
||||
def test_device_to_mutation_float(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
@ -535,7 +535,27 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
torch.ops.aten.dropout.default,
|
||||
torch.ops.aten._to_copy.default,
|
||||
):
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
|
||||
def must_copy():
|
||||
"""
|
||||
Return True if the output of the op must be copied, not an alias
|
||||
"""
|
||||
# output dtype is different from input
|
||||
return (
|
||||
func == torch.ops.aten._to_copy.default
|
||||
and "dtype" in kwargs
|
||||
and kwargs["dtype"] != args_unwrapped[0].dtype
|
||||
)
|
||||
|
||||
if must_copy():
|
||||
# We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think
|
||||
# we have an aten op for that.
|
||||
torch.ops.aten._assert_tensor_metadata.default(
|
||||
torch._from_functional_tensor(args_unwrapped[0]),
|
||||
dtype=args_unwrapped[0].dtype,
|
||||
)
|
||||
else:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user