Cache the get_device_module result (#149207)

Summary: As title.

Test Plan: OSS CIs.

Reviewed By: chaos5958

Differential Revision: D71084180

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149207
Approved by: https://github.com/jansel
This commit is contained in:
Jun Luo
2025-03-19 03:20:38 +00:00
committed by PyTorch MergeBot
parent 01a57981aa
commit 14dc6e732d

View File

@ -12,6 +12,7 @@ on an NVIDIA GPU with compute capability >= 3.0.
import builtins
import ctypes
import functools
import glob
import importlib
import inspect
@ -2690,6 +2691,7 @@ else:
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@functools.cache
def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
"""
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).