[inductor] Allow num_program specification for TMA workspace (#152844)

Summary:
Allow TMA workspace creation allow specification for `num_programs`, which defaults to `num_sms` when not specified.

We need a total `num_programs * num_tma_descriptors` no. of descriptors for a kernel.

Test Plan: CI.

Differential Revision: D74189599

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152844
Approved by: https://github.com/drisspg
This commit is contained in:
Mandar Deshpande
2025-05-05 23:02:52 +00:00
committed by PyTorch MergeBot
parent cc954848d4
commit e3064bf0e3

View File

@ -1420,12 +1420,15 @@ def get_num_sms() -> int:
def get_tma_workspace_arg(
num_tma_descriptors: int,
device: torch.device,
num_programs: Optional[int] = None,
) -> WorkspaceArg:
"""Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
from .codegen.common import WorkspaceArg, WorkspaceZeroMode
if num_programs is None:
num_programs = get_num_sms()
zero_mode = WorkspaceZeroMode.from_bool(False)
size = get_num_sms() * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
return WorkspaceArg(
count=size,
zero_mode=zero_mode,