mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pre_compile] Add check for cuda and hardware version (#162438)
if we detect compiled model is using cuda in meaningful way, we should store information about cuda + hardware Example: `SystemInfo(python_version='3.12.9', torch_version='2.9.0a0+gite02b0e6', cuda_version='12.6', triton_version=(3, 4), gpu_name='NVIDIA PG509-210')` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162438 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
ae97eb86f7
commit
ccb450b190
@ -19,7 +19,6 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
@ -30,7 +29,12 @@ from typing_extensions import Never
|
||||
import torch
|
||||
import torch._inductor.package
|
||||
from torch._dynamo.exc import PackageError
|
||||
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
|
||||
from torch._dynamo.graph_utils import _graph_uses_non_cpu
|
||||
from torch._dynamo.precompile_context import (
|
||||
PrecompileCacheArtifact,
|
||||
PrecompileContext,
|
||||
SystemInfo,
|
||||
)
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
from torch.compiler._cache import CacheArtifactFactory
|
||||
|
||||
@ -275,13 +279,18 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
|
||||
class _DynamoCacheEntry:
|
||||
codes: list[_DynamoCodeCacheEntry]
|
||||
inlined_sources: set[InlinedSource]
|
||||
python_version: str = platform.python_version()
|
||||
torch_version: str = torch.__version__
|
||||
use_cuda: bool
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
||||
@property
|
||||
def backend_ids(self) -> set[_BackendId]:
|
||||
return {backend_id for code in self.codes for backend_id in code.backend_ids}
|
||||
|
||||
def check_versions(self) -> None:
|
||||
"""Check if the current system is compatible with the system used to create this cache entry."""
|
||||
current_system_info = SystemInfo.current()
|
||||
self.system_info.check_compatibility(current_system_info, self.use_cuda)
|
||||
|
||||
|
||||
@CacheArtifactFactory.register
|
||||
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
|
||||
@ -369,6 +378,8 @@ class CompilePackage:
|
||||
|
||||
self._current_entry: Optional[_DynamoCodeCacheEntry] = None
|
||||
self._installed_globals: dict[types.ModuleType, list[str]] = {}
|
||||
# whether cuda is used
|
||||
self._use_cuda = False
|
||||
|
||||
# For debugging/testing purpose only.
|
||||
self._cached_backends: dict[_BackendId, Any] = {}
|
||||
@ -397,14 +408,7 @@ class CompilePackage:
|
||||
assert self._innermost_fn is not None
|
||||
if dynamo is not None:
|
||||
assert isinstance(dynamo, _DynamoCacheEntry)
|
||||
if dynamo.python_version != platform.python_version():
|
||||
raise RuntimeError(
|
||||
f"Compile package was created with a different Python version: {dynamo.python_version}"
|
||||
)
|
||||
if dynamo.torch_version != torch.__version__:
|
||||
raise RuntimeError(
|
||||
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
|
||||
)
|
||||
dynamo.check_versions()
|
||||
if not ignore_inlined_sources:
|
||||
for code in dynamo.inlined_sources:
|
||||
m = importlib.import_module(code.module)
|
||||
@ -534,6 +538,9 @@ class CompilePackage:
|
||||
)
|
||||
)
|
||||
|
||||
def update_use_cuda(self, graph: Optional[torch.fx.Graph]) -> None:
|
||||
self._use_cuda = _graph_uses_non_cpu(graph)
|
||||
|
||||
def bypass_current_entry(self) -> None:
|
||||
assert self._current_entry is not None
|
||||
self._current_entry.bypassed = True
|
||||
@ -670,7 +677,9 @@ class CompilePackage:
|
||||
def cache_entry(self) -> _DynamoCacheEntry:
|
||||
self.validate()
|
||||
return _DynamoCacheEntry(
|
||||
codes=list(self._codes.values()), inlined_sources=self._inlined_sources
|
||||
codes=list(self._codes.values()),
|
||||
inlined_sources=self._inlined_sources,
|
||||
use_cuda=self._use_cuda,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user