[hop] ban creating hop by directly instantiating HigherOrderOperator. (#133645)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133645
Approved by: https://github.com/zou3519
ghstack dependencies: #133521
This commit is contained in:
Yidi Wu
2024-08-20 15:52:46 -07:00
committed by PyTorch MergeBot
parent 6835f20d20
commit 696107efcb
11 changed files with 68 additions and 11 deletions

View File

@ -6299,7 +6299,11 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
self._validate(fn, backend, x)
def test_override_fallthrough_dispatch_key(self):
test_op = torch._ops.HigherOrderOperator("_fallthrough_test_only")
class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
def __init__(self):
super().__init__("_fallthrough_test_only")
test_op = _FallthroughTestOnly()
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
self.assertTrue(
not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)

View File

@ -4989,7 +4989,11 @@ def forward(self, x_1):
def construct_sum_pyop():
mysum = HigherOrderOperator("mysum")
class MySum(HigherOrderOperator):
def __init__(self):
super().__init__("mysum")
mysum = MySum()
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)
def mysum_batch_rule(interpreter, x, dim):

View File

@ -48,8 +48,13 @@ def trace_wrapped(*args, **kwargs):
return _trace_wrapped_op(*args, **kwargs)
class TraceWrapped(HigherOrderOperator):
def __init__(self):
super().__init__("trace_wrapped")
# TODO(jansel): need to ensure this does not get DCEed
_trace_wrapped_op = HigherOrderOperator("trace_wrapped")
_trace_wrapped_op = TraceWrapped()
def _assert_meta(grad, size, stride, dtype):

View File

@ -12,7 +12,12 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
from torch.utils import _pytree as pytree
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
class ExportTracepoint(HigherOrderOperator):
def __init__(self):
super().__init__("export_tracepoint")
_export_tracepoint = ExportTracepoint()
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)

View File

@ -25,7 +25,12 @@ from torch.fx.experimental.proxy_tensor import (
from torch.utils._pytree import tree_flatten
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
class ExecutorchCallDelegate(HigherOrderOperator):
def __init__(self):
super().__init__("executorch_call_delegate")
executorch_call_delegate = ExecutorchCallDelegate()
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)

View File

@ -36,8 +36,14 @@ class MapWrapper(HigherOrderOperator):
return map_wrapper(xs, *args)
class MapImpl(HigherOrderOperator):
def __init__(self):
super().__init__("map_impl")
map = MapWrapper("map")
map_impl = HigherOrderOperator("map_impl")
map_impl = MapImpl()
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]

View File

@ -8,7 +8,12 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
from torch.utils import _pytree as pytree
run_const_graph = HigherOrderOperator("run_const_graph")
class RunConstGraph(HigherOrderOperator):
def __init__(self):
super().__init__("run_const_graph")
run_const_graph = RunConstGraph()
@run_const_graph.py_impl(ProxyTorchDispatchMode)

View File

@ -28,7 +28,12 @@ def strict_mode(callable, operands):
)
strict_mode_op = HigherOrderOperator("strict_mode")
class StrictMode(HigherOrderOperator):
def __init__(self):
super().__init__("strict_mode")
strict_mode_op = StrictMode()
@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)

View File

@ -16,12 +16,18 @@ from torch.utils import _pytree as pytree
log = logging.getLogger(__name__)
# The call_torchbind operator represents a method invocation on a torchbind
# object. The calling convention is:
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
# We do not expect users to write this operator directly. Instead it will be
# emitted by Dynamo when tracing encounters a torchbind object.
call_torchbind = HigherOrderOperator("call_torchbind")
class CallTorchBind(HigherOrderOperator):
def __init__(self):
super().__init__("call_torchbind")
call_torchbind = CallTorchBind()
# Register this operator as side-effectful with FX.
# TODO: this is not really sufficient. While passes (hopefully) check

View File

@ -246,6 +246,10 @@ class HigherOrderOperator(OperatorBase):
# practice due to name collisions.
def __init__(self, name):
super().__init__()
if type(self) is HigherOrderOperator:
raise RuntimeError(
"Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
)
self._name = name
# Make _OPNamespace not scream, this whole name based association needs a good hard look

View File

@ -148,7 +148,11 @@ def get_device(args, kwargs):
def register_run_and_save_rng_state_op():
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
class RunAndSaveRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_and_save_rng_state")
run_and_save_rng_state = RunAndSaveRngState()
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
@ -190,7 +194,11 @@ def register_run_and_save_rng_state_op():
def register_run_with_rng_state_op():
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
class RunWithRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_with_rng_state")
run_with_rng_state = RunWithRngState()
run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_with_rng_state, deferred_error=True)