Compare commits

...

3 Commits

4 changed files with 150 additions and 21 deletions

View File

@ -15497,6 +15497,43 @@ if RUN_GPU:
fn()
@config.patch(implicit_fallbacks=True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_default_device_context(self):
@torch.library.custom_op(
"mylib::cg_unsafe_op",
mutates_args=[],
schema="(Tensor x) -> Tensor",
device_types=GPU_TYPE,
tags=(torch._C.Tag.cudagraph_unsafe,),
)
def cg_unsafe_op(x) -> torch.Tensor:
return x + 1
@cg_unsafe_op.register_fake
def _(x) -> torch.Tensor:
return torch.empty_like(x)
def f(x):
x += 1
y = cg_unsafe_op(x)
y += 1
return y
f = torch.compile(f, mode="reduce-overhead")
inp = torch.randn(2, device=GPU_TYPE)
_, (code,) = run_and_get_code(f, inp)
if config.cpp_wrapper:
FileCheck().check_count(
"AOTICudaGuard device_guard(0)", 1, exactly=True
).run(code)
else:
FileCheck().check_count(
"with torch.cuda._DeviceGuard(0)", 1, exactly=True
).run(code)
class RNNTest(TestCase):
device_type = GPU_TYPE

View File

@ -1179,14 +1179,15 @@ class PythonWrapperCodegen(CodeGen):
)
def write_get_raw_stream_header(self) -> None:
import_get_raw_stream_str = V.graph.device_ops.import_get_raw_stream_as(
"get_raw_stream"
)
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
if not self.kernel_autotune_calls.contains(import_get_raw_stream_str):
self.kernel_autotune_calls.writeline(import_get_raw_stream_str)
if not V.graph.cpp_wrapper:
self.imports.writeline(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
if not self.imports.contains(import_get_raw_stream_str):
self.imports.writeline(import_get_raw_stream_str)
@cache_on_self
def write_get_raw_stream_header_once(self) -> None:
@ -1333,7 +1334,7 @@ class PythonWrapperCodegen(CodeGen):
# that stream caching happens per graph instance. this
# is important for nested subgraph codegening.
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
self.write_get_raw_stream_header_once()
self.write_get_raw_stream_header()
name = f"stream{device_idx}"
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import collections
import contextlib
import dataclasses
import functools
import inspect
@ -19,7 +20,7 @@ from typing_extensions import ParamSpec, TypeAlias
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from types import ModuleType
import sympy
@ -2248,6 +2249,9 @@ class Scheduler:
for node in self.nodes:
node.prune_deps()
# See [Note: Graph Partition Device Contexts]
self.default_device_context: Optional[torch.device] = None
self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
self.get_donated_buffers()
)
@ -5194,6 +5198,80 @@ class Scheduler:
[node.get_name() for node in signature.output_nodes]
)
def use_default_device_context(
self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature]
) -> contextlib.AbstractContextManager[None]:
@contextlib.contextmanager
def ctx() -> Iterator[None]:
self.update_graph_partition_default_device(partitions, signatures)
if self.default_device_context and device_need_guard(
self.default_device_context.type
):
assert self.default_device_context.index is not None, (
"device should have an index"
)
V.graph.wrapper_code.codegen_device_guard_enter(
self.default_device_context.index
)
try:
yield
finally:
if self.default_device_context and device_need_guard(
self.default_device_context.type
):
V.graph.wrapper_code.codegen_device_guard_exit()
self.default_device_context = None
return ctx()
def update_graph_partition_default_device(
self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature]
) -> None:
# Note: [Graph Partition Device Contexts]
# Entering a device context takes 60 microseconds and exiting a device
# context takes 20 microseconds. If all graph partitions and
# cudagraph-unsafe ops happen on the same device, we can share the
# device context.
if len(partitions) == 1 and not signatures[0].skip_cudagraph:
# If there is only 1 cudagraph partition, the device context
# should happen within the cudagraph partition, which
# would be removed by cudagraph.
return
def get_cudagraph_partition_device(partition: PartitionType) -> torch.device:
partition_device = partition[0].get_device()
assert partition_device is not None
return partition_device
def all_on_target_device(
partition: PartitionType, target_device: torch.device
) -> bool:
for node in partition:
device = node.get_device()
if device != target_device:
return False
return True
cudagraph_partition_device = None
for partition, signature in zip(partitions, signatures):
if not signature.skip_cudagraph:
cudagraph_partition_device = get_cudagraph_partition_device(partition)
break
# all partitions skip cudagraph
if cudagraph_partition_device is None:
return
for partition, signature in zip(partitions, signatures):
if signature.skip_cudagraph and not all_on_target_device(
partition, cudagraph_partition_device
):
return
self.default_device_context = cudagraph_partition_device
def _codegen_partitions(self) -> None:
"""
Split nodes into partitions and codegen each partition into separate functions.
@ -5206,15 +5284,16 @@ class Scheduler:
msg = f"cudagraph partition into {len(partitions)} partitions"
maybe_log_cudagraph_partition(msg=msg, prefix="")
for partition, signature in zip(partitions, signatures):
assert len(partition) >= 1, (
f"Each partition must have at least one node but found {len(partition)}"
)
with self.use_default_device_context(partitions, signatures):
for partition, signature in zip(partitions, signatures):
assert len(partition) >= 1, (
f"Each partition must have at least one node but found {len(partition)}"
)
if signature.skip_cudagraph:
self._codegen(partition)
else:
self._codegen_partition_wrapper(partition, signature)
if signature.skip_cudagraph:
self._codegen(partition)
else:
self._codegen_partition_wrapper(partition, signature)
num_partitions = next(self._graph_partition_counter)
V.graph.wrapper_code.set_all_partition_names(num_partitions)
@ -5247,7 +5326,11 @@ class Scheduler:
)
seen.add(key)
self.current_device = None
self.current_device = self.default_device_context
if self.default_device_context and config.triton.autotune_at_compile_time:
V.graph.wrapper_code.write_get_raw_stream_header()
for node in nodes:
if log.isEnabledFor(logging.DEBUG):
try:
@ -5326,10 +5409,15 @@ class Scheduler:
):
self.flush()
if self.current_device and device_need_guard(self.current_device.type):
# exit the outermost CUDA device guard. this is
# important for nested indentation codegen-ing.
V.graph.wrapper_code.codegen_device_guard_exit()
if self.current_device != self.default_device_context:
# when default_device_context is not None, we are codegen
# for graph partitions and all nodes must be on
# the same default device.
assert self.current_device is not None
if device_need_guard(self.current_device.type):
# exit the outermost CUDA device guard. this is
# important for nested indentation codegen-ing.
V.graph.wrapper_code.codegen_device_guard_exit()
self.flush()

View File

@ -1460,6 +1460,9 @@ class IndentedBuffer:
res.writelines(other._lines)
return res
def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool:
return new_line in self._lines
class FakeIndentedBuffer(IndentedBuffer):
def __init__(self) -> None: