[export] Rewrite exportdb formatting. (#129260)

Summary: It'll be easier to generate examples if the code doesn't depend on exportdb library.

Test Plan: CI

Differential Revision: D58886554

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129260
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2024-06-25 21:04:53 +00:00
committed by PyTorch MergeBot
parent 551e412718
commit e58ef5b65f
40 changed files with 203 additions and 367 deletions

View File

@ -32,8 +32,6 @@ def generate_example_rst(example_case: ExportCase):
)
with open(source_file) as file:
source_code = file.read()
source_code = re.sub(r"from torch\._export\.db\.case import .*\n", "", source_code)
source_code = re.sub(r"@export_case\((.|\n)*?\)\n", "", source_code)
source_code = source_code.replace("\n", "\n ")
splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code)
@ -43,6 +41,7 @@ def generate_example_rst(example_case: ExportCase):
}, f"more than one @export_rewrite_case decorator in {source_code}"
# Generate contents of the .rst file
# TODO(zhxchen17) Update template when we switch to example_args and example_kwargs.
title = f"{example_case.name}"
doc_contents = f"""{title}
{'^' * (len(title))}

View File

@ -122,9 +122,8 @@ def to_snake_case(name):
def _make_export_case(m, name, configs):
if not issubclass(m, torch.nn.Module):
if not isinstance(m, torch.nn.Module):
raise TypeError("Export case class should be a torch.nn.Module.")
m = m()
if "description" not in configs:
# Fallback to docstring if description is missing.
@ -148,14 +147,7 @@ def export_case(**kwargs):
assert module is not None
_MODULES.add(module)
normalized_name = to_snake_case(m.__name__)
module_name = module.__name__.split(".")[-1]
if module_name != normalized_name:
raise RuntimeError(
f'Module name "{module.__name__}" is inconsistent with exported program '
+ f'name "{m.__name__}". Please rename the module to "{normalized_name}".'
)
case = _make_export_case(m, module_name, configs)
register_db_case(case)
return case

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import dataclasses
import glob
import importlib
import inspect
from os.path import basename, dirname, isfile, join
import torch
@ -9,17 +10,24 @@ from torch._export.db.case import (
_EXAMPLE_CONFLICT_CASES,
_EXAMPLE_REWRITE_CASES,
SupportLevel,
export_case,
ExportCase,
)
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
]
def _collect_examples():
case_names = glob.glob(join(dirname(__file__), "*.py"))
case_names = [
basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py")
]
# Import all module in the current directory.
from . import * # noqa: F403
case_fields = {f.name for f in dataclasses.fields(ExportCase)}
for case_name in case_names:
case = __import__(case_name, globals(), locals(), [], 1)
variables = [name for name in dir(case) if name in case_fields]
export_case(**{v: getattr(case, v) for v in variables})(case.model)
_collect_examples()
def all_examples():
return _EXAMPLE_CASES

View File

@ -2,24 +2,19 @@
import torch
import torch._dynamo as torchdynamo
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2), torch.tensor(4)),
tags={"torch.escape-hatch"},
)
class AssumeConstantResult(torch.nn.Module):
"""
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
"""
def __init__(self):
super().__init__()
@torchdynamo.assume_constant_result
def get_item(self, y):
return y.int().item()
def forward(self, x, y):
return x[: self.get_item(y)]
example_inputs = (torch.randn(3, 2), torch.tensor(4))
tags = {"torch.escape-hatch"}
model = AssumeConstantResult()

View File

@ -1,9 +1,6 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
@ -13,10 +10,6 @@ class MyAutogradFunction(torch.autograd.Function):
def backward(ctx, grad_output):
return grad_output + 1
@export_case(
example_inputs=(torch.randn(3, 2),),
)
class AutogradFunction(torch.nn.Module):
"""
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
@ -25,3 +18,6 @@ class AutogradFunction(torch.nn.Module):
def forward(self, x):
return MyAutogradFunction.apply(x)
example_inputs = (torch.randn(3, 2),)
model = AutogradFunction()

View File

@ -1,12 +1,6 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 4),),
)
class ClassMethod(torch.nn.Module):
"""
Class methods are inlined during tracing.
@ -23,3 +17,6 @@ class ClassMethod(torch.nn.Module):
def forward(self, x):
x = self.linear(x)
return self.method(x) * self.__class__.method(x) * type(self).method(x)
example_inputs = (torch.randn(3, 4),)
model = ClassMethod()

View File

@ -1,10 +1,8 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
@ -12,14 +10,6 @@ class MySubModule(torch.nn.Module):
def forward(self, x):
return self.foo(x)
@export_case(
example_inputs=(torch.randn(3),),
tags={
"torch.cond",
"torch.dynamic-shape",
},
)
class CondBranchClassMethod(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
@ -45,3 +35,10 @@ class CondBranchClassMethod(torch.nn.Module):
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
example_inputs = (torch.randn(3),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondBranchClassMethod()

View File

@ -1,17 +1,8 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from functorch.experimental.control_flow import cond
@export_case(
example_inputs=(torch.randn(3),),
tags={
"torch.cond",
"torch.dynamic-shape",
},
)
class CondBranchNestedFunction(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
@ -26,8 +17,6 @@ class CondBranchNestedFunction(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
def true_fn(x):
@ -43,3 +32,10 @@ class CondBranchNestedFunction(torch.nn.Module):
return inner_false_fn(x)
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
example_inputs = (torch.randn(3),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondBranchNestedFunction()

View File

@ -1,17 +1,8 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from functorch.experimental.control_flow import cond
@export_case(
example_inputs=(torch.randn(6),),
tags={
"torch.cond",
"torch.dynamic-shape",
},
)
class CondBranchNonlocalVariables(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
@ -43,9 +34,6 @@ class CondBranchNonlocalVariables(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
my_tensor_var = x + 100
my_primitive_var = 3.14
@ -62,3 +50,10 @@ class CondBranchNonlocalVariables(torch.nn.Module):
false_fn,
[x, my_tensor_var, torch.tensor(my_primitive_var)],
)
example_inputs = (torch.randn(6),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondBranchNonlocalVariables()

View File

@ -1,14 +1,8 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from functorch.experimental.control_flow import cond
@export_case(
example_inputs=(torch.tensor(True), torch.randn(3, 2)),
tags={"torch.cond", "python.closure"},
)
class CondClosedOverVariable(torch.nn.Module):
"""
torch.cond() supports branches closed over arbitrary variables.
@ -22,3 +16,7 @@ class CondClosedOverVariable(torch.nn.Module):
return x - 2
return cond(pred, true_fn, false_fn, [x + 1])
example_inputs = (torch.tensor(True), torch.randn(3, 2))
tags = {"torch.cond", "python.closure"}
model = CondClosedOverVariable()

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from torch.export import Dim
from functorch.experimental.control_flow import cond
@ -9,15 +8,6 @@ x = torch.randn(3, 2)
y = torch.randn(2)
dim0_x = Dim("dim0_x")
@export_case(
example_inputs=(x, y),
tags={
"torch.cond",
"torch.dynamic-shape",
},
extra_inputs=(torch.randn(2, 2), torch.randn(2)),
dynamic_shapes={"x": {0: dim0_x}, "y": None},
)
class CondOperands(torch.nn.Module):
"""
The operands passed to cond() must be:
@ -27,9 +17,6 @@ class CondOperands(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
def true_fn(x, y):
return x + y
@ -38,3 +25,12 @@ class CondOperands(torch.nn.Module):
return x - y
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
example_inputs = (x, y)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
extra_inputs = (torch.randn(2, 2), torch.randn(2))
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
model = CondOperands()

View File

@ -1,17 +1,8 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from functorch.experimental.control_flow import cond
@export_case(
example_inputs=(torch.randn(6, 4, 3),),
tags={
"torch.cond",
"torch.dynamic-shape",
},
)
class CondPredicate(torch.nn.Module):
"""
The conditional statement (aka predicate) passed to cond() must be one of the following:
@ -21,10 +12,14 @@ class CondPredicate(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
example_inputs = (torch.randn(6, 4, 3),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondPredicate()

View File

@ -1,16 +1,7 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.tensor(4),),
tags={
"torch.dynamic-value",
"torch.escape-hatch",
},
)
class ConstrainAsSizeExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
@ -19,11 +10,16 @@ class ConstrainAsSizeExample(torch.nn.Module):
tensor.
"""
def __init__(self):
super().__init__()
def forward(self, x):
a = x.item()
torch._check_is_size(a)
torch._check(a <= 5)
return torch.zeros((a, 5))
example_inputs = (torch.tensor(4),)
tags = {
"torch.dynamic-value",
"torch.escape-hatch",
}
model = ConstrainAsSizeExample()

View File

@ -1,16 +1,7 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.tensor(4), torch.randn(5, 5)),
tags={
"torch.dynamic-value",
"torch.escape-hatch",
},
)
class ConstrainAsValueExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
@ -19,9 +10,6 @@ class ConstrainAsValueExample(torch.nn.Module):
tensor.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
a = x.item()
torch._check(a >= 0)
@ -30,3 +18,11 @@ class ConstrainAsValueExample(torch.nn.Module):
if a < 6:
return y.sin()
return y.cos()
example_inputs = (torch.tensor(4), torch.randn(5, 5))
tags = {
"torch.dynamic-value",
"torch.escape-hatch",
}
model = ConstrainAsValueExample()

View File

@ -3,9 +3,6 @@ import functools
import torch
from torch._export.db.case import export_case
def test_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
@ -13,10 +10,6 @@ def test_decorator(func):
return wrapper
@export_case(
example_inputs=(torch.randn(3, 2), torch.randn(3, 2)),
)
class Decorator(torch.nn.Module):
"""
Decorators calls are inlined into the exported function during tracing.
@ -25,3 +18,6 @@ class Decorator(torch.nn.Module):
@test_decorator
def forward(self, x, y):
return x + y
example_inputs = (torch.randn(3, 2), torch.randn(3, 2))
model = Decorator()

View File

@ -1,22 +1,17 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2), torch.tensor(4)),
tags={"python.data-structure"},
)
class Dictionary(torch.nn.Module):
"""
Dictionary structures are inlined and flattened along tracing.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
elements = {}
elements["x2"] = x * x
y = y * elements["x2"]
return {"y": y}
example_inputs = (torch.randn(3, 2), torch.tensor(4))
tags = {"python.data-structure"}
model = Dictionary()

View File

@ -1,19 +1,10 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"python.assert"},
)
class DynamicShapeAssert(torch.nn.Module):
"""
A basic usage of python assertion.
"""
def __init__(self):
super().__init__()
def forward(self, x):
# assertion with error message
@ -21,3 +12,7 @@ class DynamicShapeAssert(torch.nn.Module):
# assertion without error message
assert x.shape[0] > 1
return x
example_inputs = (torch.randn(3, 2),)
tags = {"python.assert"}
model = DynamicShapeAssert()

View File

@ -1,20 +1,15 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"torch.dynamic-shape"},
)
class DynamicShapeConstructor(torch.nn.Module):
"""
Tensor constructors should be captured with dynamic shape inputs rather
than being baked in with static shape.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return torch.zeros(x.shape[0] * 2)
example_inputs = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeConstructor()

View File

@ -1,13 +1,6 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2, 2),),
tags={"torch.dynamic-shape", "python.control-flow"},
)
class DynamicShapeIfGuard(torch.nn.Module):
"""
`if` statement with backed dynamic shape predicate will be specialized into
@ -20,3 +13,7 @@ class DynamicShapeIfGuard(torch.nn.Module):
return x.cos()
return x.sin()
example_inputs = (torch.randn(3, 2, 2),)
tags = {"torch.dynamic-shape", "python.control-flow"}
model = DynamicShapeIfGuard()

View File

@ -1,24 +1,19 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from functorch.experimental.control_flow import map
@export_case(
example_inputs=(torch.randn(3, 2), torch.randn(2)),
tags={"torch.dynamic-shape", "torch.map"},
)
class DynamicShapeMap(torch.nn.Module):
"""
functorch map() maps a function over the first tensor dimension.
"""
def __init__(self):
super().__init__()
def forward(self, xs, y):
def body(x, y):
return x + y
return map(body, xs, y)
example_inputs = (torch.randn(3, 2), torch.randn(2))
tags = {"torch.dynamic-shape", "torch.map"}
model = DynamicShapeMap()

View File

@ -1,25 +1,21 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
from torch._export.db.case import SupportLevel
from torch.export import Dim
x = torch.randn(3, 2)
dim0_x = Dim("dim0_x")
@export_case(
example_inputs=(x,),
tags={"torch.dynamic-shape", "python.builtin"},
support_level=SupportLevel.NOT_SUPPORTED_YET,
dynamic_shapes={"x": {0: dim0_x}},
)
class DynamicShapeRound(torch.nn.Module):
"""
Calling round on dynamic shapes is not supported.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x[: round(x.shape[0] / 2)]
x = torch.randn(3, 2)
dim0_x = Dim("dim0_x")
example_inputs = (x,)
tags = {"torch.dynamic-shape", "python.builtin"}
support_level = SupportLevel.NOT_SUPPORTED_YET
dynamic_shapes = {"x": {0: dim0_x}}
model = DynamicShapeRound()

View File

@ -1,21 +1,15 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"torch.dynamic-shape"},
)
class DynamicShapeSlicing(torch.nn.Module):
"""
Slices with dynamic shape arguments should be captured into the graph
rather than being baked in.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
example_inputs = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeSlicing()

View File

@ -1,23 +1,17 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(10, 10),),
tags={"torch.dynamic-shape"},
)
class DynamicShapeView(torch.nn.Module):
"""
Dynamic shapes should be propagated to view arguments instead of being
baked into the exported graph.
"""
def __init__(self):
super().__init__()
def forward(self, x):
new_x_shape = x.size()[:-1] + (2, 5)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1)
example_inputs = (torch.randn(10, 10),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeView()

View File

@ -1,26 +1,12 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, ExportArgs, SupportLevel
from torch._export.db.case import ExportArgs
@export_case(
example_inputs=ExportArgs(
torch.randn(4),
(torch.randn(4), torch.randn(4)),
*[torch.randn(4), torch.randn(4)],
mykw0=torch.randn(4),
input0=torch.randn(4), input1=torch.randn(4)
),
tags={"python.data-structure"},
support_level=SupportLevel.SUPPORTED,
)
class FnWithKwargs(torch.nn.Module):
"""
Keyword arguments are not supported at the moment.
"""
def __init__(self):
super().__init__()
def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
out = pos0
@ -31,3 +17,13 @@ class FnWithKwargs(torch.nn.Module):
out = out * mykw0
out = out * mykwargs["input0"] * mykwargs["input1"]
return out
example_inputs = ExportArgs(
torch.randn(4),
(torch.randn(4), torch.randn(4)),
*[torch.randn(4), torch.randn(4)],
mykw0=torch.randn(4),
input0=torch.randn(4), input1=torch.randn(4)
)
tags = {"python.data-structure"}
model = FnWithKwargs()

View File

@ -1,22 +1,17 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"torch.dynamic-shape", "python.data-structure", "python.assert"},
)
class ListContains(torch.nn.Module):
"""
List containment relation can be checked on a dynamic shape or constants.
"""
def __init__(self):
super().__init__()
def forward(self, x):
assert x.size(-1) in [6, 2]
assert x.size(0) not in [4, 5, 6]
assert "monkey" not in ["cow", "pig"]
return x + x
example_inputs = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
model = ListContains()

View File

@ -3,22 +3,12 @@ from typing import List
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],),
tags={"python.control-flow", "python.data-structure"},
)
class ListUnpack(torch.nn.Module):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
def __init__(self):
super().__init__()
def forward(self, args: List[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
@ -26,3 +16,7 @@ class ListUnpack(torch.nn.Module):
"""
x, *y = args
return x + y[0]
example_inputs = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
tags = {"python.control-flow", "python.data-structure"}
model = ListUnpack()

View File

@ -1,14 +1,9 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
from torch._export.db.case import SupportLevel
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"python.object-model"},
support_level=SupportLevel.NOT_SUPPORTED_YET,
)
class ModelAttrMutation(torch.nn.Module):
"""
Attribute mutation is not supported.
@ -24,3 +19,9 @@ class ModelAttrMutation(torch.nn.Module):
def forward(self, x):
self.attr_list = self.recreate_list()
return x.sum() + self.attr_list[0].sum()
example_inputs = (torch.randn(3, 2),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = ModelAttrMutation()

View File

@ -1,20 +1,11 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2), torch.randn(2)),
tags={"python.closure"},
)
class NestedFunction(torch.nn.Module):
"""
Nested functions are traced through. Side effects on global captures
are not supported though.
"""
def __init__(self):
super().__init__()
def forward(self, a, b):
x = a + b
@ -26,3 +17,7 @@ class NestedFunction(torch.nn.Module):
return x * y + z
return closure(x)
example_inputs = (torch.randn(3, 2), torch.randn(2))
tags = {"python.closure"}
model = NestedFunction()

View File

@ -3,21 +3,11 @@ import contextlib
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"python.context-manager"},
)
class NullContextManager(torch.nn.Module):
"""
Null context manager in Python will be traced out.
"""
def __init__(self):
super().__init__()
def forward(self, x):
"""
Null context manager in Python will be traced out.
@ -25,3 +15,7 @@ class NullContextManager(torch.nn.Module):
ctx = contextlib.nullcontext()
with ctx:
return x.sin() + x.cos()
example_inputs = (torch.randn(3, 2),)
tags = {"python.context-manager"}
model = NullContextManager()

View File

@ -1,14 +1,9 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
from torch._export.db.case import SupportLevel
@export_case(
example_inputs=(torch.randn(2, 3),),
tags={"python.object-model"},
support_level=SupportLevel.NOT_SUPPORTED_YET,
)
class OptionalInput(torch.nn.Module):
"""
Tracing through optional input is not supported yet
@ -18,3 +13,9 @@ class OptionalInput(torch.nn.Module):
if y is not None:
return x + y
return x
example_inputs = (torch.randn(2, 3),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = OptionalInput()

View File

@ -1,21 +1,16 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
from torch.utils import _pytree as pytree
@export_case(
example_inputs=({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
support_level=SupportLevel.SUPPORTED,
)
class PytreeFlatten(torch.nn.Module):
"""
Pytree from PyTorch can be captured by TorchDynamo.
"""
def __init__(self):
super().__init__()
def forward(self, x):
y, spec = pytree.tree_flatten(x)
return y[0] + 1
example_inputs = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
model = PytreeFlatten()

View File

@ -1,17 +1,11 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
from torch.export import Dim
x = torch.randn(3, 2)
dim1_x = Dim("dim1_x")
@export_case(
example_inputs=(x,),
tags={"torch.dynamic-shape"},
dynamic_shapes={"x": {1: dim1_x}},
)
class ScalarOutput(torch.nn.Module):
"""
Returning scalar values from the graph is supported, in addition to Tensor
@ -22,3 +16,8 @@ class ScalarOutput(torch.nn.Module):
def forward(self, x):
return x.shape[1] + 1
example_inputs = (x,)
tags = {"torch.dynamic-shape"}
dynamic_shapes = {"x": {1: dim1_x}}
model = ScalarOutput()

View File

@ -3,16 +3,9 @@ from enum import Enum
import torch
from torch._export.db.case import export_case
class Animal(Enum):
COW = "moo"
@export_case(
example_inputs=(torch.randn(3, 2),),
)
class SpecializedAttribute(torch.nn.Module):
"""
Model attributes are specialized.
@ -28,3 +21,6 @@ class SpecializedAttribute(torch.nn.Module):
return x * x + self.b
else:
raise ValueError("bad")
example_inputs = (torch.randn(3, 2),)
model = SpecializedAttribute()

View File

@ -1,23 +1,17 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"python.control-flow"},
)
class StaticForLoop(torch.nn.Module):
"""
A for loop with constant number of iterations should be unrolled in the exported graph.
"""
def __init__(self):
super().__init__()
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
return ret
example_inputs = (torch.randn(3, 2),)
tags = {"python.control-flow"}
model = StaticForLoop()

View File

@ -1,24 +1,18 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2, 2),),
tags={"python.control-flow"},
)
class StaticIf(torch.nn.Module):
"""
`if` statement with static predicate value should be traced through with the
taken branch.
"""
def __init__(self):
super().__init__()
def forward(self, x):
if len(x.shape) == 3:
return x + torch.ones(1, 1, 1)
return x
example_inputs = (torch.randn(3, 2, 2),)
tags = {"python.control-flow"}
model = StaticIf()

View File

@ -1,14 +1,7 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
@export_case(
example_inputs=(torch.randn(3, 2), "attr"),
tags={"python.builtin"},
support_level=SupportLevel.SUPPORTED,
)
class TensorSetattr(torch.nn.Module):
"""
setattr() call onto tensors is not supported.
@ -16,3 +9,7 @@ class TensorSetattr(torch.nn.Module):
def forward(self, x, attr):
setattr(x, attr, torch.randn(3, 2))
return x + 4
example_inputs = (torch.randn(3, 2), "attr")
tags = {"python.builtin"}
model = TensorSetattr()

View File

@ -1,14 +1,9 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
from torch._export.db.case import SupportLevel
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"torch.operator"},
support_level=SupportLevel.NOT_SUPPORTED_YET,
)
class TorchSymMin(torch.nn.Module):
"""
torch.sym_min operator is not supported in export.
@ -16,3 +11,9 @@ class TorchSymMin(torch.nn.Module):
def forward(self, x):
return x.sum() + torch.sym_min(x.size(0), 100)
example_inputs = (torch.randn(3, 2),)
tags = {"torch.operator"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = TorchSymMin()

View File

@ -1,42 +1,22 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel, export_rewrite_case
class A:
@classmethod
def func(cls, x):
return 1 + x
@export_case(
example_inputs=(torch.randn(3, 4),),
tags={"python.builtin"},
support_level=SupportLevel.SUPPORTED,
)
class TypeReflectionMethod(torch.nn.Module):
"""
type() calls on custom objects followed by attribute accesses are not allowed
due to its overly dynamic nature.
"""
def __init__(self):
super().__init__()
def forward(self, x):
a = A()
return type(a).func(x)
@export_rewrite_case(parent=TypeReflectionMethod)
class TypeReflectionMethodRewrite(torch.nn.Module):
"""
Custom object class methods will be inlined.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return A.func(x)
example_inputs = (torch.randn(3, 4),)
tags = {"python.builtin"}
model = TypeReflectionMethod()

View File

@ -1,14 +1,7 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import export_case, SupportLevel
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={"torch.mutation"},
support_level=SupportLevel.SUPPORTED,
)
class UserInputMutation(torch.nn.Module):
"""
Directly mutate user input in forward
@ -17,3 +10,8 @@ class UserInputMutation(torch.nn.Module):
def forward(self, x):
x.mul_(2)
return x.cos()
example_inputs = (torch.randn(3, 2),)
tags = {"torch.mutation"}
model = UserInputMutation()

View File

@ -5,13 +5,6 @@ import torch._export.db.examples as examples
TEMPLATE = '''import torch
from torch._export.db.case import export_case
@export_case(
example_inputs=(torch.randn(3, 2),),
tags={{}},
)
def {case_name}(x):
"""
"""