mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor] Allow passing in custom lowering dict to register_lowering() (#154344)
This PR adds support for passing in custom lowering dict to `register_lowering()`, which allows systems (e.g. Helion, https://github.com/pytorch-labs/helion/pull/80) that uses Inductor to maintain their own lowering dict instead of using the Inductor global `lowerings` dict. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154344 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
3936e6141c
commit
100ec0b34a
@ -140,6 +140,24 @@ class TestCustomLowering(InductorTestCase):
|
||||
torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None
|
||||
)(add_custom_lowering)
|
||||
|
||||
def test_register_lowering_custom_dict(self):
|
||||
custom_lowering_dict = {}
|
||||
|
||||
from torch._inductor.lowering import register_lowering
|
||||
|
||||
@torch.library.custom_op("helion_test::foo", mutates_args={})
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
@register_lowering(
|
||||
torch.ops.helion_test.foo, lowering_dict=custom_lowering_dict
|
||||
)
|
||||
def foo_lowering(x):
|
||||
return x
|
||||
|
||||
assert torch.ops.helion_test.foo in custom_lowering_dict
|
||||
assert torch.ops.helion_test.foo not in torch._inductor.lowering.lowerings
|
||||
|
||||
@requires_gpu()
|
||||
@skipIf(GPU_TYPE == "mps", "Not applicable to MPS")
|
||||
def test_jagged_to_padded_dense_sanity_cuda(self):
|
||||
|
@ -405,6 +405,7 @@ def _register_lowering(
|
||||
broadcast,
|
||||
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
|
||||
convert_input_to_bool,
|
||||
lowering_dict,
|
||||
):
|
||||
"""
|
||||
Add a lowering to lowerings dict
|
||||
@ -449,7 +450,7 @@ def _register_lowering(
|
||||
|
||||
aten_fn = get_overloads(aten_fn)
|
||||
|
||||
lowerings.update(dict.fromkeys(aten_fn, wrapped))
|
||||
lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
|
||||
return wrapped
|
||||
|
||||
|
||||
@ -460,6 +461,7 @@ def register_lowering(
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
convert_input_to_bool=False,
|
||||
lowering_dict=lowerings,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
"""
|
||||
Shim to support decorator syntax.
|
||||
@ -470,6 +472,7 @@ def register_lowering(
|
||||
broadcast=broadcast,
|
||||
type_promotion_kind=type_promotion_kind,
|
||||
convert_input_to_bool=convert_input_to_bool,
|
||||
lowering_dict=lowering_dict,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user