mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] split_module subgraph should always have an output node (#139275)
Fixes https://github.com/pytorch/pytorch/issues/138207 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139275 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
e3e3ab805b
commit
0cf4cc3d5f
@ -18,6 +18,7 @@ import torch.fx.experimental.optimization as optimization
|
||||
from torch.fx._symbolic_trace import symbolic_trace
|
||||
from torch.fx.experimental import merge_matmul
|
||||
from torch.fx.experimental.accelerator_partitioner import Partitioner
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
|
||||
from torch.fx.experimental.partitioner_utils import (
|
||||
Device,
|
||||
@ -819,6 +820,23 @@ terrible spacing
|
||||
split(x), traced(x)
|
||||
)
|
||||
|
||||
def test_split_module_return_node(self):
|
||||
def foo(x):
|
||||
x.add_(1)
|
||||
|
||||
gm = make_fx(foo, tracing_mode="fake")(torch.randn(3,))
|
||||
|
||||
def cb(_):
|
||||
return 1
|
||||
|
||||
sp_gm = split_module(gm, None, cb)
|
||||
submod_gm = sp_gm.submod_1
|
||||
for node in submod_gm.graph.nodes:
|
||||
if node.op == "output":
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Expected the subgraph to have an output node.")
|
||||
|
||||
|
||||
def test_split_module_kwargs_expansion(self):
|
||||
class ModuleWithKwargsExpansion(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user