Compare commits

...

1 Commits

Author SHA1 Message Date
277ff19a55 Revert "[cutlass backend] Cache config generation locally and remotely (#154686)"
This reverts commit a4b0023f3b3c28735ac7d4adede6343746c8c946.
2025-06-01 15:49:45 -07:00
3 changed files with 2 additions and 104 deletions

View File

@ -1,93 +0,0 @@
# mypy: allow-untyped-defs
import functools
import hashlib
import json
import logging
import os
import time
from typing import Any, Optional
import torch._inductor.config as config
from torch._inductor.codecache import cutlass_key
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.utils import clear_on_fresh_inductor_cache
log = logging.getLogger(__name__)
CONFIG_PREFIX: str = "configs"
def get_config_request_key(
arch: str,
cuda_version: str,
instantiation_level: str,
) -> str:
"""
Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level.
"""
hash_target = "-".join(
[
cutlass_key().decode(),
arch,
cuda_version,
instantiation_level,
]
)
return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8]
def _generate_config_filename(request_key: str) -> str:
"""
Generate a filename for the full ops.
"""
return f"{CONFIG_PREFIX}_{request_key}.json"
@clear_on_fresh_inductor_cache
@functools.lru_cache(None)
def maybe_fetch_ops() -> Optional[list[Any]]:
"""
Fetch ops from databases.
"""
if config.force_disable_caches:
return None
# setup
arch: str = get_cuda_arch()
# get_cuda_version might return "12.4.0" or "12.4"
# but we want to use "12.4"
version: str = ".".join(get_cuda_version().split(".")[:2])
instantiation_level: str = config.cuda.cutlass_instantiation_level
# filename and filepath
request_key: str = get_config_request_key(arch, version, instantiation_level)
filename: str = _generate_config_filename(request_key)
filepath: str = os.path.join(cache_dir(), filename)
# try fetch
serialized_ops: Optional[list[str]] = None
start_time = time.time()
if os.path.isfile(filepath):
# locally
with open(filepath) as f:
serialized_ops = json.load(f)
elif config.is_fbcode():
from torch._inductor.fb.cutlass_remote_cache import (
maybe_fetch_cutlass_configs_from_remote,
)
# from remote
serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath)
if serialized_ops is None:
return None
# deserialize
serializer = get_cutlass_operation_serializer()
full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr]
log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time)
return full_ops

View File

@ -10,7 +10,6 @@ from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.select_algorithm import create_inputs_key
from torch._inductor.utils import clear_on_fresh_inductor_cache
@ -931,14 +930,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
log.debug("Using cached ops for %s", self.cache_key)
return self.filtered_ops_cache[self.cache_key]
maybe_ops = maybe_fetch_ops()
if maybe_ops is None:
log.debug("Cannot fetch ops from cache, generating ops from scratch")
full_ops = cutlass_utils.gen_ops()
ops = pytree.tree_flatten(full_ops)[0]
else:
log.debug("Using cached ops from cache")
ops = maybe_ops
res: dict[str, cutlass_gemm_op.GemmOperation] = {}
start_time = time.time()

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import enum
import functools
import json
from enum import Enum
from typing import Optional
@ -459,7 +458,6 @@ class CUTLASSOperationSerializer:
return enum_class[json_dict["name"]]
@functools.lru_cache(1)
def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]:
if not try_import_cutlass():
return None