mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add flag to fx.passes.split_module to normalize input names (#157733)
This is useful for vLLM, which runs AOTAutograd directly on graphs after they have been split. I created a new flag for this instead of reusing `keep_original_node_name` (please let me know if you think I should reuse this). The reasoning is: - The names of the placeholder nodes is different from the targets of the placehoder nodes. The targets are the actual input names. - Backwards compatibility: this API has been out for ~4 years, it looks public, and it has extensive public use. For example, this change would actually be BC-breaking to vLLM (they rely on the subgraph input names being different at the moment). Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/157733 Approved by: https://github.com/ezyang
This commit is contained in:
@ -791,6 +791,46 @@ terrible spacing
|
||||
|
||||
self.assertEqual(orig_out, submodules_out)
|
||||
|
||||
def test_split_module_input_names(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x, a0, a1, b0, b1, c0, c1):
|
||||
x = x + (a0 ** 2) + (a1 / 2)
|
||||
x = x + (b0 ** 2) + (b1 / 2)
|
||||
x = x + (c0 ** 2) + (c1 / 2)
|
||||
return x
|
||||
|
||||
mod = Mod()
|
||||
traced = torch.fx.symbolic_trace(mod)
|
||||
|
||||
seen = 0
|
||||
|
||||
def split(n):
|
||||
nonlocal seen
|
||||
result = seen // 4
|
||||
seen += 1
|
||||
return result
|
||||
|
||||
split = split_module(traced, mod, split, keep_original_input_name=False)
|
||||
|
||||
# All the submodules should take in the inputs in the same order.
|
||||
args = [torch.tensor(2.), torch.tensor(3.), torch.tensor(4.)]
|
||||
output0 = split.submod_0(*args)
|
||||
output1 = split.submod_1(*args)
|
||||
output2 = split.submod_2(*args)
|
||||
self.assertEqual(output0, output1)
|
||||
self.assertEqual(output1, output2)
|
||||
|
||||
# Each submodule should have normalized input names
|
||||
def check_ph(gm):
|
||||
nodes = list(gm.graph.nodes)
|
||||
self.assertEqual(nodes[0].target, "arg_0")
|
||||
self.assertEqual(nodes[1].target, "arg_1")
|
||||
self.assertEqual(nodes[2].target, "arg_2")
|
||||
|
||||
check_ph(split.submod_0)
|
||||
check_ph(split.submod_1)
|
||||
check_ph(split.submod_2)
|
||||
|
||||
def test_split_module_dead_code(self):
|
||||
class ModWithDeadCode(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
Reference in New Issue
Block a user