[Inductor] Allow passing in custom lowering dict to register_lowering() (#154344)

This PR adds support for passing in custom lowering dict to `register_lowering()`, which allows systems (e.g. Helion, https://github.com/pytorch-labs/helion/pull/80) that uses Inductor to maintain their own lowering dict instead of using the Inductor global `lowerings` dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154344
Approved by: https://github.com/jansel
This commit is contained in:
Will Feng
2025-05-26 22:38:43 +00:00
committed by PyTorch MergeBot
parent 3936e6141c
commit 100ec0b34a
2 changed files with 22 additions and 1 deletions

View File

@ -140,6 +140,24 @@ class TestCustomLowering(InductorTestCase):
torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None
)(add_custom_lowering)
def test_register_lowering_custom_dict(self):
custom_lowering_dict = {}
from torch._inductor.lowering import register_lowering
@torch.library.custom_op("helion_test::foo", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
return x
@register_lowering(
torch.ops.helion_test.foo, lowering_dict=custom_lowering_dict
)
def foo_lowering(x):
return x
assert torch.ops.helion_test.foo in custom_lowering_dict
assert torch.ops.helion_test.foo not in torch._inductor.lowering.lowerings
@requires_gpu()
@skipIf(GPU_TYPE == "mps", "Not applicable to MPS")
def test_jagged_to_padded_dense_sanity_cuda(self):

View File

@ -405,6 +405,7 @@ def _register_lowering(
broadcast,
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
convert_input_to_bool,
lowering_dict,
):
"""
Add a lowering to lowerings dict
@ -449,7 +450,7 @@ def _register_lowering(
aten_fn = get_overloads(aten_fn)
lowerings.update(dict.fromkeys(aten_fn, wrapped))
lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
return wrapped
@ -460,6 +461,7 @@ def register_lowering(
ELEMENTWISE_TYPE_PROMOTION_KIND
] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
convert_input_to_bool=False,
lowering_dict=lowerings,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
Shim to support decorator syntax.
@ -470,6 +472,7 @@ def register_lowering(
broadcast=broadcast,
type_promotion_kind=type_promotion_kind,
convert_input_to_bool=convert_input_to_bool,
lowering_dict=lowering_dict,
)