Revert "[ROCm] enable HIPMallocAsyncAllocator (#149145)"

This reverts commit ee1a2b7810126258ce64d1e22b59fae81a3f7bcb.

Reverted https://github.com/pytorch/pytorch/pull/149145 on behalf of https://github.com/izaitsevfb due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/149145#issuecomment-2738115728))
This commit is contained in:
PyTorch MergeBot
2025-03-19 21:12:13 +00:00
parent 37bb7f79c6
commit e1d143cb7b
6 changed files with 8 additions and 80 deletions

View File

@ -220,40 +220,6 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync) {
// For ease of maintenance and understanding, the CUDA and ROCm
// implementations of this function are separated. This avoids having many
// #ifdef's throughout.
#ifdef USE_ROCM
// Ease burden on ROCm users by allowing either cuda or hip tokens.
// cuda token is broken up to prevent hipify matching it.
#define PYTORCH_TOKEN1 \
"cud" \
"aMallocAsync"
#define PYTORCH_TOKEN2 "hipMallocAsync"
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
(config[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
used_cudaMallocAsync =
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
TORCH_INTERNAL_ASSERT(
config[i] == get()->name() ||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
config[i],
" != ",
get()->name());
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
return i;
#undef PYTORCH_TOKEN1
#undef PYTORCH_TOKEN2
#else // USE_ROCM
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
@ -261,6 +227,8 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
#ifndef USE_ROCM
// HIP supports hipMallocAsync and does not need to check versions
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version = 0;
@ -278,6 +246,7 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
CUDA_VERSION);
#endif
}
#endif
TORCH_INTERNAL_ASSERT(
config[i] == get()->name(),
"Allocator backend parsed at runtime != "
@ -286,7 +255,6 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
TORCH_CHECK(false, "Error parsing backend value", "");
}
return i;
#endif // USE_ROCM
}
void CUDAAllocatorConfig::parseArgs(const char* env) {

View File

@ -81,12 +81,6 @@ class C10_CUDA_API CUDAAllocatorConfig {
static CUDAAllocatorConfig* s_instance = ([]() {
auto inst = new CUDAAllocatorConfig();
const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users, allow alternative HIP token
if (!env) {
env = getenv("PYTORCH_HIP_ALLOC_CONF");
}
#endif
inst->parseArgs(env);
return inst;
})();

View File

@ -3955,13 +3955,7 @@ struct BackendStaticInitializer {
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
// works, maybe we should move all of CUDAAllocatorConfig here?
CUDAAllocator* parseEnvForBackend() {
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
const auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
if (val.has_value()) {
const std::string& config = val.value();
@ -3977,15 +3971,7 @@ struct BackendStaticInitializer {
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (kv[1] ==
"cud"
"aMallocAsync" ||
kv[1] == "hipMallocAsync")
#else
if (kv[1] == "cudaMallocAsync")
#endif
return CudaMallocAsync::allocator();
if (kv[1] == "native")
return &Native::allocator;

View File

@ -14,7 +14,7 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
#if CUDA_VERSION >= 11040 || defined(USE_ROCM)
#if CUDA_VERSION >= 11040
// CUDA device allocator that uses cudaMallocAsync to implement
// the same interface as CUDACachingAllocator.cpp.
@ -504,9 +504,9 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
CUDAGuard g(static_cast<c10::DeviceIndex>(dev));
cudaMemPool_t mempool = nullptr;
C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev));
C10_CUDA_CHECK(cudaDeviceSynchronize());
C10_CUDA_CHECK(cudaMemPoolTrimTo(mempool, 0));
cudaDeviceGetDefaultMemPool(&mempool, dev);
cudaDeviceSynchronize();
cudaMemPoolTrimTo(mempool, 0);
}
}
}

View File

@ -456,8 +456,6 @@ def get_pip_packages(run_lambda, patterns=None):
def get_cachingallocator_config():
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
if not ca_config:
ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '')
return ca_config

View File

@ -4051,23 +4051,6 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("hipMemset3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED),
),
("cudaMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_RUNTIME)),
("cudaDeviceGetDefaultMemPool", ("hipDeviceGetDefaultMemPool", CONV_MEM, API_RUNTIME)),
("cudaMemAccessDesc", ("hipMemAccessDesc", CONV_MEM, API_RUNTIME)),
("cudaMemAccessFlagsProtReadWrite", ("hipMemAccessFlagsProtReadWrite", CONV_MEM, API_RUNTIME)),
("cudaMemLocationTypeDevice", ("hipMemLocationTypeDevice", CONV_MEM, API_RUNTIME)),
("cudaMemPoolAttrReleaseThreshold", ("hipMemPoolAttrReleaseThreshold", CONV_MEM, API_RUNTIME)),
("cudaMemPoolAttrReservedMemCurrent", ("hipMemPoolAttrReservedMemCurrent", CONV_MEM, API_RUNTIME)),
("cudaMemPoolAttrReservedMemHigh", ("hipMemPoolAttrReservedMemHigh", CONV_MEM, API_RUNTIME)),
("cudaMemPoolAttrUsedMemCurrent", ("hipMemPoolAttrUsedMemCurrent", CONV_MEM, API_RUNTIME)),
("cudaMemPoolAttrUsedMemHigh", ("hipMemPoolAttrUsedMemHigh", CONV_MEM, API_RUNTIME)),
("cudaMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_RUNTIME)),
("cudaMemPoolReuseAllowInternalDependencies", ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME)),
("cudaMemPoolReuseAllowOpportunistic", ("hipMemPoolReuseAllowOpportunistic", CONV_MEM, API_RUNTIME)),
("cudaMemPoolReuseFollowEventDependencies", ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME)),
("cudaMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_RUNTIME)),
("cudaMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_RUNTIME)),
("cudaMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_RUNTIME)),
("cudaMemPool_t", ("hipMemPool_t", CONV_MEM, API_RUNTIME)),
(
"cudaArrayGetInfo",
("hipArrayGetInfo", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED),
@ -8608,7 +8591,6 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
[
("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)),
("PYTORCH_CUDA_ALLOC_CONF", ("PYTORCH_CUDA_ALLOC_CONF", API_CAFFE2)),
("cuda_stream", ("hip_stream", API_CAFFE2)),
# if the header is a native hip folder (under hip directory),
# there is no need to add a hip path to it; the trie in hipify script