[dynamo][invoke_subgraph] Input aliasing and mutation check in Dynamo (#148953)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148953
Approved by: https://github.com/zou3519
ghstack dependencies: #149087, #149667, #150036
This commit is contained in:
Animesh Jain
2025-03-27 13:17:16 -07:00
committed by PyTorch MergeBot
parent c18e2ce53b
commit c9ebf517c2
4 changed files with 292 additions and 27 deletions

View File

@ -159,10 +159,15 @@ class GraphModule(torch.nn.Module):
def f(inner, x, y):
return invoke_quant_test(inner, x, y, scheme="nf4")
with self.assertRaisesRegex(RuntimeError, "aliases of the inputs"):
with self.assertRaisesRegex(
RuntimeError, "Encountered aliasing during higher order op tracing for HOP"
):
f(inner, x, y)
with self.assertRaisesRegex(RuntimeError, "inputs are mutated"):
with self.assertRaisesRegex(
RuntimeError,
"Encountered input mutation during higher order op tracing for HOP",
):
f(inner2, x, y)
def test_eager_call(self):

View File

@ -115,7 +115,58 @@ class TestInvokeSubgraphCompile(TestCase):
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = gn(x, y)
ref = fn(x, y)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
def test_list(self):
@mark_compile_region
def gn(x, y):
return [torch.mul(x, y), torch.add(x, y)]
def fn(x, y):
lst = gn(x, y)
lst.append(torch.sin(x))
return lst[0] + lst[1] + lst[2]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = fn(x, y)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)
# Run backward
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
def test_tuple_of_tuple(self):
@mark_compile_region
def gn(x, y):
return ((torch.mul(x, y),), torch.add(x, y))
def fn(x, y):
tup = gn(x, y)
return tup[0][0] + tup[1]
x = torch.randn(8, requires_grad=True)
y = torch.randn(8, requires_grad=True)
ref = fn(x, y)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
@ -477,7 +528,29 @@ class GraphModule(torch.nn.Module):
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing"
torch._dynamo.exc.Unsupported,
"Encountered input mutation during higher order op tracing for HOP - invoke_subgraph",
):
opt_fn(x, y)
def test_input_mutation_inference_mode(self):
@mark_compile_region
def gn(x, y):
x.add_(1)
return torch.mul(x, y)
def fn(x, y):
z = torch.cos(x)
with torch.inference_mode():
return gn(torch.cos(z), y)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Encountered input mutation during higher order op tracing",
):
opt_fn(x, y)
@ -520,7 +593,7 @@ class GraphModule(torch.nn.Module):
):
opt_fn(x)
def test_input_aliasing(self):
def test_input_output_aliasing(self):
@mark_compile_region
def gn(x, y):
return (x, torch.mul(x, y))
@ -534,7 +607,73 @@ class GraphModule(torch.nn.Module):
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing"
torch._dynamo.exc.Unsupported,
"Encountered aliasing during higher order op tracing",
):
opt_fn(x, y)
def test_input_input_aliasing(self):
@mark_compile_region
def gn(x, y):
return torch.mul(x, y)
def fn(x):
return gn(x, x.view(1, 8))
x = torch.randn(8, requires_grad=False)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Encountered aliasing during higher order op tracing",
):
opt_fn(x)
def test_output_output_aliasing(self):
@mark_compile_region
def gn(x):
z = torch.cos(x)
return z, z.view(1, 8)
def fn(x):
return gn(x)
x = torch.randn(8, requires_grad=False)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Encountered aliasing during higher order op tracing",
):
opt_fn(x)
def test_mod_attr_aliasing(self):
class MutateParam(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.ones(8)
def forward(self, x):
self.a.add_(1)
return torch.mul(x, self.a)
@mark_compile_region
def gn(x):
return mod(x)
def fn(x, y):
return gn(x) * y
mod = MutateParam()
x = torch.randn(8, requires_grad=False)
y = torch.randn(8, requires_grad=False)
fn(x, y)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Encountered input mutation during higher order op tracing",
):
opt_fn(x, y)

View File

@ -63,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import (
ShapeEnv,
)
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@ -165,6 +166,18 @@ class VariableTrackerCacheKey:
source: Source
@dataclass(frozen=True)
class AliasingInfo:
has_aliasing: bool
msg: str
@dataclass(frozen=True)
class MutationInfo:
has_mutation: bool
msg: str
class VariableTrackerCache:
def __init__(self):
self.cache = {}
@ -2023,6 +2036,13 @@ class SubgraphTracer(fx.Tracer):
# This is used to create a unique name for the placeholder
self._used_names: OrderedSet[str] = OrderedSet()
# Stores the versions of the input tensors at the time they are inserted
# as placeholders in the graph. This is used to track input mutation.
self._input_versions_at_beginning: list[int] = []
if torch.is_inference_mode_enabled():
raise RuntimeError(
"Inference mode is supposed to be disabled during compilation. Please open an issue."
)
# preserve original meta if it is available
def _maybe_preserve_original_meta(self, tx, node):
@ -2273,6 +2293,8 @@ class SubgraphTracer(fx.Tracer):
def create_graph_input(
self, name, type_expr, example_value, before=False, source=None
):
if isinstance(example_value, torch.Tensor):
self._input_versions_at_beginning.append(example_value._version)
log.debug(
"create_graph_input %s %s %s at debug_level %s before=%s",
name,
@ -2690,6 +2712,77 @@ class SubgraphTracer(fx.Tracer):
# Sort the symbols so that we can have a deterministic lifting order
return sorted(to_be_bound, key=lambda s: s.name)
def has_input_mutation(self):
input_versions_at_beginning = self._input_versions_at_beginning
input_nodes = []
input_versions_at_end = []
for node in self.graph.nodes:
if node.op == "placeholder":
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
input_versions_at_end.append(example_value._version)
input_nodes.append(node)
else:
break
mutated_inputs = [
i
for i, (v1, v2) in enumerate(
zip(input_versions_at_beginning, input_versions_at_end)
)
if v1 != v2
]
if len(mutated_inputs):
mutated_nodes = [input_nodes[i] for i in mutated_inputs]
msg = f"Input mutation detected at {mutated_nodes}"
return MutationInfo(True, msg)
return MutationInfo(False, "")
def has_aliasing(self):
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
for node in self.graph.nodes:
if node.op == "placeholder":
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
if storage in input_storages:
# input-input aliasing
msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
return AliasingInfo(True, msg)
input_storages[storage] = node
else:
break
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
out_nodes = self.graph.find_nodes(op="output")[0]
for out_node in out_nodes.args[0]:
if out_node:
example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list)
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
if storage in output_storages:
# output-output aliasing
msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
return AliasingInfo(True, msg)
output_storages[storage] = out_node
intersected_storages = input_storages.keys() & output_storages.keys()
if len(intersected_storages) > 0:
# input-output aliasing
aliased = [
(input_storages[s], output_storages[s]) for s in intersected_storages
]
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
msg = f"Input-to-output aliasing detected at nodes {aliased}"
return AliasingInfo(True, msg)
return AliasingInfo(False, "")
# NOTE: [HigherOrderOperator tracing design]
# Ignoring HigherOrderOperators for a moment,

View File

@ -50,6 +50,7 @@ from ..exc import (
IncorrectUsage,
UncapturedHigherOrderOpError,
unimplemented,
unimplemented_v2,
Unsupported,
)
from ..source import AttrSource, DictGetItemSource
@ -506,6 +507,9 @@ def speculate_subgraph(
restore_side_effects=True,
should_flatten_outputs=False,
under_activation_checkpoint=False,
# TODO - supports input_mutation and aliasing should be False by default for strictness
supports_input_mutation=True,
supports_aliasing=True,
# Pass in an originating tracer - this is needed for preserving context
# across fwd-bwd for autograd.Function
tracer=None,
@ -694,6 +698,34 @@ def speculate_subgraph(
if len(lifted_freevars) > 0:
move_lifted_freevars_phs_to_end(graph, lifted_freevars)
if not supports_input_mutation:
mutation_info = subtracer.has_input_mutation()
if mutation_info.has_mutation:
context = f"{mutation_info.msg} in\n {graph}"
unimplemented_v2(
gb_type=f"Encountered input mutation during higher order op tracing for HOP - {source_target.name()}",
context=context,
explanation="Higher order ops do not support input mutation",
hints=[
"Consider using the debug context to change user code to avoid mutation.",
"Please open an issue.",
],
)
if not supports_aliasing:
aliasing_info = subtracer.has_aliasing()
if aliasing_info.has_aliasing:
context = f"{aliasing_info.msg} in\n {graph}"
unimplemented_v2(
gb_type=f"Encountered aliasing during higher order op tracing for HOP - {source_target.name()}",
context=context,
explanation="Higher order ops do not support aliasing",
hints=[
"Consider using the debug context to change user code to avoid aliasing.",
"Please open an issue.",
],
)
return (
(output, treespec),
graph,
@ -1794,6 +1826,11 @@ class FunctionalCallVariable(FunctorchHigherOrderVariable):
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.supports_input_mutation = True
self.supports_aliasing = True
def install_subgraph_in_output_graph(
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body"
):
@ -1828,6 +1865,8 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
source_target=self.value,
should_flatten_outputs=True,
under_activation_checkpoint=under_activation_checkpoint,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
@ -3039,6 +3078,11 @@ def hash_graph_and_inputs(tx, gmod, fake_inputs):
class BaseHOPVariable(WrapHigherOrderVariable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.supports_input_mutation = False
self.supports_aliasing = False
def python_type(self):
return type(self.value)
@ -3061,20 +3105,6 @@ class BaseHOPVariable(WrapHigherOrderVariable):
)
assert len(p_kwargs) == 0
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
fake_inputs = [
node.meta["example_value"]
for node in body_gmod.graph.nodes
if node.op == "placeholder"
]
if has_potential_input_alias_or_mutation(body_gmod, fake_inputs):
raise RuntimeError(
f"{self.value._name} where the inputs are mutated or the "
f"outputs are aliases of the inputs. Please ensure that this doesn't happen."
)
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
@ -3087,6 +3117,11 @@ class BaseHOPVariable(WrapHigherOrderVariable):
class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.supports_input_mutation = False
self.supports_aliasing = False
def install_subgraph_in_output_graph(
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name
):
@ -3094,19 +3129,12 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
# inputs have already been seen before. If yes, the subgraph is already
# installed in the output graph and we can just access the subgraph
# using the saved attr name.
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
fake_inputs = [
node.meta["example_value"]
for node in body_gmod.graph.nodes
if node.op == "placeholder"
]
# TODO(anijain2305) - This might be too big of a limitation. Consider
# supporting mutation/aliasing in HOP itself to remove this restriction.
if has_potential_input_alias_or_mutation(body_gmod, fake_inputs):
unimplemented("NYI: invoke_subgraph with aliasing/mutation")
key = hash_graph_and_inputs(tx, body_gmod, fake_inputs)
invoke_subgraph_cache = (