Added proxy tensor

This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py).

Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation.

I put this up as a WIP, just for discussion, but some questions off the top of my head.

1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define?
2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported.
3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback.

Some of the other implementations (for reference on requirements).

1. FX2TRT: D34868356 (internal only)
2. Edge's? @gmagogsfm

cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74360
Approved by: https://github.com/ezyang
This commit is contained in:
Horace He
2022-05-03 22:46:30 +00:00
committed by PyTorch MergeBot
parent 4fbbbed674
commit fc95eda285
2 changed files with 193 additions and 0 deletions

View File

@ -27,6 +27,7 @@ from torch.fx.experimental.partitioner_utils import (
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
from torch.fx.experimental.meta_tracer import MetaTracer
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.operator_schemas import (
@ -689,6 +690,14 @@ class TestFXExperimental(JitTestCase):
gm = torch.fx.GraphModule(mttm, graph)
torch.testing.assert_close(gm(x), mttm(x))
def test_proxy_tensor(self):
def f(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
inp = torch.randn(3, requires_grad=True)
torch.testing.assert_close(traced_graph(inp), f(inp))
def test_call_to_assert_with_msg(self):
class M(torch.nn.Module):