mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][invoke_subgraph] Input aliasing and mutation check in Dynamo (#148953)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148953 Approved by: https://github.com/zou3519 ghstack dependencies: #149087, #149667, #150036
This commit is contained in:
committed by
PyTorch MergeBot
parent
c18e2ce53b
commit
c9ebf517c2
@ -159,10 +159,15 @@ class GraphModule(torch.nn.Module):
|
||||
def f(inner, x, y):
|
||||
return invoke_quant_test(inner, x, y, scheme="nf4")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "aliases of the inputs"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Encountered aliasing during higher order op tracing for HOP"
|
||||
):
|
||||
f(inner, x, y)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "inputs are mutated"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Encountered input mutation during higher order op tracing for HOP",
|
||||
):
|
||||
f(inner2, x, y)
|
||||
|
||||
def test_eager_call(self):
|
||||
|
||||
@ -115,7 +115,58 @@ class TestInvokeSubgraphCompile(TestCase):
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = gn(x, y)
|
||||
ref = fn(x, y)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_list(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
return [torch.mul(x, y), torch.add(x, y)]
|
||||
|
||||
def fn(x, y):
|
||||
lst = gn(x, y)
|
||||
lst.append(torch.sin(x))
|
||||
return lst[0] + lst[1] + lst[2]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)
|
||||
|
||||
# Run backward
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_tuple_of_tuple(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
return ((torch.mul(x, y),), torch.add(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
tup = gn(x, y)
|
||||
return tup[0][0] + tup[1]
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
y = torch.randn(8, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
@ -477,7 +528,29 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing"
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Encountered input mutation during higher order op tracing for HOP - invoke_subgraph",
|
||||
):
|
||||
opt_fn(x, y)
|
||||
|
||||
def test_input_mutation_inference_mode(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
x.add_(1)
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x, y):
|
||||
z = torch.cos(x)
|
||||
with torch.inference_mode():
|
||||
return gn(torch.cos(z), y)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Encountered input mutation during higher order op tracing",
|
||||
):
|
||||
opt_fn(x, y)
|
||||
|
||||
@ -520,7 +593,7 @@ class GraphModule(torch.nn.Module):
|
||||
):
|
||||
opt_fn(x)
|
||||
|
||||
def test_input_aliasing(self):
|
||||
def test_input_output_aliasing(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
return (x, torch.mul(x, y))
|
||||
@ -534,7 +607,73 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing"
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Encountered aliasing during higher order op tracing",
|
||||
):
|
||||
opt_fn(x, y)
|
||||
|
||||
def test_input_input_aliasing(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
return torch.mul(x, y)
|
||||
|
||||
def fn(x):
|
||||
return gn(x, x.view(1, 8))
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Encountered aliasing during higher order op tracing",
|
||||
):
|
||||
opt_fn(x)
|
||||
|
||||
def test_output_output_aliasing(self):
|
||||
@mark_compile_region
|
||||
def gn(x):
|
||||
z = torch.cos(x)
|
||||
return z, z.view(1, 8)
|
||||
|
||||
def fn(x):
|
||||
return gn(x)
|
||||
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Encountered aliasing during higher order op tracing",
|
||||
):
|
||||
opt_fn(x)
|
||||
|
||||
def test_mod_attr_aliasing(self):
|
||||
class MutateParam(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.ones(8)
|
||||
|
||||
def forward(self, x):
|
||||
self.a.add_(1)
|
||||
return torch.mul(x, self.a)
|
||||
|
||||
@mark_compile_region
|
||||
def gn(x):
|
||||
return mod(x)
|
||||
|
||||
def fn(x, y):
|
||||
return gn(x) * y
|
||||
|
||||
mod = MutateParam()
|
||||
x = torch.randn(8, requires_grad=False)
|
||||
y = torch.randn(8, requires_grad=False)
|
||||
|
||||
fn(x, y)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Encountered input mutation during higher order op tracing",
|
||||
):
|
||||
opt_fn(x, y)
|
||||
|
||||
|
||||
@ -63,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
ShapeEnv,
|
||||
)
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
@ -165,6 +166,18 @@ class VariableTrackerCacheKey:
|
||||
source: Source
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AliasingInfo:
|
||||
has_aliasing: bool
|
||||
msg: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MutationInfo:
|
||||
has_mutation: bool
|
||||
msg: str
|
||||
|
||||
|
||||
class VariableTrackerCache:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
@ -2023,6 +2036,13 @@ class SubgraphTracer(fx.Tracer):
|
||||
|
||||
# This is used to create a unique name for the placeholder
|
||||
self._used_names: OrderedSet[str] = OrderedSet()
|
||||
# Stores the versions of the input tensors at the time they are inserted
|
||||
# as placeholders in the graph. This is used to track input mutation.
|
||||
self._input_versions_at_beginning: list[int] = []
|
||||
if torch.is_inference_mode_enabled():
|
||||
raise RuntimeError(
|
||||
"Inference mode is supposed to be disabled during compilation. Please open an issue."
|
||||
)
|
||||
|
||||
# preserve original meta if it is available
|
||||
def _maybe_preserve_original_meta(self, tx, node):
|
||||
@ -2273,6 +2293,8 @@ class SubgraphTracer(fx.Tracer):
|
||||
def create_graph_input(
|
||||
self, name, type_expr, example_value, before=False, source=None
|
||||
):
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
self._input_versions_at_beginning.append(example_value._version)
|
||||
log.debug(
|
||||
"create_graph_input %s %s %s at debug_level %s before=%s",
|
||||
name,
|
||||
@ -2690,6 +2712,77 @@ class SubgraphTracer(fx.Tracer):
|
||||
# Sort the symbols so that we can have a deterministic lifting order
|
||||
return sorted(to_be_bound, key=lambda s: s.name)
|
||||
|
||||
def has_input_mutation(self):
|
||||
input_versions_at_beginning = self._input_versions_at_beginning
|
||||
input_nodes = []
|
||||
|
||||
input_versions_at_end = []
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
example_value = node.meta["example_value"]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
input_versions_at_end.append(example_value._version)
|
||||
input_nodes.append(node)
|
||||
else:
|
||||
break
|
||||
|
||||
mutated_inputs = [
|
||||
i
|
||||
for i, (v1, v2) in enumerate(
|
||||
zip(input_versions_at_beginning, input_versions_at_end)
|
||||
)
|
||||
if v1 != v2
|
||||
]
|
||||
|
||||
if len(mutated_inputs):
|
||||
mutated_nodes = [input_nodes[i] for i in mutated_inputs]
|
||||
msg = f"Input mutation detected at {mutated_nodes}"
|
||||
return MutationInfo(True, msg)
|
||||
|
||||
return MutationInfo(False, "")
|
||||
|
||||
def has_aliasing(self):
|
||||
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
example_value = node.meta["example_value"]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in input_storages:
|
||||
# input-input aliasing
|
||||
msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
|
||||
return AliasingInfo(True, msg)
|
||||
input_storages[storage] = node
|
||||
else:
|
||||
break
|
||||
|
||||
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
out_nodes = self.graph.find_nodes(op="output")[0]
|
||||
for out_node in out_nodes.args[0]:
|
||||
if out_node:
|
||||
example_value = out_node.meta["example_value"]
|
||||
assert not isinstance(example_value, list)
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in output_storages:
|
||||
# output-output aliasing
|
||||
msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
|
||||
return AliasingInfo(True, msg)
|
||||
output_storages[storage] = out_node
|
||||
|
||||
intersected_storages = input_storages.keys() & output_storages.keys()
|
||||
if len(intersected_storages) > 0:
|
||||
# input-output aliasing
|
||||
aliased = [
|
||||
(input_storages[s], output_storages[s]) for s in intersected_storages
|
||||
]
|
||||
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
|
||||
msg = f"Input-to-output aliasing detected at nodes {aliased}"
|
||||
return AliasingInfo(True, msg)
|
||||
|
||||
return AliasingInfo(False, "")
|
||||
|
||||
|
||||
# NOTE: [HigherOrderOperator tracing design]
|
||||
# Ignoring HigherOrderOperators for a moment,
|
||||
|
||||
@ -50,6 +50,7 @@ from ..exc import (
|
||||
IncorrectUsage,
|
||||
UncapturedHigherOrderOpError,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from ..source import AttrSource, DictGetItemSource
|
||||
@ -506,6 +507,9 @@ def speculate_subgraph(
|
||||
restore_side_effects=True,
|
||||
should_flatten_outputs=False,
|
||||
under_activation_checkpoint=False,
|
||||
# TODO - supports input_mutation and aliasing should be False by default for strictness
|
||||
supports_input_mutation=True,
|
||||
supports_aliasing=True,
|
||||
# Pass in an originating tracer - this is needed for preserving context
|
||||
# across fwd-bwd for autograd.Function
|
||||
tracer=None,
|
||||
@ -694,6 +698,34 @@ def speculate_subgraph(
|
||||
if len(lifted_freevars) > 0:
|
||||
move_lifted_freevars_phs_to_end(graph, lifted_freevars)
|
||||
|
||||
if not supports_input_mutation:
|
||||
mutation_info = subtracer.has_input_mutation()
|
||||
if mutation_info.has_mutation:
|
||||
context = f"{mutation_info.msg} in\n {graph}"
|
||||
unimplemented_v2(
|
||||
gb_type=f"Encountered input mutation during higher order op tracing for HOP - {source_target.name()}",
|
||||
context=context,
|
||||
explanation="Higher order ops do not support input mutation",
|
||||
hints=[
|
||||
"Consider using the debug context to change user code to avoid mutation.",
|
||||
"Please open an issue.",
|
||||
],
|
||||
)
|
||||
|
||||
if not supports_aliasing:
|
||||
aliasing_info = subtracer.has_aliasing()
|
||||
if aliasing_info.has_aliasing:
|
||||
context = f"{aliasing_info.msg} in\n {graph}"
|
||||
unimplemented_v2(
|
||||
gb_type=f"Encountered aliasing during higher order op tracing for HOP - {source_target.name()}",
|
||||
context=context,
|
||||
explanation="Higher order ops do not support aliasing",
|
||||
hints=[
|
||||
"Consider using the debug context to change user code to avoid aliasing.",
|
||||
"Please open an issue.",
|
||||
],
|
||||
)
|
||||
|
||||
return (
|
||||
(output, treespec),
|
||||
graph,
|
||||
@ -1794,6 +1826,11 @@ class FunctionalCallVariable(FunctorchHigherOrderVariable):
|
||||
|
||||
|
||||
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supports_input_mutation = True
|
||||
self.supports_aliasing = True
|
||||
|
||||
def install_subgraph_in_output_graph(
|
||||
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body"
|
||||
):
|
||||
@ -1828,6 +1865,8 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
source_target=self.value,
|
||||
should_flatten_outputs=True,
|
||||
under_activation_checkpoint=under_activation_checkpoint,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
|
||||
@ -3039,6 +3078,11 @@ def hash_graph_and_inputs(tx, gmod, fake_inputs):
|
||||
|
||||
|
||||
class BaseHOPVariable(WrapHigherOrderVariable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supports_input_mutation = False
|
||||
self.supports_aliasing = False
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
@ -3061,20 +3105,6 @@ class BaseHOPVariable(WrapHigherOrderVariable):
|
||||
)
|
||||
assert len(p_kwargs) == 0
|
||||
|
||||
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
||||
|
||||
fake_inputs = [
|
||||
node.meta["example_value"]
|
||||
for node in body_gmod.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
]
|
||||
|
||||
if has_potential_input_alias_or_mutation(body_gmod, fake_inputs):
|
||||
raise RuntimeError(
|
||||
f"{self.value._name} where the inputs are mutated or the "
|
||||
f"outputs are aliases of the inputs. Please ensure that this doesn't happen."
|
||||
)
|
||||
|
||||
flat_example_value = pytree.tree_map_only(
|
||||
torch.fx.Proxy,
|
||||
lambda a: a.node.meta["example_value"],
|
||||
@ -3087,6 +3117,11 @@ class BaseHOPVariable(WrapHigherOrderVariable):
|
||||
|
||||
|
||||
class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supports_input_mutation = False
|
||||
self.supports_aliasing = False
|
||||
|
||||
def install_subgraph_in_output_graph(
|
||||
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name
|
||||
):
|
||||
@ -3094,19 +3129,12 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
# inputs have already been seen before. If yes, the subgraph is already
|
||||
# installed in the output graph and we can just access the subgraph
|
||||
# using the saved attr name.
|
||||
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
||||
|
||||
fake_inputs = [
|
||||
node.meta["example_value"]
|
||||
for node in body_gmod.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
]
|
||||
|
||||
# TODO(anijain2305) - This might be too big of a limitation. Consider
|
||||
# supporting mutation/aliasing in HOP itself to remove this restriction.
|
||||
if has_potential_input_alias_or_mutation(body_gmod, fake_inputs):
|
||||
unimplemented("NYI: invoke_subgraph with aliasing/mutation")
|
||||
|
||||
key = hash_graph_and_inputs(tx, body_gmod, fake_inputs)
|
||||
|
||||
invoke_subgraph_cache = (
|
||||
|
||||
Reference in New Issue
Block a user