mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop] require hops to override __call__. (#134352)
Fixes https://github.com/pytorch/pytorch/issues/133719 by making `__call__` of hops an abstractmethod. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134352 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
66c33d5989
commit
b07d0a22f5
@ -6441,6 +6441,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
def __init__(self):
|
||||
super().__init__("_fallthrough_test_only")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
test_op = _FallthroughTestOnly()
|
||||
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
|
||||
self.assertTrue(
|
||||
|
@ -3710,6 +3710,13 @@ def forward(self, l_inp_, l_tmp_):
|
||||
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),))
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
def test_hop_raises_if_not_overriding_call(self):
|
||||
class WrongHop(torch._ops.HigherOrderOperator):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "WrongHop"):
|
||||
wrong_hop = WrongHop("wrong_hop")
|
||||
|
||||
|
||||
_hop_schema_test_schema_types = [
|
||||
"bool",
|
||||
|
@ -4993,6 +4993,9 @@ def construct_sum_pyop():
|
||||
def __init__(self):
|
||||
super().__init__("mysum")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
mysum = MySum()
|
||||
|
||||
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)
|
||||
|
@ -52,6 +52,9 @@ class TraceWrapped(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("trace_wrapped")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
# TODO(jansel): need to ensure this does not get DCEed
|
||||
_trace_wrapped_op = TraceWrapped()
|
||||
|
@ -16,6 +16,9 @@ class ExportTracepoint(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("_export_tracepoint")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
_export_tracepoint = ExportTracepoint()
|
||||
|
||||
|
@ -29,6 +29,9 @@ class ExecutorchCallDelegate(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("executorch_call_delegate")
|
||||
|
||||
def __call__(self, lowered_module, *args):
|
||||
return super().__call__(lowered_module, *args)
|
||||
|
||||
|
||||
executorch_call_delegate = ExecutorchCallDelegate()
|
||||
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
|
||||
|
@ -32,6 +32,9 @@ from .utils import (
|
||||
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
|
||||
# remove the wrapper call when it's ready.
|
||||
class MapWrapper(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("map")
|
||||
|
||||
def __call__(self, xs, *args):
|
||||
return map_wrapper(xs, *args)
|
||||
|
||||
@ -40,8 +43,11 @@ class MapImpl(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("map_impl")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
map = MapWrapper("map")
|
||||
|
||||
map = MapWrapper()
|
||||
|
||||
map_impl = MapImpl()
|
||||
|
||||
|
@ -12,6 +12,9 @@ class RunConstGraph(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("run_const_graph")
|
||||
|
||||
def __call__(self, *args):
|
||||
return super().__call__(*args)
|
||||
|
||||
|
||||
run_const_graph = RunConstGraph()
|
||||
|
||||
|
@ -32,6 +32,9 @@ class StrictMode(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("strict_mode")
|
||||
|
||||
def __call__(self, callable, operands):
|
||||
return super().__call__(callable, operands)
|
||||
|
||||
|
||||
strict_mode_op = StrictMode()
|
||||
|
||||
|
@ -26,6 +26,9 @@ class CallTorchBind(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("call_torchbind")
|
||||
|
||||
def __call__(self, obj, method, *args, **kwargs):
|
||||
return super().__call__(obj, method, *args, **kwargs)
|
||||
|
||||
|
||||
call_torchbind = CallTorchBind()
|
||||
|
||||
|
@ -522,6 +522,14 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("triton_kernel_wrapper_mutation")
|
||||
|
||||
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs):
|
||||
return super().__call__(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
|
||||
|
||||
@ -531,6 +539,15 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("triton_kernel_wrapper_functional")
|
||||
|
||||
def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone):
|
||||
return super().__call__(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args_idx=constant_args_idx,
|
||||
grid=grid,
|
||||
kwargs=kwargs,
|
||||
tensors_to_clone=tensors_to_clone,
|
||||
)
|
||||
|
||||
|
||||
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import abc
|
||||
import contextlib
|
||||
import ctypes
|
||||
import importlib
|
||||
@ -238,7 +239,7 @@ _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
|
||||
]
|
||||
|
||||
|
||||
class HigherOrderOperator(OperatorBase):
|
||||
class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||
# The HigherOrderOperator will appear as torch.ops.higher_order.{name}
|
||||
#
|
||||
# If you're creating a new HigherOrderOperator, please do not change the
|
||||
@ -410,6 +411,7 @@ class HigherOrderOperator(OperatorBase):
|
||||
assert not isinstance(kernel, DispatchKey)
|
||||
return kernel(*args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, /, *args, **kwargs):
|
||||
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
||||
# so no need to trace into it.
|
||||
@ -433,9 +435,6 @@ class HigherOrderOperator(OperatorBase):
|
||||
def __str__(self):
|
||||
return f"{self.name()}"
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"torch.ops._higher_order_ops.{self._name}"
|
||||
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
|
@ -156,6 +156,9 @@ def register_run_and_save_rng_state_op():
|
||||
def __init__(self):
|
||||
super().__init__("run_and_save_rng_state")
|
||||
|
||||
def __call__(self, op, *args, **kwargs):
|
||||
return super().__call__(op, *args, **kwargs)
|
||||
|
||||
run_and_save_rng_state = RunAndSaveRngState()
|
||||
|
||||
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
|
||||
@ -217,6 +220,9 @@ def register_run_with_rng_state_op():
|
||||
def __init__(self):
|
||||
super().__init__("run_with_rng_state")
|
||||
|
||||
def __call__(self, rng_state, op, *args, **kwargs):
|
||||
return super().__call__(rng_state, op, *args, **kwargs)
|
||||
|
||||
run_with_rng_state = RunWithRngState()
|
||||
|
||||
run_with_rng_state.py_impl(DispatchKey.Autograd)(
|
||||
|
Reference in New Issue
Block a user