mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161732 Approved by: https://github.com/zou3519 ghstack dependencies: #161557, #161664, #161808, #162025
2319 lines
76 KiB
Python
2319 lines
76 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import itertools
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch.utils._pytree as pytree
|
|
from torch._higher_order_ops.associative_scan import associative_scan
|
|
from torch._higher_order_ops.map import _fake_map
|
|
from torch._higher_order_ops.scan import _fake_scan, scan
|
|
from torch._inductor.test_case import TestCase
|
|
from torch.testing._internal.common_utils import (
|
|
decorateIf,
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfXpu,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
|
from torch.testing._internal.triton_utils import requires_gpu
|
|
|
|
|
|
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
|
|
result = []
|
|
device = inputs[0].device
|
|
# iterate over the cartesian product of predicate values
|
|
for values in itertools.product(*([possible_values] * num_to_prepend)):
|
|
prepended = [torch.tensor(v, device=device) for v in values]
|
|
result.append((*prepended, *inputs))
|
|
return result
|
|
|
|
|
|
def prepend_predicates(inputs, num_predicates=1):
|
|
return _prepend_product_of_values(inputs, [False, True], num_predicates)
|
|
|
|
|
|
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
|
|
return _prepend_product_of_values(inputs, counter_values, num_counters)
|
|
|
|
|
|
# a testing loss_fn
|
|
def loss_fn(result) -> torch.Tensor:
|
|
flat_results, _ = pytree.tree_flatten(result)
|
|
total_loss = torch.tensor(
|
|
0.0, device=flat_results[0].device if flat_results else torch.device("cpu")
|
|
)
|
|
|
|
for res in flat_results:
|
|
# Convert to float if integer tensor to avoid numerical issues
|
|
if not res.dtype.is_floating_point:
|
|
res = res.float()
|
|
|
|
# Simple robust loss: abs values + small constant to avoid inf/nan
|
|
total_loss = total_loss + (torch.abs(res) / (1.0 + torch.abs(res))).sum()
|
|
|
|
return total_loss
|
|
|
|
|
|
class CondModels:
|
|
class Simple(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
return x + y
|
|
|
|
def false_fn(x, y):
|
|
return x - y
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a, b])
|
|
|
|
class SimpleWithIntClosure(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.num = 3
|
|
|
|
def forward(self, p, a, b):
|
|
return torch.cond(
|
|
pred=p,
|
|
true_fn=lambda a, b: [a + b + self.num],
|
|
false_fn=lambda a, b: [a - b - self.num],
|
|
operands=(a, b),
|
|
)
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, p0, p1, p2, a, b, c):
|
|
def true_fn(x0, y0, z0):
|
|
def true_true_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) * 3.14
|
|
|
|
def true_false_fn(x1, y1, z1):
|
|
def true_false_true_fn(x2, y2, z2):
|
|
return (x2 * y2 * z2) / 2.71
|
|
|
|
def true_false_false_fn(x2, y2, z2):
|
|
return (x2 + y2 + z2) * 1.23
|
|
|
|
return torch.cond(
|
|
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
|
|
|
|
def false_fn(x0, y0, z0):
|
|
def false_true_fn(x1, y1, z1):
|
|
def false_true_true_fn(x2, y2, z2):
|
|
return (x2 - y2 - z2) + 1.23
|
|
|
|
def false_true_false_fn(x2, y2, z2):
|
|
return (x2 / y2 / z2) - 3.14
|
|
|
|
return torch.cond(
|
|
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
def false_false_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) / 2.71
|
|
|
|
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
|
|
|
|
return torch.cond(p0, true_fn, false_fn, [a, b, c])
|
|
|
|
class Parameters(torch.nn.Module):
|
|
class InnerModel1(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layer = torch.nn.Linear(20, 30, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.layer(x + 1) * 3.14
|
|
|
|
class InnerModel2(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(20, 10, device=device)
|
|
self.layer2 = torch.nn.Linear(10, 30, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.layer2(self.layer1(x - 2)) * 3.14
|
|
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.true_fn = self.InnerModel1(device)
|
|
self.false_fn = self.InnerModel2(device)
|
|
|
|
def forward(self, p, a):
|
|
return torch.cond(p, self.true_fn, self.false_fn, [a])
|
|
|
|
class ReinterpretView(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
z1 = x + y
|
|
z2 = x - y
|
|
return z1[2:], z2[:, 4:].contiguous()
|
|
|
|
def false_fn(x, y):
|
|
z1 = x - y
|
|
z2 = x + y
|
|
return z1[2:], z2[:, 4:].contiguous()
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
|
|
|
|
class MultipleOutputs(torch.nn.Module):
|
|
def forward(self, p, a, b, c):
|
|
def true_fn(x, y, z):
|
|
return x * y, z / 2.71, (y - x).sum(dim=1)
|
|
|
|
def false_fn(x, y, z):
|
|
return y / x, z * 3.14, (x + y).mean(dim=1)
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a, b, c])
|
|
|
|
class OuterCode(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
c = a * b + 3.14
|
|
d = a / b - 2.71
|
|
|
|
def true_fn(x, y):
|
|
return x + y
|
|
|
|
def false_fn(x, y):
|
|
return x - y
|
|
|
|
e = torch.cond(p, true_fn, false_fn, [c, d])
|
|
|
|
return e * e / 1.41
|
|
|
|
class OuterBuffers(torch.nn.Module):
|
|
def forward(self, p, a, b, c):
|
|
d = a * 2
|
|
e = b / 2
|
|
|
|
def true_fn(x):
|
|
return x + d
|
|
|
|
def false_fn(x):
|
|
return x - e
|
|
|
|
return torch.cond(p, true_fn, false_fn, [c])
|
|
|
|
class WithNonTensorPredicate(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
def true_fn(x, y):
|
|
return x.sum(0) / 3.14
|
|
|
|
def false_fn(x, y):
|
|
return y.sum(0) * 2.71
|
|
|
|
return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
|
|
|
|
class UnbackedSymIntClosure(torch.nn.Module):
|
|
def forward(self, p, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.sum().to(torch.int64).item()
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
class MismatchedOutputSize(torch.nn.Module):
|
|
def forward(self, p, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return (x + a)[2:].sin()
|
|
|
|
def false_fn(x):
|
|
return (x + b * z)[:2].cos()
|
|
|
|
return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
|
|
|
|
class FunctionalCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, p, x):
|
|
true_new_weight = torch.ones(x.size(0), x.size(0), device=x.device)
|
|
false_new_weight = torch.zeros(x.size(0), x.size(0), device=x.device)
|
|
true_new_bias = torch.ones(x.size(0), device=x.device)
|
|
false_new_bias = torch.zeros(x.size(0), device=x.device)
|
|
x = x.reshape(-1, x.size(0))
|
|
|
|
def true_fn(x):
|
|
return torch.func.functional_call(
|
|
self.linear,
|
|
{
|
|
"weight": true_new_weight,
|
|
"bias": true_new_bias,
|
|
},
|
|
x,
|
|
)
|
|
|
|
def false_fn(x):
|
|
return torch.func.functional_call(
|
|
self.linear,
|
|
{
|
|
"weight": false_new_weight,
|
|
"bias": false_new_bias,
|
|
},
|
|
x,
|
|
)
|
|
|
|
return torch.cond(p, true_fn, false_fn, (x,))
|
|
|
|
class SelectWithInputIdx(torch.nn.Module):
|
|
def forward(self, p, x, idx):
|
|
u0 = idx.item()
|
|
x0 = x.select(0, u0)
|
|
|
|
def fn():
|
|
return x0.sin()
|
|
|
|
return torch.cond(x0.sum() > 0, fn, fn)
|
|
|
|
|
|
class CondTests(TestCase):
|
|
def _run_test(
|
|
self,
|
|
model,
|
|
inputs,
|
|
device,
|
|
dynamic=False,
|
|
num_predicates=1,
|
|
):
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
|
|
|
|
inputs = [inp.to(device=device) for inp in inputs]
|
|
input_sets = [inputs]
|
|
if dynamic:
|
|
larger_inputs = []
|
|
for inp in inputs:
|
|
# only tile non-scalar tensor inputs
|
|
if inp.ndim > 0:
|
|
# tile every first dim 5x
|
|
tiling = [5] + [1] * (inp.ndim - 1)
|
|
larger_inputs.append(torch.tile(inp, tiling))
|
|
else:
|
|
larger_inputs.append(inp)
|
|
input_sets.append(larger_inputs)
|
|
for inputs in input_sets:
|
|
for inp in inputs:
|
|
# mark every first dim as dynamic
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
|
|
for inputs in input_sets:
|
|
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
|
|
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
|
|
result = model(*inputs_with_predicates)
|
|
result_compiled = compiled_model(*inputs_with_predicates)
|
|
# inputs must not be mutated
|
|
torch.testing.assert_close(cloned_inputs, inputs_with_predicates)
|
|
torch.testing.assert_close(result, result_compiled)
|
|
|
|
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_simple_control_flow(self, device, dynamic):
|
|
# cond control flow without nesting
|
|
self._run_test(
|
|
model=CondModels.Simple(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_simple_with_int_closure(self, device):
|
|
self._run_test(
|
|
model=torch.compile(CondModels.SimpleWithIntClosure(), dynamic=True),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_cond_unbacked_symint_closure(self, device, dynamic):
|
|
self._run_test(
|
|
model=CondModels.UnbackedSymIntClosure(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@skipIfXpu(msg="Remove this skip after issue #154949 resolved.")
|
|
@requires_gpu
|
|
def test_cond_control_flow_with_precomputed_size(self):
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv2d = torch.nn.Conv2d(
|
|
512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
)
|
|
self.threshold = 20
|
|
|
|
def forward(self, x: torch.Tensor, index) -> torch.Tensor:
|
|
def true_fn(x: torch.Tensor):
|
|
return self.conv2d(x)
|
|
|
|
def false_fn(x: torch.Tensor):
|
|
return self.conv2d(x)
|
|
|
|
return torch.cond(
|
|
index < self.threshold and index >= 0, true_fn, false_fn, (x,)
|
|
)
|
|
|
|
main_model = TestModel().to(GPU_TYPE)
|
|
x1 = torch.rand(2, 512, 128, 72).to(GPU_TYPE)
|
|
x2 = torch.rand(2, 512, 96, 96).to(GPU_TYPE)
|
|
|
|
opt_model = torch.compile(main_model)
|
|
out1 = main_model(x1, 1)
|
|
opt_out1 = opt_model(x1, 1)
|
|
self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
|
|
|
|
out2 = main_model(x2, 30)
|
|
opt_out2 = opt_model(x2, 30)
|
|
self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_nested_control_flow(self, device, dynamic):
|
|
# cond control flow with nesting
|
|
self._run_test(
|
|
model=CondModels.Nested(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
num_predicates=3,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_outer_code_before_after(self, device, dynamic):
|
|
# some code before and after the conditional
|
|
self._run_test(
|
|
model=CondModels.OuterCode(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_multiple_outputs(self, device, dynamic):
|
|
# multiple outputs with different shapes
|
|
self._run_test(
|
|
model=CondModels.MultipleOutputs(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(30, 40),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_advanced_dynamic_shapes(self, device):
|
|
# subgraphs input shapes include symbolic expressions
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
return torch.cat([x - 3, y * 3], dim=1)
|
|
|
|
def false_fn(x, y):
|
|
return torch.cat([x / 3, y - 3], dim=1)
|
|
|
|
c = torch.cat([a, b], dim=0)
|
|
d = c * 2
|
|
e = c / 2
|
|
|
|
return torch.cond(p, true_fn, false_fn, [d, e])
|
|
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(
|
|
torch.randn(2, 3, 3),
|
|
torch.randn(4, 3, 3),
|
|
),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_unbacked_symint_outer_to_inner(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
return torch.cos(x)
|
|
|
|
def false_fn(x):
|
|
return torch.sin(x)
|
|
|
|
nz = torch.nonzero(a)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
|
|
return torch.cond(p, true_fn, false_fn, [b])
|
|
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.randn(2, 3, 3),),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@torch._inductor.config.patch(size_asserts=False)
|
|
# TODO: graph partition does not support creating tensor
|
|
# with dynamic shape in conditional subgraph yet
|
|
@torch._inductor.config.patch(graph_partition=False)
|
|
def test_cond_unbacked_symint_inner(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.cos(b)
|
|
|
|
def false_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.sin(b)
|
|
|
|
b = torch.sin(a)
|
|
|
|
return torch.cond(p, true_fn, false_fn, [b])
|
|
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.randn(2, 3, 3),),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_unbacked_symint_inner_to_outer(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.cos(b)
|
|
|
|
def false_fn(x):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones([nz.size(0), 8], device=nz.device)
|
|
return torch.sin(b)
|
|
|
|
b = torch.sin(a)
|
|
|
|
y = torch.cond(p, true_fn, false_fn, [b])
|
|
return torch.sin(y)
|
|
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.randn(2, 3, 3),),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_use_buffers_from_outer_scope(self):
|
|
# subgraphs input shapes include symbolic expressions
|
|
self._run_test(
|
|
model=CondModels.OuterBuffers(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=GPU_TYPE,
|
|
dynamic=False,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_reintepret_view_inputs_outputs(self):
|
|
# ReinterpretView in inputs and outputs of the subgraphs
|
|
self._run_test(
|
|
model=CondModels.ReinterpretView(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=GPU_TYPE,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_subgraphs_with_parameters(self, device, dynamic):
|
|
# nested Modules with parameters
|
|
self._run_test(
|
|
model=CondModels.Parameters(device),
|
|
inputs=(torch.randn(10, 20),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
def test_cond_non_tensor_predicates(self, device, dynamic):
|
|
# model with a boolean predicate
|
|
for b_size_0 in [5, 15]:
|
|
torch._dynamo.reset()
|
|
self._run_test(
|
|
model=CondModels.WithNonTensorPredicate(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(b_size_0, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
num_predicates=0,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_aliasing_outputs(self):
|
|
# output aliasing in subgraphs: not supported
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a, b):
|
|
def true_fn(x, y):
|
|
z = x + y
|
|
return z, z[1:]
|
|
|
|
def false_fn(x, y):
|
|
z = x - y
|
|
return z, z[1:]
|
|
|
|
return torch.cond(p, true_fn, false_fn, [a, b])
|
|
|
|
# AssertionError: Output aliasing is currently not supported...
|
|
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
|
torch.compile(Model())(
|
|
torch.tensor(True),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_decompose_ops_in_subgraph(self, device):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
return torch.zeros_like(x)
|
|
|
|
def false_fn(x):
|
|
return torch.ones_like(x)
|
|
|
|
b = torch.ones_like(a)
|
|
c = torch.cond(p, true_fn, false_fn, [b])
|
|
return c
|
|
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.rand(10, 20),),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
def test_cond_decompose_ops_in_subgraph_recursive(self, device):
|
|
def inner_fn1(x):
|
|
return torch.zeros_like(x)
|
|
|
|
def inner_fn2(x):
|
|
return torch.ones_like(x)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, p, a):
|
|
def true_fn(x):
|
|
return torch.cond(p, inner_fn2, inner_fn1, [x])
|
|
|
|
def false_fn(x):
|
|
return torch.cond(p, inner_fn1, inner_fn2, [x])
|
|
|
|
b = torch.ones_like(a)
|
|
c = torch.cond(p, true_fn, false_fn, [b])
|
|
return c
|
|
|
|
self._run_test(
|
|
model=Model(),
|
|
inputs=(torch.rand(10, 20),),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
def test_cond_inductor_fx_passes_recursively_applied(self):
|
|
counters = {"pre_grad": 0, "post_grad": 0}
|
|
|
|
def pre_grad_pass_counter(gm):
|
|
counters["pre_grad"] += 1
|
|
|
|
def post_grad_pass_counter(gm):
|
|
counters["post_grad"] += 1
|
|
|
|
with torch._inductor.config.patch(
|
|
{
|
|
"pre_grad_custom_pass": pre_grad_pass_counter,
|
|
"post_grad_custom_pre_pass": post_grad_pass_counter,
|
|
# The above patches don't pickle
|
|
"fx_graph_cache": False,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=CondModels.Nested(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=GPU_TYPE,
|
|
dynamic=True,
|
|
num_predicates=3,
|
|
)
|
|
|
|
self.assertEqual(counters["pre_grad"], 11)
|
|
self.assertEqual(counters["post_grad"], 11)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_cond_mismatched_branch_output_size(self, device, dynamic):
|
|
self._run_test(
|
|
model=CondModels.MismatchedOutputSize(),
|
|
inputs={
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
},
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_cond_functional_call(self, device, dynamic):
|
|
self._run_test(
|
|
model=CondModels.FunctionalCall(),
|
|
inputs=(torch.randn(10, 20),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_cond_select_with_input_idx(self, device, dynamic):
|
|
self._run_test(
|
|
model=CondModels.SelectWithInputIdx(),
|
|
inputs=(torch.randn(10, 20), torch.tensor(0, dtype=torch.int64)),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
|
|
class WhileLoopModels:
|
|
class Simple(torch.nn.Module):
|
|
def forward(self, ci, a, b):
|
|
def cond_fn(i, x, y):
|
|
return i > 0
|
|
|
|
def body_fn(i, x, y):
|
|
return i - 1, x + y, y - x
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, ci, cj, a, b):
|
|
def cond_fn(i1, j1, x1, y1):
|
|
return i1 > 0
|
|
|
|
def body_fn(i1, j1, x1, y1):
|
|
def cond_fn_nested(i2, j2, x2, y2):
|
|
return j2 > 0
|
|
|
|
def body_fn_nested(i2, j2, x2, y2):
|
|
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
|
|
|
|
i1, j1, x1, y1 = torch._higher_order_ops.while_loop(
|
|
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
|
|
)
|
|
|
|
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b))
|
|
|
|
class Parameters(torch.nn.Module):
|
|
class InnerModel(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(
|
|
20, 30, device=device, dtype=torch.float64
|
|
)
|
|
self.layer2 = torch.nn.Linear(
|
|
30, 20, device=device, dtype=torch.float64
|
|
)
|
|
|
|
def forward(self, c, x):
|
|
return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
|
|
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.body_fn = self.InnerModel(device)
|
|
self.cond_fn = lambda c, x: c > 0
|
|
|
|
def forward(self, c, a):
|
|
return torch._higher_order_ops.while_loop(
|
|
self.cond_fn, self.body_fn, [c, a]
|
|
)
|
|
|
|
class OuterCode(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a * b + 3.14
|
|
e = a / b - 2.71
|
|
|
|
def cond_fn(c, x, y):
|
|
return c > 0
|
|
|
|
def body_fn(c, x, y):
|
|
return c - 1, y - x, x + y
|
|
|
|
_, f, g = torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, d, e])
|
|
|
|
return f * g / 1.41
|
|
|
|
# TODO(aakhundov): add while_loop test with outer buffers
|
|
# with dynamic=True once dynamo / export allows while_loop
|
|
# closure capture with mark_dynamic:
|
|
# https://github.com/pytorch/pytorch/issues/123596
|
|
class OuterBuffers(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a * 2
|
|
e = b / 2
|
|
|
|
def cond_fn(c, x, y):
|
|
return c > 0
|
|
|
|
def body_fn(c, x, y):
|
|
return c - 1, x + d, y - e
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, a, b])
|
|
|
|
class PytreeCarry(torch.nn.Module):
|
|
def forward(self, it, pytree_input):
|
|
def cond_fn(it, pytree_input):
|
|
return it > 0
|
|
|
|
def body_fn(it, pytree_input):
|
|
x = pytree_input[0][0]
|
|
y = pytree_input[1]["x"]
|
|
z = pytree_input[1]["y"]
|
|
new_x = y.sin()
|
|
new_y = z.cos()
|
|
new_z = x + 1
|
|
return it - 1, ([new_x], {"x": new_y, "y": new_z})
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn, body_fn, (it, pytree_input)
|
|
)
|
|
|
|
class DataDependentOpInSubgraph(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
def cond_fn(c, reduced_carry):
|
|
return c > 0
|
|
|
|
def body_fn(c, reduced_carry):
|
|
k = torch.masked_select(a, b)
|
|
d = torch.concat([k, k * 2])
|
|
return c - 1, torch.min(d).unsqueeze(0) + reduced_carry
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, torch.zeros([1], dtype=torch.int64, device=c.device)],
|
|
)
|
|
|
|
class DataDependentInOut(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
inp = torch.zeros(
|
|
a.sum().to(torch.int64).item(), 3, device=a.device, dtype=torch.int64
|
|
)
|
|
|
|
def cond_fn(c, inp):
|
|
return c > 0
|
|
|
|
def body_fn(c, inp):
|
|
return c - 1, (inp.sin() + 1).to(torch.int64)
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, inp],
|
|
)
|
|
|
|
class DataDependentInOutMismatch(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
def cond_fn(c, a, b):
|
|
return c > 0
|
|
|
|
def body_fn(c, a, b):
|
|
return c - 1, a.nonzero(), b.nonzero()
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a, b],
|
|
)
|
|
|
|
class InfiniteLoop(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = a.view(-1, 1)
|
|
|
|
def cond_fn(c, a_view):
|
|
return a_view.size(-1) > 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
|
|
class ZeroLoop(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return a_view.size(-1) == 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out1 + 1, out2 + 2
|
|
|
|
class ZeroLoop2(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return False
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out1 + 1, out2 + 2
|
|
|
|
class ZeroLoop3(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out1 + 1, out2 + 2
|
|
|
|
class ZeroLoop4(torch.nn.Module):
|
|
def forward(self, c, a):
|
|
a_view = torch.sin(a.view(-1, 1))
|
|
|
|
def cond_fn(c, a_view):
|
|
return torch.clip(a_view.sum(), 0, 1) < 0
|
|
|
|
def body_fn(c, a_view):
|
|
return c - 1, a_view + 1
|
|
|
|
out1, out2 = torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a_view],
|
|
)
|
|
return out2.sin_(), a_view.cos_()
|
|
|
|
class UnbackedSymIntClosure(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a.sum().to(torch.int64).item()
|
|
e = torch.nonzero(b).size(0)
|
|
|
|
def cond_fn(c, a, b):
|
|
return c > d + e + a.shape[0] - b.shape[0]
|
|
|
|
def body_fn(c, a, b):
|
|
return c - 1, a + e, b + d
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a, b],
|
|
)
|
|
|
|
class SymExprCond(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
d = a.sum().to(torch.int64).item()
|
|
e = torch.nonzero(b).size(0)
|
|
|
|
def cond_fn(c, a, b):
|
|
return c + d + e + a.shape[0] - b.shape[0] < 10
|
|
|
|
def body_fn(c, a, b):
|
|
return c + 1, a + e, b + d
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
[c, a, b],
|
|
)
|
|
|
|
class MixedDevice(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
# Force the loop idx on cpu
|
|
c = c.to(torch.device("cpu"))
|
|
|
|
def cond_fn(loop_idx, a, b):
|
|
return loop_idx < a.shape[0]
|
|
|
|
def body_fn(loop_idx, a, b):
|
|
return loop_idx + 1, a + b, a - b
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, a, b))
|
|
|
|
class MixedDevice2(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
# Force the loop idx on cpu
|
|
c.to(torch.device("cpu"))
|
|
|
|
def cond_fn(loop_idx, a, b):
|
|
return loop_idx < a.shape[0]
|
|
|
|
def body_fn(loop_idx, a, b):
|
|
return loop_idx + a.sum(), a + b, a - b
|
|
|
|
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, a, b))
|
|
|
|
class Conv(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.conv2d = torch.nn.Conv2d(
|
|
4,
|
|
4,
|
|
(3, 3),
|
|
stride=(1, 1),
|
|
padding=(1, 1),
|
|
device=device,
|
|
dtype=torch.float64,
|
|
)
|
|
|
|
def forward(self, c, x):
|
|
def cond_fn(loop_idx, x):
|
|
return loop_idx < x.size(0)
|
|
|
|
def body_fn(loop_idx, x):
|
|
return loop_idx + 1, self.conv2d(x) + 1
|
|
|
|
return torch._higher_order_ops.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
(c, x),
|
|
)
|
|
|
|
class WhileLoopStackOutputSimple(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3, device=device)
|
|
|
|
def forward(self, c, x):
|
|
def cond_fn(c, x):
|
|
return c < x.size(0)
|
|
|
|
def body_fn(c, x):
|
|
return c + 1, self.linear(x)
|
|
|
|
stacked_c, stacked_x = torch.ops.higher_order.while_loop_stack_output(
|
|
cond_fn, body_fn, (c, x), tuple()
|
|
)
|
|
return stacked_c, stacked_x
|
|
|
|
|
|
class WhileLoopTests(TestCase):
|
|
def _run_test(
|
|
self, model, inputs, device, dynamic=False, num_counters=1, autograd=False
|
|
):
|
|
import torch.utils._pytree as pytree
|
|
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
import copy
|
|
|
|
if not autograd:
|
|
for p in model.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
compiled_model = copy.deepcopy(model)
|
|
compiled_fn = torch.compile(backend=cnt, fullgraph=True)(compiled_model)
|
|
|
|
inputs = pytree.tree_map(lambda t: t.to(device=device), inputs)
|
|
input_sets = [inputs]
|
|
|
|
def mark_first_dim_dyn(inp):
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
|
|
if dynamic:
|
|
|
|
def tile_fn(inp):
|
|
# tile every first dim 5x
|
|
tiling = [5] + [1] * (inp.ndim - 1)
|
|
t = torch.tile(inp, tiling)
|
|
return t
|
|
|
|
larger_inputs = pytree.tree_map(tile_fn, inputs)
|
|
input_sets.append(larger_inputs)
|
|
|
|
for inputs in input_sets:
|
|
flat_inputs, inp_spec = pytree.tree_flatten(inputs)
|
|
for flat_inputs_with_counters in prepend_counters(
|
|
flat_inputs, num_counters
|
|
):
|
|
counters, flat = (
|
|
flat_inputs_with_counters[:num_counters],
|
|
flat_inputs_with_counters[num_counters:],
|
|
)
|
|
unflat_inputs = pytree.tree_unflatten(flat, inp_spec)
|
|
inputs_with_counters = counters + unflat_inputs
|
|
|
|
def process_inputs(inp):
|
|
inp = inp.clone()
|
|
if dynamic:
|
|
mark_first_dim_dyn(inp)
|
|
|
|
if autograd and inp.dtype.is_floating_point:
|
|
inp.requires_grad_(True)
|
|
return inp
|
|
|
|
cloned_inputs = pytree.tree_map(process_inputs, inputs_with_counters)
|
|
cloned_inputs2 = pytree.tree_map(process_inputs, inputs_with_counters)
|
|
|
|
result = model(*cloned_inputs)
|
|
result_compiled = compiled_fn(*cloned_inputs2)
|
|
# inputs must not be mutated
|
|
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
|
|
torch.testing.assert_close(
|
|
result, result_compiled, atol=1e-4, rtol=1e-4
|
|
)
|
|
|
|
if autograd and any(
|
|
pytree.tree_map_only(
|
|
torch.Tensor, lambda t: t.requires_grad, cloned_inputs
|
|
)
|
|
):
|
|
result_loss = loss_fn(pytree.tree_flatten(result)[0])
|
|
compiled_loss = loss_fn(pytree.tree_flatten(result_compiled)[0])
|
|
self.assertTrue(
|
|
not torch.isnan(result_loss) and not torch.isinf(compiled_loss)
|
|
)
|
|
self.assertTrue(
|
|
not torch.isnan(compiled_loss)
|
|
and not torch.isinf(compiled_loss)
|
|
)
|
|
|
|
self.assertEqual(result_loss, compiled_loss)
|
|
|
|
result_loss.backward()
|
|
compiled_loss.backward()
|
|
|
|
model_parameters = dict(model.named_parameters())
|
|
compiled_parameters = dict(compiled_model.named_parameters())
|
|
for name, param in model_parameters.items():
|
|
self.assertEqual(param, compiled_parameters[name])
|
|
self.assertEqual(
|
|
param.grad,
|
|
compiled_parameters[name].grad,
|
|
atol=1e-4,
|
|
rtol=1e-4,
|
|
)
|
|
|
|
for inp1, inp2 in zip(
|
|
pytree.tree_flatten(cloned_inputs)[0],
|
|
pytree.tree_flatten(cloned_inputs2)[0],
|
|
):
|
|
if inp1.requires_grad:
|
|
self.assertEqual(
|
|
inp1.grad,
|
|
inp2.grad,
|
|
atol=1e-4,
|
|
rtol=1e-4,
|
|
)
|
|
|
|
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_simple_control_flow(self, device, dynamic, autograd):
|
|
# while_loop control flow without nesting
|
|
self._run_test(
|
|
model=WhileLoopModels.Simple(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_nested_control_flow(self, device, dynamic, autograd):
|
|
# while_loop control flow with nesting
|
|
self._run_test(
|
|
model=WhileLoopModels.Nested(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
num_counters=2,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_outer_code(self, device, dynamic, autograd):
|
|
# while_loop control flow with outer code
|
|
self._run_test(
|
|
model=WhileLoopModels.OuterCode(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [False, True])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_parameters(self, device, dynamic, autograd):
|
|
# while_loop control flow with parameters
|
|
self._run_test(
|
|
model=WhileLoopModels.Parameters(device),
|
|
inputs=(torch.randn(10, 20, dtype=torch.float64),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
# dynamic=True doesn't work now due to
|
|
# https://github.com/pytorch/pytorch/issues/123596
|
|
@parametrize("dynamic", [False])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_outer_buffers(self, device, dynamic, autograd):
|
|
# while_loop control flow with outer code
|
|
self._run_test(
|
|
model=WhileLoopModels.OuterBuffers(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_pytree_inputs(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=WhileLoopModels.PytreeCarry(),
|
|
inputs=(
|
|
(
|
|
[torch.randn(10, 20)],
|
|
{"x": torch.randn(10, 20), "y": torch.randn(10, 20)},
|
|
),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_data_dependent_ops(self, device, dynamic, autograd):
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.DataDependentOpInSubgraph(),
|
|
inputs=(
|
|
torch.tensor([1, 2, 3, 4, 5]),
|
|
torch.tensor(
|
|
[True, True, True, True, True],
|
|
),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_data_dependent_in_out(self, device, dynamic, autograd):
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
"capture_scalar_outputs": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.DataDependentInOut(),
|
|
inputs=(
|
|
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
|
|
torch.tensor(
|
|
[True, True, True, True, True],
|
|
),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_with_data_dependent_in_out_mismatch(self, dynamic):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected body_fn_output and carried_inputs to have same metadata but found",
|
|
):
|
|
with torch._dynamo.config.patch(
|
|
{
|
|
"capture_dynamic_output_shape_ops": True,
|
|
}
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.DataDependentInOutMismatch(),
|
|
inputs=(
|
|
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
|
|
torch.tensor(
|
|
[True, True, True, True, True],
|
|
),
|
|
),
|
|
device="cpu",
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
def test_while_loop_infinite_loop_error(self):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"while_loop doesn't work unless it is captured completely",
|
|
):
|
|
self._run_test(
|
|
model=WhileLoopModels.InfiniteLoop(),
|
|
inputs=(torch.tensor([1, 2, 3, 4, 5]),),
|
|
device="cpu",
|
|
dynamic=False,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_zero_loop(self, device, dynamic):
|
|
for model in [
|
|
WhileLoopModels.ZeroLoop(),
|
|
WhileLoopModels.ZeroLoop2(),
|
|
WhileLoopModels.ZeroLoop3(),
|
|
WhileLoopModels.ZeroLoop4(),
|
|
]:
|
|
self._run_test(
|
|
model=model,
|
|
inputs=(torch.tensor([1, 2, 3, 4, 5]),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch(
|
|
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
|
)
|
|
@parametrize("autograd", [False, True])
|
|
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=WhileLoopModels.UnbackedSymIntClosure(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", [GPU_TYPE])
|
|
def test_while_loop_models_with_mixed_device(self, device):
|
|
self._run_test(
|
|
model=WhileLoopModels.MixedDevice(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected body_fn_output and carried_inputs to have same metadata but found",
|
|
):
|
|
# Error at front end because device are promoted to a different one
|
|
# after the first iteration
|
|
self._run_test(
|
|
model=WhileLoopModels.MixedDevice2(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=True,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch(
|
|
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
|
|
)
|
|
def test_while_loop_with_sym_expr_cond(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=WhileLoopModels.SymExprCond(),
|
|
inputs=(
|
|
torch.randn(10, 20),
|
|
torch.randn(10, 20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [False, True])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_with_conv(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=WhileLoopModels.Conv(device),
|
|
inputs=(torch.randn(2, 4, 4, 4, dtype=torch.float64),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_while_loop_stack_output_simple(self, device, dynamic):
|
|
self._run_test(
|
|
model=WhileLoopModels.WhileLoopStackOutputSimple(device),
|
|
inputs=(torch.randn(3, 3, dtype=torch.float32),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
)
|
|
|
|
|
|
class AssociativeScanTests(TestCase):
|
|
@requires_gpu
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("backend", ["inductor"])
|
|
@parametrize("device", [torch.device("cpu"), GPU_TYPE])
|
|
# This test will fail as flip in combination with particular input lengths
|
|
# produces weird results.
|
|
# This is under investigations in
|
|
# https://github.com/pytorch/pytorch/issues/131805
|
|
@decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE)
|
|
def test_associative_scan_CUDA_flip(self, combine_mode, backend, device):
|
|
def fct(x: torch.Tensor, y: torch.Tensor):
|
|
return x + y
|
|
|
|
# for n in range(10):
|
|
for n in [9]:
|
|
x = torch.arange(n, device=device)
|
|
torch.compiler.reset()
|
|
associative_scan1 = torch.compile(
|
|
associative_scan, backend=backend, fullgraph=True
|
|
)
|
|
associative_scan2 = associative_scan
|
|
|
|
if combine_mode == "pointwise" and device == torch.device("cpu"):
|
|
with self.assertRaisesRegex(Exception, r"."):
|
|
associative_scan1(
|
|
fct, x, 0, reverse=False, combine_mode=combine_mode
|
|
)
|
|
|
|
# Skipping test because combine_mode currently only supports CUDA tensors
|
|
return
|
|
|
|
result1 = associative_scan1(
|
|
fct, x, 0, reverse=False, combine_mode=combine_mode
|
|
)
|
|
result2 = associative_scan2(
|
|
fct, x, 0, reverse=False, combine_mode=combine_mode
|
|
)
|
|
result3 = torch.cumsum(x, 0)
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Flip only non-compiled and compare with compiled reverse=True
|
|
result1 = associative_scan1(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result2 = torch.flip(
|
|
associative_scan2(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Flip only compiled and compare with non-compiled reverse=True
|
|
result1 = torch.flip(
|
|
associative_scan1(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result2 = associative_scan2(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Use reverse=False, but flip both results before and after
|
|
result1 = torch.flip(
|
|
associative_scan1(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result2 = torch.flip(
|
|
associative_scan2(
|
|
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
|
|
),
|
|
[0],
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
# Reverse=True
|
|
result1 = associative_scan1(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result2 = associative_scan2(
|
|
fct, x, 0, reverse=True, combine_mode=combine_mode
|
|
)
|
|
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
|
|
|
|
self.assertEqual(result1, result2)
|
|
self.assertEqual(result1, result3)
|
|
|
|
|
|
class ScanModels:
|
|
class SimpleScan(torch.nn.Module):
|
|
def __init__(self, reverse, dim):
|
|
super().__init__()
|
|
self.reverse = reverse
|
|
self.dim = dim
|
|
|
|
def forward(self, _input, weight, bias):
|
|
def combine_fn(carry, x):
|
|
from torch.utils import _pytree as pytree
|
|
|
|
new_carry = {
|
|
"param": carry["param"] @ x + carry["bias"],
|
|
"bias": carry["bias"].sin(),
|
|
}
|
|
return new_carry, (
|
|
pytree.tree_map(lambda x: x.clone(), new_carry),
|
|
{"dummy": x.sin()},
|
|
)
|
|
|
|
return scan(
|
|
combine_fn,
|
|
{"param": weight, "bias": bias},
|
|
_input,
|
|
reverse=self.reverse,
|
|
dim=self.dim,
|
|
)
|
|
|
|
class ScanLinearWithView(torch.nn.Module):
|
|
def __init__(self, reverse, dim):
|
|
super().__init__()
|
|
self.reverse = reverse
|
|
self.dim = dim
|
|
self.linear = torch.nn.Linear(4, 4, dtype=torch.float64)
|
|
|
|
def forward(self, scan_op, init, xs):
|
|
def combine_fn(carry, x):
|
|
prev_sz = x.size()
|
|
x = self.linear(x.view(-1, x.size(-1)))
|
|
x_view = x.view(*prev_sz)
|
|
return x_view, x_view.clone()
|
|
|
|
return scan_op(combine_fn, init, xs, dim=self.dim, reverse=self.reverse)
|
|
|
|
class ScanConv(torch.nn.Module):
|
|
def __init__(self, reverse, dim):
|
|
super().__init__()
|
|
self.reverse = reverse
|
|
self.dim = dim
|
|
self.conv2d = torch.nn.Conv2d(
|
|
4, 4, (3, 3), stride=(1, 1), padding=(1, 1), dtype=torch.float64
|
|
)
|
|
|
|
# init = torch.randn(2, 4, 4, 4)
|
|
# xs = torch.randn(scan_dim, 2, 4, 4, 4)
|
|
def forward(self, scan_op, init, xs):
|
|
def combine_fn(carry, x):
|
|
x = self.conv2d(x)
|
|
return x, x.clone()
|
|
|
|
return scan_op(combine_fn, init, xs, dim=self.dim, reverse=self.reverse)
|
|
|
|
class ScanInCond(torch.nn.Module):
|
|
def __init__(self, reverse, dim):
|
|
super().__init__()
|
|
self.true_scan_linear = ScanModels.ScanLinearWithView(reverse, dim)
|
|
self.false_scan_linear = ScanModels.ScanLinearWithView(not reverse, dim)
|
|
|
|
def forward(self, scan_op, pred, init, xs):
|
|
def true_fn():
|
|
last_carry, y = self.true_scan_linear(scan_op, init, xs)
|
|
return last_carry.sum(), y.sin()
|
|
|
|
def false_fn():
|
|
last_carry, y = self.false_scan_linear(scan_op, init, xs)
|
|
return -last_carry.sum(), y.cos()
|
|
|
|
return torch.cond(pred, true_fn, false_fn, tuple())
|
|
|
|
class CondInScan(torch.nn.Module):
|
|
def __init__(self, reverse, dim):
|
|
super().__init__()
|
|
self.reverse = reverse
|
|
self.dim = dim
|
|
self.true_linear = torch.nn.Linear(4, 4)
|
|
self.false_linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, scan_op, init, xs):
|
|
def combine_fn(carry, x):
|
|
old_sizes = carry.size()
|
|
carry_view = carry.view(-1, carry.size()[-1])
|
|
new_carry_out = torch.cond(
|
|
torch.all(carry_view > 1),
|
|
lambda: self.true_linear(carry_view).sin(),
|
|
lambda: self.false_linear(carry_view).cos(),
|
|
tuple(),
|
|
)
|
|
return carry + new_carry_out.view(*old_sizes), new_carry_out
|
|
|
|
return scan_op(
|
|
combine_fn,
|
|
init,
|
|
xs,
|
|
dim=self.dim,
|
|
reverse=self.reverse,
|
|
)
|
|
|
|
class SimpleWithPytreeInOuts(torch.nn.Module):
|
|
def __init__(self, reverse, dim):
|
|
super().__init__()
|
|
self.reverse = reverse
|
|
self.dim = dim
|
|
|
|
def forward(self, scan_op, _input, weight, bias):
|
|
def combine_fn(carry, x):
|
|
new_carry = {
|
|
"param": carry["param"] @ x + carry["bias"],
|
|
"bias": carry["bias"].sin(),
|
|
}
|
|
return new_carry, (
|
|
pytree.tree_map(lambda x: x.clone(), new_carry),
|
|
{"dummy": x.sin()},
|
|
)
|
|
|
|
return scan_op(
|
|
combine_fn,
|
|
{"param": weight, "bias": bias},
|
|
_input,
|
|
reverse=self.reverse,
|
|
dim=self.dim,
|
|
)
|
|
|
|
class ChunkedCE(torch.nn.Module):
|
|
def __init__(self, chunk_size):
|
|
super().__init__()
|
|
self.chunk_size = chunk_size
|
|
self.ce = lambda logits, target: torch.abs(target - logits).sum()
|
|
|
|
def forward(self, scan_op, _input, weight, target, bias):
|
|
CHUNK_SIZE = self.chunk_size
|
|
|
|
def compute_loss(input_chunk, weight, bias, target):
|
|
logits = torch.addmm(bias, input_chunk, weight.t())
|
|
logits = logits.float()
|
|
loss = self.ce(logits, target)
|
|
return loss
|
|
|
|
grad_weight = torch.zeros_like(weight)
|
|
grad_bias = torch.zeros_like(bias)
|
|
loss_acc = torch.zeros((), device=_input.device)
|
|
|
|
chunks = _input.shape[0] // CHUNK_SIZE
|
|
|
|
_input_chunks = _input.view(CHUNK_SIZE, chunks, *_input.shape[1:])
|
|
target_chunks = target.view(CHUNK_SIZE, chunks, *target.shape[1:])
|
|
|
|
def combine_fn(carry, xs):
|
|
grad_weight, grad_bias, loss_acc = carry
|
|
input_chunk, target_chunk = xs
|
|
(
|
|
(
|
|
chunk_grad_input,
|
|
chunk_grad_weight,
|
|
chunk_grad_bias,
|
|
),
|
|
chunk_loss,
|
|
) = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(
|
|
input_chunk, weight, bias, target_chunk
|
|
)
|
|
return (
|
|
(
|
|
grad_weight + chunk_grad_weight,
|
|
grad_bias + chunk_grad_bias,
|
|
loss_acc + chunk_loss,
|
|
),
|
|
chunk_grad_input,
|
|
)
|
|
|
|
(grad_weight, grad_bias, loss_acc), grad_inputs = scan_op(
|
|
combine_fn,
|
|
(grad_weight, grad_bias, loss_acc),
|
|
(_input_chunks, target_chunks),
|
|
)
|
|
return (
|
|
grad_weight / chunks,
|
|
grad_bias / chunks,
|
|
loss_acc / chunks,
|
|
grad_inputs.view(-1, *_input.shape[1:]) / chunks,
|
|
)
|
|
|
|
class ChunkedCENoScan(torch.nn.Module):
|
|
def __init__(self, chunk_size):
|
|
super().__init__()
|
|
self.chunk_size = chunk_size
|
|
self.ce = lambda logits, target: torch.abs(target - logits).sum()
|
|
|
|
def forward(self, scan_op, _input, weight, target, bias):
|
|
CHUNK_SIZE = self.chunk_size
|
|
|
|
def compute_loss(input_chunk, weight, bias, target):
|
|
logits = torch.addmm(bias, input_chunk, weight.t())
|
|
logits = logits.float()
|
|
loss = self.ce(logits, target)
|
|
return loss
|
|
|
|
grad_weight = torch.zeros_like(weight)
|
|
grad_inputs = []
|
|
grad_bias = torch.zeros_like(bias)
|
|
loss_acc = torch.zeros((), device=_input.device)
|
|
|
|
chunks = _input.shape[0] // CHUNK_SIZE
|
|
|
|
def accumulate_chunk(input_chunk, target_chunk):
|
|
(
|
|
(
|
|
chunk_grad_input,
|
|
chunk_grad_weight,
|
|
chunk_grad_bias,
|
|
),
|
|
chunk_loss,
|
|
) = torch.func.grad_and_value(compute_loss, argnums=(0, 1, 2))(
|
|
input_chunk, weight, bias, target_chunk
|
|
)
|
|
grad_weight.add_(chunk_grad_weight)
|
|
grad_bias.add_(chunk_grad_bias)
|
|
loss_acc.add_(chunk_loss)
|
|
return chunk_grad_input
|
|
|
|
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
|
|
input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
target_chunks = torch.chunk(target, chunks=chunks, dim=0)
|
|
for input_chunk, target_chunk in zip(input_chunks, target_chunks):
|
|
grad_inputs.append(accumulate_chunk(input_chunk, target_chunk))
|
|
return (
|
|
grad_weight / chunks,
|
|
grad_bias / chunks,
|
|
loss_acc / chunks,
|
|
torch.cat(grad_inputs, dim=0) / chunks,
|
|
)
|
|
|
|
class ScanWithClamp(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, scan_op, initial, xs):
|
|
def step(h_prev, x_t):
|
|
h_next = (h_prev + x_t).clamp(min=0.1)
|
|
return h_next, h_next.clone()
|
|
|
|
final, ys = scan_op(step, initial, xs)
|
|
return final, ys
|
|
|
|
|
|
class ScanTests(TestCase):
|
|
def _run_test(
|
|
self,
|
|
model,
|
|
inputs,
|
|
device,
|
|
dynamic,
|
|
autograd=False,
|
|
):
|
|
import copy
|
|
|
|
inputs = [
|
|
inp.requires_grad_(autograd) if inp.dtype.is_floating_point else inp
|
|
for inp in inputs
|
|
]
|
|
inputs = [inp.to(device=device) for inp in inputs]
|
|
model = model.to(device=device)
|
|
for p in model.parameters():
|
|
p.requires_grad_(autograd)
|
|
|
|
model1 = copy.deepcopy(model)
|
|
model2 = copy.deepcopy(model)
|
|
model3 = copy.deepcopy(model)
|
|
model4 = copy.deepcopy(model)
|
|
model3.compile(fullgraph=True, dynamic=dynamic)
|
|
model4.compile(fullgraph=True, dynamic=dynamic)
|
|
|
|
def _run_model(model, inputs):
|
|
cloned_inputs = [
|
|
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
|
|
]
|
|
fw_result = model(*cloned_inputs)
|
|
loss = loss_fn(fw_result)
|
|
if autograd:
|
|
loss.backward()
|
|
return (
|
|
fw_result,
|
|
loss,
|
|
[
|
|
inp.grad
|
|
for inp in cloned_inputs
|
|
if isinstance(inp, torch.Tensor)
|
|
],
|
|
{n: p.grad for n, p in model.named_parameters()},
|
|
)
|
|
else:
|
|
return fw_result, loss
|
|
|
|
result_exp = _run_model(model1, [_fake_scan] + inputs)
|
|
result_eager = _run_model(model2, [scan] + inputs)
|
|
result_compiled = _run_model(model3, [scan] + inputs)
|
|
result_compiled_exp = _run_model(
|
|
model4,
|
|
[_fake_scan] + inputs,
|
|
)
|
|
|
|
self.assertEqual(result_exp, result_eager)
|
|
self.assertEqual(result_exp, result_compiled)
|
|
self.assertEqual(result_exp, result_compiled_exp)
|
|
|
|
def _compare_result(
|
|
self,
|
|
model1,
|
|
model2,
|
|
inputs,
|
|
device,
|
|
):
|
|
inp_on_device = [elem.to(device=device) for elem in inputs]
|
|
cloned_inputs = [arg.clone() for arg in inp_on_device]
|
|
model1_out = model1(scan, *cloned_inputs)
|
|
model2_out = model2(scan, *cloned_inputs)
|
|
self.assertEqual(model1_out, model2_out)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("reverse", [True, False])
|
|
@parametrize("dim", [0, 1, 2])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_pytree_in_out(self, device, dynamic, reverse, dim, autograd):
|
|
self._run_test(
|
|
model=ScanModels.SimpleWithPytreeInOuts(reverse=reverse, dim=dim),
|
|
inputs=(
|
|
torch.ones(2, 2, 2),
|
|
torch.ones(2, 2),
|
|
torch.ones(2),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("reverse", [True, False])
|
|
@parametrize("dim", [0, 1, 3])
|
|
@parametrize("scan_length", [1, 5])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_nn_modules(
|
|
self, device, dynamic, reverse, dim, scan_length, autograd
|
|
):
|
|
init = torch.randn(20, 16, 4, 4, dtype=torch.float64)
|
|
xs = torch.randn(scan_length, 20, 16, 4, 4, dtype=torch.float64)
|
|
xs = xs.movedim(0, dim)
|
|
self._run_test(
|
|
model=ScanModels.ScanLinearWithView(reverse=reverse, dim=dim),
|
|
inputs=(
|
|
init,
|
|
xs,
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("reverse", [True, False])
|
|
@parametrize("dim", [0, 1, 3])
|
|
@parametrize("scan_length", [1, 5])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_conv(self, device, dynamic, reverse, dim, scan_length, autograd):
|
|
init = torch.randn(2, 4, 4, 4, dtype=torch.float64)
|
|
xs = torch.randn(scan_length, 2, 4, 4, 4, dtype=torch.float64)
|
|
xs = xs.movedim(0, dim)
|
|
self._run_test(
|
|
model=ScanModels.ScanConv(reverse=reverse, dim=dim),
|
|
inputs=(
|
|
init,
|
|
xs,
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("reverse", [True, False])
|
|
@parametrize("dim", [0, 1, 3])
|
|
@parametrize("pred", [True, False])
|
|
@parametrize("scan_length", [1, 5])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_in_cond(
|
|
self, device, dynamic, reverse, dim, pred, scan_length, autograd
|
|
):
|
|
init = torch.randn(4, 4, 4, dtype=torch.float64)
|
|
xs = torch.randn(scan_length, 4, 4, 4, dtype=torch.float64)
|
|
xs = xs.movedim(0, dim)
|
|
self._run_test(
|
|
model=ScanModels.ScanInCond(reverse=reverse, dim=dim),
|
|
inputs=(
|
|
torch.tensor(pred),
|
|
init,
|
|
xs,
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("reverse", [True, False])
|
|
@parametrize("dim", [0, 1, 3])
|
|
@parametrize("scan_length", [1, 5])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_cond_in_scan(self, device, dynamic, reverse, dim, scan_length, autograd):
|
|
init = torch.randn(2, 4, 4, 4)
|
|
xs = torch.randn(scan_length, 4, 4, 4)
|
|
xs = xs.movedim(0, dim)
|
|
self._run_test(
|
|
model=ScanModels.CondInScan(reverse=reverse, dim=dim),
|
|
inputs=(
|
|
init,
|
|
xs,
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_chunked_ce(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=ScanModels.ChunkedCE(10),
|
|
inputs=(
|
|
torch.randn(100, 20),
|
|
torch.randn(20, 20),
|
|
torch.randn(100, 20),
|
|
torch.randn(20),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_compare_chunked_ce_with_no_scan(self, device, dynamic):
|
|
for trunk_size, B, T in zip([10, 20], [10, 100], [20, 40]):
|
|
self._compare_result(
|
|
model1=torch.compile(ScanModels.ChunkedCE(trunk_size), dynamic=dynamic),
|
|
model2=ScanModels.ChunkedCENoScan(trunk_size),
|
|
inputs=(
|
|
torch.randn(B, T),
|
|
torch.randn(T, T),
|
|
torch.randn(B, T),
|
|
torch.randn(T),
|
|
),
|
|
device=device,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_scan_with_clamp(self, device, dynamic, autograd):
|
|
B = 4
|
|
T = 8
|
|
H = 16
|
|
self._run_test(
|
|
model=ScanModels.ScanWithClamp(),
|
|
inputs=(
|
|
torch.randn((B, H)),
|
|
torch.randn((T, B, H)),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
|
|
class MapModels:
|
|
class Simple(torch.nn.Module):
|
|
def forward(self, map_op, x):
|
|
a = torch.ones(3, 4, device=x.device)
|
|
|
|
def f(x):
|
|
return x.sin() + a
|
|
|
|
return map_op(f, x)
|
|
|
|
class SimpleWithLinearWithView(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 5)
|
|
|
|
def forward(self, map_op, x):
|
|
def f(x):
|
|
return self.linear(x).sin()
|
|
|
|
return map_op(f, x.view(4, 3))
|
|
|
|
class PytreeInOut(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 5)
|
|
|
|
def forward(self, map_op, x, y, z):
|
|
def f(x_y_z):
|
|
x = x_y_z["x"]
|
|
y, (z,) = x_y_z["y_z"]
|
|
return self.linear(x).sin(), (self.linear(y), z.cos())
|
|
|
|
return map_op(f, {"x": x, "y_z": (y, (z,))})
|
|
|
|
class ReinterpretView(torch.nn.Module):
|
|
def forward(self, map_op, x, y, z):
|
|
def f(xyz):
|
|
x, y, z = xyz
|
|
return x.sin()[:2], y.cos()[:2] + z[-2:].clone()
|
|
|
|
return map_op(f, (x, y, z))
|
|
|
|
class NestedWithCond(torch.nn.Module):
|
|
def forward(self, map_op, x, y, z):
|
|
def true_fn(x, y, z):
|
|
def inner_f(yz):
|
|
y, z = yz
|
|
return y + z
|
|
|
|
return map_op(inner_f, (y, z))
|
|
|
|
def false_fn(x, y, z):
|
|
def inner_f(yz):
|
|
y, z = yz
|
|
return y - z
|
|
|
|
return map_op(inner_f, (y, z))
|
|
|
|
return torch._higher_order_ops.cond(
|
|
x.sum() > 0, true_fn, false_fn, (x, y, z)
|
|
)
|
|
|
|
|
|
class MapTests(TestCase):
|
|
def _run_test(
|
|
self,
|
|
model,
|
|
inputs,
|
|
device,
|
|
dynamic=False,
|
|
autograd=False,
|
|
):
|
|
import copy
|
|
|
|
inputs = [inp.to(device=device) for inp in inputs]
|
|
model = model.to(device=device)
|
|
model_eager = copy.deepcopy(model)
|
|
model_compiled = copy.deepcopy(model)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
|
|
model_compiled
|
|
)
|
|
|
|
if autograd:
|
|
pytree.tree_map_only(torch.Tensor, lambda t: t.requires_grad_(True), inputs)
|
|
|
|
cloned_inputs = [inp.clone() for inp in inputs]
|
|
result = model(torch._higher_order_ops.map, *cloned_inputs)
|
|
result_exp = model_eager(_fake_map, *cloned_inputs)
|
|
result_compiled = compiled_model(torch._higher_order_ops.map, *cloned_inputs)
|
|
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(result, result_compiled)
|
|
|
|
if autograd:
|
|
loss_fn(result).backward()
|
|
loss_fn(result_exp).backward()
|
|
loss_fn(result_compiled).backward()
|
|
|
|
model_params = dict(model.named_parameters())
|
|
model_eager_params = dict(model_eager.named_parameters())
|
|
model_compiled_params = dict(model_compiled.named_parameters())
|
|
for name, param in model_eager_params.items():
|
|
self.assertEqual(param, model_params[name])
|
|
self.assertEqual(param, model_compiled_params[name])
|
|
self.assertEqual(param.grad, model_params[name].grad)
|
|
self.assertEqual(param.grad, model_compiled_params[name].grad)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_map_simple(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=MapModels.Simple(),
|
|
inputs=(torch.randn(3, 4),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_map_simple_linear_with_view(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=MapModels.SimpleWithLinearWithView(),
|
|
inputs=(torch.randn(3, 4),),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_map_pytree_in_out(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=MapModels.PytreeInOut(),
|
|
inputs=(
|
|
torch.randn(2, 5, 3),
|
|
torch.randn(2, 5, 3),
|
|
torch.randn(2, 4, 3),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
@requires_gpu
|
|
@parametrize("device", ["cpu", GPU_TYPE])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("autograd", [True, False])
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_map_nested_with_cond(self, device, dynamic, autograd):
|
|
self._run_test(
|
|
model=MapModels.NestedWithCond(),
|
|
inputs=(
|
|
torch.randn(3, 2),
|
|
torch.randn(3, 10, 5),
|
|
torch.randn(3, 10, 5),
|
|
),
|
|
device=device,
|
|
dynamic=dynamic,
|
|
autograd=autograd,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(CondTests)
|
|
instantiate_parametrized_tests(WhileLoopTests)
|
|
instantiate_parametrized_tests(AssociativeScanTests)
|
|
instantiate_parametrized_tests(ScanTests)
|
|
instantiate_parametrized_tests(MapTests)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_CPU or HAS_GPU:
|
|
run_tests(needs="filelock")
|