[triton][export] serialization in internal path + unit tests (#162200)

Summary: will package triton artifacts to be runnable in nativert if wrappers exist.

Test Plan:
unit tests

Rollback Plan:

Differential Revision: D81368559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162200
Approved by: https://github.com/angelayi
This commit is contained in:
dolpm
2025-09-10 09:49:08 +00:00
committed by PyTorch MergeBot
parent f0ae3a57f6
commit d9832d8425
2 changed files with 61 additions and 19 deletions

View File

@ -673,7 +673,7 @@ def forward(self, x):
kwargs.append(arg.arg)
self.assertEqual(len(args), 4)
self.assertEqual(len(kwargs), 4)
self.assertEqual(len(kwargs), 5)
for i in range(3):
self.assertIsNotNone(args[i].as_tensor)
@ -686,6 +686,7 @@ def forward(self, x):
self.assertEqual(
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4
) # num warps
self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes
self.assertEqual(len(triton_node.outputs), 1)
self.assertIsNotNone(triton_node.outputs[0].as_tensors)

View File

@ -510,6 +510,59 @@ class Final(type):
return type.__new__(metacls, name, bases, dict(classdict))
def get_triton_kernel_and_cache_entry(node: torch.fx.Node):
assert (
node.target
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
)
assert has_triton(), "triton required to serialize triton kernels"
from triton.runtime.autotuner import Autotuner
assert isinstance(node.kwargs["kernel_idx"], int)
kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel(
node.kwargs["kernel_idx"]
)
kNumWarpsDefault = 4
# currently we only support specialization of
# num_warps -- so search for the entry that
# matches the value from the associated kernel
if isinstance(kernel, Autotuner):
assert len(kernel.configs) == 1
num_warps = kernel.configs[0].num_warps
assert kernel.configs[0].num_ctas == 1, (
"serialization only supports num_ctas == 1"
)
kernel = kernel.fn
else:
num_warps = kNumWarpsDefault
if hasattr(kernel, "device_caches"):
caches = kernel.device_caches
assert len(caches.keys()) == 1
cache = next(iter(caches.values()))[0]
elif hasattr(kernel, "cache"):
# old path, still used for cpu triton builds
caches = kernel.cache
assert len(caches.keys()) == 1
cache = next(iter(caches.values()))
else:
raise AssertionError(f"kernel caches not found for kernel {kernel.__name__}")
# can also get num_warps, num_ctas, etc. from here ig
if len(cache.keys()) == 1:
return kernel, next(iter(cache.values()))
else:
for cache_entry in cache.values():
if cache_entry.metadata.num_warps == num_warps:
return kernel, cache_entry
raise AssertionError(
f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {kernel.__name__}"
)
@final
class GraphModuleSerializer(metaclass=Final):
def __init__(
@ -676,8 +729,8 @@ class GraphModuleSerializer(metaclass=Final):
node.target
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
):
assert has_triton(), "triton required to serialize triton kernels"
from triton.runtime.autotuner import Autotuner
kernel, kernel_cache_entry = get_triton_kernel_and_cache_entry(node)
kernel_cache_metadata = kernel_cache_entry.metadata
meta_val = node.meta["val"]
assert isinstance(meta_val, dict)
@ -685,21 +738,6 @@ class GraphModuleSerializer(metaclass=Final):
output_keys = meta_val.keys()
output_indices = []
assert isinstance(node.kwargs["kernel_idx"], int)
kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel(
node.kwargs["kernel_idx"]
)
if isinstance(kernel, Autotuner):
assert len(kernel.configs) == 1
num_warps = kernel.configs[0].num_warps
assert kernel.configs[0].num_ctas == 1, (
"serialization only supports num_ctas == 1"
)
kernel = kernel.fn
else:
num_warps = 4
constexpr_keys = set()
for p in kernel.params:
if p.is_constexpr:
@ -732,9 +770,12 @@ class GraphModuleSerializer(metaclass=Final):
"name": kernel.fn.__name__,
"grid": node.kwargs["grid"][0],
"output_indices": output_indices,
"num_warps": num_warps,
"num_warps": kernel_cache_metadata.num_warps,
}
if hasattr(kernel_cache_metadata, "shared"):
kwargs_new["shared_memory_bytes"] = kernel_cache_metadata.shared
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(args_new, kwargs_new),