Files
pytorch/torch/fx/experimental/proxy_tensor.py
Horace He fc95eda285 Added proxy tensor
This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py).

Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation.

I put this up as a WIP, just for discussion, but some questions off the top of my head.

1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define?
2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported.
3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback.

Some of the other implementations (for reference on requirements).

1. FX2TRT: D34868356 (internal only)
2. Edge's? @gmagogsfm

cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74360
Approved by: https://github.com/ezyang
2022-05-03 22:46:30 +00:00

185 lines
6.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Dict, Optional, Tuple, Callable, Union
import torch
from torch._C import _disabled_torch_function_impl
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
import torch.fx as fx
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
aten = torch.ops.aten
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
@contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard
@contextmanager
def decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
CURRENT_DECOMPOSITION_TABLE = decomposition_table
try:
yield CURRENT_DECOMPOSITION_TABLE
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
class ProxyTensor(torch.Tensor):
proxy: fx.Proxy
@staticmethod
def __new__(cls, elem, proxy):
# Hack to deal with super().__new__ not working for sparse tensors
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
else:
r = super().__new__(cls, elem) # type: ignore[call-arg]
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
r.proxy = proxy # type: ignore[attr-defined]
return r
def __repr__(self):
with no_dispatch():
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
__torch_function__ = _disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
func = func_overload.overloadpacket
if func_overload in CURRENT_DECOMPOSITION_TABLE:
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
if func_overload == aten._local_scalar_dense.default:
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar.")
def unwrap_proxy(e):
return e.proxy if isinstance(e, ProxyTensor) else e
proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
proxy_out = func(*proxy_args, **proxy_kwargs)
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
args[0].proxy = proxy_out
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
with no_dispatch():
real_out = func_overload(*args, **kwargs)
def wrap_with_proxy(e, proxy):
if type(e) == torch.Tensor:
return ProxyTensor(e, proxy)
else:
return e
# Unfortunately, tree_map cannot directly be used here. As the resulting
# object may be a proxy that represents a tuple, we may need to
# explicitly unwrap the proxy by simulating the flattening operations.
if isinstance(real_out, tuple):
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
elif isinstance(real_out, list):
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
elif isinstance(real_out, torch.Tensor):
return wrap_with_proxy(real_out, proxy_out)
else:
return real_out
class PythonKeyTracer(Tracer):
def __init__(self):
super().__init__()
# In general, we don't want to make modules leaves. In principle, users of
# this tracer might want to override this in order to turn a couple specific
# modules into leaves in the traced graph.
def call_module(
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
return forward(*args, **kwargs)
def create_arg(self, a: Any):
if isinstance(a, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if a is p:
return self.create_node('get_attr', n, (), {})
qualname: Optional[str] = None
if not qualname:
i = 0
while True:
qualname = f'_param_constant{i}'
if not hasattr(self.root, qualname):
break
i += 1
setattr(self.root, qualname, a)
return self.create_node('get_attr', qualname, (), {})
return super().create_arg(a)
def dispatch_trace(
root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None
) -> GraphModule:
tracer = PythonKeyTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
def wrap_key(f, inps):
flat_inps, _ = pytree.tree_flatten(inps)
@functools.wraps(f)
def wrapped(*args):
flat_args, args_spec = pytree.tree_flatten(args)
assert(len(flat_args) == len(flat_inps))
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
else:
flat_args[idx] = flat_inps[idx]
tree_args = pytree.tree_unflatten(flat_args, args_spec)
out = f(*tree_args)
flat_outs, out_spec = pytree.tree_flatten(out)
for idx in range(len(flat_outs)):
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
flat_outs[idx] = flat_outs[idx].proxy
return pytree.tree_unflatten(flat_outs, out_spec)
return wrapped
def make_fx(f, decomposition_table=None):
if decomposition_table is None:
decomposition_table = {}
@functools.wraps(f)
def wrapped(*args):
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
with decompose(decomposition_table):
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs))
return t
return wrapped