[Graph Partition] interface for custom cg wrapper (#162207)

This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](https://github.com/vllm-project/vllm/pull/24281)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162207
Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
This commit is contained in:
Boyuan Feng
2025-09-06 03:13:01 +00:00
committed by PyTorch MergeBot
parent b2b4add0e7
commit c0983e6cc0
3 changed files with 71 additions and 3 deletions

View File

@ -3395,8 +3395,8 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
def is_using_cudagraph_partition() -> bool:
return (
torch._inductor.config.triton.cudagraphs
and torch._inductor.config.graph_partition
)
or _unstable_customized_partition_wrapper.wrapper is not None
) and torch._inductor.config.graph_partition
def dtype_from_size(size: int) -> torch.dtype:
@ -3621,3 +3621,48 @@ def python_subprocess_env() -> dict[str, str]:
env["PYTHONHOME"] = sysconfig.get_path("data")
return env
@dataclasses.dataclass(frozen=True)
class CUDAGraphWrapperMetadata:
"""
Metadata for Customized CUDAGraphWrapper.
Currently assumes there is 1 dynamo graph and will extend to
multiple graphs in the future.
"""
# The number of partitions that are cudagraphable.
num_partitions: int
# Index of the current partition.
partition_index: int
PartitionFnType = Callable[..., Any]
CUDAGraphWrapperType = Callable[
[PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
]
# only incremented by user call of mark_step_begin
class CUDAGraphWrapper:
wrapper: Optional[CUDAGraphWrapperType] = None
# A customized partition wrappers from users. Interface should be:
#
# def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
#
# Inductor generates N wrapper functions for N partition functions, and mechanically wrap
# each partition fn with the generated wrapper function. Users need to handle all details
# such as static inputs, dynamic shapes, etc.
# Users could customize the wrapper based on the metadata. One example is to have special
# handle for the first and last wrapper function.
#
# Warning: This API is unstable and may change in the future.
_unstable_customized_partition_wrapper = CUDAGraphWrapper()
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
_unstable_customized_partition_wrapper.wrapper = wrapper