Files
pytorch/test/dynamo/test_decorators.py
PyTorch MergeBot 7ae0629d64 Revert "[inductor] turn on windows inductor UTs (#160161)"
This reverts commit f0980fc0bbd656d6c02d23ad97e945353b314f35.

Reverted https://github.com/pytorch/pytorch/pull/160161 on behalf of https://github.com/clee2000 due to broke some inductor tests on windows inductor\test_codecache.py::TestStandaloneCompile::test_different_process [GH job link](https://github.com/pytorch/pytorch/actions/runs/16853706010/job/47748778757) [HUD commit link](f0980fc0bb).  note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/160161#issuecomment-3172784292))
2025-08-10 17:33:19 +00:00

1985 lines
59 KiB
Python

# Owner(s): ["module: dynamo"]
import functools
import operator
import os
import unittest.mock as mock
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import IncorrectUsage, Unsupported
from torch._dynamo.utils import counters
def my_custom_function(x):
return x + 1
class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_disallow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
x = torch.add(x, 1)
return x
torch._dynamo.disallow_in_graph(torch.sub)
fn(torch.randn(10))
torch._dynamo.allow_in_graph(torch.sub)
# check for graph break on sub
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_disable_for_custom_op(self):
import torch.library
from torch.library import Library
foo = Library("foo", "DEF") # noqa: TOR901
foo.define("custom(Tensor self) -> Tensor")
# Dynamic shape data dependent operator. For static shape compilation, Dynamo
# should graph break on it. But, the meta kernel is not implemented properly.
@torch.library.impl(foo, "custom", "CPU")
def foo_cpu(x):
return x.nonzero()
# Disallow does not work because of extra python frames with torch.library python API
torch.ops.foo.custom = torch._dynamo.disable(torch.ops.foo.custom)
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.foo.custom(a)
c = torch.cos(b)
return c
x = torch.randint(2, (100,))
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
res = opt_fn(x)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(ref, res)
def test_disable_ignores_outer_wraps(self):
def orig_inner():
pass
def inner():
pass
inner._torchdynamo_orig_callable = orig_inner
@functools.wraps(inner)
def wrapper():
raise AssertionError("wrapper called")
# This behavior is not ideal, but supporting it would add overhead
# to callsites of eval_frame.innermost_fn. A warning would also be very noisy.
torch._dynamo.disable(fn=wrapper, recursive=True)
def test_disable_nn_modules_forward_hook(self):
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
def hook(module, args):
inp = args[0].sigmoid()
return (inp,)
model = SimpleModel()
model.layer0.register_forward_pre_hook(hook)
# Disable my monkeypatching
model.layer0 = torch._dynamo.disable(model.layer0)
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
opt_model = torch.compile(model, backend=cnts)
opt_model(torch.randn(4))
# check for no graph break
self.assertEqual(cnts.frame_count, 2)
gm0 = cnts.graphs[0]
# Check that the first graph has sin node, and no sigmoid
self.assertTrue(any(node.target is torch.sin for node in gm0.graph.nodes))
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm0.graph.nodes)
)
gm1 = cnts.graphs[1]
# Check that the first graph does not have sigmoid. sigmoid is used in
# both hook and disabled module.
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
)
def test_disable_nn_module_with_class_decorator(self):
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
@torch._dynamo.disable
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
@torch.compile(backend=cnts)
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
def hook(module, args):
inp = args[0].sigmoid()
return (inp,)
model = SimpleModel()
model.layer0.register_forward_pre_hook(hook)
model(torch.randn(4))
# check for no graph break
self.assertEqual(cnts.frame_count, 2)
gm0 = cnts.graphs[0]
# Check that the first graph has sin node, and no sigmoid
self.assertTrue(any(node.target is torch.sin for node in gm0.graph.nodes))
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm0.graph.nodes)
)
gm1 = cnts.graphs[1]
# Check that the first graph does not have sigmoid. sigmoid is used in
# both hook and disabled module.
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
)
def test_allow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
x = torch.add(x, 1)
return x
torch._dynamo.allow_in_graph(my_custom_function)
fn(torch.randn(10))
torch._dynamo.disallow_in_graph(my_custom_function)
# check for no graph break
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 5)
def test_allow_in_graph_no_id_reuse(self):
cnts = torch._dynamo.testing.CompileCounter()
def do_allow_in_graph(x):
return x + 1
torch._dynamo.allow_in_graph(do_allow_in_graph)
del do_allow_in_graph
# `id(dont_allow_in_graph)` would likely match `id(do_allow_in_graph)`
# We want to make sure Dynamo always trace through
# `dont_allow_in_graph`, by checking for the explicit graph break.
def dont_allow_in_graph(x):
torch._dynamo.graph_break()
return x + 1
@torch.compile(backend=cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
x = dont_allow_in_graph(x)
x = torch.add(x, 1)
x = torch.add(x, 1)
return x
fn(torch.randn(10))
# Check for graph break
self.assertEqual(cnts.frame_count, 3)
def test_incorrect_usage_disallow_in_graph(self):
with self.assertRaises(IncorrectUsage):
@torch._dynamo.disallow_in_graph
def fn1(x):
return x.cos()
def test_nonstrict_trace_tensor_args(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, y, z):
torch._dynamo.graph_break()
return x * y + z
def fn(x, y):
t0 = x + 1
t1 = trace_me(x, y, t0)
t2 = t1 + y
return t0 * t2
x, y = torch.randn(10), torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_dict(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, d):
torch._dynamo.graph_break()
return x * d["a"]
def fn(x, d):
t0 = trace_me(x, d)
return t0 + 1
x = torch.randn(10)
d = {"a": 2}
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, d)
res = opt_fn(x, d)
self.assertEqual(ref, res)
def test_nonstrict_trace_newly_constructed_dict_with_side_effects(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, d):
torch._dynamo.graph_break()
return x * d["a"]
def fn(x):
d = {}
d["a"] = 2
t0 = trace_me(x, d)
return t0 + 1
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_dict_with_side_effects(self):
@torch._dynamo.nonstrict_trace
def trace_me(x, d):
torch._dynamo.graph_break()
return x * d["a"]
def fn(x, d):
d["a"] = x + 1
t0 = trace_me(x, d)
return t0 + 2
x = torch.randn(10)
d0 = {"a": 0}
d1 = dict(d0)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, d0)
res = opt_fn(x, d1)
self.assertEqual(ref, res)
self.assertEqual(d0, d1)
def test_nonstrict_trace_pre_existing_custom_class(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
def fn(p):
res = trace_me(p)
return res, p.x, p.y
p = Point(torch.ones(10), torch.ones(1))
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(p)
res = opt_fn(p)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_custom_class_with_side_effects(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
def fn(p):
p.x = p.x + 1
p.y = p.y + 2
res = trace_me(p)
return res, p.x, p.y
p1 = Point(torch.ones(10), torch.ones(1))
p2 = Point(torch.ones(10), torch.ones(1))
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(p1)
res = opt_fn(p2)
self.assertEqual(ref, res)
self.assertEqual(p1.x, p2.x)
self.assertEqual(p1.y, p2.y)
def test_nonstrict_trace_newly_constructed_custom_class_with_side_effects(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
def fn(x, y):
p = Point(x, y)
p.x = p.x + 1
p.y = p.y + 2
res = trace_me(p)
return res, p.x, p.y
x, y = torch.ones(10), torch.ones(1)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_nested_custom_class(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res
x, y = torch.ones(10), torch.ones(1)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_pre_existing_register_constant_type_guard(self):
class State:
def __init__(self, n):
self.n = n
def get_num(self):
torch._dynamo.graph_break()
return self.n
def __eq__(self, other):
return isinstance(other, State) and self.n == other.n
def __hash__(self):
return hash(self.n)
# Assume `State` is implemented in C, and the author didn't bother to
# provide a pytree decomposition for it, and its instances are safe to
# treat as a constant by `torch.compile`.
torch.utils._pytree.register_constant(State)
@torch._dynamo.nonstrict_trace
def trace_me(x, s):
return x * s.get_num()
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
@torch.compile(fullgraph=True, backend=cnts)
def fn(x, s):
res = trace_me(x, s)
return res
x = torch.ones(10)
# Make sure recompilation didn't happen.
self.assertEqual(cnts.frame_count, 0)
fn(x, State(42))
self.assertEqual(cnts.frame_count, 1)
fn(x, State(42))
self.assertEqual(cnts.frame_count, 1)
# Make sure recompilation did happen.
fn(x, State(41))
self.assertEqual(cnts.frame_count, 2)
def test_nonstrict_trace_int_and_float_output(self):
@torch._dynamo.nonstrict_trace
def trace_me(x):
torch._dynamo.graph_break()
return len(x.shape), 0.42
def fn(x):
n1, n2 = trace_me(x)
return x * n1 + n2
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_tuple_and_sym_int_output(self):
@torch._dynamo.nonstrict_trace
def trace_me(x):
torch._dynamo.graph_break()
return x + 1, x.size(0)
def fn(x):
t0, n = trace_me(x)
return t0 * n
x = torch.randn(10)
opt_fn = torch.compile(fn, dynamic=True, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_inside_compiled_function(self):
def trace_me(x):
torch._dynamo.graph_break()
return x + 42
def fn(x):
res = torch._dynamo.nonstrict_trace(trace_me)(x)
return res + 1
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_inside_compiled_function_kwarg(self):
def trace_me(x):
torch._dynamo.graph_break()
return x + 42
def fn(x):
res = torch._dynamo.nonstrict_trace(traceable_fn=trace_me)(x)
return res + 1
x = torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_nonstrict_trace_on_method(self):
class Num:
def __init__(self, n):
self.n = n
@torch._dynamo.nonstrict_trace
def trace_me(self, t):
torch._dynamo.graph_break()
return t + self.n
torch.utils._pytree.register_pytree_node(
Num,
lambda num: ((num.n,), ()),
lambda n, _: Num(n[0]),
)
def fn(x, n):
num = Num(n)
return num.trace_me(x)
x, n = torch.randn(10), 42
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, n)
res = opt_fn(x, n)
self.assertEqual(ref, res)
def test_nonstrict_trace_captured_external_tensor(self):
cst = torch.ones(1)
@torch._dynamo.nonstrict_trace
def trace_me(x, y):
torch._dynamo.graph_break()
return x * y + cst
def fn(x, y):
return trace_me(x, y)
x, y = torch.randn(10), torch.randn(10)
opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager")
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nonstrict_trace_no_action_at_a_distance(self):
def trace_me(x):
torch._dynamo.graph_break()
return x + 42
# No effect on traceability of `trace_me`
torch._dynamo.nonstrict_trace(trace_me)
def fn(x):
res = trace_me(x)
return res + 1
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
opt_fn = torch.compile(fn, backend=cnts)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
# There should be 1 graph break
self.assertEqual(cnts.frame_count, 2)
def test_nonstrict_trace_inside_compiled_function_error(self):
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, y):
def trace_me(x, y):
torch._dynamo.graph_break()
return x * y
res = torch._dynamo.nonstrict_trace(trace_me)(x, y)
return res + 1
try:
fn(torch.ones(10), torch.ones(1))
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
msg = "Applying `nonstrict_trace` to function <trace_me>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # NOQA: B950
self.assertIn(msg, str(e))
def test_nonstrict_trace_custom_class_error(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
@torch._dynamo.nonstrict_trace
def trace_me(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(p):
res = trace_me(p)
return res + 1
try:
p = Point(torch.ones(10), torch.ones(1))
fn(p)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e))
def test_nonstrict_trace_nested_custom_class_error(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)
def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res
try:
fn(torch.ones(10), torch.ones(1))
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e))
def test_nonstrict_trace_custom_class_output_error(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
@torch._dynamo.nonstrict_trace
def trace_me(x):
torch._dynamo.graph_break()
return Point(x, x + 1)
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x):
p = trace_me(x)
return p.x * p.y
try:
x = torch.ones(10)
fn(x)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
self.assertIn(
"Unsupported output type for nonstrict_trace-ed function", str(e)
)
def test_nonstrict_newly_constructed_trace_register_constant_type_error(self):
class State:
def __init__(self, n):
self.n = n
def get_num(self):
torch._dynamo.graph_break()
return self.n
def __eq__(self, other):
return isinstance(other, State) and self.n == other.n
def __hash__(self):
return hash(self.n)
# Assume `State` is implemented in C, and the author didn't bother to
# provide a pytree decomposition for it, and its instances are safe to
# treat as a constant by `torch.compile`.
torch.utils._pytree.register_constant(State)
@torch._dynamo.nonstrict_trace
def trace_me(x, s):
return x * s.get_num()
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x):
s = State(10)
res = trace_me(x, s)
return res
try:
x = torch.ones(10)
fn(x)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
self.assertIn(
"Input marked with `pytree.register_constant` constructed in the `torch.compile` region",
str(e),
)
def test_nonstrict_trace_object_in_context_error(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.t,), pt.p),
lambda ts, p: PointTensor(p, ts[0]),
)
@torch._dynamo.nonstrict_trace
def trace_me(pt):
torch._dynamo.graph_break()
return pt.t + pt.p.x * pt.p.y
@torch.compile(fullgraph=True, backend="aot_eager")
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_me(pt)
return res
try:
x, y = torch.ones(10), torch.ones(1)
fn(x, y)
self.assertFalse(True) # must raise error before this
except torch._dynamo.exc.Unsupported as e:
self.assertIn(
"Invalid use of pytree_flatten with nonstrict_trace-ed function", str(e)
)
def test_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x):
x = torch.cos(x)
x = torch.cos(x)
torch._dynamo.graph_break()
x = torch.cos(x)
x = torch.cos(x)
torch._dynamo.graph_break()
x = torch.cos(x)
x = torch.cos(x)
return x
fn(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)
def test_skip_frame(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x):
x = x + 1
torch._dynamo.skip_frame()
return x + 1
inp = torch.ones(3, 3)
self.assertEqual(fn(inp), inp + 2)
self.assertEqual(cnts.frame_count, 0)
@torch.compile(backend=cnts)
def gn(x):
x = x + 1
torch._dynamo.graph_break()
x = x + 1
torch._dynamo.skip_frame()
return x + 1
self.assertEqual(gn(inp), inp + 3)
self.assertEqual(cnts.frame_count, 1)
def test_disable_recursive_false(self):
def fn2(x):
return x + 1
@torch._dynamo.disable(recursive=False)
def fn1(x):
if torch.compiler.is_compiling():
raise RuntimeError("bad")
x = x.sigmoid()
return fn2(x.cos())
def fn(x):
return fn1(x.tan())
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
opt_fn(torch.randn(4))
self.assertEqual(cnts.frame_count, 2)
# test that applying disable nonrecursive doesn't modify the original function
def fn3(x):
if torch.compiler.is_compiling():
return x - 1
return fn2(x) + 2
@torch.compile(backend=cnts)
def outer(f, x):
return f(x)
inp = torch.ones(3)
fn3_disabled = torch._dynamo.disable(fn3, recursive=False)
torch._dynamo.reset()
cnts.clear()
res = outer(fn3, inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(res, inp - 1)
cnts.clear()
res = outer(fn3_disabled, inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(res, inp + 3)
torch._dynamo.reset()
cnts.clear()
res = outer(fn3_disabled, inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(res, inp + 3)
cnts.clear()
res = outer(fn3, inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(res, inp - 1)
# directly compiling a disabled function should result in a compile
torch._dynamo.reset()
cnts.clear()
res = torch.compile(fn3_disabled, backend=cnts)(inp)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(res, inp - 1)
def test_disable_recursive_false_weird(self):
from torch._dynamo.types import FrameAction, FrameExecStrategy
# test the case where the next invocation of the function is
# manually skipped
def fn(x):
if torch.compiler.is_compiling():
return x - 1
return x + 1
fn_disabled = torch._dynamo.disable(fn, recursive=False)
torch._dynamo.eval_frame.set_code_exec_strategy(
fn.__code__, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
)
@torch.compile(backend="eager")
def outer(fn, x):
return fn(x)
inp = torch.ones(3)
self.assertEqual(outer(fn_disabled, inp), inp + 1)
torch._dynamo.eval_frame.set_code_exec_strategy(
fn.__code__, FrameExecStrategy(FrameAction.DEFAULT, FrameAction.DEFAULT)
)
self.assertEqual(torch.compile(fn, backend="eager")(inp), inp - 1)
def test_substitute_in_graph(self):
counters.clear()
# NB: Choose another C function for test when we support operator.indexOf
# out of the box
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch.compile(fn, backend=cnts)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(len(counters["graph_break"]), 1)
torch._dynamo.reset()
counters.clear()
with self.assertRaisesRegex(TypeError, "Signature mismatch"):
@torch._dynamo.substitute_in_graph(operator.indexOf)
def _(sequence, x):
for i, item in enumerate(sequence):
if item is x or item == x:
return i
raise ValueError("sequence.index(x): x not in sequence")
@torch._dynamo.substitute_in_graph(operator.indexOf)
def polyfill(a, b):
for i, item in enumerate(a):
if item is b or item == b:
return i
raise ValueError("sequence.index(x): x not in sequence")
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(len(counters["graph_break"]), 0)
torch._dynamo.reset()
counters.clear()
cnts = torch._dynamo.testing.CompileCounter()
fn = polyfill
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
self.assertEqual(cnts.frame_count, 0)
self.assertEqual(len(counters["graph_break"]), 0)
@patch.object(torch._dynamo.config, "suppress_errors", True)
def test_nested_disable_decorator(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.disable()
def fn1(x):
return torch.sin(x) * 10
@torch.compile(backend=cnts)
def fn2(x):
x = x + 1
x = x + 1
x = fn1(x) # graph break
x = x + 1
x = x + 1
return x
@torch.compile(backend=cnts, fullgraph=True)
def fn3(x):
return fn2(x)
fn2(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
cnts.clear()
torch._dynamo.reset()
fn3(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_disable_optimize(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, disable=True)
def f1(x):
return x + 1
f1(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
@torch.compile(backend=cnt, disable=True)
def f2(x):
return x + 1
f2(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
@torch.compile(backend=cnt)
def f3(x):
return x + 1
f3(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
def test_torch_guards_stack_frame_register_inlining_disable(self):
x = torch.tensor([0.5, 0.5])
class encoder(torch.nn.Module):
def __init__(self, y):
super().__init__()
self.a = y
@torch._dynamo.disable
def helper(self, x, y):
return x * y
def forward(self, a, *args):
x = a + a
return self.helper(x, self.a)
e = encoder(2.0)
seen_frames = []
import contextlib
@contextlib.contextmanager
def global_context_capture_fn(frame_summary):
if frame_summary is not None:
seen_frames.append(frame_summary)
yield
with mock.patch(
"torch._guards.TracingContext.current_frame",
side_effect=global_context_capture_fn,
):
torch.compile(e, backend="eager")(x)
self.assertEqual(len(seen_frames), 0)
def test_torch_guards_stack_frame_register_inlining_partially_disable(self):
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
x = torch.tensor([0.5, 0.5])
class encoder(torch.nn.Module):
def __init__(self, y):
super().__init__()
self.register_parameter("param", y)
@torch._dynamo.disable
def helper_disabled(self, x, y):
return x.sin() * y.cos()
def helper(self, x, y):
return x * y
def forward(self, a, *args):
x = a + a
return self.helper(x, self.param) + self.helper_disabled(x, self.param)
e = encoder(y)
cnt = torch._dynamo.testing.CompileCounter()
torch.compile(e, backend=cnt)(x)
# first frame is before disable, second frame is after disable
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 3)
def _test_mark_static_address(self, guarded):
# This test verifies that dynamo properly marks inputs as static
# when using the mark_static_address API.
# For both inline_inbuilt_nn_modules True and False, we expect the
# tensor to be present in the buffers attribute of the graph.
compiles_with_buffers = 0
compiles = 0
def debug_compiler(gm, _):
nonlocal compiles_with_buffers
nonlocal compiles
compiles_with_buffers += len(gm._buffers) > 0
compiles += 1
return gm
@torch.compile(backend=debug_compiler)
def fn(x):
return x + 1
inp = torch.ones(2)
torch._dynamo.mark_static_address(inp, guard=guarded)
fn(inp)
if guarded:
self.assertEqual(compiles_with_buffers, 1)
inp2 = torch.ones(2)
# if guarded, should trigger another recompile
# since it was not marked static, compiles with buffers
# should not be incremented
fn(inp2)
if guarded:
self.assertEqual(compiles_with_buffers, 1)
self.assertEqual(compiles, 2 if guarded else 1)
def test_mark_static_address_guarded(self):
with torch._dynamo.config.patch("inline_inbuilt_nn_modules", True):
self._test_mark_static_address(guarded=True)
self._test_mark_static_address(guarded=True)
def test_mark_static_address_unguarded(self):
with torch._dynamo.config.patch("inline_inbuilt_nn_modules", True):
self._test_mark_static_address(guarded=False)
self._test_mark_static_address(guarded=False)
def test_class_methods(self):
class A:
@classmethod
def my_class_method(cls, arg1):
return cls, arg1
@staticmethod
def my_static_method(arg1):
return None, arg1
def my_regular_method(self, arg1):
return self, arg1
class B(A):
def my_class_method(self, arg1):
return super().my_class_method(arg1)
def my_static_method(self, arg1):
return super().my_static_method(arg1)
class C(A):
@classmethod
def my_class_method(cls, arg1):
return super().my_class_method(arg1)
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt)
def fn(a, b, c):
# We want a function that does not graph break but
# does generate custom bytecode
v1 = a.my_class_method(1)
v2 = A.my_class_method(2)
v3 = a.my_static_method(3)
v4 = A.my_static_method(4)
v5 = a.my_regular_method(5)
v6 = b.my_class_method(6)
v7 = b.my_static_method(7)
v8 = c.my_class_method(8)
v9 = C.my_class_method(9)
torch.rand(2)
return v1, v2, v3, v4, v5, v6, v7, v8, v9
a, b, c = A(), B(), C()
v1, v2, v3, v4, v5, _, v7, v8, v9 = fn(a, b, c)
self.assertEqual(v1, (A, 1))
self.assertEqual(v2, (A, 2))
self.assertEqual(v3, (None, 3))
self.assertEqual(v4, (None, 4))
self.assertEqual(v5, (a, 5))
# TODO fix me: we do not resolve classmethods properly
# from a regular method
# self.assertEqual(v6, (B, 6))
self.assertEqual(v7, (None, 7))
self.assertEqual(v8, (C, 8))
self.assertEqual(v9, (C, 9))
self.assertEqual(cnt.frame_count, 1)
def test_assume_constant_result_on_user_defined_fn(self):
@torch._dynamo.assume_constant_result
def const_fn(n, s):
return torch.full([n], s)
def fn(B):
B = const_fn(B.size(0), 13)
X = B * 2
return X.tolist()
B_list = [8] * 32
B = torch.tensor(B_list, dtype=torch.int32)
torch._dynamo.decorators.mark_static(B, 0)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
self.assertEqual(
fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
)
def test_assume_constant_result_on_computation_with_graph_input(self):
@torch._dynamo.assume_constant_result
def check(y):
return y[0].item() == 1
def fn(x, y):
if check(y):
return x + 2
else:
return x + 1
y = torch.tensor([1])
x = torch.tensor(1)
self.assertEqual(fn(x, y), torch.compile(fn)(x, y))
def test_set_stance_aot_eager_then_compile(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x, y, z):
return x * y * z[0]
with torch.compiler.set_stance("aot_eager_then_compile"):
fn(2, torch.randn(2), {0: torch.randn(2)})
fn(3, torch.randn(3), {0: torch.randn(3)})
fn(4, torch.randn(4), {0: torch.randn(4)})
# Would have been 4 without stance
self.assertEqual(cnts.op_count, 2)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_mark_static_nn_module(self):
@torch._dynamo.mark_static
class Mock(torch.nn.Module):
def __init__(self, c):
super().__init__()
self.c = c
def forward(self, x):
return x * self.c
cnts = torch._dynamo.testing.CompileCounter()
mod1 = Mock(10)
mod2 = Mock(20)
mod3 = Mock(30)
opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True)
opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True)
opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True)
x = torch.randn(4, 4)
opt_mod1(x)
opt_mod2(x)
opt_mod3(x)
# Must be 3 compilations. If not marked static there would be 2, because self.c would be converted to symints.
self.assertEqual(cnts.frame_count, 3)
def test_set_stance_eager_then_compile(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x, y, z):
return x * y * z[0]
with torch.compiler.set_stance("eager_then_compile"):
fn(1, torch.randn(1), {0: torch.randn(1)})
fn(2, torch.randn(2), {0: torch.randn(2)})
fn(3, torch.randn(3), {0: torch.randn(3)})
self.assertEqual(cnts.frame_count, 1)
def test_set_stance_eager_then_compile_with_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x, y, z):
y = torch.sin(y)
torch._dynamo.graph_break()
y = torch.cos(y)
return x * y * z[0]
with torch.compiler.set_stance("eager_then_compile"):
fn(1, torch.randn(1), {0: torch.randn(1)})
fn(2, torch.randn(2), {0: torch.randn(2)})
fn(3, torch.randn(3), {0: torch.randn(3)})
# frame count 2 since we added a graph break
self.assertEqual(cnts.frame_count, 2)
def test_set_stance_force_eager(self):
@torch.compile(backend="eager")
def a(x):
if torch._dynamo.is_compiling():
return x + 1
return x + 2
@torch.compiler.set_stance("force_eager")
def b(x):
return a(x)
def c(x):
out0 = a(x)
with torch.compiler.set_stance("force_eager"):
out1 = a(x)
return out0, out1, a(x)
inp = torch.ones(3)
# test that decorating b has no overall side effect
self.assertEqual(a(inp), inp + 1)
self.assertEqual(b(inp), inp + 2)
self.assertEqual(c(inp), (inp + 1, inp + 2, inp + 1))
torch.compiler.set_stance("force_eager")
self.assertEqual(a(inp), inp + 2)
torch.compiler.set_stance("default")
self.assertEqual(a(inp), inp + 1)
def test_set_stance_eager_on_recompile(self):
@torch.compile(backend="eager", dynamic=False)
def a(x, n):
if torch._dynamo.is_compiling():
return x + n + 1
return x + n + 2
inp = torch.ones(3)
out1 = a(inp, 1)
with torch.compiler.set_stance("eager_on_recompile"):
out2 = a(inp, 1)
out3 = a(inp, 2)
self.assertEqual(out1, inp + 2)
self.assertEqual(out2, inp + 2)
self.assertEqual(out3, inp + 4)
def test_set_stance_fail_on_recompile(self):
@torch.compile(backend="eager", dynamic=False)
def a(x, n):
if torch._dynamo.is_compiling():
return x + n + 1
return x + n + 2
inp = torch.ones(3)
out1 = a(inp, 1)
with torch.compiler.set_stance("fail_on_recompile"):
out2 = a(inp, 1)
with self.assertRaisesRegex(RuntimeError, "fail_on_recompile"):
a(inp, 2)
self.assertEqual(out1, inp + 2)
self.assertEqual(out2, inp + 2)
def test_set_stance_fail_on_recompile_with_disable(self):
@torch.compiler.disable
def inner(x):
return x
@torch.compile(backend="eager")
def f(x):
return inner(x)
f(torch.randn(3, 3))
# should not raise error
with torch.compiler.set_stance("fail_on_recompile"):
f(torch.randn(3, 3))
def test_set_stance_forbid_in_graph(self):
@torch.compiler.set_stance("force_eager")
def a(x):
return x + 1
@torch.compile(backend="eager")
def b(x):
return a(x)
with self.assertRaisesRegex(
AssertionError, "Attempt to trace forbidden callable"
):
b(torch.ones(3))
@torch.compile(backend="eager")
def c(x):
with torch.compiler.set_stance("force_eager"):
return x + 1
with self.assertRaisesRegex(
AssertionError, "Attempt to trace forbidden callable"
):
c(torch.ones(3))
@torch.compile(backend="eager")
@torch.compiler.set_stance("force_eager")
def d(x):
return x + 1
with self.assertRaisesRegex(
AssertionError, "Attempt to trace forbidden callable"
):
d(torch.ones(3))
@torch.compile(backend="eager")
def e(x):
with torch._dynamo.set_stance("force_eager"):
return x + 1
with self.assertRaisesRegex(
AssertionError, "Attempt to trace forbidden callable"
):
e(torch.ones(3))
@torch.compile(backend="eager")
def f(x):
torch._dynamo.eval_frame._set_stance("force_eager")
return x + 1
with self.assertRaisesRegex(
AssertionError, "Attempt to trace forbidden callable"
):
f(torch.ones(3))
@torch.compile(backend="eager")
def g(x):
torch._dynamo.skip_frame()
# NOTE: torch._dynamo.is_compiling() will get traced
# and return true. torch.compiler.is_compiling() is skipped
# and will return false.
if torch.compiler.is_compiling():
raise RuntimeError("Expect this frame to be skipped")
# should not be traced, but eval frame callback is still set
with torch.compiler.set_stance("force_eager"):
return x + 1
with self.assertRaisesRegex(RuntimeError, "set_stance in a torch.compile"):
g(torch.ones(3))
def test_set_stance_force_backend(self):
@torch.compile
def a(x):
return x + 1
cnts = torch._dynamo.testing.CompileCounter()
@torch.compiler.set_stance("default", force_backend=cnts)
def b(x):
return a(x)
b(torch.ones(3))
self.assertEqual(cnts.frame_count, 1)
@torch.compiler.set_stance("default", force_backend="eager")
def c(x):
return a(x)
# just make sure this doesn't crash
c(torch.ones(3))
with self.assertRaisesRegex(RuntimeError, "force_backend"):
@torch.compiler.set_stance("force_eager", force_backend="eager")
def d(x):
pass
def test_set_stance_force_backend_with_disable(self):
@torch.compiler.disable
def inner(x):
return x
@torch.compile(backend="eager")
def f(x):
return inner(x)
f(torch.randn(3, 3))
def fail_backend(gm, ex):
raise RuntimeError("fail!")
# should not raise error
with torch.compiler.set_stance("default", force_backend=fail_backend):
f(torch.randn(3, 3))
# also tests a lot of torch._dynamo.patch_dynamo_config functionality
def test_dont_skip_tracing(self):
from torch._dynamo.test_dont_skip_tracing_functions import f1, f3, f4, f5, f6
cnts = torch._dynamo.testing.CompileCounter()
# make sure test_dont_skip_tracing_functions is actually skipped by trace rules
torch.compile(f1, backend=cnts)(torch.randn(3))
self.assertEqual(cnts.frame_count, 0)
f1_unskip = torch._dynamo.dont_skip_tracing(f1)
# basic test
def g1(x):
return f1_unskip(x)
cnts.clear()
torch.compile(g1, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test that dont_skip_tracing is traceable
def g2(x):
return torch._dynamo.dont_skip_tracing(f1)(x)
cnts.clear()
torch.compile(g2, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test that dont_skip_tracing is recursive, applied to non-skipped function
@torch._dynamo.dont_skip_tracing
def g3(x):
return f1(x)
cnts.clear()
torch.compile(g3, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test that dont_skip_tracing is recursive, applied to skipped function
f3_unskip = torch._dynamo.dont_skip_tracing(f3)
cnts.clear()
torch.compile(f3_unskip, backend=cnts, fullgraph=True)(torch.randn(3))
self.assertEqual(cnts.frame_count, 1)
# test dont_skip_tracing with graph breaks
inp = torch.ones(3)
res = torch.compile(f4, backend=cnts)(inp)
self.assertEqual(res, inp + 6)
@torch.compile(backend=cnts)
def g4(x):
x = f5(x, 1)
x = torch._dynamo.dont_skip_tracing(f6)(x)
x = f5(x, 8)
return x
res = g4(inp)
self.assertEqual(res, inp + 6)
# test nested dont_skip_tracing
# this also happens to test if a previously skipped frame (f4)
# can actually be compiled if called as a top-level function (in the case of a graph break)
# TODO the reset is necessary for now since attempting to trace f4 previously
# resulted in an unconditional skip
torch._dynamo.reset()
f4_unskip = torch._dynamo.dont_skip_tracing(f4)
res = torch.compile(f4_unskip, backend=cnts)(inp)
self.assertEqual(res, inp + 15)
# test dont_skip_tracing that is activated outside torch.compile
f4_unskip2 = torch._dynamo.dont_skip_tracing(torch.compile(f4, backend=cnts))
res = f4_unskip2(inp)
self.assertEqual(res, inp + 15)
# test context manager from inside
@torch.compile(backend=cnts)
def g5(x):
x = f5(x, 1)
with torch._dynamo.dont_skip_tracing():
x = f5(x, 2)
torch._dynamo.graph_break()
x = f5(x, 4)
x = f5(x, 8)
return x
res = g5(inp)
self.assertEqual(res, inp + 6)
# test context manager from outside
with torch._dynamo.dont_skip_tracing():
res = torch.compile(f4, backend=cnts)(inp)
self.assertEqual(res, inp + 15)
# test skipped function from different dont_skip_tracing regions
@torch.compile(backend=cnts)
def g6(x):
fn1 = f5
with torch._dynamo.dont_skip_tracing():
fn2 = f5
x = fn1(x, 1)
x = fn2(x, 2)
return x
res = g6(inp)
self.assertEqual(res, inp + 1)
def test_patch_dynamo_config_errors(self):
@torch.compile(backend="eager")
def f1(x):
with torch._dynamo.patch_dynamo_config(nonexistent=False):
return x + 1
with self.assertRaisesRegex(Exception, "patch_dynamo_config does not support"):
f1(torch.randn(3))
@torch.compile(backend="eager")
def f2(x):
with torch._dynamo.patch_dynamo_config("verbose", {"a": 1}):
return x + 1
with self.assertRaisesRegex(
Exception, "patch_dynamo_config does not support .* with non-safe-constant"
):
f2(torch.randn(3))
@torch.compile(backend="eager")
def f3(x):
with torch._dynamo.patch_dynamo_config({"recompile_limit": 1}):
return x + 1
with self.assertRaisesRegex(Exception, "patch_dynamo_config does not support"):
f3(torch.randn(3))
@torch.compile(backend="eager")
def f4(x):
with torch._dynamo.patch_dynamo_config(verbose=object()):
return x + 1
with self.assertRaisesRegex(
Exception, "Cannot convert patch_dynamo_config args/kwargs to constants."
):
f4(torch.randn(3))
def test_set_fullgraph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f1(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
torch._dynamo.graph_break()
return x + 2
inp = torch.ones(3)
self.assertEqual(f1(inp), inp + 3)
self.assertEqual(cnts.frame_count, 2)
@torch.compile(backend=cnts)
def f2(x):
x = x + 1
with torch._dynamo.set_fullgraph(True):
torch._dynamo.graph_break()
return x + 2
with self.assertRaises(Unsupported):
f2(inp)
@torch.compile(backend=cnts, fullgraph=True)
def f3(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
torch._dynamo.graph_break()
x = x + 2
torch._dynamo.graph_break()
return x + 4
cnts.clear()
self.assertEqual(f3(inp), inp + 7)
self.assertEqual(cnts.frame_count, 3)
def inner_f4(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
def f4(x):
x = x + 1
with torch._dynamo.set_fullgraph(False):
torch._dynamo.skip_frame()
return inner_f4(x)
cnts.clear()
self.assertEqual(f4(inp), inp + 7)
self.assertEqual(cnts.frame_count, 2)
def test_set_fullgraph_nested(self):
# set_fullgraph in a nested frame
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
def inner_f5(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
def f5(x):
x = x + 1
return inner_f5(x)
inp = torch.ones(3)
self.assertEqual(f5(inp), inp + 7)
self.assertEqual(cnts.frame_count, 4)
def inner_f6(x):
x = x + 2
with torch._dynamo.set_fullgraph(False):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
def f6(x):
x = x + 1
return inner_f6(x)
cnts.clear()
self.assertEqual(f6(inp), inp + 7)
self.assertEqual(cnts.frame_count, 3)
def inner_f7(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=False)
def f7(x):
x = x + 1
return inner_f7(x)
with self.assertRaises(Unsupported):
f7(inp)
def test_set_fullgraph_nested_with_skip(self):
# set_fullgraph in a nested frame with a skipped frame in between
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.set_fullgraph(False)
def inner2_f8(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
def inner1_f8(x):
with torch._dynamo.set_fullgraph(False):
torch._dynamo.skip_frame()
return inner2_f8(x)
@torch.compile(backend=cnts, fullgraph=True)
def f8(x):
x = x + 1
return inner1_f8(x)
inp = torch.ones(3)
self.assertEqual(f8(inp), inp + 7)
self.assertEqual(cnts.frame_count, 4)
def inner2_f9(x):
x = x + 2
with torch._dynamo.set_fullgraph(True):
torch._dynamo.graph_break()
return x + 4
@torch._dynamo.disable(recursive=False)
def inner1_f9(x):
return inner2_f9(x)
@torch.compile(backend=cnts, fullgraph=False)
def f9(x):
x = x + 1
return inner1_f9(x)
with self.assertRaises(Unsupported):
f9(inp)
# test export with set_fullgraph(False) still errors
def test_set_fullgraph_export(self):
@torch._dynamo.set_fullgraph(False)
def inner(x):
x = x + 2
torch._dynamo.graph_break()
return x + 4
def f(x):
x = x + 1
return inner(x)
with self.assertRaises(Unsupported):
torch._dynamo.export(f)(torch.ones(3))
def test_set_fullgraph_nested_deep(self):
cnts = torch._dynamo.testing.CompileCounter()
def inner1_f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2_f1(x):
return inner1_f1(x)
def inner3_f1(x):
with torch._dynamo.set_fullgraph(False):
return inner2_f1(x)
def inner4_f1(x):
return inner3_f1(x)
@torch.compile(backend=cnts, fullgraph=True)
def f1(x):
x = x + 4
return inner4_f1(x)
inp = torch.ones(3)
self.assertEqual(f1(inp), inp + 7)
self.assertEqual(cnts.frame_count, 4)
def inner1_f2(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2_f2(x):
return inner1_f2(x)
def inner3_f2(x):
with torch._dynamo.set_fullgraph(True):
return inner2_f2(x)
def inner4_f2(x):
return inner3_f2(x)
@torch.compile(backend=cnts, fullgraph=False)
def f2(x):
x = x + 4
return inner4_f2(x)
with self.assertRaises(Unsupported):
f2(inp)
def test_set_fullgraph_error(self):
@torch.compile(backend="eager")
def f1():
with torch._dynamo.set_fullgraph(foo="bar"):
pass
@torch.compile(backend="eager")
def f2():
with torch._dynamo.set_fullgraph():
pass
@torch.compile(backend="eager")
def f3():
with torch._dynamo.set_fullgraph("foo"):
pass
with self.assertRaises(Exception):
f1()
with self.assertRaises(Exception):
f2()
with self.assertRaises(Exception):
f3()
def test_nested_compile_fullgraph(self):
inp = torch.ones(3)
@torch.compile(backend="eager", fullgraph=True)
def inner_f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(backend="eager", fullgraph=False)
def f1(x):
return inner_f1(x)
with self.assertRaises(Unsupported):
f1(inp)
@torch.compile(backend="eager", fullgraph=False)
def inner_f2(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(backend="eager", fullgraph=True)
def f2(x):
return inner_f2(x)
self.assertEqual(f2(inp), inp + 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()