mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop] ban creating hop by directly instantiating HigherOrderOperator. (#133645)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133645 Approved by: https://github.com/zou3519 ghstack dependencies: #133521
This commit is contained in:
committed by
PyTorch MergeBot
parent
6835f20d20
commit
696107efcb
@ -6299,7 +6299,11 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
def test_override_fallthrough_dispatch_key(self):
|
||||
test_op = torch._ops.HigherOrderOperator("_fallthrough_test_only")
|
||||
class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("_fallthrough_test_only")
|
||||
|
||||
test_op = _FallthroughTestOnly()
|
||||
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
|
||||
self.assertTrue(
|
||||
not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)
|
||||
|
@ -4989,7 +4989,11 @@ def forward(self, x_1):
|
||||
|
||||
|
||||
def construct_sum_pyop():
|
||||
mysum = HigherOrderOperator("mysum")
|
||||
class MySum(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("mysum")
|
||||
|
||||
mysum = MySum()
|
||||
|
||||
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)
|
||||
def mysum_batch_rule(interpreter, x, dim):
|
||||
|
@ -48,8 +48,13 @@ def trace_wrapped(*args, **kwargs):
|
||||
return _trace_wrapped_op(*args, **kwargs)
|
||||
|
||||
|
||||
class TraceWrapped(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("trace_wrapped")
|
||||
|
||||
|
||||
# TODO(jansel): need to ensure this does not get DCEed
|
||||
_trace_wrapped_op = HigherOrderOperator("trace_wrapped")
|
||||
_trace_wrapped_op = TraceWrapped()
|
||||
|
||||
|
||||
def _assert_meta(grad, size, stride, dtype):
|
||||
|
@ -12,7 +12,12 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
|
||||
class ExportTracepoint(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("export_tracepoint")
|
||||
|
||||
|
||||
_export_tracepoint = ExportTracepoint()
|
||||
|
||||
|
||||
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
|
||||
|
@ -25,7 +25,12 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
|
||||
class ExecutorchCallDelegate(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("executorch_call_delegate")
|
||||
|
||||
|
||||
executorch_call_delegate = ExecutorchCallDelegate()
|
||||
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
|
||||
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
|
||||
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
|
||||
|
@ -36,8 +36,14 @@ class MapWrapper(HigherOrderOperator):
|
||||
return map_wrapper(xs, *args)
|
||||
|
||||
|
||||
class MapImpl(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("map_impl")
|
||||
|
||||
|
||||
map = MapWrapper("map")
|
||||
map_impl = HigherOrderOperator("map_impl")
|
||||
|
||||
map_impl = MapImpl()
|
||||
|
||||
dummy_aot_config = AOTConfig(
|
||||
fw_compiler=None, # type: ignore[arg-type]
|
||||
|
@ -8,7 +8,12 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
run_const_graph = HigherOrderOperator("run_const_graph")
|
||||
class RunConstGraph(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("run_const_graph")
|
||||
|
||||
|
||||
run_const_graph = RunConstGraph()
|
||||
|
||||
|
||||
@run_const_graph.py_impl(ProxyTorchDispatchMode)
|
||||
|
@ -28,7 +28,12 @@ def strict_mode(callable, operands):
|
||||
)
|
||||
|
||||
|
||||
strict_mode_op = HigherOrderOperator("strict_mode")
|
||||
class StrictMode(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("strict_mode")
|
||||
|
||||
|
||||
strict_mode_op = StrictMode()
|
||||
|
||||
|
||||
@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
|
@ -16,12 +16,18 @@ from torch.utils import _pytree as pytree
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# The call_torchbind operator represents a method invocation on a torchbind
|
||||
# object. The calling convention is:
|
||||
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
|
||||
# We do not expect users to write this operator directly. Instead it will be
|
||||
# emitted by Dynamo when tracing encounters a torchbind object.
|
||||
call_torchbind = HigherOrderOperator("call_torchbind")
|
||||
class CallTorchBind(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("call_torchbind")
|
||||
|
||||
|
||||
call_torchbind = CallTorchBind()
|
||||
|
||||
# Register this operator as side-effectful with FX.
|
||||
# TODO: this is not really sufficient. While passes (hopefully) check
|
||||
|
@ -246,6 +246,10 @@ class HigherOrderOperator(OperatorBase):
|
||||
# practice due to name collisions.
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
if type(self) is HigherOrderOperator:
|
||||
raise RuntimeError(
|
||||
"Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
|
||||
)
|
||||
self._name = name
|
||||
|
||||
# Make _OPNamespace not scream, this whole name based association needs a good hard look
|
||||
|
@ -148,7 +148,11 @@ def get_device(args, kwargs):
|
||||
|
||||
|
||||
def register_run_and_save_rng_state_op():
|
||||
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
|
||||
class RunAndSaveRngState(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("run_and_save_rng_state")
|
||||
|
||||
run_and_save_rng_state = RunAndSaveRngState()
|
||||
|
||||
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
|
||||
@ -190,7 +194,11 @@ def register_run_and_save_rng_state_op():
|
||||
|
||||
|
||||
def register_run_with_rng_state_op():
|
||||
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
|
||||
class RunWithRngState(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("run_with_rng_state")
|
||||
|
||||
run_with_rng_state = RunWithRngState()
|
||||
|
||||
run_with_rng_state.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(run_with_rng_state, deferred_error=True)
|
||||
|
Reference in New Issue
Block a user