mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor-FX] Support IndexPutFallback (#162863)
# Feature This PR supports lowering `IndexPutFallback` through Inductor's FX converter. The approach is very similar to the one taken in https://github.com/pytorch/pytorch/pull/162686. Compared to `ScatterFallback`, this required one additional change: the value of `self.op_overload` for `IndexPutFallback` was inaccurate. Previously, it used `aten.index_put`, which would result in unsound FX IR. The existing Python/C++ codegen use `aten.index_put_`, since the fallback mutates its input. This PR changes `self.op_overload` to match that. # Test plan Added a CI test lowering deterministic index put via the FX converter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162863 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
de143bf79b
commit
9aca0ba027
@ -591,6 +591,24 @@ class FxirTestCase(InductorTestCase):
|
||||
num_fallback = self._count_ops(gm, torch.ops.aten.scatter_.value)
|
||||
self.assertEqual(num_fallback, 1)
|
||||
|
||||
def test_index_put_fallback(self):
|
||||
"""
|
||||
Test the deterministic fallback for index_put.
|
||||
"""
|
||||
length = 8
|
||||
out, values = [torch.randn(length, device=self.device) for _ in range(2)]
|
||||
indices = (torch.randint(length, (length,), device=self.device),)
|
||||
accumulate = True
|
||||
with DeterministicGuard(True):
|
||||
(gm,) = self._compile_and_check(
|
||||
torch.index_put,
|
||||
(out, indices, values, accumulate),
|
||||
expected_num_triton_kernels=1,
|
||||
)
|
||||
|
||||
# Check for the fallback op.
|
||||
self.assertEqual(self._count_ops(gm, torch.ops.aten.index_put_.default), 1)
|
||||
|
||||
def test_scatter_reduce_fallback(self):
|
||||
"""
|
||||
Test the customized wrapper codegen for ScatterFallback ops.
|
||||
|
@ -1439,7 +1439,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
line += ");"
|
||||
self.writeline(line)
|
||||
|
||||
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
||||
def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
||||
# TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
|
||||
# See the comment in codegen_reinterpret_view about why having something like
|
||||
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding
|
||||
|
@ -735,10 +735,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||
line += ");"
|
||||
self.writeline(line)
|
||||
|
||||
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
||||
def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None:
|
||||
# No stack allocation when there is a fallback op
|
||||
self.allow_stack_allocation = False
|
||||
super().generate_index_put_fallback(node)
|
||||
|
||||
def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
||||
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
||||
# TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
|
||||
# See the comment in codegen_reinterpret_view about why having something like
|
||||
|
@ -908,6 +908,29 @@ class MultiOutputLine(WrapperLine):
|
||||
return converter._generate_multi_output
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IndexPutFallbackLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: ir.IndexPutFallback
|
||||
indices: list[Optional[ir.IRNode]]
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
node = self.node
|
||||
assert ir.is_node_sequence(node.inputs)
|
||||
(x, values) = (t.codegen_reference() for t in node.inputs[:2])
|
||||
indices = [
|
||||
idx.codegen_reference() if idx else self.wrapper.none_str
|
||||
for idx in self.indices
|
||||
]
|
||||
|
||||
self.wrapper._generate_index_put_fallback(
|
||||
node.get_kernel_name(), x, indices, values, *node.codegen_const_args()
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_index_put_fallback
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScatterFallbackLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
@ -1560,7 +1583,22 @@ class PythonWrapperCodegen(CodeGen):
|
||||
line += ")"
|
||||
self.writeline(line)
|
||||
|
||||
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
||||
def generate_index_put_fallback(self, node: ir.IndexPutFallback) -> None:
|
||||
# Collect index tensors into a list.
|
||||
indices: list[Optional[ir.IRNode]] = []
|
||||
valid_indices = node.inputs[2:]
|
||||
iter_valid_indices = iter(valid_indices)
|
||||
for i, _ in enumerate(node.indices):
|
||||
if node.indices[i] is not None:
|
||||
index = next(iter_valid_indices)
|
||||
assert isinstance(index, ir.IRNode)
|
||||
indices.append(index)
|
||||
else:
|
||||
indices.append(None)
|
||||
|
||||
self.writeline(IndexPutFallbackLine(self, node, indices))
|
||||
|
||||
def _generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
|
||||
indices_str = f"[{', '.join(indices)}]"
|
||||
args = [x, indices_str, values, accumulate]
|
||||
self.writeline(self.wrap_kernel_call(kernel, args))
|
||||
|
@ -55,6 +55,7 @@ from .wrapper import (
|
||||
ExternKernelOutLine,
|
||||
FreeIfNotReusedLine,
|
||||
FreeLine,
|
||||
IndexPutFallbackLine,
|
||||
KernelCallLine,
|
||||
KernelDefinitionLine,
|
||||
Line,
|
||||
@ -654,6 +655,42 @@ class FxConverter:
|
||||
node.name = line.result_name
|
||||
self.buffer_to_node[line.result_name] = node
|
||||
|
||||
def _generate_fallback_call(
|
||||
self,
|
||||
ir_node: ir.ExternKernel,
|
||||
args: Optional[tuple[Any, ...]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
fx_node = self.gm.graph.call_function(
|
||||
ir_node.op_overload, # type: ignore[arg-type]
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
result_buffer = ir_node.codegen_reference()
|
||||
self.buffer_to_node[result_buffer] = fx_node
|
||||
|
||||
def _generate_index_put_fallback(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, IndexPutFallbackLine)
|
||||
ir_node = line.node
|
||||
|
||||
def generate_buffer_or_none(
|
||||
x: Union[ir.IRNode, Sequence[ir.IRNode], None],
|
||||
) -> Optional[torch.fx.Node]:
|
||||
"""
|
||||
Handles None before calling _generate_buffer.
|
||||
"""
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
assert isinstance(x, ir.IRNode)
|
||||
return self._generate_buffer(x)
|
||||
|
||||
(x, values) = [generate_buffer_or_none(t) for t in ir_node.inputs[:2]]
|
||||
indices = tuple(generate_buffer_or_none(t) for t in line.indices)
|
||||
accumulate = ir_node.constant_args[0]
|
||||
args = (x, indices, values, accumulate)
|
||||
self._generate_fallback_call(ir_node, args)
|
||||
|
||||
def _generate_scatter_fallback(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ScatterFallbackLine)
|
||||
ir_node = line.node
|
||||
@ -666,13 +703,7 @@ class FxConverter:
|
||||
if reduce := ir_node.kwargs.get("reduce"):
|
||||
kwargs["reduce"] = reduce
|
||||
|
||||
fx_node = self.gm.graph.call_function(
|
||||
ir_node.op_overload, # type: ignore[arg-type]
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
result_buffer = ir_node.codegen_reference()
|
||||
self.buffer_to_node[result_buffer] = fx_node
|
||||
self._generate_fallback_call(ir_node, args, kwargs)
|
||||
|
||||
def _generate_null(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, NullLine)
|
||||
|
@ -7115,19 +7115,7 @@ class IndexPutFallback(ExternKernel):
|
||||
"""
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
assert is_node_sequence(self.inputs)
|
||||
(x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
|
||||
indices = []
|
||||
iter_valid_indices = iter(valid_indices)
|
||||
for i, _ in enumerate(self.indices):
|
||||
if self.indices[i] is not None:
|
||||
indices.append(next(iter_valid_indices))
|
||||
else:
|
||||
indices.append(V.graph.wrapper_code.none_str)
|
||||
|
||||
wrapper.generate_index_put_fallback(
|
||||
self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
|
||||
)
|
||||
wrapper.generate_index_put_fallback(self)
|
||||
|
||||
def should_allocate(self) -> bool:
|
||||
return False
|
||||
|
@ -3717,8 +3717,8 @@ def index_put_as_masked_fill(self, indices, value, accumulate):
|
||||
|
||||
|
||||
def index_put_fallback(self, indices, values, accumulate):
|
||||
assert isinstance(V.graph.current_node.target, torch._ops.OpOverload)
|
||||
ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate)
|
||||
op_overload = getattr(aten.index_put_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
|
||||
ir.IndexPutFallback(op_overload, self, indices, values, accumulate)
|
||||
return self
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user