mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py). Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation. I put this up as a WIP, just for discussion, but some questions off the top of my head. 1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define? 2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported. 3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback. Some of the other implementations (for reference on requirements). 1. FX2TRT: D34868356 (internal only) 2. Edge's? @gmagogsfm cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74360 Approved by: https://github.com/ezyang
185 lines
6.6 KiB
Python
185 lines
6.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import functools
|
|
from typing import Any, Dict, Optional, Tuple, Callable, Union
|
|
import torch
|
|
from torch._C import _disabled_torch_function_impl
|
|
import torch.utils._pytree as pytree
|
|
from torch.fx import Tracer, GraphModule
|
|
import torch.fx as fx
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
|
from contextlib import contextmanager
|
|
|
|
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
|
|
aten = torch.ops.aten
|
|
|
|
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
|
|
|
|
|
|
@contextmanager
|
|
def no_dispatch():
|
|
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
|
|
try:
|
|
yield
|
|
finally:
|
|
del guard
|
|
|
|
|
|
@contextmanager
|
|
def decompose(decomposition_table):
|
|
global CURRENT_DECOMPOSITION_TABLE
|
|
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
|
|
CURRENT_DECOMPOSITION_TABLE = decomposition_table
|
|
try:
|
|
yield CURRENT_DECOMPOSITION_TABLE
|
|
finally:
|
|
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
|
|
|
|
|
|
class ProxyTensor(torch.Tensor):
|
|
proxy: fx.Proxy
|
|
|
|
@staticmethod
|
|
def __new__(cls, elem, proxy):
|
|
# Hack to deal with super().__new__ not working for sparse tensors
|
|
if elem.is_sparse:
|
|
proxy.node.meta['tensor_meta'] = {}
|
|
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
else:
|
|
r = super().__new__(cls, elem) # type: ignore[call-arg]
|
|
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
|
|
r.proxy = proxy # type: ignore[attr-defined]
|
|
|
|
return r
|
|
|
|
def __repr__(self):
|
|
with no_dispatch():
|
|
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
|
|
|
|
__torch_function__ = _disabled_torch_function_impl
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
|
func = func_overload.overloadpacket
|
|
if func_overload in CURRENT_DECOMPOSITION_TABLE:
|
|
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
|
|
if func_overload == aten._local_scalar_dense.default:
|
|
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
|
|
"It's likely that this is caused by data-dependent control flow or similar.")
|
|
|
|
def unwrap_proxy(e):
|
|
return e.proxy if isinstance(e, ProxyTensor) else e
|
|
|
|
proxy_args = pytree.tree_map(unwrap_proxy, args)
|
|
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
|
|
|
|
proxy_out = func(*proxy_args, **proxy_kwargs)
|
|
|
|
# Kind of a hacky way to test if an op is in-place or not
|
|
if func.__name__[-1] == "_" and func.__name__[0] != "_":
|
|
args[0].proxy = proxy_out
|
|
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
|
|
|
|
with no_dispatch():
|
|
real_out = func_overload(*args, **kwargs)
|
|
|
|
def wrap_with_proxy(e, proxy):
|
|
if type(e) == torch.Tensor:
|
|
return ProxyTensor(e, proxy)
|
|
else:
|
|
return e
|
|
|
|
# Unfortunately, tree_map cannot directly be used here. As the resulting
|
|
# object may be a proxy that represents a tuple, we may need to
|
|
# explicitly unwrap the proxy by simulating the flattening operations.
|
|
if isinstance(real_out, tuple):
|
|
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
|
|
elif isinstance(real_out, list):
|
|
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
|
|
elif isinstance(real_out, torch.Tensor):
|
|
return wrap_with_proxy(real_out, proxy_out)
|
|
else:
|
|
return real_out
|
|
|
|
|
|
class PythonKeyTracer(Tracer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
# In general, we don't want to make modules leaves. In principle, users of
|
|
# this tracer might want to override this in order to turn a couple specific
|
|
# modules into leaves in the traced graph.
|
|
def call_module(
|
|
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
) -> Any:
|
|
return forward(*args, **kwargs)
|
|
|
|
def create_arg(self, a: Any):
|
|
if isinstance(a, torch.nn.Parameter):
|
|
for n, p in self.root.named_parameters():
|
|
if a is p:
|
|
return self.create_node('get_attr', n, (), {})
|
|
qualname: Optional[str] = None
|
|
|
|
if not qualname:
|
|
i = 0
|
|
while True:
|
|
qualname = f'_param_constant{i}'
|
|
if not hasattr(self.root, qualname):
|
|
break
|
|
i += 1
|
|
setattr(self.root, qualname, a)
|
|
|
|
return self.create_node('get_attr', qualname, (), {})
|
|
return super().create_arg(a)
|
|
|
|
|
|
def dispatch_trace(
|
|
root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None
|
|
) -> GraphModule:
|
|
tracer = PythonKeyTracer()
|
|
graph = tracer.trace(root, concrete_args)
|
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
|
return GraphModule(tracer.root, graph, name)
|
|
|
|
|
|
def wrap_key(f, inps):
|
|
flat_inps, _ = pytree.tree_flatten(inps)
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
flat_args, args_spec = pytree.tree_flatten(args)
|
|
assert(len(flat_args) == len(flat_inps))
|
|
for idx, arg in enumerate(flat_args):
|
|
if isinstance(flat_inps[idx], torch.Tensor):
|
|
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
|
|
else:
|
|
flat_args[idx] = flat_inps[idx]
|
|
|
|
tree_args = pytree.tree_unflatten(flat_args, args_spec)
|
|
out = f(*tree_args)
|
|
flat_outs, out_spec = pytree.tree_flatten(out)
|
|
for idx in range(len(flat_outs)):
|
|
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
|
|
flat_outs[idx] = flat_outs[idx].proxy
|
|
return pytree.tree_unflatten(flat_outs, out_spec)
|
|
|
|
return wrapped
|
|
|
|
|
|
def make_fx(f, decomposition_table=None):
|
|
if decomposition_table is None:
|
|
decomposition_table = {}
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
|
|
with decompose(decomposition_table):
|
|
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs))
|
|
return t
|
|
|
|
return wrapped
|