mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
[export] Publicize validate function (#132777)
as titled Pull Request resolved: https://github.com/pytorch/pytorch/pull/132777 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
21d4c48059
commit
c327710a87
@ -108,7 +108,7 @@ class TestVerifier(TestCase):
|
||||
return self.linear(x)
|
||||
|
||||
ep = export(M(), (torch.randn(10, 10),))
|
||||
ep._validate()
|
||||
ep.validate()
|
||||
|
||||
def test_ep_verifier_invalid_param(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
@ -128,14 +128,14 @@ class TestVerifier(TestCase):
|
||||
kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param"
|
||||
)
|
||||
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
|
||||
ep._validate()
|
||||
ep.validate()
|
||||
|
||||
# Add non-torch.nn.Parameter parameter to the state dict
|
||||
ep.state_dict["bad_param"] = torch.randn(100)
|
||||
with self.assertRaisesRegex(
|
||||
SpecViolationError, "not an instance of torch.nn.Parameter"
|
||||
):
|
||||
ep._validate()
|
||||
ep.validate()
|
||||
|
||||
def test_ep_verifier_invalid_buffer(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
@ -156,7 +156,7 @@ class TestVerifier(TestCase):
|
||||
persistent=True,
|
||||
)
|
||||
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
|
||||
ep._validate()
|
||||
ep.validate()
|
||||
|
||||
def test_ep_verifier_buffer_mutate(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
@ -179,7 +179,7 @@ class TestVerifier(TestCase):
|
||||
return output
|
||||
|
||||
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
|
||||
ep._validate()
|
||||
ep.validate()
|
||||
|
||||
def test_ep_verifier_invalid_output(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
@ -213,7 +213,7 @@ class TestVerifier(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
|
||||
ep._validate()
|
||||
ep.validate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1366,7 +1366,7 @@ class ExportedProgramSerializer(metaclass=Final):
|
||||
Args:
|
||||
exported_program: Exported Program to serialize
|
||||
"""
|
||||
exported_program._validate()
|
||||
exported_program.validate()
|
||||
|
||||
gm_serializer = GraphModuleSerializer(
|
||||
exported_program.graph_signature, exported_program.module_call_graph
|
||||
|
@ -716,7 +716,7 @@ class ExportedProgram:
|
||||
assert all(issubclass(v, Verifier) for v in verifiers)
|
||||
self._verifiers = verifiers
|
||||
# Validate should be always the last step of the constructor.
|
||||
self._validate()
|
||||
self.validate()
|
||||
|
||||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@ -1131,6 +1131,11 @@ class ExportedProgram:
|
||||
input_placeholders, flat_args_with_path, self.range_constraints
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def validate(self):
|
||||
self._validate()
|
||||
|
||||
# TODO: remove this
|
||||
@final
|
||||
def _validate(self):
|
||||
assert (
|
||||
|
Reference in New Issue
Block a user