mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161355 Approved by: https://github.com/zou3519
7248 lines
268 KiB
Python
7248 lines
268 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import enum
|
|
import functools
|
|
import pprint
|
|
import re
|
|
import unittest
|
|
import warnings
|
|
from copy import deepcopy
|
|
|
|
import functorch.experimental.control_flow as control_flow
|
|
import torch
|
|
import torch._dynamo.config as config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.config
|
|
import torch.nn as nn
|
|
import torch.utils._pytree as pytree
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.testing import (
|
|
check_dynamic_shape_capture,
|
|
CompileCounter,
|
|
CompileCounterWithBackend,
|
|
EagerAndRecordGraphs,
|
|
empty_line_normalizer,
|
|
normalize_gm,
|
|
)
|
|
from torch._dynamo.utils import counters, ifdynstaticdefault
|
|
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
|
from torch._higher_order_ops.wrap import wrap
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
ops,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
munge_exc,
|
|
parametrize,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
xfailIfTorchDynamo,
|
|
)
|
|
from torch.testing._internal.hop_db import hop_db
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
|
|
|
|
def count_ops(gm, args, freq, op):
|
|
actual = [node.target for node in gm.graph.nodes].count(op)
|
|
assert actual == freq, f"expected={freq}, actual={actual}"
|
|
return gm
|
|
|
|
|
|
class Obj:
|
|
pass
|
|
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.existing = torch.nn.Parameter(torch.ones([]))
|
|
|
|
def forward(self, x):
|
|
return self.existing * x
|
|
|
|
|
|
global_obj = Obj()
|
|
global_module = MyModule()
|
|
global_var = torch.randn(3)
|
|
global_num = 3.14
|
|
global_list = []
|
|
|
|
|
|
def find_first_node(gm, func):
|
|
for node in gm.graph.nodes:
|
|
if node.target is func:
|
|
return node
|
|
return None
|
|
|
|
|
|
def op_count(gm):
|
|
result = 0
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
result += 1
|
|
return result
|
|
|
|
|
|
# Checks that a dict matches a dict with "regex keys". That is,
|
|
# the keys are regex expressions.
|
|
def assert_dict_matches_regex(self, dct, dct_with_regex_keys):
|
|
regex_keys = dct_with_regex_keys.keys()
|
|
regex_key_to_actual_key = {}
|
|
for regex_key in regex_keys:
|
|
for key in dct:
|
|
if re.match(regex_key, key):
|
|
if regex_key in regex_key_to_actual_key:
|
|
raise AssertionError(
|
|
f"Single key regex mapped to multiple keys. Please improve your "
|
|
f"regex. Got: regex='{regex_key}' "
|
|
f"keys='{regex_key_to_actual_key[regex_key]}',"
|
|
f"'{key}'"
|
|
)
|
|
regex_key_to_actual_key[regex_key] = key
|
|
new_dct = {}
|
|
for regex_key in regex_keys:
|
|
if regex_key not in regex_key_to_actual_key:
|
|
raise AssertionError(
|
|
f"Got regex '{regex_key}' but could not match any key in dict with "
|
|
f"keys {dct.keys()}"
|
|
)
|
|
new_dct[regex_key_to_actual_key[regex_key]] = dct_with_regex_keys[regex_key]
|
|
self.assertEqual(dct, new_dct)
|
|
|
|
|
|
def default_args_generator(seed_value):
|
|
flat_args, args_spec = pytree.tree_flatten(seed_value)
|
|
for i in range(3):
|
|
new_flat_arg = []
|
|
for val in flat_args:
|
|
if isinstance(val, torch.Tensor):
|
|
new_val = val + 0.1 * i
|
|
elif isinstance(val, int):
|
|
new_val = val + 1 * i
|
|
elif isinstance(val, float):
|
|
new_val = val + 0.1 * i
|
|
elif isinstance(val, enum.Enum):
|
|
new_val = val
|
|
else:
|
|
raise AssertionError("unexpected arg type")
|
|
|
|
new_flat_arg.append(new_val)
|
|
new_args = pytree.tree_unflatten(new_flat_arg, args_spec)
|
|
yield new_args
|
|
|
|
|
|
class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|
def _assert_wrap_fallback(self, func, args, setup=lambda: None):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
setup()
|
|
expected = func(*args)
|
|
setup()
|
|
result = torch.compile(func, backend=cnt, fullgraph=False)(*args)
|
|
num_graph_breaks = len(counters["graph_break"].keys())
|
|
self.assertGreater(num_graph_breaks, 0)
|
|
|
|
for gm in backend.graphs:
|
|
for node in gm.graph.nodes:
|
|
self.assertFalse(node.target is wrap)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
def _test_wrap_simple(
|
|
self,
|
|
func,
|
|
args_generator,
|
|
expected_num_wrap_args,
|
|
expected_opcount=2,
|
|
return_graph=False,
|
|
):
|
|
# Given a `func` that has a single call to `wrap`,
|
|
# we check that:
|
|
# - there are no graph breaks
|
|
# - eager vs torch.compile has the same result (correctness)
|
|
# - other compilation metrics, e.g, # of ops in the dynamo captured graph,
|
|
# the wrap has the expected number of args, etc
|
|
#
|
|
# we have one or multiple runs through with each of the args from args_generator,
|
|
# and we will check:
|
|
# - correctness and no graph breaks for every run
|
|
# - other compilation metrics only for the first run, since automatic_dynamic_shapes
|
|
# may compile another dynamic version graph for the later runs
|
|
graph = None
|
|
for i, args in enumerate(args_generator):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
expected = func(*args)
|
|
result = torch.compile(func, fullgraph=True, backend=cnt)(*args)
|
|
# check correctness and no graph breaks
|
|
self.assertEqual(result, expected)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
# check other compilation metrics
|
|
if i == 0:
|
|
self.assertEqual(cnt.op_count, expected_opcount)
|
|
graph = backend.graphs[0]
|
|
wrap_node = find_first_node(graph, wrap)
|
|
self.assertEqual(len(wrap_node.args), expected_num_wrap_args)
|
|
# We always return/check the graph from the first run if return_graph = True
|
|
if return_graph:
|
|
return normalize_gm(graph.print_readable(print_output=False))
|
|
|
|
def test_error_message_sane(self):
|
|
foo = []
|
|
|
|
def inner(x):
|
|
foo.append(x)
|
|
return x.clone()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
return wrap(inner, x)
|
|
|
|
x = torch.randn(3)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)",
|
|
):
|
|
f(x)
|
|
|
|
def test_no_freevars(self):
|
|
def f(x):
|
|
return wrap(lambda x: torch.sin(x), x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_enum_arg(self):
|
|
class SomeEnum(enum.Enum):
|
|
A = 0
|
|
B = 1
|
|
|
|
def g(x, val):
|
|
if val == SomeEnum.A:
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
def f(x, val):
|
|
return wrap(g, x, val)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), arg_count)
|
|
|
|
def test_return_captured_var(self):
|
|
freevar = torch.randn(3)
|
|
|
|
def test(x):
|
|
return freevar
|
|
|
|
def fn(x):
|
|
return wrap(test, x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# Since, `x` is unused, we don't lift it to
|
|
# be the input.
|
|
|
|
# when testing with dynamic shape, symbols are lifted as input
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
|
|
|
def test_return_captured_vars(self):
|
|
freevar1 = torch.randn(3)
|
|
freevar2 = torch.randn(3)
|
|
|
|
def test(x):
|
|
return freevar1, freevar2, freevar1
|
|
|
|
def fn(x):
|
|
return wrap(test, x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# Since, `x` is unused, we don't lift it to
|
|
# be the input.
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
|
|
|
|
def test_return_captured_var_used_multiple_times(self):
|
|
freevar = torch.randn(3)
|
|
|
|
def test(x):
|
|
y = x + freevar
|
|
return y, freevar
|
|
|
|
def fn(x):
|
|
return wrap(test, x)
|
|
|
|
x = torch.randn(3)
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
|
|
|
|
def test_capture_untracked_global(self):
|
|
def f(x):
|
|
return wrap(lambda x: x + global_var, x)
|
|
|
|
x = torch.randn(3)
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_allow_python_side_effects_utility(self):
|
|
from torch._dynamo.utils import (
|
|
_disable_side_effect_safety_checks_for_current_subtracer,
|
|
)
|
|
from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper
|
|
|
|
def wrapper(fn):
|
|
return fn
|
|
|
|
count = 0
|
|
|
|
def does_side_effect(x):
|
|
nonlocal count
|
|
count += 1
|
|
return x.sin()
|
|
|
|
def does_side_effect_wrapped(*args, **kwargs):
|
|
return _disable_side_effect_safety_checks_for_current_subtracer(
|
|
does_side_effect, *args, **kwargs
|
|
)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return dynamo_bypassing_wrapper(wrapper, does_side_effect_wrapped, x)
|
|
|
|
x = torch.tensor(1.0)
|
|
fn(x)
|
|
|
|
def inner_does_side_effect(x):
|
|
nonlocal count
|
|
count += 1
|
|
return x
|
|
|
|
# Test that any nested HOPs are unaffected
|
|
def outer(x):
|
|
return dynamo_bypassing_wrapper(wrapper, inner_does_side_effect, x)
|
|
|
|
def outer_wrapped(*args, **kwargs):
|
|
return _disable_side_effect_safety_checks_for_current_subtracer(
|
|
outer, *args, **kwargs
|
|
)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn_nested(x):
|
|
return dynamo_bypassing_wrapper(wrapper, outer_wrapped, x)
|
|
|
|
x = torch.tensor(1.0)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Mutating a variable not in the current scope"
|
|
):
|
|
fn_nested(x)
|
|
|
|
def test_symint_input(self):
|
|
def f(x):
|
|
i = x.size(0)
|
|
return wrap(lambda x, i: x.view(i), x, i)
|
|
|
|
x = torch.randn(3, 1)
|
|
self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
ifdynstaticdefault(2, 3),
|
|
expected_opcount=2,
|
|
)
|
|
|
|
def test_symint_in_slice(self):
|
|
def f(x):
|
|
i = x.size(0) - 2
|
|
j = x.size(1) - 3
|
|
k = x.size(2)
|
|
return wrap(lambda x: x[:i, :j, k:], x)
|
|
|
|
x = torch.randn(3, 4, 5)
|
|
self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
# 3 basic symbols and 2 compound symbols
|
|
ifdynstaticdefault(2, 7),
|
|
# 2 more sym expression computation
|
|
expected_opcount=ifdynstaticdefault(2, 4),
|
|
)
|
|
|
|
def test_wrap_pytree_args_nested(self):
|
|
def f(x, y, z):
|
|
def fn(d):
|
|
return d["x"].sin() + d["y"][0].cos() - d["y"][1][2].sin()
|
|
|
|
return wrap(fn, d)
|
|
|
|
x = torch.tensor(1.5)
|
|
y = torch.tensor(2.0)
|
|
z = torch.tensor(3.0)
|
|
d = {"x": x, "y": (y, [x, y, z])}
|
|
|
|
def my_args_generator(t):
|
|
yield t
|
|
yield t[0] + 0.1, t[1], t[2]
|
|
yield t[0], t[1] + 0.1, t[2]
|
|
|
|
actual_graph = self._test_wrap_simple(
|
|
f,
|
|
my_args_generator((x, y, z)),
|
|
4,
|
|
return_graph=True,
|
|
)
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"):
|
|
l_d_x_ = L_d_x_
|
|
l_d_y_0_ = L_d_y_0_
|
|
l_d_y_1_2_ = L_d_y_1_2_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
|
|
getitem: "f32[]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"):
|
|
sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None
|
|
cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None
|
|
add: "f32[]" = sin + cos; sin = cos = None
|
|
sin_1: "f32[]" = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
|
|
sub: "f32[]" = add - sin_1; add = sin_1 = None
|
|
return (sub,)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
def test_wrap_pytree_args_with_symint_constant(self):
|
|
def f(x, y):
|
|
i = x.size(0)
|
|
return wrap(lambda t: t[0].view(t[2]) + t[1], (x, y, i))
|
|
|
|
x = torch.randn(3, 1)
|
|
y = 0.5
|
|
actual_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
ifdynstaticdefault(2, 3),
|
|
expected_opcount=2,
|
|
return_graph=True,
|
|
)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 1]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3, 1]"):
|
|
view: "f32[3]" = l_x_.view(3); l_x_ = None
|
|
add: "f32[3]" = view + 0.5; view = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 1]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_); wrap_body_0 = s77 = l_x_ = None
|
|
getitem: "f32[s77]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77, 1]"):
|
|
view: "f32[s77]" = l_x_.view(s77); l_x_ = s77 = None
|
|
add: "f32[s77]" = view + 0.5; view = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_wrap_pytree_kwargs(self):
|
|
def f(x, y, z):
|
|
def fn(*, x, y, z):
|
|
z1, _ = z
|
|
return (x * 2) + y + z1
|
|
|
|
return wrap(fn, x=x, y=y, z=z)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def my_args_generator(t):
|
|
yield t
|
|
x1 = t[0] + 0.1
|
|
y1 = t[1] + 0.1
|
|
yield (x1, y1, (x1, y1))
|
|
x2 = t[0] + 0.2
|
|
y2 = t[0] + 0.2
|
|
yield (x2, y2, (x2, y2))
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), arg_count)
|
|
|
|
def test_wrap_pytree_args_not_const_symint_tensor(self):
|
|
class MyClass:
|
|
def __init__(self, x):
|
|
self.val = x
|
|
|
|
def f(x, y):
|
|
return wrap(lambda z: z[0].sin() * z[1].val.cos(), (x, y))
|
|
|
|
x = torch.tensor(1.2)
|
|
y = MyClass(torch.tensor(3.4))
|
|
self._test_wrap_simple(f, [(x, y)], 3)
|
|
|
|
def test_capture_constants(self):
|
|
x = torch.randn(3, 3)
|
|
|
|
def fn(x, y, z):
|
|
if z:
|
|
return x + y
|
|
return x * y
|
|
|
|
def f(x, y, z):
|
|
return wrap(fn, x, y, z)
|
|
|
|
args = (x, 4.0, None)
|
|
opt_f = torch.compile(f, fullgraph=True, backend=CompileCounter())
|
|
expected = f(*args)
|
|
result = opt_f(*args)
|
|
self.assertEqual(result, expected)
|
|
|
|
# Ensure that we recompile here
|
|
args = (x, 5.0, None)
|
|
expected = f(*args)
|
|
result = opt_f(*args)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_capture_untracked_global_nested(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: wrap(lambda x: x + global_var, x), x)
|
|
|
|
x = torch.randn(3)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, x + global_var)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
wrap_node = find_first_node(backend.graphs[0], wrap)
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 2)
|
|
inner_wrap_node = find_first_node(body_function, wrap)
|
|
self.assertTrue(len(inner_wrap_node.args), 3)
|
|
|
|
def test_capture_untracked_nonlocal(self):
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def f(x, y):
|
|
def g(x):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
|
|
return g(x)
|
|
|
|
f(x, y)
|
|
|
|
def test_capture_tracked(self):
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def f(x, y):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_capture_tracked_nested(self):
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def f(x, y):
|
|
return wrap(lambda x: wrap(lambda x: x + y, x), x)
|
|
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_inlined_functions(self):
|
|
def g(x, y):
|
|
return x + y
|
|
|
|
def f(x, y):
|
|
return wrap(lambda x: g(x, y), x)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_same_freevar_twice(self):
|
|
free = torch.randn(3)
|
|
|
|
def g(x):
|
|
y = free.sin()
|
|
z = free.cos()
|
|
return y, z
|
|
|
|
def f(x):
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# Since, `x` is unused, we don't lift it to
|
|
# be the input.
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count, 3)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_scalar_outputs=True,
|
|
)
|
|
def test_unbacked_symbol_closure(self):
|
|
def f(x):
|
|
c = x.sum().item()
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return x + c
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
out_graph = self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, 4, return_graph=True
|
|
)
|
|
|
|
if check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
|
|
l_x_ = L_x_
|
|
|
|
sum_1: "f32[]" = l_x_.sum()
|
|
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, item); wrap_body_1 = s77 = l_x_ = item = None
|
|
getitem: "f32[s77]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, item); wrap_body_0 = s77 = l_x_ = item = None
|
|
getitem: "f32[s77]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", item: "Sym(zuf0)"):
|
|
add: "f32[s77]" = l_x_ + item; l_x_ = item = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
sum_1: "f32[]" = l_x_.sum()
|
|
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, item); wrap_body_1 = l_x_ = item = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, item); wrap_body_0 = l_x_ = item = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
|
|
add: "f32[3]" = l_x_ + item; l_x_ = item = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_tensor_with_unbacked_shape_closure(self):
|
|
def f(x):
|
|
c = x.nonzero()
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return x.sin(), c.sin()
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(4, 5)
|
|
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
|
expected_op_count = ifdynstaticdefault(10, 8)
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
arg_count,
|
|
expected_op_count,
|
|
return_graph=True,
|
|
)
|
|
|
|
if check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
|
|
l_x_ = L_x_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, l_x_, sym_size_int_1, c); wrap_body_1 = s77 = l_x_ = sym_size_int_1 = c = None
|
|
getitem: "f32[s77]" = wrap[0]
|
|
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
|
|
child: "f32[s77]" = wrap[0]
|
|
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (child, child_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
child: "f32[s77]" = l_x_.sin(); l_x_ = None
|
|
child_1: "f32[u0, 1]" = c.sin(); c = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
le: "Sym(u0 <= 3)" = sym_size_int_1 <= 3
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int_1, c); wrap_body_1 = l_x_ = sym_size_int_1 = c = None
|
|
getitem: "f32[3]" = wrap[0]
|
|
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
|
|
child: "f32[3]" = wrap[0]
|
|
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (child, child_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
child: "f32[3]" = l_x_.sin(); l_x_ = None
|
|
child_1: "f32[u0, 1]" = c.sin(); c = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_tensor_to_list_closure(self):
|
|
def f(x):
|
|
li = x.tolist()
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return li[0] + x
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.tensor([1, 2, 3], dtype=torch.int16)
|
|
arg_count = ifdynstaticdefault(3, 3)
|
|
out_graph = self._test_wrap_simple(f, ((x,),), arg_count, 4, return_graph=True)
|
|
|
|
# tolist will specialize on input shapes, so dynamic and static tests
|
|
# have the same graph
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "i16[3]"):
|
|
l_x_ = L_x_
|
|
|
|
getitem = l_x_[0]
|
|
item: "Sym(u0)" = getitem.item(); getitem = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, item, l_x_); wrap_body_1 = item = l_x_ = None
|
|
getitem_3: "i16[3]" = wrap[0]; wrap = None
|
|
return (getitem_3,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, item, l_x_); wrap_body_0 = item = l_x_ = None
|
|
getitem: "i16[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
|
|
add: "i16[3]" = item + l_x_; item = l_x_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_tensor_and_unbacked_symbol_closure(self):
|
|
def f(x):
|
|
c = x.nonzero()
|
|
sz = c.size(0)
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return x.sin() + sz, c.sin()
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(4, 5)
|
|
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
|
expected_op_count = ifdynstaticdefault(10, 8)
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
arg_count,
|
|
expected_op_count,
|
|
return_graph=True,
|
|
)
|
|
|
|
# Note that u0 is accessed from sz and the shape of c
|
|
# We cached via the symbol u0 and de-duplicate them.
|
|
if not check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
le: "Sym(u0 <= 3)" = sym_size_int <= 3
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int, c); wrap_body_1 = l_x_ = sym_size_int = c = None
|
|
getitem: "f32[3]" = wrap[0]
|
|
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
|
|
child: "f32[3]" = wrap[0]
|
|
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (child, child_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
|
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
|
child: "f32[3]" = sin + size; sin = size = None
|
|
child_1: "f32[u0, 1]" = c.sin(); c = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_concat_unbacked_shape_tensor(self):
|
|
def f(x, y):
|
|
c = x.nonzero()
|
|
d = y.nonzero()
|
|
cat = torch.cat((c, d))
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return cat.sum() + x
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(5, 6)
|
|
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
|
|
expected_op_count = ifdynstaticdefault(17, 13)
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
arg_count,
|
|
expected_op_count,
|
|
return_graph=True,
|
|
)
|
|
|
|
if not check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]", L_y_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
le: "Sym(u0 <= 3)" = sym_size_int_2 <= 3
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
|
|
|
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
|
|
|
|
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
|
|
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
|
|
|
|
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
|
|
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
|
|
le_1: "Sym(u1 <= 3)" = sym_size_int_3 <= 3
|
|
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u1 <= 3 on node 'le_1'"); le_1 = _assert_scalar_default_3 = None
|
|
|
|
cat: "i64[u0 + u1, 1]" = torch.cat((c, d)); c = d = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_2, sym_size_int_3, cat, l_x_); wrap_body_1 = sym_size_int_2 = sym_size_int_3 = cat = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, u1, cat, l_x_); wrap_body_0 = u0 = u1 = cat = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
|
|
sum_1: "i64[]" = cat.sum(); cat = None
|
|
add: "f32[3]" = sum_1 + l_x_; sum_1 = l_x_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=False,
|
|
dynamic_shapes=True,
|
|
)
|
|
def test_lift_tensors_with_shared_symbols(self):
|
|
def f(x, y):
|
|
def g(x):
|
|
def k(x):
|
|
return x @ y
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 4)
|
|
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
6,
|
|
2,
|
|
return_graph=True,
|
|
)
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]", s94: "Sym(s94)", L_y_: "f32[s27, s94]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, s77, s27, l_x_, s94, l_y_); wrap_body_1 = s77 = s27 = l_x_ = s94 = l_y_ = None
|
|
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, s27, l_x_, s94, l_y_); wrap_body_0 = s77 = s27 = l_x_ = s94 = l_y_ = None
|
|
getitem: "f32[s77, s94]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", l_x_: "f32[s77, s27]", s94: "Sym(s94)", l_y_: "f32[s27, s94]"):
|
|
matmul: "f32[s77, s94]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
|
return (matmul,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=False,
|
|
dynamic_shapes=True,
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_lift_tensors_with_compound_expressions(self):
|
|
def f(x, y):
|
|
x = x.view(-1, 2)
|
|
c = y.nonzero()
|
|
d = torch.concat((x, c))
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return d.sum() + x
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 4)
|
|
|
|
f(x, y)
|
|
|
|
if not check_dynamic_shape_capture():
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
6,
|
|
9,
|
|
return_graph=True,
|
|
)
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = l_x_.view(-1, 2); l_x_ = None
|
|
|
|
c: "i64[u0, 2]" = l_y_.nonzero(); l_y_ = None
|
|
|
|
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = torch.concat((x, c)); c = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_1, s1, s0, d, x); wrap_body_1 = sym_size_int_1 = s1 = s0 = d = x = None
|
|
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, s1, s0, d, x); wrap_body_0 = u0 = s1 = s0 = d = x = None
|
|
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
|
|
sum_1: "f32[]" = d.sum(); d = None
|
|
add: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = sum_1 + x; sum_1 = x = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_register_subclass(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
a = torch.tensor([1.0, 0.0, 1.0])
|
|
b = torch.randn(3)
|
|
t = TwoTensor(a, b)
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"no rule registered for HOP cond and subclass .*TwoTensor'>",
|
|
):
|
|
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
|
|
|
|
called = 0
|
|
|
|
# Using cond.py_impl
|
|
@cond_op.py_impl(TwoTensor)
|
|
def _(pred, true_fn, false_fn, operands):
|
|
nonlocal called
|
|
called += 1
|
|
assert len(operands) == 1
|
|
a = cond_op(pred, true_fn, false_fn, (operands[0].a,))
|
|
b = cond_op(pred, true_fn, false_fn, (operands[0].b,))
|
|
return TwoTensor(a, b)
|
|
|
|
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
|
|
self.assertEqual(res.a, torch.sin(a))
|
|
self.assertEqual(res.b, torch.sin(b))
|
|
self.assertEqual(called, 1)
|
|
|
|
def test_register_mode(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
torch_dispatch_called = 0
|
|
|
|
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
nonlocal torch_dispatch_called
|
|
torch_dispatch_called += 1
|
|
return func(*args, **kwargs)
|
|
|
|
a = torch.tensor([1.0, 0.1, 1.0])
|
|
pred = a.sum() > 0
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"no rule registered for HigherOrderOperator cond and mode .*MyMode",
|
|
):
|
|
with MyMode():
|
|
res = cond_op(pred, torch.sin, torch.cos, (a,))
|
|
|
|
py_impl_called = 0
|
|
|
|
# Using cond.py_impl
|
|
@cond_op.py_impl(MyMode)
|
|
def _(mode, pred, true_fn, false_fn, operands):
|
|
nonlocal py_impl_called
|
|
py_impl_called += 1
|
|
return cond_op(pred, true_fn, false_fn, operands)
|
|
|
|
a = torch.tensor([1.0, 0.1, 1.0])
|
|
pred = a.sum() > 0
|
|
with MyMode():
|
|
res = cond_op(pred, torch.sin, torch.cos, (a,))
|
|
self.assertEqual(res, a.sin())
|
|
|
|
def test_capture_value_created_in_subgraph(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def inner(x, y):
|
|
z = x + y
|
|
return wrap(lambda x: wrap(lambda x: x + z, x), x)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x, y):
|
|
return wrap(inner, x, y)
|
|
|
|
result = f(x, y)
|
|
|
|
self.assertEqual(result, x + y + x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
|
|
# No changes to args of outer wrap
|
|
gm = backend.graphs[0]
|
|
wrap_node = find_first_node(gm, wrap)
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
# z was lifted to arg of inner wrap
|
|
body_function = getattr(gm, wrap_node.args[0].name)
|
|
# addition + wrap + getitem
|
|
self.assertEqual(op_count(body_function), 3)
|
|
inner_wrap_node = find_first_node(body_function, wrap)
|
|
self.assertTrue(len(inner_wrap_node.args), 3)
|
|
|
|
# Innermost body function: z was also lifted to arg
|
|
body_function = getattr(body_function, inner_wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 2)
|
|
inner_wrap_node = find_first_node(body_function, wrap)
|
|
self.assertTrue(len(inner_wrap_node.args), 3)
|
|
|
|
def test_side_effect_set_new_attr_global_obj(self):
|
|
def setup():
|
|
global global_obj
|
|
global_obj = Obj()
|
|
|
|
def f(x):
|
|
def h(x):
|
|
def g(x):
|
|
global_obj.foo = x + 1
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y + global_obj.foo
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_set_existing_attr_global_obj(self):
|
|
def setup():
|
|
global global_obj
|
|
global_obj = Obj()
|
|
global_obj.foo = nn.Parameter(torch.tensor(4.0))
|
|
|
|
def f(x):
|
|
def h(x):
|
|
def g(x):
|
|
global_obj.foo = x + 1
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y + global_obj.foo
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_del_existing_attr_global_obj(self):
|
|
def setup():
|
|
global global_obj
|
|
global_obj = Obj()
|
|
global_obj.foo = torch.tensor(4.0)
|
|
|
|
def f(x):
|
|
def h(x):
|
|
def g(x):
|
|
del global_obj.foo
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_set_new_attr_global_module(self):
|
|
def setup():
|
|
global global_module
|
|
global_module = MyModule()
|
|
|
|
def h(x):
|
|
def g(x):
|
|
global_module.foo = nn.Parameter(x + 1)
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y + global_module.foo
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,), setup=setup)
|
|
|
|
def test_side_effect_set_existing_attr_global_module(self):
|
|
def setup():
|
|
global global_module
|
|
global_module = MyModule()
|
|
|
|
def h(x):
|
|
def g(x):
|
|
global_module.existing = nn.Parameter(torch.tensor(4.0))
|
|
return global_module(x)
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,), setup=setup)
|
|
|
|
def test_side_effect_del_existing_attr_global_module(self):
|
|
def setup():
|
|
global global_module
|
|
global_module = MyModule()
|
|
|
|
def h(x):
|
|
def g(x):
|
|
del global_module.existing
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_num(self):
|
|
def setup():
|
|
global global_num
|
|
global_num = 3.14
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_num
|
|
global_num = global_num + 1
|
|
return x + global_num
|
|
|
|
y = wrap(g, x)
|
|
return y + global_num
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_num_builtin(self):
|
|
def setup():
|
|
global global_num
|
|
global_num = 3.14
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_num
|
|
global_num += 1
|
|
return x + global_num
|
|
|
|
y = wrap(g, x)
|
|
return y + global_num
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_tensor(self):
|
|
def setup():
|
|
global global_var
|
|
global_var = torch.ones(3)
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_var
|
|
global_var = global_var + 1
|
|
return x + global_var
|
|
|
|
y = wrap(g, x)
|
|
return y + global_var
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_tensor_builtin(self):
|
|
def setup():
|
|
global global_var
|
|
global_var = torch.ones(3)
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_var
|
|
global_var += 1
|
|
return x + global_var
|
|
|
|
y = wrap(g, x)
|
|
return y + global_var
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_list(self):
|
|
def setup():
|
|
global global_list
|
|
global_list = []
|
|
|
|
def f(x):
|
|
def g(x):
|
|
val = x + 1
|
|
global_list.append(val)
|
|
return global_list[-1]
|
|
|
|
y = wrap(g, x)
|
|
z = y + global_list[-1]
|
|
return z
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_nonlocal_num(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = 1
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val = val + 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_set_new_attr_nonlocal_obj(self):
|
|
def f(x):
|
|
def h(x):
|
|
obj = Obj()
|
|
|
|
def g(x):
|
|
obj.val = x.dim()
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
z = y + obj.val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_set_existing_attr_nonlocal_obj(self):
|
|
def f(x):
|
|
def h(x):
|
|
obj = Obj()
|
|
obj.val = 3
|
|
|
|
def g(x):
|
|
obj.val = x.dim()
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
z = y + obj.val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_del_existing_attr_nonlocal_obj(self):
|
|
def f(x):
|
|
def h(x):
|
|
obj = Obj()
|
|
obj.val = 3
|
|
|
|
def g(x):
|
|
del obj.val
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_set_new_attr_nonlocal_module(self):
|
|
def h(x):
|
|
obj = MyModule()
|
|
|
|
def g(x):
|
|
obj.val = x.dim()
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
z = y + obj.val
|
|
return z
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,))
|
|
|
|
def test_side_effect_set_existing_attr_nonlocal_module(self):
|
|
def h(x):
|
|
obj = MyModule()
|
|
|
|
def g(x):
|
|
obj.existing = nn.Parameter(torch.tensor(3.14))
|
|
return obj(x)
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,))
|
|
|
|
def test_side_effect_del_existing_attr_nonlocal_module(self):
|
|
def h(x):
|
|
obj = MyModule()
|
|
|
|
def g(x):
|
|
del obj.existing
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,))
|
|
|
|
def test_side_effect_mutate_nonlocal_tensor(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = torch.tensor(1.0)
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val = val + 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_mutate_nonlocal_num_builtin(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = 1
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val += 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_mutate_nonlocal_tensor_builtin(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = torch.tensor(1.0)
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val += 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_nonlocal_list_append_graph_break(self):
|
|
def g(x):
|
|
y = []
|
|
|
|
def f(k):
|
|
m = k + 1
|
|
y.append(m)
|
|
return k
|
|
|
|
wrap(f, x)
|
|
return y[0]
|
|
|
|
x = torch.randn(3, 3)
|
|
self._assert_wrap_fallback(g, (x,))
|
|
|
|
def test_side_effect_nested_nonlocal_list_append_graph_break(self):
|
|
def g(x):
|
|
def h(x):
|
|
y = []
|
|
|
|
def f(k):
|
|
m = k + 1
|
|
y.append(m)
|
|
return k
|
|
|
|
wrap(f, x)
|
|
return y[0]
|
|
|
|
return h(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
self._assert_wrap_fallback(g, (x,))
|
|
|
|
def test_side_effect_local_list_append_no_graph_break(self):
|
|
def g(x):
|
|
def f(k):
|
|
y = []
|
|
y.append(k + 1)
|
|
return y[0]
|
|
|
|
return wrap(f, x)
|
|
|
|
x = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
|
|
|
|
def test_wrap_kwarg(self):
|
|
def f(x, y):
|
|
return wrap(lambda x, y: x + y, x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_int(self):
|
|
def f(x, y):
|
|
return wrap(lambda x, y: x + y, x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = 8
|
|
|
|
arg_count = (
|
|
ifdynstaticdefault(2, 3) + 1
|
|
if check_dynamic_shape_capture()
|
|
else ifdynstaticdefault(2, 3)
|
|
)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_all_kwarg(self):
|
|
def f(y, x):
|
|
return wrap(lambda x, y: (x * 2) + y, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_only(self):
|
|
def f(x, y):
|
|
def fn(*, x, y):
|
|
return (x * 2) + y
|
|
|
|
return wrap(fn, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_default(self):
|
|
def f(x, y):
|
|
def fn(*, x, y, z=8):
|
|
return (x * 2) + y + z
|
|
|
|
return wrap(fn, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_default_if_branch(self):
|
|
def f(x, y):
|
|
def fn(*, x, y, z=None):
|
|
if z is None:
|
|
return (x * 2) + y
|
|
else:
|
|
return 2 * x
|
|
|
|
return wrap(fn, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_recompile(self):
|
|
def f(x, y, z=None):
|
|
def fn(*, x, y, z=None):
|
|
if z is None:
|
|
return (x * 2) + y
|
|
else:
|
|
return 2 * x
|
|
|
|
return wrap(fn, x=x, y=y, z=z)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
counters.clear()
|
|
opt = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt(x, y)
|
|
self.assertEqual(counters["stats"]["calls_captured"], 2)
|
|
|
|
# verify that we `don't` recompile
|
|
opt(x, y)
|
|
self.assertEqual(counters["stats"]["calls_captured"], 2)
|
|
|
|
output = opt(x, y, 8)
|
|
self.assertEqual(counters["stats"]["calls_captured"], 4)
|
|
self.assertEqual(output, 2 * x)
|
|
|
|
def test_wrap_kwarg_default_else_branch(self):
|
|
def f(x, y, z):
|
|
def fn(*, x, y, z=None):
|
|
if z is None:
|
|
return (x * 2) + y
|
|
else:
|
|
return 2 * x
|
|
|
|
return wrap(fn, x=x, y=y, z=z)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x, y, 8)), arg_count)
|
|
|
|
def test_map_subgraph_name_is_valid(self):
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3)
|
|
|
|
def map_f(xs, y):
|
|
def inner(x, y):
|
|
def inner2(x, y):
|
|
return x + y
|
|
|
|
return control_flow.map(inner2, x, y)
|
|
|
|
return control_flow.map(inner, xs, y)
|
|
|
|
graphs = self._check_map_graph_and_extract(map_f, (xs, y))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
|
|
l_xs_ = L_xs_
|
|
l_y_ = L_y_
|
|
map_body_1 = self.map_body_1
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
|
|
getitem = map_impl[0]; map_impl = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor, l_y_ : torch.Tensor):
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None
|
|
getitem = map_impl[0]; map_impl = None
|
|
return (getitem,)""",
|
|
)
|
|
|
|
def test_map_multi_return(self):
|
|
def f(x):
|
|
return control_flow.map(lambda x: (x.sin(), x.sin()), x)
|
|
|
|
x = torch.randn(3)
|
|
graphs = self._check_map_graph_and_extract(f, (x,))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
|
getitem = map_impl[0]
|
|
getitem_1 = map_impl[1]; map_impl = None
|
|
return (getitem, getitem_1)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor):
|
|
child_1 = child.sin()
|
|
child_2 = child.sin(); child = None
|
|
return (child_1, child_2)""",
|
|
)
|
|
|
|
def test_map_pytree_return(self):
|
|
def _construct_pytree(a):
|
|
return (
|
|
a.clone(),
|
|
[[[a.clone()]]],
|
|
a.clone(),
|
|
(a.clone(), (a.clone(),), a.clone()),
|
|
{"a": a.clone()},
|
|
)
|
|
|
|
def f(x):
|
|
def inner_f(xs):
|
|
return _construct_pytree(xs)
|
|
|
|
return control_flow.map(inner_f, x)
|
|
|
|
x = torch.randn(3)
|
|
graphs = self._check_map_graph_and_extract(f, (x,))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
|
getitem = map_impl[0]
|
|
getitem_1 = map_impl[1]
|
|
getitem_2 = map_impl[2]
|
|
getitem_3 = map_impl[3]
|
|
getitem_4 = map_impl[4]
|
|
getitem_5 = map_impl[5]
|
|
value = map_impl[6]; map_impl = None
|
|
return (getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, value)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor):
|
|
child_1 = child.clone()
|
|
child_2 = child.clone()
|
|
child_3 = child.clone()
|
|
child_4 = child.clone()
|
|
child_5 = child.clone()
|
|
child_6 = child.clone()
|
|
child_7 = child.clone(); child = None
|
|
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
|
|
)
|
|
|
|
def test_map_kwargs(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return control_flow.map(lambda x: x.sin(), x=x)
|
|
|
|
x = torch.randn(3)
|
|
self.assertRaises(TypeError, lambda: f(x))
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
def test_map_symint_input(self):
|
|
def fn(x, y):
|
|
def inner(x, y):
|
|
return torch.sin(x + y)
|
|
|
|
return control_flow.map(inner, x, y.size(0))
|
|
|
|
x = torch.randn(3, 1)
|
|
y = torch.randn(3, 1)
|
|
graphs = self._check_map_graph_and_extract(fn, (x, y))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
|
getitem = map_impl[0]; map_impl = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor, const_unused : int):
|
|
add = child + 3; child = None
|
|
sin = torch.sin(add); add = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_map_lowers_to_graph(self):
|
|
def fn(x, y):
|
|
def inner(x, y):
|
|
return torch.sin(x + y)
|
|
|
|
return control_flow.map(inner, x, y.size(0))
|
|
|
|
x = torch.randn(3, 1)
|
|
y = torch.randn(3, 1)
|
|
graphs = self._check_map_graph_and_extract(fn, (x, y))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
|
getitem = map_impl[0]; map_impl = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor, const_unused : int):
|
|
add = child + 3; child = None
|
|
sin = torch.sin(add); add = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_map_example_value_metadata_consistent_with_eager(self):
|
|
from torch._higher_order_ops.map import map_dense
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
def inner(x):
|
|
return x.sin(), x.cos().T, x.sin().view(-1)
|
|
|
|
rand_44 = torch.randn(4, 4)
|
|
inps = [
|
|
torch.randn(3),
|
|
torch.randn(3, 4),
|
|
torch.randn(3, 4, 5, requires_grad=True),
|
|
torch.randn(3, 4, 5, requires_grad=True).permute((2, 0, 1)),
|
|
torch.randn(3, 4, 5, requires_grad=True).detach(),
|
|
torch.randn(3, 4, 5, requires_grad=True).narrow(1, 1, 2),
|
|
rand_44.T,
|
|
rand_44[::2],
|
|
rand_44[::2, ::2],
|
|
rand_44[1::3, 1::3],
|
|
rand_44[1::3, 1::2].T,
|
|
rand_44.unsqueeze(1),
|
|
rand_44.squeeze(0),
|
|
rand_44.reshape(2, 8),
|
|
]
|
|
for x in inps:
|
|
compiled_ret = torch.compile( # noqa: F841
|
|
control_flow.map, backend=backend, fullgraph=True
|
|
)(inner, x)
|
|
eager_sin, eager_transpose, eager_view = map_dense(inner, (x,), ())
|
|
|
|
map_node = next(
|
|
node
|
|
for node in backend.graphs[0].graph.nodes
|
|
if node.op == "call_function" and "map" in node.name
|
|
)
|
|
|
|
fake_sin, fake_transpose, fake_view = map_node.meta["example_value"]
|
|
|
|
def _check_size_stride_contiguous(x, y):
|
|
self.assertEqual(y.size(), x.size())
|
|
self.assertEqual(y.stride(), x.stride())
|
|
self.assertEqual(y.requires_grad, x.requires_grad)
|
|
self.assertEqual(x.is_contiguous(), True)
|
|
self.assertEqual(y.is_contiguous(), True)
|
|
|
|
_check_size_stride_contiguous(eager_sin, fake_sin)
|
|
_check_size_stride_contiguous(eager_transpose, fake_transpose)
|
|
_check_size_stride_contiguous(eager_view, fake_view)
|
|
|
|
torch._dynamo.reset()
|
|
backend.graphs.clear()
|
|
|
|
def test_cond_subgraph_name_is_valid(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
pred = torch.tensor(True)
|
|
pred2 = torch.tensor(False)
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def cond_f(pred, pred2, x, y):
|
|
def true_fn(pred2, x, y):
|
|
return x + y
|
|
|
|
def false_fn(pred2, x, y):
|
|
def true_fn2(x, y):
|
|
return x.sin() - y.cos()
|
|
|
|
def false_fn2(x, y):
|
|
return x.cos() - y.sin()
|
|
|
|
return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
|
|
|
|
return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
|
|
|
|
result = cond_f(pred, pred2, xs, y)
|
|
self.assertEqual(result, xs + y)
|
|
|
|
cond_gm = backend.graphs[0]
|
|
name_set = set()
|
|
name_set.update(name for name, _ in cond_gm.named_modules())
|
|
self.assertEqual(
|
|
name_set,
|
|
{
|
|
"",
|
|
"cond_true_1",
|
|
"cond_false_1",
|
|
"cond_false_1.cond_false_0",
|
|
"cond_false_1.cond_true_0",
|
|
},
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=True,
|
|
dynamic_shapes=True,
|
|
)
|
|
def test_cond_graph_break_in_one_branch(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, x):
|
|
def true_fn(x):
|
|
self.buffer += 1
|
|
return self.buffer.sum() + x.sum()
|
|
|
|
def false_fn(x):
|
|
return (x - 1).sum()
|
|
|
|
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
|
|
|
|
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
|
|
mod_for_eager = Foo()
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_eager(torch.ones(6, 4))
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_compile(torch.ones(3, 4))
|
|
|
|
def test_cond_free_variable_in_both_branches(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
z = torch.ones(4, 4)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, x, y):
|
|
def true_fn(x):
|
|
return x.sum() + self.buffer.sum() + z.sum()
|
|
|
|
def false_fn(x):
|
|
return x.sum() - z.sum() - self.buffer.sum()
|
|
|
|
return control_flow.cond(y, true_fn, false_fn, [x])
|
|
|
|
mod_for_compile = torch.compile(
|
|
Foo(), backend=cnt, dynamic=True, fullgraph=True
|
|
)
|
|
mod_for_eager = Foo()
|
|
|
|
self.assertEqual(
|
|
mod_for_compile(torch.tensor(True), torch.tensor(5)),
|
|
mod_for_eager(torch.tensor(True), torch.tensor(5)),
|
|
)
|
|
|
|
for node in backend.graphs[0].graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.higher_order.cond
|
|
):
|
|
_, _, _, operands = node.args
|
|
# Since we compile with dynamic, each branch takes 4 inputs (buffer, x, z, s1)
|
|
self.assertEqual(len(operands), 4)
|
|
if node.op == "get_attr":
|
|
if str(node.target) in ("cond_true_0, cond_false_0"):
|
|
num_placeholders = len(
|
|
[
|
|
node
|
|
for node in getattr(
|
|
backend.graphs[0], str(node.target)
|
|
).graph.nodes
|
|
if node.op == "placeholder"
|
|
]
|
|
)
|
|
self.assertEqual(num_placeholders, 4)
|
|
|
|
def _check_cond_graph_and_extract(self, fn, args):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
|
|
self.assertEqual(out, fn(*args))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
gm = backend.graphs[0]
|
|
graph = gm.code.strip()
|
|
true_graph = gm.cond_true_0.code.strip()
|
|
false_graph = gm.cond_false_0.code.strip()
|
|
return (graph, true_graph, false_graph)
|
|
|
|
def _check_map_graph_and_extract(self, fn, args):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
|
|
self.assertEqual(out, fn(*args))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
gm = backend.graphs[0]
|
|
graph = gm.code.strip()
|
|
subgraphs = []
|
|
for module_name in gm._modules.keys():
|
|
subgraphs.append(getattr(gm, module_name).code.strip())
|
|
return (graph, *subgraphs)
|
|
|
|
def test_cond_branches_no_arguments(self):
|
|
def fn(x):
|
|
def true_fn():
|
|
return torch.sin(x)
|
|
|
|
def false_fn():
|
|
return torch.cos(x)
|
|
|
|
return control_flow.cond(x.sum() > 0, true_fn, false_fn, ())
|
|
|
|
graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
|
|
if graphs is not None:
|
|
graph, true_graph, false_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
sum_1 = l_x_.sum()
|
|
gt = sum_1 > 0; sum_1 = None
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_,)); gt = cond_true_0 = cond_false_0 = l_x_ = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
true_graph,
|
|
"""\
|
|
def forward(self, l_x_):
|
|
l_x__1 = l_x_
|
|
sin = torch.sin(l_x__1); l_x__1 = None
|
|
return (sin,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
false_graph,
|
|
"""\
|
|
def forward(self, l_x_):
|
|
l_x__1 = l_x_
|
|
cos = torch.cos(l_x__1); l_x__1 = None
|
|
return (cos,)""",
|
|
)
|
|
|
|
def test_cond_branches_no_arguments_no_closure(self):
|
|
def fn(x):
|
|
def true_fn():
|
|
return torch.ones(3, 4)
|
|
|
|
def false_fn():
|
|
return torch.ones(3, 4).sin()
|
|
|
|
return control_flow.cond(x.sum() > 0, true_fn, false_fn, ())
|
|
|
|
self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
|
|
graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
|
|
if graphs is not None:
|
|
graph, true_graph, false_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
sum_1 = l_x_.sum(); l_x_ = None
|
|
gt = sum_1 > 0; sum_1 = None
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, ()); gt = cond_true_0 = cond_false_0 = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
true_graph,
|
|
"""\
|
|
def forward(self):
|
|
ones = torch.ones(3, 4)
|
|
return (ones,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
false_graph,
|
|
"""\
|
|
def forward(self):
|
|
ones = torch.ones(3, 4)
|
|
sin = ones.sin(); ones = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_cond_side_effect_in_one_branches(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
z = [torch.ones(4, 4)]
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, y, x):
|
|
def true_fn(x):
|
|
z.append(x)
|
|
z.append(x)
|
|
z.pop()
|
|
return x.sum() + z[-1].sum()
|
|
|
|
def false_fn(x):
|
|
return x.sum() - z[0].sum()
|
|
|
|
return control_flow.cond(y, true_fn, false_fn, [x])
|
|
|
|
mod_for_eager = Foo()
|
|
mod_for_compile = torch.compile(
|
|
Foo(), backend=cnt, dynamic=True, fullgraph=False
|
|
)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_eager(torch.tensor(True), torch.tensor(5))
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_compile(torch.tensor(True), torch.tensor(5))
|
|
|
|
def test_cond_with_constant_pred(self):
|
|
def test(pred, x):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
return -x
|
|
|
|
return control_flow.cond(pred, true_fn, false_fn, [x])
|
|
|
|
opt_test = torch.compile(test, backend="eager")
|
|
inp = torch.ones(3, 3)
|
|
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
|
|
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
|
|
|
|
def test_map_graph_break(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, xs):
|
|
def body(x):
|
|
self.w += 1
|
|
return x
|
|
|
|
return control_flow.map(body, xs)
|
|
|
|
mod = Module()
|
|
|
|
mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"map doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
|
|
def test_map_side_effect(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
z = [torch.ones(6, 4)]
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, xs):
|
|
def body(x):
|
|
z.append(x)
|
|
z.append(x)
|
|
z.pop()
|
|
return x + z[-1].sum()
|
|
|
|
return control_flow.map(body, xs)
|
|
|
|
mod = Module()
|
|
|
|
mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"map doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
|
|
def test_wrap_subgraph_name_is_valid(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def inner(x, y):
|
|
z = x + y
|
|
return wrap(lambda x: wrap(lambda x: x + z, x), x)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x, y):
|
|
return wrap(inner, x, y)
|
|
|
|
result = f(x, y)
|
|
|
|
self.assertEqual(result, x + y + x)
|
|
wrap_gm = backend.graphs[0]
|
|
names = set()
|
|
names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
|
|
self.assertEqual(
|
|
names,
|
|
{
|
|
"",
|
|
"wrap_body_2",
|
|
"wrap_body_2.wrap_body_1",
|
|
"wrap_body_2.wrap_body_1.wrap_body_0",
|
|
},
|
|
)
|
|
|
|
def test_wrap_allow_local_assign_in_body_fn(self):
|
|
def f(arg1, arg2):
|
|
def inner_f(arg1, arg2):
|
|
a = arg1
|
|
b = arg2
|
|
ret = []
|
|
for x in a:
|
|
ret.append(x + 1)
|
|
for x in b:
|
|
ret.append(x + 1)
|
|
return ret
|
|
|
|
return wrap(inner_f, arg1, arg2)
|
|
|
|
x = torch.ones(3)
|
|
|
|
def my_args_generator():
|
|
yield [x], [x.sin()]
|
|
yield (x,), (x.sin(),)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
actual_graph = self._test_wrap_simple(
|
|
f,
|
|
my_args_generator(),
|
|
arg_count,
|
|
3,
|
|
return_graph=True,
|
|
)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_arg1_0_: "f32[3]", L_arg2_0_: "f32[3]"):
|
|
l_arg1_0_ = L_arg1_0_
|
|
l_arg2_0_ = L_arg2_0_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_arg1_0_, l_arg2_0_); wrap_body_0 = l_arg1_0_ = l_arg2_0_ = None
|
|
getitem: "f32[3]" = wrap[0]
|
|
getitem_1: "f32[3]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
|
|
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
|
|
|
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
def test_capture_global_num(self):
|
|
def f(x):
|
|
return wrap(lambda x: x + global_num, x)
|
|
|
|
x = torch.zeros([])
|
|
# Numbers don't get lifted, so args is still 2.
|
|
self._test_wrap_simple(f, default_args_generator((x,)), 2)
|
|
|
|
def test_capture_global_num_adds_guard(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: x + global_num, x)
|
|
|
|
global global_num
|
|
x = torch.zeros([])
|
|
result = f(x)
|
|
self.assertEqual(result, x + global_num)
|
|
|
|
global_num = torch.randn([]).item()
|
|
result = f(x)
|
|
self.assertEqual(result, x + global_num)
|
|
|
|
def test_capture_input_num(self):
|
|
def f(x, y):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
x = torch.zeros([])
|
|
y = 3.14
|
|
# Numbers don't get lifted, so args is still 2.
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), 2)
|
|
|
|
def test_side_effect_in_body(self):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
x = torch.randn([])
|
|
y = torch.randn([])
|
|
|
|
def inner(x):
|
|
nonlocal y
|
|
y = x
|
|
return x.clone()
|
|
|
|
@torch.compile(backend=backend)
|
|
def f(x):
|
|
return wrap(inner, x)
|
|
|
|
f(x)
|
|
self.assertEqual(y, x)
|
|
assert_dict_matches_regex(
|
|
self,
|
|
dict(counters["graph_break"]),
|
|
{
|
|
r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 1
|
|
},
|
|
)
|
|
|
|
def test_fallback_on_graph_break_simple(self):
|
|
# In the future, there should be a per-HigherOrderOperator switch
|
|
# on whether or not to fallback or raise a loud error.
|
|
# For now we just fallback by default.
|
|
cnt = CompileCounter()
|
|
x = torch.randn([])
|
|
|
|
def inner(x):
|
|
y = x.sin()
|
|
torch._dynamo.graph_break()
|
|
z = y.sin()
|
|
return z
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return wrap(inner, x)
|
|
|
|
result = f(x)
|
|
self.assertEqual(result, inner(x))
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
def test_fallback_on_graph_break_complicated(self):
|
|
cnt = CompileCounter()
|
|
x = torch.randn([])
|
|
|
|
def inner(x):
|
|
y = x.sin()
|
|
y = y * global_var
|
|
torch._dynamo.graph_break()
|
|
z = y.sin()
|
|
return z
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
x = x.clone()
|
|
result = wrap(inner, x)
|
|
return result.clone()
|
|
|
|
result = f(x)
|
|
self.assertEqual(result, inner(x))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_modules(self):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
mod = torch.nn.Linear(3, 3)
|
|
x = torch.randn(3, 3)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: mod(x), x)
|
|
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, mod(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
wrap_node = find_first_node(backend.graphs[0], wrap)
|
|
# 3 args - 1 for input, and other 2 for the weight and bias
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
# Check that the linear bias and weight are getattr in the outer graph
|
|
if not torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2)
|
|
|
|
# Check that the inner function has one op and its a linear op
|
|
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 1)
|
|
linear_node = find_first_node(body_function, torch._C._nn.linear)
|
|
self.assertTrue(linear_node is not None)
|
|
|
|
# Check that the innermost graph does not have any params
|
|
self.assertTrue(len(dict(body_function.named_parameters())) == 0)
|
|
self.assertTrue(len(dict(body_function.named_children())) == 0)
|
|
|
|
def test_flat_list_output(self):
|
|
def f(x):
|
|
return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
|
)
|
|
|
|
def test_support_float_in_output(self):
|
|
counters.clear()
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: [1, torch.sin(x), 2.0], x)
|
|
|
|
x = torch.randn(3)
|
|
result = f(x)
|
|
self.assertEqual(result, [1, torch.sin(x), 2.0])
|
|
|
|
def test_nested_tuple_output(self):
|
|
def f(x):
|
|
((a, b),) = wrap(lambda x: ((x.sin(), x.cos()),), x)
|
|
return a + b
|
|
|
|
x = torch.randn(2, 3)
|
|
|
|
counters.clear()
|
|
arg_count = ifdynstaticdefault(2, 4)
|
|
graph = self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, 4, return_graph=True
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
a: "f32[2, 3]" = wrap[0]
|
|
b: "f32[2, 3]" = wrap[1]; wrap = None
|
|
|
|
add: "f32[2, 3]" = a + b; a = b = None
|
|
return (add,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[2, 3]"):
|
|
child: "f32[2, 3]" = l_x_.sin()
|
|
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
def test_output_with_dict(self):
|
|
def f(x):
|
|
return wrap(lambda x: [{"a": -x}], x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
counters.clear()
|
|
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
graph = self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, 2, return_graph=True
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
value: "f32[3]" = wrap[0]; wrap = None
|
|
return (value,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]"):
|
|
child: "f32[3]" = -l_x_; l_x_ = None
|
|
return (child,)
|
|
""",
|
|
)
|
|
|
|
def test_access_module_attr(self):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
mod = torch.nn.Linear(3, 3)
|
|
x = torch.randn(3, 3)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
y = mod(x)
|
|
return wrap(lambda y: y - mod.bias, y)
|
|
|
|
result = f(x)
|
|
self.assertEqual(result, mod(x) - mod.bias)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
wrap_node = find_first_node(backend.graphs[0], wrap)
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
# Check that the linear bias and weight are getattr in the outer graph
|
|
if not torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2)
|
|
|
|
# Check that the inner function has one op and its a linear op
|
|
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 1)
|
|
|
|
# Check that the innermost graph does not have any params
|
|
self.assertTrue(len(dict(body_function.named_parameters())) == 0)
|
|
self.assertTrue(len(dict(body_function.named_children())) == 0)
|
|
|
|
def test_make_closure(self):
|
|
def f(x, y):
|
|
def g(x):
|
|
return x + y
|
|
|
|
return g(x)
|
|
|
|
def h(x, y):
|
|
return wrap(f, x, y)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_internal_nonlocal(self):
|
|
def f(x, y):
|
|
w = 1
|
|
|
|
def g(x):
|
|
nonlocal w
|
|
w = x
|
|
return x
|
|
|
|
def h(x):
|
|
nonlocal w
|
|
w = w + 1
|
|
return x
|
|
|
|
g(x)
|
|
h(x)
|
|
return w + y
|
|
|
|
def h(x, y):
|
|
return wrap(f, x, y)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_capture_numpy_number(self):
|
|
import numpy as np
|
|
|
|
y = np.float32(1.0)
|
|
|
|
def f(x):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
x = torch.randn(3)
|
|
# np.number are lifted to graph inputs
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_freevars_as_inputs_to_wrap(self):
|
|
y = torch.randn(3)
|
|
|
|
def f(x):
|
|
return wrap(lambda x, y: x + y, x, y)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_lift_tensor_constant(self):
|
|
def f(x):
|
|
y = torch.tensor(1.0)
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
|
)
|
|
|
|
def test_nested_wrap(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
mod = MockModule()
|
|
|
|
# Two levels of wrap ops
|
|
def gn(x):
|
|
return torch.cos(x) + wrap(mod, x)
|
|
|
|
def fn(x):
|
|
return wrap(gn, x)
|
|
|
|
arg_count = ifdynstaticdefault(4, 5)
|
|
self._test_wrap_simple(
|
|
fn, default_args_generator((torch.randn(10, 10),)), arg_count
|
|
)
|
|
|
|
def test_fn_with_kwargs_in_torch_ops(self):
|
|
def fn(x):
|
|
return wrap(lambda z: torch.cos(input=z), x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
|
|
|
def test_hooks(self):
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.net = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
model = ToyModel()
|
|
forward_handles = {}
|
|
activations = {}
|
|
|
|
def save_activations(mod, inp, out):
|
|
activations[name] = inp
|
|
|
|
for name, module in model.named_children():
|
|
forward_handles[name] = module.register_forward_hook(save_activations)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return wrap(lambda x: model(x), x)
|
|
|
|
for i in range(2):
|
|
# second iteration is key, hooks would have fired during aot trace
|
|
# on first iter
|
|
activations.clear()
|
|
x = torch.randn((10, 10))
|
|
pred = fn(x)
|
|
loss = pred.sum()
|
|
loss.backward()
|
|
|
|
self.assertTrue(activations.keys() == forward_handles.keys())
|
|
|
|
def _get_source_fn_stack(self, gm, node_names):
|
|
ret = {}
|
|
for mod in gm.modules():
|
|
for node in mod.graph.nodes:
|
|
if node.name in node_names:
|
|
actual_stack = [
|
|
name for name, _ in node.meta.get("source_fn_stack", [])
|
|
]
|
|
ret[node.name] = actual_stack
|
|
return ret
|
|
|
|
def test_wrap_source_fn_stack(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
mod = MockModule()
|
|
|
|
def gn(x):
|
|
return torch.cos(x) + wrap(mod, x)
|
|
|
|
def fn(x):
|
|
return wrap(gn, x)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
inp = torch.randn((4, 4))
|
|
torch.compile(fn, backend=backend, fullgraph=True)(inp)
|
|
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""\
|
|
{'add': ['wrap', 'add'],
|
|
'cos': ['wrap', 'cos'],
|
|
'linear': ['wrap', 'wrap', 'linear']}""",
|
|
)
|
|
|
|
def test_cond_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=backend, fullgraph=True)
|
|
def cond_f(pred, pred2, x, y):
|
|
def true_fn(pred2, x, y):
|
|
return x + y
|
|
|
|
def false_fn(pred2, x, y):
|
|
def true_fn2(x, y):
|
|
return x.sin() - y.cos()
|
|
|
|
def false_fn2(x, y):
|
|
return x.cos() - y.sin()
|
|
|
|
return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
|
|
|
|
return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
|
|
|
|
pred = torch.tensor(True)
|
|
pred2 = torch.tensor(False)
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
cond_f(pred, pred2, xs, y)
|
|
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""\
|
|
{'add': ['cond', 'add'],
|
|
'cos': ['cond', 'cond', 'cos'],
|
|
'sin': ['cond', 'cond', 'sin'],
|
|
'sub': ['cond', 'cond', 'sub']}""",
|
|
)
|
|
|
|
def test_map_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3)
|
|
|
|
@torch.compile(backend=backend, fullgraph=True)
|
|
def map_f(xs, y):
|
|
def inner(x, y):
|
|
def inner2(x, y):
|
|
return x + y
|
|
|
|
return control_flow.map(inner2, x, y) * y.cos()
|
|
|
|
return control_flow.map(inner, xs, y).sin()
|
|
|
|
map_f(xs, y)
|
|
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""\
|
|
{'add': ['map_impl', 'map_impl', 'add'],
|
|
'cos': ['map_impl', 'cos'],
|
|
'sin': ['sin']}""",
|
|
)
|
|
|
|
def test_grad_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
@torch.compile(backend=backend, fullgraph=False)
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(torch.func.grad(fn))(x)
|
|
|
|
x = torch.randn(())
|
|
|
|
wrapper_fn(x)
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""{'sin': ['sin']}""",
|
|
)
|
|
|
|
def test_vmap_multiply_scalar(self):
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def g(x):
|
|
return torch.vmap(torch.mul, in_dims=(0, None))(x, 3.14)
|
|
|
|
x = torch.randn(3)
|
|
y = g(x)
|
|
self.assertEqual(y, x * 3.14)
|
|
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def f(x):
|
|
return torch.vmap(torch.mul, in_dims=(0, None))(x, 314)
|
|
|
|
x = torch.randn(3)
|
|
y = f(x)
|
|
self.assertEqual(y, x * 314)
|
|
|
|
def test_vmap_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
def inner_fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
|
|
|
|
@torch.compile(backend=backend, fullgraph=True)
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: inner_fn(x.cos()))(x)
|
|
|
|
x = torch.randn(3, 3, 3, 3)
|
|
fn(x)
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(
|
|
gm, {"sum_1", "sum_2", "batched_output"}
|
|
)
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""{'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""",
|
|
)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/137061
|
|
def test_dynamic_shapes_over_vmap_batch_size(self):
|
|
def gn(a, b, c, d):
|
|
return a + b + c + d
|
|
|
|
def fn(func, a, b, c, d):
|
|
a = torch.arange(a)
|
|
b = torch.arange(b)
|
|
c = torch.arange(c)
|
|
d = torch.arange(d)
|
|
func = torch.vmap(func, in_dims=(0, None, None, None))
|
|
func = torch.vmap(func, in_dims=(None, 0, None, None))
|
|
func = torch.vmap(func, in_dims=(None, None, 0, None))
|
|
func = torch.vmap(func, in_dims=(None, None, None, 0))
|
|
return func(a, b, c, d)
|
|
|
|
cnt = CompileCounterWithBackend("eager")
|
|
# We generate corresponding dynamic shapes test case at
|
|
# `test/dynamo/test_dynamic_shapes.py` automatically.
|
|
compiled_fn = torch.compile(fn, backend=cnt)
|
|
a, b, c, d = 2, 4, 8, 8
|
|
self.assertEqual(fn(gn, a, b, c, d), compiled_fn(gn, a, b, c, d))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
a, b, c, d = 4, 8, 16, 16
|
|
self.assertEqual(fn(gn, a, b, c, d), compiled_fn(gn, a, b, c, d))
|
|
# Ensure no recompile if dynamic shapes enabled.
|
|
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))
|
|
graph = cnt.graphs[0]
|
|
|
|
# Check dynamic shapes generates correct graph.
|
|
if check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, L_a_ : torch.SymInt, L_b_ : torch.SymInt, L_c_ : torch.SymInt, L_d_ : torch.SymInt):
|
|
l_a_ = L_a_
|
|
l_b_ = L_b_
|
|
l_c_ = L_c_
|
|
l_d_ = L_d_
|
|
a = torch.arange(l_a_)
|
|
b = torch.arange(l_b_)
|
|
c = torch.arange(l_c_)
|
|
d = torch.arange(l_d_)
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(l_d_, 'error'); _vmap_increment_nesting = None
|
|
child = torch._functorch.predispatch._add_batch_dim(d, 0, 1); d = None
|
|
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(l_c_, 'error'); _vmap_increment_nesting_1 = None
|
|
child_1 = torch._functorch.predispatch._add_batch_dim(c, 0, 2); c = None
|
|
lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_2 = None
|
|
_vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(l_b_, 'error'); _vmap_increment_nesting_2 = None
|
|
child_2 = torch._functorch.predispatch._add_batch_dim(b, 0, 3); b = None
|
|
lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_3 = None
|
|
_vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(l_a_, 'error'); _vmap_increment_nesting_3 = None
|
|
_add_batch_dim_3 = torch._functorch.predispatch._add_batch_dim(a, 0, 4); a = None
|
|
add = _add_batch_dim_3 + child_2; _add_batch_dim_3 = child_2 = None
|
|
add_1 = add + child_1; add = child_1 = None
|
|
batched_outputs = add_1 + child; add_1 = child = None
|
|
batched_outputs_1 = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 4, l_a_, 0); batched_outputs = l_a_ = None
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
batched_outputs_2 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 3, l_b_, 0); batched_outputs_1 = l_b_ = None
|
|
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
batched_outputs_3 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_2, 2, l_c_, 0); batched_outputs_2 = l_c_ = None
|
|
_vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None
|
|
_remove_batch_dim_3 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None
|
|
_vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None
|
|
return (_remove_batch_dim_3,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_pytree_operands(self):
|
|
def _construct_pytree():
|
|
a = torch.randn(3, 3)
|
|
b = torch.randn(3, 3)
|
|
c = torch.randn(3, 3)
|
|
d = torch.randn(3, 3)
|
|
e = torch.randn(3, 3)
|
|
f = torch.randn(3, 3)
|
|
g = torch.randn(3, 3)
|
|
return (a, [[[b]]], c, (d, (e,), f), {"g": g})
|
|
|
|
pred = torch.tensor(True)
|
|
inp = _construct_pytree()
|
|
|
|
def _reduce_sum(flattened):
|
|
init = 0
|
|
for val in flattened:
|
|
init += val
|
|
return init
|
|
|
|
def _reduce_max(flattened):
|
|
init = flattened[0]
|
|
for val in flattened:
|
|
init = max(val, init)
|
|
return init
|
|
|
|
def true_fn(pytree_in):
|
|
flattened, spec = pytree.tree_flatten(pytree_in)
|
|
return _reduce_sum(flattened)
|
|
|
|
def false_fn(pytree_in):
|
|
flattened, spec = pytree.tree_flatten(pytree_in)
|
|
return _reduce_max(flattened)
|
|
|
|
def fn(pred, pytree_in):
|
|
return torch.cond(pred, true_fn, false_fn, [pytree_in])
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
compiled_res = torch.compile(fn, backend=backend)(pred, inp)
|
|
eager_res = fn(pred, inp)
|
|
self.assertEqual(compiled_res, eager_res)
|
|
graph = backend.graphs[0]
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
|
|
l_pred_ = L_pred_
|
|
l_pytree_in_0_ = L_pytree_in_0_
|
|
l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
|
|
l_pytree_in_2_ = L_pytree_in_2_
|
|
l_pytree_in_3_0_ = L_pytree_in_3_0_
|
|
l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
|
|
l_pytree_in_3_2_ = L_pytree_in_3_2_
|
|
l_pytree_in_4_g_ = L_pytree_in_4_g_
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_)); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_pytree_operands_with_non_tensor_leaves(self):
|
|
def fn(pred, pytree_in):
|
|
return torch.cond(
|
|
pred, lambda x: x[0] + 1, lambda x: x[0] * 2, (pytree_in,)
|
|
)
|
|
|
|
pred = torch.tensor(True)
|
|
for pytree_in in [("string",), (1.0,)]:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
|
):
|
|
fn(pred, pytree_in)
|
|
|
|
for pytree_in in [("string",), (1.0,)]:
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
torch.compile(fn, backend="eager")(pred, pytree_in)
|
|
|
|
def test_cond_with_empty_operands(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x, y, z):
|
|
def true_fn():
|
|
return y + 2
|
|
|
|
def false_fn():
|
|
return z + 1
|
|
|
|
return torch.cond(x, true_fn, false_fn)
|
|
|
|
zeros = torch.zeros(1)
|
|
ones = torch.ones(1)
|
|
self.assertEqual(fn(zeros, ones, ones), torch.tensor([2.0]))
|
|
self.assertEqual(fn(ones, ones, ones), torch.tensor([3.0]))
|
|
|
|
def test_hopify_generic_wrap(self):
|
|
from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper
|
|
|
|
def my_hop_fn_impl(fn, *args, k=1, **kwargs):
|
|
def wrapper(*args, **kwargs):
|
|
out = fn(*args, **kwargs)
|
|
if isinstance(out, tuple):
|
|
return (out[0] + k,)
|
|
return out + k
|
|
|
|
return wrapper
|
|
|
|
def my_hop_fn(fn, *args, k=1, **kwargs):
|
|
return dynamo_bypassing_wrapper(
|
|
functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs
|
|
)
|
|
|
|
def my_hop_fn_2_impl(fn, *args, g=None):
|
|
def wrapper(*args, **kwargs):
|
|
assert g is not None
|
|
out = fn(*args)
|
|
if isinstance(out, tuple):
|
|
return (g(out[0]),)
|
|
return g(out)
|
|
|
|
return wrapper
|
|
|
|
def my_hop_fn_2(fn, *args, g=None, **kwargs):
|
|
return dynamo_bypassing_wrapper(
|
|
functools.partial(my_hop_fn_2_impl, g=g), fn, *args, **kwargs
|
|
)
|
|
|
|
def gn(x, h=1):
|
|
return x.sin() + h
|
|
|
|
def fn(x, b):
|
|
out = my_hop_fn(gn, x, h=b, k=2)
|
|
return out
|
|
|
|
a = torch.rand((4, 4), requires_grad=True)
|
|
b = torch.rand((4, 4))
|
|
compiled_fn = torch.compile(
|
|
fn, backend="aot_eager_decomp_partition", fullgraph=True
|
|
)
|
|
self.assertEqual(compiled_fn(a, b), fn(a, b))
|
|
|
|
def g(x):
|
|
return x.cos()
|
|
|
|
def fn_2(x, b):
|
|
out = my_hop_fn_2(fn, x, b, g=g)
|
|
return out
|
|
|
|
a = torch.rand((4, 4), requires_grad=True)
|
|
compiled_fn_2 = torch.compile(
|
|
fn_2, backend="aot_eager_decomp_partition", fullgraph=True
|
|
)
|
|
self.assertEqual(compiled_fn_2(a, b), fn_2(a, b))
|
|
|
|
def test_hints_wrapper(self):
|
|
def ref_fn(x, y):
|
|
x = x + y
|
|
x = torch.relu(x)
|
|
x = x + y
|
|
return torch.abs(x)
|
|
|
|
def fn_with_hints(x, y):
|
|
x = x + y
|
|
|
|
def inner_body_fn(x, y):
|
|
x = torch.relu(x)
|
|
x = x + y
|
|
return x
|
|
|
|
def outer_body_fn(x, y):
|
|
x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True})
|
|
x = torch.abs(x)
|
|
return x
|
|
|
|
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True})
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
eager_res = fn_with_hints(x, y)
|
|
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
|
ref_res = ref_fn(x, y)
|
|
self.assertEqual(eager_res, ref_res)
|
|
self.assertEqual(compiled_res, ref_res)
|
|
self.assertEqual(len(cnt.graphs), 1)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
graph = backend.graphs[0]
|
|
self.assertExpectedInline(
|
|
normalize_gm(graph.print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None
|
|
|
|
hints_wrapper_body_1 = self.hints_wrapper_body_1
|
|
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
|
|
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
|
return (res,)
|
|
|
|
class hints_wrapper_body_1(torch.nn.Module):
|
|
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
|
hints_wrapper_body_0 = self.hints_wrapper_body_0
|
|
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
|
|
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
|
|
|
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
|
|
return (x_2,)
|
|
|
|
class hints_wrapper_body_0(torch.nn.Module):
|
|
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
|
x_1: "f32[2, 4]" = torch.relu(x); x = None
|
|
|
|
x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None
|
|
return (x_2,)
|
|
""",
|
|
)
|
|
|
|
def test_hints_wrapper_no_hints(self):
|
|
def fn_with_hints(x, y):
|
|
def outer_body_fn(x, y):
|
|
x = torch.add(x, y)
|
|
return x
|
|
|
|
res = hints_wrapper(outer_body_fn, (x, y), {})
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
msg = "hints_wrapper - key hints not provided"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.compile(fn_with_hints, backend=cnt)(x, y)
|
|
|
|
def test_hints_wrapper_incorrect_type(self):
|
|
def fn_with_hints(x, y):
|
|
def outer_body_fn(x, y):
|
|
x = torch.add(x, y)
|
|
return x
|
|
|
|
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)})
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
msg = r"hints must be a dict containing int, float, bool or str value,"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.compile(fn_with_hints, backend=cnt)(x, y)
|
|
|
|
def test_hints_wrapper_pytree_inputs(self):
|
|
def fn_with_hints(x, y):
|
|
def outer_body_fn(x):
|
|
res = torch.add(x[0], x[1]["test"])
|
|
return res
|
|
|
|
res = hints_wrapper(
|
|
outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True}
|
|
)
|
|
return res
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
msg = r"args must be a tuple of tensors, ints, floats, or bools,"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
fn_with_hints(x, y)
|
|
|
|
|
|
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_grad_guard_ok(self, records):
|
|
vmap = torch.vmap
|
|
grad = torch.func.grad
|
|
|
|
def g(x):
|
|
return vmap(grad(torch.sin))(x)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return vmap(g)(x)
|
|
|
|
x = torch.randn(4, 5)
|
|
y = fn(x)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(x.cos(), y)
|
|
|
|
# Calling the same function again won't have any effect on guards
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_grad_guard_fail(self, records):
|
|
grad = torch.func.grad
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return grad(torch.sin)(x.sum())
|
|
|
|
x = torch.randn([])
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# calling again should not invalidate the graph
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call grad should retrigger compilation
|
|
x = torch.randn(3)
|
|
grad(fn)(x)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_dual_level_guard(self, records):
|
|
fwAD = torch.autograd.forward_ad
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(foo, tangent):
|
|
with fwAD.dual_level():
|
|
dual = fwAD.make_dual(foo, tangent[1:])
|
|
return dual
|
|
|
|
foo = torch.rand(2)
|
|
tangent = torch.rand(3)
|
|
fn(foo, tangent)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# calling again should not invalidate the graph
|
|
fn(foo, tangent)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# assertRaises is only here because Nested forward mode AD is not supported
|
|
with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
|
|
with fwAD.dual_level():
|
|
fn(foo, tangent)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "forward_ad")
|
|
self.assertIn(
|
|
"""torch.autograd.forward_ad._current_level == -1""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_jvp_guard_fail(self, records):
|
|
jvp = torch.func.jvp
|
|
vmap = torch.func.vmap
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return jvp(torch.sin, (x,), (x,))
|
|
|
|
x = torch.randn(3, 4)
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# calling again should not invalidate the graph
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call jvp should retrigger compilation
|
|
x = torch.randn(3, 4, 5)
|
|
jvp(vmap(fn), (x,), (x,))
|
|
|
|
self.assertGreater(len(records), 0)
|
|
if self.hasRecord(records, "pyfunctorch"):
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
elif self.hasRecord(records, "forward_ad"):
|
|
record = self.getRecord(records, "forward_ad")
|
|
self.assertIn(
|
|
"""torch.autograd.forward_ad._current_level == -1""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_guard_ok(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.randn(3, 3, 4, 5)
|
|
y = fn(x)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(x.sin(), y)
|
|
|
|
# Calling the same function again won't have any effect on guards
|
|
z = fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(x.sin(), z)
|
|
|
|
# calling with a different object will also not affect guards
|
|
w = fn(z)
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(z.sin(), w)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_guard_fail_different_state(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 4)
|
|
y = torch.vmap(fn, randomness="same")(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call vmap(vmap(fn))(x) should retrigger compilation
|
|
y = torch.vmap(fn, randomness="different")(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
|
|
record.getMessage(),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_guard_fail(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
y = torch.vmap(fn)(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call vmap(vmap(fn))(x) should retrigger compilation as
|
|
# _functorch.current_level() is not the same
|
|
x = torch.zeros(3, 3, 3, 4, 5)
|
|
y = torch.vmap(torch.vmap(fn))(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
|
record.getMessage(),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_grad_vmap_guard_fail(self, records):
|
|
vmap = torch.vmap
|
|
grad = torch.func.grad
|
|
|
|
def g(x):
|
|
y = vmap(torch.sin, randomness="same")(x)
|
|
return y.sum(0)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return grad(g)(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = vmap(fn, randomness="error")(x)
|
|
self.assertEqual(x.cos(), y)
|
|
|
|
# previous FX graph should be invalidated
|
|
x = torch.randn(3, 3, 4)
|
|
y = vmap(vmap(fn, randomness="different"))(x)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_recompile_different_states(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(fn, randomness="same")(x)
|
|
self.assertEqual(len(records), 0) # sanity check
|
|
|
|
torch.vmap(fn, randomness="different")(x)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(guards=True)
|
|
def test_emit_functorch_guard_if_active(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.sin(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
_ = fn(x)
|
|
self.assertFalse(self.hasRecord(records, "pyfunctorch")) # sanity check
|
|
|
|
_ = torch.vmap(fn)(x)
|
|
self.assertTrue(self.hasRecord(records, "pyfunctorch"))
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_linearize_recompiles(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
out, jvp_fn = torch.func.linearize(torch.sin, x)
|
|
return out, jvp_fn(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
z = torch.randn(2, 3)
|
|
fn(z)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
y = torch.randn(3, 4)
|
|
fn(y)
|
|
self.assertGreater(len(records), 0)
|
|
|
|
|
|
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|
def tearDown(self):
|
|
# Ensure that in the case of a test failure, the next test won't fail
|
|
# because of a previous call to _vmap_increment_nesting that wasn't undone
|
|
# i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
|
|
# and the call to increment nesting is not undone
|
|
if not TEST_WITH_TORCHDYNAMO:
|
|
return
|
|
|
|
warn = False
|
|
while ci := torch._C._functorch.peek_interpreter_stack():
|
|
if ci.key() == torch._C._functorch.TransformType.Vmap:
|
|
warn = True
|
|
torch._C._functorch._vmap_decrement_nesting()
|
|
else:
|
|
break
|
|
|
|
if warn:
|
|
msg = (
|
|
"Interpreter stack is not empty. Test should have called "
|
|
"'torch._C._functorch._vmap_decrement_nesting()'"
|
|
)
|
|
warnings.warn(msg)
|
|
|
|
def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
|
|
backend = EagerAndRecordGraphs()
|
|
actual = fn(*inputs)
|
|
expected = torch.compile(fn, backend=backend, fullgraph=fullgraph)(*inputs)
|
|
|
|
self.assertEqual(actual, expected)
|
|
|
|
wrapped_gm = backend.graphs[graph_idx]
|
|
return wrapped_gm
|
|
|
|
def test_hessian(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.hessian(torch.sin)(x)
|
|
|
|
x = torch.randn(4, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
primals_out: "f32[4, 3]" = torch.sin(diff_primals)
|
|
|
|
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primals_out, 3)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor_1: "i64[1]" = torch.tensor((12,))
|
|
cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None
|
|
getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
|
|
neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None
|
|
unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None
|
|
|
|
chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal_1: "f32[12]" = chunk_1.diagonal(0)
|
|
fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
|
|
|
|
basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 3); basis = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim_1 = None
|
|
batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
|
|
|
output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(output_input, level = 0); output_input = None
|
|
primal: "f32[4, 3, 4, 3]" = _unpack_dual[0]
|
|
dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results_1: "f32[12, 4, 3, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
|
|
movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1); results_1 = None
|
|
split_2 = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None
|
|
|
|
unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
|
return (unflatten,)
|
|
""",
|
|
)
|
|
|
|
def test_hessian_argnums(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return x.sin()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.hessian(fn, argnums=(1,))(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
"\n".join(actual.split("\n")[:-2]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
|
|
|
|
child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
_wrap_for_grad_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
|
|
child_4: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
primals_out: "f32[4, 3]" = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None
|
|
|
|
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primals_out, 3)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor_1: "i64[1]" = torch.tensor((12,))
|
|
cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None
|
|
getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
|
|
neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None
|
|
unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None
|
|
|
|
chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal_1: "f32[12]" = chunk_1.diagonal(0)
|
|
fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
|
|
|
|
basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 3); basis = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); primals_out = child_4 = _add_batch_dim_1 = None
|
|
child_5: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
child_6: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = child_6.split((12,), dim = 0); child_6 = None
|
|
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
|
|
|
child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4)); split_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(child_7, level = 0); child_7 = None
|
|
primal: "f32[4, 3, 3, 4]" = _unpack_dual[0]; _unpack_dual = None
|
|
|
|
tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal)
|
|
|
|
child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None
|
|
|
|
child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
child_10: "f32[12, 4, 3, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
|
|
movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1); child_10 = None
|
|
split_2 = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[4, 3, 3, 4, 12]" = split_2[0]; split_2 = None
|
|
|
|
unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
actual.split("\n")[-2],
|
|
""" return (unflatten,)""",
|
|
)
|
|
|
|
def test_jacrev(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jacrev(torch.sin)(x)
|
|
|
|
x = torch.randn(4, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
primals_out: "f32[4, 3]" = torch.sin(diff_primals)
|
|
|
|
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primals_out, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None
|
|
batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
|
|
|
output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
|
|
return (output_input,)
|
|
""",
|
|
)
|
|
|
|
def test_jacrev_two_tensors_argnums(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacrev(fn, argnums=1)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = _wrap_for_grad = None
|
|
diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
primals_out: "f32[3, 4]" = diff_primals.sin()
|
|
|
|
results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None
|
|
batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
|
|
|
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
|
|
return (output_input,)
|
|
""",
|
|
)
|
|
|
|
def test_jacrev_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin(), x
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacrev(fn, argnums=1, has_aux=True)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
primals_out: "f32[3, 4]" = diff_primals.sin()
|
|
|
|
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None
|
|
batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
|
|
|
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
|
|
return (output_input, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_vjp(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
(out, vjpfunc) = torch.func.vjp(fn, x)
|
|
return out
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[5]" = child.sin(); child = None
|
|
primals_out: "f32[]" = sin.sum(); sin = None
|
|
|
|
results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (results,)
|
|
""",
|
|
)
|
|
|
|
def test_vjp_multiple_outputs(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, v):
|
|
fn = lambda x: (x.sin(), x.cos()) # noqa: E731
|
|
(out, vjpfunc) = torch.func.vjp(fn, x)
|
|
vjps = vjpfunc((v, v))
|
|
return out, vjps
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
child_1: "f32[5]" = child.sin()
|
|
child_2: "f32[5]" = child.cos(); child = None
|
|
|
|
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
|
|
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = None
|
|
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
|
|
""",
|
|
)
|
|
|
|
def test_vjp_multiple_outputs_python_struct(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, v):
|
|
fn = lambda x: {"first": x.sin(), "second": x.cos()} # noqa: E731
|
|
(out, vjpfunc) = torch.func.vjp(fn, x)
|
|
vjps = vjpfunc({"first": v, "second": v.sin()})
|
|
return out, vjps
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
child_1: "f32[5]" = child.sin()
|
|
child_2: "f32[5]" = child.cos(); child = None
|
|
|
|
value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
|
|
value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
child_4: "f32[5]" = l_v_.sin()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None
|
|
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
|
|
return (value, value_1, getitem)
|
|
""",
|
|
)
|
|
|
|
def test_vjp_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum(), x
|
|
|
|
def wrapper_fn(x, v):
|
|
(out, vjpfunc, _) = torch.func.vjp(fn, x, has_aux=True)
|
|
return out
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[5]" = child.sin()
|
|
primals_out: "f32[]" = sin.sum(); sin = None
|
|
|
|
aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None
|
|
|
|
results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (results,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_functional_call(self):
|
|
def wrapper_fn(model, params, inputs, targets):
|
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
model = torch.nn.Linear(3, 3)
|
|
params = dict(model.named_parameters())
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
|
|
wrapped_gm = self._compile_check(wrapper_fn, (model, params, inputs, targets))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
|
|
l_model_parameters_weight_ = L_model_parameters_weight_
|
|
l_model_parameters_bias_ = L_model_parameters_bias_
|
|
l_inputs_ = L_inputs_
|
|
l_targets_ = L_targets_
|
|
|
|
prediction: "f32[64, 3]" = torch._C._nn.linear(l_inputs_, l_model_parameters_weight_, l_model_parameters_bias_); l_inputs_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None
|
|
|
|
mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None
|
|
return (mse_loss,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
|
|
l_inputs_ = L_inputs_
|
|
l_targets_ = L_targets_
|
|
|
|
prediction: "f32[64, 3]" = self.model(l_inputs_); l_inputs_ = None
|
|
|
|
mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None
|
|
return (mse_loss,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_functional_call_sequential_params_and_buffers(self):
|
|
# copied from test/test_stateless.py
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(1, 1)
|
|
self.register_buffer("buffer", torch.ones(1))
|
|
self.foo = 0.0
|
|
|
|
def forward(self, x):
|
|
return self.l1(x) + self.buffer
|
|
|
|
def wrapper_fn(model, params, buffers, inputs):
|
|
# two separate dictionaries
|
|
return torch.func.functional_call(model, (params, buffers), inputs)
|
|
|
|
model = MockModule()
|
|
params = dict(model.named_parameters())
|
|
buffers = dict(model.named_buffers())
|
|
inputs = torch.tensor([[1.5]])
|
|
|
|
wrapped_gm = self._compile_check(
|
|
wrapper_fn, (model, params, buffers, inputs), fullgraph=False
|
|
)
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
expected = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_inputs_: "f32[1, 1]", L_model_modules_l1_parameters_weight_: "f32[1, 1]", L_model_modules_l1_parameters_bias_: "f32[1]", L_model_buffers_buffer_: "f32[1]"):
|
|
l_inputs_ = L_inputs_
|
|
l_model_modules_l1_parameters_weight_ = L_model_modules_l1_parameters_weight_
|
|
l_model_modules_l1_parameters_bias_ = L_model_modules_l1_parameters_bias_
|
|
l_model_buffers_buffer_ = L_model_buffers_buffer_
|
|
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_model_modules_l1_parameters_weight_, l_model_modules_l1_parameters_bias_); l_inputs_ = l_model_modules_l1_parameters_weight_ = l_model_modules_l1_parameters_bias_ = None
|
|
add: "f32[1, 1]" = linear + l_model_buffers_buffer_; linear = l_model_buffers_buffer_ = None
|
|
return (add,)
|
|
"""
|
|
# We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it.
|
|
self.assertExpectedInline(
|
|
empty_line_normalizer(actual),
|
|
empty_line_normalizer(normalize_gm(expected)),
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[1, 1]"):
|
|
l_x_ = L_x_
|
|
|
|
l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_); l_x_ = None
|
|
l__self___buffer: "f32[1]" = self.L__self___buffer
|
|
add: "f32[1, 1]" = l__self___l1 + l__self___buffer; l__self___l1 = l__self___buffer = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=False)
|
|
def test_functional_call_disable_inline_nn_module(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(model, params, inputs, targets):
|
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
model = torch.nn.Linear(3, 3)
|
|
params = dict(model.named_parameters())
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
|
|
actual = wrapper_fn(model, params, inputs, targets)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
|
|
model, params, inputs, targets
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(
|
|
{
|
|
"torch.func.functional_call capture is disabled, it can be "
|
|
"turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1,
|
|
},
|
|
dict(counters["graph_break"]),
|
|
)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
output: "f32[]" = sin.sum(); sin = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_freevar_tensor(self):
|
|
counters.clear()
|
|
y = torch.randn(3, 3)
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
expected = wrapper_fn(x)
|
|
actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_freevar_python_scalar(self):
|
|
counters.clear()
|
|
y = 3
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + 3; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_capture_tensor(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
y = torch.randn(3)
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
y: "f32[3]" = torch.randn(3)
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + y; sin = y = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_closure_scalar(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
y = 3.14
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
|
|
# Graph break because dynamo is unable to get source `fn` and
|
|
# functools.wraps in `grad` leads to graph-break
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + 3.14; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_has_aux(self):
|
|
counters.clear()
|
|
|
|
y = 3.14
|
|
|
|
def fn(x):
|
|
return ((x.sin() + y).sum(), x.cos())
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn, has_aux=True)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + 3.14; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = diff_args.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_grad_two_tensor_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return ((x.sin() + y).sum(), x.cos())
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.grad(fn, has_aux=True)(x, y)
|
|
|
|
y = torch.randn(3, 3, 3)
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
_wrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = diff_args.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_grad_two_tensor_all_grad_has_aux(self):
|
|
counters.clear()
|
|
|
|
nums = (0, 1)
|
|
|
|
def fn(x, y):
|
|
return ((x.sin() + y).sum(), x.cos())
|
|
|
|
def wrapper_fn_const_var(x, y):
|
|
return torch.func.grad(fn, argnums=(0, 1), has_aux=True)(x, y)
|
|
|
|
def wrapper_fn_tuple_var(x, y):
|
|
return torch.func.grad(fn, argnums=nums, has_aux=True)(x, y)
|
|
|
|
y = torch.randn(3, 3, 3)
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm_const_var = self._compile_check(wrapper_fn_const_var, (x, y))
|
|
wrapped_gm_tuple_var = self._compile_check(wrapper_fn_tuple_var, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual_const_var = normalize_gm(
|
|
wrapped_gm_const_var.print_readable(print_output=False)
|
|
)
|
|
actual_tuple_var = normalize_gm(
|
|
wrapped_gm_tuple_var.print_readable(print_output=False)
|
|
)
|
|
self.assertExpectedInline(
|
|
actual_const_var,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
|
|
|
|
_set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
|
|
|
|
sin: "f32[3, 3, 3]" = child.sin()
|
|
add: "f32[3, 3, 3]" = sin + child_1; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = child.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None
|
|
child_2: "f32[3, 3, 3]" = _autograd_grad[0]
|
|
child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None
|
|
|
|
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
|
|
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1)
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
actual_tuple_var,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
|
|
|
|
_set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
|
|
|
|
sin: "f32[3, 3, 3]" = child.sin()
|
|
add: "f32[3, 3, 3]" = sin + child_1; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = child.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None
|
|
child_2: "f32[3, 3, 3]" = _autograd_grad[0]
|
|
child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None
|
|
|
|
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
|
|
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_grad_over_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(torch.func.grad(fn))(x)
|
|
|
|
x = torch.randn(())
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False)
|
|
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_1 = None
|
|
_grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting_1 = None
|
|
|
|
diff_args_1: "f32[]" = torch._C._functorch._wrap_for_grad(diff_args, 2)
|
|
|
|
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
|
|
|
|
_set_tensor_requires_grad_1: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1); _set_tensor_requires_grad_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
|
|
|
|
sin: "f32[]" = diff_args_1.sin()
|
|
output: "f32[]" = sin.sum(); sin = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None
|
|
grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_2 = None
|
|
|
|
_autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None
|
|
|
|
grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None
|
|
|
|
output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None
|
|
|
|
_grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_3,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_with_graph_break(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
torch._dynamo.graph_break()
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
actual = wrapper_fn(x)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_with_side_effect(self):
|
|
counters.clear()
|
|
|
|
foo = [1, 2]
|
|
|
|
def fn(x):
|
|
foo.append(3)
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
actual = wrapper_fn(x)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_pytree(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
x1, x2 = x
|
|
return x1.sin().sum() + x2
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x1 = torch.randn(3, 3, 3)
|
|
x2 = torch.randn(())
|
|
actual = wrapper_fn((x1, x2))
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
|
|
(x1, x2)
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_non_tensor_input(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return x.sin().sum() + y
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.grad(fn)(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = 3.0
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
sum_1: "f32[]" = sin.sum(); sin = None
|
|
output: "f32[]" = sum_1 + 3.0; sum_1 = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_fn_with_kwargs(self):
|
|
def fn(x, y):
|
|
return (x + y).sum()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.grad(fn)(x, y=y)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
actual = wrapper_fn(x, y)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_jacfwd(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jacfwd(torch.sin)(x)
|
|
|
|
x = torch.randn(4, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
|
|
result_duals: "f32[4, 3]" = torch.sin(_make_dual); _make_dual = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[4, 3]" = _unpack_dual[0]
|
|
dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
movedim: "f32[4, 3, 12]" = results.movedim(0, -1); results = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[4, 3, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
|
return (unflatten,)
|
|
""",
|
|
)
|
|
|
|
def test_jacfwd_two_tensors_argnums(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacfwd(fn, argnums=1)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
|
|
|
|
result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 4]" = _unpack_dual[0]
|
|
dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[3, 4, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None
|
|
return (unflatten,)
|
|
""",
|
|
)
|
|
|
|
def test_jacfwd_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin(), x
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacfwd(fn, argnums=1, has_aux=True)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
|
|
|
|
aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
|
|
|
|
result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None
|
|
|
|
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2); aux = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 4]" = _unpack_dual[0]
|
|
dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
aux_2: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
aux_3: "f32[4, 3]" = aux_2[0]; aux_2 = None
|
|
|
|
movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[3, 4, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None
|
|
return (unflatten, aux_3)
|
|
""",
|
|
)
|
|
|
|
def test_jacfwd_randomness(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin(), x
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacfwd(fn, randomness="same")(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None
|
|
|
|
child_2: "f32[3, 4]" = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(child_2, level = 0); child_2 = None
|
|
primal: "f32[3, 4]" = _unpack_dual[0]; _unpack_dual = None
|
|
|
|
tangent: "f32[3, 4]" = torch.zeros_like(primal)
|
|
|
|
_unpack_dual_1 = torch._unpack_dual(child_3, level = 0); child_3 = None
|
|
primal_1: "f32[4, 3]" = _unpack_dual_1[0]
|
|
dual: "f32[4, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None
|
|
|
|
child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None
|
|
child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None
|
|
|
|
child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
|
|
child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
child_8: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None
|
|
child_9: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1); child_8 = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[3, 4, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[3, 4, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
|
|
|
movedim_1: "f32[4, 3, 12]" = child_9.movedim(0, -1); child_9 = None
|
|
split_1 = movedim_1.split((12,), dim = -1); movedim_1 = None
|
|
jac_out_in_1: "f32[4, 3, 12]" = split_1[0]; split_1 = None
|
|
|
|
unflatten_1: "f32[4, 3, 4, 3]" = jac_out_in_1.unflatten(-1, (4, 3)); jac_out_in_1 = None
|
|
return (unflatten, unflatten_1)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_simple(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
return torch.func.jvp(fn, (x,), (v,))
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
return (primals_out_unflatten, tangents_out_unflatten)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum(), x
|
|
|
|
def wrapper_fn(x, v):
|
|
return torch.func.jvp(fn, (x,), (v,), has_aux=True)
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = aux.sin()
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
return (primals_out_unflatten, tangents_out_unflatten, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_two_tensors_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return (x.sin().sum() + y.cos()), x
|
|
|
|
def wrapper_fn(x, y, v):
|
|
return torch.func.jvp(fn, (x, y), (v, v), has_aux=True)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
l_v_ = L_v_
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None
|
|
|
|
_maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None
|
|
|
|
_make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = aux.sin()
|
|
sum_1: "f32[]" = sin.sum(); sin = None
|
|
cos: "f32[3, 3]" = _make_dual_1.cos(); _make_dual_1 = None
|
|
result_duals: "f32[3, 3]" = sum_1 + cos; sum_1 = cos = None
|
|
|
|
aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 3]" = _unpack_dual[0]
|
|
dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
return (primals_out_unflatten, tangents_out_unflatten, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_two_tensors_disable_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(False):
|
|
return torch.func.jvp(fn, (x,), (v,))
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
_set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None
|
|
return (primals_out_unflatten, tangents_out_unflatten)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_two_tensors_disable_enable_disable_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(False): # (1)
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(True): # (2)
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(False): # (3)
|
|
return torch.func.jvp(fn, (x,), (v,)) # (4)
|
|
|
|
# Start True
|
|
# False (1)
|
|
# True (2)
|
|
# False (3)
|
|
# True (4)
|
|
# True (undo 3)
|
|
# False (undo 2)
|
|
# True (undo 1)
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_4 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
_set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_5 = None
|
|
_set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_6 = None
|
|
_set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_7 = None
|
|
return (primals_out_unflatten, tangents_out_unflatten)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_freevar_tensor(self):
|
|
counters.clear()
|
|
y = torch.randn(3, 3)
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jvp(fn, (x,), (x,))
|
|
|
|
x = torch.randn(3, 3)
|
|
expected = wrapper_fn(x)
|
|
actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_jvp_jvp(self):
|
|
counters.clear()
|
|
|
|
if check_dynamic_shape_capture():
|
|
self.skipTest("test fails with dynamic shapes")
|
|
|
|
def fn(x):
|
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jvp(fn, (x,), (x,))
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None
|
|
|
|
_jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting_1 = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
|
|
_maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None
|
|
|
|
_make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0); child = None
|
|
|
|
result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1); _make_dual_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 3, 3]" = _unpack_dual[0]
|
|
dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None
|
|
|
|
tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
_unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None
|
|
primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0]
|
|
dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None
|
|
_unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0); tangents_out_unflatten = None
|
|
primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0]
|
|
dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1]; _unpack_dual_2 = None
|
|
|
|
_unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None
|
|
_unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None
|
|
|
|
_unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None
|
|
_unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None
|
|
_jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting_1 = None
|
|
return (_unwrap_for_grad_2, _unwrap_for_grad_3, _unwrap_for_grad_4, _unwrap_for_grad_5)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_freevar_python_scalar(self):
|
|
counters.clear()
|
|
y = 3
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jvp(fn, (x,), (x,))
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
expected = wrapper_fn(x)
|
|
actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_linearize_jvp_fn(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
output, jvp_fn = torch.func.linearize(torch.sin, x)
|
|
return output, jvp_fn(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=0)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
|
|
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
|
|
|
|
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
|
|
|
|
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
|
|
|
|
alias_default_1: "f32[3, 3, 3]" = torch.ops.aten.alias.default(alias_default)
|
|
|
|
cos_default: "f32[3, 3, 3]" = torch.ops.aten.cos.default(alias_default_1); alias_default_1 = None
|
|
|
|
alias_default_2: "f32[3, 3, 3]" = torch.ops.aten.alias.default(sin_default); alias_default_2 = None
|
|
return (alias_default, cos_default, sin_default)
|
|
""",
|
|
)
|
|
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=1)
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
|
|
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
|
|
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
|
|
l_flat_tangents_1_ = L_flat_tangents_1_
|
|
|
|
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
|
|
|
|
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
|
|
|
|
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
|
|
return (mul_tensor,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(fn)(x)
|
|
# should not recompile on second call. See Pytorch issue #118493
|
|
torch.vmap(fn)(x)
|
|
|
|
@xfailIfTorchDynamo
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile_different_config(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(fn)(x)
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
fn(x)
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile_same_config(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(torch.vmap(fn, randomness="same"), randomness="same")(x)
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
torch.vmap(torch.vmap(fn, randomness="same"), randomness="error")(x)
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile_with_randomness(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(fn, randomness="same")(x)
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
torch.vmap(fn, randomness="different")(x)
|
|
|
|
def test_vmap_call_torch_compile_fn(self):
|
|
def wrapped_fn(x):
|
|
return x.sin()
|
|
|
|
x = torch.randn(3, 4)
|
|
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Calling torch.func.vmap\\(compiled_fn\\) function from eager mode is not supported",
|
|
):
|
|
torch.func.vmap(fn)(x)
|
|
|
|
def test_vmap_call_compiled_backward_fn(self):
|
|
# See PyTorch issue #138422
|
|
@torch.compile
|
|
def f(x):
|
|
return x**2
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
y = f(x)
|
|
|
|
def get_vjp(v):
|
|
return torch.autograd.grad(y, x, v)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"It looks like you're trying to call a compiled backward function within vmap/grad/vjp, which isn't supported",
|
|
):
|
|
torch.func.vjp(get_vjp, x)
|
|
|
|
def test_vjp_call_compiled_backward_fn(self):
|
|
# See PyTorch issue #138422
|
|
@torch.compile
|
|
def f(x):
|
|
return x**2
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
y = f(x)
|
|
|
|
def get_vjp(v):
|
|
return torch.autograd.grad(y, x, v)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"It looks like you're trying to call a compiled backward function within vmap/grad/vjp, which isn't supported",
|
|
):
|
|
torch.func.vjp(get_vjp, x)
|
|
|
|
def test_grad_call_compiled_backward_fn(self):
|
|
# See PyTorch issue #138422
|
|
@torch.compile
|
|
def f(x):
|
|
return x**2
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
y = f(x)
|
|
|
|
def get_vjp(v):
|
|
return torch.autograd.grad(y, x, v)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"It looks like you're trying to call a compiled backward function within vmap/grad/vjp, which isn't supported",
|
|
):
|
|
torch.func.grad(get_vjp)(x)
|
|
|
|
def test_grad_call_torch_compile_fn(self):
|
|
def wrapped_fn(x):
|
|
return x.sin().sum()
|
|
|
|
x = torch.randn(3, 4)
|
|
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Calling torch.func.grad\\(compiled_fn\\) function from eager mode is not supported",
|
|
):
|
|
torch.func.grad(fn)(x)
|
|
|
|
def test_jvp_call_torch_compile_fn(self):
|
|
def wrapped_fn(x):
|
|
return x.sin().sum()
|
|
|
|
x = torch.randn(3, 4)
|
|
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Calling torch.func.jvp\\(compiled_fn\\) function from eager mode is not supported",
|
|
):
|
|
torch.func.jvp(fn, (x,), (x,))
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_grad_recompile(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.func.grad(torch.sin)(x)
|
|
|
|
x = torch.randn([])
|
|
torch.func.grad(fn)(x)
|
|
# should not recompile on second call
|
|
torch.func.grad(fn)(x)
|
|
|
|
def test_vmap_get_wrapped(self):
|
|
counters.clear()
|
|
|
|
def g(x):
|
|
return x.sin()
|
|
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def fn():
|
|
return torch.vmap(g)
|
|
|
|
x = torch.randn(3, 4)
|
|
expected = torch.vmap(g)(x)
|
|
wrapper = fn()
|
|
got = wrapper(x)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_with_conditional_graph_break(self):
|
|
def g(x):
|
|
if len(x.shape) < 2:
|
|
torch._dynamo.graph_break()
|
|
return x.sin()
|
|
else:
|
|
return x.cos()
|
|
|
|
@torch.compile(backend="aot_eager")
|
|
def fn(x):
|
|
return torch.vmap(g)(x)
|
|
|
|
counters.clear()
|
|
x = torch.randn(2, 3)
|
|
expected = x.sin()
|
|
got = fn(x)
|
|
self.assertEqual(expected, got)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
|
|
counters.clear()
|
|
y = torch.randn(2, 3, 4)
|
|
expected = y.cos()
|
|
got = fn(y)
|
|
self.assertEqual(expected, got)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
def test_vmap_with_graph_break(self):
|
|
counters.clear()
|
|
|
|
def g(x):
|
|
y = x.cos()
|
|
print("hi")
|
|
return y.sin()
|
|
|
|
def fn(x):
|
|
return torch.vmap(g)(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
|
expected = fn(x)
|
|
got = opt(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_with_graph_break_2(self):
|
|
counters.clear()
|
|
|
|
def cos(x):
|
|
print("cos")
|
|
return x.cos()
|
|
|
|
def sin(x):
|
|
print("sin")
|
|
return x.sin()
|
|
|
|
def g(x):
|
|
y = cos(x)
|
|
return sin(y)
|
|
|
|
def fn(x):
|
|
return torch.vmap(g, randomness="same")(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
|
expected = fn(x)
|
|
got = opt(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_with_graph_break_lambda(self):
|
|
counters.clear()
|
|
|
|
def sin(x):
|
|
print("sin")
|
|
return x.sin()
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: sin(x))(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
|
expected = fn(x)
|
|
got = opt(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap(self):
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
batched_outputs: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_free_const(self):
|
|
y = 3
|
|
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3]" = add + 3; add = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_free_tensor(self):
|
|
y = torch.randn(3, 3)
|
|
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3, 3]" = add + l_y_; add = l_y_ = None
|
|
|
|
_remove_batch_dim: "f32[3, 3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_two_inputs(self):
|
|
def fn(x, y):
|
|
return torch.func.vmap(
|
|
lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=(0, 1)
|
|
)(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
_add_batch_dim_1: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_two_inputs_tuple_in_dims(self):
|
|
in_dims = (0, 1)
|
|
|
|
def fn(x, y):
|
|
return torch.func.vmap(
|
|
lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=in_dims
|
|
)(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
_add_batch_dim_1: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_over_vmap_two_inputs(self):
|
|
def fn(x, y):
|
|
return torch.func.vmap(torch.func.vmap(lambda x, y: x + y, in_dims=1))(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
child: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
child_1: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_2: "f32[3]" = torch._functorch.predispatch._add_batch_dim(child, 1, 2); child = None
|
|
_add_batch_dim_3: "f32[3]" = torch._functorch.predispatch._add_batch_dim(child_1, 1, 2); child_1 = None
|
|
|
|
batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
|
|
|
|
batched_outputs_1: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
_remove_batch_dim_1: "f32[3, 3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
return (_remove_batch_dim_1,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_over_vmap_captured(self):
|
|
x = torch.ones(2, 3)
|
|
y = torch.ones(5, 3)
|
|
|
|
def fn(x):
|
|
return torch.func.vmap(torch.func.vmap(lambda y: x * y))(y)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_y_: "f32[5, 3]", L_x_: "f32[2, 3]"):
|
|
l_y_ = L_y_
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None
|
|
|
|
child: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_1: "f32[]" = torch._functorch.predispatch._add_batch_dim(child, 0, 2); child = None
|
|
|
|
batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
|
|
|
|
batched_outputs_1: "f32[3, 2, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
_remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
return (_remove_batch_dim_1,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_multiple_outputs(self):
|
|
x = torch.ones(2, 4, 3)
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: (x.sum(0), x.sum(1)))(x)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
child: "f32[3]" = _add_batch_dim.sum(0)
|
|
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
|
|
_remove_batch_dim: "f32[2, 3]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 0); child = None
|
|
_remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_multiple_outputs_diff_dims(self):
|
|
x = torch.ones(2, 4, 3)
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=(1, 0))(x)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
child: "f32[3]" = _add_batch_dim.sum(0)
|
|
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
|
|
_remove_batch_dim: "f32[3, 2]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 1); child = None
|
|
_remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_multiple_outputs_out_dims_tuple(self):
|
|
x = torch.ones(2, 4, 3)
|
|
out_dims = (1, 0)
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=out_dims)(x)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
child: "f32[3]" = _add_batch_dim.sum(0)
|
|
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
|
|
_remove_batch_dim: "f32[3, 2]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 1); child = None
|
|
_remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_kwargs(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
def fn(x, y):
|
|
return torch.func.vmap(lambda x, y: x + y)(x, y=y)
|
|
|
|
actual = fn(x, y)
|
|
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_pytree_inputs(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
def vmap_fn(inps):
|
|
x = inps["x"]
|
|
y = inps["y"]
|
|
return x + y
|
|
|
|
def fn(x, y):
|
|
return torch.func.vmap(vmap_fn)({"x": x, "y": y})
|
|
|
|
actual = fn(x, y)
|
|
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_side_effects(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
some_list = []
|
|
|
|
def f(x, y):
|
|
some_list.append(1)
|
|
return x + y
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.vmap(f)(x, y)
|
|
|
|
actual = wrapper_fn(x, y)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
self.assertEqual(some_list, [1, 1])
|
|
|
|
@unittest.expectedFailure
|
|
def test_vmap_side_effects_append_input(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
some_list = []
|
|
|
|
def f(x, y):
|
|
some_list.append(x)
|
|
return x + y
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.vmap(f)(x, y)
|
|
|
|
actual = wrapper_fn(x, y)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_previous_illegal_op_no_graph_break(self):
|
|
counters.clear()
|
|
|
|
# calling .stride() would previously graph break
|
|
def bad_fn(x):
|
|
y = x.view((4, 3))
|
|
y.stride()
|
|
return y
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(bad_fn)(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
actual = wrapper_fn(x)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_multiple_invocation_in_dims(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, in_dims):
|
|
return torch.func.vmap(torch.sum, in_dims)(x)
|
|
|
|
x = torch.randn(3, 3, 3, 3)
|
|
cnt = CompileCounter()
|
|
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
|
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
|
# Third invocation of `opt` makes `in_dims` as SymInt.
|
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
|
self.assertEqual(expected, actual)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 18)
|
|
|
|
def test_vmap_multiple_invocation_out_dims(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, out_dims):
|
|
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
|
|
|
|
x = torch.randn(3, 3, 3, 3)
|
|
cnt = CompileCounter()
|
|
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
|
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
|
# Third invocation of `opt` makes `in_dims` as SymInt.
|
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
|
self.assertEqual(expected, actual)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 18)
|
|
|
|
def test_vmap_out_dims_None(self):
|
|
# issue https://github.com/pytorch/pytorch/issues/149509
|
|
def fn(x, y):
|
|
return x, y * 2
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.vmap(fn, in_dims=(None, 0), out_dims=(None, 0))(x, y)
|
|
|
|
x, y = torch.randn(4), torch.randn(3, 4)
|
|
expected = wrapper_fn(x, y)
|
|
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x, y)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_new_tensor_in_body(self):
|
|
def fn(x):
|
|
return x + torch.ones(3)
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(fn)(x)
|
|
|
|
x = torch.randn(
|
|
3,
|
|
)
|
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
|
expected = wrapper_fn(x)
|
|
actual = opt(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_vmap_new_tensor_unused_in_body(self):
|
|
def fn(x):
|
|
return torch.tensor(0.5)
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(fn)(x)
|
|
|
|
x = torch.randn(3)
|
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
|
expected = wrapper_fn(x)
|
|
actual = opt(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_vmap_new_tensor_implicit_via_op(self):
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(lambda t: torch.add(t, 0.5))(x)
|
|
|
|
x = torch.randn(3)
|
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
|
expected = wrapper_fn(x)
|
|
actual = opt(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
|
|
class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
|
|
cloned_args = []
|
|
for arg in args:
|
|
cloned_args.append(arg.detach().clone().requires_grad_(arg.requires_grad))
|
|
|
|
torch.manual_seed(0)
|
|
expected = fn(*args)
|
|
expected.sum().backward()
|
|
|
|
opt_fn = torch.compile(fn, fullgraph=fullgraph, backend=backend)
|
|
torch.manual_seed(0)
|
|
result = opt_fn(*cloned_args)
|
|
result.sum().backward()
|
|
|
|
if not skip_check:
|
|
self.assertEqual(result, expected)
|
|
for arg, cloned_arg in zip(args, cloned_args):
|
|
self.assertEqual(arg.grad, cloned_arg.grad)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_function(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_function_with_kwargs(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
torch.sin(x),
|
|
y,
|
|
use_reentrant=True,
|
|
preserve_rng_state=False,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_dropout(self):
|
|
def gn(x, y):
|
|
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
y = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default
|
|
)
|
|
# philox_rand is passed from fwd
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=0, op=torch.ops.rngprims.philox_rand.default
|
|
)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(
|
|
fn, backend, x, y, skip_check=True
|
|
) # dropout decomp is known to diverge with eager
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_dropout_inductor(self):
|
|
def gn(x, y):
|
|
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
y = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
|
|
backend = "inductor"
|
|
self._validate(
|
|
fn, backend, x, y, skip_check=True
|
|
) # dropout decomp is known to diverge with eager
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_fallback(self):
|
|
def gn(x, y):
|
|
torch._dynamo.graph_break()
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.cos(
|
|
torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
),
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
args = (x, y)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
expected = fn(*args)
|
|
result = torch.compile(fn, backend=cnt)(*args)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
# One graph for torch.sin on the input, and other for torch.cos.
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_module(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.sigmoid(self.linear(x))
|
|
|
|
mod = MockModule()
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
mod, torch.sin(x), use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(10, 10, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
# sigmoid passed from fwd
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=0, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x)
|
|
|
|
def test_override_fallthrough_dispatch_key(self):
|
|
class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
|
|
def __init__(self):
|
|
super().__init__("_fallthrough_test_only")
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return super().__call__(*args, **kwargs)
|
|
|
|
test_op = _FallthroughTestOnly()
|
|
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
|
|
self.assertTrue(
|
|
not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)
|
|
)
|
|
|
|
foos = [lambda x=i: x for i, k in enumerate(default_keys)]
|
|
for foo, fallthrough_key in zip(foos, default_keys):
|
|
test_op.py_impl(fallthrough_key)(foo)
|
|
|
|
self.assertTrue(
|
|
all(test_op.non_fallthrough_keys.has(key) for key in default_keys)
|
|
)
|
|
self.assertEqual(
|
|
list(range(len(default_keys))),
|
|
[test_op.py_kernels[key]() for key in default_keys],
|
|
)
|
|
|
|
def test_cond_with_kwargs(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
def test(pred, x):
|
|
def true_fn(x):
|
|
return x.clone()
|
|
|
|
def false_fn(x):
|
|
return -x
|
|
|
|
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
|
|
|
|
cnt = CompileCounter()
|
|
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
|
|
inp = torch.ones(3, 3)
|
|
true_pred = torch.Tensor([True])
|
|
false_pred = torch.Tensor([False])
|
|
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertTrue(
|
|
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_cond_with_invalid_kwargs(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
def test(pred, mode, x):
|
|
def true_fn(x):
|
|
return x.clone()
|
|
|
|
def false_fn(x):
|
|
return -x
|
|
|
|
if mode:
|
|
return cond_op(
|
|
pred=pred,
|
|
true_fn=true_fn,
|
|
false_fn=false_fn,
|
|
operands=[x],
|
|
invalid=True,
|
|
)
|
|
else:
|
|
return cond_op(
|
|
pred,
|
|
pred=pred,
|
|
true_fn=true_fn,
|
|
false_fn=false_fn,
|
|
operands=[x],
|
|
)
|
|
|
|
cnt = CompileCounter()
|
|
opt_test = torch.compile(test, backend=cnt)
|
|
inp = torch.ones(3, 3)
|
|
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
|
opt_test(True, True, inp)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
opt_test(True, False, inp)
|
|
|
|
def test_cond_with_mismatched_output(self):
|
|
def output_mismatch_test(x):
|
|
def true_fn():
|
|
return torch.concat([x, x])
|
|
|
|
def false_fn():
|
|
return x.sin()
|
|
|
|
return torch.cond(x.sum() > 0, true_fn, false_fn)
|
|
|
|
x = torch.randn(2, 3)
|
|
output_mismatch_test(x)
|
|
|
|
torch.compile(output_mismatch_test, backend="eager")(x)
|
|
|
|
def test_non_aliasing_util(self):
|
|
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
|
|
|
a = [torch.tensor(1), {"a": torch.tensor(1)}]
|
|
b = (torch.tensor(1),)
|
|
_assert_tensors_nonaliasing(a, b)
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "inputs to function body cannot alias outputs"
|
|
):
|
|
_assert_tensors_nonaliasing(a, a)
|
|
|
|
def test_flop_counter_for_cond(self):
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return torch.cond(
|
|
torch.tensor(True),
|
|
lambda x: self.linear(x),
|
|
lambda x: self.linear(self.linear(x)),
|
|
(x,),
|
|
)
|
|
|
|
mod = Mod()
|
|
with FlopCounterMode(mod, display=False) as mode:
|
|
mod(torch.randn(4, 4))
|
|
|
|
self.assertEqual(
|
|
mode.get_flop_counts(),
|
|
{
|
|
"Global": {torch.ops.aten.addmm: 256},
|
|
"Mod": {torch.ops.aten.addmm: 256},
|
|
"Mod.linear": {torch.ops.aten.addmm: 256},
|
|
},
|
|
)
|
|
|
|
def test_flop_counter_for_nested_cond(self):
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(4, 4)
|
|
self.linear2 = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
def true_branch(x):
|
|
# Nested cond inside true branch
|
|
return torch.cond(
|
|
torch.tensor(True),
|
|
lambda x: self.linear1(x),
|
|
lambda x: self.linear2(x),
|
|
(x,),
|
|
)
|
|
|
|
def false_branch(x):
|
|
return self.linear1(self.linear2(x))
|
|
|
|
return torch.cond(torch.tensor(True), true_branch, false_branch, (x,))
|
|
|
|
mod = Mod()
|
|
with FlopCounterMode(mod, display=False) as mode:
|
|
mod(torch.randn(4, 4))
|
|
|
|
self.assertEqual(
|
|
mode.get_flop_counts(),
|
|
{
|
|
"Global": {torch.ops.aten.addmm: 256},
|
|
"Mod": {torch.ops.aten.addmm: 256},
|
|
"Mod.linear1": {torch.ops.aten.addmm: 128},
|
|
"Mod.linear2": {torch.ops.aten.addmm: 128},
|
|
},
|
|
)
|
|
|
|
def test_flop_counter_for_cond_unbalanced_branches(self):
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
def true_branch(x):
|
|
return self.linear(x)
|
|
|
|
def false_branch(x):
|
|
return x.clone()
|
|
|
|
return torch.cond(torch.tensor(True), true_branch, false_branch, (x,))
|
|
|
|
mod = Mod()
|
|
with FlopCounterMode(mod, display=False) as mode:
|
|
mod(torch.randn(4, 4))
|
|
|
|
self.assertEqual(
|
|
mode.get_flop_counts(),
|
|
{
|
|
"Global": {torch.ops.aten.addmm: 128},
|
|
"Mod": {torch.ops.aten.addmm: 128},
|
|
"Mod.linear": {torch.ops.aten.addmm: 128},
|
|
},
|
|
)
|
|
|
|
|
|
xfail_hops_compile = {
|
|
# aot_eager
|
|
"map", # assert type(args[1].realize()) is TensorVariable
|
|
"scan", # scan is not an OpOverload
|
|
# inductor
|
|
"while_loop", # LoweringException: AssertionError
|
|
"flex_attention", # LoweringException: AssertionError
|
|
"flex_attention_backward", # AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 16
|
|
}
|
|
|
|
|
|
class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase):
|
|
@requires_cuda_and_triton
|
|
@parametrize("backend", ("aot_eager", "inductor"))
|
|
@ops(
|
|
list(filter(lambda op: op.name not in xfail_hops_compile, hop_db)),
|
|
allowed_dtypes=(torch.float,),
|
|
)
|
|
def test_hops_compile(self, device, dtype, op, backend):
|
|
# Ensure HOPs can be compiled
|
|
|
|
if backend == "aot_eager" and op.name == "invoke_quant":
|
|
raise unittest.SkipTest(
|
|
"TODO: partitioner fails. migrate canonicalization to aot eager backend"
|
|
)
|
|
|
|
sample_inputs_itr = op.sample_inputs(
|
|
device, dtype, requires_grad=op.supports_autograd
|
|
)
|
|
for inp in sample_inputs_itr:
|
|
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
|
eager_args = (*input, *inp.args)
|
|
eager_kwargs = inp.kwargs
|
|
compiled_args = deepcopy(eager_args)
|
|
compiled_kwargs = deepcopy(eager_kwargs)
|
|
|
|
def fn(args, kwargs):
|
|
return op.op(*args, **(kwargs))
|
|
|
|
compiled_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
|
|
|
eager_out = fn(eager_args, eager_kwargs)
|
|
compiled_out = compiled_fn(compiled_args, compiled_kwargs)
|
|
self.assertEqual(eager_out, compiled_out)
|
|
|
|
|
|
instantiate_device_type_tests(TestHigherOrderOpsOpInfo, globals(), only_for=("cuda",))
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|