[export] Publicize validate function (#132777)

as titled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132777
Approved by: https://github.com/zhxchen17
This commit is contained in:
angelayi
2024-08-07 23:10:05 +00:00
committed by PyTorch MergeBot
parent 21d4c48059
commit c327710a87
3 changed files with 13 additions and 8 deletions

View File

@ -108,7 +108,7 @@ class TestVerifier(TestCase):
return self.linear(x)
ep = export(M(), (torch.randn(10, 10),))
ep._validate()
ep.validate()
def test_ep_verifier_invalid_param(self) -> None:
class M(torch.nn.Module):
@ -128,14 +128,14 @@ class TestVerifier(TestCase):
kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param"
)
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
ep._validate()
ep.validate()
# Add non-torch.nn.Parameter parameter to the state dict
ep.state_dict["bad_param"] = torch.randn(100)
with self.assertRaisesRegex(
SpecViolationError, "not an instance of torch.nn.Parameter"
):
ep._validate()
ep.validate()
def test_ep_verifier_invalid_buffer(self) -> None:
class M(torch.nn.Module):
@ -156,7 +156,7 @@ class TestVerifier(TestCase):
persistent=True,
)
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
ep._validate()
ep.validate()
def test_ep_verifier_buffer_mutate(self) -> None:
class M(torch.nn.Module):
@ -179,7 +179,7 @@ class TestVerifier(TestCase):
return output
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
ep._validate()
ep.validate()
def test_ep_verifier_invalid_output(self) -> None:
class M(torch.nn.Module):
@ -213,7 +213,7 @@ class TestVerifier(TestCase):
)
with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
ep._validate()
ep.validate()
if __name__ == "__main__":

View File

@ -1366,7 +1366,7 @@ class ExportedProgramSerializer(metaclass=Final):
Args:
exported_program: Exported Program to serialize
"""
exported_program._validate()
exported_program.validate()
gm_serializer = GraphModuleSerializer(
exported_program.graph_signature, exported_program.module_call_graph

View File

@ -716,7 +716,7 @@ class ExportedProgram:
assert all(issubclass(v, Verifier) for v in verifiers)
self._verifiers = verifiers
# Validate should be always the last step of the constructor.
self._validate()
self.validate()
@property
@compatibility(is_backward_compatible=False)
@ -1131,6 +1131,11 @@ class ExportedProgram:
input_placeholders, flat_args_with_path, self.range_constraints
)
@compatibility(is_backward_compatible=False)
def validate(self):
self._validate()
# TODO: remove this
@final
def _validate(self):
assert (