Compare commits

...

2 Commits

Author SHA1 Message Date
35b5994ed8 Update
[ghstack-poisoned]
2025-11-12 23:57:49 +00:00
4e2045e211 Update (base update)
[ghstack-poisoned]
2025-11-12 23:57:49 +00:00
11 changed files with 40 additions and 42 deletions

View File

@ -12,22 +12,20 @@ constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB
AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() {
static AcceleratorAllocatorConfig instance;
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env) \
auto env##_name = c10::utils::get_env(#env); \
if (env##_name.has_value()) { \
instance.parseArgs(env##_name.value()); \
return true; \
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \
auto env##_name = c10::utils::get_env(#env); \
if (env##_name.has_value()) { \
if (deprecated) { \
TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \
} \
instance.parseArgs(env##_name.value()); \
return true; \
}
static bool env_flag [[maybe_unused]] = []() {
// Parse allocator configuration from environment variables.
// The first two entries are kept for backward compatibility with legacy
// CUDA and HIP environment variable names. The new unified variable
// (PYTORCH_ALLOC_CONF) should be used going forward.
// Note: keep the parsing order and logic stable to avoid potential
// performance regressions in internal tests.
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false)
// Keep this for backwards compatibility
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true)
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true)
return false;
}();
#undef C10_ALLOCATOR_CONFIG_PARSE_ENV

View File

@ -120,18 +120,16 @@ class C10_CUDA_API CUDAAllocatorConfig {
static CUDAAllocatorConfig& instance() {
static CUDAAllocatorConfig* s_instance = ([]() {
auto inst = new CUDAAllocatorConfig();
auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
}
#ifdef USE_ROCM
// convenience for ROCm users, allow alternative HIP token
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
// Note: keep the parsing order and logic stable to avoid potential
// performance regressions in internal tests.
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
}
if (env.has_value()) {
inst->parseArgs(env.value());
}

View File

@ -1566,7 +1566,7 @@ class DeviceCachingAllocator {
reserved_bytes - allocated_bytes - allocated_in_private_pools),
" is reserved by PyTorch but unallocated.",
" If reserved but unallocated memory is large try setting",
" PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid"
" PYTORCH_ALLOC_CONF=expandable_segments:True to avoid"
" fragmentation. See documentation for Memory Management "
" (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)");
}
@ -4449,16 +4449,16 @@ struct BackendStaticInitializer {
// instance based on its value. If no valid configuration is found, it falls
// back to the default Native allocator.
CUDAAllocator* parseEnvForBackend() {
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
auto val = c10::utils::get_env("PYTORCH_ALLOC_CONF");
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
}
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_ALLOC_CONF");
}
if (val.has_value()) {
c10::CachingAllocator::ConfigTokenizer tokenizer(val.value());
for (size_t i = 0; i < tokenizer.size(); i++) {

View File

@ -13,8 +13,8 @@ For more information on CUDA runtime environment variables, see `CUDA Environmen
- Description
* - ``PYTORCH_NO_CUDA_MEMORY_CACHING``
- If set to ``1``, disables caching of memory allocations in CUDA. This can be useful for debugging.
* - ``PYTORCH_CUDA_ALLOC_CONF``
- For a more in depth explanation of this environment variable, see :ref:`cuda-memory-management`.
* - ``PYTORCH_ALLOC_CONF``
- For a more in depth explanation of this environment variable, see :ref:`cuda-memory-management`. ``PYTORCH_CUDA_ALLOC_CONF`` is deprecated and is provided only for backward compatibility.
* - ``PYTORCH_NVML_BASED_CUDA_CHECK``
- If set to ``1``, before importing PyTorch modules that check if CUDA is available, PyTorch will use NVML to check if the CUDA driver is functional instead of using the CUDA runtime. This can be helpful if forked processes fail with a CUDA initialization error.
* - ``TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT``

View File

@ -1,6 +1,6 @@
.. meta::
:description: A guide to torch.cuda, a PyTorch module to run CUDA operations
:keywords: memory management, PYTORCH_CUDA_ALLOC_CONF, optimize PyTorch, CUDA
:keywords: memory management, PYTORCH_ALLOC_CONF, optimize PyTorch, CUDA
.. _cuda-semantics:
@ -488,7 +488,7 @@ underlying allocation patterns produced by your code.
.. _cuda-memory-envvars:
Optimizing memory usage with ``PYTORCH_CUDA_ALLOC_CONF``
Optimizing memory usage with ``PYTORCH_ALLOC_CONF``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Use of a caching allocator can interfere with memory checking tools such as
@ -496,8 +496,9 @@ Use of a caching allocator can interfere with memory checking tools such as
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
The behavior of the caching allocator can be controlled via the environment variable
``PYTORCH_CUDA_ALLOC_CONF``.
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
``PYTORCH_ALLOC_CONF``. ``PYTORCH_CUDA_ALLOC_CONF`` is deprecated and is provided only
for backward compatibility.
The format is ``PYTORCH_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
Available options:
* ``backend`` allows selecting the underlying allocator implementation.
@ -699,7 +700,7 @@ Mixing different CUDA system allocators in the same program
-----------------------------------------------------------
Depending on your use case, :meth:`~torch.cuda.change_current_allocator` may not be what you
want to use, since it swaps the CUDA allocator for the entire program (similar to
``PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync``). For instance, if the swapped allocator doesn't
``PYTORCH_ALLOC_CONF=backend:cudaMallocAsync``). For instance, if the swapped allocator doesn't
have caching mechanism, you will lose all the benefits of PyTorch's CUDACachingAllocator. Instead,
you can selectively mark a region of PyTorch code to use a custom allocator using
:class:`torch.cuda.MemPool`. This will let you use multiple CUDA system allocators in the same

View File

@ -4629,7 +4629,7 @@ print(torch.cuda.get_allocator_backend())
def test_allocator_memory_fraction_setting(self):
def make_env(fraction):
env = os.environ.copy()
var = "PYTORCH_CUDA_ALLOC_CONF"
var = "PYTORCH_ALLOC_CONF"
key = "per_process_memory_fraction"
value = [
x

View File

@ -909,7 +909,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
}
py::dict allocator_settings;
py::str last_allocator_settings_s = "PYTORCH_CUDA_ALLOC_CONF";
py::str last_allocator_settings_s = "PYTORCH_ALLOC_CONF";
py::str max_split_size_s = "max_split_size";
py::str garbage_collection_threshold_s = "garbage_collection_threshold";
py::str expandable_segments_s = "expandable_segments";

View File

@ -453,7 +453,7 @@ std::string _memory_snapshot_pickled() {
}
auto allocator_settings = new_dict();
IValue last_allocator_settings_s = "PYTORCH_CUDA_ALLOC_CONF";
IValue last_allocator_settings_s = "PYTORCH_ALLOC_CONF";
IValue max_split_size_s = "max_split_size";
IValue garbage_collection_threshold_s = "garbage_collection_threshold";
IValue expandable_segments_s = "expandable_segments";

View File

@ -1306,7 +1306,7 @@ def _set_allocator_settings(env: str):
def get_allocator_backend() -> str:
r"""Return a string describing the active allocator backend as set by
``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
``PYTORCH_ALLOC_CONF``. Currently available backends are
``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
(CUDA's built-in asynchronous allocator).

View File

@ -1601,7 +1601,7 @@ def allocator_option_enabled_fn(allocator_config, _, option):
EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag(
"EXPANDABLE_SEGMENTS",
env_var="PYTORCH_CUDA_ALLOC_CONF",
env_var="PYTORCH_ALLOC_CONF",
enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'),
)

View File

@ -654,11 +654,12 @@ def get_pip_packages(run_lambda, patterns=None):
return pip_version, filtered_out
def get_cachingallocator_config():
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if not ca_config:
ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "")
return ca_config
def get_cachingallocator_config() -> str:
"""Return the caching allocator configuration from environment variables."""
for var in ("PYTORCH_ALLOC_CONF", "PYTORCH_CUDA_ALLOC_CONF", "PYTORCH_HIP_ALLOC_CONF"):
if config := os.environ.get(var):
return config
return ""
def get_cuda_module_loading_config():