mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[triton][export] serialization in internal path + unit tests (#162200)
Summary: will package triton artifacts to be runnable in nativert if wrappers exist. Test Plan: unit tests Rollback Plan: Differential Revision: D81368559 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162200 Approved by: https://github.com/angelayi
This commit is contained in:
@ -673,7 +673,7 @@ def forward(self, x):
|
||||
kwargs.append(arg.arg)
|
||||
|
||||
self.assertEqual(len(args), 4)
|
||||
self.assertEqual(len(kwargs), 4)
|
||||
self.assertEqual(len(kwargs), 5)
|
||||
|
||||
for i in range(3):
|
||||
self.assertIsNotNone(args[i].as_tensor)
|
||||
@ -686,6 +686,7 @@ def forward(self, x):
|
||||
self.assertEqual(
|
||||
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4
|
||||
) # num warps
|
||||
self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes
|
||||
|
||||
self.assertEqual(len(triton_node.outputs), 1)
|
||||
self.assertIsNotNone(triton_node.outputs[0].as_tensors)
|
||||
|
@ -510,6 +510,59 @@ class Final(type):
|
||||
return type.__new__(metacls, name, bases, dict(classdict))
|
||||
|
||||
|
||||
def get_triton_kernel_and_cache_entry(node: torch.fx.Node):
|
||||
assert (
|
||||
node.target
|
||||
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
|
||||
)
|
||||
|
||||
assert has_triton(), "triton required to serialize triton kernels"
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
|
||||
assert isinstance(node.kwargs["kernel_idx"], int)
|
||||
kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel(
|
||||
node.kwargs["kernel_idx"]
|
||||
)
|
||||
|
||||
kNumWarpsDefault = 4
|
||||
|
||||
# currently we only support specialization of
|
||||
# num_warps -- so search for the entry that
|
||||
# matches the value from the associated kernel
|
||||
if isinstance(kernel, Autotuner):
|
||||
assert len(kernel.configs) == 1
|
||||
num_warps = kernel.configs[0].num_warps
|
||||
assert kernel.configs[0].num_ctas == 1, (
|
||||
"serialization only supports num_ctas == 1"
|
||||
)
|
||||
kernel = kernel.fn
|
||||
else:
|
||||
num_warps = kNumWarpsDefault
|
||||
|
||||
if hasattr(kernel, "device_caches"):
|
||||
caches = kernel.device_caches
|
||||
assert len(caches.keys()) == 1
|
||||
cache = next(iter(caches.values()))[0]
|
||||
elif hasattr(kernel, "cache"):
|
||||
# old path, still used for cpu triton builds
|
||||
caches = kernel.cache
|
||||
assert len(caches.keys()) == 1
|
||||
cache = next(iter(caches.values()))
|
||||
else:
|
||||
raise AssertionError(f"kernel caches not found for kernel {kernel.__name__}")
|
||||
|
||||
# can also get num_warps, num_ctas, etc. from here ig
|
||||
if len(cache.keys()) == 1:
|
||||
return kernel, next(iter(cache.values()))
|
||||
else:
|
||||
for cache_entry in cache.values():
|
||||
if cache_entry.metadata.num_warps == num_warps:
|
||||
return kernel, cache_entry
|
||||
raise AssertionError(
|
||||
f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {kernel.__name__}"
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class GraphModuleSerializer(metaclass=Final):
|
||||
def __init__(
|
||||
@ -676,8 +729,8 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
node.target
|
||||
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
|
||||
):
|
||||
assert has_triton(), "triton required to serialize triton kernels"
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
kernel, kernel_cache_entry = get_triton_kernel_and_cache_entry(node)
|
||||
kernel_cache_metadata = kernel_cache_entry.metadata
|
||||
|
||||
meta_val = node.meta["val"]
|
||||
assert isinstance(meta_val, dict)
|
||||
@ -685,21 +738,6 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
output_keys = meta_val.keys()
|
||||
output_indices = []
|
||||
|
||||
assert isinstance(node.kwargs["kernel_idx"], int)
|
||||
kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel(
|
||||
node.kwargs["kernel_idx"]
|
||||
)
|
||||
|
||||
if isinstance(kernel, Autotuner):
|
||||
assert len(kernel.configs) == 1
|
||||
num_warps = kernel.configs[0].num_warps
|
||||
assert kernel.configs[0].num_ctas == 1, (
|
||||
"serialization only supports num_ctas == 1"
|
||||
)
|
||||
kernel = kernel.fn
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
constexpr_keys = set()
|
||||
for p in kernel.params:
|
||||
if p.is_constexpr:
|
||||
@ -732,9 +770,12 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
"name": kernel.fn.__name__,
|
||||
"grid": node.kwargs["grid"][0],
|
||||
"output_indices": output_indices,
|
||||
"num_warps": num_warps,
|
||||
"num_warps": kernel_cache_metadata.num_warps,
|
||||
}
|
||||
|
||||
if hasattr(kernel_cache_metadata, "shared"):
|
||||
kwargs_new["shared_memory_bytes"] = kernel_cache_metadata.shared
|
||||
|
||||
ex_node = Node(
|
||||
target=self.serialize_operator(node.target),
|
||||
inputs=self.serialize_hoo_inputs(args_new, kwargs_new),
|
||||
|
Reference in New Issue
Block a user