mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Allow num_program specification for TMA workspace (#152844)
Summary: Allow TMA workspace creation allow specification for `num_programs`, which defaults to `num_sms` when not specified. We need a total `num_programs * num_tma_descriptors` no. of descriptors for a kernel. Test Plan: CI. Differential Revision: D74189599 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152844 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc954848d4
commit
e3064bf0e3
@ -1420,12 +1420,15 @@ def get_num_sms() -> int:
|
||||
def get_tma_workspace_arg(
|
||||
num_tma_descriptors: int,
|
||||
device: torch.device,
|
||||
num_programs: Optional[int] = None,
|
||||
) -> WorkspaceArg:
|
||||
"""Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
|
||||
from .codegen.common import WorkspaceArg, WorkspaceZeroMode
|
||||
|
||||
if num_programs is None:
|
||||
num_programs = get_num_sms()
|
||||
zero_mode = WorkspaceZeroMode.from_bool(False)
|
||||
size = get_num_sms() * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
|
||||
size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
|
||||
return WorkspaceArg(
|
||||
count=size,
|
||||
zero_mode=zero_mode,
|
||||
|
Reference in New Issue
Block a user