mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Speed up AMD AOT Inductor lowering by memoizing hipify trie to regex logic (#140156)
Summary: AMD lowering duration is 1.55x longer than H100. Profiling shows hipification related functions took 22% of overall lowering time. This diff cuts that time by safely memoize the trie to regex logic. The trick is to incrementally build a state of the trie during the trie construction. The state is the hash of all the words added to the trie. Differential Revision: D65659445 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140156 Approved by: https://github.com/ColinPeppler Co-authored-by: Kefei Lu <kefeilu@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
8b2e3855a9
commit
d2d1258b1b
@ -38,6 +38,8 @@ from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
|
||||
from typing import Dict, List, Iterator, Optional
|
||||
from collections.abc import Mapping, Iterable
|
||||
from enum import Enum
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
class CurrentState(Enum):
|
||||
INITIALIZED = 1
|
||||
@ -678,9 +680,13 @@ class Trie:
|
||||
def __init__(self):
|
||||
"""Initialize the trie with an empty root node."""
|
||||
self.root = TrieNode()
|
||||
self._hash = hashlib.md5()
|
||||
self._digest = self._hash.digest()
|
||||
|
||||
def add(self, word):
|
||||
"""Add a word to the Trie. """
|
||||
self._hash.update(word.encode())
|
||||
self._digest = self._hash.digest()
|
||||
node = self.root
|
||||
|
||||
for char in word:
|
||||
@ -709,8 +715,13 @@ class Trie:
|
||||
# make sure to check the end-of-word marker present
|
||||
return '' in node.children
|
||||
|
||||
def _pattern(self, root):
|
||||
"""Convert a Trie into a regular expression pattern"""
|
||||
@functools.lru_cache # noqa: B019
|
||||
def _pattern(self, root, digest):
|
||||
"""Convert a Trie into a regular expression pattern
|
||||
|
||||
Memoized on the hash digest of the trie, which is built incrementally
|
||||
during add().
|
||||
"""
|
||||
node = root
|
||||
|
||||
if "" in node.children and len(node.children.keys()) == 1:
|
||||
@ -722,7 +733,7 @@ class Trie:
|
||||
for char in sorted(node.children.keys()):
|
||||
if isinstance(node.children[char], TrieNode):
|
||||
try:
|
||||
recurse = self._pattern(node.children[char])
|
||||
recurse = self._pattern(node.children[char], self._digest)
|
||||
alt.append(self.quote(char) + recurse)
|
||||
except Exception:
|
||||
cc.append(self.quote(char))
|
||||
@ -750,11 +761,11 @@ class Trie:
|
||||
|
||||
def pattern(self):
|
||||
"""Export the Trie to a regex pattern."""
|
||||
return self._pattern(self.root)
|
||||
return self._pattern(self.root, self._digest)
|
||||
|
||||
def export_to_regex(self):
|
||||
"""Export the Trie to a regex pattern."""
|
||||
return self._pattern(self.root)
|
||||
return self._pattern(self.root, self._digest)
|
||||
|
||||
CAFFE2_TRIE = Trie()
|
||||
CAFFE2_MAP = {}
|
||||
|
Reference in New Issue
Block a user