mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[dynamo][invoke_subgraph] Input aliasing and mutation check in Dynamo (#148953)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148953 Approved by: https://github.com/zou3519 ghstack dependencies: #149087, #149667, #150036
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							c18e2ce53b
						
					
				
				
					commit
					c9ebf517c2
				
			| @ -159,10 +159,15 @@ class GraphModule(torch.nn.Module): | ||||
|         def f(inner, x, y): | ||||
|             return invoke_quant_test(inner, x, y, scheme="nf4") | ||||
|  | ||||
|         with self.assertRaisesRegex(RuntimeError, "aliases of the inputs"): | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, "Encountered aliasing during higher order op tracing for HOP" | ||||
|         ): | ||||
|             f(inner, x, y) | ||||
|  | ||||
|         with self.assertRaisesRegex(RuntimeError, "inputs are mutated"): | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, | ||||
|             "Encountered input mutation during higher order op tracing for HOP", | ||||
|         ): | ||||
|             f(inner2, x, y) | ||||
|  | ||||
|     def test_eager_call(self): | ||||
|  | ||||
| @ -115,7 +115,58 @@ class TestInvokeSubgraphCompile(TestCase): | ||||
|  | ||||
|         x = torch.randn(8, requires_grad=True) | ||||
|         y = torch.randn(8, requires_grad=True) | ||||
|         ref = gn(x, y) | ||||
|         ref = fn(x, y) | ||||
|  | ||||
|         x_clone = x.detach().clone().requires_grad_(True) | ||||
|         y_clone = y.detach().clone().requires_grad_(True) | ||||
|         res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone) | ||||
|  | ||||
|         # Run backward | ||||
|         ref.sum().backward() | ||||
|         res.sum().backward() | ||||
|  | ||||
|         self.assertEqual(ref, res) | ||||
|         self.assertEqual(x.grad, x_clone.grad) | ||||
|         self.assertEqual(y.grad, y_clone.grad) | ||||
|  | ||||
|     def test_list(self): | ||||
|         @mark_compile_region | ||||
|         def gn(x, y): | ||||
|             return [torch.mul(x, y), torch.add(x, y)] | ||||
|  | ||||
|         def fn(x, y): | ||||
|             lst = gn(x, y) | ||||
|             lst.append(torch.sin(x)) | ||||
|             return lst[0] + lst[1] + lst[2] | ||||
|  | ||||
|         x = torch.randn(8, requires_grad=True) | ||||
|         y = torch.randn(8, requires_grad=True) | ||||
|         ref = fn(x, y) | ||||
|  | ||||
|         x_clone = x.detach().clone().requires_grad_(True) | ||||
|         y_clone = y.detach().clone().requires_grad_(True) | ||||
|         res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone) | ||||
|  | ||||
|         # Run backward | ||||
|         ref.sum().backward() | ||||
|         res.sum().backward() | ||||
|  | ||||
|         self.assertEqual(ref, res) | ||||
|         self.assertEqual(x.grad, x_clone.grad) | ||||
|         self.assertEqual(y.grad, y_clone.grad) | ||||
|  | ||||
|     def test_tuple_of_tuple(self): | ||||
|         @mark_compile_region | ||||
|         def gn(x, y): | ||||
|             return ((torch.mul(x, y),), torch.add(x, y)) | ||||
|  | ||||
|         def fn(x, y): | ||||
|             tup = gn(x, y) | ||||
|             return tup[0][0] + tup[1] | ||||
|  | ||||
|         x = torch.randn(8, requires_grad=True) | ||||
|         y = torch.randn(8, requires_grad=True) | ||||
|         ref = fn(x, y) | ||||
|  | ||||
|         x_clone = x.detach().clone().requires_grad_(True) | ||||
|         y_clone = y.detach().clone().requires_grad_(True) | ||||
| @ -477,7 +528,29 @@ class GraphModule(torch.nn.Module): | ||||
|  | ||||
|         opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) | ||||
|         with self.assertRaisesRegex( | ||||
|             torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing" | ||||
|             torch._dynamo.exc.Unsupported, | ||||
|             "Encountered input mutation during higher order op tracing for HOP - invoke_subgraph", | ||||
|         ): | ||||
|             opt_fn(x, y) | ||||
|  | ||||
|     def test_input_mutation_inference_mode(self): | ||||
|         @mark_compile_region | ||||
|         def gn(x, y): | ||||
|             x.add_(1) | ||||
|             return torch.mul(x, y) | ||||
|  | ||||
|         def fn(x, y): | ||||
|             z = torch.cos(x) | ||||
|             with torch.inference_mode(): | ||||
|                 return gn(torch.cos(z), y) | ||||
|  | ||||
|         opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) | ||||
|         x = torch.randn(8, requires_grad=False) | ||||
|         y = torch.randn(8, requires_grad=False) | ||||
|  | ||||
|         with self.assertRaisesRegex( | ||||
|             torch._dynamo.exc.Unsupported, | ||||
|             "Encountered input mutation during higher order op tracing", | ||||
|         ): | ||||
|             opt_fn(x, y) | ||||
|  | ||||
| @ -520,7 +593,7 @@ class GraphModule(torch.nn.Module): | ||||
|         ): | ||||
|             opt_fn(x) | ||||
|  | ||||
|     def test_input_aliasing(self): | ||||
|     def test_input_output_aliasing(self): | ||||
|         @mark_compile_region | ||||
|         def gn(x, y): | ||||
|             return (x, torch.mul(x, y)) | ||||
| @ -534,7 +607,73 @@ class GraphModule(torch.nn.Module): | ||||
|  | ||||
|         opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) | ||||
|         with self.assertRaisesRegex( | ||||
|             torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing" | ||||
|             torch._dynamo.exc.Unsupported, | ||||
|             "Encountered aliasing during higher order op tracing", | ||||
|         ): | ||||
|             opt_fn(x, y) | ||||
|  | ||||
|     def test_input_input_aliasing(self): | ||||
|         @mark_compile_region | ||||
|         def gn(x, y): | ||||
|             return torch.mul(x, y) | ||||
|  | ||||
|         def fn(x): | ||||
|             return gn(x, x.view(1, 8)) | ||||
|  | ||||
|         x = torch.randn(8, requires_grad=False) | ||||
|  | ||||
|         opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) | ||||
|         with self.assertRaisesRegex( | ||||
|             torch._dynamo.exc.Unsupported, | ||||
|             "Encountered aliasing during higher order op tracing", | ||||
|         ): | ||||
|             opt_fn(x) | ||||
|  | ||||
|     def test_output_output_aliasing(self): | ||||
|         @mark_compile_region | ||||
|         def gn(x): | ||||
|             z = torch.cos(x) | ||||
|             return z, z.view(1, 8) | ||||
|  | ||||
|         def fn(x): | ||||
|             return gn(x) | ||||
|  | ||||
|         x = torch.randn(8, requires_grad=False) | ||||
|  | ||||
|         opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) | ||||
|         with self.assertRaisesRegex( | ||||
|             torch._dynamo.exc.Unsupported, | ||||
|             "Encountered aliasing during higher order op tracing", | ||||
|         ): | ||||
|             opt_fn(x) | ||||
|  | ||||
|     def test_mod_attr_aliasing(self): | ||||
|         class MutateParam(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.a = torch.ones(8) | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 self.a.add_(1) | ||||
|                 return torch.mul(x, self.a) | ||||
|  | ||||
|         @mark_compile_region | ||||
|         def gn(x): | ||||
|             return mod(x) | ||||
|  | ||||
|         def fn(x, y): | ||||
|             return gn(x) * y | ||||
|  | ||||
|         mod = MutateParam() | ||||
|         x = torch.randn(8, requires_grad=False) | ||||
|         y = torch.randn(8, requires_grad=False) | ||||
|  | ||||
|         fn(x, y) | ||||
|  | ||||
|         opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) | ||||
|         with self.assertRaisesRegex( | ||||
|             torch._dynamo.exc.Unsupported, | ||||
|             "Encountered input mutation during higher order op tracing", | ||||
|         ): | ||||
|             opt_fn(x, y) | ||||
|  | ||||
|  | ||||
| @ -63,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import ( | ||||
|     ShapeEnv, | ||||
| ) | ||||
| from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts | ||||
| from torch.multiprocessing.reductions import StorageWeakRef | ||||
| from torch.utils._ordered_set import OrderedSet | ||||
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass | ||||
|  | ||||
| @ -165,6 +166,18 @@ class VariableTrackerCacheKey: | ||||
|     source: Source | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class AliasingInfo: | ||||
|     has_aliasing: bool | ||||
|     msg: str | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class MutationInfo: | ||||
|     has_mutation: bool | ||||
|     msg: str | ||||
|  | ||||
|  | ||||
| class VariableTrackerCache: | ||||
|     def __init__(self): | ||||
|         self.cache = {} | ||||
| @ -2023,6 +2036,13 @@ class SubgraphTracer(fx.Tracer): | ||||
|  | ||||
|         # This is used to create a unique name for the placeholder | ||||
|         self._used_names: OrderedSet[str] = OrderedSet() | ||||
|         # Stores the versions of the input tensors at the time they are inserted | ||||
|         # as placeholders in the graph. This is used to track input mutation. | ||||
|         self._input_versions_at_beginning: list[int] = [] | ||||
|         if torch.is_inference_mode_enabled(): | ||||
|             raise RuntimeError( | ||||
|                 "Inference mode is supposed to be disabled during compilation. Please open an issue." | ||||
|             ) | ||||
|  | ||||
|     # preserve original meta if it is available | ||||
|     def _maybe_preserve_original_meta(self, tx, node): | ||||
| @ -2273,6 +2293,8 @@ class SubgraphTracer(fx.Tracer): | ||||
|     def create_graph_input( | ||||
|         self, name, type_expr, example_value, before=False, source=None | ||||
|     ): | ||||
|         if isinstance(example_value, torch.Tensor): | ||||
|             self._input_versions_at_beginning.append(example_value._version) | ||||
|         log.debug( | ||||
|             "create_graph_input %s %s %s at debug_level %s before=%s", | ||||
|             name, | ||||
| @ -2690,6 +2712,77 @@ class SubgraphTracer(fx.Tracer): | ||||
|         # Sort the symbols so that we can have a deterministic lifting order | ||||
|         return sorted(to_be_bound, key=lambda s: s.name) | ||||
|  | ||||
|     def has_input_mutation(self): | ||||
|         input_versions_at_beginning = self._input_versions_at_beginning | ||||
|         input_nodes = [] | ||||
|  | ||||
|         input_versions_at_end = [] | ||||
|         for node in self.graph.nodes: | ||||
|             if node.op == "placeholder": | ||||
|                 example_value = node.meta["example_value"] | ||||
|                 if isinstance(example_value, torch.Tensor): | ||||
|                     input_versions_at_end.append(example_value._version) | ||||
|                     input_nodes.append(node) | ||||
|             else: | ||||
|                 break | ||||
|  | ||||
|         mutated_inputs = [ | ||||
|             i | ||||
|             for i, (v1, v2) in enumerate( | ||||
|                 zip(input_versions_at_beginning, input_versions_at_end) | ||||
|             ) | ||||
|             if v1 != v2 | ||||
|         ] | ||||
|  | ||||
|         if len(mutated_inputs): | ||||
|             mutated_nodes = [input_nodes[i] for i in mutated_inputs] | ||||
|             msg = f"Input mutation detected at {mutated_nodes}" | ||||
|             return MutationInfo(True, msg) | ||||
|  | ||||
|         return MutationInfo(False, "") | ||||
|  | ||||
|     def has_aliasing(self): | ||||
|         input_storages: dict[StorageWeakRef, torch.fx.Node] = dict() | ||||
|  | ||||
|         for node in self.graph.nodes: | ||||
|             if node.op == "placeholder": | ||||
|                 example_value = node.meta["example_value"] | ||||
|                 if isinstance(example_value, torch.Tensor): | ||||
|                     storage = StorageWeakRef(example_value._typed_storage()) | ||||
|                     if storage in input_storages: | ||||
|                         # input-input aliasing | ||||
|                         msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}" | ||||
|                         return AliasingInfo(True, msg) | ||||
|                     input_storages[storage] = node | ||||
|             else: | ||||
|                 break | ||||
|  | ||||
|         output_storages: dict[StorageWeakRef, torch.fx.Node] = dict() | ||||
|         out_nodes = self.graph.find_nodes(op="output")[0] | ||||
|         for out_node in out_nodes.args[0]: | ||||
|             if out_node: | ||||
|                 example_value = out_node.meta["example_value"] | ||||
|                 assert not isinstance(example_value, list) | ||||
|                 if isinstance(example_value, torch.Tensor): | ||||
|                     storage = StorageWeakRef(example_value._typed_storage()) | ||||
|                     if storage in output_storages: | ||||
|                         # output-output aliasing | ||||
|                         msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}" | ||||
|                         return AliasingInfo(True, msg) | ||||
|                     output_storages[storage] = out_node | ||||
|  | ||||
|         intersected_storages = input_storages.keys() & output_storages.keys() | ||||
|         if len(intersected_storages) > 0: | ||||
|             # input-output aliasing | ||||
|             aliased = [ | ||||
|                 (input_storages[s], output_storages[s]) for s in intersected_storages | ||||
|             ] | ||||
|             aliased = ", ".join([f"{i} and {o}" for i, o in aliased]) | ||||
|             msg = f"Input-to-output aliasing detected at nodes {aliased}" | ||||
|             return AliasingInfo(True, msg) | ||||
|  | ||||
|         return AliasingInfo(False, "") | ||||
|  | ||||
|  | ||||
| # NOTE: [HigherOrderOperator tracing design] | ||||
| # Ignoring HigherOrderOperators for a moment, | ||||
|  | ||||
| @ -50,6 +50,7 @@ from ..exc import ( | ||||
|     IncorrectUsage, | ||||
|     UncapturedHigherOrderOpError, | ||||
|     unimplemented, | ||||
|     unimplemented_v2, | ||||
|     Unsupported, | ||||
| ) | ||||
| from ..source import AttrSource, DictGetItemSource | ||||
| @ -506,6 +507,9 @@ def speculate_subgraph( | ||||
|     restore_side_effects=True, | ||||
|     should_flatten_outputs=False, | ||||
|     under_activation_checkpoint=False, | ||||
|     # TODO - supports input_mutation and aliasing should be False by default for strictness | ||||
|     supports_input_mutation=True, | ||||
|     supports_aliasing=True, | ||||
|     # Pass in an originating tracer - this is needed for preserving context | ||||
|     # across fwd-bwd for autograd.Function | ||||
|     tracer=None, | ||||
| @ -694,6 +698,34 @@ def speculate_subgraph( | ||||
|                 if len(lifted_freevars) > 0: | ||||
|                     move_lifted_freevars_phs_to_end(graph, lifted_freevars) | ||||
|  | ||||
|                 if not supports_input_mutation: | ||||
|                     mutation_info = subtracer.has_input_mutation() | ||||
|                     if mutation_info.has_mutation: | ||||
|                         context = f"{mutation_info.msg} in\n {graph}" | ||||
|                         unimplemented_v2( | ||||
|                             gb_type=f"Encountered input mutation during higher order op tracing for HOP - {source_target.name()}", | ||||
|                             context=context, | ||||
|                             explanation="Higher order ops do not support input mutation", | ||||
|                             hints=[ | ||||
|                                 "Consider using the debug context to change user code to avoid mutation.", | ||||
|                                 "Please open an issue.", | ||||
|                             ], | ||||
|                         ) | ||||
|  | ||||
|                 if not supports_aliasing: | ||||
|                     aliasing_info = subtracer.has_aliasing() | ||||
|                     if aliasing_info.has_aliasing: | ||||
|                         context = f"{aliasing_info.msg} in\n {graph}" | ||||
|                         unimplemented_v2( | ||||
|                             gb_type=f"Encountered aliasing during higher order op tracing for HOP - {source_target.name()}", | ||||
|                             context=context, | ||||
|                             explanation="Higher order ops do not support aliasing", | ||||
|                             hints=[ | ||||
|                                 "Consider using the debug context to change user code to avoid aliasing.", | ||||
|                                 "Please open an issue.", | ||||
|                             ], | ||||
|                         ) | ||||
|  | ||||
|                 return ( | ||||
|                     (output, treespec), | ||||
|                     graph, | ||||
| @ -1794,6 +1826,11 @@ class FunctionalCallVariable(FunctorchHigherOrderVariable): | ||||
|  | ||||
|  | ||||
| class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.supports_input_mutation = True | ||||
|         self.supports_aliasing = True | ||||
|  | ||||
|     def install_subgraph_in_output_graph( | ||||
|         self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" | ||||
|     ): | ||||
| @ -1828,6 +1865,8 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): | ||||
|             source_target=self.value, | ||||
|             should_flatten_outputs=True, | ||||
|             under_activation_checkpoint=under_activation_checkpoint, | ||||
|             supports_input_mutation=self.supports_input_mutation, | ||||
|             supports_aliasing=self.supports_aliasing, | ||||
|         ) | ||||
|  | ||||
|         body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) | ||||
| @ -3039,6 +3078,11 @@ def hash_graph_and_inputs(tx, gmod, fake_inputs): | ||||
|  | ||||
|  | ||||
| class BaseHOPVariable(WrapHigherOrderVariable): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.supports_input_mutation = False | ||||
|         self.supports_aliasing = False | ||||
|  | ||||
|     def python_type(self): | ||||
|         return type(self.value) | ||||
|  | ||||
| @ -3061,20 +3105,6 @@ class BaseHOPVariable(WrapHigherOrderVariable): | ||||
|         ) | ||||
|         assert len(p_kwargs) == 0 | ||||
|  | ||||
|         from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation | ||||
|  | ||||
|         fake_inputs = [ | ||||
|             node.meta["example_value"] | ||||
|             for node in body_gmod.graph.nodes | ||||
|             if node.op == "placeholder" | ||||
|         ] | ||||
|  | ||||
|         if has_potential_input_alias_or_mutation(body_gmod, fake_inputs): | ||||
|             raise RuntimeError( | ||||
|                 f"{self.value._name} where the inputs are mutated or the " | ||||
|                 f"outputs are aliases of the inputs. Please ensure that this doesn't happen." | ||||
|             ) | ||||
|  | ||||
|         flat_example_value = pytree.tree_map_only( | ||||
|             torch.fx.Proxy, | ||||
|             lambda a: a.node.meta["example_value"], | ||||
| @ -3087,6 +3117,11 @@ class BaseHOPVariable(WrapHigherOrderVariable): | ||||
|  | ||||
|  | ||||
| class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.supports_input_mutation = False | ||||
|         self.supports_aliasing = False | ||||
|  | ||||
|     def install_subgraph_in_output_graph( | ||||
|         self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name | ||||
|     ): | ||||
| @ -3094,19 +3129,12 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): | ||||
|         # inputs have already been seen before. If yes, the subgraph is already | ||||
|         # installed in the output graph and we can just access the subgraph | ||||
|         # using the saved attr name. | ||||
|         from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation | ||||
|  | ||||
|         fake_inputs = [ | ||||
|             node.meta["example_value"] | ||||
|             for node in body_gmod.graph.nodes | ||||
|             if node.op == "placeholder" | ||||
|         ] | ||||
|  | ||||
|         # TODO(anijain2305) - This might be too big of a limitation. Consider | ||||
|         # supporting mutation/aliasing in HOP itself to remove this restriction. | ||||
|         if has_potential_input_alias_or_mutation(body_gmod, fake_inputs): | ||||
|             unimplemented("NYI: invoke_subgraph with aliasing/mutation") | ||||
|  | ||||
|         key = hash_graph_and_inputs(tx, body_gmod, fake_inputs) | ||||
|  | ||||
|         invoke_subgraph_cache = ( | ||||
|  | ||||
		Reference in New Issue
	
	Block a user