mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
https://github.com/pytorch/pytorch/issues/148222 Goal: At the moment autograd saved tensors hooks are run in eager after compiled forward. They are executed at the same time for all saved tensors. Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu. This is suboptimal for optimization of peak memory. Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor. To get user specified autograd saved tensors hooks in the graph. Logic: UX: If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm). Where pack_gm and unpack_gm are torch.fx.GraphModule. Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue. User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes. This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule. In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata. If this metadata set - then aot_autograd cache can use saved cache artifact. If metadata is not set - then cache is bypassed. Dynamo: Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default). The complexity here is that at this moment we do not have example of inputs for the hooks. We trace pack_hook with some Tensor from the inputs. The result subgraphs are added to the hashing of AotAutograd Cache. In AotAutograd we retrace the graph with the true saved tensors coming from partitioner. Backwards Compatibility: As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks). For other hooks or if compiled autograd is enabled - keep the same logic. Recompilations: Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function. Aot_autograd: After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user. We do not try to put it close the last usage etc., relying on inductor to do this optimization. ``` INFO: TRACED GRAPH ===== Forward graph pre saved_tensors_hooks inlining 3 ===== /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"): # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1 add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x) view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]) return (view, add, primals_1, primals_2) INFO: TRACED GRAPH ===== Backward graph pre saved_tensors_hooks inlining 3 ===== /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"): # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1 add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x) view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]) return (view, add, primals_1, primals_2) INFO: TRACED GRAPH ===== saved_tensors_pack_hook add 3 ===== /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module): def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"): # No stacktrace found for following nodes _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn); x_1 = None return (torch.float32, _to_copy) INFO: TRACED GRAPH ===== saved_tensors_unpack_hook add 3 ===== <eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module): def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"): # No stacktrace found for following nodes _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn); x_1 = None return (torch.float32, _to_copy) INFO: TRACED GRAPH ===== Forward graph 3 ===== /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"): # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1 add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None # No stacktrace found for following nodes _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn) # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x) view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]); add = None return (view, _to_copy, primals_1, primals_2) INFO: TRACED GRAPH ===== Backward graph 3 ===== <eval_with_key>.21 class GraphModule(torch.nn.Module): def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"): # No stacktrace found for following nodes _to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32); add_packed_2 = None # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x) add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy); tangents_1 = _to_copy = None return (None, None, add_7) ``` Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032 Approved by: https://github.com/bdhirsh
434 lines
11 KiB
Python
434 lines
11 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import collections
|
|
import re
|
|
import sys
|
|
import time
|
|
from io import StringIO
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.comptime import comptime
|
|
|
|
|
|
# Because we don't support free variables in comptime at the moment,
|
|
# we have to communicate via globals. This also means these tests cannot
|
|
# be run in parallel in a single process (not that you'd... ever want
|
|
# to do that?)
|
|
FILE = None
|
|
SELF = None
|
|
|
|
|
|
class ComptimeTests(torch._dynamo.test_case.TestCase):
|
|
def test_print_single(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
def comptime_print(e):
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print(ctx.get_local("e"), file=FILE)
|
|
|
|
Employee = collections.namedtuple("Employee", ["name", "id"])
|
|
|
|
class mylist(list):
|
|
pass
|
|
|
|
@torch.compile(backend=cnt, dynamic=True)
|
|
def f(x):
|
|
y = x * 2
|
|
comptime_print(y)
|
|
comptime_print(2)
|
|
comptime_print([y, 2])
|
|
comptime_print((y, 2))
|
|
comptime_print({"foo": y})
|
|
comptime_print(range(1, 3))
|
|
comptime_print(Employee("foo", 2))
|
|
comptime_print(mylist([1, 2]))
|
|
comptime_print(collections.defaultdict(lambda: None))
|
|
comptime_print(set())
|
|
comptime_print({"a", "b"})
|
|
comptime_print(x.size(0))
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue().strip(),
|
|
"""\
|
|
FakeTensor(..., size=(s77,))
|
|
2
|
|
[FakeTensor(..., size=(s77,)), 2]
|
|
(FakeTensor(..., size=(s77,)), 2)
|
|
{'foo': FakeTensor(..., size=(s77,))}
|
|
range(1, 3, 1)
|
|
Employee(name='foo', id=2)
|
|
UserDefinedListVariable(mylist)
|
|
defaultdict(NestedUserFunctionVariable(), {})
|
|
set()
|
|
{'a','b'}
|
|
s77""",
|
|
)
|
|
|
|
def test_print_graph(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_graph(verbose=False, file=FILE)
|
|
|
|
# Test the compact notation doesn't error or graph break;
|
|
# you'll have to visually inspect to see that it printed
|
|
comptime.print_graph()
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue().strip(),
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
y = l_x_ * 2; l_x_ = y = None""",
|
|
)
|
|
|
|
def test_print_disas(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_disas(file=FILE)
|
|
|
|
comptime.print_disas()
|
|
|
|
return y + 3
|
|
|
|
def munge_disas(s): # noqa: F841
|
|
re.sub(
|
|
r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
|
|
"\1 \3",
|
|
s,
|
|
flags=re.MULTILINE,
|
|
)
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
out = FILE.getvalue()
|
|
# Check that the instruction offset is working
|
|
self.assertIn("-->", out)
|
|
# Check that the bytecode resembles what we expect
|
|
self.assertIn("STORE_FAST", out)
|
|
if sys.version_info < (3, 11):
|
|
self.assertIn("BINARY_MULTIPLY", out)
|
|
else:
|
|
self.assertIn("BINARY_OP", out)
|
|
|
|
def test_print_value_stack(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
def g(x):
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_value_stack(file=FILE, stacklevel=1)
|
|
|
|
return x
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x + g(x)
|
|
|
|
return y + comptime.print_value_stack_and_return(y * 2)
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue(),
|
|
"""\
|
|
- FakeTensor(..., size=(2,))
|
|
""",
|
|
)
|
|
|
|
def test_print_locals(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_locals(file=FILE)
|
|
|
|
comptime.print_locals()
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue(),
|
|
"""\
|
|
x = FakeTensor(..., size=(2,))
|
|
y = FakeTensor(..., size=(2,))
|
|
""",
|
|
)
|
|
|
|
# Just make sure it doesn't crash
|
|
def test_print_direct(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x, z):
|
|
y = x * 2
|
|
lambda: z
|
|
comptime.print(z)
|
|
return y + 3
|
|
|
|
f(torch.randn(2), torch.randn(2))
|
|
|
|
def test_sleep(self):
|
|
sleep_time = 5
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x, z, should_sleep):
|
|
if should_sleep:
|
|
comptime.sleep(sleep_time)
|
|
y = x * 2
|
|
return y + 3
|
|
|
|
start = time.time()
|
|
f(torch.randn(2), torch.randn(2), False)
|
|
total_no_sleep = time.time() - start
|
|
|
|
start = time.time()
|
|
f(torch.randn(2), torch.randn(2), True)
|
|
total_with_sleep = time.time() - start
|
|
|
|
self.assertTrue(total_with_sleep > sleep_time)
|
|
# Hopefully this won't be flaky
|
|
self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3)
|
|
|
|
# Just make sure it doesn't crash
|
|
def test_get_local_closure_variable(self):
|
|
global SELF
|
|
SELF = self
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
z = 3
|
|
|
|
def g():
|
|
@comptime
|
|
def _(ctx):
|
|
r = ctx.get_local("z")
|
|
SELF.assertEqual(repr(r), "3")
|
|
|
|
comptime.print(z)
|
|
return 2
|
|
|
|
y = x * g()
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
|
|
def test_print_bt(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
def g(x):
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_bt(file=FILE)
|
|
|
|
comptime.print_bt()
|
|
|
|
return x + 3
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
y = g(y)
|
|
return y + 3
|
|
|
|
def munge_filenames(s): # noqa: F841
|
|
return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
bt = FILE.getvalue()
|
|
self.assertIn("y = g(y)", bt)
|
|
|
|
def test_print_guards(self):
|
|
global FILE
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_guards(file=FILE)
|
|
|
|
comptime.print_guards()
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
|
|
"""\
|
|
|
|
local "L['x']" TENSOR_MATCH
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' AUTOGRAD_SAVED_TENSORS_HOOKS
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' GRAD_MODE
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' DETERMINISTIC_ALGORITHMS
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' TORCH_FUNCTION_STATE
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
global '' DEFAULT_DEVICE
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}
|
|
shape_env '' SHAPE_ENV
|
|
{
|
|
'guard_types': None,
|
|
'code': None,
|
|
'obj_weakref': None
|
|
'guarded_class': None
|
|
}""",
|
|
)
|
|
|
|
def test_graph_break(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
pass
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
cnt.frame_count = 0
|
|
|
|
@torch.compile(backend=cnt)
|
|
def g(x):
|
|
y = x * 2
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.graph_break()
|
|
|
|
y = y + 2
|
|
|
|
comptime.graph_break()
|
|
|
|
return y * 3
|
|
|
|
g(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_get_local(self):
|
|
global SELF, FILE
|
|
SELF = self
|
|
FILE = StringIO()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
y = x * 2
|
|
lit = 2 # noqa: F841
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
y = ctx.get_local("y")
|
|
SELF.assertEqual(y.as_fake().size(0), 2)
|
|
SELF.assertEqual(y.size(0), 2)
|
|
# Trigger a graph write (TODO: this is not so
|
|
# useful right now as there's no way to make use
|
|
# of the output proxy; maybe it's useful for inserting
|
|
# side-effectful operations into the graph)
|
|
y.as_proxy() + 4
|
|
ctx.print_graph(verbose=False, file=FILE)
|
|
SELF.assertIs(y.python_type(), torch.Tensor)
|
|
lit = ctx.get_local("lit")
|
|
SELF.assertEqual(lit.as_python_constant(), 2)
|
|
|
|
return y + 3
|
|
|
|
f(torch.randn(2))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertExpectedInline(
|
|
FILE.getvalue().strip(),
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
y = l_x_ * 2; l_x_ = None
|
|
add = y + 4; y = add = None""",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|