diff --git a/docs/source/torch.compiler.config.rst b/docs/source/torch.compiler.config.md similarity index 61% rename from docs/source/torch.compiler.config.rst rename to docs/source/torch.compiler.config.md index c40b41fdb5d3..66059f07ea5b 100644 --- a/docs/source/torch.compiler.config.rst +++ b/docs/source/torch.compiler.config.md @@ -1,9 +1,14 @@ +```{eval-rst} .. currentmodule:: torch.compiler.config +``` -torch.compiler.config -===================== +# torch.compiler.config +```{eval-rst} .. automodule:: torch.compiler.config +``` +```{eval-rst} .. autodata:: torch.compiler.config.job_id +``` diff --git a/docs/source/torch.compiler_transformations.md b/docs/source/torch.compiler_transformations.md new file mode 100644 index 000000000000..7291df298f37 --- /dev/null +++ b/docs/source/torch.compiler_transformations.md @@ -0,0 +1,424 @@ +# Writing Graph Transformations on ATen IR + +## Passes + +Since the ATen IR sits at the FX Graph/GraphModule level, any +transformations written for FX Graphs can be easily applied onto the +ATen IR. If you’re familiar with writing FX graph transformations, then +this will be the same. + +The most direct way of writing transformations is by looping through the +given graph and directly manipulating the nodes within the graph. + +For example, let’s say we want to replace +`torch.ops.aten.add.Tensor()` calls with +`torch.ops.aten.mul.Tensor()` calls: + +```python +import torch + +def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + node.target = torch.ops.aten.mul.Tensor +``` + +We can also delete and append new nodes through FX utility functions +that can be found in the +[Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) +documentation. For example, if we want to insert a +`torch.ops.aten.relu.default()` after the `add` call: + +```python +import torch + +def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: + + # Specifies the insertion point. Any nodes added to the graph within + # this scope will be inserted after `node` + with gm.graph.inserting_after(node): + # Insert a new `call_function` node with op `torch.ops.aten.relu.default` + new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,)) + # Replace all the places that use `node` to now use the `new_relu_node` + node.replace_all_uses_with(new_relu_node) +``` + +In general, transformations can be roughly categorized into a couple of +axis: + +Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating +many-to-one mapping (eg. fusion) + +Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing +backwards iteration (eg. dead code elimination) + +Axis C: 1. Dependent on local node information (eg. out-variant +conversion) 2. Dependent on global graph information (eg. memory +planning) + +Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1 +2\. A.2 3. B.2, C.2 + +Although we can make all graph transformations through directly +manipulating the graph, we also provide some helper utilities for some +ease of use for the level 1 and 2 use-cases. + +### Transformer + +For level 1 uses cases (creating one-to-X mappings, doing forwards +iterations, and looking at local node information), we can utilize the +[Transformer](https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer) +class to execute each node and recreate a graph, except with the +transformations specified. + +#### One-to-One Pass + +An example for one-to-one mappings, if we wanted to replace an op A with +another op B, we can run the GraphModule, and very time we see op A, +return op B. + +An example is: + +```python +class ReplaceAddWithMul(torch.fx.Transformer): + def call_function(self, target, args, kwargs): + if target != torch.ops.aten.add.Tensor: + return super().call_function(target, args, kwargs) + return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs) + +transformed_graph_module = ReplaceAddWithMul(graph_module).transform() +``` + +The `super().call_function(target, args, kwargs, meta)` call creates a +`call_function` FX node, and returns the result of running the +operator with the given arguments. + +#### One-to-X Pass + +If we wanted to do one-to-X mappings, like replacing op A with 2 other +ops B and C, we would then make 2 calls to `super().call_function` to +create 2 FX nodes, one with op B and another with op C, and return the +result of running op C. + +For example: + +```python +class ReplaceAddWithMulSub(torch.fx.Transformer): + """ + Original: + def f(x, y): + return x + y + + After pass: + def f(x, y): + z = x * y + return z - y + """ + def call_function(self, target, args, kwargs): + if target != torch.ops.aten.add.Tensor: + return super().call_function(target, args, kwargs) + + x, y = args + + mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {}) + return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {}) + +transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform() +``` + +#### One-to-None Pass + +If we wanted to remove an op, we can just return the value passed into +the function: + +```python +class RemoveDetachPass(torch.fx.Transformer): + def call_function(self, target, args, kwargs): + if target not in ( + torch.ops.aten.detach.default, + torch.ops.aten.detach_copy.default, + ): + return super().call_function(target, args, kwargs, meta) + + assert len(args) == 1 + return args[0] + +transformed_graph_module = RemoveDetachPass(graph_module).transform() +``` + +#### Utilizing Local Information + +An example of utilizing local node information is, if we wanted to +convert all the scalars within the graph to tensors, we can run the +given `fx.GraphModule`, and for every argument that contains a scalar, +we convert it to a tensor. It might look something like: + +```python +def args_map(target, fn, args, kwargs): + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + args = list(args) + kwargs = kwargs.copy() + + # Update the argument based on the function passed + def update(key, args, schema): + args[key] = fn(args[key], schema) + + # Update each argument in the schema + for i, schema in enumerate(target._schema.arguments): + if schema.name in kwargs: + update(schema.name, kwargs, schema) + elif not schema.kwarg_only and i < len(args): + update(i, args, schema) + return tuple(args), kwargs + +class ScalarToTensorPass(torch.fx.Transformer): + def call_function(self, target, args, kwargs): + breakpoint() + def try_coerce(value, arg): + return ( + torch.tensor(value) + if isinstance(value, (float, int, bool)) + and type(arg.type) == torch.TensorType + else value + ) + + args, kwargs = args_map(target, try_coerce, args, kwargs) + return super().call_function(target, args, kwargs) + +transformed_graph_module = ScalarToTensorPass(graph_module).transform() +``` + +### Subgraph Rewriter + +For creating many-to-one mappings, we can utilize FX’s [subgraph +rewriter](https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py). +Given a `pattern`, it creates a subgraph of operators matching to the +pattern, and then replaces each matched subgraph with the +`replacement`. + +Note: + +``` +This is an inplace operation. +``` + +The `pattern` and `replacement` inputs must be callable functions or +GraphModules containing the same operators that are used within the +graph (ATen ops) so that the subgraph rewriter can find the correct +pattern in the graph. Inputs to the pattern/replacement callables will +be treated as wildcards when matching. + +An example: + +```python +from torch.fx import subgraph_rewriter + +def replace_patterns(graph_module): + def pattern(x, y): + x = torch.ops.aten.add.Tensor(x, y) + x = torch.ops.aten.mul.Tensor(x, y) + return x + + def replacement(x, y): + return torch.ops.aten.sub.Tensor(x, y) + +replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( + traced_module, pattern, replacement +) +``` + +The subgraph rewriter returns a list of `ReplacedPatterns`: + +```python +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + # List of nodes that were added into the graph + replacements: List[Node] +``` + +Note: + +``` +The nodes created by the subgraph rewriter will not have the metadata that +is populated in the matched nodes, but you can use +`ReplacedPatterns.nodes_map` to find the nodes in the original graph that +were matched, and `ReplacedPatterns.replacements` to find the nodes that +were replaced in the transformed graph. +``` + +## Pass Manager + +The +[PassManager](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py) +is a class used to run multiple passes on a given graph module. When +initializing a `PassManager` instance, we pass in a list of passes +that we want to run and set a couple of flags. To run the collection of +passes on a graph module, we can pass the graph module directly to the +`PassManager` instance. + +An example: + +```python +from torch.fx.passes.infra.pass_manager import PassManager + +pm = PassManager( + passes=[replace_add_with_div, replace_div_with_mul], + run_checks_after_each_pass=True, + suppress_check_failures=False, +) +graph_module_out = pm(graph_module) +``` + +To add a common set of checks that are run after each pass, we can call +the function `set_checks(check: Callable)` which takes in a callable +function as input. If the `run_checks_after_each_pass` flag is set, +the `check` will be called after each pass is run on the graph module. + +An example: + +```python +pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul]) + +def check_div_target(graph_module): + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target != torch.div: + raise ValueError("Target should be div!") + +pm.add_checks(check_div_target) + +pm(graph_module) # raises ValueError after replace_div_with_mul pass +``` + +## Partitioner + +There are a couple of common FX graph based partitioners we can use to +partition the graph. + +### Subgraph Matcher + +For finding subgraphs within a graph that match a specific pattern, we +can utilize FX’s +[`SubgraphMatcher`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py). + +Class Attributes: + +- `pattern (Graph)`: The targeted matching pattern. Placeholder nodes + in the graph will be treated as wildcards when matching. +- `match_output (bool)`: If True, output node in the pattern graph + will be treated as a part of the targeted pattern. If False, output + node is ignored during match. +- `match_placeholder (bool)`: If True, placeholder node in the + pattern graph will be treated as a part of the targeted pattern. If + False, placeholder nodes will be used a wildcard. +- `remove_overlapping_matches (bool)`: If True, in the case of + overlapping matches, only the first match will be returned. +- `ignore_literals (bool)`: If True, will not check if literals are + equal and will instead treat them as wildcards. + +An example: + +```python +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher + +class LargeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self._weight = torch.nn.Parameter(torch.ones(3, 3)) + self._bias = torch.nn.Parameter(torch.ones(3, 3)) + + def forward(self, x): + return torch.ops.aten.addmm.default(self._bias, x, self._weight) + +large_model_graph = torch.export(LargeModel(), inputs).graph + +class PatternModel(torch.nn.Module): + def __init__(self): + super().__init__() + self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) + self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) + + def forward(self, x): + return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) + +pattern_graph = torch.export(PatternModel(), inputs).graph + +subgraph_matcher = SubgraphMatcher(pattern_graph) +match_result = subgraph_matcher.match(large_model_graph) +``` + +The `match` function returns a list of `InternalMatch`: + +```python +@dataclass +class InternalMatch(): + # Nodes from which the match was found + anchors: List[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] = field(default_factory=dict) + # Nodes in target graph that are matched placeholder in pattern + placeholder_nodes: List[Node] = field(default_factory=list) + # Nodes in matched subgraph returned by output + returning_nodes: List[Node] = field(default_factory=list) +``` + +### Capability Based Partitioner + +To find the largest subgraphs of nodes that support a specific +invariant, we can utilize FX’s +[`CapabilityBasedPartitioner`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34). + +Class Attributes + +- `graph_module (torch.fx.GraphModule)`: The graph module we are + partitioning on. +- `operator_support (OperatorSupportBase)`: The object used to + determine if a node in the graph is supported in the partition. +- `allows_single_node_partition (bool)`: If True, allows single node + partitions to be formed. +- `non_compute_ops (Optional[Sequence[str]])`: A set of ops that are + considered to be “non-compute” (ex `torch.ops.aten.view` and + `_operator.getitem`, so that the partitioner will not create graphs + that only contain these non-compute ops +- `allowed_single_node_partition_ops (Optional[Sequence[str]])`: A + set of ops that are allowed to be in a single node partition. + +The +[`OperatorSupportBase`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1) +class is used by the partitioner to determine if a specific node in the +graph belongs in the partition. This is done by overriding the +`is_node_supported` function. You can chain multiple +`OperatorSupportBase` by using +[`chain`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150) (which +returns False if any of the OperatorSupportBase return False) and +[`any_chain`](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164) +(which returns True if any of the OperatorSupportBase returns True). + +An example: + +```python +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase + +class AddMulOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, + ] + +capability_partitioner = CapabilityBasedPartitioner( + graph_module, + op_support, +) + +# Returns a list of partitions (list of nodes that belong in each partition) +partition_list = capability_partitioner.propose_partitions() +# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph +fused_graph_module = capability_partitioner.fuse_partitions(partition_list) +``` diff --git a/docs/source/torch.compiler_transformations.rst b/docs/source/torch.compiler_transformations.rst deleted file mode 100644 index f83abd9412d4..000000000000 --- a/docs/source/torch.compiler_transformations.rst +++ /dev/null @@ -1,436 +0,0 @@ -Writing Graph Transformations on ATen IR -======================================== - -Passes ------- - -Since the ATen IR sits at the FX Graph/GraphModule level, any -transformations written for FX Graphs can be easily applied onto the -ATen IR. If you’re familiar with writing FX graph transformations, then -this will be the same. - -The most direct way of writing transformations is by looping through the -given graph and directly manipulating the nodes within the graph. - -For example, let’s say we want to replace -``torch.ops.aten.add.Tensor()`` calls with -``torch.ops.aten.mul.Tensor()`` calls: - -.. code:: python - - import torch - - def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: - node.target = torch.ops.aten.mul.Tensor - -We can also delete and append new nodes through FX utility functions -that can be found in the -`Graph `__ -documentation. For example, if we want to insert a -``torch.ops.aten.relu.default()`` after the ``add`` call: - -.. code:: python - - import torch - - def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: - - # Specifies the insertion point. Any nodes added to the graph within - # this scope will be inserted after `node` - with gm.graph.inserting_after(node): - # Insert a new `call_function` node with op `torch.ops.aten.relu.default` - new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,)) - # Replace all the places that use `node` to now use the `new_relu_node` - node.replace_all_uses_with(new_relu_node) - -In general, transformations can be roughly categorized into a couple of -axis: - -Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creating -many-to-one mapping (eg. fusion) - -Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doing -backwards iteration (eg. dead code elimination) - -Axis C: 1. Dependent on local node information (eg. out-variant -conversion) 2. Dependent on global graph information (eg. memory -planning) - -Our projection on the frequency of these use cases are: 1. A.1, B.1, C.1 -2. A.2 3. B.2, C.2 - -Although we can make all graph transformations through directly -manipulating the graph, we also provide some helper utilities for some -ease of use for the level 1 and 2 use-cases. - -Transformer -~~~~~~~~~~~ - -For level 1 uses cases (creating one-to-X mappings, doing forwards -iterations, and looking at local node information), we can utilize the -`Transformer `__ -class to execute each node and recreate a graph, except with the -transformations specified. - -One-to-One Pass -^^^^^^^^^^^^^^^ - -An example for one-to-one mappings, if we wanted to replace an op A with -another op B, we can run the GraphModule, and very time we see op A, -return op B. - -An example is: - -.. code:: python - - class ReplaceAddWithMul(torch.fx.Transformer): - def call_function(self, target, args, kwargs): - if target != torch.ops.aten.add.Tensor: - return super().call_function(target, args, kwargs) - return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs) - - transformed_graph_module = ReplaceAddWithMul(graph_module).transform() - -The ``super().call_function(target, args, kwargs, meta)`` call creates a -``call_function`` FX node, and returns the result of running the -operator with the given arguments. - -One-to-X Pass -^^^^^^^^^^^^^ - -If we wanted to do one-to-X mappings, like replacing op A with 2 other -ops B and C, we would then make 2 calls to ``super().call_function`` to -create 2 FX nodes, one with op B and another with op C, and return the -result of running op C. - -For example: - -.. code:: python - - class ReplaceAddWithMulSub(torch.fx.Transformer): - """ - Original: - def f(x, y): - return x + y - - After pass: - def f(x, y): - z = x * y - return z - y - """ - def call_function(self, target, args, kwargs): - if target != torch.ops.aten.add.Tensor: - return super().call_function(target, args, kwargs) - - x, y = args - - mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {}) - return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {}) - - transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform() - -One-to-None Pass -^^^^^^^^^^^^^^^^ - -If we wanted to remove an op, we can just return the value passed into -the function: - -.. code:: python - - class RemoveDetachPass(torch.fx.Transformer): - def call_function(self, target, args, kwargs): - if target not in ( - torch.ops.aten.detach.default, - torch.ops.aten.detach_copy.default, - ): - return super().call_function(target, args, kwargs, meta) - - assert len(args) == 1 - return args[0] - - transformed_graph_module = RemoveDetachPass(graph_module).transform() - -Utilizing Local Information -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -An example of utilizing local node information is, if we wanted to -convert all the scalars within the graph to tensors, we can run the -given ``fx.GraphModule``, and for every argument that contains a scalar, -we convert it to a tensor. It might look something like: - -.. code:: python - - def args_map(target, fn, args, kwargs): - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - args = list(args) - kwargs = kwargs.copy() - - # Update the argument based on the function passed - def update(key, args, schema): - args[key] = fn(args[key], schema) - - # Update each argument in the schema - for i, schema in enumerate(target._schema.arguments): - if schema.name in kwargs: - update(schema.name, kwargs, schema) - elif not schema.kwarg_only and i < len(args): - update(i, args, schema) - return tuple(args), kwargs - - class ScalarToTensorPass(torch.fx.Transformer): - def call_function(self, target, args, kwargs): - breakpoint() - def try_coerce(value, arg): - return ( - torch.tensor(value) - if isinstance(value, (float, int, bool)) - and type(arg.type) == torch.TensorType - else value - ) - - args, kwargs = args_map(target, try_coerce, args, kwargs) - return super().call_function(target, args, kwargs) - - transformed_graph_module = ScalarToTensorPass(graph_module).transform() - -Subgraph Rewriter -~~~~~~~~~~~~~~~~~ - -For creating many-to-one mappings, we can utilize FX’s `subgraph -rewriter `__. -Given a ``pattern``, it creates a subgraph of operators matching to the -pattern, and then replaces each matched subgraph with the -``replacement``. - -Note: - -:: - - This is an inplace operation. - -The ``pattern`` and ``replacement`` inputs must be callable functions or -GraphModules containing the same operators that are used within the -graph (ATen ops) so that the subgraph rewriter can find the correct -pattern in the graph. Inputs to the pattern/replacement callables will -be treated as wildcards when matching. - -An example: - -.. code:: python - - from torch.fx import subgraph_rewriter - - def replace_patterns(graph_module): - def pattern(x, y): - x = torch.ops.aten.add.Tensor(x, y) - x = torch.ops.aten.mul.Tensor(x, y) - return x - - def replacement(x, y): - return torch.ops.aten.sub.Tensor(x, y) - - replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( - traced_module, pattern, replacement - ) - -The subgraph rewriter returns a list of ``ReplacedPatterns``: - -.. code:: python - - @dataclass - class ReplacedPatterns: - # Node from which the match was found - anchor: Node - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] - # List of nodes that were added into the graph - replacements: List[Node] - -Note: - -:: - - The nodes created by the subgraph rewriter will not have the metadata that - is populated in the matched nodes, but you can use - `ReplacedPatterns.nodes_map` to find the nodes in the original graph that - were matched, and `ReplacedPatterns.replacements` to find the nodes that - were replaced in the transformed graph. - -Pass Manager ------------- - -The -`PassManager `_ -is a class used to run multiple passes on a given graph module. When -initializing a ``PassManager`` instance, we pass in a list of passes -that we want to run and set a couple of flags. To run the collection of -passes on a graph module, we can pass the graph module directly to the -``PassManager`` instance. - -An example: - -.. code:: python - - from torch.fx.passes.infra.pass_manager import PassManager - - pm = PassManager( - passes=[replace_add_with_div, replace_div_with_mul], - run_checks_after_each_pass=True, - suppress_check_failures=False, - ) - graph_module_out = pm(graph_module) - -To add a common set of checks that are run after each pass, we can call -the function ``set_checks(check: Callable)`` which takes in a callable -function as input. If the ``run_checks_after_each_pass`` flag is set, -the ``check`` will be called after each pass is run on the graph module. - -An example: - -.. code:: python - - pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul]) - - def check_div_target(graph_module): - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target != torch.div: - raise ValueError("Target should be div!") - - pm.add_checks(check_div_target) - - pm(graph_module) # raises ValueError after replace_div_with_mul pass - -Partitioner ------------ - -There are a couple of common FX graph based partitioners we can use to -partition the graph. - -Subgraph Matcher -~~~~~~~~~~~~~~~~ - -For finding subgraphs within a graph that match a specific pattern, we -can utilize FX’s -`SubgraphMatcher `__. - -Class Attributes: - -- ``pattern (Graph)``: The targeted matching pattern. Placeholder nodes - in the graph will be treated as wildcards when matching. -- ``match_output (bool)``: If True, output node in the pattern graph - will be treated as a part of the targeted pattern. If False, output - node is ignored during match. -- ``match_placeholder (bool)``: If True, placeholder node in the - pattern graph will be treated as a part of the targeted pattern. If - False, placeholder nodes will be used a wildcard. -- ``remove_overlapping_matches (bool)``: If True, in the case of - overlapping matches, only the first match will be returned. -- ``ignore_literals (bool)``: If True, will not check if literals are - equal and will instead treat them as wildcards. - -An example: - -.. code:: python - - from torch.fx.passes.utils.matcher_utils import SubgraphMatcher - - class LargeModel(torch.nn.Module): - def __init__(self): - super().__init__() - self._weight = torch.nn.Parameter(torch.ones(3, 3)) - self._bias = torch.nn.Parameter(torch.ones(3, 3)) - - def forward(self, x): - return torch.ops.aten.addmm.default(self._bias, x, self._weight) - - large_model_graph = torch.export(LargeModel(), inputs).graph - - class PatternModel(torch.nn.Module): - def __init__(self): - super().__init__() - self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) - self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) - - def forward(self, x): - return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) - - pattern_graph = torch.export(PatternModel(), inputs).graph - - subgraph_matcher = SubgraphMatcher(pattern_graph) - match_result = subgraph_matcher.match(large_model_graph) - -The ``match`` function returns a list of ``InternalMatch``: - -.. code:: python - - @dataclass - class InternalMatch(): - # Nodes from which the match was found - anchors: List[Node] - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] = field(default_factory=dict) - # Nodes in target graph that are matched placeholder in pattern - placeholder_nodes: List[Node] = field(default_factory=list) - # Nodes in matched subgraph returned by output - returning_nodes: List[Node] = field(default_factory=list) - -Capability Based Partitioner -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To find the largest subgraphs of nodes that support a specific -invariant, we can utilize FX’s -`CapabilityBasedPartitioner `__. - -Class Attributes - -- ``graph_module (torch.fx.GraphModule)``: The graph module we are - partitioning on. -- ``operator_support (OperatorSupportBase)``: The object used to - determine if a node in the graph is supported in the partition. -- ``allows_single_node_partition (bool)``: If True, allows single node - partitions to be formed. -- ``non_compute_ops (Optional[Sequence[str]])``: A set of ops that are - considered to be “non-compute” (ex ``torch.ops.aten.view`` and - ``_operator.getitem``, so that the partitioner will not create graphs - that only contain these non-compute ops -- ``allowed_single_node_partition_ops (Optional[Sequence[str]])``: A - set of ops that are allowed to be in a single node partition. - -The -`OperatorSupportBase `__ -class is used by the partitioner to determine if a specific node in the -graph belongs in the partition. This is done by overriding the -``is_node_supported`` function. You can chain multiple -``OperatorSupportBase`` by using -`chain `__\ (which -returns False if any of the OperatorSupportBase return False) and -`any_chain `__ -(which returns True if any of the OperatorSupportBase returns True). - -An example: - -.. code:: python - - from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner - from torch.fx.passes.operator_support import any_chain, OperatorSupportBase - - class AddMulOperatorSupport(OperatorSupportBase): - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op == "call_function" and node.target in [ - torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, - ] - - capability_partitioner = CapabilityBasedPartitioner( - graph_module, - op_support, - ) - - # Returns a list of partitions (list of nodes that belong in each partition) - partition_list = capability_partitioner.propose_partitions() - # Fuses the partitions into graph modules and inserts `call_module` nodes in the graph - fused_graph_module = capability_partitioner.fuse_partitions(partition_list)