mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Summary: See https://github.com/pytorch/pytorch/issues/130772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137298 Approved by: https://github.com/oulgen, https://github.com/eellison
		
			
				
	
	
		
			73 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			73 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import hashlib
 | 
						|
from abc import ABC, abstractmethod
 | 
						|
from functools import lru_cache
 | 
						|
from typing import Any, Callable, Optional, Tuple, Union
 | 
						|
from typing_extensions import TypeAlias
 | 
						|
 | 
						|
import torch.fx.graph
 | 
						|
 | 
						|
 | 
						|
class CustomGraphPass(ABC):
 | 
						|
    """
 | 
						|
    Implement this interface for custom Graph passes:
 | 
						|
 | 
						|
    1) The __call__() method contains the implementation of the custom pass.
 | 
						|
 | 
						|
    2) The uuid() method enables inductor to cache compiled graphs when your custom
 | 
						|
    passes are applied. This method can return any identifier as long as it uniquely
 | 
						|
    identifies your implementation (and can be pickled). The caching logic includes this
 | 
						|
    identifier in its key calculation, i.e., any new value will effectively invalidate
 | 
						|
    existing entries. We expect custom passes would typically depend purely on the
 | 
						|
    textual reprensentation of the implementation. In that case, we recommend using the
 | 
						|
    'get_hash_for_files' helper below to compute a unique hash from the contents of a
 | 
						|
    static list of source files, i.e., the source(s) containing the custom pass
 | 
						|
    implementation. That approach ensures that any change to the implementation will
 | 
						|
    mean a new uuid.
 | 
						|
 | 
						|
    ** IMPORTANT ** If your custom pass's behavior depends on some external state, then
 | 
						|
    you'll need to implement something more complicated (or disable caching).
 | 
						|
 | 
						|
    EXAMPLE:
 | 
						|
 | 
						|
    class MyCustomGraphPass(CustomGraphPass):
 | 
						|
        def __call__(self, graph: torch.fx.graph.Graph) -> None:
 | 
						|
            # my custom graph optimization pass
 | 
						|
            #     ...
 | 
						|
 | 
						|
        def uuid(self) -> Optional[Any]:
 | 
						|
            return get_hash_for_files((__file__,))
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def __call__(self, graph: torch.fx.graph.Graph) -> None:
 | 
						|
        """
 | 
						|
        Implementation of the custom pass.
 | 
						|
        """
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    def uuid(self) -> Optional[Any]:
 | 
						|
        """
 | 
						|
        Return an ID to uniquely identify your custom pass implementation. Return None
 | 
						|
        to skip inductor code caching entirely.
 | 
						|
        """
 | 
						|
 | 
						|
 | 
						|
CustomGraphPassType: TypeAlias = Optional[
 | 
						|
    Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]]
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
@lru_cache(1)
 | 
						|
def get_hash_for_files(paths: Tuple[str], extra: str = "") -> bytes:
 | 
						|
    """
 | 
						|
    Helper to compute a unique string by hashing the contents of a list of files.
 | 
						|
    """
 | 
						|
    hasher = hashlib.sha256()
 | 
						|
    hasher.update(extra.encode("utf-8"))
 | 
						|
    for path in paths:
 | 
						|
        with open(path, "rb") as f:
 | 
						|
            hasher.update(path.encode("utf-8"))
 | 
						|
            hasher.update(f.read())
 | 
						|
    return hasher.digest()
 |