mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor-FX] Support torch.cond (#163234)
# Feature Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`: - Submodules as stored as attributes, and accessed via `getattr`. - The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs. # Implementation overview The FX backend generates code for subgraphs using the following steps: 1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`. a. We also codegen the true/false subgraphs at this time, storing their subgms for later. 2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export. 3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`. # Implementation details This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs. Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`. Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning. In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.) Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR. This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs. # Test plan This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}}) return buf1 ``` It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234 Approved by: https://github.com/angelayi, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
a31acf32bd
commit
e56dd5d770
@ -17,7 +17,6 @@ from torch._dynamo.exc import BackendCompilerFailed
|
||||
from torch._dynamo.utils import same
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.common import register_backend_for_device
|
||||
from torch._inductor.codegen.cpp import CppScheduling
|
||||
from torch._inductor.codegen.triton import TritonScheduling
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
@ -43,15 +42,17 @@ if HAS_GPU:
|
||||
|
||||
from torch.testing._internal.triton_utils import add_kernel_2d_autotuned
|
||||
|
||||
test_config = {
|
||||
"compile_threads": 1,
|
||||
"alignment_asserts": False,
|
||||
"size_asserts": False,
|
||||
"scalar_asserts": False,
|
||||
"nan_asserts": False,
|
||||
}
|
||||
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch(
|
||||
compile_threads=1,
|
||||
alignment_asserts=False,
|
||||
size_asserts=False,
|
||||
scalar_asserts=False,
|
||||
nan_asserts=False,
|
||||
)
|
||||
@config.patch(test_config)
|
||||
@instantiate_parametrized_tests
|
||||
class FxirTestCase(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
@ -116,8 +117,19 @@ class FxirTestCase(InductorTestCase):
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# Register the FX backend.
|
||||
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
|
||||
# Register the FX backend, storing the default for later.
|
||||
common.init_backend_registration()
|
||||
cls._default_backend = common.device_codegens[cls.device]
|
||||
common.register_backend_for_device(
|
||||
cls.device, TritonScheduling, WrapperFxCodegen
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
# Restore the default backend.
|
||||
common.device_codegens[cls.device] = cls._default_backend
|
||||
|
||||
def test_basic(self):
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
@ -630,21 +642,47 @@ class FxirTestCase(InductorTestCase):
|
||||
# Check for the fallback.
|
||||
self.assertEqual(self._count_ops(gm, fallback_op), 1)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_subgraph_raises(self):
|
||||
@parametrize("pred", (False, True))
|
||||
def test_cond_subgraph(self, pred: bool):
|
||||
"""
|
||||
Test a model with subgraphs. This is not yet supported, so check that we get the
|
||||
expected exception.
|
||||
Test a model with subgraphs.
|
||||
"""
|
||||
|
||||
def foo(cond, x):
|
||||
return torch.cond(cond, torch.cos, torch.sin, [x])
|
||||
def foo(pred, x):
|
||||
return torch.cond(pred, torch.cos, torch.sin, [x]) + 1
|
||||
|
||||
cond = torch.tensor([True], device=self.device)
|
||||
x = torch.ones([2, 3], device=self.device)
|
||||
x = torch.randn((2, 3), device=self.device)
|
||||
pred_tensor = torch.tensor([pred], device=self.device)
|
||||
gm = self._compile_and_check(
|
||||
foo, [pred_tensor, x], expected_num_triton_kernels=3
|
||||
)[-1]
|
||||
|
||||
with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"):
|
||||
self._compile_and_check(foo, [cond, x])
|
||||
# Check for subgraphs.
|
||||
subgm_getattrs = list(gm.graph.find_nodes(op="get_attr"))
|
||||
self.assertEqual(len(subgm_getattrs), 2)
|
||||
for subgm_getattr in subgm_getattrs:
|
||||
target = subgm_getattr.name
|
||||
self.assertTrue(isinstance(getattr(gm, target), torch.fx.GraphModule))
|
||||
|
||||
@parametrize("pred", (False, True))
|
||||
def test_cond_no_operands(self, pred: bool):
|
||||
"""
|
||||
Test torch.cond when the subgraphs take no inputs.
|
||||
"""
|
||||
|
||||
length = 8
|
||||
|
||||
def true_fn():
|
||||
return torch.zeros(length, device=self.device)
|
||||
|
||||
def false_fn():
|
||||
return true_fn() + 5
|
||||
|
||||
def foo(pred):
|
||||
return torch.cond(pred, true_fn, false_fn, ())
|
||||
|
||||
pred_tensor = torch.tensor([pred], device=self.device)
|
||||
self._compile_and_check(foo, [pred_tensor], expected_num_triton_kernels=2)
|
||||
|
||||
def test_cpp_raises(self):
|
||||
"""
|
||||
@ -759,9 +797,9 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
|
||||
)
|
||||
gm = torch._inductor.aot_compile(
|
||||
ep.module(), inp, options={"fx_wrapper": True, "compile_threads": 1}
|
||||
ep.module(), inp, options={"fx_wrapper": True, **test_config}
|
||||
)
|
||||
self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
|
||||
self.assertTrue(same(model(*inp), gm(*inp)))
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if (
|
||||
@ -919,6 +957,39 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
1,
|
||||
)
|
||||
|
||||
@parametrize("pred", (False, True))
|
||||
def test_cond_multi_inputs_and_outputs(self, pred):
|
||||
"""
|
||||
Test torch.cond and check the output graphs.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, pred, x, y):
|
||||
def true_fn(x, y):
|
||||
return torch.tanh(x), torch.relu(y)
|
||||
|
||||
def false_fn(x, y):
|
||||
return tuple(t / 2 for t in true_fn(x, y))
|
||||
|
||||
return torch.cond(pred, true_fn, false_fn, (x, y))
|
||||
|
||||
pred = torch.tensor([True], device=self.device)
|
||||
(x, y) = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
gm = self.check(M(), (pred, x, y))
|
||||
|
||||
# Check the graph.
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, (arg1_1, arg2_1)); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = None
|
||||
buf1 = cond[0]
|
||||
buf2 = cond[1]; cond = None
|
||||
return [buf1, buf2]""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -390,6 +390,19 @@ class EnterSubgraphLine(WrapperLine):
|
||||
return converter._generate_enter_subgraph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConditionalLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: ir.Conditional
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
raise NotImplementedError("Only supports FX codegen")
|
||||
|
||||
@staticmethod
|
||||
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_conditional
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommentLine(WrapperLine):
|
||||
line: LineContext
|
||||
@ -2169,8 +2182,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
)
|
||||
self.header.splice(body)
|
||||
|
||||
def define_subgraph_launcher_fn(self, fn_code: str):
|
||||
self.subgraph_definitions.splice(fn_code)
|
||||
def define_subgraph_launcher_fn(self, name: str, subgraph_code):
|
||||
self.subgraph_definitions.splice(subgraph_code.value)
|
||||
|
||||
def define_user_defined_triton_kernel(
|
||||
self,
|
||||
@ -3362,11 +3375,12 @@ class PythonWrapperCodegen(CodeGen):
|
||||
|
||||
def codegen_subgraph_common(self, subgraph):
|
||||
self.push_codegened_graph(subgraph.graph)
|
||||
self.writeline("")
|
||||
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
|
||||
self.make_comment("")
|
||||
self.make_comment(f"{self.comment} subgraph: {subgraph.name}")
|
||||
|
||||
parent_graph = V.graph
|
||||
subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper
|
||||
subgraph.graph.fx_wrapper = parent_graph.fx_wrapper
|
||||
|
||||
if subgraph.graph.name not in self.already_codegened_subgraphs:
|
||||
# If it is already codegened, the parent wrapper already has
|
||||
@ -3376,8 +3390,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
with config.patch("graph_partition", False):
|
||||
# Call the codegen of subgraph recursively
|
||||
subgraph_code, _ = subgraph.graph.codegen()
|
||||
self.already_codegened_subgraphs.add(subgraph.graph.name)
|
||||
self.define_subgraph_launcher_fn(subgraph_code.value)
|
||||
subgraph_name = subgraph.graph.name
|
||||
self.already_codegened_subgraphs.add(subgraph_name)
|
||||
self.define_subgraph_launcher_fn(subgraph_name, subgraph_code)
|
||||
|
||||
def codegen_subgraph_with_flattened_outputs(
|
||||
self, subgraph, outer_inputs, outer_flattened_outputs
|
||||
@ -3409,7 +3424,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
else:
|
||||
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name)
|
||||
|
||||
def codegen_conditional(self, conditional):
|
||||
def codegen_conditional(self, conditional) -> None:
|
||||
name = conditional.get_name()
|
||||
|
||||
outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
|
||||
@ -3644,7 +3659,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
|
||||
def get_graph_inputs(
|
||||
self,
|
||||
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]:
|
||||
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
|
||||
if signature := self.partition_signatures:
|
||||
inputs = signature.input_nodes | {
|
||||
str(s): s for s in signature.symbol_inputs
|
||||
|
@ -35,7 +35,7 @@ from torch.utils._sympy.solve import try_solve
|
||||
|
||||
from .. import config, ir
|
||||
from ..runtime.triton_compat import Config
|
||||
from ..utils import LineContext
|
||||
from ..utils import cache_property_on_self, LineContext, ValueWithLineMap
|
||||
from .common import (
|
||||
CodegenSymbol,
|
||||
FileBackedGraphModule,
|
||||
@ -48,6 +48,7 @@ from .wrapper import (
|
||||
CommBufferAllocateLine,
|
||||
CommBufferFreeLine,
|
||||
CommentLine,
|
||||
ConditionalLine,
|
||||
EnterDeviceContextManagerLine,
|
||||
EnterSubgraphLine,
|
||||
ExitDeviceContextManagerLine,
|
||||
@ -66,6 +67,7 @@ from .wrapper import (
|
||||
ReinterpretLine,
|
||||
ReuseLine,
|
||||
ScatterFallbackLine,
|
||||
SubgraphPythonWrapperCodegen,
|
||||
SymbolicCallArg,
|
||||
SymbolicCallArgLine,
|
||||
WrapperLine,
|
||||
@ -147,12 +149,54 @@ class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
|
||||
supports_caching = False
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.subgms: dict[str, torch.fx.GraphModule] = {}
|
||||
|
||||
def codegen_inputs(self) -> None:
|
||||
"""
|
||||
This would generate code for symbolic input shapes, strides, etc.
|
||||
Since the FX converter handles this, do nothing here.
|
||||
"""
|
||||
|
||||
def codegen_conditional(self, conditional: ir.Conditional) -> None:
|
||||
"""
|
||||
Conditional codegen normally emits a number of different wrapper lines.
|
||||
Instead, FX conversion uses a dedicated line for the whole conditional.
|
||||
"""
|
||||
self.writeline(ConditionalLine(self, conditional))
|
||||
for subgraph in (conditional.true_subgraph, conditional.false_subgraph):
|
||||
self.codegen_subgraph_common(subgraph)
|
||||
|
||||
def define_subgraph_launcher_fn(
|
||||
self, name: str, subgraph_code: Union[ValueWithLineMap, FileBackedGraphModule]
|
||||
) -> None:
|
||||
"""
|
||||
Record subgms as they're generated.
|
||||
"""
|
||||
assert isinstance(subgraph_code, FileBackedGraphModule)
|
||||
self.subgms[name] = subgraph_code.gm
|
||||
|
||||
@property
|
||||
@cache_property_on_self
|
||||
def is_subgraph(self) -> bool:
|
||||
return isinstance(self, SubgraphPythonWrapperCodegen)
|
||||
|
||||
def get_fx_graph_inputs(
|
||||
self,
|
||||
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
|
||||
"""
|
||||
Get the input nodes corresponding to FX graph placeholders.
|
||||
"""
|
||||
if V.aot_compilation and not self.is_subgraph:
|
||||
# AOT graphs must match the signature of the input module.
|
||||
return {
|
||||
node.name: V.graph.graph_inputs.get(node.name)
|
||||
for node in V.graph.module.graph.find_nodes(op="placeholder") # type: ignore[operator, union-attr]
|
||||
}
|
||||
|
||||
return self.get_graph_inputs()
|
||||
|
||||
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
|
||||
self.run_wrapper_ir_passes(is_inference)
|
||||
|
||||
@ -162,7 +206,15 @@ class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
self.header.getvalue(),
|
||||
]
|
||||
)
|
||||
gm = FxConverter(lines=self.lines, prologue=prologue).generate()
|
||||
gm = FxConverter(
|
||||
lines=self.lines,
|
||||
prologue=prologue,
|
||||
graph_inputs=self.get_fx_graph_inputs(),
|
||||
graph_outputs=self.get_graph_outputs(),
|
||||
subgms=self.subgms,
|
||||
is_subgraph=self.is_subgraph,
|
||||
).generate()
|
||||
|
||||
compiled_fn = self.compile_graph(gm)
|
||||
|
||||
return FileBackedGraphModule(gm, compiled_fn), None
|
||||
@ -175,20 +227,43 @@ class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
"""
|
||||
return gm.forward
|
||||
|
||||
def write_header(self) -> None:
|
||||
"""
|
||||
Python subgraphs normally lack headers.
|
||||
Override this behavior to generate prologues for FX subgraphs.
|
||||
"""
|
||||
PythonWrapperCodegen.write_header(self)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
cls: type["WrapperFxCodegen"],
|
||||
is_subgraph: bool,
|
||||
subgraph_name: Optional[str],
|
||||
parent_wrapper: Optional[PythonWrapperCodegen],
|
||||
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
||||
) -> "WrapperFxCodegen":
|
||||
if is_subgraph:
|
||||
raise NotImplementedError(
|
||||
"Subgraphs are not yet supported by FX conversion"
|
||||
assert subgraph_name is not None
|
||||
assert parent_wrapper is not None
|
||||
|
||||
# Subgraphs override some methods of PythonWrapperCodegen.
|
||||
# Apply these overrides to the user-provided class, with priority given to
|
||||
# user-provided methods.
|
||||
class SubgraphFxWrapperCodegen(cls, SubgraphPythonWrapperCodegen): # type: ignore[misc,valid-type]
|
||||
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
|
||||
"""
|
||||
Skip graph compilation for subgraphs.
|
||||
"""
|
||||
|
||||
def crash_if_run(*args: Any) -> None:
|
||||
raise NotImplementedError("Cannot run a subgraph in isolation!")
|
||||
|
||||
return crash_if_run
|
||||
|
||||
return SubgraphFxWrapperCodegen(
|
||||
subgraph_name, parent_wrapper, partition_signatures
|
||||
)
|
||||
|
||||
# For derived backends, this could be a subclass.
|
||||
return cls()
|
||||
|
||||
|
||||
@ -200,7 +275,11 @@ class FxConverter:
|
||||
"""
|
||||
|
||||
lines: list[Line]
|
||||
prologue: str = ""
|
||||
prologue: str
|
||||
graph_inputs: dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]
|
||||
graph_outputs: list[ir.IRNode]
|
||||
subgms: dict[str, torch.fx.GraphModule]
|
||||
is_subgraph: bool
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
graph = torch.fx.Graph()
|
||||
@ -312,10 +391,11 @@ class FxConverter:
|
||||
Converts graph inputs to FX placeholders.
|
||||
"""
|
||||
|
||||
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||
name = node.name
|
||||
if name in V.graph.graph_inputs:
|
||||
ir_node = V.graph.graph_inputs[name]
|
||||
for name, ir_node in self.graph_inputs.items():
|
||||
if ir_node is None:
|
||||
# Create dummy input nodes to match the input signature
|
||||
self.gm.graph.placeholder(name)
|
||||
continue
|
||||
|
||||
# Introduce a new symbol for constant inputs.
|
||||
buffer = (
|
||||
@ -327,10 +407,6 @@ class FxConverter:
|
||||
placeholder_node.meta["val"] = buffer.get_example()
|
||||
self._record_allocation(buffer, placeholder_node)
|
||||
|
||||
elif V.aot_compilation:
|
||||
# Create dummy input nodes to match the input signature
|
||||
self.gm.graph.placeholder(name)
|
||||
|
||||
def _generate_graph_input_shapes(self) -> None:
|
||||
"""
|
||||
Generate nodes creating symints that are part of graph input
|
||||
@ -401,8 +477,7 @@ class FxConverter:
|
||||
|
||||
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||
name = node.name
|
||||
if name in V.graph.graph_inputs:
|
||||
ir_node = V.graph.graph_inputs[name]
|
||||
ir_node = self.graph_inputs.get(name)
|
||||
if isinstance(ir_node, ir.TensorBox):
|
||||
buffer = self._get_buffer(ir_node)
|
||||
placeholder_node = self.buffer_to_node[buffer.get_name()]
|
||||
@ -469,21 +544,48 @@ class FxConverter:
|
||||
Generate FX IR for graph outputs.
|
||||
"""
|
||||
output_nodes = [
|
||||
self._generate_buffer(node)
|
||||
for idx, node in enumerate(V.graph.graph_outputs)
|
||||
self._generate_buffer(node) for idx, node in enumerate(self.graph_outputs)
|
||||
]
|
||||
|
||||
# Single return elements don't use a tuple.
|
||||
output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes
|
||||
# Parent graphs with single return elements don't use a tuple.
|
||||
output_value = (
|
||||
output_nodes[0]
|
||||
if len(output_nodes) == 1 and not self.is_subgraph
|
||||
else output_nodes
|
||||
)
|
||||
|
||||
self.gm.graph.output(output_value)
|
||||
|
||||
def _generate_subgm_getattrs(self) -> None:
|
||||
"""
|
||||
Generate getattr nodes for subgms.
|
||||
"""
|
||||
|
||||
def generate_getattr(name: str, subgm: torch.fx.GraphModule) -> torch.fx.Node:
|
||||
self.gm.add_submodule(name, subgm)
|
||||
node = self.gm.graph.get_attr(name)
|
||||
node.meta["val"] = subgm
|
||||
return node
|
||||
|
||||
self.subgm_getattrs = {
|
||||
name: generate_getattr(name, subgm) for name, subgm in self.subgms.items()
|
||||
}
|
||||
|
||||
def _get_subgm_attr(self, subgraph: ir.Subgraph) -> torch.fx.Node:
|
||||
"""
|
||||
Look up the getattr node for a subgraph.
|
||||
"""
|
||||
graph = subgraph.graph
|
||||
assert graph is not None
|
||||
return self.subgm_getattrs[graph.name]
|
||||
|
||||
def generate(self) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Main entrypoint for FX codegen.
|
||||
"""
|
||||
self._generate_graph_inputs()
|
||||
self._generate_graph_constants()
|
||||
self._generate_subgm_getattrs()
|
||||
|
||||
fake_mode = _detect_fake_mode_from_gm(self.gm)
|
||||
|
||||
@ -586,6 +688,33 @@ class FxConverter:
|
||||
node.name = name
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
def _generate_conditional(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ConditionalLine)
|
||||
|
||||
def get_subgm_attr(subgraph: Optional[ir.Subgraph]) -> torch.fx.Node:
|
||||
assert subgraph is not None
|
||||
return self._get_subgm_attr(subgraph)
|
||||
|
||||
# Access the subgraphs as getattrs.
|
||||
ir_node = line.node
|
||||
(true_subgm, false_subgm) = [
|
||||
get_subgm_attr(subgraph)
|
||||
for subgraph in (ir_node.true_subgraph, ir_node.false_subgraph)
|
||||
]
|
||||
|
||||
def generate_buffer(node: Optional[ir.IRNode]) -> Optional[torch.fx.Node]:
|
||||
assert node is not None
|
||||
return self._generate_buffer(node)
|
||||
|
||||
predicate = generate_buffer(ir_node.predicate)
|
||||
assert ir_node.operands is not None
|
||||
operands = tuple(generate_buffer(arg) for arg in ir_node.operands)
|
||||
fx_node = self.gm.graph.call_function(
|
||||
torch.ops.higher_order.cond,
|
||||
args=(predicate, true_subgm, false_subgm, operands),
|
||||
)
|
||||
self._record_allocation(ir_node, fx_node)
|
||||
|
||||
def _generate_comment(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommentLine)
|
||||
# We ignore comments in FX IR.
|
||||
@ -600,11 +729,11 @@ class FxConverter:
|
||||
|
||||
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, EnterSubgraphLine)
|
||||
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
|
||||
# We ignore memory planning lines in FX IR.
|
||||
|
||||
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExitSubgraphLine)
|
||||
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
|
||||
# We ignore memory planning lines in FX IR.
|
||||
|
||||
def _generate_free(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, FreeLine)
|
||||
|
@ -5245,9 +5245,10 @@ class Scheduler:
|
||||
V.graph.wrapper_code.partition_signatures = signature
|
||||
V.graph.wrapper_code.write_prefix()
|
||||
|
||||
graph_name = V.graph.name
|
||||
partition_code, _ = V.graph.wrapper_code.generate(V.graph.is_inference)
|
||||
|
||||
V.graph.wrapper_code.define_subgraph_launcher_fn(partition_code.value)
|
||||
V.graph.wrapper_code.define_subgraph_launcher_fn(graph_name, partition_code)
|
||||
|
||||
V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature)
|
||||
V.graph.wrapper_code.allocated.update( # type: ignore[has-type]
|
||||
|
@ -666,6 +666,13 @@ def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
|
||||
"""
|
||||
Variant of cache_on_self for properties. The only difference is the type signature.
|
||||
"""
|
||||
return cache_on_self(fn)
|
||||
|
||||
|
||||
def aggregate_origins(
|
||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||
) -> OrderedSet[Node]:
|
||||
|
Reference in New Issue
Block a user