Compare commits

...

13 Commits

7 changed files with 449 additions and 15 deletions

View File

@ -8,6 +8,21 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name):
return f"from torch._C import _cuda_getCurrentRawStream as {name}"
def generate_stream_creation(self, stream_pool, default_stream_id):
stream_creation_str = ""
for index, num_used in enumerate(stream_pool):
if num_used > 0 and index != default_stream_id:
stream_creation_str += (
f"stream{index}_raw = torch.cuda.Stream()\n"
)
stream_creation_str += (
f"stream{index} = stream{index}_raw.cuda_stream\n"
)
# generate for default stream
default_stream_str = f"streamdata = torch._C._cuda_getCurrentStream({default_stream_id})\n"
default_stream_str += f"stream{default_stream_id}_raw = torch.cuda.Stream(stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2])\n"
return stream_creation_str + default_stream_str
def set_device(self, device_idx):
return f"torch.cuda.set_device({device_idx})"

View File

@ -56,7 +56,6 @@ from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
if TYPE_CHECKING:
import triton
@ -360,6 +359,14 @@ class MemoryPlanningLine(WrapperLine):
class AllocateLine(MemoryPlanningLine):
node: ir.Buffer
def set_user_stream(self, stream_id):
self.user_streams = [
stream_id,
]
def add_user_stream(self, stream_id):
self.user_streams.append(stream_id)
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
if self.node.get_name() in V.graph.removed_buffers:
return NullLine(self.wrapper)
@ -383,7 +390,42 @@ class AllocateLine(MemoryPlanningLine):
def codegen(self, code: IndentedBuffer) -> None:
assert self.node.get_name() not in V.graph.removed_buffers
line = self.wrapper.make_buffer_allocation(self.node)
code.writeline(line)
# if self has attribution named user_streams, it is used by multiple streams
if hasattr(self, "user_streams"):
DEFAULT_STREAM_ID = V.graph.stream_graph.DEFAULT_STREAM_ID
# the redundant stream_id have been removed when user_streams is set
if len(self.user_streams) == 1:
if self.user_streams[0] != DEFAULT_STREAM_ID:
code.writeline(
f"torch.cuda.set_stream(stream{self.user_streams[0]}_raw)"
)
code.writeline(line)
code.writeline(
f"torch.cuda.set_stream(stream{DEFAULT_STREAM_ID}_raw)"
)
else:
code.writeline(line)
elif len(self.user_streams) > 1:
# assign the `empty_strided` to the first user_stream
assign_stream = self.user_streams[0]
event_name = f"event_allocate_{self.node.get_name()}"
code.writeline(f"{event_name} = torch.cuda.Event()")
if assign_stream != DEFAULT_STREAM_ID:
code.writeline(f"torch.cuda.set_stream(stream{assign_stream}_raw)")
code.writeline(line)
if assign_stream != DEFAULT_STREAM_ID:
code.writeline(
f"torch.cuda.set_stream(stream{DEFAULT_STREAM_ID}_raw)"
)
code.writeline(f"{event_name}.record(stream{assign_stream}_raw)")
for user_stream in self.user_streams[1:]:
code.writeline(f"stream{user_stream}_raw.wait_event({event_name})")
else:
raise AssertionError(f"invalid user_streams: {self.user_streams}")
else:
code.writeline(line)
@dataclasses.dataclass
@ -620,6 +662,10 @@ class PythonWrapperCodegen(CodeGen):
self.kernel_autotune_calls.writeline(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
if config.multiple_streams:
self.header.writeline(V.graph.device_ops.generate_stream_creation(V.graph.stream_graph.stream_pool, V.graph.stream_graph.DEFAULT_STREAM_ID))
def add_meta_once(self, meta: TritonMetaParams) -> str:
meta = repr(meta)
@ -704,7 +750,7 @@ class PythonWrapperCodegen(CodeGen):
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
self.write_get_raw_stream_header_once()
name = f"stream{device_idx}"
self.writeline(f"{name} = get_raw_stream({device_idx})")
self.header.writeline(f"{name} = get_raw_stream({device_idx})")
return name
def get_codegened_graph(self):
@ -762,7 +808,86 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None:
return
def cuda_event_dependency(self, node_name, kernel_IndentedBuffer):
"""
Yueming: is it better to move the dependency detection to streamscheduler?
"""
ssnode = V.graph.stream_graph.buf_to_ssnode[node_name]
def update_event_dependency(tmp_ssnode):
dependent_buffers = set()
for predecessor in tmp_ssnode.predecessors:
if predecessor.is_nop_node:
for prepredecessor in predecessor.predecessors:
if (
prepredecessor.stream_id != tmp_ssnode.stream_id
and not prepredecessor.is_nop_node
):
dependent_buffers.add(prepredecessor)
else:
if (
predecessor.stream_id != tmp_ssnode.stream_id
and not predecessor.is_nop_node
):
dependent_buffers.add(predecessor)
for buffer in dependent_buffers:
kernel_IndentedBuffer.writeline(
f"stream{tmp_ssnode.stream_id}_raw.wait_event(event_{buffer.get_name()})"
)
update_event_dependency(ssnode)
for predecessor in ssnode.predecessors:
if predecessor.is_nop_node:
update_event_dependency(predecessor)
def cuda_event_create(self, node_name, kernel_IndentedBuffer):
ssnode = V.graph.stream_graph.buf_to_ssnode[node_name]
kernel_IndentedBuffer.writeline(
f"event_{ssnode.get_name()} = torch.cuda.Event()"
)
def cuda_event_record(self, node_name, kernel_IndentedBuffer):
ssnode = V.graph.stream_graph.buf_to_ssnode[node_name]
kernel_IndentedBuffer.writeline(
f"event_{ssnode.get_name()}.record(stream{ssnode.stream_id}_raw)"
)
def generate_kernel_w_stream(self, node_name, call_strs, stream_switch=True):
"""
Attributes:
node_name: name of the caller buffer
call_strs: list of strings or a single string to write
out_node: for cpp wrapper, we need to define a new buffer before using it.
"""
kernel_IndentedBuffer = IndentedBuffer()
self.cuda_event_dependency(node_name, kernel_IndentedBuffer)
ssnode = V.graph.stream_graph.buf_to_ssnode[node_name]
if ssnode.cuda_event:
self.cuda_event_create(node_name, kernel_IndentedBuffer)
stream_id = ssnode.stream_id
kernel_IndentedBuffer = kernel_IndentedBuffer
DEFAULT_STREAM_ID = V.graph.stream_graph.DEFAULT_STREAM_ID
if stream_id != DEFAULT_STREAM_ID:
if stream_switch:
kernel_IndentedBuffer.writeline(
f"torch.cuda.set_stream(stream{stream_id}_raw)"
)
if isinstance(call_strs, list):
kernel_IndentedBuffer.writelines(call_strs)
else:
kernel_IndentedBuffer.writeline(call_strs)
if stream_switch:
kernel_IndentedBuffer.writeline(f"torch.cuda.set_stream(stream{DEFAULT_STREAM_ID}_raw)")
else:
if isinstance(call_strs, list):
kernel_IndentedBuffer.writelines(call_strs)
else:
kernel_IndentedBuffer.writeline(call_strs)
if ssnode.cuda_event:
self.cuda_event_record(node_name, kernel_IndentedBuffer)
for line in [_ for _ in kernel_IndentedBuffer.getrawvalue().split("\n") if _]:
self.writeline(line)
def generate_fallback_kernel(self, fallback_kernel, args):
self.generate_extern_kernel_alloc(fallback_kernel, args)
@ -780,20 +905,21 @@ class PythonWrapperCodegen(CodeGen):
ending = f".clone(){ending}"
if no_return:
self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}")
self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}", node_name=output_name)
else:
self.writeline(
call_strs = [
f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}"
)
]
if (
self.supports_intermediate_hooks
and config.generate_intermediate_hooks
and origin_node is not None
):
counters["inductor"]["intermediate_hooks"] += 1
self.writeline(
call_strs.append(
f"run_intermediate_hooks({origin_node.name!r}, {output_name})"
)
self.writelines(call_strs, node_name=output_name)
def generate_extern_kernel_out(
self, kernel: str, out: str, out_view: Optional[str], args: List[str]
@ -877,6 +1003,7 @@ class PythonWrapperCodegen(CodeGen):
def _generate(self, is_inference):
if config.profile_bandwidth:
self.write_triton_header_once()
result = IndentedBuffer()
result.splice(self.imports)
result.writeline("")
@ -1787,12 +1914,31 @@ class PythonWrapperCodegen(CodeGen):
)
self.kernel_autotune_names.add(kernel_name)
def writeline(self, line):
self.lines.append(line)
def writeline(self, line, caller=None, node_name=None):
if config.multiple_streams:
if caller is not None:
assert isinstance(caller, ExternKernel)
node_name = caller.name
if node_name is not None:
self.generate_kernel_w_stream(node_name, line)
else:
self.lines.append(line)
else:
self.lines.append(line)
def writelines(self, lines):
for line in lines:
self.writeline(line)
def writelines(self, lines, caller=None, node_name=None):
if config.multiple_streams:
if caller is not None:
assert isinstance(caller, ExternKernel)
node_name = caller.name
if node_name is not None:
self.generate_kernel_w_stream(node_name, lines)
else:
for line in lines:
self.writeline(line)
else:
for line in lines:
self.writeline(line)
def enter_context(self, ctx):
self.lines.append(LineContext(ctx))
@ -1915,7 +2061,20 @@ class PythonWrapperCodegen(CodeGen):
self.codegen_deferred_allocation(name, layout)
return
self.writeline(AllocateLine(self, buffer))
new_allocateline = AllocateLine(self, buffer)
# when the node is nop node, get users from its predecessors
if config.multiple_streams:
ssnode = V.graph.stream_graph.buf_to_ssnode[buffer.get_name()]
new_allocateline.set_user_stream(ssnode.stream_id)
user_streams = set()
if ssnode.is_nop_node:
for predecessor in ssnode.predecessors:
if predecessor.stream_id != ssnode.stream_id:
user_streams.add(predecessor.stream_id)
for user_stream in user_streams:
new_allocateline.add_user_stream(user_stream)
self.writeline(new_allocateline)
def codegen_free(self, buffer):
name = buffer.get_name()

View File

@ -506,6 +506,12 @@ optimize_scatter_upon_const_tensor = (
os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
)
# Enable mutiple streams
multiple_streams = os.environ.get("TORCHINDUCTOR_MULTIPLE_STREAMS", "0") == "1"
# We can set different modes for multiple streams.
# "all": schedule kernels to run in different streams as much as possible.
# "distributed": only schedule communication kernels to run in different streams.
multiple_streams_mode = os.environ.get("TORCHINDUCTOR_MULTIPLE_STREAMS_MODE", "all")
# The multiprocessing start method to use for inductor workers in the codecache.
# Can be "subprocess" or "fork".

View File

@ -418,7 +418,8 @@ class GraphLowering(torch.fx.Interpreter):
self.creation_time = time.time()
self.name = name # type: ignore[assignment]
self.cpp_wrapper = cpp_wrapper
self.stream_graph = None
# record multi_kernel choice for cpp_wrapper so the second pass knows
# which sub-kernel is picked. Copy cpp_wrapper to another variable
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.

View File

@ -42,6 +42,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, SymT
from torch.utils._triton import has_triton
from . import comms, config, dependencies, ir, metrics
from .codecache import write_text
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
from .comm_analysis import estimate_nccl_collective_runtime
@ -1838,6 +1839,9 @@ class Scheduler:
if config.combo_kernels:
self.create_combo_kernel_nodes(num_ck_nodes=None)
self.process_grouped_nodes()
if config.multiple_streams:
from .stream_scheduler import stream_schedule
stream_schedule(self.nodes)
self.compute_last_usage()
V.debug.ir_post_fusion(self.nodes)
V.debug.graph_diagram(self.nodes)
@ -2132,6 +2136,7 @@ class Scheduler:
for node in self.nodes:
for buf in node.get_outputs():
buf.set_users(name_to_users[buf.get_name()].items)
self.name_to_users = name_to_users
def dead_node_elimination(self) -> None:
"""

View File

@ -0,0 +1,245 @@
import logging
from os import environ
from .virtualized import V
from typing import Union
from .scheduler import NopKernelSchedulerNode, FusedSchedulerNode, OutputNode, BaseSchedulerNode
from torch.utils._ordered_set import OrderedSet
log = logging.getLogger(__name__)
import os
import sys
import torch
# It is 1000 by default. But the call depth for dig_node can be larger than this number. @TODO: FIXME
sys.setrecursionlimit(5000)
UNASSIGNED_STREAM_ID = -1
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
if isinstance(snode, FusedSchedulerNode):
return any(in_output(x) for x in snode.snodes)
return any(
isinstance(user.node, OutputNode)
for buf in snode.get_outputs()
for user in buf.users
)
class SSNode:
"""
Stream Scheduler Node is a wrapper of the original node. It contains the information of the original node and the information for stream scheduling.
Attributes:
node_id: the index of the node in the graph
successors: {buf_name: SSNode}. It records the successors of the node. The buf_name is the name of the original node(scheduler node or fused node).
predecessors: {buf_name: SSNode}. It records the predecessors of the node. The buf_name is the name of the original node(scheduler node or fused node).
first_predecessor: SSNode. A node should have same stream id with its first predecessor.
fake_successors, fake_predecessors: {buf_name: SSNode}.
name: the name of the original node.
original_user_names: the names of the original users of the node. If the original node is a fused node, we'll check the users of scheduler nodes included in this fused node.
stream_id: the stream id of the node. -1 means not assigned.
snode_names: the names of the scheduler nodes included in this fused node.
cuda_event: mark if this node needs to generate a CUDA event. CUDA events are used to keep the data dependencies between streams.
is_fused: mark if this node is a fused node.
is_nop_node: mark if this node is a nop node.
node_type: the type of the node. It can be "template", "extern", "foreach", "fused_or_schedule", "nop"
device: by default is None. If it is a cpu node, we will skip it.
skip_cpu_nodes: It is meaningless to add some stream switch for cpu nodes.
"""
def __init__(self, original_node, skip_cpu_nodes=True) -> None:
assert original_node is not None
self.successors = OrderedSet()
self.predecessors = OrderedSet()
self.first_successor = None
self.first_predecessor = None
# self.fake_successors = {}
# self.fake_predecessors = {}
self.name = original_node.get_name()
self.original_user_names = []
# SchedulerNode is operation node now which is different from SchedulerBuffer.
self.original_node = original_node
self.to_output_node = False
# -1 means not assigned. It is important to set it to -1 instead of 0, because 0 is a valid stream id.
self.stream_id = UNASSIGNED_STREAM_ID
if isinstance(original_node, FusedSchedulerNode):
self.snode_names = [node.get_name() for node in original_node.snodes]
else:
self.snode_names = []
# mark if this node needs to generate a CUDA event
self.cuda_event = False
self.is_nop_node = isinstance(original_node, NopKernelSchedulerNode)
self.node_type = None
self.device = None
self.skip_cpu_nodes = skip_cpu_nodes
# is it enough to check if the node is a cpu node?
if self.skip_cpu_nodes and original_node:
if hasattr(original_node, "group"):
for device in original_node.group:
if isinstance(device, torch.device) and device.type != "cpu":
self.device = device.type
break
else:
self.device = "cpu"
elif hasattr(original_node, "get_device"):
self.device = original_node.get_device().type
self.to_output_node = in_output(original_node)
self.is_fused = isinstance(original_node, FusedSchedulerNode)
def get_name(self):
return self.name
class SSGraph:
"""
Stream Scheduler Graph records all the information for stream scheduling.
Attributes:
ssnodes: [SSNode]. It records all the SSNodes in the graph. The order matters.
op_to_ssnode: {buf name: SSNode}. It records the mapping from the original node name to the SSNode. The names include scheduler node name and fused node. For example, buf4, buf5, and buf4_buf5 are all pointed to the same SSNode.
reverse_level: {SSNode: level}. It records the levels back from the OUTPUT node. The level of OUTPUT node is 0. The level of the predecessors of OUTPUT node is 1. The level of the predecessors of the predecessors of OUTPUT node is 2. And so on.
reverse_level_predecessors: {SSNode:reverse_predecessor_node, }. It records a node's predecessor in reverse order.
critical_path: [SSNode]. It records the critical path of the graph. All nodes in the critical path will be assigned to the default stream.
stream_pool_size: how many extra CUDA streams used to allocate. TODO: it's better to use the max number of nodes in the same level in reverse_level
stream_pool: [stream_index, ]. It records the CUDA streams used to allocate.
final_order: [SSNode]. It records the final order of the nodes after reorder. The order matters.
max_stream_id: the max stream id used in the graph.
arg_to_stream: the argument to stream mapping. {arg_name: stream_id}
"""
def __init__(self, snodes) -> None:
self.ssnodes = []
self.op_to_ssnode = {}
self.buf_to_ssnode = {}
# It records the levels back from the OUTPUT node. {ssnode: level, }
self.reverse_level = {}
self.reverse_level_predecessors = {}
self.critical_path = []
self.stream_pool_size = int(environ.get("STREAM_POOL_SIZE", 31))
self.stream_pool = [0] * (self.stream_pool_size + 1)
self.arg_to_stream = {}
self.final_order = []
self.max_stream_id = 0
self.skip_cpu_nodes = True
self.to_output_nodes = []
# By default, we use the same mechanism with the original default stream assignment, which take the stream{device_index} as default stream id.
self.DEFAULT_STREAM_ID = None
self.build_graph(snodes)
def build_graph(self, snodes):
for snode in snodes:
if self.DEFAULT_STREAM_ID is None and not isinstance(snode, NopKernelSchedulerNode) and (device := snode.get_device()):
self.DEFAULT_STREAM_ID = device.index
new_ssnode = SSNode(snode)
self.ssnodes.append(new_ssnode)
self.op_to_ssnode[snode.get_name()] = new_ssnode
if new_ssnode.to_output_node:
self.to_output_nodes.append(new_ssnode.name)
if new_ssnode.is_fused:
for tmp_name in new_ssnode.snode_names:
self.op_to_ssnode[tmp_name] = new_ssnode
for schedulerbuffer in snode.get_outputs():
self.buf_to_ssnode[schedulerbuffer.get_name()] = new_ssnode
# build dependencies
# {buf1: op2, }
buf_last_update_op: Dict[str, str] = {}
for snode in snodes:
deps = snode.read_writes.reads
for schedulerbuffer in snode.get_outputs():
self.buf_to_ssnode[schedulerbuffer.get_name()] = self.op_to_ssnode[
schedulerbuffer.defining_op.get_name()
]
for dep in deps:
# we only need to care about buffers here
last_update_op = buf_last_update_op.get(dep.name, None)
if last_update_op:
dep_node = self.op_to_ssnode[last_update_op]
self.op_to_ssnode[snode.get_name()].predecessors.add(dep_node)
dep_node.successors.add(self.op_to_ssnode[snode.get_name()])
for output in snode.outputs_by_name.keys():
buf_last_update_op[output] = snode.get_name()
def pattern_distributed(self):
tmp_queue = self.to_output_nodes
from .ir import FallbackKernel
from .scheduler import ExternKernelSchedulerNode
for ssnode in self.ssnodes:
if ssnode.stream_id != UNASSIGNED_STREAM_ID:
continue
# copy-in
if isinstance(ssnode.original_node, ExternKernelSchedulerNode) and isinstance(ssnode.original_node.node, FallbackKernel) and "torch.ops.fsdp.all_gather_copy_in.default" in ssnode.original_node.node.python_kernel_name:
new_stream_id = self.stream_pool_pop()
ssnode.stream_id = new_stream_id
tmp_queue = list(ssnode.successors)
while tmp_queue:
tmp_ssnode = tmp_queue.pop()
if tmp_ssnode.stream_id != UNASSIGNED_STREAM_ID:
continue
# copy-out
if isinstance(tmp_ssnode.original_node, ExternKernelSchedulerNode) and isinstance(tmp_ssnode.original_node.node, FallbackKernel) and "torch.ops.fsdp.split_with_sizes_copy.default" in tmp_ssnode.original_node.node.python_kernel_name:
extern_kernel_node_count = 0
for predecessor in tmp_ssnode.predecessors:
predecessor.stream_id = new_stream_id
if isinstance(predecessor.original_node, ExternKernelSchedulerNode):
extern_kernel_node_count += 1
elif isinstance(predecessor.original_node, NopKernelSchedulerNode):
continue
else:
raise RuntimeError(f"Unexpected predecessor {predecessor} for copy_out node {tmp_ssnode}")
assert extern_kernel_node_count == 1, f"Expected exactly one extern kernel node as predecessor for copy_out node {tmp_ssnode}, but got {extern_kernel_node_count}. Pattern match failed."
else:
tmp_queue += list(tmp_ssnode.successors)
tmp_ssnode.stream_id = new_stream_id
else:
ssnode.stream_id = self.DEFAULT_STREAM_ID
def stream_pool_pop(self, predecessor=None):
if predecessor is not None:
self.stream_pool[predecessor.stream_id] += 1
return predecessor.stream_id
else:
min_value = min(self.stream_pool)
min_stream = self.stream_pool.index(min_value)
self.stream_pool[min_stream] += 1
return min_stream
def event_assign(self):
# if at least one of the node's successors is not in the same stream, then we need to add an event
for ssnode in self.ssnodes:
for successor in ssnode.successors:
if successor.stream_id != ssnode.stream_id:
ssnode.cuda_event = True
break
# TODO: double check how we process nop nodes now.
if successor.is_nop_node:
for successor_successor in successor.successors:
if successor_successor.stream_id != ssnode.stream_id:
ssnode.cuda_event = True
break
def stream_assign(self):
# To avoid assigning default stream when we want to pop a new stream from the pool.
self.stream_pool[self.DEFAULT_STREAM_ID] = len(self.ssnodes) + 2
self.pattern_distributed()
def check_all_nodes_assigned():
for ssnode in self.ssnodes:
if ssnode.stream_id == UNASSIGNED_STREAM_ID:
log.info(
f"Hanging node {ssnode.get_name()} found when doing stream assignment."
)
self.dig_node(ssnode)
return False
return True
while not check_all_nodes_assigned():
pass
def stream_schedule(snodes):
# Need to be same with where calls `write_get_raw_stream`
ssgraph = SSGraph(snodes)
ssgraph.stream_assign()
ssgraph.event_assign()
V.graph.stream_graph = ssgraph
return ssgraph

View File

@ -283,7 +283,10 @@ if HAS_PYDOT:
fname = self._shorten_file_name(parsed_stack_trace.file)
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
stream_id = node.meta.get('stream_id', None)
if stream_id is not None:
label += f"|stream={stream_id}" + r"\n"
return label + "}"
def _tensor_meta_to_label(self, tm) -> str: