mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchbind] fix fakifying a staitc tensor returns dynamic accidentally (#158607)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158607 Approved by: https://github.com/zou3519 ghstack dependencies: #158583, #158606
This commit is contained in:
committed by
PyTorch MergeBot
parent
0427e439aa
commit
0f31e9a656
@ -421,7 +421,9 @@ struct FlattenWithTensorOp : public torch::CustomClassHolder {
|
||||
explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {}
|
||||
|
||||
at::Tensor get() {
|
||||
return t_;
|
||||
// Need to return a copy of the tensor, otherwise the tensor will be
|
||||
// aliased with a tensor that may be modified by the user or backend.
|
||||
return t_.clone();
|
||||
}
|
||||
|
||||
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
|
||||
@ -437,7 +439,9 @@ struct ContainsTensor : public torch::CustomClassHolder {
|
||||
explicit ContainsTensor(at::Tensor t) : t_(t) {}
|
||||
|
||||
at::Tensor get() {
|
||||
return t_;
|
||||
// Need to return a copy of the tensor, otherwise the tensor will be
|
||||
// aliased with a tensor that may be modified by the user or backend.
|
||||
return t_.clone();
|
||||
}
|
||||
|
||||
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
|
||||
|
@ -1201,7 +1201,7 @@ class TestCompileTorchbind(TestCase):
|
||||
return (x_sin,)""",
|
||||
)
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_compile_script_object_input_guards(self, backend):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1370,7 +1370,7 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
|
||||
return (sub, add)""",
|
||||
)
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_compile_tensor_op_in_tensor_flatten(self, backend):
|
||||
test_obj = torch.classes._TorchScriptTesting._FlattenWithTensorOp(
|
||||
torch.randn(3, 2)
|
||||
@ -1378,24 +1378,31 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
|
||||
|
||||
class TestMod(torch.nn.Module):
|
||||
def forward(self, obj, x):
|
||||
return obj.get() + x
|
||||
return obj.get() + x + obj.get().size(0)
|
||||
|
||||
mod = TestMod()
|
||||
|
||||
torch.compile(mod, backend=backend, fullgraph=True)(test_obj, torch.randn(3, 1))
|
||||
x = torch.randn(3, 1)
|
||||
eager_out = mod(test_obj, x)
|
||||
compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(test_obj, x)
|
||||
ep = torch.export.export_for_training(
|
||||
mod, (test_obj, torch.randn(3, 1)), strict=False
|
||||
mod, (test_obj, x), strict=False
|
||||
).run_decompositions({})
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, token, obj, x):
|
||||
with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None
|
||||
with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
|
||||
return (getitem, add_3)""", # noqa: B950
|
||||
add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.call_torchbind, obj, 'get'); getitem = obj = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, 3); add = None
|
||||
return (getitem_2, add_1)""", # noqa: B950
|
||||
)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
self.assertEqual(eager_out, ep.module()(test_obj, x))
|
||||
|
||||
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_compile_error_on_non_fakified_method(self, backend):
|
||||
|
@ -795,6 +795,10 @@ def check_input_alias_and_mutation(
|
||||
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
||||
|
||||
|
||||
def _tensor_storage(t) -> StorageWeakRef:
|
||||
return StorageWeakRef(t._typed_storage())
|
||||
|
||||
|
||||
def check_input_alias_and_mutation_return_outputs(
|
||||
gm: torch.fx.GraphModule,
|
||||
fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]],
|
||||
@ -844,9 +848,6 @@ def check_input_alias_and_mutation_return_outputs(
|
||||
return t._version
|
||||
return None
|
||||
|
||||
def _tensor_storage(t) -> StorageWeakRef:
|
||||
return StorageWeakRef(t._typed_storage())
|
||||
|
||||
def _get_shape_env(
|
||||
fake_args,
|
||||
) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]:
|
||||
|
@ -143,9 +143,42 @@ def maybe_to_fake_obj(
|
||||
|
||||
_check_valid_flat_script_obj(flat_x)
|
||||
|
||||
with fake_mode:
|
||||
from torch._higher_order_ops.utils import _tensor_storage
|
||||
|
||||
storage_map = {
|
||||
_tensor_storage(inp): i
|
||||
for i, inp in enumerate(flat_x)
|
||||
if isinstance(inp, torch.Tensor)
|
||||
}
|
||||
alias_map = {
|
||||
i: storage_map[_tensor_storage(inp)]
|
||||
for i, inp in enumerate(flat_x)
|
||||
if isinstance(inp, torch.Tensor) and storage_map[_tensor_storage(inp)] != i
|
||||
}
|
||||
if len(alias_map) > 0:
|
||||
log.warning(
|
||||
"Detected script object %s has aliasing relationship among its tensors. "
|
||||
"Flattened obj: %s. Aliasing tensor indices: %s. "
|
||||
"This is not supported and may cause unexpected behavior.",
|
||||
x,
|
||||
flat_x,
|
||||
alias_map,
|
||||
)
|
||||
|
||||
# This breaks the aliasing relationship among the tensors inside the torchbind object
|
||||
# This is bad but since we don't need to preserve the aliasing relationship anyway and
|
||||
# we state clearly that aliasing relationship is not preserved in the doc so this might be OK.
|
||||
fake_flattened = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda t: fake_mode.from_tensor(t),
|
||||
lambda t: torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
layout=t.layout,
|
||||
),
|
||||
flat_x,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user