Files
pytorch/test/inductor/test_control_flow.py

2319 lines
76 KiB
Python

# Owner(s): ["module: inductor"]
import itertools
import unittest
import torch
import torch._dynamo.testing
import torch.utils._pytree as pytree
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.map import _fake_map
from torch._higher_order_ops.scan import _fake_scan, scan
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import (
decorateIf,
instantiate_parametrized_tests,
parametrize,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.testing._internal.triton_utils import requires_gpu
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
result = []
device = inputs[0].device
# iterate over the cartesian product of predicate values
for values in itertools.product(*([possible_values] * num_to_prepend)):
prepended = [torch.tensor(v, device=device) for v in values]
result.append((*prepended, *inputs))
return result
def prepend_predicates(inputs, num_predicates=1):
return _prepend_product_of_values(inputs, [False, True], num_predicates)
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
return _prepend_product_of_values(inputs, counter_values, num_counters)
# a testing loss_fn
def loss_fn(result) -> torch.Tensor:
flat_results, _ = pytree.tree_flatten(result)
total_loss = torch.tensor(
0.0, device=flat_results[0].device if flat_results else torch.device("cpu")
)
for res in flat_results:
# Convert to float if integer tensor to avoid numerical issues
if not res.dtype.is_floating_point:
res = res.float()
# Simple robust loss: abs values + small constant to avoid inf/nan
total_loss = total_loss + (torch.abs(res) / (1.0 + torch.abs(res))).sum()
return total_loss
class CondModels:
class Simple(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
return torch.cond(p, true_fn, false_fn, [a, b])
class SimpleWithIntClosure(torch.nn.Module):
def __init__(self):
super().__init__()
self.num = 3
def forward(self, p, a, b):
return torch.cond(
pred=p,
true_fn=lambda a, b: [a + b + self.num],
false_fn=lambda a, b: [a - b - self.num],
operands=(a, b),
)
class Nested(torch.nn.Module):
def forward(self, p0, p1, p2, a, b, c):
def true_fn(x0, y0, z0):
def true_true_fn(x1, y1, z1):
return (x1 - y1 * z1) * 3.14
def true_false_fn(x1, y1, z1):
def true_false_true_fn(x2, y2, z2):
return (x2 * y2 * z2) / 2.71
def true_false_false_fn(x2, y2, z2):
return (x2 + y2 + z2) * 1.23
return torch.cond(
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
)
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
def false_fn(x0, y0, z0):
def false_true_fn(x1, y1, z1):
def false_true_true_fn(x2, y2, z2):
return (x2 - y2 - z2) + 1.23
def false_true_false_fn(x2, y2, z2):
return (x2 / y2 / z2) - 3.14
return torch.cond(
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
)
def false_false_fn(x1, y1, z1):
return (x1 - y1 * z1) / 2.71
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
return torch.cond(p0, true_fn, false_fn, [a, b, c])
class Parameters(torch.nn.Module):
class InnerModel1(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer = torch.nn.Linear(20, 30, device=device)
def forward(self, x):
return self.layer(x + 1) * 3.14
class InnerModel2(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer1 = torch.nn.Linear(20, 10, device=device)
self.layer2 = torch.nn.Linear(10, 30, device=device)
def forward(self, x):
return self.layer2(self.layer1(x - 2)) * 3.14
def __init__(self, device):
super().__init__()
self.true_fn = self.InnerModel1(device)
self.false_fn = self.InnerModel2(device)
def forward(self, p, a):
return torch.cond(p, self.true_fn, self.false_fn, [a])
class ReinterpretView(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
z1 = x + y
z2 = x - y
return z1[2:], z2[:, 4:].contiguous()
def false_fn(x, y):
z1 = x - y
z2 = x + y
return z1[2:], z2[:, 4:].contiguous()
return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
class MultipleOutputs(torch.nn.Module):
def forward(self, p, a, b, c):
def true_fn(x, y, z):
return x * y, z / 2.71, (y - x).sum(dim=1)
def false_fn(x, y, z):
return y / x, z * 3.14, (x + y).mean(dim=1)
return torch.cond(p, true_fn, false_fn, [a, b, c])
class OuterCode(torch.nn.Module):
def forward(self, p, a, b):
c = a * b + 3.14
d = a / b - 2.71
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
e = torch.cond(p, true_fn, false_fn, [c, d])
return e * e / 1.41
class OuterBuffers(torch.nn.Module):
def forward(self, p, a, b, c):
d = a * 2
e = b / 2
def true_fn(x):
return x + d
def false_fn(x):
return x - e
return torch.cond(p, true_fn, false_fn, [c])
class WithNonTensorPredicate(torch.nn.Module):
def forward(self, a, b):
def true_fn(x, y):
return x.sum(0) / 3.14
def false_fn(x, y):
return y.sum(0) * 2.71
return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
class UnbackedSymIntClosure(torch.nn.Module):
def forward(self, p, x, y, z):
a = y.shape[0]
b = z.sum().to(torch.int64).item()
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
class MismatchedOutputSize(torch.nn.Module):
def forward(self, p, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return (x + a)[2:].sin()
def false_fn(x):
return (x + b * z)[:2].cos()
return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
class FunctionalCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, p, x):
true_new_weight = torch.ones(x.size(0), x.size(0), device=x.device)
false_new_weight = torch.zeros(x.size(0), x.size(0), device=x.device)
true_new_bias = torch.ones(x.size(0), device=x.device)
false_new_bias = torch.zeros(x.size(0), device=x.device)
x = x.reshape(-1, x.size(0))
def true_fn(x):
return torch.func.functional_call(
self.linear,
{
"weight": true_new_weight,
"bias": true_new_bias,
},
x,
)
def false_fn(x):
return torch.func.functional_call(
self.linear,
{
"weight": false_new_weight,
"bias": false_new_bias,
},
x,
)
return torch.cond(p, true_fn, false_fn, (x,))
class SelectWithInputIdx(torch.nn.Module):
def forward(self, p, x, idx):
u0 = idx.item()
x0 = x.select(0, u0)
def fn():
return x0.sin()
return torch.cond(x0.sum() > 0, fn, fn)
class CondTests(TestCase):
def _run_test(
self,
model,
inputs,
device,
dynamic=False,
num_predicates=1,
):
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
inputs = [inp.to(device=device) for inp in inputs]
input_sets = [inputs]
if dynamic:
larger_inputs = []
for inp in inputs:
# only tile non-scalar tensor inputs
if inp.ndim > 0:
# tile every first dim 5x
tiling = [5] + [1] * (inp.ndim - 1)
larger_inputs.append(torch.tile(inp, tiling))
else:
larger_inputs.append(inp)
input_sets.append(larger_inputs)
for inputs in input_sets:
for inp in inputs:
# mark every first dim as dynamic
torch._dynamo.mark_dynamic(inp, 0)
for inputs in input_sets:
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
result = model(*inputs_with_predicates)
result_compiled = compiled_model(*inputs_with_predicates)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_predicates)
torch.testing.assert_close(result, result_compiled)
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_simple_control_flow(self, device, dynamic):
# cond control flow without nesting
self._run_test(
model=CondModels.Simple(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_simple_with_int_closure(self, device):
self._run_test(
model=torch.compile(CondModels.SimpleWithIntClosure(), dynamic=True),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_cond_unbacked_symint_closure(self, device, dynamic):
self._run_test(
model=CondModels.UnbackedSymIntClosure(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@skipIfXpu(msg="Remove this skip after issue #154949 resolved.")
@requires_gpu
def test_cond_control_flow_with_precomputed_size(self):
class TestModel(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.conv2d = torch.nn.Conv2d(
512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
self.threshold = 20
def forward(self, x: torch.Tensor, index) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return self.conv2d(x)
def false_fn(x: torch.Tensor):
return self.conv2d(x)
return torch.cond(
index < self.threshold and index >= 0, true_fn, false_fn, (x,)
)
main_model = TestModel().to(GPU_TYPE)
x1 = torch.rand(2, 512, 128, 72).to(GPU_TYPE)
x2 = torch.rand(2, 512, 96, 96).to(GPU_TYPE)
opt_model = torch.compile(main_model)
out1 = main_model(x1, 1)
opt_out1 = opt_model(x1, 1)
self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
out2 = main_model(x2, 30)
opt_out2 = opt_model(x2, 30)
self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_nested_control_flow(self, device, dynamic):
# cond control flow with nesting
self._run_test(
model=CondModels.Nested(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
num_predicates=3,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_outer_code_before_after(self, device, dynamic):
# some code before and after the conditional
self._run_test(
model=CondModels.OuterCode(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_multiple_outputs(self, device, dynamic):
# multiple outputs with different shapes
self._run_test(
model=CondModels.MultipleOutputs(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(30, 40),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_advanced_dynamic_shapes(self, device):
# subgraphs input shapes include symbolic expressions
class Model(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
return torch.cat([x - 3, y * 3], dim=1)
def false_fn(x, y):
return torch.cat([x / 3, y - 3], dim=1)
c = torch.cat([a, b], dim=0)
d = c * 2
e = c / 2
return torch.cond(p, true_fn, false_fn, [d, e])
self._run_test(
model=Model(),
inputs=(
torch.randn(2, 3, 3),
torch.randn(4, 3, 3),
),
device=device,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_unbacked_symint_outer_to_inner(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
return torch.cos(x)
def false_fn(x):
return torch.sin(x)
nz = torch.nonzero(a)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.cond(p, true_fn, false_fn, [b])
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=Model(),
inputs=(torch.randn(2, 3, 3),),
device=device,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@torch._inductor.config.patch(size_asserts=False)
# TODO: graph partition does not support creating tensor
# with dynamic shape in conditional subgraph yet
@torch._inductor.config.patch(graph_partition=False)
def test_cond_unbacked_symint_inner(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.cos(b)
def false_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.sin(b)
b = torch.sin(a)
return torch.cond(p, true_fn, false_fn, [b])
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=Model(),
inputs=(torch.randn(2, 3, 3),),
device=device,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_unbacked_symint_inner_to_outer(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.cos(b)
def false_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.sin(b)
b = torch.sin(a)
y = torch.cond(p, true_fn, false_fn, [b])
return torch.sin(y)
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=Model(),
inputs=(torch.randn(2, 3, 3),),
device=device,
dynamic=True,
)
@requires_gpu
def test_cond_use_buffers_from_outer_scope(self):
# subgraphs input shapes include symbolic expressions
self._run_test(
model=CondModels.OuterBuffers(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=GPU_TYPE,
dynamic=False,
)
@requires_gpu
def test_cond_reintepret_view_inputs_outputs(self):
# ReinterpretView in inputs and outputs of the subgraphs
self._run_test(
model=CondModels.ReinterpretView(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=GPU_TYPE,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_subgraphs_with_parameters(self, device, dynamic):
# nested Modules with parameters
self._run_test(
model=CondModels.Parameters(device),
inputs=(torch.randn(10, 20),),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_non_tensor_predicates(self, device, dynamic):
# model with a boolean predicate
for b_size_0 in [5, 15]:
torch._dynamo.reset()
self._run_test(
model=CondModels.WithNonTensorPredicate(),
inputs=(
torch.randn(10, 20),
torch.randn(b_size_0, 20),
),
device=device,
dynamic=dynamic,
num_predicates=0,
)
@requires_gpu
def test_cond_aliasing_outputs(self):
# output aliasing in subgraphs: not supported
class Model(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
z = x + y
return z, z[1:]
def false_fn(x, y):
z = x - y
return z, z[1:]
return torch.cond(p, true_fn, false_fn, [a, b])
# AssertionError: Output aliasing is currently not supported...
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
torch.compile(Model())(
torch.tensor(True),
torch.randn(10, 20),
torch.randn(10, 20),
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_decompose_ops_in_subgraph(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
return torch.zeros_like(x)
def false_fn(x):
return torch.ones_like(x)
b = torch.ones_like(a)
c = torch.cond(p, true_fn, false_fn, [b])
return c
self._run_test(
model=Model(),
inputs=(torch.rand(10, 20),),
device=device,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_decompose_ops_in_subgraph_recursive(self, device):
def inner_fn1(x):
return torch.zeros_like(x)
def inner_fn2(x):
return torch.ones_like(x)
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
return torch.cond(p, inner_fn2, inner_fn1, [x])
def false_fn(x):
return torch.cond(p, inner_fn1, inner_fn2, [x])
b = torch.ones_like(a)
c = torch.cond(p, true_fn, false_fn, [b])
return c
self._run_test(
model=Model(),
inputs=(torch.rand(10, 20),),
device=device,
)
@requires_gpu
def test_cond_inductor_fx_passes_recursively_applied(self):
counters = {"pre_grad": 0, "post_grad": 0}
def pre_grad_pass_counter(gm):
counters["pre_grad"] += 1
def post_grad_pass_counter(gm):
counters["post_grad"] += 1
with torch._inductor.config.patch(
{
"pre_grad_custom_pass": pre_grad_pass_counter,
"post_grad_custom_pre_pass": post_grad_pass_counter,
# The above patches don't pickle
"fx_graph_cache": False,
}
):
self._run_test(
model=CondModels.Nested(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=GPU_TYPE,
dynamic=True,
num_predicates=3,
)
self.assertEqual(counters["pre_grad"], 11)
self.assertEqual(counters["post_grad"], 11)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_cond_mismatched_branch_output_size(self, device, dynamic):
self._run_test(
model=CondModels.MismatchedOutputSize(),
inputs={
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
},
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_cond_functional_call(self, device, dynamic):
self._run_test(
model=CondModels.FunctionalCall(),
inputs=(torch.randn(10, 20),),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_cond_select_with_input_idx(self, device, dynamic):
self._run_test(
model=CondModels.SelectWithInputIdx(),
inputs=(torch.randn(10, 20), torch.tensor(0, dtype=torch.int64)),
device=device,
dynamic=dynamic,
)
class WhileLoopModels:
class Simple(torch.nn.Module):
def forward(self, ci, a, b):
def cond_fn(i, x, y):
return i > 0
def body_fn(i, x, y):
return i - 1, x + y, y - x
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
class Nested(torch.nn.Module):
def forward(self, ci, cj, a, b):
def cond_fn(i1, j1, x1, y1):
return i1 > 0
def body_fn(i1, j1, x1, y1):
def cond_fn_nested(i2, j2, x2, y2):
return j2 > 0
def body_fn_nested(i2, j2, x2, y2):
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
i1, j1, x1, y1 = torch._higher_order_ops.while_loop(
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
)
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b))
class Parameters(torch.nn.Module):
class InnerModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer1 = torch.nn.Linear(
20, 30, device=device, dtype=torch.float64
)
self.layer2 = torch.nn.Linear(
30, 20, device=device, dtype=torch.float64
)
def forward(self, c, x):
return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
def __init__(self, device):
super().__init__()
self.body_fn = self.InnerModel(device)
self.cond_fn = lambda c, x: c > 0
def forward(self, c, a):
return torch._higher_order_ops.while_loop(
self.cond_fn, self.body_fn, [c, a]
)
class OuterCode(torch.nn.Module):
def forward(self, c, a, b):
d = a * b + 3.14
e = a / b - 2.71
def cond_fn(c, x, y):
return c > 0
def body_fn(c, x, y):
return c - 1, y - x, x + y
_, f, g = torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, d, e])
return f * g / 1.41
# TODO(aakhundov): add while_loop test with outer buffers
# with dynamic=True once dynamo / export allows while_loop
# closure capture with mark_dynamic:
# https://github.com/pytorch/pytorch/issues/123596
class OuterBuffers(torch.nn.Module):
def forward(self, c, a, b):
d = a * 2
e = b / 2
def cond_fn(c, x, y):
return c > 0
def body_fn(c, x, y):
return c - 1, x + d, y - e
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, a, b])
class PytreeCarry(torch.nn.Module):
def forward(self, it, pytree_input):
def cond_fn(it, pytree_input):
return it > 0
def body_fn(it, pytree_input):
x = pytree_input[0][0]
y = pytree_input[1]["x"]
z = pytree_input[1]["y"]
new_x = y.sin()
new_y = z.cos()
new_z = x + 1
return it - 1, ([new_x], {"x": new_y, "y": new_z})
return torch._higher_order_ops.while_loop(
cond_fn, body_fn, (it, pytree_input)
)
class DataDependentOpInSubgraph(torch.nn.Module):
def forward(self, c, a, b):
def cond_fn(c, reduced_carry):
return c > 0
def body_fn(c, reduced_carry):
k = torch.masked_select(a, b)
d = torch.concat([k, k * 2])
return c - 1, torch.min(d).unsqueeze(0) + reduced_carry
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, torch.zeros([1], dtype=torch.int64, device=c.device)],
)
class DataDependentInOut(torch.nn.Module):
def forward(self, c, a, b):
inp = torch.zeros(
a.sum().to(torch.int64).item(), 3, device=a.device, dtype=torch.int64
)
def cond_fn(c, inp):
return c > 0
def body_fn(c, inp):
return c - 1, (inp.sin() + 1).to(torch.int64)
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, inp],
)
class DataDependentInOutMismatch(torch.nn.Module):
def forward(self, c, a, b):
def cond_fn(c, a, b):
return c > 0
def body_fn(c, a, b):
return c - 1, a.nonzero(), b.nonzero()
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a, b],
)
class InfiniteLoop(torch.nn.Module):
def forward(self, c, a):
a_view = a.view(-1, 1)
def cond_fn(c, a_view):
return a_view.size(-1) > 0
def body_fn(c, a_view):
return c - 1, a_view + 1
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a_view],
)
class ZeroLoop(torch.nn.Module):
def forward(self, c, a):
a_view = torch.sin(a.view(-1, 1))
def cond_fn(c, a_view):
return a_view.size(-1) == 0
def body_fn(c, a_view):
return c - 1, a_view + 1
out1, out2 = torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a_view],
)
return out1 + 1, out2 + 2
class ZeroLoop2(torch.nn.Module):
def forward(self, c, a):
a_view = torch.sin(a.view(-1, 1))
def cond_fn(c, a_view):
return False
def body_fn(c, a_view):
return c - 1, a_view + 1
out1, out2 = torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a_view],
)
return out1 + 1, out2 + 2
class ZeroLoop3(torch.nn.Module):
def forward(self, c, a):
a_view = torch.sin(a.view(-1, 1))
def cond_fn(c, a_view):
return 0
def body_fn(c, a_view):
return c - 1, a_view + 1
out1, out2 = torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a_view],
)
return out1 + 1, out2 + 2
class ZeroLoop4(torch.nn.Module):
def forward(self, c, a):
a_view = torch.sin(a.view(-1, 1))
def cond_fn(c, a_view):
return torch.clip(a_view.sum(), 0, 1) < 0
def body_fn(c, a_view):
return c - 1, a_view + 1
out1, out2 = torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a_view],
)
return out2.sin_(), a_view.cos_()
class UnbackedSymIntClosure(torch.nn.Module):
def forward(self, c, a, b):
d = a.sum().to(torch.int64).item()
e = torch.nonzero(b).size(0)
def cond_fn(c, a, b):
return c > d + e + a.shape[0] - b.shape[0]
def body_fn(c, a, b):
return c - 1, a + e, b + d
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a, b],
)
class SymExprCond(torch.nn.Module):
def forward(self, c, a, b):
d = a.sum().to(torch.int64).item()
e = torch.nonzero(b).size(0)
def cond_fn(c, a, b):
return c + d + e + a.shape[0] - b.shape[0] < 10
def body_fn(c, a, b):
return c + 1, a + e, b + d
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a, b],
)
class MixedDevice(torch.nn.Module):
def forward(self, c, a, b):
# Force the loop idx on cpu
c = c.to(torch.device("cpu"))
def cond_fn(loop_idx, a, b):
return loop_idx < a.shape[0]
def body_fn(loop_idx, a, b):
return loop_idx + 1, a + b, a - b
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, a, b))
class MixedDevice2(torch.nn.Module):
def forward(self, c, a, b):
# Force the loop idx on cpu
c.to(torch.device("cpu"))
def cond_fn(loop_idx, a, b):
return loop_idx < a.shape[0]
def body_fn(loop_idx, a, b):
return loop_idx + a.sum(), a + b, a - b
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, a, b))
class Conv(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.conv2d = torch.nn.Conv2d(
4,
4,
(3, 3),
stride=(1, 1),
padding=(1, 1),
device=device,
dtype=torch.float64,
)
def forward(self, c, x):
def cond_fn(loop_idx, x):
return loop_idx < x.size(0)
def body_fn(loop_idx, x):
return loop_idx + 1, self.conv2d(x) + 1
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
(c, x),
)
class WhileLoopStackOutputSimple(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.linear = torch.nn.Linear(3, 3, device=device)
def forward(self, c, x):
def cond_fn(c, x):
return c < x.size(0)
def body_fn(c, x):
return c + 1, self.linear(x)
stacked_c, stacked_x = torch.ops.higher_order.while_loop_stack_output(
cond_fn, body_fn, (c, x), tuple()
)
return stacked_c, stacked_x
class WhileLoopTests(TestCase):
def _run_test(
self, model, inputs, device, dynamic=False, num_counters=1, autograd=False
):
import torch.utils._pytree as pytree
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
import copy
if not autograd:
for p in model.parameters():
p.requires_grad_(False)
compiled_model = copy.deepcopy(model)
compiled_fn = torch.compile(backend=cnt, fullgraph=True)(compiled_model)
inputs = pytree.tree_map(lambda t: t.to(device=device), inputs)
input_sets = [inputs]
def mark_first_dim_dyn(inp):
torch._dynamo.mark_dynamic(inp, 0)
if dynamic:
def tile_fn(inp):
# tile every first dim 5x
tiling = [5] + [1] * (inp.ndim - 1)
t = torch.tile(inp, tiling)
return t
larger_inputs = pytree.tree_map(tile_fn, inputs)
input_sets.append(larger_inputs)
for inputs in input_sets:
flat_inputs, inp_spec = pytree.tree_flatten(inputs)
for flat_inputs_with_counters in prepend_counters(
flat_inputs, num_counters
):
counters, flat = (
flat_inputs_with_counters[:num_counters],
flat_inputs_with_counters[num_counters:],
)
unflat_inputs = pytree.tree_unflatten(flat, inp_spec)
inputs_with_counters = counters + unflat_inputs
def process_inputs(inp):
inp = inp.clone()
if dynamic:
mark_first_dim_dyn(inp)
if autograd and inp.dtype.is_floating_point:
inp.requires_grad_(True)
return inp
cloned_inputs = pytree.tree_map(process_inputs, inputs_with_counters)
cloned_inputs2 = pytree.tree_map(process_inputs, inputs_with_counters)
result = model(*cloned_inputs)
result_compiled = compiled_fn(*cloned_inputs2)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
torch.testing.assert_close(
result, result_compiled, atol=1e-4, rtol=1e-4
)
if autograd and any(
pytree.tree_map_only(
torch.Tensor, lambda t: t.requires_grad, cloned_inputs
)
):
result_loss = loss_fn(pytree.tree_flatten(result)[0])
compiled_loss = loss_fn(pytree.tree_flatten(result_compiled)[0])
self.assertTrue(
not torch.isnan(result_loss) and not torch.isinf(compiled_loss)
)
self.assertTrue(
not torch.isnan(compiled_loss)
and not torch.isinf(compiled_loss)
)
self.assertEqual(result_loss, compiled_loss)
result_loss.backward()
compiled_loss.backward()
model_parameters = dict(model.named_parameters())
compiled_parameters = dict(compiled_model.named_parameters())
for name, param in model_parameters.items():
self.assertEqual(param, compiled_parameters[name])
self.assertEqual(
param.grad,
compiled_parameters[name].grad,
atol=1e-4,
rtol=1e-4,
)
for inp1, inp2 in zip(
pytree.tree_flatten(cloned_inputs)[0],
pytree.tree_flatten(cloned_inputs2)[0],
):
if inp1.requires_grad:
self.assertEqual(
inp1.grad,
inp2.grad,
atol=1e-4,
rtol=1e-4,
)
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_simple_control_flow(self, device, dynamic, autograd):
# while_loop control flow without nesting
self._run_test(
model=WhileLoopModels.Simple(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_nested_control_flow(self, device, dynamic, autograd):
# while_loop control flow with nesting
self._run_test(
model=WhileLoopModels.Nested(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
num_counters=2,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_outer_code(self, device, dynamic, autograd):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterCode(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_parameters(self, device, dynamic, autograd):
# while_loop control flow with parameters
self._run_test(
model=WhileLoopModels.Parameters(device),
inputs=(torch.randn(10, 20, dtype=torch.float64),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
# dynamic=True doesn't work now due to
# https://github.com/pytorch/pytorch/issues/123596
@parametrize("dynamic", [False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_outer_buffers(self, device, dynamic, autograd):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterBuffers(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_pytree_inputs(self, device, dynamic, autograd):
self._run_test(
model=WhileLoopModels.PytreeCarry(),
inputs=(
(
[torch.randn(10, 20)],
{"x": torch.randn(10, 20), "y": torch.randn(10, 20)},
),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_data_dependent_ops(self, device, dynamic, autograd):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=WhileLoopModels.DataDependentOpInSubgraph(),
inputs=(
torch.tensor([1, 2, 3, 4, 5]),
torch.tensor(
[True, True, True, True, True],
),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_data_dependent_in_out(self, device, dynamic, autograd):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
"capture_scalar_outputs": True,
}
):
self._run_test(
model=WhileLoopModels.DataDependentInOut(),
inputs=(
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
torch.tensor(
[True, True, True, True, True],
),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@parametrize("dynamic", [True, False])
def test_while_loop_with_data_dependent_in_out_mismatch(self, dynamic):
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Expected body_fn_output and carried_inputs to have same metadata but found",
):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=WhileLoopModels.DataDependentInOutMismatch(),
inputs=(
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
torch.tensor(
[True, True, True, True, True],
),
),
device="cpu",
dynamic=dynamic,
)
def test_while_loop_infinite_loop_error(self):
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"while_loop doesn't work unless it is captured completely",
):
self._run_test(
model=WhileLoopModels.InfiniteLoop(),
inputs=(torch.tensor([1, 2, 3, 4, 5]),),
device="cpu",
dynamic=False,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_while_loop_zero_loop(self, device, dynamic):
for model in [
WhileLoopModels.ZeroLoop(),
WhileLoopModels.ZeroLoop2(),
WhileLoopModels.ZeroLoop3(),
WhileLoopModels.ZeroLoop4(),
]:
self._run_test(
model=model,
inputs=(torch.tensor([1, 2, 3, 4, 5]),),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch(
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
)
@parametrize("autograd", [False, True])
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic, autograd):
self._run_test(
model=WhileLoopModels.UnbackedSymIntClosure(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", [GPU_TYPE])
def test_while_loop_models_with_mixed_device(self, device):
self._run_test(
model=WhileLoopModels.MixedDevice(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=True,
)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Expected body_fn_output and carried_inputs to have same metadata but found",
):
# Error at front end because device are promoted to a different one
# after the first iteration
self._run_test(
model=WhileLoopModels.MixedDevice2(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch(
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
)
def test_while_loop_with_sym_expr_cond(self, device, dynamic, autograd):
self._run_test(
model=WhileLoopModels.SymExprCond(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_conv(self, device, dynamic, autograd):
self._run_test(
model=WhileLoopModels.Conv(device),
inputs=(torch.randn(2, 4, 4, 4, dtype=torch.float64),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_stack_output_simple(self, device, dynamic):
self._run_test(
model=WhileLoopModels.WhileLoopStackOutputSimple(device),
inputs=(torch.randn(3, 3, dtype=torch.float32),),
device=device,
dynamic=dynamic,
)
class AssociativeScanTests(TestCase):
@requires_gpu
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("backend", ["inductor"])
@parametrize("device", [torch.device("cpu"), GPU_TYPE])
# This test will fail as flip in combination with particular input lengths
# produces weird results.
# This is under investigations in
# https://github.com/pytorch/pytorch/issues/131805
@decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE)
def test_associative_scan_CUDA_flip(self, combine_mode, backend, device):
def fct(x: torch.Tensor, y: torch.Tensor):
return x + y
# for n in range(10):
for n in [9]:
x = torch.arange(n, device=device)
torch.compiler.reset()
associative_scan1 = torch.compile(
associative_scan, backend=backend, fullgraph=True
)
associative_scan2 = associative_scan
if combine_mode == "pointwise" and device == torch.device("cpu"):
with self.assertRaisesRegex(Exception, r"."):
associative_scan1(
fct, x, 0, reverse=False, combine_mode=combine_mode
)
# Skipping test because combine_mode currently only supports CUDA tensors
return
result1 = associative_scan1(
fct, x, 0, reverse=False, combine_mode=combine_mode
)
result2 = associative_scan2(
fct, x, 0, reverse=False, combine_mode=combine_mode
)
result3 = torch.cumsum(x, 0)
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Flip only non-compiled and compare with compiled reverse=True
result1 = associative_scan1(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result2 = torch.flip(
associative_scan2(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Flip only compiled and compare with non-compiled reverse=True
result1 = torch.flip(
associative_scan1(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result2 = associative_scan2(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Use reverse=False, but flip both results before and after
result1 = torch.flip(
associative_scan1(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result2 = torch.flip(
associative_scan2(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Reverse=True
result1 = associative_scan1(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result2 = associative_scan2(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
class ScanModels:
class SimpleScan(torch.nn.Module):
def __init__(self, reverse, dim):
super().__init__()
self.reverse = reverse
self.dim = dim
def forward(self, _input, weight, bias):
def combine_fn(carry, x):
from torch.utils import _pytree as pytree
new_carry = {
"param": carry["param"] @ x + carry["bias"],
"bias": carry["bias"].sin(),
}
return new_carry, (
pytree.tree_map(lambda x: x.clone(), new_carry),
{"dummy": x.sin()},
)
return scan(
combine_fn,
{"param": weight, "bias": bias},
_input,
reverse=self.reverse,
dim=self.dim,
)
class ScanLinearWithView(torch.nn.Module):
def __init__(self, reverse, dim):
super().__init__()
self.reverse = reverse
self.dim = dim
self.linear = torch.nn.Linear(4, 4, dtype=torch.float64)
def forward(self, scan_op, init, xs):
def combine_fn(carry, x):
prev_sz = x.size()
x = self.linear(x.view(-1, x.size(-1)))
x_view = x.view(*prev_sz)
return x_view, x_view.clone()
return scan_op(combine_fn, init, xs, dim=self.dim, reverse=self.reverse)
class ScanConv(torch.nn.Module):
def __init__(self, reverse, dim):
super().__init__()
self.reverse = reverse
self.dim = dim
self.conv2d = torch.nn.Conv2d(
4, 4, (3, 3), stride=(1, 1), padding=(1, 1), dtype=torch.float64
)
# init = torch.randn(2, 4, 4, 4)
# xs = torch.randn(scan_dim, 2, 4, 4, 4)
def forward(self, scan_op, init, xs):
def combine_fn(carry, x):
x = self.conv2d(x)
return x, x.clone()
return scan_op(combine_fn, init, xs, dim=self.dim, reverse=self.reverse)
class ScanInCond(torch.nn.Module):
def __init__(self, reverse, dim):
super().__init__()
self.true_scan_linear = ScanModels.ScanLinearWithView(reverse, dim)
self.false_scan_linear = ScanModels.ScanLinearWithView(not reverse, dim)
def forward(self, scan_op, pred, init, xs):
def true_fn():
last_carry, y = self.true_scan_linear(scan_op, init, xs)
return last_carry.sum(), y.sin()
def false_fn():
last_carry, y = self.false_scan_linear(scan_op, init, xs)
return -last_carry.sum(), y.cos()
return torch.cond(pred, true_fn, false_fn, tuple())
class CondInScan(torch.nn.Module):
def __init__(self, reverse, dim):
super().__init__()
self.reverse = reverse
self.dim = dim
self.true_linear = torch.nn.Linear(4, 4)
self.false_linear = torch.nn.Linear(4, 4)
def forward(self, scan_op, init, xs):
def combine_fn(carry, x):
old_sizes = carry.size()
carry_view = carry.view(-1, carry.size()[-1])
new_carry_out = torch.cond(
torch.all(carry_view > 1),
lambda: self.true_linear(carry_view).sin(),
lambda: self.false_linear(carry_view).cos(),
tuple(),
)
return carry + new_carry_out.view(*old_sizes), new_carry_out
return scan_op(
combine_fn,
init,
xs,
dim=self.dim,
reverse=self.reverse,
)
class SimpleWithPytreeInOuts(torch.nn.Module):
def __init__(self, reverse, dim):
super().__init__()
self.reverse = reverse
self.dim = dim
def forward(self, scan_op, _input, weight, bias):
def combine_fn(carry, x):
new_carry = {
"param": carry["param"] @ x + carry["bias"],
"bias": carry["bias"].sin(),
}
return new_carry, (
pytree.tree_map(lambda x: x.clone(), new_carry),
{"dummy": x.sin()},
)
return scan_op(
combine_fn,
{"param": weight, "bias": bias},
_input,
reverse=self.reverse,
dim=self.dim,
)
class ChunkedCE(torch.nn.Module):
def __init__(self, chunk_size):
super().__init__()
self.chunk_size = chunk_size
self.ce = lambda logits, target: torch.abs(target - logits).sum()
def forward(self, scan_op, _input, weight, target, bias):
CHUNK_SIZE = self.chunk_size
def compute_loss(input_chunk, weight, bias, target):
logits = torch.addmm(bias, input_chunk, weight.t())
logits = logits.float()
loss = self.ce(logits, target)
return loss
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
loss_acc = torch.zeros((), device=_input.device)
chunks = _input.shape[0] // CHUNK_SIZE
_input_chunks = _input.view(CHUNK_SIZE, chunks, *_input.shape[1:])
target_chunks = target.view(CHUNK_SIZE, chunks, *target.shape[1:])
def combine_fn(carry, xs):
grad_weight, grad_bias, loss_acc = carry
input_chunk, target_chunk = xs
(
(
chunk_grad_input,
chunk_grad_weight,
chunk_grad_bias,
),
chunk_loss,
) = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(
input_chunk, weight, bias, target_chunk
)
return (
(
grad_weight + chunk_grad_weight,
grad_bias + chunk_grad_bias,
loss_acc + chunk_loss,
),
chunk_grad_input,
)
(grad_weight, grad_bias, loss_acc), grad_inputs = scan_op(
combine_fn,
(grad_weight, grad_bias, loss_acc),
(_input_chunks, target_chunks),
)
return (
grad_weight / chunks,
grad_bias / chunks,
loss_acc / chunks,
grad_inputs.view(-1, *_input.shape[1:]) / chunks,
)
class ChunkedCENoScan(torch.nn.Module):
def __init__(self, chunk_size):
super().__init__()
self.chunk_size = chunk_size
self.ce = lambda logits, target: torch.abs(target - logits).sum()
def forward(self, scan_op, _input, weight, target, bias):
CHUNK_SIZE = self.chunk_size
def compute_loss(input_chunk, weight, bias, target):
logits = torch.addmm(bias, input_chunk, weight.t())
logits = logits.float()
loss = self.ce(logits, target)
return loss
grad_weight = torch.zeros_like(weight)
grad_inputs = []
grad_bias = torch.zeros_like(bias)
loss_acc = torch.zeros((), device=_input.device)
chunks = _input.shape[0] // CHUNK_SIZE
def accumulate_chunk(input_chunk, target_chunk):
(
(
chunk_grad_input,
chunk_grad_weight,
chunk_grad_bias,
),
chunk_loss,
) = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(
input_chunk, weight, bias, target_chunk
)
grad_weight.add_(chunk_grad_weight)
grad_bias.add_(chunk_grad_bias)
loss_acc.add_(chunk_loss)
return chunk_grad_input
accumulate_chunk = torch.compile(accumulate_chunk)
input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
target_chunks = torch.chunk(target, chunks=chunks, dim=0)
for input_chunk, target_chunk in zip(input_chunks, target_chunks):
grad_inputs.append(accumulate_chunk(input_chunk, target_chunk))
return (
grad_weight / chunks,
grad_bias / chunks,
loss_acc / chunks,
torch.cat(grad_inputs, dim=0) / chunks,
)
class ScanWithClamp(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, scan_op, initial, xs):
def step(h_prev, x_t):
h_next = (h_prev + x_t).clamp(min=0.1)
return h_next, h_next.clone()
final, ys = scan_op(step, initial, xs)
return final, ys
class ScanTests(TestCase):
def _run_test(
self,
model,
inputs,
device,
dynamic,
autograd=False,
):
import copy
inputs = [
inp.requires_grad_(autograd) if inp.dtype.is_floating_point else inp
for inp in inputs
]
inputs = [inp.to(device=device) for inp in inputs]
model = model.to(device=device)
for p in model.parameters():
p.requires_grad_(autograd)
model1 = copy.deepcopy(model)
model2 = copy.deepcopy(model)
model3 = copy.deepcopy(model)
model4 = copy.deepcopy(model)
model3.compile(fullgraph=True, dynamic=dynamic)
model4.compile(fullgraph=True, dynamic=dynamic)
def _run_model(model, inputs):
cloned_inputs = [
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
]
fw_result = model(*cloned_inputs)
loss = loss_fn(fw_result)
if autograd:
loss.backward()
return (
fw_result,
loss,
[
inp.grad
for inp in cloned_inputs
if isinstance(inp, torch.Tensor)
],
{n: p.grad for n, p in model.named_parameters()},
)
else:
return fw_result, loss
result_exp = _run_model(model1, [_fake_scan] + inputs)
result_eager = _run_model(model2, [scan] + inputs)
result_compiled = _run_model(model3, [scan] + inputs)
result_compiled_exp = _run_model(
model4,
[_fake_scan] + inputs,
)
self.assertEqual(result_exp, result_eager)
self.assertEqual(result_exp, result_compiled)
self.assertEqual(result_exp, result_compiled_exp)
def _compare_result(
self,
model1,
model2,
inputs,
device,
):
inp_on_device = [elem.to(device=device) for elem in inputs]
cloned_inputs = [arg.clone() for arg in inp_on_device]
model1_out = model1(scan, *cloned_inputs)
model2_out = model2(scan, *cloned_inputs)
self.assertEqual(model1_out, model2_out)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 2])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_pytree_in_out(self, device, dynamic, reverse, dim, autograd):
self._run_test(
model=ScanModels.SimpleWithPytreeInOuts(reverse=reverse, dim=dim),
inputs=(
torch.ones(2, 2, 2),
torch.ones(2, 2),
torch.ones(2),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_nn_modules(
self, device, dynamic, reverse, dim, scan_length, autograd
):
init = torch.randn(20, 16, 4, 4, dtype=torch.float64)
xs = torch.randn(scan_length, 20, 16, 4, 4, dtype=torch.float64)
xs = xs.movedim(0, dim)
self._run_test(
model=ScanModels.ScanLinearWithView(reverse=reverse, dim=dim),
inputs=(
init,
xs,
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_conv(self, device, dynamic, reverse, dim, scan_length, autograd):
init = torch.randn(2, 4, 4, 4, dtype=torch.float64)
xs = torch.randn(scan_length, 2, 4, 4, 4, dtype=torch.float64)
xs = xs.movedim(0, dim)
self._run_test(
model=ScanModels.ScanConv(reverse=reverse, dim=dim),
inputs=(
init,
xs,
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("pred", [True, False])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_in_cond(
self, device, dynamic, reverse, dim, pred, scan_length, autograd
):
init = torch.randn(4, 4, 4, dtype=torch.float64)
xs = torch.randn(scan_length, 4, 4, 4, dtype=torch.float64)
xs = xs.movedim(0, dim)
self._run_test(
model=ScanModels.ScanInCond(reverse=reverse, dim=dim),
inputs=(
torch.tensor(pred),
init,
xs,
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_cond_in_scan(self, device, dynamic, reverse, dim, scan_length, autograd):
init = torch.randn(2, 4, 4, 4)
xs = torch.randn(scan_length, 4, 4, 4)
xs = xs.movedim(0, dim)
self._run_test(
model=ScanModels.CondInScan(reverse=reverse, dim=dim),
inputs=(
init,
xs,
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_chunked_ce(self, device, dynamic, autograd):
self._run_test(
model=ScanModels.ChunkedCE(10),
inputs=(
torch.randn(100, 20),
torch.randn(20, 20),
torch.randn(100, 20),
torch.randn(20),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_compare_chunked_ce_with_no_scan(self, device, dynamic):
for trunk_size, B, T in zip([10, 20], [10, 100], [20, 40]):
self._compare_result(
model1=torch.compile(ScanModels.ChunkedCE(trunk_size), dynamic=dynamic),
model2=ScanModels.ChunkedCENoScan(trunk_size),
inputs=(
torch.randn(B, T),
torch.randn(T, T),
torch.randn(B, T),
torch.randn(T),
),
device=device,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_with_clamp(self, device, dynamic, autograd):
B = 4
T = 8
H = 16
self._run_test(
model=ScanModels.ScanWithClamp(),
inputs=(
torch.randn((B, H)),
torch.randn((T, B, H)),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
class MapModels:
class Simple(torch.nn.Module):
def forward(self, map_op, x):
a = torch.ones(3, 4, device=x.device)
def f(x):
return x.sin() + a
return map_op(f, x)
class SimpleWithLinearWithView(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 5)
def forward(self, map_op, x):
def f(x):
return self.linear(x).sin()
return map_op(f, x.view(4, 3))
class PytreeInOut(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 5)
def forward(self, map_op, x, y, z):
def f(x_y_z):
x = x_y_z["x"]
y, (z,) = x_y_z["y_z"]
return self.linear(x).sin(), (self.linear(y), z.cos())
return map_op(f, {"x": x, "y_z": (y, (z,))})
class ReinterpretView(torch.nn.Module):
def forward(self, map_op, x, y, z):
def f(xyz):
x, y, z = xyz
return x.sin()[:2], y.cos()[:2] + z[-2:].clone()
return map_op(f, (x, y, z))
class NestedWithCond(torch.nn.Module):
def forward(self, map_op, x, y, z):
def true_fn(x, y, z):
def inner_f(yz):
y, z = yz
return y + z
return map_op(inner_f, (y, z))
def false_fn(x, y, z):
def inner_f(yz):
y, z = yz
return y - z
return map_op(inner_f, (y, z))
return torch._higher_order_ops.cond(
x.sum() > 0, true_fn, false_fn, (x, y, z)
)
class MapTests(TestCase):
def _run_test(
self,
model,
inputs,
device,
dynamic=False,
autograd=False,
):
import copy
inputs = [inp.to(device=device) for inp in inputs]
model = model.to(device=device)
model_eager = copy.deepcopy(model)
model_compiled = copy.deepcopy(model)
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
model_compiled
)
if autograd:
pytree.tree_map_only(torch.Tensor, lambda t: t.requires_grad_(True), inputs)
cloned_inputs = [inp.clone() for inp in inputs]
result = model(torch._higher_order_ops.map, *cloned_inputs)
result_exp = model_eager(_fake_map, *cloned_inputs)
result_compiled = compiled_model(torch._higher_order_ops.map, *cloned_inputs)
self.assertEqual(result, result_exp)
self.assertEqual(result, result_compiled)
if autograd:
loss_fn(result).backward()
loss_fn(result_exp).backward()
loss_fn(result_compiled).backward()
model_params = dict(model.named_parameters())
model_eager_params = dict(model_eager.named_parameters())
model_compiled_params = dict(model_compiled.named_parameters())
for name, param in model_eager_params.items():
self.assertEqual(param, model_params[name])
self.assertEqual(param, model_compiled_params[name])
self.assertEqual(param.grad, model_params[name].grad)
self.assertEqual(param.grad, model_compiled_params[name].grad)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_simple(self, device, dynamic, autograd):
self._run_test(
model=MapModels.Simple(),
inputs=(torch.randn(3, 4),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_simple_linear_with_view(self, device, dynamic, autograd):
self._run_test(
model=MapModels.SimpleWithLinearWithView(),
inputs=(torch.randn(3, 4),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_pytree_in_out(self, device, dynamic, autograd):
self._run_test(
model=MapModels.PytreeInOut(),
inputs=(
torch.randn(2, 5, 3),
torch.randn(2, 5, 3),
torch.randn(2, 4, 3),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_nested_with_cond(self, device, dynamic, autograd):
self._run_test(
model=MapModels.NestedWithCond(),
inputs=(
torch.randn(3, 2),
torch.randn(3, 10, 5),
torch.randn(3, 10, 5),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
instantiate_parametrized_tests(CondTests)
instantiate_parametrized_tests(WhileLoopTests)
instantiate_parametrized_tests(AssociativeScanTests)
instantiate_parametrized_tests(ScanTests)
instantiate_parametrized_tests(MapTests)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")