Files
pytorch/torch/_inductor/subgraph_lowering.py
Angel Li 3a4140bf8e [FlexAttention] fixing learnable bias assertion error in inductor (#161170)
Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677)

We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts.

In this PR, we register the buffers with the base graph to solve this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170
Approved by: https://github.com/drisspg
2025-08-23 06:24:22 +00:00

210 lines
7.1 KiB
Python

"""Utilities for lowering subgraphs used by higher order operators"""
import functools
import operator
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
from torch.utils._ordered_set import OrderedSet
from . import ir
from .exc import SubgraphLoweringException
from .graph import GraphLowering
from .ops_handler import SimpleCSEHandler
from .virtualized import ops, V, WrapperHandler
T = TypeVar("T")
_P = ParamSpec("_P")
OpOverload = torch._ops.OpOverload
LoweringDict = dict[Union[OpOverload, str], Callable[..., Any]]
TargetType = Union[Callable[..., Any], str]
class PointwiseSubgraphLowering(torch.fx.Interpreter):
"""
Lowers a pointwise subgraph to a single set of buffers with a separate
lowering object. Errors if buffers are created unexpectedly
"""
graph_outputs: Optional[list[ir.IRNode]]
root_graph: GraphLowering
_current_op: Optional[TargetType]
# For backwards of buffer_grads with scatters we allow mutations
allowed_mutations: Optional[OrderedSet[OpOverload]]
additional_lowerings: Optional[LoweringDict]
buffers: list[ir.Buffer]
mutated_buffers: OrderedSet[str]
def __init__(
self,
gm: torch.fx.GraphModule,
root_graph_lowering: GraphLowering,
allowed_mutations: Optional[OrderedSet[OpOverload]] = None,
additional_lowerings: Optional[LoweringDict] = None,
) -> None:
super().__init__(gm)
self.graph_outputs = None
self.root_graph = root_graph_lowering
self.allowed_mutations = allowed_mutations
self.additional_lowerings = additional_lowerings
self._current_op = None
# Used to track buffers created during lowering
self.mutated_buffers = OrderedSet()
self.buffers = []
@contextmanager
def _op_context(self, op: TargetType) -> Generator[None, None, None]:
"""Set which op is being processed in call function to know if we can mutate buffers"""
previous = self._current_op
self._current_op = op
try:
yield
finally:
self._current_op = previous
def _approved_mutator(self) -> bool:
return (
self.allowed_mutations is not None
and self._current_op in self.allowed_mutations
)
def mark_buffer_mutated(self, name: str) -> None:
if self._approved_mutator():
self.mutated_buffers.add(name)
else:
raise SubgraphLoweringException(
f"Buffer mutation detected during lowering of {self._current_op}. "
"Buffer mutations are only allowed in approved mutation ops. "
"This is an error in the lowering of the subgraph, please file a bug report."
)
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
if self._approved_mutator():
name = self.root_graph.register_buffer(buffer, set_name=set_name)
return name
else:
raise SubgraphLoweringException(
"Buffers cannot be created while lowering a pointwise subgraph. "
"This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), "
"but it could also be a bug. Please file a bug report if you think this should be supportable."
)
def __getattr__(self, name: str) -> Any:
return getattr(self.root_graph, name)
def call_function(
self,
target: TargetType,
args: Any,
kwargs: dict[str, Any],
) -> Any:
from .lowering import lowerings
with self._op_context(target):
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
# These takes precedence over the main lowerings
if self.additional_lowerings is not None:
if target in self.additional_lowerings:
assert isinstance(target, OpOverload)
return self.additional_lowerings[target](*args, **kwargs)
if target not in lowerings:
raise SubgraphLoweringException(
f"{target} not supported in subgraph, (missing lowering)"
)
return lowerings[target](*args, **kwargs)
def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override]
assert len(args) == 1
self.graph_outputs = args[0]
@dataclass
class InputDescriptor:
dtype: torch.dtype
device: torch.device
class TracingOpsHandler(WrapperHandler):
def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None:
parent = tracer.create_proxy("placeholder", "ops", (), {})
super().__init__(parent)
self.tracer = tracer
self.placeholders = [
self.tracer.create_proxy("placeholder", f"input{i}", (), {})
for i in range(num_inputs)
]
def placeholder(self, idx: int) -> torch.fx.Proxy:
return self.placeholders[idx]
def output(self, *args: tuple[object]) -> None:
self.tracer.create_node(
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
)
def lower_pointwise_subgraph(
subgraph: ir.Subgraph, inputs: list[InputDescriptor]
) -> Callable[_P, Any]:
# Lower subgraph to ir.Pointwise nodes
def fake_inner_fn(
loop_idx: int, input_idx: int
) -> Union[ir.Expr, ir.TensorBox, None]:
return ops.placeholder(input_idx)
graph_inputs = [
ir.Pointwise.create(
device=desc.device,
dtype=desc.dtype,
inner_fn=functools.partial(fake_inner_fn, input_idx=i),
ranges=[],
)
for i, desc in enumerate(inputs)
]
gm = subgraph.graph_module
pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*graph_inputs)
# Combine multiple pointwise computations into a single graph module
# Do this by tracing through each individually and doing CSE
tracer = torch.fx.Tracer()
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs)))
assert pw_subgraph.graph_outputs is not None
with V.set_ops_handler(trace_ops):
output_irs = []
for out_var in pw_subgraph.graph_outputs:
assert isinstance(out_var, ir.TensorBox), type(out_var)
assert out_var.get_size() == []
assert isinstance(out_var.data, ir.StorageBox)
assert isinstance(out_var.data.data, ir.Pointwise)
idx = ()
ir_out = out_var.data.data.inner_fn(idx)
output_irs.append(ir_out)
ops.output(*output_irs)
lowered_gm = torch.fx.GraphModule({}, tracer.graph)
def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any:
return lowered_gm(V.get_ops_handler(), *args, **kwargs)
return inner_fn