[dynamo][testing] Update AOTEagerandRecordGraphs backend (#138231)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138231
Approved by: https://github.com/StrongerXi, https://github.com/mlazos, https://github.com/aakhundov
This commit is contained in:
Animesh Jain
2024-10-17 16:31:31 -07:00
committed by PyTorch MergeBot
parent 8a5dd7f59b
commit e714ebf664
2 changed files with 19 additions and 9 deletions

View File

@ -1677,7 +1677,7 @@ def forward(self, x_1, output_1):
if dynamic:
self.assertExpectedInline(
backend.graphs[0].code.strip(),
backend.fw_graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False)
@ -1690,7 +1690,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
)
else:
self.assertExpectedInline(
backend.graphs[0].code.strip(),
backend.fw_graphs[0].code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)

View File

@ -246,24 +246,34 @@ class EagerAndRecordGraphs:
return gm.forward
# Equivalent to backend="aot_eager", but also records graphs that
# we can assert on
class AOTEagerAndRecordGraphs:
class AotEagerAndRecordGraphs:
def __init__(self) -> None:
self.graphs: List[torch.fx.GraphModule] = []
self.fw_graphs: List[torch.fx.GraphModule] = []
self.bw_graphs: List[torch.fx.GraphModule] = []
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
def save_graph(gm: torch.fx.GraphModule, *args: Any, **kwargs: Any) -> Any:
self.graphs.append(gm)
self.graphs.append(gm)
def fw_compiler(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
self.fw_graphs.append(gm)
return gm.forward
def bw_compiler(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
self.bw_graphs.append(gm)
return gm.forward
return aot_eager(
gm,
example_inputs,
fw_compiler=save_graph,
bw_compiler=save_graph,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
)