mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Graph Partition] interface for custom cg wrapper (#162207)
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](https://github.com/vllm-project/vllm/pull/24281) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162207 Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2b4add0e7
commit
c0983e6cc0
@ -3395,8 +3395,8 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
|
||||
def is_using_cudagraph_partition() -> bool:
|
||||
return (
|
||||
torch._inductor.config.triton.cudagraphs
|
||||
and torch._inductor.config.graph_partition
|
||||
)
|
||||
or _unstable_customized_partition_wrapper.wrapper is not None
|
||||
) and torch._inductor.config.graph_partition
|
||||
|
||||
|
||||
def dtype_from_size(size: int) -> torch.dtype:
|
||||
@ -3621,3 +3621,48 @@ def python_subprocess_env() -> dict[str, str]:
|
||||
env["PYTHONHOME"] = sysconfig.get_path("data")
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CUDAGraphWrapperMetadata:
|
||||
"""
|
||||
Metadata for Customized CUDAGraphWrapper.
|
||||
|
||||
Currently assumes there is 1 dynamo graph and will extend to
|
||||
multiple graphs in the future.
|
||||
"""
|
||||
|
||||
# The number of partitions that are cudagraphable.
|
||||
num_partitions: int
|
||||
|
||||
# Index of the current partition.
|
||||
partition_index: int
|
||||
|
||||
|
||||
PartitionFnType = Callable[..., Any]
|
||||
CUDAGraphWrapperType = Callable[
|
||||
[PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
|
||||
]
|
||||
|
||||
|
||||
# only incremented by user call of mark_step_begin
|
||||
class CUDAGraphWrapper:
|
||||
wrapper: Optional[CUDAGraphWrapperType] = None
|
||||
|
||||
|
||||
# A customized partition wrappers from users. Interface should be:
|
||||
#
|
||||
# def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
|
||||
#
|
||||
# Inductor generates N wrapper functions for N partition functions, and mechanically wrap
|
||||
# each partition fn with the generated wrapper function. Users need to handle all details
|
||||
# such as static inputs, dynamic shapes, etc.
|
||||
# Users could customize the wrapper based on the metadata. One example is to have special
|
||||
# handle for the first and last wrapper function.
|
||||
#
|
||||
# Warning: This API is unstable and may change in the future.
|
||||
_unstable_customized_partition_wrapper = CUDAGraphWrapper()
|
||||
|
||||
|
||||
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
|
||||
_unstable_customized_partition_wrapper.wrapper = wrapper
|
||||
|
Reference in New Issue
Block a user