diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index bb9ab85c202d..3aa981a7883b 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -421,7 +421,9 @@ struct FlattenWithTensorOp : public torch::CustomClassHolder { explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {} at::Tensor get() { - return t_; + // Need to return a copy of the tensor, otherwise the tensor will be + // aliased with a tensor that may be modified by the user or backend. + return t_.clone(); } std::tuple> __obj_flatten__() { @@ -437,7 +439,9 @@ struct ContainsTensor : public torch::CustomClassHolder { explicit ContainsTensor(at::Tensor t) : t_(t) {} at::Tensor get() { - return t_; + // Need to return a copy of the tensor, otherwise the tensor will be + // aliased with a tensor that may be modified by the user or backend. + return t_.clone(); } std::tuple> __obj_flatten__() { diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 735e83f22c0c..a18dd87245da 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1201,7 +1201,7 @@ class TestCompileTorchbind(TestCase): return (x_sin,)""", ) - @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_compile_script_object_input_guards(self, backend): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1370,7 +1370,7 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject): return (sub, add)""", ) - @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_compile_tensor_op_in_tensor_flatten(self, backend): test_obj = torch.classes._TorchScriptTesting._FlattenWithTensorOp( torch.randn(3, 2) @@ -1378,24 +1378,31 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject): class TestMod(torch.nn.Module): def forward(self, obj, x): - return obj.get() + x + return obj.get() + x + obj.get().size(0) mod = TestMod() - torch.compile(mod, backend=backend, fullgraph=True)(test_obj, torch.randn(3, 1)) + x = torch.randn(3, 1) + eager_out = mod(test_obj, x) + compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(test_obj, x) ep = torch.export.export_for_training( - mod, (test_obj, torch.randn(3, 1)), strict=False + mod, (test_obj, x), strict=False ).run_decompositions({}) self.assertExpectedInline( ep.graph_module.code.strip(), """\ def forward(self, token, obj, x): - with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None + with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = None getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None - add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None - return (getitem, add_3)""", # noqa: B950 + add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None + with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.call_torchbind, obj, 'get'); getitem = obj = None + getitem_2 = with_effects_1[0]; with_effects_1 = None + add_1 = torch.ops.aten.add.Tensor(add, 3); add = None + return (getitem_2, add_1)""", # noqa: B950 ) + self.assertEqual(eager_out, compiled_out) + self.assertEqual(eager_out, ep.module()(test_obj, x)) @parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_compile_error_on_non_fakified_method(self, backend): diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 1f7428e53819..25ef972864d5 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -795,6 +795,10 @@ def check_input_alias_and_mutation( return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs +def _tensor_storage(t) -> StorageWeakRef: + return StorageWeakRef(t._typed_storage()) + + def check_input_alias_and_mutation_return_outputs( gm: torch.fx.GraphModule, fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]], @@ -844,9 +848,6 @@ def check_input_alias_and_mutation_return_outputs( return t._version return None - def _tensor_storage(t) -> StorageWeakRef: - return StorageWeakRef(t._typed_storage()) - def _get_shape_env( fake_args, ) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]: diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index 4cb79ae48725..68208d0be4a8 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -143,11 +143,44 @@ def maybe_to_fake_obj( _check_valid_flat_script_obj(flat_x) - fake_flattened = pytree.tree_map_only( - torch.Tensor, - lambda t: fake_mode.from_tensor(t), - flat_x, - ) + with fake_mode: + from torch._higher_order_ops.utils import _tensor_storage + + storage_map = { + _tensor_storage(inp): i + for i, inp in enumerate(flat_x) + if isinstance(inp, torch.Tensor) + } + alias_map = { + i: storage_map[_tensor_storage(inp)] + for i, inp in enumerate(flat_x) + if isinstance(inp, torch.Tensor) and storage_map[_tensor_storage(inp)] != i + } + if len(alias_map) > 0: + log.warning( + "Detected script object %s has aliasing relationship among its tensors. " + "Flattened obj: %s. Aliasing tensor indices: %s. " + "This is not supported and may cause unexpected behavior.", + x, + flat_x, + alias_map, + ) + + # This breaks the aliasing relationship among the tensors inside the torchbind object + # This is bad but since we don't need to preserve the aliasing relationship anyway and + # we state clearly that aliasing relationship is not preserved in the doc so this might be OK. + fake_flattened = pytree.tree_map_only( + torch.Tensor, + lambda t: torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + layout=t.layout, + ), + flat_x, + ) fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)