Compare commits

...

1 Commits

Author SHA1 Message Date
42942afc86 [Inductor] Allow passing in custom lowering dict to register_lowering()
ghstack-source-id: 1f5e1a4b23d9ef1c752432ae142c46517639cf7e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154344
2025-05-26 09:02:37 -07:00
2 changed files with 20 additions and 1 deletions

View File

@ -140,6 +140,22 @@ class TestCustomLowering(InductorTestCase):
torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None
)(add_custom_lowering)
def test_register_lowering_custom_dict(self):
custom_lowering_dict = {}
from torch._inductor.lowering import register_lowering
@torch.library.custom_op("helion_test::foo", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
return x
@register_lowering(torch.ops.helion_test.foo, lowering_dict=custom_lowering_dict)
def foo_lowering(x):
return x
assert torch.ops.helion_test.foo in custom_lowering_dict, f"Expected custom lowering for helion_test::foo to be registered, but it was not found in {custom_lowering_dict.keys()}"
assert torch.ops.helion_test.foo not in torch._inductor.lowering.lowerings
@requires_gpu()
@skipIf(GPU_TYPE == "mps", "Not applicable to MPS")
def test_jagged_to_padded_dense_sanity_cuda(self):

View File

@ -405,6 +405,7 @@ def _register_lowering(
broadcast,
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
convert_input_to_bool,
lowering_dict,
):
"""
Add a lowering to lowerings dict
@ -449,7 +450,7 @@ def _register_lowering(
aten_fn = get_overloads(aten_fn)
lowerings.update(dict.fromkeys(aten_fn, wrapped))
lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
return wrapped
@ -460,6 +461,7 @@ def register_lowering(
ELEMENTWISE_TYPE_PROMOTION_KIND
] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
convert_input_to_bool=False,
lowering_dict=lowerings,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
Shim to support decorator syntax.
@ -470,6 +472,7 @@ def register_lowering(
broadcast=broadcast,
type_promotion_kind=type_promotion_kind,
convert_input_to_bool=convert_input_to_bool,
lowering_dict=lowerings,
)