diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index d2f274f2e412..e7e857fc0dc3 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -591,6 +591,24 @@ class FxirTestCase(InductorTestCase): num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value) self.assertEqual(num_fallback, 1) + def test_index_put_fallback(self): + """ + Test the deterministic fallback for index_put. + """ + length = 8 + out, values = [torch.randn(length, device=self.device) for _ in range(2)] + indices = (torch.randint(length, (length,), device=self.device),) + accumulate = True + with DeterministicGuard(True): + (gm,) = self._compile_and_check( + torch.index_put, + (out, indices, values, accumulate), + expected_num_triton_kernels=1, + ) + + # Check for the fallback op. + self.assertEqual(self._count_ops(gm, torch.ops.aten.index_put_.default), 1) + def test_scatter_reduce_fallback(self): """ Test the customized wrapper codegen for ScatterFallback ops. diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index df162b806b73..759eb3da462c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1439,7 +1439,7 @@ class CppWrapperCpu(PythonWrapperCodegen): line += ");" self.writeline(line) - def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate): # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version # See the comment in codegen_reinterpret_view about why having something like # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 086a9bc37a6d..9749d09a1af2 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -735,10 +735,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu): line += ");" self.writeline(line) - def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None: # No stack allocation when there is a fallback op self.allow_stack_allocation = False + super().generate_index_put_fallback(node) + def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate): self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version # See the comment in codegen_reinterpret_view about why having something like diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 6154df7dccf3..0bd059f19565 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -908,6 +908,29 @@ class MultiOutputLine(WrapperLine): return converter._generate_multi_output +@dataclasses.dataclass +class IndexPutFallbackLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.IndexPutFallback + indices: list[Optional[ir.IRNode]] + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + assert ir.is_node_sequence(node.inputs) + (x, values) = (t.codegen_reference() for t in node.inputs[:2]) + indices = [ + idx.codegen_reference() if idx else self.wrapper.none_str + for idx in self.indices + ] + + self.wrapper._generate_index_put_fallback( + node.get_kernel_name(), x, indices, values, *node.codegen_const_args() + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_index_put_fallback + + @dataclasses.dataclass class ScatterFallbackLine(WrapperLine): wrapper: PythonWrapperCodegen @@ -1560,7 +1583,22 @@ class PythonWrapperCodegen(CodeGen): line += ")" self.writeline(line) - def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None: + # Collect index tensors into a list. + indices: list[Optional[ir.IRNode]] = [] + valid_indices = node.inputs[2:] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(node.indices): + if node.indices[i] is not None: + index = next(iter_valid_indices) + assert isinstance(index, ir.IRNode) + indices.append(index) + else: + indices.append(None) + + self.writeline(IndexPutFallbackLine(self, node, indices)) + + def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate): indices_str = f"[{', '.join(indices)}]" args = [x, indices_str, values, accumulate] self.writeline(self.wrap_kernel_call(kernel, args)) diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 133e30709645..9bc7a98fadb6 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -55,6 +55,7 @@ from .wrapper import ( ExternKernelOutLine, FreeIfNotReusedLine, FreeLine, + IndexPutFallbackLine, KernelCallLine, KernelDefinitionLine, Line, @@ -654,6 +655,42 @@ class FxConverter: node.name = line.result_name self.buffer_to_node[line.result_name] = node + def _generate_fallback_call( + self, + ir_node: ir.ExternKernel, + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + ) -> None: + fx_node = self.gm.graph.call_function( + ir_node.op_overload, # type: ignore[arg-type] + args=args, + kwargs=kwargs, + ) + result_buffer = ir_node.codegen_reference() + self.buffer_to_node[result_buffer] = fx_node + + def _generate_index_put_fallback(self, line: WrapperLine) -> None: + assert isinstance(line, IndexPutFallbackLine) + ir_node = line.node + + def generate_buffer_or_none( + x: Union[ir.IRNode, Sequence[ir.IRNode], None], + ) -> Optional[torch.fx.Node]: + """ + Handles None before calling _generate_buffer. + """ + if x is None: + return None + + assert isinstance(x, ir.IRNode) + return self._generate_buffer(x) + + (x, values) = [generate_buffer_or_none(t) for t in ir_node.inputs[:2]] + indices = tuple(generate_buffer_or_none(t) for t in line.indices) + accumulate = ir_node.constant_args[0] + args = (x, indices, values, accumulate) + self._generate_fallback_call(ir_node, args) + def _generate_scatter_fallback(self, line: WrapperLine) -> None: assert isinstance(line, ScatterFallbackLine) ir_node = line.node @@ -666,13 +703,7 @@ class FxConverter: if reduce := ir_node.kwargs.get("reduce"): kwargs["reduce"] = reduce - fx_node = self.gm.graph.call_function( - ir_node.op_overload, # type: ignore[arg-type] - args=args, - kwargs=kwargs, - ) - result_buffer = ir_node.codegen_reference() - self.buffer_to_node[result_buffer] = fx_node + self._generate_fallback_call(ir_node, args, kwargs) def _generate_null(self, line: WrapperLine) -> None: assert isinstance(line, NullLine) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 9aece7015b97..6cab868b916b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7115,19 +7115,7 @@ class IndexPutFallback(ExternKernel): """ def codegen(self, wrapper: PythonWrapperCodegen) -> None: - assert is_node_sequence(self.inputs) - (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) - indices = [] - iter_valid_indices = iter(valid_indices) - for i, _ in enumerate(self.indices): - if self.indices[i] is not None: - indices.append(next(iter_valid_indices)) - else: - indices.append(V.graph.wrapper_code.none_str) - - wrapper.generate_index_put_fallback( - self.get_kernel_name(), x, indices, values, *self.codegen_const_args() - ) + wrapper.generate_index_put_fallback(self) def should_allocate(self) -> bool: return False diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d05bdd135469..eec60b31f2eb 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3717,8 +3717,8 @@ def index_put_as_masked_fill(self, indices, value, accumulate): def index_put_fallback(self, indices, values, accumulate): - assert isinstance(V.graph.current_node.target, torch._ops.OpOverload) - ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + op_overload = getattr(aten.index_put_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + ir.IndexPutFallback(op_overload, self, indices, values, accumulate) return self