[hop] require hops to override __call__. (#134352)

Fixes https://github.com/pytorch/pytorch/issues/133719 by making `__call__` of hops an abstractmethod.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134352
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2024-08-27 16:24:42 -07:00
committed by PyTorch MergeBot
parent 66c33d5989
commit b07d0a22f5
13 changed files with 64 additions and 5 deletions

View File

@ -6441,6 +6441,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
def __init__(self):
super().__init__("_fallthrough_test_only")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
test_op = _FallthroughTestOnly()
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
self.assertTrue(

View File

@ -3710,6 +3710,13 @@ def forward(self, l_inp_, l_tmp_):
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),))
self.assertEqual(cnt.frame_count, 3)
def test_hop_raises_if_not_overriding_call(self):
class WrongHop(torch._ops.HigherOrderOperator):
pass
with self.assertRaisesRegex(TypeError, "WrongHop"):
wrong_hop = WrongHop("wrong_hop")
_hop_schema_test_schema_types = [
"bool",

View File

@ -4993,6 +4993,9 @@ def construct_sum_pyop():
def __init__(self):
super().__init__("mysum")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
mysum = MySum()
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)

View File

@ -52,6 +52,9 @@ class TraceWrapped(HigherOrderOperator):
def __init__(self):
super().__init__("trace_wrapped")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
# TODO(jansel): need to ensure this does not get DCEed
_trace_wrapped_op = TraceWrapped()

View File

@ -16,6 +16,9 @@ class ExportTracepoint(HigherOrderOperator):
def __init__(self):
super().__init__("_export_tracepoint")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
_export_tracepoint = ExportTracepoint()

View File

@ -29,6 +29,9 @@ class ExecutorchCallDelegate(HigherOrderOperator):
def __init__(self):
super().__init__("executorch_call_delegate")
def __call__(self, lowered_module, *args):
return super().__call__(lowered_module, *args)
executorch_call_delegate = ExecutorchCallDelegate()
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)

View File

@ -32,6 +32,9 @@ from .utils import (
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
# remove the wrapper call when it's ready.
class MapWrapper(HigherOrderOperator):
def __init__(self):
super().__init__("map")
def __call__(self, xs, *args):
return map_wrapper(xs, *args)
@ -40,8 +43,11 @@ class MapImpl(HigherOrderOperator):
def __init__(self):
super().__init__("map_impl")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
map = MapWrapper("map")
map = MapWrapper()
map_impl = MapImpl()

View File

@ -12,6 +12,9 @@ class RunConstGraph(HigherOrderOperator):
def __init__(self):
super().__init__("run_const_graph")
def __call__(self, *args):
return super().__call__(*args)
run_const_graph = RunConstGraph()

View File

@ -32,6 +32,9 @@ class StrictMode(HigherOrderOperator):
def __init__(self):
super().__init__("strict_mode")
def __call__(self, callable, operands):
return super().__call__(callable, operands)
strict_mode_op = StrictMode()

View File

@ -26,6 +26,9 @@ class CallTorchBind(HigherOrderOperator):
def __init__(self):
super().__init__("call_torchbind")
def __call__(self, obj, method, *args, **kwargs):
return super().__call__(obj, method, *args, **kwargs)
call_torchbind = CallTorchBind()

View File

@ -522,6 +522,14 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_mutation")
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
return super().__call__(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
kwargs=kwargs,
)
triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
@ -531,6 +539,15 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_functional")
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone):
return super().__call__(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
kwargs=kwargs,
tensors_to_clone=tensors_to_clone,
)
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import abc
import contextlib
import ctypes
import importlib
@ -238,7 +239,7 @@ _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
]
class HigherOrderOperator(OperatorBase):
class HigherOrderOperator(OperatorBase, abc.ABC):
# The HigherOrderOperator will appear as torch.ops.higher_order.{name}
#
# If you're creating a new HigherOrderOperator, please do not change the
@ -410,6 +411,7 @@ class HigherOrderOperator(OperatorBase):
assert not isinstance(kernel, DispatchKey)
return kernel(*args, **kwargs)
@abc.abstractmethod
def __call__(self, /, *args, **kwargs):
# Dynamo already traces the body of HigherOrderOp beforehand when it
# so no need to trace into it.
@ -433,9 +435,6 @@ class HigherOrderOperator(OperatorBase):
def __str__(self):
return f"{self.name()}"
# def __repr__(self):
# return f"torch.ops._higher_order_ops.{self._name}"
def name(self):
return self._name

View File

@ -156,6 +156,9 @@ def register_run_and_save_rng_state_op():
def __init__(self):
super().__init__("run_and_save_rng_state")
def __call__(self, op, *args, **kwargs):
return super().__call__(op, *args, **kwargs)
run_and_save_rng_state = RunAndSaveRngState()
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
@ -217,6 +220,9 @@ def register_run_with_rng_state_op():
def __init__(self):
super().__init__("run_with_rng_state")
def __call__(self, rng_state, op, *args, **kwargs):
return super().__call__(rng_state, op, *args, **kwargs)
run_with_rng_state = RunWithRngState()
run_with_rng_state.py_impl(DispatchKey.Autograd)(