Compare commits

...

28 Commits

Author SHA1 Message Date
8690e80b9d Merge branch 'main' into bf/lite 2025-11-11 10:02:15 -08:00
24edea44b3 nit 2025-11-11 10:01:20 -08:00
8f0fa2e52d skip a mps test since mul2_kernel not exist 2025-11-10 23:05:11 -08:00
0f120be0b1 minor doc/format improve 2025-11-10 17:41:42 -08:00
fee30c60a8 lint 2025-11-10 15:50:22 -08:00
5a8affeba1 add docs 2025-11-10 14:54:09 -08:00
0932f1ff21 Merge branch 'main' into bf/lite 2025-11-09 22:43:20 -08:00
814cb7024b support triton_kernel_wrapper_functional 2025-11-09 22:41:47 -08:00
f2a7f85f11 also fallback by default for hop 2025-11-09 16:06:14 -08:00
308414c1b6 nit 2025-11-07 17:52:20 -08:00
5c9a710c99 make _SelectiveDecomposeInterpreter private 2025-11-07 10:09:03 -08:00
a24161475f Merge branch 'main' into bf/lite 2025-11-07 10:07:03 -08:00
f79e116bfa fix more cpp_wrapper 2025-11-06 23:39:44 -08:00
db30ee1a88 check different string for cpp_wrapper tests 2025-11-06 21:54:42 -08:00
5c17c94af8 nit 2025-11-06 19:28:12 -08:00
1d56ee10de nit 2025-11-06 18:01:30 -08:00
ccfc98f8a0 support dynamic shape assertion 2025-11-06 17:35:39 -08:00
7dddcb24c9 Merge branch 'main' into bf/lite 2025-11-06 17:15:58 -08:00
cd453f5e7c more tests 2025-11-06 17:09:36 -08:00
16d170b0ef add inductor selective_decompose config 2025-11-06 17:01:40 -08:00
8c23bb9fef support both joint graph and fwd only graph 2025-11-06 16:47:56 -08:00
08517c8556 nit 2025-11-06 13:10:37 -08:00
b42d8bdc76 Merge branch 'main' into bf/lite 2025-11-06 11:15:47 -08:00
a68e77bdea selective decompose for regional compile 2025-11-05 22:47:25 -08:00
9d977f1f68 nit 2025-11-05 13:35:04 -08:00
0bba0cdb9c Merge branch 'main' into bf/lite 2025-11-05 11:47:25 -08:00
7dba73a0a7 nit 2025-11-05 11:46:18 -08:00
8ec9f7de82 init 2025-11-05 11:37:54 -08:00
11 changed files with 486 additions and 18 deletions

View File

@ -30,6 +30,7 @@ import numpy as np
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.aoti_eager
import torch.fx.traceback as fx_traceback
import torch.nn as nn
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
from torch._dispatch.python import enable_python_dispatcher
@ -13651,6 +13652,224 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
FileCheck().check_regex(size_assert_pattern).run(code)
def test_lite_mode_fallback(self):
def f(x):
z = x.sin()
return z.cos()
f = torch.compile(f, mode="lite")
_, code = run_and_get_code(f, torch.randn(2, device=self.device))
# Checks that aten ops are kept and run
if config.cpp_wrapper:
FileCheck().check("aoti_torch_call_dispatcher(").check("aten::sin").check(
"aoti_torch_call_dispatcher("
).check("aten::cos").run(code[0])
else:
FileCheck().check("torch.ops.aten.sin.default(").check(
"torch.ops.aten.cos.default("
).run(code[0])
# Checks that no triton code run in the generated code
self.assertFalse(".run(" in code[0])
# skip cpu test since rms norm is always decomposed on cpu
def test_lite_mode_not_decompose(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
def f(x, shape):
y = x + 1
z = torch.ops.aten._fused_rms_norm(y, shape, None, None)
return z[0] + z[1]
f = torch.compile(f, mode="lite")
x = torch.randn(2, 3, device=self.device)
_, code = run_and_get_code(f, x, [2, 3])
if config.cpp_wrapper:
FileCheck().check(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda__fused_rms_norm("
).run(code[0])
else:
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
if config.cpp_wrapper:
# arg type List[int] is not yet supported by custom_op_wrapper
pass
else:
x = torch.randn(2, 3, device=self.device, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: f(x, [2, 3]))
self.assertEqual(len(codes), 2)
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
def test_lite_regional_compile_flex_attention(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
def _squared(score, b, h, m, n):
return score * score
def mask_mod(b, h, q, k):
return q >= 0
a = 12
b = 64
block_mask = create_block_mask(
mask_mod, None, None, a * b, a * b, device=self.device
)
def fn(x):
x = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
return torch.cos(x)
x = torch.randn(
1,
1,
a * b,
b,
dtype=torch.bfloat16,
device=self.device,
requires_grad=True,
)
opt_fn = torch.compile(
fn,
mode="lite",
fullgraph=True,
)
# Check that inductor compilation is called twice
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
@unittest.skipIf(
config.cpp_wrapper,
"codegen invoke_subgraph is not implemented for cpp wrapper",
)
def test_lite_regional_compile_invoke_subgraph(self):
# Checks that get_attr nodes custom metadata is propagated
@torch.compiler.nested_compile_region
def gn(x):
return torch.sin(x)
def fn(x):
x = x + 1
with fx_traceback.annotate({"compile_with_inductor": 0}):
z = gn(x)
return torch.sigmoid(z)
opt_fn = torch.compile(fn, mode="lite", fullgraph=True)
x = torch.randn(10, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
@unittest.skipIf(
config.cpp_wrapper,
"codegen triton_kernel_wrapper_functional is not implemented for cpp wrapper",
)
def test_lite_triton_kernel_wrapper_functional(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
from torch._higher_order_ops.triton_kernel_wrap import (
kernel_side_table,
triton_kernel_wrapper_functional,
)
from torch.testing._internal.triton_utils import mul2_kernel
kernel_side_table.reset_table()
def f(x, output):
out = triton_kernel_wrapper_functional(
kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
constant_args_idx=kernel_side_table.add_constant_args(
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
),
grid=[(x.numel(),)],
tma_descriptor_metadata={},
kwargs={
"in_ptr0": x,
"out_ptr": output,
},
tensors_to_clone=["in_ptr0", "out_ptr"],
)
return out["out_ptr"]
t1 = torch.rand(5, device=self.device)
t2 = torch.rand(5, device=self.device)
compiled_f = torch.compile(f, mode="lite")
out = compiled_f(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(out, t2)
def test_lite_regional_compile_repeated_blocks(self):
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
a = fn(x, y)
return fn(a, y)
mod = Mod()
opt_mod = torch.compile(
mod,
mode="lite",
fullgraph=True,
)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
self.assertEqual(len(codes), 2)
def test_lite_dynamic_shape_assertion(self):
class Model(torch.nn.Module):
def forward(self, c):
d = torch.concat([c, c], dim=0)
with fx_traceback.annotate({"compile_with_inductor": "my_region"}):
d = d + 1
return d
model = Model()
model = torch.compile(
model,
mode="lite",
fullgraph=True,
)
c = torch.randn((64, 32), device=self.device)
torch._dynamo.decorators.mark_unbacked(c, 0)
_, code = run_and_get_code(model, c)
# Checks that unbacked symint assertions are kept
if config.cpp_wrapper:
FileCheck().check_regex(r"if \(!\(u.* >= 0L\)\)").check_regex(
"Expected u.* >= 0 but receive"
).run(code[0])
else:
FileCheck().check_regex(r"if not \(u.* >= 0\):").check_regex(
r"raise RuntimeError\('u.* >= 0'\)"
).run(code[0])
@lowering.force_fallback(aten.sort.default)
@unittest.skipIf(
config.cpp_wrapper,

View File

@ -103,6 +103,17 @@ from .utils import (
_thread_local = threading.local()
@contextmanager
def maybe_skip_decompose(aot_config: AOTConfig):
old_decomp = aot_config.decompositions
try:
if config.selective_decompose:
aot_config.decompositions = {}
yield
finally:
aot_config.decompositions = old_decomp
# Saved tensor hooks context
# Compiled saved tensor hooks are convenient way to inline some logic in the graphs
# for saved nodes from forward to backward. (E.g. activations quantization)
@ -196,11 +207,28 @@ def aot_stage1_graph_capture(
# deterministic TLS can be different
aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]]
if aot_state.needs_autograd and not aot_config.pre_dispatch:
# FYI: this being moved to trigger in export is new, seems fine!
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
with maybe_skip_decompose(aot_config):
# if config.selective_decompose, skip decomposition and apply selective_decompose
# after we get the joint graph. See [Note: Selective Decomposition] for details.
if aot_state.needs_autograd and not aot_config.pre_dispatch:
# FYI: this being moved to trigger in export is new, seems fine!
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
(
graph,
updated_flat_args,
updated_flat_args_descs,
maybe_subclass_meta,
) = aot_dispatch_autograd_graph(
flat_fn,
aot_state.flat_args,
aot_state.flat_args_descs,
aot_config,
fw_metadata=aot_state.fw_metadata,
)
else:
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
aot_dispatch_autograd_graph(
aot_dispatch_base_graph(
flat_fn,
aot_state.flat_args,
aot_state.flat_args_descs,
@ -208,15 +236,17 @@ def aot_stage1_graph_capture(
fw_metadata=aot_state.fw_metadata,
)
)
else:
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
aot_dispatch_base_graph( # type: ignore[assignment]
flat_fn,
aot_state.flat_args,
aot_state.flat_args_descs,
aot_config,
fw_metadata=aot_state.fw_metadata,
)
if config.selective_decompose:
from torch.fx.experimental.proxy_tensor import selective_decompose
from torch.fx.passes.regional_inductor import _needs_inductor_compile
graph = selective_decompose(
graph,
*updated_flat_args,
decomposition=aot_config.decompositions,
should_decompose=_needs_inductor_compile,
trace_joint_graph=aot_state.needs_autograd and not aot_config.pre_dispatch,
)
return AOTGraphCapture(

View File

@ -374,6 +374,13 @@ saved_tensors_hooks_filtering_mode = "donated"
# This callback is invoked on the joint graph before partitioning
joint_custom_pass: Callable = None # type: ignore[assignment]
# Note [Selective Decomposition]
# This config allows selective decomposition of certain operators in the graph.
# When True, it does NOT decompose any nodes, except those nodes that users explicitly
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
# on to explicitly annotate. This is currently only used by inductor lite mode.
selective_decompose: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -315,6 +315,25 @@ def aot_compile(
)
lite_mode_options = {
# Fallback by default unless users explicitly annotated with
# regional inductor compile.
"fallback_by_default": True,
"selective_decompose": True,
# Disable reorder optimizations
"reorder_for_peak_memory": False,
"reorder_for_compute_comm_overlap": False,
"triton.reorder_for_reducing_graph_partitions": False,
# Disable pre-, joint-, post-grad passes
"use_pre_grad_passes": False,
"use_joint_graph_passes": False,
"use_post_grad_passes": False,
# Disable dead code elimination (dce) and buffer reuse
"use_dce": False,
"allow_buffer_reuse": False,
}
def list_mode_options(
mode: Optional[str] = None, dynamic: Optional[bool] = None
) -> dict[str, Any]:
@ -332,6 +351,8 @@ def list_mode_options(
mode_options: dict[str, dict[str, bool]] = {
"default": {},
# lite backend for opt-in optimizations
"lite": lite_mode_options,
# enable cudagraphs
"reduce-overhead": {
"triton.cudagraphs": True,

View File

@ -508,6 +508,9 @@ def _recursive_pre_grad_passes(
log_pt2_compile_event=True,
dynamo_compile_column_us="pre_grad_pass_time_us",
):
if not config.use_pre_grad_passes:
return gm
add_passes = config.add_pre_grad_passes
remove_passes = config.remove_pre_grad_passes
for subgraph_name in _get_subgraph_names(gm):
@ -526,6 +529,9 @@ def _recursive_joint_graph_passes(
log_pt2_compile_event=True,
dynamo_compile_column_us="joint_graph_pass_time_us",
):
if not config.use_joint_graph_passes:
return
# invoke_subgraph already runs the _recursive_joint_graph_passes. In
# AOTAutograd, `run_joint_graph_passes_on_hops` partitions the
# invoke_subgraph HOP before calling the partitioner on the outer graph.
@ -544,6 +550,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
log_pt2_compile_event=True,
dynamo_compile_column_us="post_grad_pass_time_us",
):
if not config.use_post_grad_passes:
return
for subgraph_name in _get_subgraph_names(gm):
subgraph = getattr(gm, subgraph_name)
_recursive_post_grad_passes(subgraph, is_inference)
@ -2708,7 +2717,10 @@ def _compile_fx_main(
is_valid_aoti_model_name()
with functorch_config.patch(unlift_effect_tokens=True):
with functorch_config.patch(
unlift_effect_tokens=True,
selective_decompose=config.selective_decompose,
):
gm, graph_signature = aot_export_module(
model_,
example_inputs_,
@ -2768,7 +2780,10 @@ def _compile_fx_main(
V.set_fake_mode(fake_mode),
torch._guards.tracing(tracing_context),
compiled_autograd._disable(),
functorch_config.patch(unlift_effect_tokens=True),
functorch_config.patch(
unlift_effect_tokens=True,
selective_decompose=config.selective_decompose,
),
):
try:
return aot_autograd(

View File

@ -550,6 +550,32 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
# Fall back to ATen for all ops by default, except those nodes that users explicitly
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
# on to explicitly annotate. This is currently only used by inductor lite mode.
# Different from default inductor mode that fuses all nodes, this config enables an
# opt-in mode that only fuse for user-specified nodes. The motivation is to provide
# guaranteed numeric correctness and give full control to users.
fallback_by_default: bool = False
# This config allows selective decomposition of certain operators in the graph.
# Currently the only use case is to patch the same-name config in functorch, for
# inductor lite mode. See more details in [Note: Selective Decomposition]
selective_decompose: bool = False
# Use dead code elimination
use_dce: bool = True
# Use fx graph passes
use_pre_grad_passes: bool = True
use_joint_graph_passes: bool = True
use_post_grad_passes: bool = True
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
@ -1373,6 +1399,10 @@ class triton:
default=False,
)
# reorder nodes to minimize the number of graph partitions while
# not incurring large memory overhead
reorder_for_reducing_graph_partitions: bool = True
# assertions on the fast path
fast_path_cudagraph_asserts = False

View File

@ -110,6 +110,7 @@ from .utils import (
maybe_get_suppress_shape_guards_ctx,
normalize_name,
should_assume_input_aligned,
should_fallback_by_default,
SUPPORTED_MKLDNN_DEVICES,
ValueWithLineMap,
)
@ -1634,6 +1635,20 @@ class GraphLowering(torch.fx.Interpreter):
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
and isinstance(
n.target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
)
and should_fallback_by_default(n)
):
# this path supports fallback due to inductor lite mode. It supports
# both OpOverload and HOPs (e.g., triton_kernel_wrapper_functional).
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation

View File

@ -2742,8 +2742,9 @@ class Scheduler:
self.process_grouped_nodes()
if (
torch._inductor.config.graph_partition
and torch._inductor.config.triton.cudagraphs
config.graph_partition
and config.triton.cudagraphs
and config.triton.reorder_for_reducing_graph_partitions
):
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
@ -3191,6 +3192,9 @@ class Scheduler:
"""
Remove any nodes without users
"""
if not config.use_dce:
return
# self.nodes is in topological order, so by iterating in reverse order
# we have visited (and potentially removed) all users before visiting a
# given node.

View File

@ -58,6 +58,7 @@ import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties
from torch.fx.passes.regional_inductor import _needs_inductor_compile
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten, tree_map_only
@ -4036,3 +4037,40 @@ def load_template(name: str, template_dir: Path) -> str:
"""Load a template file and return its content."""
with open(template_dir / f"{name}.py.jinja") as f:
return f.read()
def should_fallback_by_default(node: torch.fx.Node) -> bool:
"""Decide whether fallback for a node. This is only used in inductor lite mode."""
target = node.target
assert isinstance(
target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
), f"Expected OpOverload or HigherOrderOperator, but found {type(target)}"
if not config.fallback_by_default:
return False
# some ops need special handle due to dynamic shapes. we can avoid
# fallback if they do not impact numerics.
skip_fallback_due_to_dynamic_shape = OrderedSet(
[
torch.ops.aten._assert_scalar.default,
torch.ops.aten.lift_fresh_copy.default,
]
)
if target in skip_fallback_due_to_dynamic_shape:
return False
# Most hops have registered lowering. We should follow the lowering and not fallback.
# However, in rare cases, hops may not register lowering, such as
# torch.ops.higher_order.triton_kernel_wrapper_functional. We should fallback for
# these hops.
fallback_hops = OrderedSet(
[torch.ops.higher_order.triton_kernel_wrapper_functional]
)
if isinstance(target, torch._ops.HigherOrderOperator):
return target in fallback_hops
return not _needs_inductor_compile(node)

View File

@ -42,6 +42,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._logging import trace_structured
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_impls import fast_detach
from torch._subclasses.fake_tensor import (
FakeTensor,
@ -90,6 +91,7 @@ __all__ = [
"dispatch_trace",
"make_fx",
"DecompositionInterpreter",
"selective_decompose",
"py_sym_types",
"get_innermost_proxy_mode",
"get_proxy_mode",
@ -1881,6 +1883,93 @@ class DecompositionInterpreter(fx.Interpreter):
return super().run(*args, **kwargs) # type: ignore[arg-type]
class _SelectiveDecomposeInterpreter(fx.Interpreter):
def __init__(
self,
module: fx.GraphModule,
should_decompose: Callable[[fx.Node], bool],
decomposition_table: Mapping[OpOverload, Callable],
**kwargs: object,
) -> None:
"""
For all nodes in `module`, selectively decompose if is `should_decompose`,
following the given `decomposition_table`.
"""
super().__init__(module, **kwargs) # type: ignore[arg-type]
self.should_decompose = should_decompose
self.decomposition_table = decomposition_table
@staticmethod
def recursive_wrap(
gm: fx.GraphModule,
should_decompose: Callable[[fx.Node], bool],
decomposition_table: Mapping[OpOverload, Callable],
**kwargs: object,
) -> _SelectiveDecomposeInterpreter:
"""
Recursively wrap gm and its sub graph modules. Specifically, HOP takes
sub graph module as args. We may not want to decompose all nodes within
these sub graph modules. So we also need to wrap these sub graph modules.
As a result:
- if should_decompose(hop) is True, we decompose all nodes within the hop.
- if should_decompose(hop) is False, we check each node within the hop
and decide whether decompose or not.
"""
for node in gm.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, HigherOrderOperator
):
new_args = []
for arg in node.args:
if isinstance(arg, fx.GraphModule):
new_arg = _SelectiveDecomposeInterpreter.recursive_wrap(
arg, should_decompose, decomposition_table, **kwargs
)
else:
new_arg = arg
new_args.append(new_arg)
node.args = tuple(new_args)
return _SelectiveDecomposeInterpreter(
gm, should_decompose, decomposition_table, **kwargs
)
def run_node(self, n):
if self.should_decompose(n):
with decompose(self.decomposition_table):
result = super().run_node(n)
else:
result = super().run_node(n)
return result
def selective_decompose(
joint_gm: fx.GraphModule,
*args,
decomposition,
should_decompose,
trace_joint_graph: bool,
) -> fx.GraphModule:
"""Retrace a joint graph module and selectively apply decomposition."""
if trace_joint_graph:
# the arg name, primals and tangents, are important.
# make_fx keeps the name in the traced graph and partitioner later relies
# on the name to partition joint graph correctly.
def wrap_fn(primals: list[Any], tangents: list[Any]):
return _SelectiveDecomposeInterpreter.recursive_wrap(
joint_gm, should_decompose, decomposition
).run(*args)
else:
def wrap_fn(*args):
return _SelectiveDecomposeInterpreter.recursive_wrap(
joint_gm, should_decompose, decomposition
).run(*args)
return make_fx(wrap_fn, decomposition_table={})(*args)
def wrapper_and_args_for_make_fx(
func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
) -> tuple[Callable[[list[object]], R], list[object]]:

View File

@ -112,7 +112,7 @@ def _compile_submod(gm, prefix):
return gm
def _needs_inductor_compile(node):
def _needs_inductor_compile(node: torch.fx.Node):
return (
node.op not in ("placeholder", "output")
and hasattr(node, "meta")