[Inductor-FX] Support torch.cond (#163234)

# Feature

Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.

# Implementation overview

The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
   a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.

# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.

Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.

Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.

In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)

Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.

This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.

# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
    return buf1
```

It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234
Approved by: https://github.com/angelayi, https://github.com/jansel
This commit is contained in:
Blaine Burton Rister
2025-09-20 03:52:29 +00:00
committed by PyTorch MergeBot
parent a31acf32bd
commit e56dd5d770
5 changed files with 296 additions and 73 deletions

View File

@ -17,7 +17,6 @@ from torch._dynamo.exc import BackendCompilerFailed
from torch._dynamo.utils import same
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._inductor import config
from torch._inductor.codegen.common import register_backend_for_device
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
@ -43,15 +42,17 @@ if HAS_GPU:
from torch.testing._internal.triton_utils import add_kernel_2d_autotuned
test_config = {
"compile_threads": 1,
"alignment_asserts": False,
"size_asserts": False,
"scalar_asserts": False,
"nan_asserts": False,
}
@requires_gpu()
@config.patch(
compile_threads=1,
alignment_asserts=False,
size_asserts=False,
scalar_asserts=False,
nan_asserts=False,
)
@config.patch(test_config)
@instantiate_parametrized_tests
class FxirTestCase(InductorTestCase):
device = GPU_TYPE
@ -116,8 +117,19 @@ class FxirTestCase(InductorTestCase):
def setUpClass(cls):
super().setUpClass()
# Register the FX backend.
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
# Register the FX backend, storing the default for later.
common.init_backend_registration()
cls._default_backend = common.device_codegens[cls.device]
common.register_backend_for_device(
cls.device, TritonScheduling, WrapperFxCodegen
)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
# Restore the default backend.
common.device_codegens[cls.device] = cls._default_backend
def test_basic(self):
args = [torch.randn(8, device=self.device) for _ in range(2)]
@ -630,21 +642,47 @@ class FxirTestCase(InductorTestCase):
# Check for the fallback.
self.assertEqual(self._count_ops(gm, fallback_op), 1)
@torch._inductor.config.patch("graph_partition", True)
def test_subgraph_raises(self):
@parametrize("pred", (False, True))
def test_cond_subgraph(self, pred: bool):
"""
Test a model with subgraphs. This is not yet supported, so check that we get the
expected exception.
Test a model with subgraphs.
"""
def foo(cond, x):
return torch.cond(cond, torch.cos, torch.sin, [x])
def foo(pred, x):
return torch.cond(pred, torch.cos, torch.sin, [x]) + 1
cond = torch.tensor([True], device=self.device)
x = torch.ones([2, 3], device=self.device)
x = torch.randn((2, 3), device=self.device)
pred_tensor = torch.tensor([pred], device=self.device)
gm = self._compile_and_check(
foo, [pred_tensor, x], expected_num_triton_kernels=3
)[-1]
with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"):
self._compile_and_check(foo, [cond, x])
# Check for subgraphs.
subgm_getattrs = list(gm.graph.find_nodes(op="get_attr"))
self.assertEqual(len(subgm_getattrs), 2)
for subgm_getattr in subgm_getattrs:
target = subgm_getattr.name
self.assertTrue(isinstance(getattr(gm, target), torch.fx.GraphModule))
@parametrize("pred", (False, True))
def test_cond_no_operands(self, pred: bool):
"""
Test torch.cond when the subgraphs take no inputs.
"""
length = 8
def true_fn():
return torch.zeros(length, device=self.device)
def false_fn():
return true_fn() + 5
def foo(pred):
return torch.cond(pred, true_fn, false_fn, ())
pred_tensor = torch.tensor([pred], device=self.device)
self._compile_and_check(foo, [pred_tensor], expected_num_triton_kernels=2)
def test_cpp_raises(self):
"""
@ -759,9 +797,9 @@ class AOTFxirTestCase(InductorTestCase):
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
)
gm = torch._inductor.aot_compile(
ep.module(), inp, options={"fx_wrapper": True, "compile_threads": 1}
ep.module(), inp, options={"fx_wrapper": True, **test_config}
)
self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
self.assertTrue(same(model(*inp), gm(*inp)))
for node in gm.graph.nodes:
if (
@ -919,6 +957,39 @@ class AOTFxirTestCase(InductorTestCase):
1,
)
@parametrize("pred", (False, True))
def test_cond_multi_inputs_and_outputs(self, pred):
"""
Test torch.cond and check the output graphs.
"""
class M(torch.nn.Module):
def forward(self, pred, x, y):
def true_fn(x, y):
return torch.tanh(x), torch.relu(y)
def false_fn(x, y):
return tuple(t / 2 for t in true_fn(x, y))
return torch.cond(pred, true_fn, false_fn, (x, y))
pred = torch.tensor([True], device=self.device)
(x, y) = [torch.randn(8, device=self.device) for _ in range(2)]
gm = self.check(M(), (pred, x, y))
# Check the graph.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, (arg1_1, arg2_1)); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = None
buf1 = cond[0]
buf2 = cond[1]; cond = None
return [buf1, buf2]""", # noqa: B950
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -390,6 +390,19 @@ class EnterSubgraphLine(WrapperLine):
return converter._generate_enter_subgraph
@dataclasses.dataclass
class ConditionalLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: ir.Conditional
def codegen(self, code: IndentedBuffer) -> None:
raise NotImplementedError("Only supports FX codegen")
@staticmethod
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
return converter._generate_conditional
@dataclasses.dataclass
class CommentLine(WrapperLine):
line: LineContext
@ -2169,8 +2182,8 @@ class PythonWrapperCodegen(CodeGen):
)
self.header.splice(body)
def define_subgraph_launcher_fn(self, fn_code: str):
self.subgraph_definitions.splice(fn_code)
def define_subgraph_launcher_fn(self, name: str, subgraph_code):
self.subgraph_definitions.splice(subgraph_code.value)
def define_user_defined_triton_kernel(
self,
@ -3362,11 +3375,12 @@ class PythonWrapperCodegen(CodeGen):
def codegen_subgraph_common(self, subgraph):
self.push_codegened_graph(subgraph.graph)
self.writeline("")
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
self.make_comment("")
self.make_comment(f"{self.comment} subgraph: {subgraph.name}")
parent_graph = V.graph
subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper
subgraph.graph.fx_wrapper = parent_graph.fx_wrapper
if subgraph.graph.name not in self.already_codegened_subgraphs:
# If it is already codegened, the parent wrapper already has
@ -3376,8 +3390,9 @@ class PythonWrapperCodegen(CodeGen):
with config.patch("graph_partition", False):
# Call the codegen of subgraph recursively
subgraph_code, _ = subgraph.graph.codegen()
self.already_codegened_subgraphs.add(subgraph.graph.name)
self.define_subgraph_launcher_fn(subgraph_code.value)
subgraph_name = subgraph.graph.name
self.already_codegened_subgraphs.add(subgraph_name)
self.define_subgraph_launcher_fn(subgraph_name, subgraph_code)
def codegen_subgraph_with_flattened_outputs(
self, subgraph, outer_inputs, outer_flattened_outputs
@ -3409,7 +3424,7 @@ class PythonWrapperCodegen(CodeGen):
else:
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name)
def codegen_conditional(self, conditional):
def codegen_conditional(self, conditional) -> None:
name = conditional.get_name()
outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
@ -3644,7 +3659,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
def get_graph_inputs(
self,
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]:
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
if signature := self.partition_signatures:
inputs = signature.input_nodes | {
str(s): s for s in signature.symbol_inputs

View File

@ -35,7 +35,7 @@ from torch.utils._sympy.solve import try_solve
from .. import config, ir
from ..runtime.triton_compat import Config
from ..utils import LineContext
from ..utils import cache_property_on_self, LineContext, ValueWithLineMap
from .common import (
CodegenSymbol,
FileBackedGraphModule,
@ -48,6 +48,7 @@ from .wrapper import (
CommBufferAllocateLine,
CommBufferFreeLine,
CommentLine,
ConditionalLine,
EnterDeviceContextManagerLine,
EnterSubgraphLine,
ExitDeviceContextManagerLine,
@ -66,6 +67,7 @@ from .wrapper import (
ReinterpretLine,
ReuseLine,
ScatterFallbackLine,
SubgraphPythonWrapperCodegen,
SymbolicCallArg,
SymbolicCallArgLine,
WrapperLine,
@ -147,12 +149,54 @@ class WrapperFxCodegen(PythonWrapperCodegen):
supports_caching = False
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.subgms: dict[str, torch.fx.GraphModule] = {}
def codegen_inputs(self) -> None:
"""
This would generate code for symbolic input shapes, strides, etc.
Since the FX converter handles this, do nothing here.
"""
def codegen_conditional(self, conditional: ir.Conditional) -> None:
"""
Conditional codegen normally emits a number of different wrapper lines.
Instead, FX conversion uses a dedicated line for the whole conditional.
"""
self.writeline(ConditionalLine(self, conditional))
for subgraph in (conditional.true_subgraph, conditional.false_subgraph):
self.codegen_subgraph_common(subgraph)
def define_subgraph_launcher_fn(
self, name: str, subgraph_code: Union[ValueWithLineMap, FileBackedGraphModule]
) -> None:
"""
Record subgms as they're generated.
"""
assert isinstance(subgraph_code, FileBackedGraphModule)
self.subgms[name] = subgraph_code.gm
@property
@cache_property_on_self
def is_subgraph(self) -> bool:
return isinstance(self, SubgraphPythonWrapperCodegen)
def get_fx_graph_inputs(
self,
) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]:
"""
Get the input nodes corresponding to FX graph placeholders.
"""
if V.aot_compilation and not self.is_subgraph:
# AOT graphs must match the signature of the input module.
return {
node.name: V.graph.graph_inputs.get(node.name)
for node in V.graph.module.graph.find_nodes(op="placeholder") # type: ignore[operator, union-attr]
}
return self.get_graph_inputs()
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
self.run_wrapper_ir_passes(is_inference)
@ -162,7 +206,15 @@ class WrapperFxCodegen(PythonWrapperCodegen):
self.header.getvalue(),
]
)
gm = FxConverter(lines=self.lines, prologue=prologue).generate()
gm = FxConverter(
lines=self.lines,
prologue=prologue,
graph_inputs=self.get_fx_graph_inputs(),
graph_outputs=self.get_graph_outputs(),
subgms=self.subgms,
is_subgraph=self.is_subgraph,
).generate()
compiled_fn = self.compile_graph(gm)
return FileBackedGraphModule(gm, compiled_fn), None
@ -175,20 +227,43 @@ class WrapperFxCodegen(PythonWrapperCodegen):
"""
return gm.forward
def write_header(self) -> None:
"""
Python subgraphs normally lack headers.
Override this behavior to generate prologues for FX subgraphs.
"""
PythonWrapperCodegen.write_header(self)
@classmethod
def create(
cls,
cls: type["WrapperFxCodegen"],
is_subgraph: bool,
subgraph_name: Optional[str],
parent_wrapper: Optional[PythonWrapperCodegen],
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
) -> "WrapperFxCodegen":
if is_subgraph:
raise NotImplementedError(
"Subgraphs are not yet supported by FX conversion"
assert subgraph_name is not None
assert parent_wrapper is not None
# Subgraphs override some methods of PythonWrapperCodegen.
# Apply these overrides to the user-provided class, with priority given to
# user-provided methods.
class SubgraphFxWrapperCodegen(cls, SubgraphPythonWrapperCodegen): # type: ignore[misc,valid-type]
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
"""
Skip graph compilation for subgraphs.
"""
def crash_if_run(*args: Any) -> None:
raise NotImplementedError("Cannot run a subgraph in isolation!")
return crash_if_run
return SubgraphFxWrapperCodegen(
subgraph_name, parent_wrapper, partition_signatures
)
# For derived backends, this could be a subclass.
return cls()
@ -200,7 +275,11 @@ class FxConverter:
"""
lines: list[Line]
prologue: str = ""
prologue: str
graph_inputs: dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr, None]]
graph_outputs: list[ir.IRNode]
subgms: dict[str, torch.fx.GraphModule]
is_subgraph: bool
def __post_init__(self) -> None:
graph = torch.fx.Graph()
@ -312,24 +391,21 @@ class FxConverter:
Converts graph inputs to FX placeholders.
"""
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
name = node.name
if name in V.graph.graph_inputs:
ir_node = V.graph.graph_inputs[name]
# Introduce a new symbol for constant inputs.
buffer = (
SymbolBuffer(sympy.Symbol(name, is_integer=True))
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
else self._get_buffer(ir_node)
)
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
placeholder_node.meta["val"] = buffer.get_example()
self._record_allocation(buffer, placeholder_node)
elif V.aot_compilation:
for name, ir_node in self.graph_inputs.items():
if ir_node is None:
# Create dummy input nodes to match the input signature
self.gm.graph.placeholder(name)
continue
# Introduce a new symbol for constant inputs.
buffer = (
SymbolBuffer(sympy.Symbol(name, is_integer=True))
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
else self._get_buffer(ir_node)
)
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
placeholder_node.meta["val"] = buffer.get_example()
self._record_allocation(buffer, placeholder_node)
def _generate_graph_input_shapes(self) -> None:
"""
@ -401,20 +477,19 @@ class FxConverter:
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
name = node.name
if name in V.graph.graph_inputs:
ir_node = V.graph.graph_inputs[name]
if isinstance(ir_node, ir.TensorBox):
buffer = self._get_buffer(ir_node)
placeholder_node = self.buffer_to_node[buffer.get_name()]
ir_node = self.graph_inputs.get(name)
if isinstance(ir_node, ir.TensorBox):
buffer = self._get_buffer(ir_node)
placeholder_node = self.buffer_to_node[buffer.get_name()]
for dim, size in enumerate(ir_node.get_size()):
_codegen_symbol(
size, placeholder_node, torch.ops.aten.sym_size.int, dim
)
for dim, stride in enumerate(ir_node.get_stride()):
_codegen_symbol(
stride, placeholder_node, torch.ops.aten.sym_stride.int, dim
)
for dim, size in enumerate(ir_node.get_size()):
_codegen_symbol(
size, placeholder_node, torch.ops.aten.sym_size.int, dim
)
for dim, stride in enumerate(ir_node.get_stride()):
_codegen_symbol(
stride, placeholder_node, torch.ops.aten.sym_stride.int, dim
)
def _generate_graph_constants(self) -> None:
for name, value in V.graph.constants.items():
@ -469,21 +544,48 @@ class FxConverter:
Generate FX IR for graph outputs.
"""
output_nodes = [
self._generate_buffer(node)
for idx, node in enumerate(V.graph.graph_outputs)
self._generate_buffer(node) for idx, node in enumerate(self.graph_outputs)
]
# Single return elements don't use a tuple.
output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes
# Parent graphs with single return elements don't use a tuple.
output_value = (
output_nodes[0]
if len(output_nodes) == 1 and not self.is_subgraph
else output_nodes
)
self.gm.graph.output(output_value)
def _generate_subgm_getattrs(self) -> None:
"""
Generate getattr nodes for subgms.
"""
def generate_getattr(name: str, subgm: torch.fx.GraphModule) -> torch.fx.Node:
self.gm.add_submodule(name, subgm)
node = self.gm.graph.get_attr(name)
node.meta["val"] = subgm
return node
self.subgm_getattrs = {
name: generate_getattr(name, subgm) for name, subgm in self.subgms.items()
}
def _get_subgm_attr(self, subgraph: ir.Subgraph) -> torch.fx.Node:
"""
Look up the getattr node for a subgraph.
"""
graph = subgraph.graph
assert graph is not None
return self.subgm_getattrs[graph.name]
def generate(self) -> torch.fx.GraphModule:
"""
Main entrypoint for FX codegen.
"""
self._generate_graph_inputs()
self._generate_graph_constants()
self._generate_subgm_getattrs()
fake_mode = _detect_fake_mode_from_gm(self.gm)
@ -586,6 +688,33 @@ class FxConverter:
node.name = name
self._record_allocation(buffer, node)
def _generate_conditional(self, line: WrapperLine) -> None:
assert isinstance(line, ConditionalLine)
def get_subgm_attr(subgraph: Optional[ir.Subgraph]) -> torch.fx.Node:
assert subgraph is not None
return self._get_subgm_attr(subgraph)
# Access the subgraphs as getattrs.
ir_node = line.node
(true_subgm, false_subgm) = [
get_subgm_attr(subgraph)
for subgraph in (ir_node.true_subgraph, ir_node.false_subgraph)
]
def generate_buffer(node: Optional[ir.IRNode]) -> Optional[torch.fx.Node]:
assert node is not None
return self._generate_buffer(node)
predicate = generate_buffer(ir_node.predicate)
assert ir_node.operands is not None
operands = tuple(generate_buffer(arg) for arg in ir_node.operands)
fx_node = self.gm.graph.call_function(
torch.ops.higher_order.cond,
args=(predicate, true_subgm, false_subgm, operands),
)
self._record_allocation(ir_node, fx_node)
def _generate_comment(self, line: WrapperLine) -> None:
assert isinstance(line, CommentLine)
# We ignore comments in FX IR.
@ -600,11 +729,11 @@ class FxConverter:
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
assert isinstance(line, EnterSubgraphLine)
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
# We ignore memory planning lines in FX IR.
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
assert isinstance(line, ExitSubgraphLine)
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
# We ignore memory planning lines in FX IR.
def _generate_free(self, line: WrapperLine) -> None:
assert isinstance(line, FreeLine)

View File

@ -5245,9 +5245,10 @@ class Scheduler:
V.graph.wrapper_code.partition_signatures = signature
V.graph.wrapper_code.write_prefix()
graph_name = V.graph.name
partition_code, _ = V.graph.wrapper_code.generate(V.graph.is_inference)
V.graph.wrapper_code.define_subgraph_launcher_fn(partition_code.value)
V.graph.wrapper_code.define_subgraph_launcher_fn(graph_name, partition_code)
V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature)
V.graph.wrapper_code.allocated.update( # type: ignore[has-type]

View File

@ -666,6 +666,13 @@ def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
return wrapper # type: ignore[return-value]
def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
"""
Variant of cache_on_self for properties. The only difference is the type signature.
"""
return cache_on_self(fn)
def aggregate_origins(
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
) -> OrderedSet[Node]: