mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reviewed By: avikchaudhuri Differential Revision: D67530154 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143688 Approved by: https://github.com/tugsbayasgalan
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
from torch._inductor.utils import pass_execution_and_save
|
|
|
|
|
|
class FxPassesPreGradTests(torch._dynamo.test_case.TestCase):
|
|
@mock.patch("torch._inductor.utils.ShapeProp.propagate")
|
|
def test_pass_execution_and_save(self, mock_shape_prop):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.ones(4, 4))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.param + x
|
|
|
|
def fx_pass(graph: torch.fx.GraphModule) -> None:
|
|
return
|
|
|
|
sample_input = torch.randn(4, 4)
|
|
m = TestModule()
|
|
m(sample_input)
|
|
exported_program = torch.export.export(m, (sample_input,), strict=True)
|
|
gm = exported_program.graph_module
|
|
|
|
pass_execution_and_save(fx_pass, gm, sample_input, "Apply testing pass")
|
|
mock_shape_prop.assert_called_once()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|