diff --git a/test/inductor/test_custom_lowering.py b/test/inductor/test_custom_lowering.py index 4786a97429eb..1a91cc50e4ee 100644 --- a/test/inductor/test_custom_lowering.py +++ b/test/inductor/test_custom_lowering.py @@ -140,6 +140,24 @@ class TestCustomLowering(InductorTestCase): torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None )(add_custom_lowering) + def test_register_lowering_custom_dict(self): + custom_lowering_dict = {} + + from torch._inductor.lowering import register_lowering + + @torch.library.custom_op("helion_test::foo", mutates_args={}) + def foo(x: torch.Tensor) -> torch.Tensor: + return x + + @register_lowering( + torch.ops.helion_test.foo, lowering_dict=custom_lowering_dict + ) + def foo_lowering(x): + return x + + assert torch.ops.helion_test.foo in custom_lowering_dict + assert torch.ops.helion_test.foo not in torch._inductor.lowering.lowerings + @requires_gpu() @skipIf(GPU_TYPE == "mps", "Not applicable to MPS") def test_jagged_to_padded_dense_sanity_cuda(self): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 4648bfe96fc5..87b50c6bd0f2 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -405,6 +405,7 @@ def _register_lowering( broadcast, type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], convert_input_to_bool, + lowering_dict, ): """ Add a lowering to lowerings dict @@ -449,7 +450,7 @@ def _register_lowering( aten_fn = get_overloads(aten_fn) - lowerings.update(dict.fromkeys(aten_fn, wrapped)) + lowering_dict.update(dict.fromkeys(aten_fn, wrapped)) return wrapped @@ -460,6 +461,7 @@ def register_lowering( ELEMENTWISE_TYPE_PROMOTION_KIND ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, convert_input_to_bool=False, + lowering_dict=lowerings, ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: """ Shim to support decorator syntax. @@ -470,6 +472,7 @@ def register_lowering( broadcast=broadcast, type_promotion_kind=type_promotion_kind, convert_input_to_bool=convert_input_to_bool, + lowering_dict=lowering_dict, )