[fx] fix split_module with symint (#160093)

Fixes https://github.com/pytorch/pytorch/issues/155220

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160093
Approved by: https://github.com/ezyang
This commit is contained in:
kshitij12345
2025-08-13 05:50:11 +00:00
committed by PyTorch MergeBot
parent 685f15dbea
commit 199e9abb6a
2 changed files with 91 additions and 1 deletions

View File

@ -53,7 +53,7 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase, TEST_WITH_CROSSREF
from torch.testing._internal.jit_utils import JitTestCase
import torch.utils._pytree as pytree
@ -963,6 +963,95 @@ terrible spacing
# `keep_original_order=True`
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True))
@unittest.skipIf(TEST_WITH_CROSSREF, "See https://github.com/pytorch/pytorch/issues/160077")
def test_split_module_symint_dependency_handling(self):
# Based on the code from - transformers/models/granitemoe/modeling_granitemoe.py
class GraniteMoeTopKGating(torch.nn.Module):
def __init__(self, input_size: int, num_experts: int, top_k: int):
super().__init__()
self.num_experts = num_experts
self.input_size = input_size
self.top_k = top_k
self.layer = torch.nn.Linear(input_size, num_experts, bias=False)
def forward(self, hidden_states):
# compute the top_k routing decision
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
# compute number of input given to each expert
zeros = torch.zeros(
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
expert_size = expert_size.tolist()
# sort and group input tokens according to expert assignment
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
# gather the gate values for grouped input tokens
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
class GraniteMoeMoE(torch.nn.Module):
def __init__(self):
super().__init__()
self.input_size = 32
self.num_local_experts = 4
num_experts_per_tok = 2
self.router = GraniteMoeTopKGating(
input_size=self.input_size,
num_experts=self.num_local_experts,
top_k=num_experts_per_tok,
)
def forward(self, layer_input):
_, batch_index, _, expert_size, _ = self.router(layer_input)
expert_inputs = layer_input[batch_index]
return expert_inputs.split(expert_size, dim=0)
moe = GraniteMoeMoE()
inp = torch.randn([32, 32])
expected = moe(inp)
PARTITION_ID = 0
PARTITION_OPS_CTR = 0
NODE_PARTITION_MAP = {}
# `callback` is called multiple times with same `node` in `split_module`.
# Cache the result such that partition id is consistent across calls.
def callback(node) -> int:
nonlocal PARTITION_ID, PARTITION_OPS_CTR, NODE_PARTITION_MAP
if node in NODE_PARTITION_MAP:
return NODE_PARTITION_MAP[node]
if PARTITION_OPS_CTR % 5 == 0:
PARTITION_ID += 1
PARTITION_OPS_CTR += 1
NODE_PARTITION_MAP[node] = PARTITION_ID
return PARTITION_ID
def backend(gm, inps):
split_gm = split_module(gm, root_m=None, split_callback=callback,
keep_original_order=True, keep_original_node_name=True)
return split_gm
actual = torch.compile(moe, backend=backend)(inp)
torch.testing.assert_close(actual, expected)
def test_normalize_binary_operators(self):
ops_to_test = {
torch.add,

View File

@ -248,6 +248,7 @@ def split_module(
s_def_partition = partitions[s_defined]
s_def_partition.outputs.setdefault(s_node.name)
s_def_partition.dependents.setdefault(used)
use_partition.dependencies.setdefault(s_defined)
if defined is not None:
use_partition.dependencies.setdefault(defined)