mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] fix graph partition signature (#27139)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@ -90,6 +90,156 @@ def memory_plan_reuse_patched(self):
|
|||||||
assert len(planning_states) == 0
|
assert len(planning_states) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================
|
||||||
|
# torch 2.9 Inductor get_graph_partition_signature monkeypatch
|
||||||
|
# ===================================================
|
||||||
|
# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to
|
||||||
|
# fix inductor partition + attention-nvfp4 quant fusion, tested in
|
||||||
|
# `tests/compile/test_fusions_e2e.py::test_attn_quant`.
|
||||||
|
# For more context, see https://github.com/pytorch/pytorch/pull/165815.
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_partition_signature_patched(
|
||||||
|
self, partitions, skip_cudagraphs: list[bool]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Gets signature for each graph partition, including input nodes, output nodes, and
|
||||||
|
whether deallocating an input within graph partition.
|
||||||
|
"""
|
||||||
|
from torch._inductor import dependencies
|
||||||
|
from torch._inductor.ir import GraphPartitionSignature, MutationOutput, NoneLayout
|
||||||
|
from torch._inductor.virtualized import V
|
||||||
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
signatures = []
|
||||||
|
|
||||||
|
unmet_output_names = OrderedSet(V.graph.get_output_names())
|
||||||
|
name_to_node = self.get_name_to_nodes()
|
||||||
|
|
||||||
|
def is_none_layout(buf_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated
|
||||||
|
so graph partition should not take it as inputs or outputs.
|
||||||
|
"""
|
||||||
|
buf = self.name_to_buf.get(buf_name, None)
|
||||||
|
|
||||||
|
if buf is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(buf.node.layout, NoneLayout):
|
||||||
|
if isinstance(buf.node, MutationOutput) and (
|
||||||
|
real_name := self.mutation_real_name.get(buf_name, None)
|
||||||
|
):
|
||||||
|
return is_none_layout(real_name)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
for partition, skip_cudagraph in zip(
|
||||||
|
reversed(partitions), reversed(skip_cudagraphs)
|
||||||
|
):
|
||||||
|
output_names: OrderedSet[str] = OrderedSet()
|
||||||
|
|
||||||
|
for node in partition:
|
||||||
|
output_names.update(node.outputs_by_name.keys())
|
||||||
|
|
||||||
|
returned_output_names = output_names.intersection(unmet_output_names)
|
||||||
|
|
||||||
|
# all reads/writes are partition inputs except those generated
|
||||||
|
# within the partition and tensor constants
|
||||||
|
read_writes = dependencies.ReadWrites.merge_list(
|
||||||
|
[node.read_writes for node in partition]
|
||||||
|
)
|
||||||
|
|
||||||
|
# WeakDep is fake dependency on unused buffer. It should not appear
|
||||||
|
# in partition_input_names for inputs that are actually read or written.
|
||||||
|
partition_input_names = (
|
||||||
|
OrderedSet(
|
||||||
|
[
|
||||||
|
x.name
|
||||||
|
for x in read_writes.reads | read_writes.writes
|
||||||
|
if not is_none_layout(x.name)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
- output_names
|
||||||
|
)
|
||||||
|
|
||||||
|
partition_input_names = OrderedSet(
|
||||||
|
self.mutation_real_name.get(name, name) for name in partition_input_names
|
||||||
|
)
|
||||||
|
|
||||||
|
buffer_names_to_free: OrderedSet[str] = OrderedSet()
|
||||||
|
for node in partition:
|
||||||
|
buffer_names_to_free.update(node.last_usage)
|
||||||
|
|
||||||
|
# buffer_names_to_free may contain buffers allocated in previous
|
||||||
|
# graph partitions. These buffers should also be a partition
|
||||||
|
# input.
|
||||||
|
extra_input_names = [
|
||||||
|
name
|
||||||
|
for name in (buffer_names_to_free - output_names)
|
||||||
|
if name in name_to_node
|
||||||
|
]
|
||||||
|
partition_input_names.update(extra_input_names)
|
||||||
|
|
||||||
|
input_nodes = {
|
||||||
|
name: name_to_node[name]
|
||||||
|
for name in partition_input_names
|
||||||
|
if name in name_to_node
|
||||||
|
}
|
||||||
|
input_deallocation = {
|
||||||
|
name: name in buffer_names_to_free
|
||||||
|
for name in partition_input_names
|
||||||
|
if name in name_to_node
|
||||||
|
}
|
||||||
|
|
||||||
|
# if an input tensor is not freed in the partition function, it should
|
||||||
|
# also be returned as an output. This brings benefits to cudagraph
|
||||||
|
# since the returned output tensor is a cudagraph managed tensor with
|
||||||
|
# a static tensor address.
|
||||||
|
extra_output_names = [
|
||||||
|
name
|
||||||
|
for name in partition_input_names
|
||||||
|
if name in name_to_node and name not in buffer_names_to_free
|
||||||
|
]
|
||||||
|
|
||||||
|
returned_output_names.update(extra_output_names)
|
||||||
|
|
||||||
|
returned_output_names = OrderedSet(
|
||||||
|
self.mutation_real_name.get(name, name) for name in returned_output_names
|
||||||
|
)
|
||||||
|
|
||||||
|
output_nodes = [
|
||||||
|
name_to_node[name]
|
||||||
|
for name in returned_output_names
|
||||||
|
if not is_none_layout(name)
|
||||||
|
]
|
||||||
|
|
||||||
|
constant_names = [
|
||||||
|
name for name in partition_input_names if name in V.graph.constants
|
||||||
|
]
|
||||||
|
|
||||||
|
symbol_inputs = self.get_graph_partition_symbol_inputs(partition, input_nodes)
|
||||||
|
|
||||||
|
partition_signature = GraphPartitionSignature(
|
||||||
|
symbol_inputs,
|
||||||
|
input_nodes,
|
||||||
|
output_nodes,
|
||||||
|
input_deallocation,
|
||||||
|
skip_cudagraph,
|
||||||
|
constant_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
signatures.append(partition_signature)
|
||||||
|
|
||||||
|
unmet_output_names = partition_input_names.union(
|
||||||
|
unmet_output_names - returned_output_names
|
||||||
|
)
|
||||||
|
|
||||||
|
return signatures[::-1]
|
||||||
|
|
||||||
|
|
||||||
# ========================================
|
# ========================================
|
||||||
# torch 2.9 Inductor Scheduler monkeypatch
|
# torch 2.9 Inductor Scheduler monkeypatch
|
||||||
# ========================================
|
# ========================================
|
||||||
@ -196,6 +346,7 @@ def _update_scheduler_patched(self) -> None:
|
|||||||
from torch._inductor.scheduler import Scheduler
|
from torch._inductor.scheduler import Scheduler
|
||||||
|
|
||||||
Scheduler.should_partition = should_partition_patched
|
Scheduler.should_partition = should_partition_patched
|
||||||
|
Scheduler.get_graph_partition_signature = get_graph_partition_signature_patched
|
||||||
|
|
||||||
with config.patch("triton.store_cubin", False):
|
with config.patch("triton.store_cubin", False):
|
||||||
self.scheduler = Scheduler(self.operations)
|
self.scheduler = Scheduler(self.operations)
|
||||||
|
Reference in New Issue
Block a user