[Inductor-FX] Support IndexPutFallback (#162863)

# Feature

This PR supports lowering `IndexPutFallback` through Inductor's FX converter. The approach is very similar to the one taken in https://github.com/pytorch/pytorch/pull/162686.

Compared to `ScatterFallback`, this required one additional change: the value of `self.op_overload` for `IndexPutFallback` was inaccurate. Previously, it used `aten.index_put`, which would result in unsound FX IR. The existing Python/C++ codegen use `aten.index_put_`, since the fallback mutates its input. This PR changes `self.op_overload` to match that.

# Test plan
Added a CI test lowering deterministic index put via the FX converter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162863
Approved by: https://github.com/angelayi
This commit is contained in:
Blaine Burton Rister
2025-09-16 08:52:43 +00:00
committed by PyTorch MergeBot
parent de143bf79b
commit 9aca0ba027
7 changed files with 102 additions and 25 deletions

View File

@ -591,6 +591,24 @@ class FxirTestCase(InductorTestCase):
num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value)
self.assertEqual(num_fallback, 1)
def test_index_put_fallback(self):
"""
Test the deterministic fallback for index_put.
"""
length = 8
out, values = [torch.randn(length, device=self.device) for _ in range(2)]
indices = (torch.randint(length, (length,), device=self.device),)
accumulate = True
with DeterministicGuard(True):
(gm,) = self._compile_and_check(
torch.index_put,
(out, indices, values, accumulate),
expected_num_triton_kernels=1,
)
# Check for the fallback op.
self.assertEqual(self._count_ops(gm, torch.ops.aten.index_put_.default), 1)
def test_scatter_reduce_fallback(self):
"""
Test the customized wrapper codegen for ScatterFallback ops.

View File

@ -1439,7 +1439,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
line += ");"
self.writeline(line)
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
# TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
# See the comment in codegen_reinterpret_view about why having something like
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding

View File

@ -735,10 +735,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
line += ");"
self.writeline(line)
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None:
# No stack allocation when there is a fallback op
self.allow_stack_allocation = False
super().generate_index_put_fallback(node)
def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
# TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
# See the comment in codegen_reinterpret_view about why having something like

View File

@ -908,6 +908,29 @@ class MultiOutputLine(WrapperLine):
return converter._generate_multi_output
@dataclasses.dataclass
class IndexPutFallbackLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: ir.IndexPutFallback
indices: list[Optional[ir.IRNode]]
def codegen(self, code: IndentedBuffer) -> None:
node = self.node
assert ir.is_node_sequence(node.inputs)
(x, values) = (t.codegen_reference() for t in node.inputs[:2])
indices = [
idx.codegen_reference() if idx else self.wrapper.none_str
for idx in self.indices
]
self.wrapper._generate_index_put_fallback(
node.get_kernel_name(), x, indices, values, *node.codegen_const_args()
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_index_put_fallback
@dataclasses.dataclass
class ScatterFallbackLine(WrapperLine):
wrapper: PythonWrapperCodegen
@ -1560,7 +1583,22 @@ class PythonWrapperCodegen(CodeGen):
line += ")"
self.writeline(line)
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None:
# Collect index tensors into a list.
indices: list[Optional[ir.IRNode]] = []
valid_indices = node.inputs[2:]
iter_valid_indices = iter(valid_indices)
for i, _ in enumerate(node.indices):
if node.indices[i] is not None:
index = next(iter_valid_indices)
assert isinstance(index, ir.IRNode)
indices.append(index)
else:
indices.append(None)
self.writeline(IndexPutFallbackLine(self, node, indices))
def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
indices_str = f"[{', '.join(indices)}]"
args = [x, indices_str, values, accumulate]
self.writeline(self.wrap_kernel_call(kernel, args))

View File

@ -55,6 +55,7 @@ from .wrapper import (
ExternKernelOutLine,
FreeIfNotReusedLine,
FreeLine,
IndexPutFallbackLine,
KernelCallLine,
KernelDefinitionLine,
Line,
@ -654,6 +655,42 @@ class FxConverter:
node.name = line.result_name
self.buffer_to_node[line.result_name] = node
def _generate_fallback_call(
self,
ir_node: ir.ExternKernel,
args: Optional[tuple[Any, ...]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> None:
fx_node = self.gm.graph.call_function(
ir_node.op_overload, # type: ignore[arg-type]
args=args,
kwargs=kwargs,
)
result_buffer = ir_node.codegen_reference()
self.buffer_to_node[result_buffer] = fx_node
def _generate_index_put_fallback(self, line: WrapperLine) -> None:
assert isinstance(line, IndexPutFallbackLine)
ir_node = line.node
def generate_buffer_or_none(
x: Union[ir.IRNode, Sequence[ir.IRNode], None],
) -> Optional[torch.fx.Node]:
"""
Handles None before calling _generate_buffer.
"""
if x is None:
return None
assert isinstance(x, ir.IRNode)
return self._generate_buffer(x)
(x, values) = [generate_buffer_or_none(t) for t in ir_node.inputs[:2]]
indices = tuple(generate_buffer_or_none(t) for t in line.indices)
accumulate = ir_node.constant_args[0]
args = (x, indices, values, accumulate)
self._generate_fallback_call(ir_node, args)
def _generate_scatter_fallback(self, line: WrapperLine) -> None:
assert isinstance(line, ScatterFallbackLine)
ir_node = line.node
@ -666,13 +703,7 @@ class FxConverter:
if reduce := ir_node.kwargs.get("reduce"):
kwargs["reduce"] = reduce
fx_node = self.gm.graph.call_function(
ir_node.op_overload, # type: ignore[arg-type]
args=args,
kwargs=kwargs,
)
result_buffer = ir_node.codegen_reference()
self.buffer_to_node[result_buffer] = fx_node
self._generate_fallback_call(ir_node, args, kwargs)
def _generate_null(self, line: WrapperLine) -> None:
assert isinstance(line, NullLine)

View File

@ -7115,19 +7115,7 @@ class IndexPutFallback(ExternKernel):
"""
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
assert is_node_sequence(self.inputs)
(x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
indices = []
iter_valid_indices = iter(valid_indices)
for i, _ in enumerate(self.indices):
if self.indices[i] is not None:
indices.append(next(iter_valid_indices))
else:
indices.append(V.graph.wrapper_code.none_str)
wrapper.generate_index_put_fallback(
self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
)
wrapper.generate_index_put_fallback(self)
def should_allocate(self) -> bool:
return False

View File

@ -3717,8 +3717,8 @@ def index_put_as_masked_fill(self, indices, value, accumulate):
def index_put_fallback(self, indices, values, accumulate):
assert isinstance(V.graph.current_node.target, torch._ops.OpOverload)
ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate)
op_overload = getattr(aten.index_put_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
ir.IndexPutFallback(op_overload, self, indices, values, accumulate)
return self