mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added proxy tensor
This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py). Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation. I put this up as a WIP, just for discussion, but some questions off the top of my head. 1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define? 2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported. 3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback. Some of the other implementations (for reference on requirements). 1. FX2TRT: D34868356 (internal only) 2. Edge's? @gmagogsfm cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74360 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
4fbbbed674
commit
fc95eda285
@ -27,6 +27,7 @@ from torch.fx.experimental.partitioner_utils import (
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
|
||||
from torch.fx.experimental.meta_tracer import MetaTracer
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.operator_schemas import (
|
||||
@ -689,6 +690,14 @@ class TestFXExperimental(JitTestCase):
|
||||
gm = torch.fx.GraphModule(mttm, graph)
|
||||
torch.testing.assert_close(gm(x), mttm(x))
|
||||
|
||||
def test_proxy_tensor(self):
|
||||
def f(x):
|
||||
val = x.cos().cos().sum()
|
||||
return torch.autograd.grad(val, x)
|
||||
|
||||
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
|
||||
inp = torch.randn(3, requires_grad=True)
|
||||
torch.testing.assert_close(traced_graph(inp), f(inp))
|
||||
|
||||
def test_call_to_assert_with_msg(self):
|
||||
class M(torch.nn.Module):
|
||||
|
Reference in New Issue
Block a user