[torchbind] fix fakifying a staitc tensor returns dynamic accidentally (#158607)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158607
Approved by: https://github.com/zou3519
ghstack dependencies: #158583, #158606
This commit is contained in:
Yidi Wu
2025-07-25 10:33:39 -07:00
committed by PyTorch MergeBot
parent 0427e439aa
commit 0f31e9a656
4 changed files with 63 additions and 18 deletions

View File

@ -421,7 +421,9 @@ struct FlattenWithTensorOp : public torch::CustomClassHolder {
explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {}
at::Tensor get() {
return t_;
// Need to return a copy of the tensor, otherwise the tensor will be
// aliased with a tensor that may be modified by the user or backend.
return t_.clone();
}
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
@ -437,7 +439,9 @@ struct ContainsTensor : public torch::CustomClassHolder {
explicit ContainsTensor(at::Tensor t) : t_(t) {}
at::Tensor get() {
return t_;
// Need to return a copy of the tensor, otherwise the tensor will be
// aliased with a tensor that may be modified by the user or backend.
return t_.clone();
}
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {

View File

@ -1201,7 +1201,7 @@ class TestCompileTorchbind(TestCase):
return (x_sin,)""",
)
@parametrize("backend", ["eager", "aot_eager"])
@parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_compile_script_object_input_guards(self, backend):
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -1370,7 +1370,7 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
return (sub, add)""",
)
@parametrize("backend", ["eager", "aot_eager"])
@parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_compile_tensor_op_in_tensor_flatten(self, backend):
test_obj = torch.classes._TorchScriptTesting._FlattenWithTensorOp(
torch.randn(3, 2)
@ -1378,24 +1378,31 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
class TestMod(torch.nn.Module):
def forward(self, obj, x):
return obj.get() + x
return obj.get() + x + obj.get().size(0)
mod = TestMod()
torch.compile(mod, backend=backend, fullgraph=True)(test_obj, torch.randn(3, 1))
x = torch.randn(3, 1)
eager_out = mod(test_obj, x)
compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(test_obj, x)
ep = torch.export.export_for_training(
mod, (test_obj, torch.randn(3, 1)), strict=False
mod, (test_obj, x), strict=False
).run_decompositions({})
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj, x):
with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None
with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
return (getitem, add_3)""", # noqa: B950
add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.call_torchbind, obj, 'get'); getitem = obj = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
add_1 = torch.ops.aten.add.Tensor(add, 3); add = None
return (getitem_2, add_1)""", # noqa: B950
)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(eager_out, ep.module()(test_obj, x))
@parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_compile_error_on_non_fakified_method(self, backend):

View File

@ -795,6 +795,10 @@ def check_input_alias_and_mutation(
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
def _tensor_storage(t) -> StorageWeakRef:
return StorageWeakRef(t._typed_storage())
def check_input_alias_and_mutation_return_outputs(
gm: torch.fx.GraphModule,
fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]],
@ -844,9 +848,6 @@ def check_input_alias_and_mutation_return_outputs(
return t._version
return None
def _tensor_storage(t) -> StorageWeakRef:
return StorageWeakRef(t._typed_storage())
def _get_shape_env(
fake_args,
) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]:

View File

@ -143,9 +143,42 @@ def maybe_to_fake_obj(
_check_valid_flat_script_obj(flat_x)
with fake_mode:
from torch._higher_order_ops.utils import _tensor_storage
storage_map = {
_tensor_storage(inp): i
for i, inp in enumerate(flat_x)
if isinstance(inp, torch.Tensor)
}
alias_map = {
i: storage_map[_tensor_storage(inp)]
for i, inp in enumerate(flat_x)
if isinstance(inp, torch.Tensor) and storage_map[_tensor_storage(inp)] != i
}
if len(alias_map) > 0:
log.warning(
"Detected script object %s has aliasing relationship among its tensors. "
"Flattened obj: %s. Aliasing tensor indices: %s. "
"This is not supported and may cause unexpected behavior.",
x,
flat_x,
alias_map,
)
# This breaks the aliasing relationship among the tensors inside the torchbind object
# This is bad but since we don't need to preserve the aliasing relationship anyway and
# we state clearly that aliasing relationship is not preserved in the doc so this might be OK.
fake_flattened = pytree.tree_map_only(
torch.Tensor,
lambda t: fake_mode.from_tensor(t),
lambda t: torch.empty_strided(
t.size(),
t.stride(),
device=t.device,
dtype=t.dtype,
requires_grad=t.requires_grad,
layout=t.layout,
),
flat_x,
)