mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Rewrite exportdb formatting. (#129260)
Summary: It'll be easier to generate examples if the code doesn't depend on exportdb library. Test Plan: CI Differential Revision: D58886554 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129260 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
551e412718
commit
e58ef5b65f
@ -32,8 +32,6 @@ def generate_example_rst(example_case: ExportCase):
|
||||
)
|
||||
with open(source_file) as file:
|
||||
source_code = file.read()
|
||||
source_code = re.sub(r"from torch\._export\.db\.case import .*\n", "", source_code)
|
||||
source_code = re.sub(r"@export_case\((.|\n)*?\)\n", "", source_code)
|
||||
source_code = source_code.replace("\n", "\n ")
|
||||
splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code)
|
||||
|
||||
@ -43,6 +41,7 @@ def generate_example_rst(example_case: ExportCase):
|
||||
}, f"more than one @export_rewrite_case decorator in {source_code}"
|
||||
|
||||
# Generate contents of the .rst file
|
||||
# TODO(zhxchen17) Update template when we switch to example_args and example_kwargs.
|
||||
title = f"{example_case.name}"
|
||||
doc_contents = f"""{title}
|
||||
{'^' * (len(title))}
|
||||
|
@ -122,9 +122,8 @@ def to_snake_case(name):
|
||||
|
||||
|
||||
def _make_export_case(m, name, configs):
|
||||
if not issubclass(m, torch.nn.Module):
|
||||
if not isinstance(m, torch.nn.Module):
|
||||
raise TypeError("Export case class should be a torch.nn.Module.")
|
||||
m = m()
|
||||
|
||||
if "description" not in configs:
|
||||
# Fallback to docstring if description is missing.
|
||||
@ -148,14 +147,7 @@ def export_case(**kwargs):
|
||||
|
||||
assert module is not None
|
||||
_MODULES.add(module)
|
||||
normalized_name = to_snake_case(m.__name__)
|
||||
module_name = module.__name__.split(".")[-1]
|
||||
if module_name != normalized_name:
|
||||
raise RuntimeError(
|
||||
f'Module name "{module.__name__}" is inconsistent with exported program '
|
||||
+ f'name "{m.__name__}". Please rename the module to "{normalized_name}".'
|
||||
)
|
||||
|
||||
case = _make_export_case(m, module_name, configs)
|
||||
register_db_case(case)
|
||||
return case
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
from os.path import basename, dirname, isfile, join
|
||||
|
||||
import torch
|
||||
@ -9,17 +10,24 @@ from torch._export.db.case import (
|
||||
_EXAMPLE_CONFLICT_CASES,
|
||||
_EXAMPLE_REWRITE_CASES,
|
||||
SupportLevel,
|
||||
export_case,
|
||||
ExportCase,
|
||||
)
|
||||
|
||||
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [
|
||||
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
|
||||
]
|
||||
def _collect_examples():
|
||||
case_names = glob.glob(join(dirname(__file__), "*.py"))
|
||||
case_names = [
|
||||
basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py")
|
||||
]
|
||||
|
||||
# Import all module in the current directory.
|
||||
from . import * # noqa: F403
|
||||
case_fields = {f.name for f in dataclasses.fields(ExportCase)}
|
||||
for case_name in case_names:
|
||||
case = __import__(case_name, globals(), locals(), [], 1)
|
||||
variables = [name for name in dir(case) if name in case_fields]
|
||||
export_case(**{v: getattr(case, v) for v in variables})(case.model)
|
||||
|
||||
_collect_examples()
|
||||
|
||||
def all_examples():
|
||||
return _EXAMPLE_CASES
|
||||
|
@ -2,24 +2,19 @@
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2), torch.tensor(4)),
|
||||
tags={"torch.escape-hatch"},
|
||||
)
|
||||
class AssumeConstantResult(torch.nn.Module):
|
||||
"""
|
||||
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torchdynamo.assume_constant_result
|
||||
def get_item(self, y):
|
||||
return y.int().item()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x[: self.get_item(y)]
|
||||
|
||||
example_inputs = (torch.randn(3, 2), torch.tensor(4))
|
||||
tags = {"torch.escape-hatch"}
|
||||
model = AssumeConstantResult()
|
||||
|
@ -1,9 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
class MyAutogradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
@ -13,10 +10,6 @@ class MyAutogradFunction(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output + 1
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
)
|
||||
class AutogradFunction(torch.nn.Module):
|
||||
"""
|
||||
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
|
||||
@ -25,3 +18,6 @@ class AutogradFunction(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return MyAutogradFunction.apply(x)
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
model = AutogradFunction()
|
||||
|
@ -1,12 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 4),),
|
||||
)
|
||||
class ClassMethod(torch.nn.Module):
|
||||
"""
|
||||
Class methods are inlined during tracing.
|
||||
@ -23,3 +17,6 @@ class ClassMethod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return self.method(x) * self.__class__.method(x) * type(self).method(x)
|
||||
|
||||
example_inputs = (torch.randn(3, 4),)
|
||||
model = ClassMethod()
|
||||
|
@ -1,10 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
class MySubModule(torch.nn.Module):
|
||||
def foo(self, x):
|
||||
return x.cos()
|
||||
@ -12,14 +10,6 @@ class MySubModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.foo(x)
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3),),
|
||||
tags={
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
},
|
||||
)
|
||||
class CondBranchClassMethod(torch.nn.Module):
|
||||
"""
|
||||
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
||||
@ -45,3 +35,10 @@ class CondBranchClassMethod(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
|
||||
|
||||
example_inputs = (torch.randn(3),)
|
||||
tags = {
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
}
|
||||
model = CondBranchClassMethod()
|
||||
|
@ -1,17 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3),),
|
||||
tags={
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
},
|
||||
)
|
||||
class CondBranchNestedFunction(torch.nn.Module):
|
||||
"""
|
||||
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
||||
@ -26,8 +17,6 @@ class CondBranchNestedFunction(torch.nn.Module):
|
||||
|
||||
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
def true_fn(x):
|
||||
@ -43,3 +32,10 @@ class CondBranchNestedFunction(torch.nn.Module):
|
||||
return inner_false_fn(x)
|
||||
|
||||
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.randn(3),)
|
||||
tags = {
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
}
|
||||
model = CondBranchNestedFunction()
|
||||
|
@ -1,17 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(6),),
|
||||
tags={
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
},
|
||||
)
|
||||
class CondBranchNonlocalVariables(torch.nn.Module):
|
||||
"""
|
||||
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
|
||||
@ -43,9 +34,6 @@ class CondBranchNonlocalVariables(torch.nn.Module):
|
||||
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
my_tensor_var = x + 100
|
||||
my_primitive_var = 3.14
|
||||
@ -62,3 +50,10 @@ class CondBranchNonlocalVariables(torch.nn.Module):
|
||||
false_fn,
|
||||
[x, my_tensor_var, torch.tensor(my_primitive_var)],
|
||||
)
|
||||
|
||||
example_inputs = (torch.randn(6),)
|
||||
tags = {
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
}
|
||||
model = CondBranchNonlocalVariables()
|
||||
|
@ -1,14 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.tensor(True), torch.randn(3, 2)),
|
||||
tags={"torch.cond", "python.closure"},
|
||||
)
|
||||
class CondClosedOverVariable(torch.nn.Module):
|
||||
"""
|
||||
torch.cond() supports branches closed over arbitrary variables.
|
||||
@ -22,3 +16,7 @@ class CondClosedOverVariable(torch.nn.Module):
|
||||
return x - 2
|
||||
|
||||
return cond(pred, true_fn, false_fn, [x + 1])
|
||||
|
||||
example_inputs = (torch.tensor(True), torch.randn(3, 2))
|
||||
tags = {"torch.cond", "python.closure"}
|
||||
model = CondClosedOverVariable()
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from torch.export import Dim
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
@ -9,15 +8,6 @@ x = torch.randn(3, 2)
|
||||
y = torch.randn(2)
|
||||
dim0_x = Dim("dim0_x")
|
||||
|
||||
@export_case(
|
||||
example_inputs=(x, y),
|
||||
tags={
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
},
|
||||
extra_inputs=(torch.randn(2, 2), torch.randn(2)),
|
||||
dynamic_shapes={"x": {0: dim0_x}, "y": None},
|
||||
)
|
||||
class CondOperands(torch.nn.Module):
|
||||
"""
|
||||
The operands passed to cond() must be:
|
||||
@ -27,9 +17,6 @@ class CondOperands(torch.nn.Module):
|
||||
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
def true_fn(x, y):
|
||||
return x + y
|
||||
@ -38,3 +25,12 @@ class CondOperands(torch.nn.Module):
|
||||
return x - y
|
||||
|
||||
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||
|
||||
example_inputs = (x, y)
|
||||
tags = {
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
}
|
||||
extra_inputs = (torch.randn(2, 2), torch.randn(2))
|
||||
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
|
||||
model = CondOperands()
|
||||
|
@ -1,17 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(6, 4, 3),),
|
||||
tags={
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
},
|
||||
)
|
||||
class CondPredicate(torch.nn.Module):
|
||||
"""
|
||||
The conditional statement (aka predicate) passed to cond() must be one of the following:
|
||||
@ -21,10 +12,14 @@ class CondPredicate(torch.nn.Module):
|
||||
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
pred = x.dim() > 2 and x.shape[2] > 10
|
||||
|
||||
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
|
||||
|
||||
example_inputs = (torch.randn(6, 4, 3),)
|
||||
tags = {
|
||||
"torch.cond",
|
||||
"torch.dynamic-shape",
|
||||
}
|
||||
model = CondPredicate()
|
||||
|
@ -1,16 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.tensor(4),),
|
||||
tags={
|
||||
"torch.dynamic-value",
|
||||
"torch.escape-hatch",
|
||||
},
|
||||
)
|
||||
class ConstrainAsSizeExample(torch.nn.Module):
|
||||
"""
|
||||
If the value is not known at tracing time, you can provide hint so that we
|
||||
@ -19,11 +10,16 @@ class ConstrainAsSizeExample(torch.nn.Module):
|
||||
tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
a = x.item()
|
||||
torch._check_is_size(a)
|
||||
torch._check(a <= 5)
|
||||
return torch.zeros((a, 5))
|
||||
|
||||
|
||||
example_inputs = (torch.tensor(4),)
|
||||
tags = {
|
||||
"torch.dynamic-value",
|
||||
"torch.escape-hatch",
|
||||
}
|
||||
model = ConstrainAsSizeExample()
|
||||
|
@ -1,16 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.tensor(4), torch.randn(5, 5)),
|
||||
tags={
|
||||
"torch.dynamic-value",
|
||||
"torch.escape-hatch",
|
||||
},
|
||||
)
|
||||
class ConstrainAsValueExample(torch.nn.Module):
|
||||
"""
|
||||
If the value is not known at tracing time, you can provide hint so that we
|
||||
@ -19,9 +10,6 @@ class ConstrainAsValueExample(torch.nn.Module):
|
||||
tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
a = x.item()
|
||||
torch._check(a >= 0)
|
||||
@ -30,3 +18,11 @@ class ConstrainAsValueExample(torch.nn.Module):
|
||||
if a < 6:
|
||||
return y.sin()
|
||||
return y.cos()
|
||||
|
||||
|
||||
example_inputs = (torch.tensor(4), torch.randn(5, 5))
|
||||
tags = {
|
||||
"torch.dynamic-value",
|
||||
"torch.escape-hatch",
|
||||
}
|
||||
model = ConstrainAsValueExample()
|
||||
|
@ -3,9 +3,6 @@ import functools
|
||||
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
def test_decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -13,10 +10,6 @@ def test_decorator(func):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2), torch.randn(3, 2)),
|
||||
)
|
||||
class Decorator(torch.nn.Module):
|
||||
"""
|
||||
Decorators calls are inlined into the exported function during tracing.
|
||||
@ -25,3 +18,6 @@ class Decorator(torch.nn.Module):
|
||||
@test_decorator
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
example_inputs = (torch.randn(3, 2), torch.randn(3, 2))
|
||||
model = Decorator()
|
||||
|
@ -1,22 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2), torch.tensor(4)),
|
||||
tags={"python.data-structure"},
|
||||
)
|
||||
class Dictionary(torch.nn.Module):
|
||||
"""
|
||||
Dictionary structures are inlined and flattened along tracing.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
elements = {}
|
||||
elements["x2"] = x * x
|
||||
y = y * elements["x2"]
|
||||
return {"y": y}
|
||||
|
||||
example_inputs = (torch.randn(3, 2), torch.tensor(4))
|
||||
tags = {"python.data-structure"}
|
||||
model = Dictionary()
|
||||
|
@ -1,19 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"python.assert"},
|
||||
)
|
||||
class DynamicShapeAssert(torch.nn.Module):
|
||||
"""
|
||||
A basic usage of python assertion.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# assertion with error message
|
||||
@ -21,3 +12,7 @@ class DynamicShapeAssert(torch.nn.Module):
|
||||
# assertion without error message
|
||||
assert x.shape[0] > 1
|
||||
return x
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"python.assert"}
|
||||
model = DynamicShapeAssert()
|
||||
|
@ -1,20 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"torch.dynamic-shape"},
|
||||
)
|
||||
class DynamicShapeConstructor(torch.nn.Module):
|
||||
"""
|
||||
Tensor constructors should be captured with dynamic shape inputs rather
|
||||
than being baked in with static shape.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.zeros(x.shape[0] * 2)
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"torch.dynamic-shape"}
|
||||
model = DynamicShapeConstructor()
|
||||
|
@ -1,13 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2, 2),),
|
||||
tags={"torch.dynamic-shape", "python.control-flow"},
|
||||
)
|
||||
class DynamicShapeIfGuard(torch.nn.Module):
|
||||
"""
|
||||
`if` statement with backed dynamic shape predicate will be specialized into
|
||||
@ -20,3 +13,7 @@ class DynamicShapeIfGuard(torch.nn.Module):
|
||||
return x.cos()
|
||||
|
||||
return x.sin()
|
||||
|
||||
example_inputs = (torch.randn(3, 2, 2),)
|
||||
tags = {"torch.dynamic-shape", "python.control-flow"}
|
||||
model = DynamicShapeIfGuard()
|
||||
|
@ -1,24 +1,19 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from functorch.experimental.control_flow import map
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2), torch.randn(2)),
|
||||
tags={"torch.dynamic-shape", "torch.map"},
|
||||
)
|
||||
class DynamicShapeMap(torch.nn.Module):
|
||||
"""
|
||||
functorch map() maps a function over the first tensor dimension.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, xs, y):
|
||||
def body(x, y):
|
||||
return x + y
|
||||
|
||||
return map(body, xs, y)
|
||||
|
||||
example_inputs = (torch.randn(3, 2), torch.randn(2))
|
||||
tags = {"torch.dynamic-shape", "torch.map"}
|
||||
model = DynamicShapeMap()
|
||||
|
@ -1,25 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
from torch._export.db.case import SupportLevel
|
||||
from torch.export import Dim
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
dim0_x = Dim("dim0_x")
|
||||
|
||||
@export_case(
|
||||
example_inputs=(x,),
|
||||
tags={"torch.dynamic-shape", "python.builtin"},
|
||||
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
||||
dynamic_shapes={"x": {0: dim0_x}},
|
||||
)
|
||||
class DynamicShapeRound(torch.nn.Module):
|
||||
"""
|
||||
Calling round on dynamic shapes is not supported.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x[: round(x.shape[0] / 2)]
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
dim0_x = Dim("dim0_x")
|
||||
example_inputs = (x,)
|
||||
tags = {"torch.dynamic-shape", "python.builtin"}
|
||||
support_level = SupportLevel.NOT_SUPPORTED_YET
|
||||
dynamic_shapes = {"x": {0: dim0_x}}
|
||||
model = DynamicShapeRound()
|
||||
|
@ -1,21 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"torch.dynamic-shape"},
|
||||
)
|
||||
class DynamicShapeSlicing(torch.nn.Module):
|
||||
"""
|
||||
Slices with dynamic shape arguments should be captured into the graph
|
||||
rather than being baked in.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"torch.dynamic-shape"}
|
||||
model = DynamicShapeSlicing()
|
||||
|
@ -1,23 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(10, 10),),
|
||||
tags={"torch.dynamic-shape"},
|
||||
)
|
||||
class DynamicShapeView(torch.nn.Module):
|
||||
"""
|
||||
Dynamic shapes should be propagated to view arguments instead of being
|
||||
baked into the exported graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
new_x_shape = x.size()[:-1] + (2, 5)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
example_inputs = (torch.randn(10, 10),)
|
||||
tags = {"torch.dynamic-shape"}
|
||||
model = DynamicShapeView()
|
||||
|
@ -1,26 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, ExportArgs, SupportLevel
|
||||
from torch._export.db.case import ExportArgs
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=ExportArgs(
|
||||
torch.randn(4),
|
||||
(torch.randn(4), torch.randn(4)),
|
||||
*[torch.randn(4), torch.randn(4)],
|
||||
mykw0=torch.randn(4),
|
||||
input0=torch.randn(4), input1=torch.randn(4)
|
||||
),
|
||||
tags={"python.data-structure"},
|
||||
support_level=SupportLevel.SUPPORTED,
|
||||
)
|
||||
class FnWithKwargs(torch.nn.Module):
|
||||
"""
|
||||
Keyword arguments are not supported at the moment.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
|
||||
out = pos0
|
||||
@ -31,3 +17,13 @@ class FnWithKwargs(torch.nn.Module):
|
||||
out = out * mykw0
|
||||
out = out * mykwargs["input0"] * mykwargs["input1"]
|
||||
return out
|
||||
|
||||
example_inputs = ExportArgs(
|
||||
torch.randn(4),
|
||||
(torch.randn(4), torch.randn(4)),
|
||||
*[torch.randn(4), torch.randn(4)],
|
||||
mykw0=torch.randn(4),
|
||||
input0=torch.randn(4), input1=torch.randn(4)
|
||||
)
|
||||
tags = {"python.data-structure"}
|
||||
model = FnWithKwargs()
|
||||
|
@ -1,22 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"torch.dynamic-shape", "python.data-structure", "python.assert"},
|
||||
)
|
||||
class ListContains(torch.nn.Module):
|
||||
"""
|
||||
List containment relation can be checked on a dynamic shape or constants.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
assert x.size(-1) in [6, 2]
|
||||
assert x.size(0) not in [4, 5, 6]
|
||||
assert "monkey" not in ["cow", "pig"]
|
||||
return x + x
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
|
||||
model = ListContains()
|
||||
|
@ -3,22 +3,12 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],),
|
||||
tags={"python.control-flow", "python.data-structure"},
|
||||
)
|
||||
class ListUnpack(torch.nn.Module):
|
||||
"""
|
||||
Lists are treated as static construct, therefore unpacking should be
|
||||
erased after tracing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, args: List[torch.Tensor]):
|
||||
"""
|
||||
Lists are treated as static construct, therefore unpacking should be
|
||||
@ -26,3 +16,7 @@ class ListUnpack(torch.nn.Module):
|
||||
"""
|
||||
x, *y = args
|
||||
return x + y[0]
|
||||
|
||||
example_inputs = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
|
||||
tags = {"python.control-flow", "python.data-structure"}
|
||||
model = ListUnpack()
|
||||
|
@ -1,14 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
from torch._export.db.case import SupportLevel
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"python.object-model"},
|
||||
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
||||
)
|
||||
class ModelAttrMutation(torch.nn.Module):
|
||||
"""
|
||||
Attribute mutation is not supported.
|
||||
@ -24,3 +19,9 @@ class ModelAttrMutation(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
self.attr_list = self.recreate_list()
|
||||
return x.sum() + self.attr_list[0].sum()
|
||||
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"python.object-model"}
|
||||
support_level = SupportLevel.NOT_SUPPORTED_YET
|
||||
model = ModelAttrMutation()
|
||||
|
@ -1,20 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2), torch.randn(2)),
|
||||
tags={"python.closure"},
|
||||
)
|
||||
class NestedFunction(torch.nn.Module):
|
||||
"""
|
||||
Nested functions are traced through. Side effects on global captures
|
||||
are not supported though.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
x = a + b
|
||||
@ -26,3 +17,7 @@ class NestedFunction(torch.nn.Module):
|
||||
return x * y + z
|
||||
|
||||
return closure(x)
|
||||
|
||||
example_inputs = (torch.randn(3, 2), torch.randn(2))
|
||||
tags = {"python.closure"}
|
||||
model = NestedFunction()
|
||||
|
@ -3,21 +3,11 @@ import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"python.context-manager"},
|
||||
)
|
||||
class NullContextManager(torch.nn.Module):
|
||||
"""
|
||||
Null context manager in Python will be traced out.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Null context manager in Python will be traced out.
|
||||
@ -25,3 +15,7 @@ class NullContextManager(torch.nn.Module):
|
||||
ctx = contextlib.nullcontext()
|
||||
with ctx:
|
||||
return x.sin() + x.cos()
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"python.context-manager"}
|
||||
model = NullContextManager()
|
||||
|
@ -1,14 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
from torch._export.db.case import SupportLevel
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(2, 3),),
|
||||
tags={"python.object-model"},
|
||||
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
||||
)
|
||||
class OptionalInput(torch.nn.Module):
|
||||
"""
|
||||
Tracing through optional input is not supported yet
|
||||
@ -18,3 +13,9 @@ class OptionalInput(torch.nn.Module):
|
||||
if y is not None:
|
||||
return x + y
|
||||
return x
|
||||
|
||||
|
||||
example_inputs = (torch.randn(2, 3),)
|
||||
tags = {"python.object-model"}
|
||||
support_level = SupportLevel.NOT_SUPPORTED_YET
|
||||
model = OptionalInput()
|
||||
|
@ -1,21 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
|
||||
support_level=SupportLevel.SUPPORTED,
|
||||
)
|
||||
class PytreeFlatten(torch.nn.Module):
|
||||
"""
|
||||
Pytree from PyTorch can be captured by TorchDynamo.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
y, spec = pytree.tree_flatten(x)
|
||||
return y[0] + 1
|
||||
|
||||
example_inputs = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
|
||||
model = PytreeFlatten()
|
||||
|
@ -1,17 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
from torch.export import Dim
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
dim1_x = Dim("dim1_x")
|
||||
|
||||
@export_case(
|
||||
example_inputs=(x,),
|
||||
tags={"torch.dynamic-shape"},
|
||||
dynamic_shapes={"x": {1: dim1_x}},
|
||||
)
|
||||
class ScalarOutput(torch.nn.Module):
|
||||
"""
|
||||
Returning scalar values from the graph is supported, in addition to Tensor
|
||||
@ -22,3 +16,8 @@ class ScalarOutput(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x.shape[1] + 1
|
||||
|
||||
example_inputs = (x,)
|
||||
tags = {"torch.dynamic-shape"}
|
||||
dynamic_shapes = {"x": {1: dim1_x}}
|
||||
model = ScalarOutput()
|
||||
|
@ -3,16 +3,9 @@ from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
class Animal(Enum):
|
||||
COW = "moo"
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
)
|
||||
class SpecializedAttribute(torch.nn.Module):
|
||||
"""
|
||||
Model attributes are specialized.
|
||||
@ -28,3 +21,6 @@ class SpecializedAttribute(torch.nn.Module):
|
||||
return x * x + self.b
|
||||
else:
|
||||
raise ValueError("bad")
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
model = SpecializedAttribute()
|
||||
|
@ -1,23 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"python.control-flow"},
|
||||
)
|
||||
class StaticForLoop(torch.nn.Module):
|
||||
"""
|
||||
A for loop with constant number of iterations should be unrolled in the exported graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
ret = []
|
||||
for i in range(10): # constant
|
||||
ret.append(i + x)
|
||||
return ret
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"python.control-flow"}
|
||||
model = StaticForLoop()
|
||||
|
@ -1,24 +1,18 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2, 2),),
|
||||
tags={"python.control-flow"},
|
||||
)
|
||||
class StaticIf(torch.nn.Module):
|
||||
"""
|
||||
`if` statement with static predicate value should be traced through with the
|
||||
taken branch.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 3:
|
||||
return x + torch.ones(1, 1, 1)
|
||||
|
||||
return x
|
||||
|
||||
example_inputs = (torch.randn(3, 2, 2),)
|
||||
tags = {"python.control-flow"}
|
||||
model = StaticIf()
|
||||
|
@ -1,14 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2), "attr"),
|
||||
tags={"python.builtin"},
|
||||
support_level=SupportLevel.SUPPORTED,
|
||||
)
|
||||
class TensorSetattr(torch.nn.Module):
|
||||
"""
|
||||
setattr() call onto tensors is not supported.
|
||||
@ -16,3 +9,7 @@ class TensorSetattr(torch.nn.Module):
|
||||
def forward(self, x, attr):
|
||||
setattr(x, attr, torch.randn(3, 2))
|
||||
return x + 4
|
||||
|
||||
example_inputs = (torch.randn(3, 2), "attr")
|
||||
tags = {"python.builtin"}
|
||||
model = TensorSetattr()
|
||||
|
@ -1,14 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
from torch._export.db.case import SupportLevel
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"torch.operator"},
|
||||
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
||||
)
|
||||
class TorchSymMin(torch.nn.Module):
|
||||
"""
|
||||
torch.sym_min operator is not supported in export.
|
||||
@ -16,3 +11,9 @@ class TorchSymMin(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x.sum() + torch.sym_min(x.size(0), 100)
|
||||
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"torch.operator"}
|
||||
support_level = SupportLevel.NOT_SUPPORTED_YET
|
||||
model = TorchSymMin()
|
||||
|
@ -1,42 +1,22 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel, export_rewrite_case
|
||||
|
||||
|
||||
class A:
|
||||
@classmethod
|
||||
def func(cls, x):
|
||||
return 1 + x
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 4),),
|
||||
tags={"python.builtin"},
|
||||
support_level=SupportLevel.SUPPORTED,
|
||||
)
|
||||
class TypeReflectionMethod(torch.nn.Module):
|
||||
"""
|
||||
type() calls on custom objects followed by attribute accesses are not allowed
|
||||
due to its overly dynamic nature.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
a = A()
|
||||
return type(a).func(x)
|
||||
|
||||
|
||||
@export_rewrite_case(parent=TypeReflectionMethod)
|
||||
class TypeReflectionMethodRewrite(torch.nn.Module):
|
||||
"""
|
||||
Custom object class methods will be inlined.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return A.func(x)
|
||||
example_inputs = (torch.randn(3, 4),)
|
||||
tags = {"python.builtin"}
|
||||
model = TypeReflectionMethod()
|
||||
|
@ -1,14 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch._export.db.case import export_case, SupportLevel
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={"torch.mutation"},
|
||||
support_level=SupportLevel.SUPPORTED,
|
||||
)
|
||||
class UserInputMutation(torch.nn.Module):
|
||||
"""
|
||||
Directly mutate user input in forward
|
||||
@ -17,3 +10,8 @@ class UserInputMutation(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x.mul_(2)
|
||||
return x.cos()
|
||||
|
||||
|
||||
example_inputs = (torch.randn(3, 2),)
|
||||
tags = {"torch.mutation"}
|
||||
model = UserInputMutation()
|
||||
|
@ -5,13 +5,6 @@ import torch._export.db.examples as examples
|
||||
|
||||
TEMPLATE = '''import torch
|
||||
|
||||
from torch._export.db.case import export_case
|
||||
|
||||
|
||||
@export_case(
|
||||
example_inputs=(torch.randn(3, 2),),
|
||||
tags={{}},
|
||||
)
|
||||
def {case_name}(x):
|
||||
"""
|
||||
"""
|
||||
|
Reference in New Issue
Block a user