Speed up AMD AOT Inductor lowering by memoizing hipify trie to regex logic (#140156)

Summary:
AMD lowering duration is 1.55x longer than H100. Profiling shows hipification related functions took 22% of overall lowering time.

This diff cuts that time by safely memoize the trie to regex logic. The trick is to incrementally build a state of the trie during the trie construction. The state is the hash of all the words added to the trie.

Differential Revision: D65659445

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140156
Approved by: https://github.com/ColinPeppler

Co-authored-by: Kefei Lu <kefeilu@meta.com>
This commit is contained in:
Kefei Lu
2024-11-09 04:28:56 +00:00
committed by PyTorch MergeBot
parent 8b2e3855a9
commit d2d1258b1b

View File

@ -38,6 +38,8 @@ from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
from typing import Dict, List, Iterator, Optional
from collections.abc import Mapping, Iterable
from enum import Enum
import functools
import hashlib
class CurrentState(Enum):
INITIALIZED = 1
@ -678,9 +680,13 @@ class Trie:
def __init__(self):
"""Initialize the trie with an empty root node."""
self.root = TrieNode()
self._hash = hashlib.md5()
self._digest = self._hash.digest()
def add(self, word):
"""Add a word to the Trie. """
self._hash.update(word.encode())
self._digest = self._hash.digest()
node = self.root
for char in word:
@ -709,8 +715,13 @@ class Trie:
# make sure to check the end-of-word marker present
return '' in node.children
def _pattern(self, root):
"""Convert a Trie into a regular expression pattern"""
@functools.lru_cache # noqa: B019
def _pattern(self, root, digest):
"""Convert a Trie into a regular expression pattern
Memoized on the hash digest of the trie, which is built incrementally
during add().
"""
node = root
if "" in node.children and len(node.children.keys()) == 1:
@ -722,7 +733,7 @@ class Trie:
for char in sorted(node.children.keys()):
if isinstance(node.children[char], TrieNode):
try:
recurse = self._pattern(node.children[char])
recurse = self._pattern(node.children[char], self._digest)
alt.append(self.quote(char) + recurse)
except Exception:
cc.append(self.quote(char))
@ -750,11 +761,11 @@ class Trie:
def pattern(self):
"""Export the Trie to a regex pattern."""
return self._pattern(self.root)
return self._pattern(self.root, self._digest)
def export_to_regex(self):
"""Export the Trie to a regex pattern."""
return self._pattern(self.root)
return self._pattern(self.root, self._digest)
CAFFE2_TRIE = Trie()
CAFFE2_MAP = {}