Files
pytorch/test/dynamo/test_flat_apply.py

189 lines
5.2 KiB
Python

# Owner(s): ["module: dynamo", "module: higher order operators"]
from dataclasses import dataclass
import torch
import torch._dynamo.test_case
import torch.utils._pytree as pytree
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
EagerAndRecordGraphs,
normalize_gm,
)
from torch._higher_order_ops.flat_apply import (
flat_apply,
func_to_graphable,
is_graphable,
to_graphable,
)
def distance(a, b, norm):
if norm.typ == "l2":
return torch.sqrt((a.x - b.x).pow(2) + (a.y - b.y).pow(2))
elif norm.typ == "l1":
return (a.x - b.x).abs() + (a.y - b.y).abs()
@dataclass(frozen=True)
class Norm:
typ: str
pytree.register_constant(Norm)
@dataclass
class Point:
x: torch.Tensor
y: torch.Tensor
pytree.register_dataclass(Point)
class FlatApplyTests(torch._dynamo.test_case.TestCase):
def test_simple(self):
tensor = torch.tensor
a = Point(tensor(0.0), tensor(0.0))
b = Point(tensor(3.0), tensor(4.0))
norm = Norm("l2")
args = (a, b)
kwargs = {"norm": norm}
empty_list, func_spec = func_to_graphable(distance)
self.assertEqual(empty_list, [])
flat_args, in_spec = to_graphable((args, kwargs))
for arg in flat_args:
self.assertTrue(is_graphable(arg))
# Test flat_apply returns same thing as original function
result = flat_apply(func_spec, in_spec, *flat_args)
self.assertEqual(result, distance(*args, **kwargs))
def test_non_tensor_output(self):
tensor = torch.tensor
a = Point(tensor(0.0), tensor(0.0))
b = Point(tensor(3.0), tensor(4.0))
args = (a, b)
kwargs = {}
def f(a, b):
return [a.x + 1, (b.x + 2, [a.y + 3, 4.0], "5"), 6 + b.y]
empty_list, func_spec = func_to_graphable(f)
self.assertEqual(empty_list, [])
flat_args, in_spec = to_graphable((args, kwargs))
for arg in flat_args:
self.assertTrue(is_graphable(arg))
# Test flat_apply returns same thing as original function
result = flat_apply(func_spec, in_spec, *flat_args)
self.assertEqual(result, f(*args, **kwargs))
def test_nonstrict_trace_dynamo_graph(self):
class Point:
x: torch.Tensor
y: torch.Tensor
def __init__(self, x, y):
self.x = x
self.y = y
class PointTensor:
p: Point
t: torch.Tensor
def __init__(self, p, t):
self.p = p
self.t = t
torch.utils._pytree.register_pytree_node(
PointTensor,
lambda pt: ((pt.p, pt.t), ()),
lambda pt, _: PointTensor(pt[0], pt[1]),
)
torch.utils._pytree.register_pytree_node(
Point,
lambda p: ((p.x, p.y), ()),
lambda xy, _: Point(xy[0], xy[1]),
)
def trace_point(p):
torch._dynamo.graph_break()
return p.x * p.y
@torch._dynamo.nonstrict_trace
def trace_point_tensor(pt):
torch._dynamo.graph_break()
return pt.t + trace_point(pt.p)
backend = EagerAndRecordGraphs()
@torch.compile(fullgraph=True, backend=backend)
def fn(x, y):
p = Point(x, y)
t = x + y
pt = PointTensor(p, t)
res = trace_point_tensor(pt)
return res
fn(torch.randn(10), torch.randn(10))
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"):
l_x_ = L_x_
l_y_ = L_y_
t: "f32[10]" = l_x_ + l_y_
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
return (res,)
""", # NOQA: B950
)
def test_nonstrict_trace_captured_tensor_post_aot_graph(self):
cst = torch.ones(1)
@torch._dynamo.nonstrict_trace
def trace_me(x, y):
torch._dynamo.graph_break()
return x * y + cst
backend = AotEagerAndRecordGraphs()
@torch.compile(fullgraph=True, backend=backend)
def fn(x, y):
return trace_me(x, y)
fn(torch.randn(10), torch.randn(10))
self.assertExpectedInline(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
_tensor_constant0: "f32[1]" = self._tensor_constant0
add: "f32[10]" = torch.ops.aten.add.Tensor(mul, _tensor_constant0); mul = _tensor_constant0 = None
return (add,)
""", # NOQA: B950
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()