mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][ez] add src_hash property for Templates (#161468)
# why enable caching/overriding/filtering based on src hash later # what - KernelTemplate has a src_hash that is None by default - sha256 on TritonTemplate of the template src code - None on ExternKernelChoice to have same API # testing n/a (not in use in this change) Differential Revision: [](https://our.internmc.facebook.com/intern/diff/) Differential Revision: [D81821149](https://our.internmc.facebook.com/intern/diff/D81821149) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161468 Approved by: https://github.com/eellison ghstack dependencies: #161351, #161350, #162293
This commit is contained in:
committed by
PyTorch MergeBot
parent
269c9907a0
commit
25f1a5d8d1
@ -2407,8 +2407,9 @@ class KernelTemplate:
|
||||
|
||||
return get_dtype
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
def __init__(self, name: str, hash: Optional[str] = None) -> None:
|
||||
self.name = name
|
||||
self._hash = hash
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
@ -2421,6 +2422,17 @@ class KernelTemplate:
|
||||
# TODO(coconutruben): add some central registration to assert on global uniqueness
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def src_hash(self) -> Union[str, None]:
|
||||
"""
|
||||
source hash for a Template.
|
||||
|
||||
Templates can optionally provide a src hash to make it easier to cache/validate that
|
||||
a template has not changed from one version to another. Override this if that detection
|
||||
is different for your specific Template
|
||||
"""
|
||||
return self._hash
|
||||
|
||||
def choice_or_none(self, **kwargs: Any) -> Optional[ChoiceCaller]:
|
||||
"""
|
||||
Maybe generates a new ChoiceCaller and returns it, or None if generation fails.
|
||||
|
@ -39,7 +39,7 @@ class KernelTemplateChoice:
|
||||
"""
|
||||
Lazily evaluate and return the ChoiceCaller for this template choice.
|
||||
|
||||
On first access, calls template.choice_or_None() with the stored parameters.
|
||||
On first access, calls template.choice_or_none() with the stored parameters.
|
||||
If successful, caches and returns the ChoiceCaller. If it fails, caches
|
||||
and returns None. Subsequent accesses return the cached value.
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
@ -1433,7 +1434,7 @@ class TritonTemplate(KernelTemplate):
|
||||
cache_codegen_enabled_for_template=False,
|
||||
prologue_loads_all_inputs=False,
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
super().__init__(name, hash=hashlib.sha256(source.encode("utf-8")).hexdigest())
|
||||
self.grid = grid
|
||||
self.template = self._template_from_string(source)
|
||||
assert name not in self.all_templates, "duplicate template name"
|
||||
@ -1888,6 +1889,10 @@ class ExternKernelChoice:
|
||||
self.op_overload = op_overload
|
||||
self.use_fallback_kernel = use_fallback_kernel
|
||||
self.kernel_creator = kernel_creator
|
||||
# match the API for KernelTemplate as they can be treated the same
|
||||
# There is no src hash for ExternKernelChoice in the traditional sense
|
||||
# so we indicate this by returning None
|
||||
self.src_hash = None
|
||||
|
||||
def to_callable(self):
|
||||
return getattr(extern_kernels, self.name)
|
||||
|
Reference in New Issue
Block a user