mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: replace export_for_training with epxort Test Plan: CI Rollback Plan: Differential Revision: D81935792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162396 Approved by: https://github.com/angelayi, https://github.com/jerryzh168
224 lines
7.8 KiB
Python
224 lines
7.8 KiB
Python
# Owner(s): ["oncall: export"]
|
|
import unittest
|
|
|
|
import torch
|
|
from functorch.experimental import control_flow
|
|
from torch import Tensor
|
|
from torch._dynamo.eval_frame import is_dynamo_supported
|
|
from torch._export.verifier import SpecViolationError, Verifier
|
|
from torch.export import export
|
|
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
|
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
|
|
class TestVerifier(TestCase):
|
|
def test_verifier_basic(self) -> None:
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
f = Foo()
|
|
|
|
ep = export(f, (torch.randn(100), torch.randn(100)), strict=True)
|
|
|
|
verifier = Verifier()
|
|
verifier.check(ep)
|
|
|
|
def test_verifier_call_module(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.linear(x)
|
|
|
|
gm = torch.fx.symbolic_trace(M())
|
|
|
|
verifier = Verifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier._check_graph_module(gm)
|
|
|
|
def test_verifier_no_functional(self) -> None:
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
f = Foo()
|
|
|
|
ep = export(
|
|
f, (torch.randn(100), torch.randn(100)), strict=True
|
|
).run_decompositions({})
|
|
for node in ep.graph.nodes:
|
|
if node.target == torch.ops.aten.add.Tensor:
|
|
node.target = torch.ops.aten.add_.Tensor
|
|
|
|
verifier = Verifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier.check(ep)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
def test_verifier_higher_order(self) -> None:
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x - y
|
|
|
|
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
|
|
|
f = Foo()
|
|
|
|
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True)
|
|
|
|
verifier = Verifier()
|
|
verifier.check(ep)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
def test_verifier_nested_invalid_module(self) -> None:
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x - y
|
|
|
|
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
|
|
|
f = Foo()
|
|
|
|
ep = export(
|
|
f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True
|
|
).run_decompositions({})
|
|
for node in ep.graph_module.true_graph_0.graph.nodes:
|
|
if node.target == torch.ops.aten.add.Tensor:
|
|
node.target = torch.ops.aten.add_.Tensor
|
|
|
|
verifier = Verifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier.check(ep)
|
|
|
|
def test_ep_verifier_basic(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.linear(x)
|
|
|
|
ep = export(M(), (torch.randn(10, 10),), strict=True)
|
|
ep.validate()
|
|
|
|
def test_ep_verifier_invalid_param(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_parameter(
|
|
name="a", param=torch.nn.Parameter(torch.randn(100))
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y + self.a
|
|
|
|
ep = export(M(), (torch.randn(100), torch.randn(100)), strict=True)
|
|
|
|
# Parameter doesn't exist in the state dict
|
|
ep.graph_signature.input_specs[0] = InputSpec(
|
|
kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param"
|
|
)
|
|
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
|
|
ep.validate()
|
|
|
|
# Add non-torch.nn.Parameter parameter to the state dict
|
|
ep.state_dict["bad_param"] = torch.randn(100)
|
|
with self.assertRaisesRegex(
|
|
SpecViolationError, "not an instance of torch.nn.Parameter"
|
|
):
|
|
ep.validate()
|
|
|
|
def test_ep_verifier_invalid_buffer(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = torch.tensor(3.0)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y + self.a
|
|
|
|
ep = export(M(), (torch.randn(100), torch.randn(100)), strict=True)
|
|
|
|
# Buffer doesn't exist in the state dict
|
|
ep.graph_signature.input_specs[0] = InputSpec(
|
|
kind=InputKind.BUFFER,
|
|
arg=TensorArgument(name="c_a"),
|
|
target="bad_buffer",
|
|
persistent=True,
|
|
)
|
|
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
|
|
ep.validate()
|
|
|
|
def test_ep_verifier_buffer_mutate(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
|
|
|
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
|
|
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
|
|
|
|
def forward(self, x1, x2):
|
|
# Use the parameter, buffers, and both inputs in the forward method
|
|
output = (
|
|
x1 + self.my_parameter
|
|
) * self.my_buffer1 + x2 * self.my_buffer2
|
|
|
|
# Mutate one of the buffers (e.g., increment it by 1)
|
|
self.my_buffer2.add_(1.0)
|
|
return output
|
|
|
|
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True)
|
|
ep.validate()
|
|
|
|
def test_ep_verifier_invalid_output(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
|
|
|
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
|
|
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))
|
|
|
|
def forward(self, x1, x2):
|
|
# Use the parameter, buffers, and both inputs in the forward method
|
|
output = (
|
|
x1 + self.my_parameter
|
|
) * self.my_buffer1 + x2 * self.my_buffer2
|
|
|
|
# Mutate one of the buffers (e.g., increment it by 1)
|
|
self.my_buffer2.add_(1.0)
|
|
return output
|
|
|
|
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True)
|
|
|
|
output_node = list(ep.graph.nodes)[-1]
|
|
output_node.args = (
|
|
(
|
|
output_node.args[0][0],
|
|
next(iter(ep.graph.nodes)),
|
|
),
|
|
)
|
|
|
|
with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
|
|
ep.validate()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|