mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Cache the get_device_module result (#149207)
Summary: As title. Test Plan: OSS CIs. Reviewed By: chaos5958 Differential Revision: D71084180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149207 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
01a57981aa
commit
14dc6e732d
@ -12,6 +12,7 @@ on an NVIDIA GPU with compute capability >= 3.0.
|
||||
|
||||
import builtins
|
||||
import ctypes
|
||||
import functools
|
||||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
@ -2690,6 +2691,7 @@ else:
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
|
||||
"""
|
||||
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
|
||||
|
Reference in New Issue
Block a user