mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Title Test Plan: CI Reviewed By: zhxchen17 Differential Revision: D52889236 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117796 Approved by: https://github.com/angelayi
114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
from contextlib import contextmanager
|
|
|
|
import torch
|
|
import torch._custom_ops
|
|
from torch._C import DispatchKey
|
|
from torch._higher_order_ops.strict_mode import strict_mode
|
|
from torch._higher_order_ops.utils import autograd_not_implemented
|
|
from torch._ops import HigherOrderOperator
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
|
|
|
|
|
|
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
|
|
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
|
|
if not mode.enable_tracing:
|
|
return _export_tracepoint(*args, **kwargs)
|
|
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
|
|
proxy = mode.tracer.create_proxy(
|
|
"call_function", _export_tracepoint, p_args, p_kwargs
|
|
)
|
|
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
|
|
|
|
|
|
@_export_tracepoint.py_impl(FakeTensorMode)
|
|
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
|
|
with mode:
|
|
return args
|
|
|
|
|
|
@_export_tracepoint.py_functionalize_impl
|
|
def export_tracepoint_functional(ctx, *args, **kwargs):
|
|
unwrapped_args = ctx.unwrap_tensors(args)
|
|
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
|
|
|
with ctx.redispatch_to_next():
|
|
out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
|
|
return ctx.wrap_tensors(out)
|
|
|
|
|
|
_export_tracepoint.py_impl(DispatchKey.Autograd)(
|
|
autograd_not_implemented(_export_tracepoint, deferred_error=True)
|
|
)
|
|
|
|
|
|
@_export_tracepoint.py_impl(DispatchKey.CPU)
|
|
def export_tracepoint_cpu(*args, **kwargs):
|
|
return args
|
|
|
|
|
|
def _wrap_submodule(mod, path, module_call_specs):
|
|
assert isinstance(mod, torch.nn.Module)
|
|
assert path != ""
|
|
submodule = mod
|
|
for name in path.split("."):
|
|
if not hasattr(submodule, name):
|
|
raise RuntimeError(f"Couldn't find submodule at path {path}")
|
|
submodule = getattr(submodule, name)
|
|
|
|
def update_module_call_signatures(path, in_spec, out_spec):
|
|
if path in module_call_specs:
|
|
assert module_call_specs[path]["in_spec"] == in_spec
|
|
assert module_call_specs[path]["out_spec"] == out_spec
|
|
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
|
|
|
|
assert "forward" not in submodule.__dict__
|
|
wrapped_forward = submodule.forward
|
|
|
|
def check_flattened(flat_args):
|
|
for a in flat_args:
|
|
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
|
|
raise AssertionError(
|
|
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
|
|
)
|
|
|
|
def wrapper(self, *args, **kwargs):
|
|
flat_args, in_spec = pytree.tree_flatten((args, kwargs))
|
|
check_flattened(flat_args)
|
|
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
|
|
args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
|
|
res = wrapped_forward(*args, **kwargs)
|
|
flat_res, out_spec = pytree.tree_flatten(res)
|
|
check_flattened(flat_res)
|
|
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
|
|
update_module_call_signatures(path, in_spec, out_spec)
|
|
return pytree.tree_unflatten(flat_res, out_spec)
|
|
|
|
submodule.forward = wrapper.__get__(submodule, type(submodule))
|
|
return submodule
|
|
|
|
|
|
@contextmanager
|
|
def _wrap_submodules(f, preserve_signature, module_call_signatures):
|
|
tasks = []
|
|
|
|
try:
|
|
for path in preserve_signature:
|
|
tasks.append(_wrap_submodule(f, path, module_call_signatures))
|
|
yield
|
|
finally:
|
|
for submodule in tasks:
|
|
del submodule.__dict__["forward"]
|
|
|
|
|
|
def _mark_strict_experimental(cls):
|
|
def call(self, *args):
|
|
return strict_mode(self, args)
|
|
|
|
cls.__call__ = call
|
|
return cls
|