mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] fix split_module with symint (#160093)
Fixes https://github.com/pytorch/pytorch/issues/155220 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160093 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
685f15dbea
commit
199e9abb6a
@ -53,7 +53,7 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_nn import module_tests, get_new_module_tests
|
||||
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase, TEST_WITH_CROSSREF
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
@ -963,6 +963,95 @@ terrible spacing
|
||||
# `keep_original_order=True`
|
||||
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True))
|
||||
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "See https://github.com/pytorch/pytorch/issues/160077")
|
||||
def test_split_module_symint_dependency_handling(self):
|
||||
# Based on the code from - transformers/models/granitemoe/modeling_granitemoe.py
|
||||
class GraniteMoeTopKGating(torch.nn.Module):
|
||||
def __init__(self, input_size: int, num_experts: int, top_k: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.top_k = top_k
|
||||
|
||||
self.layer = torch.nn.Linear(input_size, num_experts, bias=False)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# compute the top_k routing decision
|
||||
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
|
||||
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
|
||||
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
|
||||
|
||||
# compute number of input given to each expert
|
||||
zeros = torch.zeros(
|
||||
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
|
||||
) # [num_tokens, num_experts]
|
||||
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
|
||||
expert_size = gates.long().sum(0) # [num_experts,]
|
||||
expert_size = expert_size.tolist()
|
||||
|
||||
# sort and group input tokens according to expert assignment
|
||||
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
|
||||
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
|
||||
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
|
||||
|
||||
# gather the gate values for grouped input tokens
|
||||
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
|
||||
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
|
||||
|
||||
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
|
||||
|
||||
class GraniteMoeMoE(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = 32
|
||||
self.num_local_experts = 4
|
||||
|
||||
num_experts_per_tok = 2
|
||||
self.router = GraniteMoeTopKGating(
|
||||
input_size=self.input_size,
|
||||
num_experts=self.num_local_experts,
|
||||
top_k=num_experts_per_tok,
|
||||
)
|
||||
|
||||
def forward(self, layer_input):
|
||||
_, batch_index, _, expert_size, _ = self.router(layer_input)
|
||||
expert_inputs = layer_input[batch_index]
|
||||
return expert_inputs.split(expert_size, dim=0)
|
||||
|
||||
moe = GraniteMoeMoE()
|
||||
inp = torch.randn([32, 32])
|
||||
|
||||
expected = moe(inp)
|
||||
|
||||
PARTITION_ID = 0
|
||||
PARTITION_OPS_CTR = 0
|
||||
NODE_PARTITION_MAP = {}
|
||||
|
||||
# `callback` is called multiple times with same `node` in `split_module`.
|
||||
# Cache the result such that partition id is consistent across calls.
|
||||
def callback(node) -> int:
|
||||
nonlocal PARTITION_ID, PARTITION_OPS_CTR, NODE_PARTITION_MAP
|
||||
if node in NODE_PARTITION_MAP:
|
||||
return NODE_PARTITION_MAP[node]
|
||||
|
||||
if PARTITION_OPS_CTR % 5 == 0:
|
||||
PARTITION_ID += 1
|
||||
|
||||
PARTITION_OPS_CTR += 1
|
||||
|
||||
NODE_PARTITION_MAP[node] = PARTITION_ID
|
||||
return PARTITION_ID
|
||||
|
||||
def backend(gm, inps):
|
||||
split_gm = split_module(gm, root_m=None, split_callback=callback,
|
||||
keep_original_order=True, keep_original_node_name=True)
|
||||
return split_gm
|
||||
|
||||
actual = torch.compile(moe, backend=backend)(inp)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
def test_normalize_binary_operators(self):
|
||||
ops_to_test = {
|
||||
torch.add,
|
||||
|
||||
@ -248,6 +248,7 @@ def split_module(
|
||||
s_def_partition = partitions[s_defined]
|
||||
s_def_partition.outputs.setdefault(s_node.name)
|
||||
s_def_partition.dependents.setdefault(used)
|
||||
use_partition.dependencies.setdefault(s_defined)
|
||||
if defined is not None:
|
||||
use_partition.dependencies.setdefault(defined)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user