mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
In-place mutation may create inter-loop dependency that breaks the parallelism we have for associative_scan so we ban input mutations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158883 Approved by: https://github.com/zou3519 ghstack dependencies: #154193, #158965, #158863, #158864
8922 lines
339 KiB
Python
8922 lines
339 KiB
Python
# Owner(s): ["module: functorch"]
|
|
import contextlib
|
|
import functools
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from functorch.experimental import control_flow
|
|
from functorch.experimental.control_flow import cond
|
|
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
|
from torch._higher_order_ops.associative_scan import (
|
|
_fake_associative_scan,
|
|
associative_scan,
|
|
)
|
|
from torch._higher_order_ops.map import _fake_map
|
|
from torch._higher_order_ops.scan import _fake_scan, scan
|
|
from torch._higher_order_ops.schema import HopSchemaGenerator
|
|
from torch._higher_order_ops.while_loop import while_loop
|
|
from torch._subclasses.functional_tensor import (
|
|
CppFunctionalizeAPI,
|
|
FunctionalTensor,
|
|
FunctionalTensorMode,
|
|
PythonFunctionalizeAPI,
|
|
)
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_cuda import SM70OrLater
|
|
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
|
from torch.testing._internal.common_utils import (
|
|
decorateIf,
|
|
instantiate_parametrized_tests,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
requires_cuda,
|
|
run_tests,
|
|
skipIfCrossRef,
|
|
skipIfRocm,
|
|
skipIfTorchDynamo,
|
|
TEST_WITH_CROSSREF,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
# TODO: pull these helpers from AOTAutograd later
|
|
def to_fun(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return FunctionalTensor.to_functional(t)
|
|
return t
|
|
|
|
|
|
def from_fun(t):
|
|
if not isinstance(t, FunctionalTensor):
|
|
# quick sanity assert
|
|
if isinstance(t, torch.Tensor):
|
|
assert not torch._is_functional_tensor(t)
|
|
return t
|
|
torch._sync(t)
|
|
return torch._from_functional_tensor(t.elem)
|
|
|
|
|
|
def to_fun_old(t):
|
|
if isinstance(t, torch.Tensor) and not torch._is_functional_tensor(t):
|
|
out = torch._to_functional_tensor(t)
|
|
torch._mirror_autograd_meta_to(t, out)
|
|
return out
|
|
return t
|
|
|
|
|
|
def from_fun_old(t):
|
|
# quick sanity assert
|
|
if isinstance(t, torch.Tensor):
|
|
assert torch._is_functional_tensor(t)
|
|
torch._sync(t)
|
|
return torch._from_functional_tensor(t)
|
|
return t
|
|
|
|
|
|
def _fake_while_loop(cond_fn, body_fn, operands):
|
|
while cond_fn(*operands):
|
|
operands = body_fn(*operands)
|
|
return operands
|
|
|
|
|
|
def compile_mode_helper(fct, compile_mode):
|
|
if compile_mode == "compile":
|
|
return torch.compile(fct, fullgraph=True, dynamic=False)
|
|
elif compile_mode == "compile_dynamic_shape":
|
|
return torch.compile(fct, fullgraph=True, dynamic=True)
|
|
elif compile_mode == "eager":
|
|
return torch.compile(fct, fullgraph=True, backend="eager")
|
|
else:
|
|
return fct
|
|
|
|
|
|
ALIAS_FN = [
|
|
lambda x: x,
|
|
lambda x: x.view(-1),
|
|
lambda x: x.reshape(-1),
|
|
lambda x: x.squeeze(0),
|
|
lambda x: x.unsqueeze(0),
|
|
lambda x: x.transpose(0, 1),
|
|
lambda x: x.flatten(),
|
|
lambda x: x.expand(1, *x.size()),
|
|
]
|
|
|
|
|
|
def get_scan_combine_fn(name, associative=True, parameters=None):
|
|
def add(x: torch.Tensor, y: torch.Tensor):
|
|
return x + y
|
|
|
|
def adds(x: torch.Tensor, y: torch.Tensor):
|
|
return x + x, y + y
|
|
|
|
def mul(x: torch.Tensor, y: torch.Tensor):
|
|
return x * y
|
|
|
|
def div(x: torch.Tensor, y: torch.Tensor):
|
|
return x / y
|
|
|
|
def s5_operator(x: torch.Tensor, y: torch.Tensor):
|
|
A_i, Bu_i = x
|
|
A_j, Bu_j = y
|
|
return A_j * A_i, A_j * Bu_i + Bu_j
|
|
|
|
def different_input_size_operator(x: torch.Tensor, y: torch.Tensor):
|
|
x_o, dA_o, dB_o, C_o, y_o = x
|
|
x_n, dA_n, dB_n, C_n, y_n = y
|
|
|
|
x_new = x_n + x_o
|
|
y_new = torch.einsum("bdn,bn->bd", x_new, C_n)
|
|
|
|
return x_new, dA_n + 0.0, dB_n + 0.0, C_n + 0.0, y_new
|
|
|
|
def tuple_fct(x, y):
|
|
return (x[0] + y[0], x[1] * y[1])
|
|
|
|
def complex_pointwise(x, y):
|
|
return {
|
|
"i": x["i"] * y["i"],
|
|
"j": (
|
|
[x["j"][0][0] * y["j"][0][0]],
|
|
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
}
|
|
|
|
def non_pointwise(x: torch.Tensor, y: torch.Tensor):
|
|
W = torch.diag(torch.ones(2, device=x.device))
|
|
return x @ W + y @ W
|
|
|
|
def RNN(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ parameters[0] + parameters[1]
|
|
h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3])
|
|
return h_new, h_new.clone()
|
|
|
|
def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor):
|
|
h_new = torch.tanh(x[0] + x[1] + y)
|
|
c2 = x[1] + y
|
|
with torch.no_grad():
|
|
c1 = x[0] + y
|
|
return (c1, c2), h_new
|
|
|
|
if name == "add":
|
|
fct = add
|
|
elif name == "adds":
|
|
fct = adds
|
|
elif name == "mul":
|
|
fct = mul
|
|
elif name == "div":
|
|
fct = div
|
|
elif name == "s5_operator":
|
|
fct = s5_operator
|
|
elif name == "different_input_size_operator":
|
|
fct = different_input_size_operator
|
|
elif name == "tuple_fct":
|
|
fct = tuple_fct
|
|
elif name == "complex_pointwise":
|
|
fct = complex_pointwise
|
|
elif name == "non_pointwise":
|
|
fct = non_pointwise
|
|
elif name == "RNN":
|
|
fct = RNN
|
|
elif name == "fct_c1_no_grad":
|
|
fct = fct_c1_no_grad
|
|
else:
|
|
raise ValueError("Combine_fn name unknown!")
|
|
|
|
if not associative:
|
|
return lambda x, y: (fct(x, y), fct(x, y))
|
|
else:
|
|
return fct
|
|
|
|
|
|
def _while_loop_tests():
|
|
def simple(x):
|
|
def cond_fn(x):
|
|
return x.sum() < 10
|
|
|
|
def body_fn(x):
|
|
return (x + 1,)
|
|
|
|
return while_loop(cond_fn, body_fn, (x,))
|
|
|
|
def simple_with_mutation(x):
|
|
def cond_fn(x):
|
|
y = x.clone().add_(1).add_(-1)
|
|
return y.sum() < 10
|
|
|
|
def body_fn(x):
|
|
y = x.clone().add_(1).add_(-1)
|
|
return (y + 1,)
|
|
|
|
return while_loop(cond_fn, body_fn, (x,))
|
|
|
|
def nested(out_iter, it, y):
|
|
def cond_fn(out_iter, it, y):
|
|
return it.sum() < 10
|
|
|
|
def body_fn(out_iter, it, y):
|
|
return (out_iter.clone(), it + y, y + 1)
|
|
|
|
def outer_cond_fn(out_iter, it, y):
|
|
return out_iter.sum() < 2
|
|
|
|
def outer_body_fn(out_iter, it, y):
|
|
out_iter, it, y = while_loop(cond_fn, body_fn, (out_iter, it, y))
|
|
return (out_iter + 1, it, y)
|
|
|
|
return while_loop(outer_cond_fn, outer_body_fn, (out_iter, it, y))
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, ci, cj, a, b):
|
|
def cond_fn(i1, j1, x1, y1):
|
|
return i1 > 0
|
|
|
|
def body_fn(i1, j1, x1, y1):
|
|
def cond_fn_nested(i2, j2, x2, y2):
|
|
return j2 > 0
|
|
|
|
def body_fn_nested(i2, j2, x2, y2):
|
|
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
|
|
|
|
i1, j1, x1, y1 = while_loop(
|
|
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
|
|
)
|
|
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
|
|
|
|
return while_loop(cond_fn, body_fn, (ci, cj, a, b))
|
|
|
|
class SimpleWithLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.dec = torch.nn.Buffer(torch.tensor(1))
|
|
|
|
def forward(self, iter, x):
|
|
def cond_fn(it, x):
|
|
return it - self.dec > 0
|
|
|
|
def body_fn(it, x):
|
|
return it - 1, self.linear(x)
|
|
|
|
return while_loop(cond_fn, body_fn, (iter, x))
|
|
|
|
class SimpleWithPytreeCarry(torch.nn.Module):
|
|
def forward(self, it, pytree_input):
|
|
def cond_fn(it, pytree_input):
|
|
return it > 0
|
|
|
|
def body_fn(it, pytree_input):
|
|
x = pytree_input[0][0]
|
|
y = pytree_input[1]["x"]
|
|
z = pytree_input[1]["y"]
|
|
new_x = y.sin()
|
|
new_y = z.cos()
|
|
new_z = x + 1
|
|
return it - 1, ([new_x], {"x": new_y, "y": new_z})
|
|
|
|
return while_loop(cond_fn, body_fn, (it, pytree_input))
|
|
|
|
class NestedWithLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mod = SimpleWithLinear()
|
|
self.outer_linear = torch.nn.Linear(2, 2)
|
|
self.dec = torch.nn.Buffer(torch.tensor(1))
|
|
|
|
def forward(self, iter, x):
|
|
def cond_fn(it, x):
|
|
return it - self.dec > 0
|
|
|
|
def body_fn(it, x):
|
|
return it - 1, self.outer_linear(self.mod(it, x)[1])
|
|
|
|
return while_loop(cond_fn, body_fn, (iter, x))
|
|
|
|
class PytreeIntCarry(torch.nn.Module):
|
|
def forward(self, x):
|
|
a = x.shape[0]
|
|
b = x.shape[1]
|
|
|
|
def cond_fn(shapes, const_int_dict, x):
|
|
a, b = shapes
|
|
c1, c2, c3 = const_int_dict["int_carry"]
|
|
return c1 * c2 * c3 < a * b
|
|
|
|
def body_fn(shapes, const_int_dict, x):
|
|
a, b = shapes
|
|
c1, c2, c3 = const_int_dict["int_carry"]
|
|
return (
|
|
[a + 1, b + 1],
|
|
{"int_carry": (c1 + 1, c2 + 1, c3 + 1)},
|
|
x + 1,
|
|
)
|
|
|
|
carry = ([a, b], {"int_carry": (2, 2, 3)}, x.sin())
|
|
out_shapes, out_it, out_x = while_loop(cond_fn, body_fn, carry)
|
|
out_inc = pytree.tree_map(lambda x: x + 1, out_it)
|
|
out_add = pytree.tree_map(lambda x: x + out_x, out_it)
|
|
return (out_shapes, out_inc, out_add, out_x)
|
|
|
|
class IntCarry(torch.nn.Module):
|
|
def forward(self, x):
|
|
def cond_fn(it, x):
|
|
return it < x.shape[0]
|
|
|
|
def body_fn(it, x):
|
|
x_clone = x.clone()
|
|
# Need these checks to select from x
|
|
torch._check(it >= 0)
|
|
torch._check(it < x.shape[0])
|
|
x_clone.select(0, it).copy_(x_clone.select(0, it) + it)
|
|
return it + 1, x_clone
|
|
|
|
# We invoke the hop directly to avoid triggering dyanmo tracing
|
|
out_it, out_x = torch.ops.higher_order.while_loop(
|
|
cond_fn, body_fn, (0, x), tuple()
|
|
)
|
|
# We need torch._check to use it in torch.ones call
|
|
torch._check(out_it > 0)
|
|
return (
|
|
out_it + 1,
|
|
out_it + out_x,
|
|
out_it < x.shape[0],
|
|
torch.ones(out_it * 2),
|
|
)
|
|
|
|
class ConstAndSymIntOutput(torch.nn.Module):
|
|
def forward(self, t):
|
|
a = t.shape[0]
|
|
b = t.shape[1]
|
|
|
|
def cond_fn(a, b, c1, c2, c3, c0, u0, x):
|
|
return c1 * c2 * c3 < a * b
|
|
|
|
def body_fn(a, b, c1, c2, c3, c0, u0, x):
|
|
return b, c1, c2, c3, a, 0, u0 + 1, x + 1
|
|
|
|
carry = (a, b, 1, 1, 1, a + 1, t.sum().to(torch.int64).item(), t.sin())
|
|
out_it = torch.ops.higher_order.while_loop(cond_fn, body_fn, carry, tuple())
|
|
out_inc = pytree.tree_map(lambda x: x + 1, out_it)
|
|
out_add = pytree.tree_map(lambda x: x + t, out_it)
|
|
return out_inc, out_add
|
|
|
|
nested2 = Nested()
|
|
simple_with_linear = SimpleWithLinear()
|
|
simple_with_pytree_carry = SimpleWithPytreeCarry()
|
|
nested_with_linear = NestedWithLinear()
|
|
int_carry = IntCarry()
|
|
pytree_int_carry = PytreeIntCarry()
|
|
const_and_symint_output = ConstAndSymIntOutput()
|
|
|
|
x = torch.zeros(1)
|
|
y = torch.zeros(1)
|
|
z = torch.zeros(1)
|
|
return {
|
|
"simple": (simple, (x,)),
|
|
"nested": (nested, (x, y, z)),
|
|
"nested2": (
|
|
nested2,
|
|
(torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2)),
|
|
),
|
|
"simple_with_mutation": (simple_with_mutation, (x,)),
|
|
"simple_with_linear": (
|
|
simple_with_linear,
|
|
(torch.tensor(3), torch.randn(2, 2)),
|
|
),
|
|
"nested_with_linear": (
|
|
nested_with_linear,
|
|
(torch.tensor(3), torch.randn(2, 2)),
|
|
),
|
|
"simple_with_pytree_carry": (
|
|
simple_with_pytree_carry,
|
|
(
|
|
torch.tensor(3),
|
|
([torch.randn(3, 3)], {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}),
|
|
),
|
|
),
|
|
"int_carry": (int_carry, (torch.randn(2, 3, requires_grad=True),)),
|
|
"pytree_int_carry": (
|
|
pytree_int_carry,
|
|
(torch.randn(2, 3, requires_grad=True),),
|
|
),
|
|
"const_and_symint_output": (
|
|
const_and_symint_output,
|
|
(torch.randn(2, 3, requires_grad=True),),
|
|
),
|
|
}
|
|
|
|
|
|
WHILE_LOOP_TESTS = _while_loop_tests()
|
|
|
|
|
|
def collect_meta_for_filtered_nodes(
|
|
gm: torch.fx.GraphModule, node_names, meta_field_name
|
|
):
|
|
ret = []
|
|
for mod in gm.modules():
|
|
for node in mod.graph.nodes:
|
|
if node.name in node_names:
|
|
for field_name in meta_field_name:
|
|
ret.append(node.meta.get(field_name))
|
|
return ret
|
|
|
|
|
|
def reduce_func(*operands):
|
|
acc = 0
|
|
for operand in operands:
|
|
acc += operand
|
|
return acc
|
|
|
|
|
|
class ReduceObj:
|
|
def __call__(self, *operands):
|
|
return reduce_func(*operands)
|
|
|
|
|
|
class ReduceMod(torch.nn.Module):
|
|
def _reduce(self, *operands):
|
|
return reduce_func(*operands)
|
|
|
|
def forward(self, *operands):
|
|
return self._reduce(*operands)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfNoDynamoSupport
|
|
class TestControlFlow(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def check_autograd(self, result, result_exp, params):
|
|
params_flatten = pytree.tree_leaves(params)
|
|
result_flatten = pytree.tree_leaves(result)
|
|
result_exp_flatten = pytree.tree_leaves(result_exp)
|
|
grad_exp_init = [torch.ones_like(el) for el in result_exp_flatten]
|
|
expected_grads = torch.autograd.grad(
|
|
result_exp_flatten, params_flatten, grad_exp_init
|
|
)
|
|
grad_init = [torch.ones_like(el) for el in result_flatten]
|
|
grads = torch.autograd.grad(result_flatten, params_flatten, grad_init)
|
|
self.assertEqual(grads, expected_grads, atol=6e-05, rtol=6e-06)
|
|
|
|
def test_cond_no_trace(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4)
|
|
result = cond(False, true_fn, false_fn, [x])
|
|
self.assertEqual(result, torch.cos(x))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_cond_gpu(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4, device="cuda")
|
|
pred = torch.tensor(False, device="cuda")
|
|
result = cond(pred, true_fn, false_fn, [x])
|
|
self.assertEqual(result, torch.cos(x))
|
|
|
|
def test_cond_autograd_simple(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_complex(self):
|
|
def true_fn(x):
|
|
return torch.abs((x**2).sin())
|
|
|
|
def false_fn(x):
|
|
return (x + 42).cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_nested(self):
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, p0, p1, p2, a, b, c):
|
|
def true_fn(x0, y0, z0):
|
|
def true_true_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) * 3.14
|
|
|
|
def true_false_fn(x1, y1, z1):
|
|
def true_false_true_fn(x2, y2, z2):
|
|
return (x2 * y2 * z2) / 2.71
|
|
|
|
def true_false_false_fn(x2, y2, z2):
|
|
return (x2 + y2 + z2) * 1.23
|
|
|
|
return torch.cond(
|
|
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
|
|
|
|
def false_fn(x0, y0, z0):
|
|
def false_true_fn(x1, y1, z1):
|
|
def false_true_true_fn(x2, y2, z2):
|
|
return (x2 - y2 - z2) + 1.23
|
|
|
|
def false_true_false_fn(x2, y2, z2):
|
|
return (x2 / y2 / z2) - 3.14
|
|
|
|
return torch.cond(
|
|
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
def false_false_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) / 2.71
|
|
|
|
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
|
|
|
|
return torch.cond(p0, true_fn, false_fn, [a, b, c])
|
|
|
|
nn_module = Nested()
|
|
|
|
def true_fn(x):
|
|
return nn_module(
|
|
torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x
|
|
)
|
|
|
|
def false_fn(x):
|
|
return nn_module(
|
|
torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x
|
|
)
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_mixed_require_grad(self):
|
|
def true_fn(x, y, z):
|
|
return x * y * z
|
|
|
|
def false_fn(x, y, z):
|
|
return x + y + z
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
y = torch.randn(4, requires_grad=False)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, (x, y, x))
|
|
self.assertEqual(result, fn(x, y, x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x, y, z):
|
|
result = cond(pred, true_fn, false_fn, (x, y, z))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x, y, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1, y_1, z_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (z_1, y_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = z_1 = y_1 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_grad_through_cond(self):
|
|
nn_module = torch.nn.Linear(4, 4)
|
|
|
|
def true_fn(x):
|
|
return nn_module(x)
|
|
|
|
def false_fn(X):
|
|
return x * nn_module(x)
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (nn_module.weight,), grad_out)
|
|
expected_grads = torch.autograd.grad(
|
|
fn(
|
|
x,
|
|
),
|
|
(nn_module.weight,),
|
|
grad_out,
|
|
)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (nn_module.weight,), grad_out)
|
|
|
|
# need to set _allow_non_fake_inputs = True because model parameters don't
|
|
# get fakified.
|
|
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_param_constant0 = self._param_constant0
|
|
_param_constant1 = self._param_constant1
|
|
_tensor_constant0 = self._tensor_constant0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, sym_size_int, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
_param_constant0_1 = self._param_constant0
|
|
_param_constant1_1 = self._param_constant1
|
|
_tensor_constant0_1 = self._tensor_constant0
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; getitem_1 = None
|
|
getitem_2 = cond_1[1]
|
|
getitem_3 = cond_1[2]; getitem_3 = None
|
|
getitem_4 = cond_1[3]; getitem_4 = None
|
|
getitem_5 = cond_1[4]; cond_1 = getitem_5 = None
|
|
return (getitem_2,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_in_forloop(self):
|
|
def for_loop_fake(x):
|
|
for i in range(3):
|
|
x = x * x + 1
|
|
return x
|
|
|
|
def for_loop_test(x):
|
|
for i in range(3):
|
|
pred = i < 3
|
|
|
|
def true_fn(x):
|
|
return x * x + 1
|
|
|
|
def false_fn(x):
|
|
return x
|
|
|
|
x = cond(pred, true_fn, false_fn, (x,))
|
|
|
|
return x
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
x_new = for_loop_test(x)
|
|
x_exp = for_loop_fake(x)
|
|
|
|
self.assertEqual(x_new, x_exp)
|
|
|
|
grad_out = torch.ones_like(x_new)
|
|
grads = torch.autograd.grad(x_new, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(x_exp, (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(x):
|
|
x_new = for_loop_test(x)
|
|
grad_out = torch.ones_like(x_new)
|
|
return torch.autograd.grad(x_new, (x,), grad_out)
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic")(x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
mul = torch.ops.aten.mul.Tensor(x_1, x_1)
|
|
add = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(add, add)
|
|
add_1 = torch.ops.aten.add.Tensor(mul_1, 1); mul_1 = None
|
|
mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1)
|
|
add_2 = torch.ops.aten.add.Tensor(mul_2, 1); mul_2 = None
|
|
ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False); add_2 = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1)
|
|
mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1); ones_like = add_1 = None
|
|
add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3); mul_4 = mul_3 = None
|
|
mul_5 = torch.ops.aten.mul.Tensor(add_3, add)
|
|
mul_6 = torch.ops.aten.mul.Tensor(add_3, add); add_3 = add = None
|
|
add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5); mul_6 = mul_5 = None
|
|
mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1)
|
|
mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None
|
|
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None
|
|
return (add_5,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_pytree_not_all_inputs_used(self):
|
|
def true_fn(x):
|
|
return x["t"][0] + x["t"][1]["b"]
|
|
|
|
def false_fn(x):
|
|
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
if pred:
|
|
with self.assertRaisesRegex(Exception, r"."):
|
|
grads = torch.autograd.grad(result, (a, b, c), grad_out)
|
|
expected_grads = torch.autograd.grad(
|
|
fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out
|
|
)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, a, b, c):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (a, b), grad_out)
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(
|
|
pred, a, b, c
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, a_1, b_1, c_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(c_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]
|
|
getitem_3 = cond_1[2]; getitem_3 = None
|
|
getitem_4 = cond_1[3]; getitem_4 = None
|
|
getitem_5 = cond_1[4]; getitem_5 = None
|
|
getitem_6 = cond_1[5]; cond_1 = getitem_6 = None
|
|
return (getitem_1, getitem_2)""", # noqa: B950
|
|
)
|
|
# Forward
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (add,)""",
|
|
)
|
|
# Backward
|
|
self.assertExpectedInline(
|
|
gm.true_graph_1.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = add = None
|
|
clone = torch.ops.aten.clone.default(arg6_1)
|
|
clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None
|
|
zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
|
|
return [clone, clone_1, None, None, zeros_like, None]""",
|
|
)
|
|
|
|
def test_cond_autograd_pytree_input(self):
|
|
def true_fn(x):
|
|
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
|
|
|
|
def false_fn(x):
|
|
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (a, b), grad_out)
|
|
expected_grads = torch.autograd.grad(
|
|
fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out
|
|
)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (a, b), grad_out)
|
|
|
|
# need to set _allow_non_fake_inputs = True because model parameters don't
|
|
# get fakified.
|
|
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_tensor_constant0 = self._tensor_constant0
|
|
_tensor_constant1 = self._tensor_constant1
|
|
_tensor_constant2 = self._tensor_constant2
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
_tensor_constant0_1 = self._tensor_constant0
|
|
_tensor_constant1_1 = self._tensor_constant1
|
|
_tensor_constant2_1 = self._tensor_constant2
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]
|
|
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
|
|
return (getitem_1, getitem_2)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_different_pytree_output(self):
|
|
def true_fn(x):
|
|
return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]]
|
|
|
|
def false_fn(x):
|
|
return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]}
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_same_pytree_output(self):
|
|
def true_fn(x):
|
|
return {"res": [x["t"][0].clone(), (x["t"][2][0].clone(),)]}
|
|
|
|
def false_fn(x):
|
|
return {"res": [x["t"][1]["b"].clone(), (x["t"][2][0].clone(),)]}
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
result_exp = fn({"t": [a, {"b": b}, (c,)]})
|
|
self.assertEqual(result, result_exp)
|
|
|
|
result_flat, _ = pytree.tree_flatten(result)
|
|
result_exp_flat, _ = pytree.tree_flatten(result_exp)
|
|
|
|
grad_out = [torch.ones_like(g) for g in result_flat]
|
|
expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out)
|
|
grads = torch.autograd.grad(result_flat, (c,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
return result
|
|
|
|
gm = make_fx(f, tracing_mode="real", _allow_non_fake_inputs=True)(pred)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_tensor_constant0 = self._tensor_constant0
|
|
_tensor_constant1 = self._tensor_constant1
|
|
_tensor_constant2 = self._tensor_constant2
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
|
getitem = cond[0]
|
|
getitem_1 = cond[1]; cond = None
|
|
return {'res': [getitem, (getitem_1,)]}""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_torch_nn_module(self):
|
|
nn_module_true = torch.nn.Linear(4, 4)
|
|
|
|
def true_fn(x):
|
|
return nn_module_true(torch.abs((x**2).sin()))
|
|
|
|
nn_module_false = torch.nn.GRUCell(4, 4)
|
|
|
|
def false_fn(x):
|
|
return nn_module_false((x + 42).cos())
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_param_constant0 = self._param_constant0
|
|
_param_constant1 = self._param_constant1
|
|
_param_constant2 = self._param_constant2
|
|
_param_constant3 = self._param_constant3
|
|
_param_constant4 = self._param_constant4
|
|
_param_constant5 = self._param_constant5
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
_param_constant0_1 = self._param_constant0
|
|
_param_constant1_1 = self._param_constant1
|
|
_param_constant2_1 = self._param_constant2
|
|
_param_constant3_1 = self._param_constant3
|
|
_param_constant4_1 = self._param_constant4
|
|
_param_constant5_1 = self._param_constant5
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]; getitem_2 = None
|
|
getitem_3 = cond_1[2]; getitem_3 = None
|
|
getitem_4 = cond_1[3]; getitem_4 = None
|
|
getitem_5 = cond_1[4]; getitem_5 = None
|
|
getitem_6 = cond_1[5]; getitem_6 = None
|
|
getitem_7 = cond_1[6]; cond_1 = getitem_7 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_user_nn_module(self):
|
|
class User_nn_module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, input):
|
|
return input * input
|
|
|
|
nn_module_true = User_nn_module()
|
|
|
|
def true_fn(x):
|
|
return nn_module_true(torch.abs((x**2).sin()))
|
|
|
|
nn_module_false = torch.nn.ReLU(inplace=False)
|
|
|
|
def false_fn(x):
|
|
return nn_module_false((x + 42).cos())
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_inner_fn(self):
|
|
def true_fn(x):
|
|
return torch.abs((x**2).sin())
|
|
|
|
def false_fn(x):
|
|
def inner_fn(x):
|
|
return x**2
|
|
|
|
return torch.abs(inner_fn(x).sin())
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
pred = torch.tensor(False)
|
|
fn = false_fn
|
|
result_false = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result_false, fn(x))
|
|
|
|
grad_out = torch.ones_like(result_false)
|
|
grads_false = torch.autograd.grad(result_false, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads_false)
|
|
|
|
pred = torch.tensor(True)
|
|
fn = true_fn
|
|
result_true = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result_true, fn(x))
|
|
self.assertEqual(result_false, result_true)
|
|
|
|
grad_out = torch.ones_like(result_true)
|
|
grads_true = torch.autograd.grad(result_true, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads_true)
|
|
self.assertEqual(grads_false, grads_true)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_inner_tensor(self):
|
|
def true_fn(x):
|
|
return torch.abs((x**2).sin())
|
|
|
|
def false_fn(x):
|
|
y = torch.ones(4, requires_grad=False) * 42
|
|
return (x * y).cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_cond_autograd_gpu(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")],
|
|
[false_fn, true_fn],
|
|
):
|
|
x = torch.randn(4, requires_grad=True, device="cuda")
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def _test_cond_autograd(self, cond_fct, pred_fn, true_fn, false_fn, operands):
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
|
|
|
# This is a helper function that extracts the metadata from the tensor and
|
|
# sets the requries_grad flag to false. This is needed as we compare the
|
|
# metadata of the operands and the gradients
|
|
def _extract_tensor_metadata_except_requires_grad(arg):
|
|
metadata = _extract_tensor_metadata(arg)
|
|
metadata = TensorMetadata(
|
|
metadata.shape,
|
|
metadata.dtype,
|
|
False,
|
|
metadata.stride,
|
|
metadata.memory_format,
|
|
metadata.is_quantized,
|
|
metadata.qparams,
|
|
)
|
|
return metadata
|
|
|
|
# Comparison of FWD path
|
|
cond_outputs = cond_fct(pred_fn(*operands), true_fn, false_fn, operands)
|
|
operands_forced_grad = [o.clone().detach() for o in operands]
|
|
for o in operands_forced_grad:
|
|
o.requires_grad = True
|
|
cond_outputs_exp = (
|
|
true_fn(*operands_forced_grad)
|
|
if pred_fn(*operands_forced_grad)
|
|
else false_fn(*operands_forced_grad)
|
|
)
|
|
self.assertEqual(cond_outputs, cond_outputs_exp)
|
|
|
|
# Comparison of BWD path
|
|
cond_inputs = [o for o in operands if o.requires_grad]
|
|
cond_inputs_exp = [o for o in operands_forced_grad if o.requires_grad]
|
|
|
|
# Check if at least some operators require grads
|
|
if len(cond_inputs) > 0:
|
|
grad_inputs = torch.autograd.grad(
|
|
cond_outputs, cond_inputs, allow_unused=True, retain_graph=True
|
|
)
|
|
grad_inputs_exp = torch.autograd.grad(
|
|
cond_outputs_exp,
|
|
cond_inputs_exp,
|
|
allow_unused=True,
|
|
materialize_grads=True,
|
|
)
|
|
|
|
grad_exp_masked = [
|
|
g for g, o in zip(grad_inputs_exp, operands) if o.requires_grad
|
|
]
|
|
self.assertEqual(grad_exp_masked, grad_inputs)
|
|
|
|
# Extraction and comparison of Metadata of operands and gradients
|
|
operands_metadata = [
|
|
_extract_tensor_metadata_except_requires_grad(o) for o in cond_inputs
|
|
]
|
|
grad_metadata = [
|
|
_extract_tensor_metadata_except_requires_grad(o) for o in grad_inputs
|
|
]
|
|
self.assertTrue(
|
|
all(op == g for op, g in zip(operands_metadata, grad_metadata))
|
|
)
|
|
|
|
return cond_outputs, cond_inputs
|
|
|
|
# TODO: The compile_mode = `compile_dynamic_shape` raises the Error
|
|
# torch._inductor.exc.LoweringException: NotImplementedError: get_size() is not
|
|
# implemented by <class 'torch._inductor.ir.NoneAsConstantBuffer'>!
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["compile_dynamic_shape"])
|
|
@parametrize("scalar", [False])
|
|
@unittest.expectedFailure
|
|
def test_cond_autograd_zeros_unused_branch_complex_compile_fail(
|
|
self, compile_mode, scalar
|
|
):
|
|
device = torch.device("cuda")
|
|
cond_fct = compile_mode_helper(torch.cond, compile_mode)
|
|
|
|
autograd = [False, True, True, True, True]
|
|
|
|
if scalar:
|
|
# These operands work
|
|
x = torch.randn((), device=device, requires_grad=bool(autograd[0]))
|
|
w1 = torch.randn((), device=device, requires_grad=bool(autograd[1]))
|
|
b1 = torch.randn((), device=device, requires_grad=bool(autograd[2]))
|
|
w2 = torch.randn((), device=device, requires_grad=bool(autograd[3]))
|
|
b2 = torch.randn((), device=device, requires_grad=bool(autograd[4]))
|
|
else:
|
|
# These operands do not work
|
|
x = torch.randn(4, 5, device=device, requires_grad=bool(autograd[0]))
|
|
w1 = torch.randn(2, 4, device=device, requires_grad=bool(autograd[1]))
|
|
b1 = torch.randn(2, 1, device=device, requires_grad=bool(autograd[2]))
|
|
w2 = torch.randn(2, 4, device=device, requires_grad=bool(autograd[3]))
|
|
b2 = torch.randn(1, 5, device=device, requires_grad=bool(autograd[4]))
|
|
|
|
operands = [x, w1, b1, w2, b2]
|
|
|
|
def true_fn(x, w1, b1, w2, b2):
|
|
if scalar:
|
|
# This works
|
|
return ((w1 * x + b1),)
|
|
else:
|
|
# This does not work
|
|
return ((w1 @ x + b1).sum(),)
|
|
|
|
def false_fn(x, w1, b1, w2, b2):
|
|
if scalar:
|
|
# This works
|
|
return ((w2 * x + b2),)
|
|
else:
|
|
# This does not work
|
|
return ((w2 @ x + b2).sum(),)
|
|
|
|
def pred_fn(x, w1, b1, w2, b2):
|
|
return x.mean() > 0
|
|
|
|
cond_outputs, cond_inputs = self._test_cond_autograd(
|
|
cond_fct, pred_fn, true_fn, false_fn, operands
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_map_gpu(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
xs = torch.ones(3, 2, 2, device="cuda")
|
|
y = torch.ones(2, device="cuda")
|
|
res = control_flow.map(f, xs, y)
|
|
expected = _fake_map(f, xs, y)
|
|
self.assertEqual(expected, res)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_while_loop_gpu(self):
|
|
def cond_fn(x):
|
|
return x.sum() < 10
|
|
|
|
def body_fn(x):
|
|
return (x + 1,)
|
|
|
|
x = torch.zeros(1, device="cuda")
|
|
res = while_loop(cond_fn, body_fn, (x,))
|
|
expected = _fake_while_loop(cond_fn, body_fn, (x,))
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_map_illegal_inputs(self):
|
|
def f(x, y):
|
|
return x[0] + x[1] + y
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\.",
|
|
):
|
|
_ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Leading dimensions of mapped xs cannot be 0\."
|
|
):
|
|
_ = control_flow.map(
|
|
f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2)
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Leading dimensions of mapped xs must be consistent\. "
|
|
r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\.",
|
|
):
|
|
_ = control_flow.map(
|
|
f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5)
|
|
)
|
|
|
|
def test_map_illegal_outputs(self):
|
|
def f(x, y):
|
|
return x.item()
|
|
|
|
def f1(x, y):
|
|
return y.size()
|
|
|
|
def f2(x, y):
|
|
return None
|
|
|
|
x = torch.ones([3])
|
|
y = torch.ones([1, 2, 3])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "map doesn't work unless it is captured completely"
|
|
):
|
|
control_flow.map(f, x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
# "Expected all leaves to be of torch.Tensor type.*",
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"map doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
control_flow.map(f1, x, y)
|
|
|
|
# return None is OK
|
|
control_flow.map(f2, x, y)
|
|
|
|
def test_map_list_in_out(self):
|
|
def f(x, y):
|
|
return [[x[0][0] + y]]
|
|
|
|
xs = [[torch.ones(3, 2, 2)]]
|
|
y = torch.ones(2)
|
|
res = control_flow.map(f, xs, y)
|
|
expected = _fake_map(f, xs, y)
|
|
self.assertEqual(len(res), 1)
|
|
self.assertEqual(len(res[0]), 1)
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_map_dict_in_out(self):
|
|
def f(x, y):
|
|
return {"c": x["a"]["b"] + y}
|
|
|
|
xs = {"a": {"b": torch.ones(3, 2, 2)}}
|
|
y = torch.ones(2)
|
|
res = control_flow.map(f, xs, y)
|
|
expected = _fake_map(f, xs, y)
|
|
self.assertEqual(len(res), 1)
|
|
self.assertTrue("c" in res)
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_map_autograd_simple(self):
|
|
def f(x, y):
|
|
return x.sin().cos() * y.cos().sin()
|
|
|
|
xs = torch.ones(3, 2, 2, requires_grad=True)
|
|
y = torch.ones(2, requires_grad=True)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
grad_out = torch.ones_like(res)
|
|
grads = torch.autograd.grad(res, (xs, y), grad_out)
|
|
expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out)
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_map_autograd_simple_partial_grad(self):
|
|
def f(x, y):
|
|
return x.sin().cos() * y.cos().sin()
|
|
|
|
xs = torch.ones(3, 2, 2, requires_grad=True)
|
|
# Disable the gradient computation for y
|
|
y = torch.ones(2, requires_grad=False)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
grad_out = torch.ones_like(res)
|
|
grads = torch.autograd.grad(res, (xs,), grad_out)
|
|
expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out)
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_map_autograd_no_grad_output(self):
|
|
def f(x, y):
|
|
return x[0].sin().cos() + y, y.cos().sin()
|
|
|
|
xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)]
|
|
# Disable the gradient computation for y
|
|
y = torch.ones(2, requires_grad=False)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
grad_out = torch.ones_like(res[0])
|
|
grads = torch.autograd.grad(res[0], (xs[0],), grad_out)
|
|
expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out)
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_map_autograd_nested_list(self):
|
|
import torch.utils._pytree as pytree
|
|
|
|
def f(x, y):
|
|
a, b = x
|
|
c, d = a
|
|
return [[b.sin() * c.cos()], d.sin() * y.cos()]
|
|
|
|
def fwbw(map_op, f, x, y):
|
|
z = map_op(f, x, y)
|
|
flat_x = pytree.tree_leaves(x)
|
|
flat_z = pytree.tree_leaves(z)
|
|
grads = torch.autograd.grad(
|
|
flat_z, flat_x, [torch.ones_like(z) for z in flat_z]
|
|
)
|
|
return z, grads
|
|
|
|
x = [
|
|
[
|
|
torch.randn(3, 2, 2, requires_grad=True),
|
|
torch.randn(3, 2, 1, requires_grad=True),
|
|
],
|
|
torch.ones(3, 1, 2, requires_grad=True),
|
|
]
|
|
y = torch.ones(1, requires_grad=True)
|
|
true_outs = fwbw(control_flow.map, f, x, y)
|
|
fake_outs = fwbw(_fake_map, f, x, y)
|
|
self.assertEqual(true_outs, fake_outs)
|
|
|
|
def test_map_autograd_higher_order(self):
|
|
from torch.autograd.functional import hessian as hes, jacobian as jac
|
|
|
|
def f(x, y):
|
|
return x.sin().cos() + y
|
|
|
|
def wrapper_jac(x, y):
|
|
return control_flow.map(f, x, y)
|
|
|
|
def wrapper_jac_fake(x, y):
|
|
return _fake_map(f, x, y)
|
|
|
|
def wrapper_hes(x, y):
|
|
return control_flow.map(f, x, y).sum()
|
|
|
|
def wrapper_hes_fake(x, y):
|
|
return _fake_map(f, x, y).sum()
|
|
|
|
for g_fct, (wrap, wrap_fake) in [
|
|
(jac, [wrapper_jac, wrapper_jac_fake]),
|
|
(hes, [wrapper_hes, wrapper_hes_fake]),
|
|
]:
|
|
xs = torch.ones(3, 2, 2, requires_grad=True)
|
|
# Disable the gradient computation for y
|
|
y = torch.ones(2, requires_grad=False)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
self.assertEqual(expected_res, res)
|
|
|
|
expected_grads = g_fct(wrap_fake, (xs, y))
|
|
grads = g_fct(wrap, (xs, y))
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_scan_y_less_ndim_then_dim(self):
|
|
def combine_fn(carry, x):
|
|
return carry @ x, (carry @ x).sum()
|
|
|
|
init = torch.randn(4, 3)
|
|
xs = torch.randn(3, 3, 2)
|
|
dim = 2
|
|
out = scan(combine_fn, init, xs, dim=dim)
|
|
exp_out = _fake_scan(combine_fn, init, xs, dim=dim)
|
|
self.assertEqual(out, exp_out)
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_compile(self, reverse, compile_mode, device, autograd):
|
|
def add2(x: torch.Tensor, y: torch.Tensor):
|
|
return x * y, x + y
|
|
|
|
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
for op, op_pt, init in [
|
|
(
|
|
get_scan_combine_fn("add", False),
|
|
torch.cumsum,
|
|
torch.zeros(10, 2, device=device, requires_grad=autograd),
|
|
),
|
|
(
|
|
get_scan_combine_fn("mul", False),
|
|
torch.cumprod,
|
|
torch.ones(10, 2, device=device, requires_grad=autograd),
|
|
),
|
|
]:
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
if not reverse:
|
|
result_exp_PT = op_pt(x, 0)
|
|
self.assertEqual(result[1], result_exp_PT)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
# Jax Examples
|
|
x = torch.arange(0, 4, device=device, dtype=torch.int64)
|
|
init = torch.zeros(1, device=device, dtype=torch.int64)
|
|
cumsum1 = scan_fct(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
cumsum_exp = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init,
|
|
xs=x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
if not reverse:
|
|
self.assertEqual(
|
|
cumsum1[1],
|
|
torch.tensor([[0.0], [1.0], [3.0], [6.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64))
|
|
else:
|
|
self.assertEqual(
|
|
cumsum1[1],
|
|
torch.tensor([[6.0], [6.0], [5.0], [3.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64))
|
|
self.assertEqual(cumsum1, cumsum_exp)
|
|
|
|
# Different carry computation as output computation
|
|
x = torch.arange(1, 5, device=device, dtype=torch.int64)
|
|
init = torch.ones(1, device=device, dtype=torch.int64)
|
|
result = scan_fct(add2, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(add2, init=init, xs=x, dim=0, reverse=reverse)
|
|
if not reverse:
|
|
self.assertEqual(
|
|
result[1],
|
|
torch.tensor([[2.0], [3.0], [5.0], [10.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64))
|
|
else:
|
|
self.assertEqual(
|
|
result[1],
|
|
torch.tensor([[25.0], [14.0], [7.0], [5.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64))
|
|
self.assertEqual(result, result_exp)
|
|
|
|
# Non associative operation
|
|
x = torch.arange(
|
|
0, 5, device=device, dtype=torch.float32, requires_grad=autograd
|
|
)
|
|
init = torch.ones(1, device=device, dtype=torch.float32, requires_grad=autograd)
|
|
result = scan_fct(
|
|
get_scan_combine_fn("div", False),
|
|
init,
|
|
x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
result_exp = _fake_scan(
|
|
get_scan_combine_fn("div", False),
|
|
init=init,
|
|
xs=x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize(
|
|
"dtype",
|
|
[
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.complex64,
|
|
],
|
|
)
|
|
def test_scan_dtype(self, reverse, compile_mode, device, dtype):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# Check all outputs and carries on the correct device and with torch.float32
|
|
x = torch.randn(3, 10, 2, device=device).to(dtype=dtype)
|
|
op, init = (
|
|
get_scan_combine_fn("adds"),
|
|
torch.zeros(10, 2, device=device, dtype=dtype),
|
|
)
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(
|
|
[[r.device.type for r in res] for res in result],
|
|
[[device.type for _ in res] for res in result],
|
|
)
|
|
self.assertEqual(
|
|
[[r.dtype for r in res] for res in result],
|
|
[[dtype for _ in res] for res in result],
|
|
)
|
|
|
|
# Check all outputs and carries on the correct device and
|
|
# carry.dtype torch.float32 and output.dtype torch.float16
|
|
x = torch.randn(3, 10, 2, device=device).to(dtype=dtype)
|
|
op, init = (
|
|
get_scan_combine_fn("adds"),
|
|
torch.zeros(10, 2, device=device, dtype=torch.float32),
|
|
)
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(
|
|
[[r.dtype for r in res] for res in result],
|
|
[
|
|
[torch.float32 for _ in range(len(result[0]))],
|
|
[dtype for _ in range(len(result[1]))],
|
|
],
|
|
)
|
|
|
|
# Check all outputs and carries on the correct device and
|
|
# carry.dtype torch.int64 and output.dtype torch.float32
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
op, init = (
|
|
get_scan_combine_fn("adds"),
|
|
torch.zeros(10, 2, device=device, dtype=dtype),
|
|
)
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(
|
|
[[r.dtype for r in res] for res in result],
|
|
[
|
|
[dtype for _ in range(len(result[0]))],
|
|
[torch.float32 for _ in range(len(result[1]))],
|
|
],
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_dim(self, reverse, compile_mode, device, autograd):
|
|
import random
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
num_dims = [random.randint(2, 5) for _ in range(5)]
|
|
for num_dim in num_dims:
|
|
shapes = [random.randint(1, 10) for _ in range(num_dim)]
|
|
rnd_scan_dim = random.randint(0, num_dim - 1)
|
|
x = torch.randn(*shapes, device=device, requires_grad=autograd)
|
|
init_shapes = shapes[:rnd_scan_dim] + shapes[rnd_scan_dim + 1 :]
|
|
|
|
for op, op_pt, init in [
|
|
(
|
|
get_scan_combine_fn("add", False),
|
|
torch.cumsum,
|
|
torch.zeros(*init_shapes, device=device, requires_grad=autograd),
|
|
),
|
|
(
|
|
get_scan_combine_fn("mul", False),
|
|
torch.cumprod,
|
|
torch.ones(*init_shapes, device=device, requires_grad=autograd),
|
|
),
|
|
]:
|
|
result = scan_fct(op, init, x, dim=rnd_scan_dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
op, init=init, xs=x, dim=rnd_scan_dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
if not reverse:
|
|
result_exp_PT = op_pt(x, rnd_scan_dim)
|
|
res_list = list(result)
|
|
res_list[1] = res_list[1].movedim(0, rnd_scan_dim)
|
|
self.assertEqual(res_list[1], result_exp_PT)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_binary_operator(self, reverse, compile_mode, device, autograd):
|
|
state_dim = 20
|
|
timesteps = 10
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
projected_inputs = torch.randn(
|
|
timesteps, state_dim, requires_grad=autograd, device=device
|
|
)
|
|
A = torch.randn(state_dim, requires_grad=autograd, device=device)
|
|
elements = (A.repeat((timesteps, 1)), projected_inputs)
|
|
init = tuple(
|
|
[
|
|
torch.ones_like(
|
|
torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1),
|
|
requires_grad=autograd,
|
|
)
|
|
]
|
|
+ [
|
|
torch.zeros_like(
|
|
torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1),
|
|
requires_grad=autograd,
|
|
)
|
|
]
|
|
)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("s5_operator", False),
|
|
init,
|
|
elements,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("s5_operator", False),
|
|
init=init,
|
|
xs=elements,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
init_flatten, _ = pytree.tree_flatten(init)
|
|
elements_flatten, _ = pytree.tree_flatten(elements)
|
|
|
|
result_flatten, _ = pytree.tree_flatten(result)
|
|
result_exp_flatten, _ = pytree.tree_flatten(expected_result)
|
|
grad_out = [torch.ones_like(el) for el in result_exp_flatten]
|
|
expected_grads = torch.autograd.grad(
|
|
result_exp_flatten, (*init_flatten, *elements_flatten), grad_out
|
|
)
|
|
grads = torch.autograd.grad(
|
|
result_flatten, (*init_flatten, *elements_flatten), grad_out
|
|
)
|
|
self.assertEqual(grads, expected_grads)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_tuple(self, reverse, compile_mode, device, autograd):
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
inp = (x, y)
|
|
init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp)
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
result_same = scan_fct(
|
|
get_scan_combine_fn("tuple_fct", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("tuple_fct", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result_same, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result_same, expected_result, (init, inp))
|
|
|
|
def fct_different_output_tuple(x, y):
|
|
return ((x[0] + y[0], x[1] * y[1]), (x[1] * y[1]))
|
|
|
|
inp = (x, y)
|
|
init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp)
|
|
|
|
result_diff = scan(
|
|
fct_different_output_tuple, init, inp, dim=0, reverse=reverse
|
|
)
|
|
expected_result = _fake_scan(
|
|
fct_different_output_tuple, init=init, xs=inp, dim=0, reverse=reverse
|
|
)
|
|
self.assertEqual(result_diff, expected_result)
|
|
self.assertEqual(result_diff[1], result_same[1][1])
|
|
|
|
if autograd:
|
|
self.check_autograd(result_diff, expected_result, (init, inp))
|
|
|
|
def test_scan_wrong_pytree(self):
|
|
# Init and input have same pytree
|
|
def fct_wrong_pytree(x, y):
|
|
return (
|
|
{
|
|
"i": x["i"] * y["j"][0][0],
|
|
"k": torch.tensor(0.0),
|
|
"j": (
|
|
[x["j"][1][0]["o"].clone()],
|
|
[{"o": torch.sin(x["i"])}],
|
|
),
|
|
},
|
|
{
|
|
"i": x["i"] * y["j"][0][0],
|
|
"k": torch.tensor(0.0),
|
|
"j": ([x["j"][1][0]["o"].clone()], [{"o": torch.sin(x["i"])}]),
|
|
},
|
|
)
|
|
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(3, 2, 2)
|
|
z = torch.randn(3, 2, 2)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
inp_flat, inp_spec = pytree.tree_flatten(inp)
|
|
init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat]
|
|
init = pytree.tree_unflatten(init_flat, inp_spec)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
# r"The tree structure of the inits and the carries are not identical.*",
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same number of outputs but got lhs.*",
|
|
):
|
|
scan(fct_wrong_pytree, init, inp, dim=0)
|
|
|
|
def test_scan_float_output(self):
|
|
# Init and input have same pytree
|
|
def fct_float_output(x, y):
|
|
return 0.0, x + y
|
|
|
|
x = torch.randn(3, 2, 2)
|
|
init = torch._ops.ops.aten.slice(x, 0, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be:
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "HigherOrderOperator body's output must consist of tensors or ints only"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely.*",
|
|
):
|
|
scan(fct_float_output, init, x, dim=0)
|
|
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_complex_pytree(self, reverse, compile_mode, device, autograd):
|
|
# Init and input have same pytree
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
inp_flat, inp_spec = pytree.tree_flatten(inp)
|
|
init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat]
|
|
init = pytree.tree_unflatten(init_flat, inp_spec)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, inp))
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
|
|
# Fails with: AssertionError: scan is not an OpOverload
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_scan_associative_scan(self):
|
|
combine_mode = "generic"
|
|
compile_mode_scan = "compile"
|
|
compile_mode_associative_scan = "none"
|
|
reverse = True
|
|
reverse_associative_scan = True
|
|
device = torch.device("cuda")
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode_scan)
|
|
associative_scan_fct = compile_mode_helper(
|
|
associative_scan, compile_mode_associative_scan
|
|
)
|
|
init = torch.randn(10, 5, device=device)
|
|
inp = torch.randn(3, 10, 5, device=device)
|
|
|
|
def body(x, y):
|
|
val = associative_scan_fct(
|
|
get_scan_combine_fn("add", True),
|
|
y,
|
|
0,
|
|
reverse=reverse_associative_scan,
|
|
combine_mode=combine_mode,
|
|
)
|
|
return x + y, x + val
|
|
|
|
result = scan_fct(body, init, inp, dim=0, reverse=reverse)
|
|
expected_result = _fake_scan(
|
|
body,
|
|
init,
|
|
inp,
|
|
0,
|
|
reverse=reverse,
|
|
)
|
|
|
|
self.assertEqual(result, expected_result)
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device, autograd):
|
|
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
init = torch.randn(3, 2, device=device, requires_grad=autograd)
|
|
|
|
for ind in range(2):
|
|
# Chain with matmul
|
|
def chain_fct(inp):
|
|
W = torch.ones(2, 5, device=device)
|
|
o = scan(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
return o[ind] @ W
|
|
|
|
fct_cmp = compile_mode_helper(chain_fct, compile_mode)
|
|
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)[ind] @ torch.ones(2, 5, device=device)
|
|
result = fct_cmp(inp)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, inp))
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_downstream_scan_scan_dim(
|
|
self, compile_mode, reverse, device, autograd
|
|
):
|
|
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
init = torch.randn(3, 2, device=device, requires_grad=autograd)
|
|
|
|
# Chain with scan on different dim
|
|
init2 = torch.randn(1, 10, 2, device=device, requires_grad=autograd)
|
|
|
|
def chain_fct_different_dim(inp):
|
|
o1 = scan(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
o1 = pytree.tree_map(lambda t: t.movedim(0, 1), o1)
|
|
o2 = scan(
|
|
get_scan_combine_fn("add", False),
|
|
init2,
|
|
o1[1],
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
return o2
|
|
|
|
fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode)
|
|
|
|
xs = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)[1]
|
|
xs = pytree.tree_map(lambda t: t.movedim(0, 1), xs)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init2,
|
|
xs=xs,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
result = fct_cmp(inp)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, init2, inp))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_non_pointwise(self, reverse, compile_mode, device, autograd):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
init = torch.randn(10, 2, device=device, requires_grad=autograd)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("non_pointwise", False),
|
|
init=init,
|
|
xs=x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("non_pointwise", False),
|
|
init,
|
|
x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, x))
|
|
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_scan_compile_cnt(self, reverse, device):
|
|
dim = 1
|
|
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
# Tests rely on automatic_dynamic = True
|
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=True):
|
|
cnt = CompileCounter()
|
|
x = torch.randn(3, 2, 5, device=device)
|
|
init = torch.randn(3, 5, device=device)
|
|
# First compilation step
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
x = torch.randn(3, 20, 5, device=device)
|
|
init = torch.randn(3, 5, device=device)
|
|
# Recompilation due to first different size
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
x = torch.randn(3, 40, 5, device=device)
|
|
init = torch.randn(3, 5, device=device)
|
|
# No recompilation, because of dynamic shape
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
x = torch.randn(3, 40, 5, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation because of dim change
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=2,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
x = torch.randn(3, 40, 20, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation due to first different size on new dim
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=2,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
x = torch.randn(3, 40, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# No recompilation, because of dynamic shape on new dim
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=2,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
x = torch.randn(3, 60, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation because of dim change
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 5)
|
|
|
|
x = torch.randn(3, 60, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation because of reverse change
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=not reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 6)
|
|
|
|
x = torch.randn(3, 60, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# No recompilation, as nothing changed
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=not reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 6)
|
|
|
|
x = torch.randn(3, 120, 80, device=device)
|
|
init = torch.randn(3, 80, device=device)
|
|
# No recompilation, final test
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 6)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_scanned_0(self):
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2, device=torch.device("cpu"))
|
|
init = torch.randn(3, 2, device=torch.device("cpu"))
|
|
dim = 1
|
|
|
|
# Scan dimension is 0
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"All xs leaves must at least have.*",
|
|
):
|
|
scan(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
inp,
|
|
dim=dim,
|
|
)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_non_tensor(self):
|
|
x = torch.randn(3, 1, 2, device=torch.device("cpu"))
|
|
dim = 1
|
|
|
|
# Init is a float and not a tensor
|
|
init = 1.0
|
|
with self.assertRaisesRegex(RuntimeError, "All init leaves must be a Tensor.*"):
|
|
scan(get_scan_combine_fn("add", False), init, x, dim=dim, reverse=False)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_shape(self):
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong shape (Other dim different)
|
|
init = torch.randn(1, 2)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same metadata.*",
|
|
):
|
|
scan_fct(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_pytree_init_longer_carry(self):
|
|
def init_longer_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return x[0] + 1.0, y + 1.0
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = (
|
|
torch._ops.ops.aten.slice(x, dim, 0, 1, 1),
|
|
torch._ops.ops.aten.slice(x, dim, 0, 1, 1),
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same number of outputs but got lhs.*",
|
|
):
|
|
scan_fct(init_longer_carry, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_pytree_init_shorter_carry(self):
|
|
def init_shorter_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return (x + 1, x + 2), x + 3
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# The tree structure of the inits and the carries are not identical!
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same number of outputs but got lhs.*",
|
|
):
|
|
scan_fct(init_shorter_carry, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_pytree_carry_shape(self):
|
|
def wrong_carry_shape(x: torch.Tensor, y: torch.Tensor):
|
|
return x[0, :], x + 3
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan_fct(wrong_carry_shape, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_one_return(self):
|
|
def no_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 3
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# combine_fn needs to produce two pytrees, one for the carries and one for the outputs.
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with.*",
|
|
):
|
|
scan_fct(no_carry, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_init(self, reverse, compile_mode, device, autograd):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2, device=device, requires_grad=autograd)
|
|
dim = 1
|
|
op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum)
|
|
|
|
# Only init given
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
result = scan_fct(op, init, [], dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=[], dim=dim, reverse=reverse)
|
|
result_init = scan_fct(op, init, [], dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0], init)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init,))
|
|
|
|
x = torch.randn(3, 5, 2, device=device, requires_grad=autograd)
|
|
dim = 0
|
|
|
|
op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum)
|
|
inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1)
|
|
|
|
# Init tensor scalar
|
|
init = torch.ones(1, device=device, requires_grad=autograd)
|
|
|
|
def add_scalar_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 1.0, x + y
|
|
|
|
result_init = scan_fct(add_scalar_carry, init, inp, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
add_scalar_carry, init=init, xs=inp, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0], torch.tensor([3.0], device=device))
|
|
|
|
if autograd:
|
|
self.check_autograd(result_init, result_exp, (init, inp))
|
|
|
|
# Init tensor entirely different shape than inp
|
|
init = torch.randn(7, 8, device=device, requires_grad=autograd)
|
|
|
|
def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 1.0, x[: y.shape[0], : y.shape[1]] + y
|
|
|
|
result_init = scan_fct(add_scalar_carry2, init, inp, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
add_scalar_carry2, init=init, xs=inp, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result_init, result_exp)
|
|
|
|
# Init with two timestep on dim axis. Should work as y has always 1 on dim axis and
|
|
# hence automatic broadcasting should work
|
|
# I.e., the input shape is 2x5x2, but the carry at each iteration is 2x5x2,
|
|
# thus the output of each iteration is 2x5x2, which results in the total output
|
|
# to be 4x5x2
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 2, 1)
|
|
result_init = scan_fct(op, init, inp, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=inp, dim=dim, reverse=reverse)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0].shape, torch.Size([2, 5, 2]))
|
|
|
|
if autograd:
|
|
self.check_autograd(result_init, result_exp, (init, inp))
|
|
|
|
init = torch.tile(init, (1, 2, 1))
|
|
|
|
def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 1.0, x[:, :1, :] + y
|
|
|
|
result_init = scan_fct(
|
|
add_scalar_carry_sliced_out, init, inp, dim=dim, reverse=reverse
|
|
)
|
|
result_exp = _fake_scan(
|
|
add_scalar_carry_sliced_out, init=init, xs=inp, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0].shape, torch.Size([2, 10, 2]))
|
|
self.assertEqual(result_init[1].shape, torch.Size([2, 2, 5, 2]))
|
|
|
|
if autograd:
|
|
self.check_autograd(result_init, result_exp, (init, inp))
|
|
|
|
# Correct case
|
|
op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum)
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
init = torch.zeros(3, 2, device=device, requires_grad=autograd)
|
|
dim = 2
|
|
|
|
result = scan_fct(op, init, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=dim, reverse=reverse)
|
|
|
|
self.assertEqual(result, result_exp)
|
|
if not reverse:
|
|
result_exp_PT = op_pt(x, dim)
|
|
result = list(result)
|
|
result[1] = pytree.tree_map(lambda t: torch.movedim(t, 0, dim), result[1])
|
|
self.assertEqual(result[1], result_exp_PT)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_scan_init_wrong_pytree_complex(self, reverse, device):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
z = torch.randn(3, 2, 2, device=device)
|
|
|
|
# Wrong pytree fed to the function
|
|
init = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, 0, 1, 1),
|
|
"j": (
|
|
{"a": torch._ops.ops.aten.slice(x, 0, 0, 1, 1)},
|
|
[torch._ops.ops.aten.slice(y, 0, 0, 1, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, 0, 1, 1)}],
|
|
),
|
|
}
|
|
inp = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, 0, None, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, 0, None, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, 0, None, 1)}],
|
|
),
|
|
}
|
|
with self.assertRaisesRegex(
|
|
Exception,
|
|
".*",
|
|
):
|
|
scan(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_init_pytree_complex(self, reverse, compile_mode, device, autograd):
|
|
def fct_pointwise_different_output(x, y):
|
|
return (
|
|
{
|
|
"i": x["i"] * y["i"],
|
|
"j": (
|
|
[x["j"][0][0] * y["j"][0][0]],
|
|
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
(
|
|
y["i"] * 2,
|
|
{
|
|
"o": x["i"] * y["i"],
|
|
"j": (
|
|
[x["j"][0][0] * y["j"][0][0]],
|
|
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
),
|
|
)
|
|
|
|
def fct_pointwise_different_carry(x, y):
|
|
return (
|
|
{
|
|
"i": x["i"] * y["i"],
|
|
"j": (
|
|
x["i"] * 2,
|
|
[x["j"][1][0] * y["j"][0][0]],
|
|
[{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
(
|
|
y["i"] * 2,
|
|
{
|
|
"o": x["i"] * y["i"] + x["j"][0][0],
|
|
"j": (
|
|
[x["j"][1][0] * y["j"][0][0]],
|
|
[{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
),
|
|
)
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
|
|
if reverse:
|
|
init_start, init_end = -1, None
|
|
inp_start, inp_end = 0, -1
|
|
else:
|
|
init_start, init_end = 0, 1
|
|
inp_start, inp_end = 1, None
|
|
|
|
# Regular case
|
|
init = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}],
|
|
),
|
|
}
|
|
inp = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}],
|
|
),
|
|
}
|
|
result = scan_fct(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
init_flat = pytree.tree_leaves(init)
|
|
inp_flat = pytree.tree_leaves(inp)
|
|
self.check_autograd(result, expected_result, (*init_flat, *inp_flat))
|
|
|
|
# Pytree of output is different
|
|
result = scan_fct(
|
|
fct_pointwise_different_output, init, inp, dim=0, reverse=reverse
|
|
)
|
|
expected_result = _fake_scan(
|
|
fct_pointwise_different_output, init=init, xs=inp, dim=0, reverse=reverse
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
# Pytree of carry is different
|
|
init = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1),
|
|
"j": (
|
|
torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1),
|
|
[torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}],
|
|
),
|
|
}
|
|
inp = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}],
|
|
),
|
|
}
|
|
result = scan_fct(
|
|
fct_pointwise_different_carry, init, inp, dim=0, reverse=reverse
|
|
)
|
|
expected_result = _fake_scan(
|
|
fct_pointwise_different_carry, init=init, xs=inp, dim=0, reverse=reverse
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
init_flat = pytree.tree_leaves(init)
|
|
inp_flat = pytree.tree_leaves(inp)
|
|
self.check_autograd(result, expected_result, (*init_flat, *inp_flat))
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@skipIfNoDynamoSupport
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_scan_pytree_output(self):
|
|
x = torch.randn(3, 10, 2, device=torch.device("cpu"))
|
|
init = torch.randn(1, 10, 2, device=torch.device("cpu"))
|
|
|
|
def f(fct, init, xs):
|
|
return scan(fct, init, xs, dim=0, reverse=True)
|
|
|
|
def combine_fn(init, x):
|
|
a, b = (init[0] + x, init[1] - x)
|
|
return (a, b), a - b
|
|
|
|
# Check graph
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(combine_fn, (init, init.clone()), x)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(gm.print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_init_0_: "f32[1, 10, 2]", L_init_1_: "f32[1, 10, 2]", L_xs_: "f32[3, 10, 2]"):
|
|
l_init_0_ = L_init_0_
|
|
l_init_1_ = L_init_1_
|
|
l_xs_ = L_xs_
|
|
|
|
elem: "f32[3, 10, 2]" = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
|
|
flip: "f32[3, 10, 2]" = torch.flip(elem, [0]); elem = None
|
|
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_0_, l_init_1_], [flip], []); scan_combine_fn_0 = l_init_0_ = l_init_1_ = flip = None
|
|
getitem: "f32[1, 10, 2]" = scan[0]
|
|
getitem_1: "f32[1, 10, 2]" = scan[1]
|
|
out: "f32[3, 1, 10, 2]" = scan[2]; scan = None
|
|
|
|
out_1: "f32[3, 1, 10, 2]" = out.flip([0]); out = None
|
|
return (getitem, getitem_1, out_1)
|
|
|
|
class scan_combine_fn_0(torch.nn.Module):
|
|
def forward(self, child: "f32[1, 10, 2]", child_1: "f32[1, 10, 2]", child_2: "f32[10, 2]"):
|
|
a: "f32[1, 10, 2]" = child + child_2; child = None
|
|
b: "f32[1, 10, 2]" = child_1 - child_2; child_1 = child_2 = None
|
|
|
|
child_3: "f32[1, 10, 2]" = a - b
|
|
return [a, b, child_3]
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_RNN(self, compile_mode, autograd):
|
|
dim = 1
|
|
device = torch.device("cpu")
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
rnn = torch.nn.RNN(
|
|
input_size=5,
|
|
hidden_size=7,
|
|
batch_first=True,
|
|
)
|
|
rnn = rnn.to(device=device)
|
|
x = torch.randn(3, 10, 5, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
W_ih = rnn.weight_ih_l0.T.clone()
|
|
b_ih = rnn.bias_ih_l0.clone()
|
|
W_hh = rnn.weight_hh_l0.T.clone()
|
|
b_hh = rnn.bias_hh_l0.clone()
|
|
|
|
if not autograd:
|
|
W_ih = W_ih.detach()
|
|
b_ih = b_ih.detach()
|
|
W_hh = W_hh.detach()
|
|
b_hh = b_hh.detach()
|
|
|
|
expected_result = rnn(x, torch.unsqueeze(h, 0))
|
|
expected_result_out = expected_result[0]
|
|
expected_result_state = expected_result[1][0, :]
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("RNN", True, parameters=[W_ih, b_ih, W_hh, b_hh]),
|
|
h,
|
|
x,
|
|
dim=dim,
|
|
reverse=False,
|
|
)
|
|
result_cmp = [result[0], torch.movedim(result[1], 0, dim)]
|
|
self.assertEqual(result_cmp[0], expected_result_state)
|
|
self.assertEqual(result_cmp[1], expected_result_out)
|
|
|
|
if autograd:
|
|
result_flat = pytree.tree_leaves(result)
|
|
result_exp_flat = [expected_result_state, expected_result_out]
|
|
|
|
grad_out_expected = [torch.ones_like(r) for r in result_exp_flat]
|
|
expected_grads = torch.autograd.grad(
|
|
result_exp_flat,
|
|
(
|
|
h,
|
|
x,
|
|
rnn.weight_ih_l0,
|
|
rnn.bias_ih_l0,
|
|
rnn.weight_hh_l0,
|
|
rnn.bias_hh_l0,
|
|
),
|
|
grad_out_expected,
|
|
)
|
|
expected_add_input_grads = list(expected_grads[2:])
|
|
expected_grads = expected_grads[:2]
|
|
|
|
grad_out = [torch.ones_like(r) for r in result]
|
|
grads = torch.autograd.grad(
|
|
result_flat, (h, x, W_ih, b_ih, W_hh, b_hh), grad_out
|
|
)
|
|
add_input_grads = list(grads[2:])
|
|
add_input_grads[0] = add_input_grads[0].T
|
|
add_input_grads[2] = add_input_grads[2].T
|
|
grads = grads[:2]
|
|
self.assertEqual(grads, expected_grads)
|
|
self.assertEqual(add_input_grads, expected_add_input_grads)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize(
|
|
"partial_grad", ["xs", "init", "additional_inputs", "complex", "random"]
|
|
)
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_scan_closure_RNN_partial_autograd(
|
|
self, reverse, compile_mode, partial_grad, device
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# The first two booleans are the xs
|
|
# The second two are the inits
|
|
# The last four are the additional_inputs
|
|
autograds = []
|
|
|
|
if partial_grad == "xs":
|
|
# xs tests
|
|
autograds.append([True, False, True, True, True, True, True, True])
|
|
autograds.append([False, False, True, True, True, True, True, True])
|
|
elif partial_grad == "init":
|
|
# init tests
|
|
autograds.append([True, True, False, True, True, True, True, True])
|
|
autograds.append([True, True, False, False, True, True, True, True])
|
|
elif partial_grad == "additional_inputs":
|
|
# additional input tests
|
|
autograds.append([True, True, True, True, False, True, False, True])
|
|
autograds.append([True, True, True, True, False, False, False, False])
|
|
elif partial_grad == "complex":
|
|
# complex cases
|
|
autograds.append([True, False, False, False, False, False, False, True])
|
|
autograds.append([False, False, True, True, False, False, False, True])
|
|
elif partial_grad == "random":
|
|
# random tests
|
|
import random
|
|
|
|
for _ in range(5):
|
|
autograds.append([bool(random.randint(0, 1)) for _ in range(8)])
|
|
|
|
for autograd in autograds:
|
|
x = torch.randn(3, 10, 5, device=device, requires_grad=autograd[0])
|
|
x1 = torch.randn(3, 10, 5, device=device, requires_grad=autograd[1])
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd[2])
|
|
h_1 = torch.randn(3, 7, device=device, requires_grad=autograd[3])
|
|
W_ih = torch.randn(5, 7, device=device, requires_grad=autograd[4])
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd[5])
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd[6])
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd[7])
|
|
|
|
params = [
|
|
p
|
|
for p, a in zip([x, x1, h, h_1, W_ih, b_ih, W_hh, b_hh], autograd)
|
|
if a
|
|
]
|
|
|
|
def RNN(x: torch.Tensor, y: torch.Tensor):
|
|
c_new_0 = x[0] + 1
|
|
c_new_1 = x[1] + 1
|
|
h_new = (
|
|
torch.tanh(c_new_1 + x[0] @ W_hh + b_hh)
|
|
+ y[0] @ W_ih
|
|
+ y[1] @ W_ih
|
|
+ b_ih
|
|
+ x[1]
|
|
)
|
|
return (c_new_0, c_new_1), h_new
|
|
|
|
inits = (h, h_1)
|
|
result = scan_fct(RNN, inits, (x, x1), dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(RNN, (h, h_1), (x, x1), dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
result_flat = pytree.tree_leaves(result)
|
|
result_exp_flat = pytree.tree_leaves(result_exp)
|
|
exp_grad_mask = [
|
|
True if r.requires_grad else False for r in result_exp_flat
|
|
]
|
|
self.check_autograd(
|
|
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
|
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
|
|
params,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_init_carries_unequal_grad(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
h2 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
result_exp = _fake_scan(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
# TODO: Ideally we should be able to select the results that require gradients like this
|
|
# [leaf for leaf in pytree.tree_leaves(result) if leaf.requires_grad == True]
|
|
# However, for the scan operator this does not work, as all outputs always have
|
|
# grad_fn=<ScanAutogradOpBackward>
|
|
res_req_grad_flat = pytree.tree_leaves(result)[1:]
|
|
res_exp_req_grad_flat = pytree.tree_leaves(result_exp)[1:]
|
|
self.check_autograd(res_req_grad_flat, res_exp_req_grad_flat, (x, h2))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_init_carries_equal_grad(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 7, device=device, requires_grad=False)
|
|
h2 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
result_exp = _fake_scan(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
# TODO: Ideally we should be able to select the results that require gradients like this
|
|
# [leaf for leaf in pytree.tree_leaves(result) if leaf.requires_grad == True]
|
|
# However, for the scan operator this does not work, as all outputs always have
|
|
# grad_fn=<ScanAutogradOpBackward>
|
|
res_req_grad_flat = pytree.tree_leaves(result)[1:]
|
|
res_exp_req_grad_flat = pytree.tree_leaves(result_exp)[1:]
|
|
self.check_autograd(res_req_grad_flat, res_exp_req_grad_flat, (x, h2))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_for_out(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
h2 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
def fct_ys_no_grad(x: torch.Tensor, y: torch.Tensor):
|
|
c1 = x[0] + y
|
|
c2 = x[1] + y
|
|
with torch.no_grad():
|
|
h_new = torch.tanh(x[0] + x[1] + y)
|
|
return (c1, c2), h_new
|
|
|
|
result = scan_fct(fct_ys_no_grad, (h1, h2), x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(fct_ys_no_grad, (h1, h2), x, dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[0], result_exp[0], (x, h1, h2))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_additional_inputs_partial(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W_ih = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def fct_no_grad_bhh_Whh(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_ih + b_ih + x
|
|
|
|
h_new = c_new + 1
|
|
with torch.no_grad():
|
|
h_new_no_grad = torch.tanh(x @ W_hh + b_hh)
|
|
h_new2 = h_new + h_new_no_grad
|
|
|
|
return c_new, h_new2
|
|
|
|
result = scan_fct(fct_no_grad_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(fct_no_grad_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[1], result_exp[1], (h, x, W_ih, b_ih))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_additional_inputs_all(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W_ih = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def fct_no_grad_bih_Wih_bhh_Whh(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = x + y
|
|
h_new = c_new + x
|
|
with torch.no_grad():
|
|
c_new_no_grad = y @ W_ih + b_ih
|
|
h_new_no_grad = torch.tanh(x @ W_hh + b_hh)
|
|
c_new2 = c_new + c_new_no_grad
|
|
h_new2 = h_new + h_new_no_grad
|
|
return c_new2, h_new2
|
|
|
|
result = scan_fct(fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[1], result_exp[1], (h, x))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_carries_ys_same_grad(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W_ih = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def fct_no_grad_bih_Wih_bhh_Whh(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = x + y
|
|
h_new = c_new + 1
|
|
with torch.no_grad():
|
|
c_new_no_grad = y @ W_ih + b_ih
|
|
h_new_no_grad = torch.tanh(x @ W_hh + b_hh)
|
|
c_new2 = c_new + c_new_no_grad
|
|
h_new2 = h_new + h_new_no_grad
|
|
return c_new2, h_new2
|
|
|
|
result = scan_fct(fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[1], result_exp[1], (h, x))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_nested(self, reverse, compile_mode, device, autograd):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# Simple non-nested case
|
|
x = torch.randn(3, 20, 5, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W = torch.randn(5, 7, device=device, requires_grad=autograd)
|
|
b = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def f1(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W + b
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
result = scan_fct(f1, h, x, dim=1, reverse=reverse)
|
|
result_exp = _fake_scan(f1, h, x, dim=1, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (h, x, W, b))
|
|
|
|
# Nested case
|
|
def chain_fct(fct, f_1, f_2, xs, h_1, h_2):
|
|
o1 = fct(
|
|
f_1,
|
|
h_1,
|
|
xs,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
o2 = fct(
|
|
f_2,
|
|
h_2,
|
|
o1[1],
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
return o2
|
|
|
|
x1 = torch.ones(3, 20, 5, device=device, requires_grad=autograd)
|
|
h1 = torch.zeros(3, 7, device=device, requires_grad=autograd)
|
|
h2 = torch.zeros(3, 3, device=device, requires_grad=autograd)
|
|
W_1 = torch.randn(5, 7, device=device, requires_grad=autograd)
|
|
b_1 = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_2 = torch.randn(7, 3, device=device, requires_grad=autograd)
|
|
b_2 = torch.randn(3, device=device, requires_grad=autograd)
|
|
|
|
def f1(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_1 + b_1
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
def f2(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_2 + b_2
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
result1 = chain_fct(scan_fct, f1, f2, x1, h1, h2)
|
|
expected_result = chain_fct(_fake_scan, f1, f2, x1, h1, h2)
|
|
self.assertEqual(result1, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result1, expected_result, (h1, h2, x1, W_1, b_1))
|
|
|
|
# Complex case
|
|
x1 = torch.randn(3, 20, 3, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
h2 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
W_1 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
b_1 = torch.randn(3, device=device, requires_grad=autograd)
|
|
W_2 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
b_2 = torch.randn(3, device=device, requires_grad=autograd)
|
|
|
|
def f1(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_1 + b_1
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
def f2(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_2 + b_2 * b_1 + y @ W_1
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
result1 = chain_fct(scan_fct, f1, f2, x1, h1, h2)
|
|
expected_result = chain_fct(_fake_scan, f1, f2, x1, h1, h2)
|
|
self.assertEqual(result1, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(
|
|
result1, expected_result, (h1, h2, x1, W_1, b_1, W_2, b_2)
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_scan_simple_graph_wrong_dtype(self):
|
|
def add_wrong_dtype(x: torch.Tensor, y: torch.Tensor):
|
|
return torch.ones_like(x + y, dtype=torch.int64), x + y
|
|
|
|
x = torch.randn(3, 10, 2, device=torch.device("cpu"))
|
|
init = torch.randn(1, 10, 2, device=torch.device("cpu"))
|
|
|
|
def f(fct, init, xs):
|
|
return scan(fct, init, xs, dim=0, reverse=True)
|
|
|
|
# Wrong dtype
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same metadata.*",
|
|
):
|
|
f(add_wrong_dtype, init, x)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_scan_simple_graph(self):
|
|
x = torch.randn(3, 10, 2, device=torch.device("cpu"))
|
|
init = torch.randn(1, 10, 2, device=torch.device("cpu"))
|
|
|
|
def f(fct, init, xs):
|
|
return scan(fct, init, xs, dim=0, reverse=True)
|
|
|
|
# Correct case
|
|
gm = make_fx(f, tracing_mode="symbolic")(
|
|
get_scan_combine_fn("add", False), init, x
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, fct_1, init_1, xs_1):
|
|
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
|
|
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
|
|
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
|
|
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None
|
|
scan_combine_graph_0 = self.scan_combine_graph_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], (sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4)); scan_combine_graph_0 = init_1 = flip = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
|
|
getitem = scan[0]
|
|
getitem_1 = scan[1]; scan = None
|
|
flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None
|
|
return (getitem, flip_1)""", # noqa: B950
|
|
)
|
|
|
|
# Check graph
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(get_scan_combine_fn("add", False), init, x)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
|
|
l_init_ = L_init_
|
|
l_xs_ = L_xs_
|
|
elem = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
flip = torch.flip(elem, [0]); elem = None
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [flip], []); scan_combine_fn_0 = l_init_ = flip = None
|
|
carry = scan[0]
|
|
out = scan[1]; scan = None
|
|
out_1 = out.flip([0]); out = None
|
|
return (carry, out_1)""", # noqa: B950
|
|
)
|
|
|
|
@requires_cuda
|
|
def test_scan_input_mutation(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_mutation(x, y):
|
|
x.add_(1)
|
|
return x + y, x + y + 2
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
init = torch.randn(2, 2, device=device)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_input_mutation, init, x, dim=0)
|
|
|
|
@requires_cuda
|
|
def test_scan_input_carry_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_output_alias(x, y):
|
|
return (x[0], x[1] + y[1]), (x[1] + y[1] + 1, x[1] + y[1] + 2)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_input_output_alias, init, inp, dim=0)
|
|
|
|
@requires_cuda
|
|
def test_scan_input_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_output_alias(x, y):
|
|
return (x[0] + 1, x[1] + y[1]), (x[1], x[1] + y[1] + 2)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_input_output_alias, init, inp, dim=0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_scan_carry_carry_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_carry_carry_alias(x, y):
|
|
c = x[0] + y[1]
|
|
return (c, c), (x[0] + y[1], x[0] + y[1] + 1)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_carry_carry_alias, init, inp, dim=0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_scan_carry_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_carry_output_alias(x, y):
|
|
c = x[0] + y[1]
|
|
return (x[0] + y[1], c), (c, x[0] + y[1] + 1)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_carry_output_alias, init, inp, dim=0)
|
|
|
|
|
|
class AssociativeScanModels:
|
|
@staticmethod
|
|
def get_scan_fct(compile_mode, combine_mode):
|
|
# Compile the associative_scan according to the provided compile_mode
|
|
if compile_mode != "fake":
|
|
assoc_scan_comp = compile_mode_helper(associative_scan, compile_mode)
|
|
|
|
def scan_fct(combine_fn, xs, dim, reverse):
|
|
return assoc_scan_comp(combine_fn, xs, dim, reverse, combine_mode)
|
|
|
|
else:
|
|
scan_fct = _fake_associative_scan
|
|
return scan_fct
|
|
|
|
class CombineFn(torch.nn.Module):
|
|
def __init__(self, combine_fn, dim, reverse, combine_mode, compile_mode):
|
|
super().__init__()
|
|
|
|
self.scan_fct = AssociativeScanModels.get_scan_fct(
|
|
compile_mode, combine_mode
|
|
)
|
|
self.combine_fn = combine_fn
|
|
self.dim = dim
|
|
self.reverse = reverse
|
|
|
|
def forward(self, inputs):
|
|
results = self.scan_fct(self.combine_fn, inputs, self.dim, self.reverse)
|
|
return results
|
|
|
|
class Simple(torch.nn.Module):
|
|
def __init__(self, dim, reverse, combine_mode, compile_mode):
|
|
super().__init__()
|
|
|
|
kwargs = {
|
|
"dim": dim,
|
|
"reverse": reverse,
|
|
"combine_mode": combine_mode,
|
|
"compile_mode": compile_mode,
|
|
}
|
|
self.combine_fns = [
|
|
AssociativeScanModels.CombineFn(
|
|
get_scan_combine_fn("add", True), **kwargs
|
|
),
|
|
AssociativeScanModels.CombineFn(
|
|
get_scan_combine_fn("mul", True), **kwargs
|
|
),
|
|
]
|
|
|
|
def forward(self, inputs):
|
|
results = []
|
|
for combine_fn in self.combine_fns:
|
|
results.append(combine_fn(inputs))
|
|
return results
|
|
|
|
class ChainFn(torch.nn.Module):
|
|
def __init__(self, combine_fn, dim, reverse, combine_mode, compile_mode):
|
|
super().__init__()
|
|
|
|
chain_len = len(combine_fn)
|
|
kwargs = {
|
|
"combine_fn": combine_fn,
|
|
"dim": dim,
|
|
"reverse": reverse,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
|
|
# Prepare the kwargs as a list.
|
|
self.nested_tuple = []
|
|
for ind in range(chain_len):
|
|
kwargs_el = {}
|
|
for key, val in kwargs.items():
|
|
# Check if val is a list and if it has the same length as combine_fn
|
|
# If so, then use the individual elements.
|
|
# If not, duplicate the first element.
|
|
if type(val) == list and len(val) == chain_len:
|
|
kwargs_el[key] = val[ind]
|
|
else:
|
|
kwargs_el[key] = val
|
|
|
|
scan_fct = AssociativeScanModels.get_scan_fct(
|
|
compile_mode, kwargs_el["combine_mode"]
|
|
)
|
|
combine_fn = kwargs_el["combine_fn"]
|
|
del kwargs_el["combine_fn"]
|
|
del kwargs_el["combine_mode"]
|
|
self.nested_tuple.append((combine_fn, scan_fct, kwargs_el))
|
|
|
|
def forward(self, inputs):
|
|
results = inputs
|
|
for combine_fn, scan_fct, kwargs in self.nested_tuple:
|
|
results = combine_fn(scan_fct, results, **kwargs)
|
|
return results
|
|
|
|
class NestedFn(torch.nn.Module):
|
|
def forward(self, scan_fct, inputs, **kwargs):
|
|
combine_fn = kwargs["combine_fn"]
|
|
|
|
# Remove combine_fn from kwargs
|
|
del kwargs["combine_fn"]
|
|
|
|
results = scan_fct(combine_fn, inputs, **kwargs)
|
|
|
|
return results
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfNoDynamoSupport
|
|
class AssociativeScanTests(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def _run_test(self, model, model_fake, inputs):
|
|
result = model(inputs)
|
|
result_exp = model_fake(inputs)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
# Return the result of the functions under test for further investigations
|
|
return result
|
|
|
|
def _prepare_fake_kwargs(self, original_kwargs):
|
|
kwargs_fake = original_kwargs.copy()
|
|
kwargs_fake["compile_mode"] = "fake"
|
|
return kwargs_fake
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_compile(
|
|
self, combine_mode, reverse, compile_mode, device
|
|
):
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
results = self._run_test(
|
|
model=AssociativeScanModels.Simple(**kwargs),
|
|
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
if not reverse:
|
|
results_torch = []
|
|
for op_pt in [torch.cumsum, torch.cumprod]:
|
|
results_torch.append(op_pt(x, 0))
|
|
self.assertEqual(results, results_torch)
|
|
|
|
# Jax Examples
|
|
x = torch.arange(0, 4, device=device)
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("add", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
result = self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
if not reverse:
|
|
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
|
|
else:
|
|
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
|
|
|
|
self.assertEqual(result, results_torch)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_dim(self, combine_mode, compile_mode, reverse, device):
|
|
import random
|
|
|
|
random.seed(1234)
|
|
|
|
num_dims = [random.randint(2, 5) for _ in range(4)]
|
|
for num_dim in num_dims:
|
|
# To avoid triggering automatic dynamic shape
|
|
torch._dynamo.reset()
|
|
shapes = [random.randint(1, 9) for _ in range(num_dim)]
|
|
rnd_scan_dim = random.randint(0, num_dim - 1)
|
|
x = torch.randn(*shapes, device=device)
|
|
|
|
kwargs = {
|
|
"dim": rnd_scan_dim,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
results = self._run_test(
|
|
model=AssociativeScanModels.Simple(**kwargs),
|
|
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
if not reverse:
|
|
results_torch = []
|
|
for op_pt in [torch.cumsum, torch.cumprod]:
|
|
results_torch.append(op_pt(x, 0))
|
|
self.assertEqual(results, results_torch)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_dim_shape_failure(self, compile_mode, combine_mode):
|
|
num_dims = [2]
|
|
for num_dim in num_dims:
|
|
shapes = [9 for _ in range(num_dim)]
|
|
rnd_scan_dim = 0
|
|
x = torch.randn(*shapes, device=torch.device("cuda"))
|
|
|
|
kwargs = {
|
|
"dim": rnd_scan_dim,
|
|
"reverse": True,
|
|
"compile_mode": "compile",
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.Simple(**kwargs),
|
|
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_tuple(self, compile_mode, combine_mode, reverse, device):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("tuple_fct", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_expand_in_combine_fn(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
|
|
def combine_fn(x, y):
|
|
return x * torch.sum(y, -1).expand(x.shape)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_non_contiguous_tensor(
|
|
self, compile_mode, reverse, device
|
|
):
|
|
x = torch.arange(30, device=device).view(10, 3).t()
|
|
assert not x.is_contiguous()
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("add", True),
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_complex_pytree(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
z = torch.randn(3, 2, 2, device=device)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("complex_pointwise", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@skipIfNoDynamoSupport
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_associative_scan_pytree_output(self):
|
|
x = (
|
|
(
|
|
torch.randn(3, 10, 2, device=torch.device("cpu")),
|
|
(torch.randn(3, 10, 2, device=torch.device("cpu")),),
|
|
),
|
|
torch.randn(3, 10, 2, device=torch.device("cpu")),
|
|
)
|
|
|
|
def f(fct, xs):
|
|
return associative_scan(
|
|
fct, xs, dim=0, reverse=True, combine_mode="generic"
|
|
)
|
|
|
|
def combine_fn(x: torch.Tensor, y: torch.Tensor):
|
|
a, b = (x[0][0] + y[1], x[0][1][0] - y[1])
|
|
return (a, (b,)), a - b
|
|
|
|
# Check graph
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(combine_fn, x)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(gm.print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs_1_: "f32[3, 10, 2]"):
|
|
l_xs_0_0_ = L_xs_0_0_
|
|
l_xs_0_1_0_ = L_xs_0_1_0_
|
|
l_xs_1_ = L_xs_1_
|
|
|
|
elem: "f32[3, 10, 2]" = torch.movedim(l_xs_0_0_, 0, 0); l_xs_0_0_ = None
|
|
elem_1: "f32[3, 10, 2]" = torch.movedim(l_xs_0_1_0_, 0, 0); l_xs_0_1_0_ = None
|
|
elem_2: "f32[3, 10, 2]" = torch.movedim(l_xs_1_, 0, 0); l_xs_1_ = None
|
|
|
|
elem_3: "f32[3, 10, 2]" = torch.flip(elem, [0]); elem = None
|
|
elem_4: "f32[3, 10, 2]" = torch.flip(elem_1, [0]); elem_1 = None
|
|
elem_5: "f32[3, 10, 2]" = torch.flip(elem_2, [0]); elem_2 = None
|
|
|
|
child: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 0, -1, 2)
|
|
child_1: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 0, -1, 2)
|
|
child_2: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 0, -1, 2)
|
|
|
|
child_3: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 1, None, 2)
|
|
child_4: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 1, None, 2)
|
|
child_5: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 1, None, 2)
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
_add_batch_dim_1: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_1, 0, 1); child_1 = None
|
|
_add_batch_dim_2: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_2, 0, 1); child_2 = _add_batch_dim_2 = None
|
|
_add_batch_dim_3: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_3, 0, 1); child_3 = _add_batch_dim_3 = None
|
|
_add_batch_dim_4: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_4, 0, 1); child_4 = _add_batch_dim_4 = None
|
|
_add_batch_dim_5: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_5, 0, 1); child_5 = None
|
|
|
|
a: "f32[10, 2]" = _add_batch_dim + _add_batch_dim_5; _add_batch_dim = None
|
|
b: "f32[10, 2]" = _add_batch_dim_1 - _add_batch_dim_5; _add_batch_dim_1 = _add_batch_dim_5 = None
|
|
|
|
child_6: "f32[10, 2]" = a - b
|
|
|
|
child_7: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(a, 1, 1, 0); a = None
|
|
child_8: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(b, 1, 1, 0); b = None
|
|
child_9: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(child_6, 1, 1, 0); child_6 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
child_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 2, None, 2)
|
|
child_11: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 2, None, 2)
|
|
child_12: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 2, None, 2)
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_6: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_7, 0, 1)
|
|
_add_batch_dim_7: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_8, 0, 1)
|
|
_add_batch_dim_8: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_9, 0, 1); _add_batch_dim_8 = None
|
|
_add_batch_dim_9: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_10, 0, 1); child_10 = _add_batch_dim_9 = None
|
|
_add_batch_dim_10: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_11, 0, 1); child_11 = _add_batch_dim_10 = None
|
|
_add_batch_dim_11: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_12, 0, 1); child_12 = None
|
|
|
|
a_1: "f32[10, 2]" = _add_batch_dim_6 + _add_batch_dim_11; _add_batch_dim_6 = None
|
|
b_1: "f32[10, 2]" = _add_batch_dim_7 - _add_batch_dim_11; _add_batch_dim_7 = _add_batch_dim_11 = None
|
|
|
|
child_13: "f32[10, 2]" = a_1 - b_1
|
|
|
|
child_14: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(a_1, 1, 1, 0); a_1 = None
|
|
child_15: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(b_1, 1, 1, 0); b_1 = None
|
|
child_16: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(child_13, 1, 1, 0); child_13 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
|
|
slice_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 0, 1); elem_3 = None
|
|
cat: "f32[2, 10, 2]" = torch.cat([slice_10, child_14], dim = 0); slice_10 = child_14 = None
|
|
slice_11: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 0, 1); elem_4 = None
|
|
cat_1: "f32[2, 10, 2]" = torch.cat([slice_11, child_15], dim = 0); slice_11 = child_15 = None
|
|
slice_12: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 0, 1); elem_5 = None
|
|
cat_2: "f32[2, 10, 2]" = torch.cat([slice_12, child_16], dim = 0); slice_12 = child_16 = None
|
|
|
|
b_2: "f32[2, 10, 2]" = torch._C._nn.pad(child_7, [0, 0, 0, 0, 0, 1], 'constant', None); child_7 = None
|
|
|
|
stacked: "f32[2, 2, 10, 2]" = torch.stack([cat, b_2], dim = 1); cat = b_2 = None
|
|
|
|
interleaved: "f32[4, 10, 2]" = torch.flatten(stacked, start_dim = 0, end_dim = 1); stacked = None
|
|
|
|
interleaved_1: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved, 0, 0, 3); interleaved = None
|
|
|
|
b_3: "f32[2, 10, 2]" = torch._C._nn.pad(child_8, [0, 0, 0, 0, 0, 1], 'constant', None); child_8 = None
|
|
|
|
stacked_1: "f32[2, 2, 10, 2]" = torch.stack([cat_1, b_3], dim = 1); cat_1 = b_3 = None
|
|
|
|
interleaved_2: "f32[4, 10, 2]" = torch.flatten(stacked_1, start_dim = 0, end_dim = 1); stacked_1 = None
|
|
|
|
interleaved_3: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved_2, 0, 0, 3); interleaved_2 = None
|
|
|
|
b_4: "f32[2, 10, 2]" = torch._C._nn.pad(child_9, [0, 0, 0, 0, 0, 1], 'constant', None); child_9 = None
|
|
|
|
stacked_2: "f32[2, 2, 10, 2]" = torch.stack([cat_2, b_4], dim = 1); cat_2 = b_4 = None
|
|
|
|
interleaved_4: "f32[4, 10, 2]" = torch.flatten(stacked_2, start_dim = 0, end_dim = 1); stacked_2 = None
|
|
|
|
interleaved_5: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved_4, 0, 0, 3); interleaved_4 = None
|
|
|
|
child_17: "f32[3, 10, 2]" = interleaved_1.flip([0]); interleaved_1 = None
|
|
child_18: "f32[3, 10, 2]" = interleaved_3.flip([0]); interleaved_3 = None
|
|
child_19: "f32[3, 10, 2]" = interleaved_5.flip([0]); interleaved_5 = None
|
|
|
|
movedim_3: "f32[3, 10, 2]" = torch.movedim(child_17, 0, 0); child_17 = None
|
|
movedim_4: "f32[3, 10, 2]" = torch.movedim(child_18, 0, 0); child_18 = None
|
|
movedim_5: "f32[3, 10, 2]" = torch.movedim(child_19, 0, 0); child_19 = None
|
|
return (movedim_3, movedim_4, movedim_5)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_downstream_scan_matmul(
|
|
self, combine_mode, compile_mode, reverse, device
|
|
):
|
|
def first_chain_fct(scan_fct, inp, **kwargs):
|
|
o = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o
|
|
|
|
def second_chain_fct(scan_fct, inp, **kwargs):
|
|
W = torch.ones(2, 5, device=device)
|
|
return inp @ W
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
kwargs = {
|
|
"dim": 1,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": [first_chain_fct, second_chain_fct],
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.ChainFn(**kwargs),
|
|
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_downstream_scan_scan(
|
|
self, combine_mode, compile_mode, reverse, device
|
|
):
|
|
def first_chain_fct(scan_fct, inp, **kwargs):
|
|
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o1
|
|
|
|
def second_chain_fct(scan_fct, inp, **kwargs):
|
|
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o2
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 1,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": [first_chain_fct, second_chain_fct],
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.ChainFn(**kwargs),
|
|
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse_first", [False, True])
|
|
@parametrize("same_direction", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_downstream_scan_scan_different_dim(
|
|
self, combine_mode, compile_mode, reverse_first, same_direction, device
|
|
):
|
|
reverse_second = reverse_first if same_direction else not reverse_first
|
|
|
|
def first_chain_fct(scan_fct, inp, **kwargs):
|
|
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o1
|
|
|
|
def second_chain_fct(scan_fct, inp, **kwargs):
|
|
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o2
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": [1, 0],
|
|
"reverse": [reverse_first, reverse_second],
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": [first_chain_fct, second_chain_fct],
|
|
"combine_mode": [combine_mode, combine_mode],
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.ChainFn(**kwargs),
|
|
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# TODO: Re-enable additional parameters again once this issues has been resolved
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_nested(self):
|
|
combine_mode = "pointwise"
|
|
compile_mode = "eager"
|
|
reverse_first = False
|
|
same_direction = False
|
|
device = torch.device("cuda")
|
|
|
|
reverse_second = reverse_first if same_direction else not reverse_first
|
|
|
|
def first_nested_fct(x, y):
|
|
y_new = associative_scan(
|
|
second_nested_fct,
|
|
y,
|
|
0,
|
|
reverse=reverse_second,
|
|
combine_mode=combine_mode,
|
|
)
|
|
return x + y_new
|
|
|
|
def first_nested_fct_fake(x, y):
|
|
y_new = _fake_associative_scan(
|
|
second_nested_fct, y, 0, reverse=reverse_second
|
|
)
|
|
return x + y_new
|
|
|
|
def second_nested_fct(x, y):
|
|
return x * y
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse_first,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": first_nested_fct,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
kwargs_fake["combine_fn"] = first_nested_fct_fake
|
|
self._run_test(
|
|
model=AssociativeScanModels.NestedFn(**kwargs),
|
|
model_fake=AssociativeScanModels.NestedFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("loop_type", ["for"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_loop_in_combine_fn(
|
|
self, compile_mode, loop_type, reverse, device
|
|
):
|
|
def combine_fn(x, y):
|
|
cnt = torch.zeros_like(y[0, :])
|
|
if loop_type == "while":
|
|
|
|
def cond_fn(ind, loop_val):
|
|
return (loop_val < 5)[0]
|
|
|
|
def body_fn(ind, loop_val):
|
|
return ind + 1, loop_val + torch.abs(ind)
|
|
|
|
new_ind, cnt = torch.while_loop(
|
|
cond_fn=cond_fn,
|
|
body_fn=body_fn,
|
|
carried_inputs=(
|
|
torch.zeros(1, dtype=torch.int32, device=cnt.device),
|
|
cnt,
|
|
),
|
|
)
|
|
else:
|
|
for ind in range(10):
|
|
cnt += torch.abs(y[ind])
|
|
return x * cnt
|
|
|
|
inp = torch.randn(3, 10, 1, device=device) * 2
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# TODO: Re-enable additional parameters again once this issues has been resolved
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_loop_in_combine_fn_failure(self):
|
|
compile_mode = "none"
|
|
loop_type = "while"
|
|
reverse = False
|
|
device = torch.device("cuda")
|
|
|
|
def combine_fn(x, y):
|
|
_cnt = torch.zeros_like(y[0, :])
|
|
if loop_type == "while":
|
|
|
|
def cond_fn(ind, loop_val):
|
|
return (loop_val < 5)[0]
|
|
|
|
def body_fn(ind, loop_val):
|
|
return ind + 1, loop_val + torch.abs(ind)
|
|
|
|
inp = torch.randn(3, 10, 1, device=device) * 2
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
),
|
|
)
|
|
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
|
|
def combine_fn(x, y):
|
|
val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
|
|
return x * val
|
|
|
|
inp = torch.randn(3, 10, 1, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# TODO: Re-enable additional parameters again once this issues has been resolved
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_map_in_combine_fn(self):
|
|
compile_mode = "none"
|
|
reverse = False
|
|
device = torch.device("cuda")
|
|
|
|
def combine_fn(x, y):
|
|
def body(x, y):
|
|
return x + y
|
|
|
|
y_init = y[0]
|
|
y_new = control_flow.map(body, y, y_init)
|
|
return x * y_new
|
|
|
|
inp = torch.randn(3, 10, 1, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_vmap_in_combine_fn(self, compile_mode, reverse, device):
|
|
def combine_fn(x, y):
|
|
def body(x):
|
|
return x**2
|
|
|
|
mapped_body = torch.vmap(body, 0, 0)
|
|
y_new = mapped_body(y)
|
|
return x + y_new
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of associative_scan and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["device"] == torch.device("cpu")),
|
|
)
|
|
def test_associative_scan_non_pointwise_generic(
|
|
self, reverse, compile_mode, device
|
|
):
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("non_pointwise", True),
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_binary_operator(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
state_dim = 20
|
|
timesteps = 10
|
|
projected_inputs = torch.randn(
|
|
timesteps, state_dim, requires_grad=True, device=device
|
|
)
|
|
A = torch.randn(state_dim, requires_grad=True, device=device)
|
|
elements = (A.repeat((timesteps, 1)), projected_inputs)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("s5_operator", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=elements,
|
|
)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_different_input_size(self, compile_mode, reverse, device):
|
|
batch = 5
|
|
hidden_dim = 3
|
|
length = 10
|
|
dstate = 7
|
|
|
|
deltaA = torch.randn(
|
|
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
|
|
)
|
|
deltaB_u = torch.randn(
|
|
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
|
|
)
|
|
C = torch.randn((batch, dstate, length), requires_grad=True, device=device)
|
|
x = torch.randn(
|
|
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
|
|
)
|
|
y = torch.randn((batch, hidden_dim, length), requires_grad=True, device=device)
|
|
elements = (x, deltaA, deltaB_u, C, y)
|
|
|
|
kwargs = {
|
|
"dim": 2,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("different_input_size_operator", True),
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=elements,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_different_input_size_wrong_dim(self):
|
|
batch = 5
|
|
hidden_dim = 3
|
|
length = 10
|
|
dstate = 7
|
|
|
|
deltaA = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
deltaB_u = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
C = torch.randn((batch, dstate, length), device=torch.device("cuda"))
|
|
x = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
y = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
elements = (x, deltaA, deltaB_u, C, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"All xs leaves must at least have.*",
|
|
):
|
|
associative_scan(
|
|
get_scan_combine_fn("different_input_size_operator", True),
|
|
elements,
|
|
3,
|
|
combine_mode="pointwise",
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_simple(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
H = torch.rand(2, device=device)
|
|
|
|
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
|
|
return x * H + y * 2
|
|
|
|
def fct_freevars2(x: torch.Tensor, y: torch.Tensor):
|
|
return x * H + y * H
|
|
|
|
H1 = torch.rand(1, device=device)
|
|
H2 = torch.rand(1, device=device)
|
|
|
|
def fct_freevars3(x: torch.Tensor, y: torch.Tensor):
|
|
return x * H1 + y * H2
|
|
|
|
inp = torch.randn(3, 2, 2, device=device)
|
|
|
|
for fct, param in [
|
|
(fct_freevars1, (H,)),
|
|
(fct_freevars2, (H,)),
|
|
(fct_freevars3, (H1, H2)),
|
|
]:
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_nested(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
H1 = torch.rand(4, 5, device=device)
|
|
H2 = torch.rand(4, 1, device=device)
|
|
|
|
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
|
def inner(xi):
|
|
return xi * H2
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor):
|
|
def inner(xi):
|
|
return xi * H2
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
H1_i = torch.rand(4, 5, device=device)
|
|
|
|
# TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error:
|
|
# RuntimeError: vmap: called random operation while in randomness error mode.
|
|
# Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
|
|
def fct_nested_inside(x: torch.Tensor, y: torch.Tensor):
|
|
# H2_i = torch.rand(4, 1, device=device)
|
|
H2_i = torch.ones(4, 1, device=device) * 42
|
|
|
|
def inner(xi):
|
|
return xi * H2_i
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
def fct_nested_inside_fake(x: torch.Tensor, y: torch.Tensor):
|
|
# H2_i = torch.rand(4, 1, device=device)
|
|
H2_i = torch.ones(4, 1, device=device) * 42
|
|
|
|
def inner(xi):
|
|
return xi * H2_i
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
inp = torch.randn(3, 4, 5, device=device)
|
|
|
|
for fct, fct_fake, param in [
|
|
(fct_nested_outside, fct_nested_outside_fake, (H1, H2)),
|
|
(fct_nested_inside, fct_nested_inside_fake, (H1_i,)),
|
|
]:
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
kwargs_fake["combine_fn"] = fct_fake
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_fct(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
def additional_fct_no_add_inp(x, y):
|
|
return x * y
|
|
|
|
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
|
ret = additional_fct_no_add_inp(y, y)
|
|
return x + ret
|
|
|
|
inp = torch.randn(3, 4, 5, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_nested_outside,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_freevars_fct_generic(self, compile_mode, reverse, device):
|
|
def additional_fct_no_add_inp(x, y):
|
|
return x * y
|
|
|
|
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
|
ret = associative_scan(
|
|
additional_fct_no_add_inp, y, 1, combine_mode="generic"
|
|
)
|
|
return x + ret
|
|
|
|
def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor):
|
|
ret = _fake_associative_scan(additional_fct_no_add_inp, y, 1)
|
|
return x + ret
|
|
|
|
inp = torch.randn(3, 4, 5, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_nested_outside,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
kwargs_fake["combine_fn"] = fct_nested_outside_fake
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_shape_check(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
H = torch.eye(2, device=device, requires_grad=True)
|
|
|
|
def fct_freevars(x: torch.Tensor, y: torch.Tensor):
|
|
return x @ H + y
|
|
|
|
inp = torch.randn(2, 2, 3, device=device, requires_grad=True)
|
|
|
|
kwargs = {
|
|
"dim": 2,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_freevars,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_pytree(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
xf = torch.randn(2, 2, device=device, requires_grad=True)
|
|
yf = torch.randn(2, 2, device=device, requires_grad=True)
|
|
zf = torch.randn(2, 2, device=device, requires_grad=True)
|
|
inpf = {"i": xf, "j": ([yf], [{"o": zf}])}
|
|
|
|
def fct_pointwise(x, y):
|
|
return {
|
|
"i": (x["i"] * y["i"]) + inpf["i"],
|
|
"j": (
|
|
[(x["j"][0][0] * y["j"][0][0]) + inpf["j"][0][0]],
|
|
[
|
|
{
|
|
"o": (x["j"][1][0]["o"] + y["j"][1][0]["o"])
|
|
+ inpf["j"][1][0]["o"]
|
|
}
|
|
],
|
|
),
|
|
}
|
|
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
|
z = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_pointwise,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
def test_associative_scan_sparse_tensor(self):
|
|
x = torch.tensor(
|
|
[[[0.0, 0], [1.0, 2.0]], [[0.0, 0], [3.0, 4.0]], [[0.0, 0], [5.0, 6.0]]]
|
|
).to_sparse()
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"xs leaves must dense Tensors.*",
|
|
):
|
|
associative_scan(
|
|
get_scan_combine_fn("add", True), x, 0, combine_mode="generic"
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_combine_fn_wrong_meta_in_combine_fn(self):
|
|
device = torch.device("cuda")
|
|
B, N, C, H, W = 3, 3, 2, 3, 3
|
|
x = torch.randn(B, N, C, H, W, device=device)
|
|
|
|
def fct_wrong_dtype(x, y):
|
|
return (x + y).to(torch.int64)
|
|
|
|
def fct_wrong_device(x, y):
|
|
return (x + y).to(
|
|
torch.device("cpu") if device.type == "cuda" else torch.device("cuda")
|
|
)
|
|
|
|
def fct_wrong_stride(x, y):
|
|
return (x + y).to(memory_format=torch.channels_last)
|
|
|
|
for fct in [fct_wrong_dtype, fct_wrong_device, fct_wrong_stride]:
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected initial_xs and combine_fn_output to have same metadata.*",
|
|
):
|
|
associative_scan(fct, x, 0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
def test_associative_scan_wrong_pytree(self):
|
|
def fct_wrong_pytree(x, y):
|
|
return {
|
|
"i": x["i"] * y["j"][0][0],
|
|
"k": torch.tensor(0.0),
|
|
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
|
|
}
|
|
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(3, 2, 2)
|
|
z = torch.randn(3, 2, 2)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Combin_fn received wrong number of arguments.*",
|
|
):
|
|
associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic")
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_non_pointwise(self):
|
|
device = torch.device("cuda")
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
with self.assertRaisesRegex(
|
|
# Should be:
|
|
RuntimeError,
|
|
r"For combine_mode='pointwise', the combine_fn needs to be pointwise",
|
|
):
|
|
associative_scan(
|
|
get_scan_combine_fn("non_pointwise", True),
|
|
x,
|
|
0,
|
|
combine_mode="pointwise",
|
|
)
|
|
|
|
@requires_cuda
|
|
def test_associative_scan_input_mutation(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_mutation(x, y):
|
|
x.add_(1)
|
|
return x + y
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"associative_scan must be captured completely with torch.compile.*",
|
|
):
|
|
associative_scan(fct_input_mutation, x, 0)
|
|
|
|
@requires_cuda
|
|
def test_associative_scan_input_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_output_alias(x, y):
|
|
return x[0], x[1] + y[1]
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"associative_scan must be captured completely with torch.compile.*",
|
|
):
|
|
associative_scan(fct_input_output_alias, inp, 0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_output_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_output_output_alias(x, y):
|
|
c = x[0] + y[1]
|
|
return c, c
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"associative_scan must be captured completely with torch.compile.*",
|
|
):
|
|
associative_scan(fct_output_output_alias, inp, 0)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfNoDynamoSupport
|
|
class TestControlFlowTraced(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def _check_tracing(self, fn, args, allow_non_fake_inputs=False):
|
|
graphs = {}
|
|
eager_res = fn(*args)
|
|
for tracing_mode in ["symbolic", "real", "fake"]:
|
|
graph = make_fx(
|
|
fn,
|
|
tracing_mode=tracing_mode,
|
|
_allow_non_fake_inputs=allow_non_fake_inputs,
|
|
)(*args)
|
|
graphs[tracing_mode] = graph
|
|
self.assertEqual(graph(*args), eager_res)
|
|
return graphs
|
|
|
|
def _check_compile(self, fn, args, *, dynamic=False, backend="eager"):
|
|
eager_res = fn(*args)
|
|
compiled_fn = torch.compile(fn, backend=backend, dynamic=dynamic)
|
|
self.assertEqual(compiled_fn(*args), eager_res)
|
|
|
|
def _check_export(self, fn, args, *, strict=False, dynamic_shapes=None):
|
|
eg_out = fn(*args)
|
|
ep = torch.export.export(fn, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
ep_out = ep.module()(*args)
|
|
self.assertEqual(eg_out, ep_out)
|
|
return ep
|
|
|
|
def test_cond_traced_not_nested(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False))
|
|
result_true = graph.forward(x, torch.tensor(True))
|
|
result_false = graph.forward(x, torch.tensor(False))
|
|
self.assertFalse(torch.allclose(result_true, result_false))
|
|
self.assertEqual(result_true, torch.sin(x))
|
|
self.assertEqual(result_false, torch.cos(x))
|
|
|
|
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
|
|
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_cond_simple_with_linear_compile_check_graph(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(torch.tensor(False), x)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
|
|
l_pred_ = L_pred_
|
|
l_x_ = L_x_
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (l_x_,)); l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None
|
|
result = cond[0]; cond = None
|
|
grad_out = torch.ones_like(result)
|
|
return (result, grad_out)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_ctx_saved_tensors_0_: "f32[4]", L_ctx_pred: "b8[]", L_args_1_: "f32[4]"):
|
|
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
|
|
l_ctx_pred = L_ctx_pred
|
|
l_args_1_ = L_args_1_
|
|
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, (l_args_1_, l_ctx_saved_tensors_0_)); l_ctx_pred = cond_true_0 = cond_false_0 = l_args_1_ = l_ctx_saved_tensors_0_ = None
|
|
getitem: "f32[4]" = cond[0]; cond = None
|
|
return (getitem,)
|
|
|
|
class cond_true_0(torch.nn.Module):
|
|
def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"):
|
|
l_args_1__1 = l_args_1_
|
|
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
|
|
|
|
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); sin = None
|
|
|
|
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
|
|
|
|
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, cos); l_args_1__1 = cos = None
|
|
return (mul,)
|
|
|
|
class cond_false_0(torch.nn.Module):
|
|
def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"):
|
|
l_args_1__1 = l_args_1_
|
|
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
|
|
|
|
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); cos = None
|
|
|
|
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
|
|
|
|
neg: "f32[4]" = torch.ops.aten.neg.default(sin); sin = None
|
|
|
|
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, neg); l_args_1__1 = neg = None
|
|
return (mul,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_op_mismatch_in_meta(self):
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
def cond_fn(c, a, b):
|
|
return c > 0
|
|
|
|
def body_fn(c, a, b):
|
|
return c - 1, a.nonzero(), b.nonzero()
|
|
|
|
return torch.ops.higher_order.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
(c, a, b),
|
|
tuple(),
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected carried_inputs and body_output to have same metadata but found",
|
|
):
|
|
make_fx(Mod(), tracing_mode="fake")(
|
|
torch.tensor(
|
|
0,
|
|
),
|
|
torch.randn(2, 3),
|
|
torch.randn(2, 3),
|
|
)
|
|
|
|
def test_while_loop_nested_traced(self):
|
|
fn, inp = WHILE_LOOP_TESTS["nested"]
|
|
graphs = self._check_tracing(fn, inp)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, out_iter_1, it_1, y_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (out_iter_1, it_1, y_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = out_iter_1 = it_1 = y_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]; while_loop = None
|
|
return (getitem, getitem_1, getitem_2)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 2); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]; while_loop = None
|
|
add = torch.ops.aten.add.Tensor(getitem, 1); getitem = None
|
|
return (add, getitem_1, getitem_2)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_pytree_carry(self):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_pytree_carry"]
|
|
backend = EagerAndRecordGraphs()
|
|
expected_res = fn(*inp)
|
|
compiled_res = torch.compile(fn, backend=backend)(*inp)
|
|
self.assertEqual(expected_res, compiled_res)
|
|
# When test with torch dynamo, the graph is not captured because
|
|
# it's traced together with the code before torch.compile
|
|
if not TEST_WITH_TORCHDYNAMO:
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_pytree_input_1_x_ : torch.Tensor, L_pytree_input_1_y_ : torch.Tensor):
|
|
l_it_ = L_it_
|
|
l_pytree_input_0_0_ = L_pytree_input_0_0_
|
|
l_pytree_input_1_x_ = L_pytree_input_1_x_
|
|
l_pytree_input_1_y_ = L_pytree_input_1_y_
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
value = while_loop[2]
|
|
value_1 = while_loop[3]; while_loop = None
|
|
return (getitem, getitem_1, value, value_1)""", # noqa: B950
|
|
)
|
|
|
|
def _wrap_with_functionalize(self, fn, func_type):
|
|
mode = None
|
|
if func_type == "cpp":
|
|
fn = CppFunctionalizeAPI().functionalize(fn)
|
|
elif func_type == "python":
|
|
fn = PythonFunctionalizeAPI().functionalize(fn)
|
|
mode = FunctionalTensorMode()
|
|
elif func_type == "functorch":
|
|
fn = torch.func.functionalize(fn)
|
|
else:
|
|
assert func_type == "no"
|
|
return fn, mode
|
|
|
|
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
|
|
def test_while_loop_simple_functionalize_check_graph(self, func_type):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"]
|
|
fn, mode = self._wrap_with_functionalize(fn, func_type)
|
|
mode = mode if mode is not None else contextlib.nullcontext()
|
|
with mode:
|
|
graphs = self._check_tracing(fn, inp)
|
|
if func_type == "no":
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, x_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
|
|
getitem = while_loop[0]; while_loop = None
|
|
return (getitem,)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None
|
|
add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None
|
|
sum_1 = torch.ops.aten.sum.default(add__1); add__1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None
|
|
add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None
|
|
add = torch.ops.aten.add.Tensor(add__1, 1); add__1 = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
elif func_type == "python":
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = None
|
|
getitem = while_loop[0]; while_loop = None
|
|
return (getitem,)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None
|
|
return (add_2,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, x_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
|
|
getitem = while_loop[0]; while_loop = None
|
|
return (getitem,)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None
|
|
return (add_2,)
|
|
""",
|
|
)
|
|
|
|
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
|
|
# - "simple_with_linear" and "nested_with_linear" doesn't work because parameters and buffers
|
|
# are not inputs so they're not wrapped by functionalization and tracing.
|
|
#
|
|
# - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output"
|
|
# because tensors are real but we unspecialize the ints with unbacked symints causing
|
|
# data dependent errors.
|
|
# Since this is not the common use path, we skip them for now.
|
|
@parametrize(
|
|
"while_loop_test",
|
|
set(WHILE_LOOP_TESTS.keys())
|
|
- {
|
|
"simple_with_linear",
|
|
"nested_with_linear",
|
|
"int_carry",
|
|
"pytree_int_carry",
|
|
"const_and_symint_output",
|
|
},
|
|
)
|
|
def test_while_loop_functionalize(self, func_type, while_loop_test):
|
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
|
fn, mode = self._wrap_with_functionalize(fn, func_type)
|
|
mode = mode if mode is not None else contextlib.nullcontext()
|
|
with mode:
|
|
self._check_tracing(fn, inp)
|
|
|
|
# - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output"
|
|
# because tensors are real but we unspecialize the ints with unbacked symints causing
|
|
# data dependent errors.
|
|
# Since this is not the common use path, we skip them for now.
|
|
@parametrize(
|
|
"while_loop_test",
|
|
set(WHILE_LOOP_TESTS.keys())
|
|
- {"int_carry", "pytree_int_carry", "const_and_symint_output"},
|
|
)
|
|
def test_while_loop_tracing(self, while_loop_test):
|
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
|
allow_non_fake_inputs = (
|
|
False
|
|
if while_loop_test not in ("simple_with_linear", "nested_with_linear")
|
|
else True
|
|
)
|
|
self._check_tracing(fn, inp, allow_non_fake_inputs)
|
|
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
@parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
|
|
def test_while_loop_compile(self, backend, while_loop_test):
|
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
|
self._check_compile(fn, inp, backend=backend)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
@skipIfCrossRef # Arg order changes with cross ref
|
|
def test_while_loop_simple_with_linear_compile_check_graph(self):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(fn, backend=backend)(*inp)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
gm = backend.graphs[0]
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter):
|
|
l_iter_ = L_iter_
|
|
l_x_ = L_x_
|
|
l_self_buffers_dec_ = L_self_buffers_dec_
|
|
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
|
|
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]; while_loop = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.cond_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, child : torch.Tensor, child_1 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
|
|
sub = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None
|
|
gt = sub > 0; sub = None
|
|
return gt""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.body_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, child_2 : torch.Tensor, child_3 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
|
|
child = child_2 - 1; child_2 = None
|
|
child_4 = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
|
|
return (child, child_4)""", # noqa: B950
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
|
|
l_iter_ = L_iter_
|
|
l_x_ = L_x_
|
|
l__self___dec = self.L__self___dec
|
|
l__self___linear_weight = self.L__self___linear_weight
|
|
l__self___linear_bias = self.L__self___linear_bias
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]; while_loop = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.cond_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
|
|
sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None
|
|
gt = sub > 0; sub = None
|
|
return gt""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.body_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
|
|
child = l_iter_ - 1; l_iter_ = None
|
|
child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
|
|
return (child, child_1)""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_nested2_traced(self):
|
|
fn, inp = WHILE_LOOP_TESTS["nested2"]
|
|
graphs = self._check_tracing(fn, inp)
|
|
gm = graphs["symbolic"]
|
|
outer_body = gm.while_loop_body_graph_0
|
|
inner_body = outer_body.while_loop_body_graph_0
|
|
inner_cond = outer_body.while_loop_cond_graph_0
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1)
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0)
|
|
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0)
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]
|
|
getitem_3 = while_loop[3]; while_loop = None
|
|
return (getitem, getitem_1, getitem_2, getitem_3)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
outer_body.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]
|
|
getitem_3 = while_loop[3]; while_loop = None
|
|
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
|
|
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
|
|
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
|
|
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
|
|
return (sub, clone, mul, div)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
outer_body.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]
|
|
getitem_3 = while_loop[3]; while_loop = None
|
|
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
|
|
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
|
|
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
|
|
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
|
|
return (sub, clone, mul, div)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
inner_body.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
|
|
add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71); arg3_1 = None
|
|
return (clone, sub, add, sub_1)
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
inner_cond.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None
|
|
return gt
|
|
""",
|
|
)
|
|
|
|
def test_cond_nested_traced(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(x, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [x])
|
|
return x + z
|
|
|
|
def false_fn(x, _):
|
|
return x.cos()
|
|
|
|
def f(x, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [x, pred2])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
result_true_true = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(True)
|
|
) # True + True -> x * x
|
|
result_true_false = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(False)
|
|
) # True + True -> x + x
|
|
result_false_true = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(True)
|
|
) # False + either -> cos
|
|
result_false_false = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
) # False + either -> cos
|
|
|
|
self.assertNotEqual(result_true_true, result_true_false)
|
|
self.assertFalse(torch.allclose(result_false_true, result_true_true))
|
|
|
|
self.assertEqual(result_false_true, result_false_false)
|
|
|
|
self.assertEqual(result_true_true, (x * x) + x)
|
|
self.assertEqual(result_true_false, x + x + x)
|
|
|
|
self.assertEqual(result_false_true, torch.cos(x))
|
|
|
|
graph = make_fx(f, tracing_mode="symbolic")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
self.assertEqual(
|
|
graph(x, torch.tensor(True), torch.tensor(True)),
|
|
f(x, torch.tensor(True), torch.tensor(True)),
|
|
)
|
|
|
|
def test_cond_functionalized(self):
|
|
def true_fn(x):
|
|
y = x.sin()
|
|
y.add_(4)
|
|
return x.sin().max() + y.sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
all_ops_in_true_branch = []
|
|
for node in graph_module.true_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
all_ops_in_true_branch.append(node.target)
|
|
|
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
|
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
def test_cond_accepts_torch_function_as_inputs(self):
|
|
a = torch.randn(3, 4)
|
|
b = torch.randn(3, 4)
|
|
|
|
def f(a, b):
|
|
return cond(a.sum() > 0, torch.add, torch.mul, (a, b))
|
|
|
|
gm = self._check_tracing(f, (a, b))["symbolic"]
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, a_1, b_1):
|
|
sum_1 = torch.ops.aten.sum.default(a_1)
|
|
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
|
sym_size_int = torch.ops.aten.sym_size.int(a_1, 1)
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1)
|
|
sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (add,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.false_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (mul,)""",
|
|
)
|
|
|
|
def test_cond_retrace_functionalized(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x):
|
|
return cond(x.all(), true_fn, false_fn, (x,))
|
|
|
|
inp = torch.ones(1, 2)
|
|
gm_non_functional = make_fx(f, tracing_mode="real")(inp)
|
|
gm_functional = make_fx(
|
|
torch.func.functionalize(gm_non_functional), tracing_mode="real"
|
|
)(inp)
|
|
self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
|
|
|
|
def test_cond_subgraph_same_shape_env_as_parent(self):
|
|
def true_fn(x):
|
|
return x.sin() + 10
|
|
|
|
def false_fn(x):
|
|
return x.cos() - 20
|
|
|
|
def f(x, pred):
|
|
y = cond(pred, true_fn, false_fn, [x])
|
|
z = torch.add(y, y)
|
|
return z
|
|
|
|
symbolic_traced_graph = self._check_tracing(
|
|
f, (torch.ones(4), torch.Tensor([True]))
|
|
)["symbolic"]
|
|
graph_shape_env = symbolic_traced_graph.shape_env
|
|
|
|
def _node_shape_env_iter(gm):
|
|
for node in symbolic_traced_graph.graph.nodes:
|
|
if node.op == "call_function":
|
|
val = node.meta.get("val")
|
|
if isinstance(val, tuple):
|
|
for v in val:
|
|
yield v.fake_mode.shape_env
|
|
elif isinstance(val, torch.SymInt):
|
|
yield val.node.shape_env
|
|
else:
|
|
yield val.fake_mode.shape_env
|
|
|
|
for shape_env in _node_shape_env_iter(symbolic_traced_graph):
|
|
self.assertTrue(shape_env is graph_shape_env)
|
|
|
|
for shape_env in _node_shape_env_iter(symbolic_traced_graph.true_graph_0):
|
|
self.assertTrue(shape_env is graph_shape_env)
|
|
|
|
for shape_env in _node_shape_env_iter(symbolic_traced_graph.false_graph_0):
|
|
self.assertTrue(shape_env is graph_shape_env)
|
|
|
|
def test_cond_functionalized_nested(self):
|
|
def true_true_fn(x):
|
|
y = x.cos()
|
|
y.add_(4)
|
|
return x.sin().max() + y.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
|
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
all_ops = []
|
|
for node in gm_true_true_branch.graph.nodes:
|
|
if node.op == "call_function":
|
|
all_ops.append(node.target)
|
|
|
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops))
|
|
|
|
def test_cond_functionalized_data_dependent_pred(self):
|
|
def true_fn(x):
|
|
return x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.nonzero().shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
|
def true_fn(x):
|
|
view_x = x.view(x.shape)
|
|
view_x.add_(1)
|
|
return view_x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
# torch.cond inlines into one of the branches because the predicate
|
|
# is a constant.
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [4, 5])
|
|
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
|
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
|
|
sin = torch.ops.aten.sin.default(view_2); view_2 = None
|
|
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
|
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None
|
|
return sum_1""",
|
|
)
|
|
|
|
# torch.cond triggers the check of the branches because the predicate
|
|
# is a SymBool.
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"cond_true might be modifying the input!",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
|
def true_fn(x):
|
|
return x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
view_x.add_(1)
|
|
return view_x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(5, 5),)
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
# torch.cond inlines into one of the branches because the predicate
|
|
# is a constant.
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [5, 5])
|
|
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
|
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
|
|
cos = torch.ops.aten.cos.default(view_2); view_2 = None
|
|
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
|
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None
|
|
return sum_1""",
|
|
)
|
|
|
|
# torch.cond triggers the check of the branches because the predicate
|
|
# is a SymBool.
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"cond_false might be modifying the input!",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_output_alias_input(self):
|
|
def true_fn(x):
|
|
return x.clone()
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
return view_x
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(5, 5),)
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
# torch.cond inlines into one of the branches because the predicate
|
|
# is a constant.
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
|
|
return view""",
|
|
)
|
|
|
|
# torch.cond triggers the check of the branches because the predicate
|
|
# is a SymBool.
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_nested_input_mutation(self):
|
|
def true_true_fn(x):
|
|
x.add_(4)
|
|
return x.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"cond_true might be modifying the input!",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
|
def true_true_fn(x):
|
|
x.add_(4)
|
|
return x.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_input = torch.ones(4, 5)
|
|
try:
|
|
example_input_func = to_fun_old(example_input)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
f(example_input_func)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
|
|
|
def test_cond_functionalized_input_aliasing_with_aot_func(self):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
return view_x
|
|
|
|
def f(x):
|
|
pred = x.sum() > 0
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_input = torch.ones(5, 5)
|
|
try:
|
|
example_input_func = to_fun_old(example_input)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
f(example_input_func)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(
|
|
lambda x: torch._to_functional_tensor(x)
|
|
if isinstance(x, torch.Tensor)
|
|
else x,
|
|
args,
|
|
)
|
|
func_kwargs = pytree.tree_map(
|
|
lambda x: torch._to_functional_tensor(x)
|
|
if isinstance(x, torch.Tensor)
|
|
else x,
|
|
kwargs,
|
|
)
|
|
return func(*func_args, **func_kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
|
|
|
def test_cond_functionalized_aot_func_check_functional(self):
|
|
def true_fn(x):
|
|
return x.cos()
|
|
|
|
def false_fn(x):
|
|
y = x.sin()
|
|
y.add_(5)
|
|
return y
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_input = torch.ones(5, 5)
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
args,
|
|
)
|
|
func_kwargs = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
kwargs,
|
|
)
|
|
return pytree.tree_map(
|
|
from_fun_old, func(*func_args, **func_kwargs)
|
|
)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
|
for node in result_gm.true_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
for node in result_gm.false_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5)))
|
|
|
|
def test_cond_nested_traced_other_inputs(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(k, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [k])
|
|
return torch.add(torch.tensor([0.25, 0.25]), z)
|
|
|
|
def false_fn(k, _):
|
|
return k.cos()
|
|
|
|
def f(k, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [k, pred2])
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
a = torch.tensor([1.0, 1.0])
|
|
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
|
|
|
|
b = torch.tensor([2.0, 2.0])
|
|
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
|
|
|
|
def test_cond_nested_traced_multi(self):
|
|
def true_a(y):
|
|
return y * y
|
|
|
|
def false_a(y):
|
|
return y + y
|
|
|
|
def true_b(y, z):
|
|
return y + z
|
|
|
|
def false_b(y, z):
|
|
return y * z
|
|
|
|
def f(x, pred, pred2):
|
|
a_out = cond(pred, true_a, false_a, [x])
|
|
b_out = cond(pred2, true_b, false_b, [x, x])
|
|
return a_out + b_out
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, x_1, pred_1, pred2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
|
return add""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graph.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return (mul,)""",
|
|
)
|
|
|
|
def test_raise_error_on_mismatch_type_size(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f)(x, torch.tensor(False))
|
|
|
|
def test_raise_error_on_mismatch_tensor_size(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return torch.zeros([10, 10])
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"When merging two branches' output in torch.cond",
|
|
):
|
|
make_fx(f)(x, torch.tensor(False))
|
|
|
|
def test_cond_traced_not_nested_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
result_true = graph.forward(x, torch.tensor(True))
|
|
result_false = graph.forward(x, torch.tensor(False))
|
|
self.assertFalse(torch.allclose(result_true, result_false))
|
|
self.assertEqual(result_true, torch.sin(x))
|
|
self.assertEqual(result_false, torch.cos(x))
|
|
|
|
def test_cond_nested_traced_fake_tensor(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(x, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [x])
|
|
return x + z
|
|
|
|
def false_fn(x, _):
|
|
return x.cos()
|
|
|
|
def f(x, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [x, pred2])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
|
|
result_true_true = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(True)
|
|
) # True + True -> x * x
|
|
result_true_false = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(False)
|
|
) # True + True -> x + x
|
|
result_false_true = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(True)
|
|
) # False + either -> cos
|
|
result_false_false = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
) # False + either -> cos
|
|
|
|
self.assertNotEqual(result_true_true, result_true_false)
|
|
self.assertFalse(torch.allclose(result_false_true, result_true_true))
|
|
|
|
self.assertEqual(result_false_true, result_false_false)
|
|
|
|
self.assertEqual(result_true_true, (x * x) + x)
|
|
self.assertEqual(result_true_false, x + x + x)
|
|
|
|
self.assertEqual(result_false_true, torch.cos(x))
|
|
|
|
def test_cond_nested_traced_other_inputs_fake_tensor(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(k, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [k])
|
|
return torch.add(torch.tensor([0.25, 0.25]), z)
|
|
|
|
def false_fn(k, _):
|
|
return k.cos()
|
|
|
|
def f(k, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [k, pred2])
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
graph = make_fx(f, tracing_mode="fake")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
|
|
a = torch.tensor([1.0, 1.0])
|
|
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
|
|
|
|
b = torch.tensor([2.0, 2.0])
|
|
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
|
|
|
|
def test_cond_nested_traced_multi_fake_tensor(self):
|
|
def true_a(y):
|
|
return y * y
|
|
|
|
def false_a(y):
|
|
return y + y
|
|
|
|
def true_b(y, z):
|
|
return y + z
|
|
|
|
def false_b(y, z):
|
|
return y * z
|
|
|
|
def f(x, pred, pred2):
|
|
a_out = cond(pred, true_a, false_a, [x])
|
|
b_out = cond(pred2, true_b, false_b, [x, x])
|
|
return a_out + b_out
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, x_1, pred_1, pred2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
|
return add""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graph.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return (mul,)""",
|
|
)
|
|
|
|
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
|
|
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return torch.zeros([10, 10])
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"When merging two branches' output in torch.cond",
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
|
|
def check_map_count(self, gm, op_count):
|
|
i = 0
|
|
for m in gm.modules():
|
|
for node in m.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.higher_order.map_impl
|
|
):
|
|
i += 1
|
|
self.assertEqual(i, op_count)
|
|
|
|
def test_tracing_map_real(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2))
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_symbolic_simple(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4))
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_symbolic_list(self):
|
|
def f(x, y):
|
|
return [x[0][0] + y, x[1] * y]
|
|
|
|
def g(xs, y, z):
|
|
out = control_flow.map(f, xs, y)
|
|
return out[0] + z, out[1] * z
|
|
|
|
example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)]
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
example_x, torch.ones(5), torch.ones(5)
|
|
)
|
|
x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)]
|
|
y = torch.randn(6)
|
|
z = torch.ones(6)
|
|
res = gm(x, y, z)
|
|
self.assertEqual(res, g(x, y, z))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_symbolic_dict(self):
|
|
def f(x, y):
|
|
return {"d": x["b"]["a"] + y, "e": x["c"] * y}
|
|
|
|
def g(xs, y, z):
|
|
out = control_flow.map(f, xs, y)
|
|
return {"f": out["d"] + z, "g": out["e"] * z}
|
|
|
|
example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)}
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
example_x, torch.ones(5), torch.ones(5)
|
|
)
|
|
x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)}
|
|
y = torch.randn(6)
|
|
z = torch.ones(6)
|
|
res = gm(x, y, z)
|
|
self.assertEqual(res, g(x, y, z))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_autograd_symbolic_simple(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
out = control_flow.map(f, xs, y)
|
|
return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)
|
|
)
|
|
x = torch.randn(4, 5, 6, requires_grad=True)
|
|
y = torch.randn(6, requires_grad=True)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_tracing_map_autograd_symbolic_list(self):
|
|
import torch.utils._pytree as pytree
|
|
|
|
def f(x, y):
|
|
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
|
|
|
|
def g(xs, y):
|
|
out = control_flow.map(f, xs, y)
|
|
flat_out = pytree.tree_leaves(out)
|
|
flat_inp = pytree.tree_leaves((xs, y))
|
|
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
|
|
return torch.autograd.grad(
|
|
flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
|
|
)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
[torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
|
|
torch.ones(5, requires_grad=True),
|
|
)
|
|
x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)]
|
|
y = torch.randn(6, requires_grad=True)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_tracing_map_autograd_symbolic_dict(self):
|
|
def f(x, y):
|
|
return [x["a"] + y, x["b"] * y]
|
|
|
|
def g(xs, y):
|
|
out = control_flow.map(f, xs, y)
|
|
flat_out = pytree.tree_leaves(out)
|
|
flat_inp = pytree.tree_leaves((xs, y))
|
|
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
|
|
return torch.autograd.grad(
|
|
flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
|
|
)
|
|
|
|
traced_x = {
|
|
"a": torch.ones(3, 4, 5, requires_grad=True),
|
|
"b": torch.ones(3, 4, 5, requires_grad=True),
|
|
}
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
traced_x, torch.ones(5, requires_grad=True)
|
|
)
|
|
x = {
|
|
"a": torch.randn(4, 5, 6, requires_grad=True),
|
|
"b": torch.ones(4, 5, 6, requires_grad=True),
|
|
}
|
|
y = torch.randn(6, requires_grad=True)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_tracing_map_autograd_aot_functionalized(self):
|
|
def inner(x, y):
|
|
z = x - 1
|
|
z.add_(1)
|
|
return z * y
|
|
|
|
def f(xs, y):
|
|
res = control_flow.map(inner, xs, y)
|
|
grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res))
|
|
return grads
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
return pytree.tree_map(from_fun_old, func(*args, **kwargs))
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
example_inputs = (
|
|
torch.ones(3, 2, 4, requires_grad=True),
|
|
torch.ones(2, 4, requires_grad=True),
|
|
)
|
|
gm = make_fx(f, tracing_mode="symbolic")(*example_inputs)
|
|
fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs)
|
|
xs = torch.ones(3, 4, 5, requires_grad=True)
|
|
y = torch.ones(4, 5, requires_grad=True)
|
|
|
|
self.assertEqual(gm(xs, y), f(xs, y))
|
|
|
|
def count_mutable(gm):
|
|
c = 0
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function":
|
|
if node.target == torch.ops.higher_order.map_impl:
|
|
c += count_mutable(getattr(gm, str(node.args[0])))
|
|
elif schema := getattr(node.target, "_schema", None):
|
|
c += int(schema.is_mutable)
|
|
return c
|
|
|
|
self.assertEqual(count_mutable(fgm), 0)
|
|
# One for forward, one for recomputation logic in backward
|
|
self.assertEqual(count_mutable(gm), 2)
|
|
|
|
def test_map_functionalized(self):
|
|
def map_fn(x, y):
|
|
z = x + y
|
|
z.add_(4)
|
|
return z
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
gm = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
for node in gm.body_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_map_functionalized_aot_func(self):
|
|
def map_fn(x, y):
|
|
z = x + y
|
|
z.add_(4)
|
|
return z
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
return pytree.tree_map(from_fun_old, func(*args, **kwargs))
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
|
|
gm = make_fx(f_wrapper(f))(*example_inputs)
|
|
|
|
for node in gm.body_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
def test_map_functionalized_arg_mutation(self):
|
|
def map_fn(x, y):
|
|
y.add_(4)
|
|
return x + y
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"map might be modifying the input!",
|
|
):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_map_functionalized_elem_mutation(self):
|
|
def map_fn(x, y):
|
|
x.add_(4)
|
|
return x + y
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError, "map might be modifying the input!"
|
|
):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_cond_autograd_backward(self):
|
|
def true_fn(x):
|
|
return x.cos()
|
|
|
|
def false_fn(x):
|
|
return x.sin()
|
|
|
|
def f(x, y):
|
|
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
|
|
|
example_inputs = (
|
|
torch.ones(3, 2, 4, requires_grad=True),
|
|
torch.ones(4, requires_grad=True),
|
|
)
|
|
f(*example_inputs).sum().backward()
|
|
|
|
# Ensure no error is thrown when not running backward
|
|
res = f(*example_inputs)
|
|
|
|
# Ensure no error is thrown when not running backward
|
|
res_compiled = torch.compile(f)(*example_inputs)
|
|
self.assertEqual(res, res_compiled)
|
|
|
|
def test_map_functionalized_elem_alias(self):
|
|
def map_fn(x):
|
|
x.view(x.shape)
|
|
return x
|
|
|
|
def f(xs):
|
|
return control_flow.map(map_fn, xs)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4),)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"map doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_nested_map_cond_real(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
gm = make_fx(g, tracing_mode="real")(
|
|
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
pred = torch.tensor(False)
|
|
x = torch.randn(3, 2, 4)
|
|
y = torch.randn(4)
|
|
res = gm(pred, x, y)
|
|
self.assertEqual(res, g(pred, x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_nested_map_cond_symbolic(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
pred = torch.tensor(False)
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(pred, x, y)
|
|
self.assertEqual(res, g(pred, x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_nested_cond_map_cond_symbolic(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
def main_true_fn(pred, xs, y):
|
|
return g(pred, xs, y) * 2
|
|
|
|
def main_false_fn(pred, xs, y):
|
|
return g(pred, xs, y) + 1
|
|
|
|
def main(p, pred, xs, y):
|
|
return cond(p, main_true_fn, main_false_fn, [pred, xs, y])
|
|
|
|
gm = make_fx(main, tracing_mode="symbolic")(
|
|
torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
p = torch.tensor(False)
|
|
pred = torch.tensor(False)
|
|
xs = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(p, pred, xs, y)
|
|
self.assertEqual(res, main(p, pred, xs, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_cond_with_sym_pred(self):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x * x
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
|
|
# The symbols in make_fx's shape_env should not be specialized.
|
|
self.assertEqual(len(gm.shape_env.guards), 0)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 4
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
|
|
# We expect the traced graph module to work even if input size changes.
|
|
x = torch.ones(4, 3, 2)
|
|
self.assertEqual(gm(x), true_fn(x))
|
|
self.assertEqual(foo(x), true_fn(x))
|
|
|
|
def test_cond_with_unbacked_sym_pred(self):
|
|
def foo(x):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x * x
|
|
|
|
az = x.nonzero()
|
|
return cond(az.shape[0] > 3, true_fn, false_fn, (x,))
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.randn(7))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
nonzero = torch.ops.aten.nonzero.default(x_1)
|
|
sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0); nonzero = None
|
|
gt = sym_size_int > 3; sym_size_int = None
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x_1, sym_size_int_1)); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
|
|
def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
|
|
assert isinstance(args, (tuple, list))
|
|
self.assertEqual(f(*args), exp_res)
|
|
gm = make_fx(f)(*args)
|
|
self.assertEqual(gm(*args), exp_res)
|
|
|
|
def cnt_placeholder(gm):
|
|
return len([node for node in gm.graph.nodes if node.op == "placeholder"])
|
|
|
|
placeholder_cnts = [cnt_placeholder(mod) for mod in gm.children()]
|
|
self.assertTrue(all(cnt == exp_arg_num for cnt in placeholder_cnts))
|
|
|
|
def _check_closure_correctly_lifted_with_mutation(
|
|
self, f, closures_to_be_mutated, *, args, exp_arg_num
|
|
):
|
|
exp_res = f(*args)
|
|
self._check_closure_correctly_lifted(
|
|
f, args=args, exp_res=exp_res, exp_arg_num=exp_arg_num
|
|
)
|
|
|
|
for closure in closures_to_be_mutated:
|
|
closure.add(-1)
|
|
new_exp_res = f(*args)
|
|
|
|
self._check_closure_correctly_lifted(
|
|
f, args=args, exp_res=new_exp_res, exp_arg_num=exp_arg_num
|
|
)
|
|
|
|
def test_cond_with_tensor_closure(self):
|
|
a = torch.ones(2, 3)
|
|
b = torch.ones(2, 3) + 1
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
# expected branches takes [x, a, b] as input
|
|
inp = torch.randn(2, 3)
|
|
self._check_closure_correctly_lifted_with_mutation(
|
|
foo, (a, b), args=(inp,), exp_arg_num=3
|
|
)
|
|
|
|
def test_cond_with_tensor_closure_graph_module(self):
|
|
a = torch.ones(2, 3)
|
|
b = torch.ones(2, 3) + 1
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
# expected branches takes [x, a, b] as input
|
|
inp = torch.randn(2, 3)
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 4
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_tensor_constant0 = self._tensor_constant0
|
|
_tensor_constant1 = self._tensor_constant1
|
|
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (add,)""",
|
|
)
|
|
|
|
def test_cond_with_module_param_closure(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_parameter(
|
|
"param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
|
|
)
|
|
self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1)
|
|
|
|
my_mode = Mod()
|
|
|
|
def true_fn(x):
|
|
return x + my_mode.param
|
|
|
|
def false_fn(x):
|
|
return x + my_mode.buffer
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inp = torch.ones(2, 3)
|
|
# expected both branches takes (x, param, buffer)
|
|
self._check_closure_correctly_lifted_with_mutation(
|
|
foo, (my_mode.param, my_mode.buffer), args=(inp,), exp_arg_num=3
|
|
)
|
|
|
|
def test_cond_with_module_python_scalar_closure(self):
|
|
def foo(x):
|
|
a = torch.ones(1, 1)
|
|
b = 1
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b
|
|
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inp = torch.ones(2, 3)
|
|
res = inp + 1
|
|
# python scalar b is not lifted as input, so both branches take (x, a)
|
|
self._check_closure_correctly_lifted(
|
|
foo, args=(inp,), exp_res=res, exp_arg_num=2
|
|
)
|
|
|
|
def test_cond_nested_with_closure(self):
|
|
a = torch.ones(1, 1)
|
|
b = torch.ones(1, 1) + 1
|
|
|
|
def inner_true_fn(x):
|
|
return x + a
|
|
|
|
def inner_false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
def true_fn(x):
|
|
return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
|
|
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inp = torch.ones(2, 3)
|
|
# For top-level cond, it take 3 arguments (x, a, b). Dynamo should
|
|
# realize that the nonlocal variables are same for the true and false
|
|
# branches, so it should de-dupe them.
|
|
# For second-level conds, it takes (x, a, b)
|
|
self._check_closure_correctly_lifted_with_mutation(
|
|
foo, (a, b), args=(inp,), exp_arg_num=3
|
|
)
|
|
|
|
def test_cond_nested_with_closure_graph_module(self):
|
|
a = torch.ones(1, 1)
|
|
b = torch.ones(1, 1) + 1
|
|
|
|
def inner_true_fn(x):
|
|
return x + a
|
|
|
|
def inner_false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
def true_fn(x):
|
|
return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
|
|
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
def test_map_unfunc_boolean_tensor_for_nested_map_cond(self):
|
|
def map_fn(pred, x):
|
|
def fn(x, pred):
|
|
return control_flow.cond(pred, lambda x: x * 2, lambda x: x / 2, (x,))
|
|
|
|
return control_flow.map(fn, x, pred)
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
args,
|
|
)
|
|
func_kwargs = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
kwargs,
|
|
)
|
|
return pytree.tree_map(
|
|
from_fun_old, func(*func_args, **func_kwargs)
|
|
)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
gm = make_fx(f_wrapper(map_fn))(
|
|
torch.tensor(True), torch.ones([2, 3], requires_grad=False)
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
unbind = torch.ops.aten.unbind.int(x_1)
|
|
getitem = unbind[0]; getitem = None
|
|
getitem_1 = unbind[1]; unbind = getitem_1 = None
|
|
body_graph_0 = self.body_graph_0
|
|
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
|
getitem_2 = map_impl[0]; map_impl = None
|
|
return getitem_2""",
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.body_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, (arg0_1,)); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
|
|
def true_fn(x):
|
|
return x + x.cos()
|
|
|
|
def false_fn(x):
|
|
return x * x.sin()
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
|
|
|
|
inp = torch.randn([4, 3])
|
|
gm, _ = torch._dynamo.export(foo)(inp)
|
|
|
|
def run_with_interpreter(*args):
|
|
with torch.fx.traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(gm).run(*args)
|
|
|
|
new_gm = make_fx(run_with_interpreter)(inp)
|
|
|
|
checked_ops = {"add", "mul", "sin", "cos"}
|
|
checked_meta = ["source_fn_stack", "stack_trace"]
|
|
all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
|
|
new_source_fns = collect_meta_for_filtered_nodes(
|
|
new_gm, checked_ops, checked_meta
|
|
)
|
|
self.assertEqual(all_source_fns, new_source_fns)
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TORCHDYNAMO,
|
|
"triggers cache limit for foo and changes unique_graphs count.",
|
|
)
|
|
def test_cond_no_dynamo_cache_limit(self):
|
|
torch._dynamo.reset()
|
|
counters = torch._dynamo.utils.counters
|
|
counters.clear()
|
|
|
|
def foo(x, true_fn, false_fn):
|
|
return cond(x.sum() < 0, true_fn, false_fn, (x,))
|
|
|
|
inp = torch.ones(3, 4)
|
|
exp_out = inp.sin()
|
|
iter_n = torch._dynamo.config.recompile_limit + 1
|
|
|
|
# Need functions that cause recompilations
|
|
def get_dummy_fns(str):
|
|
def dummy_cos(x):
|
|
return x.cos() + len(str) - len(str)
|
|
|
|
def dummy_sin(x):
|
|
return x.sin() + len(str) - len(str)
|
|
|
|
return dummy_cos, dummy_sin
|
|
|
|
for i in range(iter_n):
|
|
# we fail guards each iter because `str(i)` is different
|
|
self.assertEqual(foo(inp, *get_dummy_fns(str(i))), exp_out)
|
|
|
|
# each iteration captures a cond and a getitem from the tuple output
|
|
self.assertEqual(counters["stats"]["calls_captured"], iter_n * 2)
|
|
self.assertEqual(counters["stats"]["unique_graphs"], iter_n)
|
|
|
|
def test_cond_with_consecutive_make_fx_symbolic(self):
|
|
def true_fn(x):
|
|
return x - x.cos()
|
|
|
|
def false_fn(x):
|
|
return x + x.sin()
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3))
|
|
for inp in inps:
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 4
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
cos = torch.ops.aten.cos.default(arg0_1)
|
|
sub = torch.ops.aten.sub.Tensor(arg0_1, cos); arg0_1 = cos = None
|
|
return (sub,)""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
gm.false_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
sin = torch.ops.aten.sin.default(arg0_1)
|
|
add = torch.ops.aten.add.Tensor(arg0_1, sin); arg0_1 = sin = None
|
|
return (add,)""",
|
|
)
|
|
|
|
def _create_test_fns_for_cond(
|
|
self, pred, inner_most_fn, operands, closure_list, nested_level
|
|
):
|
|
if nested_level == 0:
|
|
if len(closure_list) > 0:
|
|
|
|
def true_fn(*operands):
|
|
return inner_most_fn(*operands) + inner_most_fn(*closure_list)
|
|
|
|
def false_fn(*operands):
|
|
return inner_most_fn(*operands) - inner_most_fn(*closure_list)
|
|
|
|
else:
|
|
|
|
def true_fn(*operands):
|
|
return inner_most_fn(*operands)
|
|
|
|
def false_fn(*operands):
|
|
return inner_most_fn(*operands)
|
|
|
|
def fn(*operands):
|
|
if len(operands) == 0 and len(closure_list) == 0:
|
|
return torch.zeros(1)
|
|
return cond(pred, true_fn, false_fn, operands)
|
|
|
|
return operands, fn
|
|
else:
|
|
args, inner_fn = self._create_test_fns_for_cond(
|
|
pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1
|
|
)
|
|
|
|
def true_fn(*operands):
|
|
return inner_most_fn(*operands) + inner_fn(*args)
|
|
|
|
def false_fn(*operands):
|
|
return inner_most_fn(*operands) - inner_fn(*args)
|
|
|
|
def fn(*operands):
|
|
if len(operands) == 0 and len(closure_list) == 0:
|
|
return torch.ones(1)
|
|
return cond(pred, true_fn, false_fn, operands)
|
|
|
|
return operands, fn
|
|
|
|
def _init_predicate(self, pred_type):
|
|
if pred_type == "bool":
|
|
return True
|
|
elif pred_type == "intTensor":
|
|
return torch.tensor(1)
|
|
elif pred_type == "floatTensor":
|
|
return torch.tensor(1.0)
|
|
elif pred_type == "boolTensor":
|
|
return torch.tensor(False)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def _init_fn(self, inner_fn_type):
|
|
if inner_fn_type == "function":
|
|
return reduce_func
|
|
elif inner_fn_type == "module":
|
|
return ReduceMod()
|
|
elif inner_fn_type == "object":
|
|
return ReduceObj()
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
@parametrize("predType", ["bool", "intTensor", "floatTensor", "boolTensor"])
|
|
@parametrize("innerFnType", ["function", "module", "object"])
|
|
@parametrize("nOperands", [0, 1])
|
|
@parametrize("nClosure", [0, 1])
|
|
@parametrize("nesting", [0, 2])
|
|
def test_cond_tracing_with_valid_inputs(
|
|
self, predType, innerFnType, nOperands, nClosure, nesting
|
|
):
|
|
pred = self._init_predicate(predType)
|
|
inner_fn = self._init_fn(innerFnType)
|
|
operands = [torch.ones(2, 3) + i for i in range(nOperands)]
|
|
closure = [torch.ones(2, 3) - i for i in range(nClosure)]
|
|
args, fn = self._create_test_fns_for_cond(
|
|
pred, inner_fn, operands, closure, nesting
|
|
)
|
|
eager_res = fn(*args)
|
|
for tracing_mode in ["symbolic", "fake", "real"]:
|
|
# set _allow_non_fake_inputs = True to allow fake prop through closures
|
|
with self.subTest(tracing_mode=tracing_mode):
|
|
gm = make_fx(
|
|
fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True
|
|
)(*args)
|
|
self.assertEqual(gm(*args), eager_res)
|
|
|
|
@parametrize("predType", ["boolTensor"])
|
|
@parametrize("innerFnType", ["function", "module", "object"])
|
|
@parametrize("nOperands", [1, 2])
|
|
@parametrize("nClosure", [0, 1])
|
|
@parametrize("nesting", [0])
|
|
def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting):
|
|
pred = self._init_predicate(predType)
|
|
inner_fn = self._init_fn(innerFnType)
|
|
operands = [torch.ones(2, 3) + i for i in range(nOperands)]
|
|
closure = [torch.ones(2, 3) - i for i in range(nClosure)]
|
|
args, fn = self._create_test_fns_for_cond(
|
|
pred, inner_fn, operands, closure, nesting
|
|
)
|
|
eager_res = fn(*args)
|
|
out = torch.vmap(fn)(*args)
|
|
if nClosure == 0:
|
|
self.assertEqual(eager_res, out)
|
|
else:
|
|
self.assertEqual(eager_res, out[0])
|
|
self.assertEqual(eager_res, out[1])
|
|
|
|
def test_cond_vmap_simple(self):
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: x + 100,
|
|
false_fn=lambda x: x.clone(),
|
|
operands=(x,),
|
|
)
|
|
|
|
a = torch.arange(15).reshape((3, 5))
|
|
res = torch.vmap(fn, in_dims=(0,))(a)
|
|
self.assertEqual(res.shape, (3, 5))
|
|
self.assertEqual(res, a + 100)
|
|
|
|
def test_cond_vmap_multiple_inputs(self):
|
|
def fn(x, y):
|
|
return torch.cond(
|
|
pred=x.sum() < y.sum(),
|
|
true_fn=lambda x, y: x + 100,
|
|
false_fn=lambda x, y: y.clone(),
|
|
operands=(x, y),
|
|
)
|
|
|
|
a = torch.arange(15).reshape(3, 5)
|
|
b = torch.ones_like(a) + 3
|
|
res = torch.vmap(fn, in_dims=(0, 0))(a, b)
|
|
expected = torch.tensor(
|
|
[[100, 101, 102, 103, 104], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]]
|
|
)
|
|
self.assertEqual(res.shape, (3, 5))
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_cond_vmap_single_input_with_closure(self):
|
|
a = torch.ones((3, 5)) + 3
|
|
c = torch.arange(5)
|
|
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: x + c,
|
|
false_fn=lambda x: x - c,
|
|
operands=(x,),
|
|
)
|
|
|
|
res = torch.vmap(fn, in_dims=(0,))(
|
|
a,
|
|
)
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
res = torch.vmap(fn, in_dims=(0,))(
|
|
a,
|
|
)
|
|
self.assertEqual(a + c, res)
|
|
|
|
def test_cond_vmap_multiple_args_with_closure(self):
|
|
a = torch.ones((3, 5), dtype=torch.int64) + 3
|
|
b = torch.arange(15).reshape(3, 5)
|
|
c = torch.arange(5)
|
|
|
|
def fn(x, y):
|
|
return torch.cond(
|
|
pred=torch.tensor([False]),
|
|
true_fn=lambda x, y: x + c,
|
|
false_fn=lambda x, y: y - c,
|
|
operands=(x, y),
|
|
)
|
|
|
|
res = torch.vmap(fn)(a, b)
|
|
self.assertEqual(b - c, res)
|
|
|
|
@parametrize("nClosure", [0, 1])
|
|
def test_cond_vmap_multiple_outputs(self, nClosure):
|
|
if nClosure:
|
|
c = torch.ones(5, dtype=torch.int64) + 5
|
|
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: (x + c, x - c),
|
|
false_fn=lambda x: (x.clone(), x.clone()),
|
|
operands=(x,),
|
|
)
|
|
|
|
else:
|
|
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: (x + 1, x - 1),
|
|
false_fn=lambda x: (x.clone(), x.clone()),
|
|
operands=(x,),
|
|
)
|
|
|
|
a = torch.arange(15).reshape(3, 5)
|
|
res = torch.vmap(fn)(
|
|
a,
|
|
)
|
|
self.assertEqual(len(res), 2)
|
|
if nClosure:
|
|
self.assertEqual(res, (a + c, a - c))
|
|
else:
|
|
self.assertEqual(res, (a + 1, a - 1))
|
|
|
|
@parametrize("boolcond", [True, False])
|
|
def test_vmap_vmap(self, boolcond):
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]) if not boolcond else True,
|
|
true_fn=lambda x: x + 1,
|
|
false_fn=lambda x: x - 1,
|
|
operands=(x,),
|
|
)
|
|
|
|
def wrapper(x):
|
|
return torch.vmap(fn)(x)
|
|
|
|
a = torch.ones((3, 4, 5))
|
|
res = torch.vmap(wrapper)(a)
|
|
self.assertEqual(res, a + 1)
|
|
|
|
def test_cond_trace_set__and_mutate_input(self):
|
|
def f(a, tmp):
|
|
a_view = a.view(-1)
|
|
with torch.no_grad():
|
|
a.set_(tmp)
|
|
a_view.mul_(2)
|
|
return a + tmp
|
|
|
|
inp = torch.ones(3, 3, requires_grad=True)
|
|
tmp = torch.ones(3, 3, requires_grad=True)
|
|
# graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {}
|
|
# due to set_
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
|
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_cond_trace_set__and_mutate_intermediate(self):
|
|
def f(a, tmp):
|
|
a = a.clone()
|
|
a_view = a.view(-1)
|
|
tmp = tmp.clone()
|
|
with torch.no_grad():
|
|
a.set_(tmp)
|
|
a_view.mul_(2)
|
|
return a + tmp
|
|
|
|
inp = torch.ones(3, 3, requires_grad=True)
|
|
tmp = torch.ones(3, 3, requires_grad=True)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor:
|
|
return torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cannot mutate tensors with frozen storage"
|
|
):
|
|
out = torch.compile(Mod(), backend="aot_eager")(inp, tmp)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cannot mutate tensors with frozen storage"
|
|
):
|
|
out = torch.compile(Mod(), backend="inductor")(inp, tmp)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
out = torch.compile(Mod(), backend=backend)(inp, tmp)
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].cond_true_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, l_inp_, l_tmp_):
|
|
l_inp__1 = l_inp_
|
|
l_tmp__1 = l_tmp_
|
|
a = l_inp__1.clone(); l_inp__1 = None
|
|
a_view = a.view(-1)
|
|
tmp = l_tmp__1.clone(); l_tmp__1 = None
|
|
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
|
set_ = a.set_(tmp); set_ = None
|
|
mul_ = a_view.mul_(2); a_view = mul_ = None
|
|
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
|
add = a + tmp; a = tmp = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
self.assertEqual(out, f(inp, tmp))
|
|
|
|
@skipIfCrossRef # Args get renamed to r in crossref mode
|
|
@parametrize("requires_grad", [True, False])
|
|
def test_cond_symint_operands(self, requires_grad):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.num = 3
|
|
|
|
def forward(self, a, b):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda a, b: a + b + self.num,
|
|
false_fn=lambda a, b: a - b - self.num,
|
|
operands=(a, b),
|
|
)
|
|
|
|
a = torch.ones(3, 3, requires_grad=requires_grad)
|
|
b = torch.ones(3, 3, requires_grad=requires_grad)
|
|
out = torch.compile(Mod(), backend=backend, dynamic=True)(a, b)
|
|
self.assertEqual(out, Mod()(a, b))
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
|
|
l_a_ = L_a_
|
|
l_b_ = L_b_
|
|
tensor = torch.tensor([True])
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""", # noqa: B950
|
|
)
|
|
|
|
def test_two_hops_not_sharing_code_obj(self):
|
|
pred, args = torch.tensor(True), (torch.ones(3, 3),)
|
|
|
|
def fn1(x):
|
|
return x + 1
|
|
|
|
def fn2(x):
|
|
return x - 1
|
|
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
# Tests rely on automatic_dynamic = True
|
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=True):
|
|
cnt = CompileCounter()
|
|
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
args = (torch.randn(3, 3),)
|
|
# No recompilation
|
|
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def cond_fn(x):
|
|
return x.sum() > 0
|
|
|
|
args = (torch.randn(4, 4),)
|
|
torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
|
|
# recompilation
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
args = (torch.randn(4, 4),)
|
|
torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# With recompilation due to automatic dynamic
|
|
# This also proves that while_loop doesn't share code obj with cond
|
|
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_hop_raises_if_not_overriding_call(self):
|
|
class WrongHop(torch._ops.HigherOrderOperator):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(TypeError, "WrongHop"):
|
|
WrongHop("wrong_hop")
|
|
|
|
def test_scan_functionalized(self):
|
|
def f(init, xs):
|
|
return scan(get_scan_combine_fn("add", False), init, xs, dim=1)
|
|
|
|
example_inputs = torch.ones(5, 7, 4)
|
|
example_init = torch.ones(5, 4)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(
|
|
functional_f(example_init, example_inputs), f(example_init, example_inputs)
|
|
)
|
|
|
|
def test_scan_functionalized_elem_mutation(self):
|
|
def add1(x, y):
|
|
x.add_(4)
|
|
return x + y, x + y
|
|
|
|
def f(init, xs):
|
|
return scan(add1, init, xs, dim=1)
|
|
|
|
example_inputs = torch.ones(5, 7, 4)
|
|
example_init = torch.ones(5, 4)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
|
# RuntimeError,
|
|
# "torch.scan might be modifying the input!",
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
|
# torch._dynamo.exc.TorchDynamoException,
|
|
# "Unexpected exception when running generated GraphModule.*"
|
|
Exception,
|
|
".*",
|
|
):
|
|
functional_f(example_init, example_inputs)
|
|
|
|
def add2(x, y):
|
|
y.add_(4)
|
|
return x + y, x + y
|
|
|
|
def f(init, xs):
|
|
return scan(add2, init, xs, dim=1)
|
|
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
|
# Should be
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
|
# RuntimeError,
|
|
# "torch.scan might be modifying the input!",
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
|
# torch._dynamo.exc.TorchDynamoException,
|
|
# "Unexpected exception when running generated GraphModule.*"
|
|
Exception,
|
|
".*",
|
|
):
|
|
functional_f(example_init, example_inputs)
|
|
|
|
def test_scan_functionalized_elem_alias(self):
|
|
def add(x, y):
|
|
return x, x
|
|
|
|
def f(init, xs):
|
|
return scan(add, init, xs, dim=1)
|
|
|
|
example_inputs = torch.ones(5, 7, 4)
|
|
example_init = torch.ones(5, 4)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
|
# Should be
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
# "scan must be captured completely with torch.compile.*",
|
|
Exception,
|
|
".*",
|
|
):
|
|
functional_f(example_init, example_inputs)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
def test_scan_pytree_closure(self):
|
|
param_buffer = ({"param": torch.randn(3, 3)}, (torch.randn(3),))
|
|
|
|
def add(carry, x):
|
|
ret = (carry @ param_buffer[0]["param"]) @ x + param_buffer[1][0]
|
|
return ret, ret.sum()
|
|
|
|
def f(init, xs):
|
|
return scan(add, init, xs)
|
|
|
|
init = torch.randn(4, 3)
|
|
xs = torch.randn(3, 3, 3)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
eager_out = f(init, xs)
|
|
compiled_out = torch.compile(f, backend=backend)(init, xs)
|
|
exp_out = _fake_scan(add, init, xs)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
if TEST_WITH_CROSSREF:
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_cell_contents_0_param_ : torch.Tensor, L_add_closure_0_cell_contents_1_0_ : torch.Tensor):
|
|
l_init_ = L_init_
|
|
l_xs_ = L_xs_
|
|
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
|
|
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
|
|
r = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [r], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = r = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
|
carry = scan[0]
|
|
out = scan[1]; scan = None
|
|
return (carry, out)""", # noqa: B950
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_cell_contents_0_param_ : torch.Tensor, L_add_closure_0_cell_contents_1_0_ : torch.Tensor):
|
|
l_init_ = L_init_
|
|
l_xs_ = L_xs_
|
|
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
|
|
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
|
|
movedim = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [movedim], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = movedim = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
|
carry = scan[0]
|
|
out = scan[1]; scan = None
|
|
return (carry, out)""", # noqa: B950
|
|
)
|
|
self.assertEqual(eager_out, exp_out)
|
|
self.assertEqual(compiled_out, exp_out)
|
|
|
|
@skipIfTorchDynamo("Skip because we're testing export")
|
|
@parametrize("strict", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_op_int_carry_export(self, strict, dynamic):
|
|
m, args = WHILE_LOOP_TESTS["int_carry"]
|
|
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None
|
|
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
if not strict and dynamic:
|
|
self.assertExpectedInline(
|
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x: "f32[s77, 3]";
|
|
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
|
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (0, x), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x = None
|
|
|
|
getitem_2: "Sym(u1)" = while_loop[0]
|
|
|
|
ge: "Sym(u1 >= 1)" = getitem_2 >= 1
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 1 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
gt_1: "Sym(u1 > 0)" = getitem_2 > 0
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
|
|
|
|
getitem_1: "f32[s77, 3]" = while_loop[1]; while_loop = None
|
|
|
|
add: "Sym(u1 + 1)" = getitem_2 + 1
|
|
|
|
add_1: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
|
|
|
|
lt: "Sym(u1 < s77)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
|
|
|
|
mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None
|
|
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
|
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
|
|
|
|
class while_loop_cond_graph_0(torch.nn.Module):
|
|
def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"):
|
|
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
|
|
|
lt: "Sym(u0 < s77)" = it_1 < sym_size_int_1; it_1 = sym_size_int_1 = None
|
|
return lt
|
|
|
|
class while_loop_body_graph_0(torch.nn.Module):
|
|
def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"):
|
|
clone: "f32[s77, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
|
|
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
|
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
|
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
|
|
copy_: "f32[3]" = torch.ops.aten.copy_.default(select, add); select = add = copy_ = None
|
|
add_1: "Sym(u0 + 1)" = it_1 + 1; it_1 = None
|
|
return (add_1, clone)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_while_loop_op_int_carry_compile(self, dynamic, backend):
|
|
m, args = WHILE_LOOP_TESTS["int_carry"]
|
|
if backend == "eager":
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=dynamic, backend=backend)
|
|
if (
|
|
isinstance(backend, EagerAndRecordGraphs)
|
|
and dynamic
|
|
and not TEST_WITH_CROSSREF
|
|
):
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
|
|
l_x_ = L_x_
|
|
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None
|
|
|
|
getitem_4: "Sym(u2)" = while_loop[0]
|
|
|
|
ge: "Sym(u2 >= 1)" = getitem_4 >= 1
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 1 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
gt_1: "Sym(u2 > 0)" = getitem_4 > 0
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u2 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
|
|
|
|
out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
|
|
|
|
gt: "Sym(u2 > 0)" = getitem_4 > 0
|
|
_check = torch._check(gt); gt = _check = None
|
|
|
|
add: "Sym(u2 + 1)" = getitem_4 + 1
|
|
|
|
add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
|
|
|
|
lt: "Sym(u2 < s77)" = getitem_4 < s77; s77 = None
|
|
|
|
mul: "Sym(2*u2)" = getitem_4 * 2; getitem_4 = None
|
|
ones: "f32[2*u2]" = torch.ones(mul); mul = None
|
|
return (add, add_1, lt, ones)
|
|
|
|
class cond_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint: "Sym(u0)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
size = child.size(); child = None
|
|
getitem: "Sym(s77)" = size[0]
|
|
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
|
|
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None
|
|
return lt
|
|
|
|
class body_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
x_clone: "f32[s77, s27]" = child_1.clone()
|
|
|
|
ge: "Sym(u1 >= 0)" = unbacked_symint_0 >= 0
|
|
_check = torch._check(ge); ge = _check = None
|
|
|
|
size = child_1.size(); child_1 = None
|
|
getitem: "Sym(s77)" = size[0]
|
|
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
|
|
lt: "Sym(u1 < s77)" = unbacked_symint_0 < getitem; getitem = None
|
|
_check_1 = torch._check(lt); lt = _check_1 = None
|
|
|
|
select: "f32[s27]" = x_clone.select(0, unbacked_symint_0)
|
|
select_1: "f32[s27]" = x_clone.select(0, unbacked_symint_0)
|
|
add: "f32[s27]" = select_1 + unbacked_symint_0; select_1 = None
|
|
copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None
|
|
|
|
add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None
|
|
return (add_1, x_clone)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip because we're testing export")
|
|
@parametrize("strict", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_while_loop_op_constant_and_symint_output_export(self, strict, dynamic):
|
|
m, args = WHILE_LOOP_TESTS["const_and_symint_output"]
|
|
dynamic_shapes = {"t": {0: torch.export.Dim("dim_t")}} if dynamic else None
|
|
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
# strict or dynamic gives a slightly different graph
|
|
if not strict and not dynamic:
|
|
self.assertExpectedInline(
|
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, t):
|
|
t: "f32[2, 3]";
|
|
|
|
t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec)
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(t)
|
|
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
|
to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None
|
|
item: "Sym(u0)" = torch.ops.aten.item.default(to); to = None
|
|
sin: "f32[2, 3]" = torch.ops.aten.sin.default(t)
|
|
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (2, 3, 1, 1, 1, 3, item, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = item = sin = None
|
|
|
|
getitem_8: "Sym(u8)" = while_loop[0]
|
|
getitem_9: "Sym(u9)" = while_loop[1]
|
|
getitem_10: "Sym(u10)" = while_loop[2]
|
|
getitem_11: "Sym(u11)" = while_loop[3]
|
|
getitem_12: "Sym(u12)" = while_loop[4]
|
|
getitem_13: "Sym(u13)" = while_loop[5]
|
|
getitem_14: "Sym(u14)" = while_loop[6]
|
|
|
|
getitem_7: "f32[2, 3]" = while_loop[7]; while_loop = None
|
|
|
|
add: "Sym(u8 + 1)" = getitem_8 + 1
|
|
add_1: "Sym(u9 + 1)" = getitem_9 + 1
|
|
add_2: "Sym(u10 + 1)" = getitem_10 + 1
|
|
add_3: "Sym(u11 + 1)" = getitem_11 + 1
|
|
add_4: "Sym(u12 + 1)" = getitem_12 + 1
|
|
add_5: "Sym(u13 + 1)" = getitem_13 + 1
|
|
add_6: "Sym(u14 + 1)" = getitem_14 + 1
|
|
add_7: "f32[2, 3]" = torch.ops.aten.add.Tensor(getitem_7, 1)
|
|
|
|
add_8: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_8); getitem_8 = None
|
|
add_9: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_9); getitem_9 = None
|
|
add_10: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_10); getitem_10 = None
|
|
add_11: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_11); getitem_11 = None
|
|
add_12: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_12); getitem_12 = None
|
|
add_13: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_13); getitem_13 = None
|
|
add_14: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_14); getitem_14 = None
|
|
add_15: "f32[2, 3]" = torch.ops.aten.add.Tensor(getitem_7, t); getitem_7 = t = None
|
|
return pytree.tree_unflatten((add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15), self._out_spec)
|
|
|
|
class while_loop_cond_graph_0(torch.nn.Module):
|
|
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[2, 3]"):
|
|
mul: "Sym(u3*u4)" = c1_1 * c2_1; c1_1 = c2_1 = None
|
|
mul_1: "Sym(u3*u4*u5)" = mul * c3_1; mul = c3_1 = None
|
|
mul_2: "Sym(u1*u2)" = a_1 * b_1; a_1 = b_1 = None
|
|
lt: "Sym(u3*u4*u5 < u1*u2)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class while_loop_body_graph_0(torch.nn.Module):
|
|
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[2, 3]"):
|
|
add: "Sym(u7 + 1)" = u0_1 + 1; u0_1 = None
|
|
add_1: "f32[2, 3]" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = None
|
|
return (b_1, c1_1, c2_1, c3_1, a_1, 0, add, add_1)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_while_loop_op_constant_and_symint_output_compile(self, dynamic, backend):
|
|
m, args = WHILE_LOOP_TESTS["const_and_symint_output"]
|
|
if backend == "eager":
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=dynamic, backend=backend)
|
|
if (
|
|
isinstance(backend, EagerAndRecordGraphs)
|
|
# cross ref or dynamic gives a slightly different graph
|
|
and not dynamic
|
|
and not TEST_WITH_CROSSREF
|
|
):
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_t_: "f32[2, 3]"):
|
|
l_t_ = L_t_
|
|
|
|
sum_1: "f32[]" = l_t_.sum()
|
|
to: "i64[]" = sum_1.to(torch.int64); sum_1 = None
|
|
item: "Sym(u0)" = to.item(); to = None
|
|
sin: "f32[2, 3]" = l_t_.sin()
|
|
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, sin), ()); cond_fn_0 = body_fn_0 = item = sin = None
|
|
|
|
getitem_8: "Sym(u15)" = while_loop[0]
|
|
getitem_9: "Sym(u16)" = while_loop[1]
|
|
getitem_10: "Sym(u17)" = while_loop[2]
|
|
getitem_11: "Sym(u18)" = while_loop[3]
|
|
getitem_12: "Sym(u19)" = while_loop[4]
|
|
getitem_13: "Sym(u20)" = while_loop[5]
|
|
getitem_14: "Sym(u21)" = while_loop[6]
|
|
|
|
child: "f32[2, 3]" = while_loop[7]; while_loop = None
|
|
|
|
add: "Sym(u15 + 1)" = getitem_8 + 1
|
|
add_1: "Sym(u16 + 1)" = getitem_9 + 1
|
|
add_2: "Sym(u17 + 1)" = getitem_10 + 1
|
|
add_3: "Sym(u18 + 1)" = getitem_11 + 1
|
|
add_4: "Sym(u19 + 1)" = getitem_12 + 1
|
|
add_5: "Sym(u20 + 1)" = getitem_13 + 1
|
|
add_6: "Sym(u21 + 1)" = getitem_14 + 1
|
|
add_7: "f32[2, 3]" = child + 1
|
|
|
|
add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None
|
|
add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None
|
|
add_10: "f32[2, 3]" = getitem_10 + l_t_; getitem_10 = None
|
|
add_11: "f32[2, 3]" = getitem_11 + l_t_; getitem_11 = None
|
|
add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None
|
|
add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None
|
|
add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None
|
|
add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None
|
|
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15)
|
|
|
|
class cond_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unbacked_symint_1: "Sym(u3)", unbacked_symint_2: "Sym(u4)", unbacked_symint_3: "Sym(u5)", unbacked_symint_4: "Sym(u6)", unbacked_symint_5: "Sym(u7)", child: "f32[2, 3]"):
|
|
mul: "Sym(u3*u4)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
|
|
mul_1: "Sym(u3*u4*u5)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
|
|
mul_2: "Sym(u1*u2)" = unbacked_symint * unbacked_symint_0; unbacked_symint = unbacked_symint_0 = None
|
|
lt: "Sym(u3*u4*u5 < u1*u2)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class body_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", unbacked_symint_8: "Sym(u10)", unbacked_symint_9: "Sym(u11)", unbacked_symint_10: "Sym(u12)", unbacked_symint_11: "Sym(u13)", unbacked_symint_12: "Sym(u14)", child_1: "f32[2, 3]"):
|
|
add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None
|
|
child: "f32[2, 3]" = child_1 + 1; child_1 = None
|
|
return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip because we're testing export")
|
|
@parametrize("strict", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_op_pytree_int_carry_export(self, strict, dynamic):
|
|
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
|
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None
|
|
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
if strict and dynamic:
|
|
self.assertExpectedInline(
|
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x: "f32[s77, 3]";
|
|
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
|
|
|
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
|
|
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 2, 2, 3, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = sin = None
|
|
|
|
getitem_6: "Sym(u10)" = while_loop[0]
|
|
getitem_7: "Sym(u11)" = while_loop[1]
|
|
getitem_8: "Sym(u12)" = while_loop[2]
|
|
getitem_9: "Sym(u13)" = while_loop[3]
|
|
getitem_10: "Sym(u14)" = while_loop[4]
|
|
|
|
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
|
|
|
|
add: "Sym(u12 + 1)" = getitem_8 + 1
|
|
add_1: "Sym(u13 + 1)" = getitem_9 + 1
|
|
add_2: "Sym(u14 + 1)" = getitem_10 + 1
|
|
|
|
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
|
|
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
|
|
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
|
|
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
|
|
|
|
class while_loop_cond_graph_0(torch.nn.Module):
|
|
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
|
|
mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
|
|
mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None
|
|
mul_2: "Sym(u20*u21)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
|
|
lt: "Sym(u22*u23*u24 < u20*u21)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class while_loop_body_graph_0(torch.nn.Module):
|
|
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
|
|
add: "Sym(u20 + 1)" = arg0_1 + 1; arg0_1 = None
|
|
add_1: "Sym(u21 + 1)" = arg1_1 + 1; arg1_1 = None
|
|
|
|
add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None
|
|
add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None
|
|
add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None
|
|
|
|
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
|
|
return (add, add_1, add_2, add_3, add_4, add_5)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend):
|
|
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
|
if backend == "eager":
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=dynamic, backend=backend)
|
|
if (
|
|
isinstance(backend, EagerAndRecordGraphs)
|
|
and dynamic
|
|
and not TEST_WITH_CROSSREF
|
|
):
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
|
|
l_x_ = L_x_
|
|
|
|
child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
|
|
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None
|
|
|
|
getitem_10: "Sym(u10)" = while_loop[0]
|
|
getitem_11: "Sym(u11)" = while_loop[1]
|
|
getitem_12: "Sym(u12)" = while_loop[2]
|
|
getitem_13: "Sym(u13)" = while_loop[3]
|
|
getitem_14: "Sym(u14)" = while_loop[4]
|
|
|
|
out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None
|
|
|
|
add: "Sym(u12 + 1)" = getitem_12 + 1
|
|
add_1: "Sym(u13 + 1)" = getitem_13 + 1
|
|
add_2: "Sym(u14 + 1)" = getitem_14 + 1
|
|
|
|
add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None
|
|
add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None
|
|
add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None
|
|
return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x)
|
|
|
|
class cond_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
|
|
mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
|
|
mul_2: "Sym(u0*u1)" = unbacked_symint * unbacked_symint_0; unbacked_symint = unbacked_symint_0 = None
|
|
lt: "Sym(u2*u3*u4 < u0*u1)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class body_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint_4: "Sym(u5)", unbacked_symint_5: "Sym(u6)", unbacked_symint_6: "Sym(u7)", unbacked_symint_7: "Sym(u8)", unbacked_symint_8: "Sym(u9)", child_2: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
add: "Sym(u5 + 1)" = unbacked_symint_4 + 1; unbacked_symint_4 = None
|
|
add_1: "Sym(u6 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None
|
|
|
|
add_2: "Sym(u7 + 1)" = unbacked_symint_6 + 1; unbacked_symint_6 = None
|
|
add_3: "Sym(u8 + 1)" = unbacked_symint_7 + 1; unbacked_symint_7 = None
|
|
add_4: "Sym(u9 + 1)" = unbacked_symint_8 + 1; unbacked_symint_8 = None
|
|
|
|
child: "f32[s77, s27]" = child_2 + 1; child_2 = None
|
|
return (add, add_1, add_2, add_3, add_4, child)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_input_output_alias(self):
|
|
def fn(f, *args):
|
|
return torch.cond(args[0].sum() > 0, f, f, args)
|
|
|
|
x = torch.randn(2, 2)
|
|
for f in ALIAS_FN:
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
torch.compile(fn)(f, x)
|
|
|
|
def test_input_input_alias(self):
|
|
def fn(view_f, arg):
|
|
def f(arg1, arg2):
|
|
return arg1.cos(), arg2.sin()
|
|
|
|
return torch.cond(arg.sum() > 0, f, f, (arg, view_f(arg)))
|
|
|
|
x = torch.randn(2, 2)
|
|
# ALIAS_FN[0] is an identical function, cond optimizes the duplication
|
|
# as a result of auto lifting.
|
|
for view_f in ALIAS_FN[1:]:
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
torch.compile(fn)(view_f, x)
|
|
|
|
@parametrize("inference_mode", [True, False])
|
|
def test_input_mutation(self, inference_mode):
|
|
def fn(view_f, *args):
|
|
def mutate_f(x):
|
|
v = view_f(x)
|
|
v.add_(1)
|
|
return v.sin()
|
|
|
|
return torch.cond(args[0].sum() > 0, mutate_f, mutate_f, args)
|
|
|
|
x = torch.randn(2, 2)
|
|
for f in ALIAS_FN:
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
torch.compile(fn)(f, x)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
with torch.inference_mode(inference_mode):
|
|
torch.compile(fn)(f, x)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
def test_while_loop_unbacked_bindings(self):
|
|
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=True, backend=backend)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
while_loop_nodes = backend.graphs[0].graph.find_nodes(
|
|
op="call_function", target=torch.ops.higher_order.while_loop
|
|
)
|
|
self.assertEqual(len(while_loop_nodes), 1)
|
|
self.assertEqual(len(while_loop_nodes[0].meta.get("unbacked_bindings")), 5)
|
|
|
|
# Return the .module() graph str result of non-strict export
|
|
def _check_export_ret_graph_str(self, fn, args, dynamic_shapes=None) -> str:
|
|
strict_ep = torch.export.export(
|
|
fn, args, dynamic_shapes=dynamic_shapes, strict=True
|
|
)
|
|
non_strict_ep = torch.export.export(
|
|
fn, args, dynamic_shapes=dynamic_shapes, strict=False
|
|
)
|
|
eager_res = fn(*args)
|
|
self.assertEqual(strict_ep.module()(*args), eager_res)
|
|
self.assertEqual(non_strict_ep.module()(*args), eager_res)
|
|
return normalize_gm(non_strict_ep.module().print_readable(print_output=False))
|
|
|
|
@skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.")
|
|
def test_cond_eager_run_with_item(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b1, b2, c):
|
|
def true_fn(x):
|
|
return x * b1.item()
|
|
|
|
def false_fn(x):
|
|
return x * b2.item()
|
|
|
|
r = torch.cond(a, true_fn, false_fn, (c,))
|
|
return r * 2
|
|
|
|
x = torch.randn(10, requires_grad=True)
|
|
args = (
|
|
torch.tensor(True),
|
|
torch.tensor([3]),
|
|
torch.tensor([4]),
|
|
x,
|
|
)
|
|
model = M()
|
|
torch.export.export(model, args, strict=True)
|
|
graph_str = self._check_export_ret_graph_str(model, args, None)
|
|
self.assertExpectedInline(
|
|
graph_str,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, a, b1, b2, c):
|
|
a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]";
|
|
|
|
a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, (c, b1, b2)); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
|
|
getitem: "f32[10]" = cond[0]; cond = None
|
|
|
|
mul: "f32[10]" = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None
|
|
return pytree.tree_unflatten((mul,), self._out_spec)
|
|
|
|
class true_graph_0(torch.nn.Module):
|
|
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
|
|
item: "Sym(u0)" = torch.ops.aten.item.default(b1); b1 = None
|
|
|
|
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
|
|
return (mul,)
|
|
|
|
class false_graph_0(torch.nn.Module):
|
|
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
|
|
item: "Sym(u1)" = torch.ops.aten.item.default(b2); b2 = None
|
|
|
|
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
|
|
return (mul,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_merge_graph_preserves_ph_meta(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
return torch.cond(x.sum() > 5, true_fn, false_fn, (x,))
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
_ = torch.compile(M(), backend=backend)(
|
|
torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)
|
|
)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
gm = backend.graphs[0]
|
|
subgraph_attr = gm.graph.find_nodes(op="get_attr")[0]
|
|
subgm = getattr(gm, subgraph_attr.target)
|
|
for ph in subgm.graph.find_nodes(op="placeholder"):
|
|
self.assertTrue("example_value" in ph.meta)
|
|
|
|
@skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.")
|
|
def test_cond_symint_closure(self):
|
|
from torch.export import Dim
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
# When exporting with non-strict: a and b are symints,
|
|
# so torch.compile need to wrap and trace symint inputs.
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
args = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3))
|
|
model = M()
|
|
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
|
non_strict_graph_str = self._check_export_ret_graph_str(
|
|
model, args, dynamic_shapes
|
|
)
|
|
self.assertExpectedInline(
|
|
non_strict_graph_str,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
x: "f32[s68, 3]"; y: "f32[s17]"; z: "f32[s68, 3]";
|
|
|
|
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
|
|
sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None
|
|
sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0)
|
|
|
|
gt: "Sym(s68 > 5)" = sym_size_int_5 > 5
|
|
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_5, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_5 = z = None
|
|
getitem: "f32[s68, 3]" = cond[0]; cond = None
|
|
return pytree.tree_unflatten((getitem,), self._out_spec)
|
|
|
|
class true_graph_0(torch.nn.Module):
|
|
def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"):
|
|
add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
|
|
return (add,)
|
|
|
|
class false_graph_0(torch.nn.Module):
|
|
def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"):
|
|
mul: "f32[s68, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_5); z = sym_size_int_5 = None
|
|
|
|
add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
|
|
return (add,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
# unbacked symint inputs are created during non-strict export,
|
|
# which causes a graph break
|
|
@unittest.expectedFailure
|
|
def test_cond_unbacked_symint_closure(self):
|
|
from torch.export import Dim
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
# c is an unbacked symint in non-strict export
|
|
c = y.sum().item()
|
|
|
|
def true_fn(x):
|
|
return x + a + c
|
|
|
|
def false_fn(x):
|
|
return x + b * z * c
|
|
|
|
# When exporting with non-strict: a and b are symints,
|
|
# so torch.compile need to wrap and trace symint inputs.
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
args = (torch.ones(3, 3), torch.ones(5, dtype=torch.int32), torch.ones(3, 3))
|
|
model = M()
|
|
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
|
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
|
|
|
@skipIfTorchDynamo(
|
|
"Skip because _merge_output is not intended for dynamo to compile"
|
|
)
|
|
def test_merge_output(self):
|
|
from torch._higher_order_ops.cond import _merge_output
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
|
|
valid_test_cases = [
|
|
# [(size1, stride1), (size2, stride2), (expected_stride, expected_size)]
|
|
[((3,), (1,)), ((4,), (1,)), ("(u0,)", "(1,)")],
|
|
[((1, 3), (3, 1)), ((3, 2), (2, 1)), ("(u0, u1)", "(u1, 1)")],
|
|
[((2, 1), (1, 1)), ((7, 3), (3, 1)), ("(u0, u1)", "(u1, 1)")],
|
|
[((5, 5), (1, 5)), ((4, 5), (1, 4)), ("(u0, 5)", "(1, u0)")],
|
|
[
|
|
((7, 3, 1), (1, 7, 1)),
|
|
((4, 3, 3), (3, 12, 1)),
|
|
("(u0, 3, u1)", "(u1, u0*u1, 1)"),
|
|
],
|
|
[
|
|
((5, 7, 4), (7, 1, 35)),
|
|
((7, 4, 4), (4, 1, 28)),
|
|
("(u0, u1, 4)", "(u1, 1, u0*u1)"),
|
|
],
|
|
[
|
|
((1, 6, 3, 2), (36, 1, 6, 18)),
|
|
((4, 2, 2, 6), (24, 1, 2, 4)),
|
|
("(u0, u1, u2, u3)", "(u1*u2*u3, 1, u1, u1*u2)"),
|
|
],
|
|
[
|
|
((6, 1, 6, 3), (18, 1, 1, 6)),
|
|
((2, 1, 3, 4), (12, 1, 1, 3)),
|
|
("(u0, 1, u1, u2)", "(u1*u2, 1, 1, u1)"),
|
|
],
|
|
[
|
|
((3, 1, 2, 4, 1), (8, 8, 4, 1, 1)),
|
|
((2, 4, 1, 4, 1), (16, 4, 4, 1, 1)),
|
|
("(u0, u1, u2, 4, 1)", "(4*u1*u2, 4*u2, 4, 1, 1)"),
|
|
],
|
|
]
|
|
|
|
def _inner(case):
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
|
|
|
(size1, stride1), (size2, stride2), (merged_size, merged_stride) = case
|
|
with fake_mode:
|
|
t1 = torch.empty_strided(size1, stride1)
|
|
t2 = torch.empty_strided(size2, stride2)
|
|
out = _merge_output(t1, t2, fake_mode)
|
|
self.assertEqual(str(tuple(out.size())), merged_size)
|
|
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
|
|
|
for case in valid_test_cases:
|
|
_inner(case)
|
|
|
|
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
|
|
invalid_test_cases = [
|
|
# [(size1, stride1), (size2, stride2)]
|
|
[((1,), (1,)), ((1,), (0,))],
|
|
[
|
|
((1, 3), (1, 1)),
|
|
((5, 6), (6, 1)),
|
|
], # t1 is not contiguous, t2 is contiguous
|
|
[
|
|
((2, 1), (1, 1)),
|
|
((7, 3), (1, 3)),
|
|
], # t1 is contiguous, t2 is not contiguous
|
|
[
|
|
((5, 4), (4, 1)),
|
|
((5, 5), (1, 5)),
|
|
], # t1 is contiguous, t2 is not contiguous
|
|
[((7, 3, 1), (1, 7, 1)), ((4, 3, 3), (9, 1, 3))], # layout is different
|
|
[((5, 7, 4), (7, 1, 35)), ((7, 4, 4), (4, 28, 1))], # layout is different
|
|
[
|
|
((1, 6, 3, 2), (36, 1, 6, 18)),
|
|
((4, 1, 1, 6), (1, 4, 4, 4)),
|
|
], # layout is different
|
|
[
|
|
((6, 1, 6, 3), (18, 1, 1, 6)),
|
|
((1, 1, 1, 1), (1, 1, 1, 1)),
|
|
], # layout is different
|
|
[
|
|
((6, 1, 1, 6, 3), (3, 18, 18, 18, 1)),
|
|
((5, 1, 2, 1, 1), (2, 10, 1, 10, 1)),
|
|
], # layout is different
|
|
]
|
|
for case in invalid_test_cases:
|
|
with self.assertRaisesRegex(Exception, r"."):
|
|
_inner(case)
|
|
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_cond_mismatched_branch_output(self, dynamic, backend):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
# clone the outputs so branches have the same storage_offset
|
|
return (x + a)[2:].clone()
|
|
|
|
def false_fn(x):
|
|
# clone the outputs so branches have the same storage_offset
|
|
return (x + b * z)[:2].clone()
|
|
|
|
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
|
|
return y.sum() - ret
|
|
|
|
m = M()
|
|
x, y, z = torch.randn(5, 4), torch.randn(5, 4), torch.randn(5, 4)
|
|
out = m(x, y, z)
|
|
if not (backend == "eager" and dynamic and not TEST_WITH_CROSSREF):
|
|
compiled_out = torch.compile(
|
|
m, backend=backend, dynamic=dynamic, fullgraph=True
|
|
)(x, y, z)
|
|
self.assertEqual(compiled_out, out)
|
|
else:
|
|
bk = EagerAndRecordGraphs()
|
|
compiled_out = torch.compile(
|
|
m, backend=bk, dynamic=dynamic, fullgraph=True
|
|
)(x, y, z)
|
|
self.assertEqual(compiled_out, out)
|
|
self.assertExpectedInline(
|
|
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"):
|
|
l_y_ = L_y_
|
|
l_z_ = L_z_
|
|
l_x_ = L_x_
|
|
|
|
sum_1: "f32[]" = l_x_.sum()
|
|
gt: "b8[]" = sum_1 > 0; sum_1 = None
|
|
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None
|
|
|
|
getitem_5: "f32[u0, s94]" = cond[0]
|
|
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
|
|
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
ret: "f32[u0, s94]" = cond[0]; cond = None
|
|
|
|
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
|
|
sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None
|
|
return (sub,)
|
|
|
|
class cond_true_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"):
|
|
l_x__1 = l_x_
|
|
s94_1 = s94
|
|
|
|
add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None
|
|
getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None
|
|
clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None
|
|
return (clone,)
|
|
|
|
class cond_false_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"):
|
|
l_x__1 = l_x_
|
|
s94_1 = s94
|
|
|
|
mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
|
|
add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None
|
|
getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None
|
|
clone: "f32[2, s94]" = getitem.clone(); getitem = None
|
|
return (clone,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_cond_mismatched_branch_strided_output(self, dynamic, backend):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
def true_fn(x, y):
|
|
return (
|
|
(x.swapaxes(-1, 0) + 1)
|
|
.unsqueeze(1)
|
|
.expand(-1, 5, -1, -1, -1, -1, -1),
|
|
torch.empty_strided((3, 3), (0, 1)),
|
|
)
|
|
|
|
def false_fn(x, y):
|
|
return (
|
|
(y.swapaxes(-1, 0) + 1)
|
|
.unsqueeze(1)
|
|
.expand(-1, 4, -1, -1, -1, -1, -1),
|
|
torch.empty_strided((4, 5), (0, 1)),
|
|
)
|
|
|
|
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x, y))
|
|
return y.sum() + ret[0]
|
|
|
|
m = M()
|
|
x, y = torch.randn(1, 6, 1, 5, 4, 3), torch.randn(1, 4, 5, 1, 3, 8)
|
|
out = m(x, y)
|
|
compiled_out = torch.compile(
|
|
m, backend=backend, dynamic=dynamic, fullgraph=True
|
|
)(x, y)
|
|
self.assertEqual(compiled_out, out)
|
|
|
|
|
|
_hop_schema_test_schema_types = [
|
|
"bool",
|
|
"int",
|
|
"float",
|
|
"str",
|
|
"Tensor",
|
|
"SymInt",
|
|
"SymBool",
|
|
"GraphModule",
|
|
"ScriptObj",
|
|
]
|
|
|
|
|
|
@skipIfTorchDynamo("We don't expect users to torch.compile hop schema generation.")
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
class TestHopSchema(TestCase):
|
|
def _get_example_val(self, ty: str):
|
|
from torch.fx.experimental.sym_node import SymNode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
def create_symtype(cls, pytype, shape_env, val):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
symbol = shape_env.create_symbol(
|
|
val,
|
|
source=ConstantSource(
|
|
f"__testing_hop_schema{len(shape_env.var_to_val)}"
|
|
),
|
|
)
|
|
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
|
|
|
if ty == "bool":
|
|
return True
|
|
elif ty == "int":
|
|
return 1
|
|
elif ty == "float":
|
|
return 1.0
|
|
elif ty == "str":
|
|
return "foo"
|
|
elif ty == "Tensor":
|
|
return torch.tensor(1)
|
|
elif ty == "SymInt":
|
|
shape_env = ShapeEnv()
|
|
return create_symtype(torch.SymInt, int, shape_env, 1)
|
|
elif ty == "SymBool":
|
|
shape_env = ShapeEnv()
|
|
return create_symtype(torch.SymBool, bool, shape_env, True)
|
|
elif ty == "GraphModule":
|
|
|
|
def f(x):
|
|
return x.sin()
|
|
|
|
return make_fx(f)(torch.ones(1))
|
|
elif ty == "ScriptObj":
|
|
from torch.testing._internal.torchbind_impls import (
|
|
init_torchbind_implementations,
|
|
)
|
|
|
|
init_torchbind_implementations()
|
|
foo = torch.classes._TorchScriptTesting._Foo(3, 4)
|
|
return foo
|
|
else:
|
|
raise NotImplementedError(ty)
|
|
|
|
@parametrize("schema_type", _hop_schema_test_schema_types)
|
|
def test_type_gen(self, schema_type):
|
|
from torchgen.gen_schema_utils import TypeGen
|
|
|
|
example_val = self._get_example_val(schema_type)
|
|
ty = TypeGen.from_example(example_val)
|
|
# Test the generated type can be parsed
|
|
self.assertEqual(ty.parse(str(ty)), ty)
|
|
|
|
@parametrize("schema_type", _hop_schema_test_schema_types)
|
|
def test_list_gen(self, schema_type):
|
|
from torchgen.gen_schema_utils import TypeGen
|
|
|
|
example_val = self._get_example_val(schema_type)
|
|
li1 = [example_val]
|
|
ty1 = TypeGen.from_example(li1)
|
|
ty2 = TypeGen.from_example(li1)
|
|
self.assertEqual(ty1.parse(str(ty1)), ty1)
|
|
self.assertEqual(ty2.parse(str(ty2)), ty2)
|
|
|
|
def test_function_schema_gen(self):
|
|
from torchgen.gen_schema_utils import FunctionSchemaGen
|
|
|
|
inps = [
|
|
(schema_type + "_v", self._get_example_val(schema_type))
|
|
for schema_type in _hop_schema_test_schema_types
|
|
]
|
|
schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
|
|
schema2 = FunctionSchemaGen.from_example(
|
|
"test_op2",
|
|
inps,
|
|
[
|
|
torch.ones(1),
|
|
],
|
|
)
|
|
schema3 = FunctionSchemaGen.from_example(
|
|
"test_op3", inps, [torch.ones(1), torch.ones(1)]
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema1),
|
|
"""test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema2),
|
|
"""test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema3),
|
|
"""test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950,
|
|
)
|
|
self.assertEqual(schema1.parse(str(schema1)), schema1)
|
|
self.assertEqual(schema2.parse(str(schema2)), schema2)
|
|
self.assertEqual(schema3.parse(str(schema3)), schema3)
|
|
|
|
def test_while_loop_schema_gen(self):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
|
graph = make_fx(fn)(*inp).graph
|
|
while_loop_node = next(
|
|
node
|
|
for node in graph.nodes
|
|
if node.op == "call_function"
|
|
and node.target is torch.ops.higher_order.while_loop
|
|
)
|
|
schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950
|
|
)
|
|
self.assertEqual(schema.parse(str(schema)), schema)
|
|
|
|
def test_schema_tree_spec(self):
|
|
schema_gen = HopSchemaGenerator(torch.ops.higher_order.cond)
|
|
args = (torch.randn(3, 4), torch.randn(2, 3))
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Please only add flattened inputs to the hop schema"
|
|
):
|
|
schema_gen.add_arg("tuple_args", args)
|
|
|
|
for i, arg in enumerate(args):
|
|
schema_gen.add_arg(f"tuple_args{i}", arg)
|
|
schema_gen.add_schema_tree_spec(pytree.tree_flatten(args)[1])
|
|
flat_schema = schema_gen.gen_schema()
|
|
self.assertExpectedInline(
|
|
str(flat_schema), """cond(Tensor tuple_args0, Tensor tuple_args1) -> ()"""
|
|
)
|
|
|
|
def test_cond_gen_schema_tensor_inputs(self):
|
|
schema = torch.ops.higher_order.cond.gen_schema(
|
|
torch.tensor(True),
|
|
lambda x: x.sin(),
|
|
lambda x: x.cos(),
|
|
(torch.randn(3, 4),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""cond(Tensor pred, Any true_fn, Any false_fn, Tensor operand0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_cond_gen_schema_symbool_inputs(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
|
with fake_mode, fake_mode.shape_env.ignore_fresh_unbacked_symbols():
|
|
sym_bool = torch.randn(3, 4).nonzero().size(0) == 0
|
|
|
|
schema = torch.ops.higher_order.cond.gen_schema(
|
|
sym_bool,
|
|
lambda x: x.sin(),
|
|
lambda x: x.cos(),
|
|
(torch.randn(3, 4),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""cond(SymBool pred, Any true_fn, Any false_fn, Tensor operand0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_while_loop_gen_schema_tensor_inputs(self):
|
|
def cond_fn(x, y):
|
|
return x.sum() < 10
|
|
|
|
def body_fn(x, y):
|
|
return x + 1, y.sin()
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(torch.randn(3, 4), torch.randn(2, 3)),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, Tensor carried_input0, Tensor carried_input1) -> (Tensor, Tensor)""",
|
|
)
|
|
|
|
def test_while_loop_gen_schema_with_additional_inputs(self):
|
|
def cond_fn(x, y, z):
|
|
return x.sum() < z
|
|
|
|
def body_fn(x, y, z):
|
|
return x + 1, y.sin()
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(torch.randn(3, 4), torch.randn(2, 3)),
|
|
(torch.tensor(10),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, Tensor carried_input0, Tensor carried_input1, Tensor additional_input0) -> (Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_scan_gen_schema_tensor_inputs(self):
|
|
def combine_fn(carry, x):
|
|
return carry + x, carry * x
|
|
|
|
schema = torch.ops.higher_order.scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(3, 4),),
|
|
(torch.randn(5, 3, 4),),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""scan(Any combine_fn, Tensor init0, Tensor xs0) -> (Tensor, Tensor)""",
|
|
)
|
|
|
|
def test_scan_gen_schema_with_additional_inputs(self):
|
|
def combine_fn(carry, x, scale):
|
|
return carry + x * scale, carry * x
|
|
|
|
schema = torch.ops.higher_order.scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(3, 4),),
|
|
(torch.randn(5, 3, 4),),
|
|
(torch.tensor(2.0),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""scan(Any combine_fn, Tensor init0, Tensor xs0, Tensor additional_input0) -> (Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_scan_gen_schema_multiple_inputs(self):
|
|
def combine_fn(carry1, carry2, x1, x2):
|
|
return carry1 + x1, carry2 * x2, carry1 - x1, carry2 + x2
|
|
|
|
schema = torch.ops.higher_order.scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(3, 4), torch.randn(2, 3)),
|
|
(torch.randn(5, 3, 4), torch.randn(5, 2, 3)),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""scan(Any combine_fn, Tensor init0, Tensor init1, Tensor xs0, Tensor xs1) -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_associative_scan_gen_schema_tensor_inputs(self):
|
|
def combine_fn(x, y):
|
|
return x + y
|
|
|
|
schema = torch.ops.higher_order.associative_scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(5, 3, 4),),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""associative_scan(Any combine_fn, Tensor xs0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_associative_scan_gen_schema_with_additional_inputs(self):
|
|
def combine_fn(x, y, scale):
|
|
return x * y * scale
|
|
|
|
schema = torch.ops.higher_order.associative_scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(5, 3, 4),),
|
|
(torch.tensor(2.0),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""associative_scan(Any combine_fn, Tensor xs0, Tensor additional_input0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_associative_scan_gen_schema_multiple_inputs(self):
|
|
def combine_fn(x1, x2, y1, y2):
|
|
return x1 + y1, x2 * y2
|
|
|
|
schema = torch.ops.higher_order.associative_scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(5, 3, 4), torch.randn(5, 2, 3)),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""associative_scan(Any combine_fn, Tensor xs0, Tensor xs1) -> (Tensor, Tensor)""",
|
|
)
|
|
|
|
def test_while_loop_gen_schema_with_int_carries(self):
|
|
def cond_fn(x, y, z, c):
|
|
return x < y
|
|
|
|
def body_fn(x, y, z, c):
|
|
return x + 1, y - 1, z.sin(), c + x
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(2, 10, torch.randn(2, 3)),
|
|
(torch.tensor(10),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, int carried_input0, int carried_input1, Tensor carried_input2, Tensor additional_input0) -> (int, int, Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_gen_schema_with_input_mutation(self):
|
|
def cond_fn(x, y, z, c):
|
|
return x < y
|
|
|
|
def body_fn(x, y, z, c):
|
|
x.add_(1)
|
|
y.sub_(1)
|
|
z.sin_()
|
|
c.add_(x)
|
|
return x, y, z
|
|
|
|
c = torch.randn(3, 3)
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)),
|
|
(c,),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, Tensor(a2!) carried_input0, Tensor(a3!) carried_input1, Tensor(a4!) carried_input2, Tensor(a5!) additional_input0) -> (Tensor, Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestHopSchema)
|
|
instantiate_parametrized_tests(TestControlFlowTraced)
|
|
|
|
instantiate_parametrized_tests(TestControlFlow)
|
|
instantiate_parametrized_tests(AssociativeScanTests)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|