diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f37e492c861f..f23b35047fcc 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +TORCH_API inline void emptyCache() { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->emptyCache(); +} + +TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); +} + +TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); +} + +TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetPeakStats(device_index); +} + } // namespace at::accelerator namespace at { diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index c6f2fb108040..ce593a9acf51 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,3 +25,26 @@ synchronize device_index ``` + +```{eval-rst} +.. automodule:: torch.accelerator.memory +``` +```{eval-rst} +.. currentmodule:: torch.accelerator.memory +``` + +## Memory management +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + reset_accumulated_memory_stats + reset_peak_memory_stats +``` diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9e03c7dba830..fb7e9c5ce56e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2435,6 +2435,11 @@ def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... +def _accelerator_isAllocatorInitialized() -> _bool: ... +def _accelerator_emptyCache() -> None: ... +def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... +def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... +def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e9e48f1cf306..4d1a78df1f74 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,6 +8,16 @@ from typing_extensions import deprecated import torch from ._utils import _device_t, _get_device_index +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) __all__ = [ @@ -15,9 +25,17 @@ __all__ = [ "current_device_idx", # deprecated "current_device_index", "current_stream", + "empty_cache", "device_count", "device_index", "is_available", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py new file mode 100644 index 000000000000..d34a11a3a02e --- /dev/null +++ b/torch/accelerator/memory.py @@ -0,0 +1,201 @@ +from collections import OrderedDict +from typing import Any + +import torch + +from ._utils import _device_t, _get_device_index + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other application. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return + torch._C._accelerator_emptyCache() + + +def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: + r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from device memory allocation. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of June 2025, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of June 2025, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed device memory allocation calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of device memory allocation calls. + - ``"num_device_free"``: number of device memory free calls. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return OrderedDict() + device_index = _get_device_index(device_index, optional=True) + stats = torch._C._accelerator_getDeviceStats(device_index) + flat_stats = [] + + def flatten(prefix: str, value: Any) -> None: + if isinstance(value, dict): + for k, v in value.items(): + nested_prefix = f"{prefix}.{k}" if prefix else k + flatten(nested_prefix, v) + else: + flat_stats.append((prefix, value)) + + flatten("", stats) + flat_stats.sort() + return OrderedDict(flat_stats) + + +def memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory occupied by tensors + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors + in bytes for a given device index. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory managed by the caching allocator + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator + in bytes for a given device index. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.peak", 0) + + +def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetAccumulatedStats(device_index) + + +def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "peak" stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 3a97c0794684..59cb8047467c 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -77,6 +77,70 @@ void initModule(PyObject* module) { m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); + + m.def("_accelerator_isAllocatorInitialized", []() { + const auto device_type = at::accelerator::getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->initialized(); + }); + + m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); + + m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { + using c10::CachingAllocator::Stat; + using c10::CachingAllocator::StatArray; + using c10::CachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + + const auto stats = at::accelerator::getDeviceStats(device_index); + const auto stat_to_dict = [](const Stat& stat) -> py::dict { + py::dict dict; + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { + const std::array(StatType::NUM_TYPES)> + kStatTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(kStatTypeNames.size())) { + dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); + } + return dict; + }; + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); + result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); + result["active_bytes"] = stat_array_to_dict(stats.active_bytes); + result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); + result["allocation"] = stat_array_to_dict(stats.allocation); + result["segment"] = stat_array_to_dict(stats.segment); + result["active"] = stat_array_to_dict(stats.active); + result["inactive_split"] = stat_array_to_dict(stats.inactive_split); + result["inactive_split_bytes"] = + stat_array_to_dict(stats.inactive_split_bytes); + result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); + result["oversize_segments"] = stat_to_dict(stats.oversize_segments); + return result; + }); + + m.def( + "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetAccumulatedStats(device_index); + }); + + m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetPeakStats(device_index); + }); } } // namespace torch::accelerator diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 63e59096162f..1bd6f9edc031 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of October 2019, for size >= 1MB allocations). + (as of June 2025, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of October 2019, for size < 1MB allocations). + (as of June 2025, for size < 1MB allocations). Metric type: