mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch dispatch mode to ProxyTensor tracing (#77174)
Uses a mode for ProxyTensor tracing so that it traces factory functions as well cc @dhruvbird Pull Request resolved: https://github.com/pytorch/pytorch/pull/77174 Approved by: https://github.com/ezyang
This commit is contained in:
@ -710,6 +710,30 @@ class TestFXExperimental(JitTestCase):
|
|||||||
inp = torch.randn(3, requires_grad=True)
|
inp = torch.randn(3, requires_grad=True)
|
||||||
torch.testing.assert_close(traced_graph(inp), f(inp))
|
torch.testing.assert_close(traced_graph(inp), f(inp))
|
||||||
|
|
||||||
|
def test_mode_tracing_factory_function(self):
|
||||||
|
def f(x):
|
||||||
|
return x + torch.randn(x.shape)
|
||||||
|
|
||||||
|
traced = make_fx(f, trace_factory_functions=True)(torch.randn(3))
|
||||||
|
self.assertTrue(
|
||||||
|
any(
|
||||||
|
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
|
||||||
|
for node in traced.graph.nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_mode_tracing_factory_function_default_behavior(self):
|
||||||
|
def f(x):
|
||||||
|
return x + torch.randn(x.shape)
|
||||||
|
|
||||||
|
traced = make_fx(f)(torch.randn(3)) # default behavior should not trace factory functions
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
|
||||||
|
for node in traced.graph.nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_call_to_assert_with_msg(self):
|
def test_call_to_assert_with_msg(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
|
@ -13,6 +13,8 @@ import torch.fx as fx
|
|||||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
|
||||||
|
|
||||||
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
|
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
@ -39,6 +41,53 @@ def decompose(decomposition_table):
|
|||||||
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
|
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_output(real_out, proxy_out):
|
||||||
|
def wrap_with_proxy(e, proxy):
|
||||||
|
if type(e) == torch.Tensor:
|
||||||
|
with no_dispatch():
|
||||||
|
return ProxyTensor(e, proxy)
|
||||||
|
else:
|
||||||
|
return e
|
||||||
|
|
||||||
|
# Unfortunately, tree_map cannot directly be used here. As the resulting
|
||||||
|
# object may be a proxy that represents a tuple, we may need to
|
||||||
|
# explicitly unwrap the proxy by simulating the flattening operations.
|
||||||
|
if isinstance(real_out, tuple):
|
||||||
|
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
|
||||||
|
elif isinstance(real_out, list):
|
||||||
|
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
|
||||||
|
elif isinstance(real_out, torch.Tensor):
|
||||||
|
return wrap_with_proxy(real_out, proxy_out)
|
||||||
|
else:
|
||||||
|
return real_out
|
||||||
|
|
||||||
|
|
||||||
|
def proxy_call(func_overload, args, kwargs=None):
|
||||||
|
func = func_overload.overloadpacket
|
||||||
|
if func_overload in CURRENT_DECOMPOSITION_TABLE:
|
||||||
|
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
|
||||||
|
if func_overload == aten._local_scalar_dense.default:
|
||||||
|
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
|
||||||
|
"It's likely that this is caused by data-dependent control flow or similar.")
|
||||||
|
|
||||||
|
def unwrap_proxy(e):
|
||||||
|
return e.proxy if isinstance(e, ProxyTensor) else e
|
||||||
|
|
||||||
|
proxy_args = pytree.tree_map(unwrap_proxy, args)
|
||||||
|
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
|
||||||
|
|
||||||
|
proxy_out = func(*proxy_args, **proxy_kwargs)
|
||||||
|
|
||||||
|
# Kind of a hacky way to test if an op is in-place or not
|
||||||
|
if func.__name__[-1] == "_" and func.__name__[0] != "_":
|
||||||
|
args[0].proxy = proxy_out
|
||||||
|
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
|
||||||
|
|
||||||
|
with no_dispatch():
|
||||||
|
real_out = func_overload(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrap_output(real_out, proxy_out)
|
||||||
|
|
||||||
class ProxyTensor(torch.Tensor):
|
class ProxyTensor(torch.Tensor):
|
||||||
proxy: fx.Proxy
|
proxy: fx.Proxy
|
||||||
|
|
||||||
@ -63,46 +112,7 @@ class ProxyTensor(torch.Tensor):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
||||||
func = func_overload.overloadpacket
|
return proxy_call(func_overload, args, kwargs)
|
||||||
if func_overload in CURRENT_DECOMPOSITION_TABLE:
|
|
||||||
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
|
|
||||||
if func_overload == aten._local_scalar_dense.default:
|
|
||||||
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
|
|
||||||
"It's likely that this is caused by data-dependent control flow or similar.")
|
|
||||||
|
|
||||||
def unwrap_proxy(e):
|
|
||||||
return e.proxy if isinstance(e, ProxyTensor) else e
|
|
||||||
|
|
||||||
proxy_args = pytree.tree_map(unwrap_proxy, args)
|
|
||||||
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
|
|
||||||
|
|
||||||
proxy_out = func(*proxy_args, **proxy_kwargs)
|
|
||||||
|
|
||||||
# Kind of a hacky way to test if an op is in-place or not
|
|
||||||
if func.__name__[-1] == "_" and func.__name__[0] != "_":
|
|
||||||
args[0].proxy = proxy_out
|
|
||||||
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
|
|
||||||
|
|
||||||
with no_dispatch():
|
|
||||||
real_out = func_overload(*args, **kwargs)
|
|
||||||
|
|
||||||
def wrap_with_proxy(e, proxy):
|
|
||||||
if type(e) == torch.Tensor:
|
|
||||||
return ProxyTensor(e, proxy)
|
|
||||||
else:
|
|
||||||
return e
|
|
||||||
|
|
||||||
# Unfortunately, tree_map cannot directly be used here. As the resulting
|
|
||||||
# object may be a proxy that represents a tuple, we may need to
|
|
||||||
# explicitly unwrap the proxy by simulating the flattening operations.
|
|
||||||
if isinstance(real_out, tuple):
|
|
||||||
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
|
|
||||||
elif isinstance(real_out, list):
|
|
||||||
return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
|
|
||||||
elif isinstance(real_out, torch.Tensor):
|
|
||||||
return wrap_with_proxy(real_out, proxy_out)
|
|
||||||
else:
|
|
||||||
return real_out
|
|
||||||
|
|
||||||
|
|
||||||
class PythonKeyTracer(Tracer):
|
class PythonKeyTracer(Tracer):
|
||||||
@ -113,7 +123,7 @@ class PythonKeyTracer(Tracer):
|
|||||||
# this tracer might want to override this in order to turn a couple specific
|
# this tracer might want to override this in order to turn a couple specific
|
||||||
# modules into leaves in the traced graph.
|
# modules into leaves in the traced graph.
|
||||||
def call_module(
|
def call_module(
|
||||||
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return forward(*args, **kwargs)
|
return forward(*args, **kwargs)
|
||||||
|
|
||||||
@ -138,10 +148,16 @@ class PythonKeyTracer(Tracer):
|
|||||||
|
|
||||||
|
|
||||||
def dispatch_trace(
|
def dispatch_trace(
|
||||||
root: Union[torch.nn.Module, Callable], concrete_args: Optional[Tuple[Any, ...]] = None
|
root: Union[torch.nn.Module, Callable],
|
||||||
|
concrete_args: Optional[Tuple[Any, ...]] = None,
|
||||||
|
trace_factory_functions: bool = False,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
tracer = PythonKeyTracer()
|
tracer = PythonKeyTracer()
|
||||||
graph = tracer.trace(root, concrete_args)
|
if trace_factory_functions:
|
||||||
|
with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)):
|
||||||
|
graph = tracer.trace(root, concrete_args)
|
||||||
|
else:
|
||||||
|
graph = tracer.trace(root, concrete_args)
|
||||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
return GraphModule(tracer.root, graph, name)
|
return GraphModule(tracer.root, graph, name)
|
||||||
|
|
||||||
@ -152,10 +168,11 @@ def wrap_key(f, inps):
|
|||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped(*args):
|
def wrapped(*args):
|
||||||
flat_args, args_spec = pytree.tree_flatten(args)
|
flat_args, args_spec = pytree.tree_flatten(args)
|
||||||
assert(len(flat_args) == len(flat_inps))
|
assert (len(flat_args) == len(flat_inps))
|
||||||
for idx, arg in enumerate(flat_args):
|
for idx, arg in enumerate(flat_args):
|
||||||
if isinstance(flat_inps[idx], torch.Tensor):
|
if isinstance(flat_inps[idx], torch.Tensor):
|
||||||
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
|
with no_dispatch():
|
||||||
|
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
|
||||||
else:
|
else:
|
||||||
flat_args[idx] = flat_inps[idx]
|
flat_args[idx] = flat_inps[idx]
|
||||||
|
|
||||||
@ -170,7 +187,25 @@ def wrap_key(f, inps):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def make_fx(f, decomposition_table=None):
|
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||||
|
def __init__(self, tracer):
|
||||||
|
self.tracer = tracer
|
||||||
|
|
||||||
|
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
|
||||||
|
func = func_overload.overloadpacket
|
||||||
|
if any(tuple(isinstance(arg, ProxyTensor) for arg in args)):
|
||||||
|
return proxy_call(func_overload, args, kwargs)
|
||||||
|
else:
|
||||||
|
proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs,
|
||||||
|
name=self.tracer.graph._target_to_str(func.__name__))
|
||||||
|
|
||||||
|
with no_dispatch():
|
||||||
|
real_out = func_overload(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrap_output(real_out, proxy_out)
|
||||||
|
|
||||||
|
|
||||||
|
def make_fx(f, decomposition_table=None, trace_factory_functions=False):
|
||||||
if decomposition_table is None:
|
if decomposition_table is None:
|
||||||
decomposition_table = {}
|
decomposition_table = {}
|
||||||
|
|
||||||
@ -178,7 +213,8 @@ def make_fx(f, decomposition_table=None):
|
|||||||
def wrapped(*args):
|
def wrapped(*args):
|
||||||
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
|
phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
|
||||||
with decompose(decomposition_table):
|
with decompose(decomposition_table):
|
||||||
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs))
|
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs),
|
||||||
|
trace_factory_functions=trace_factory_functions)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
Reference in New Issue
Block a user