mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][testing] Update AOTEagerandRecordGraphs backend (#138231)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138231 Approved by: https://github.com/StrongerXi, https://github.com/mlazos, https://github.com/aakhundov
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a5dd7f59b
commit
e714ebf664
@ -1677,7 +1677,7 @@ def forward(self, x_1, output_1):
|
||||
|
||||
if dynamic:
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False)
|
||||
@ -1690,7 +1690,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False)
|
||||
|
@ -246,24 +246,34 @@ class EagerAndRecordGraphs:
|
||||
return gm.forward
|
||||
|
||||
|
||||
# Equivalent to backend="aot_eager", but also records graphs that
|
||||
# we can assert on
|
||||
class AOTEagerAndRecordGraphs:
|
||||
class AotEagerAndRecordGraphs:
|
||||
def __init__(self) -> None:
|
||||
self.graphs: List[torch.fx.GraphModule] = []
|
||||
self.fw_graphs: List[torch.fx.GraphModule] = []
|
||||
self.bw_graphs: List[torch.fx.GraphModule] = []
|
||||
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
) -> Callable[..., Any]:
|
||||
def save_graph(gm: torch.fx.GraphModule, *args: Any, **kwargs: Any) -> Any:
|
||||
self.graphs.append(gm)
|
||||
self.graphs.append(gm)
|
||||
|
||||
def fw_compiler(
|
||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
) -> Callable[..., Any]:
|
||||
self.fw_graphs.append(gm)
|
||||
return gm.forward
|
||||
|
||||
def bw_compiler(
|
||||
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
) -> Callable[..., Any]:
|
||||
self.bw_graphs.append(gm)
|
||||
return gm.forward
|
||||
|
||||
return aot_eager(
|
||||
gm,
|
||||
example_inputs,
|
||||
fw_compiler=save_graph,
|
||||
bw_compiler=save_graph,
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user