Revert "Support for expandable segments with cuda graph trees (#128068)"

This reverts commit fdc83610f272610ce50d1a6f5b6354f2df1baabb.

Reverted https://github.com/pytorch/pytorch/pull/128068 on behalf of https://github.com/janeyx99 due to Reverting for breaking ROCm tests on trunk, I think the tests need to be qualified with @onlyCUDA ([comment](https://github.com/pytorch/pytorch/pull/128068#issuecomment-2223672381))
This commit is contained in:
PyTorch MergeBot
2024-07-11 18:58:13 +00:00
parent 1cae60a87e
commit 578388bed8
7 changed files with 16 additions and 196 deletions

View File

@ -1543,16 +1543,6 @@ class DeviceCachingAllocator {
// allocate all blocks in the segment
for (size_t i = 0; i < segment_len; ++i) {
// The last block in every expandable segment is the remaining amount of
// available unmapped virtual address space. We shouldn't change it but
// instead check it is correctly formed then skip over allocating it.
if (i == segment_len - 1 && curr_block->expandable_segment_) {
TORCH_CHECK(curr_block->next == nullptr);
TORCH_CHECK(!curr_block->mapped);
TORCH_CHECK(curr_block->allocated == false);
continue;
}
auto& block_state = segment.blocks.at(i);
AllocParams params(
block_state.device,
@ -1567,11 +1557,8 @@ class DeviceCachingAllocator {
// splitting a block depends on `max_split_size`, which may have changed
// between whe checkpoint was taken and now, so we make sure to recreate
// the behavior from the checkpoint. Keep splitting as long as there is
// space left in the block because the block is already the size of how it
// appears in the segment, so any leftover space belongs to the next
// block.
bool split = curr_block->size - block_state.size > 0;
// the behavior from the checkpoint.
bool split = (i + 1) < segment.blocks.size();
// curr_block will become next pointer if it is split, so reassign with
// the returned value
@ -1594,13 +1581,6 @@ class DeviceCachingAllocator {
curr_block = last_block;
for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) {
if (i == segment_len - 1 && curr_block->expandable_segment_) {
TORCH_CHECK(curr_block->next == nullptr);
TORCH_CHECK(!curr_block->mapped);
TORCH_CHECK(curr_block->allocated == false);
continue;
}
auto& block_state = segment.blocks.at(i);
TORCH_INTERNAL_ASSERT(curr_block != nullptr);
@ -2450,15 +2430,15 @@ class DeviceCachingAllocator {
total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation;
return false;
} else if (CUDAAllocatorConfig::expandable_segments()) {
} else if (
CUDAAllocatorConfig::expandable_segments() &&
// our checkpointing logic for private pools doesn't support
// the expandable_segments_ structure yet
!p.pool->owner_PrivatePool) {
p.block = try_allocate_expandable_block(
p.device(), p.stream(), p.pool, p.size(), ctx);
if (p.block) {
p.err = cudaSuccess;
if (p.pool->owner_PrivatePool) {
// The block is for a CUDA graph's PrivatePool.
p.pool->owner_PrivatePool->cudaMalloc_count++;
}
} else {
p.err = cudaErrorMemoryAllocation;
}
@ -2701,13 +2681,6 @@ class DeviceCachingAllocator {
decrease_stat(stats.reserved_bytes[stat_type], unmapped.size);
});
if (block->pool->owner_PrivatePool) {
// The cudaFreed block belonged to a CUDA graph's PrivatePool.
TORCH_INTERNAL_ASSERT(
block->pool->owner_PrivatePool->cudaMalloc_count > 0);
block->pool->owner_PrivatePool->cudaMalloc_count--;
}
stats.num_device_free++;
record_trace(
TraceEntry::SEGMENT_UNMAP,

View File

@ -1,32 +0,0 @@
# Owner(s): ["module: cuda"]
# run time cuda tests, but with the allocator using expandable segments
import os
import pathlib
import sys
import torch
from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
from torch.testing._internal.common_utils import run_tests
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from dynamo.test_cudagraphs import TestAotCudagraphs # noqa: F401
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import get_disabled_tests
# Make sure to remove REPO_ROOT after import is done
sys.path.remove(str(REPO_ROOT))
if __name__ == "__main__":
if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS:
get_disabled_tests(".")
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
run_tests()

View File

@ -1,40 +0,0 @@
# Owner(s): ["module: cuda"]
# run time cuda tests, but with the allocator using expandable segments
import os
import pathlib
import sys
import torch
from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import HAS_CUDA
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if HAS_CUDA and not TEST_WITH_ASAN:
from inductor.test_cudagraph_trees import CudaGraphTreeTests # noqa: F401
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import get_disabled_tests
# Make sure to remove REPO_ROOT after import is done
sys.path.remove(str(REPO_ROOT))
if __name__ == "__main__":
if (
torch.cuda.is_available()
and not IS_JETSON
and not IS_WINDOWS
and HAS_CUDA
and not TEST_WITH_ASAN
):
get_disabled_tests(".")
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
run_tests()

View File

@ -41,7 +41,6 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker
from torch.testing._internal.common_utils import (
EXPANDABLE_SEGMENTS,
freeze_rng_state,
gcIfJetson,
get_cycles_per_ms,
@ -118,10 +117,6 @@ class TestCuda(TestCase):
del self.autocast_lists
super().tearDown()
@property
def expandable_segments(self):
return EXPANDABLE_SEGMENTS
def test_pinned_memory_with_cudaregister(self):
torch.cuda.memory._set_allocator_settings(
"pinned_use_cuda_host_register:True,pinned_num_register_threads:8"
@ -2912,21 +2907,6 @@ exit(2)
for stat, expected in zip(stats_to_check, expecteds):
stat = stat + pool_string + ".current"
current = postcapture_stats[stat] - precapture_stats[stat]
# There will only ever be one expandable segment in each of the small and large pools. The way the
# bookeeping is done in the allocator means that we never increment the number of segments.
if self.expandable_segments and "segment" in stat:
expected = 0
# These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
# expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
# smaller than the page size
if (
self.expandable_segments
and "reserved" in stat
and (numel == cases[3][0] or numel == cases[4][0])
):
expected = 2 * kLargeBuffer
self.assertEqual(
current,
expected,
@ -2953,27 +2933,6 @@ exit(2)
for stat, expected in zip(stats_to_check, expecteds):
stat = stat + pool_string + ".current"
current = postdel_stats[stat] - precapture_stats[stat]
# There will only ever be one expandable segment in each of the small and large pools. The way the
# bookeeping is done in the allocator means that we never increment the number of segments.
if self.expandable_segments and "segment" in stat:
expected = 0
# These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
# expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
# smaller than the page size
if (
self.expandable_segments
and "reserved" in stat
and numel == cases[3][0]
):
expected = 2 * kLargeBuffer
if (
self.expandable_segments
and "reserved" in stat
and numel == cases[4][0]
):
expected = kLargeBuffer
self.assertEqual(
current,
expected,
@ -4565,10 +4524,6 @@ def reconstruct_from_tensor_metadata(metadata):
@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestBlockStateAbsorption(TestCase):
@property
def expandable_segments(self):
return EXPANDABLE_SEGMENTS
def checkCheckpointedBlock(self, before_block, after_block):
for field in ("size", "state"):
self.assertEqual(before_block[field], after_block[field])
@ -4896,9 +4851,7 @@ class TestBlockStateAbsorption(TestCase):
graph_thread.join()
no_graph_thread.join()
self.assertEqual(
len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4
)
self.assertEqual(len(get_cudagraph_segments(pool)), 4)
del graph

View File

@ -1,34 +1,15 @@
# Owner(s): ["module: cuda"]
# run time cuda tests, but with the allocator using expandable segments
import pathlib
import sys
from test_cuda import ( # noqa: F401
TestBlockStateAbsorption,
TestCuda,
TestCudaMallocAsync,
)
import os
import torch
from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
from torch.testing._internal.common_utils import run_tests
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from tools.stats.import_test_stats import get_disabled_tests
# Make sure to remove REPO_ROOT after import is done
sys.path.remove(str(REPO_ROOT))
if __name__ == "__main__":
if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS:
get_disabled_tests(".")
from torch.testing._internal.common_cuda import IS_JETSON
if torch.cuda.is_available() and not IS_JETSON:
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
TestCuda.expandable_segments = lambda _: True
TestBlockStateAbsorption.expandable_segments = lambda _: True
run_tests()
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(current_dir, "test_cuda.py")
exec(compile(open(filepath).read(), filepath, mode="exec"))

View File

@ -1690,7 +1690,7 @@ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWr
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
)
if len(allocated_not_in_live_storages) != 0:
if allocated_not_in_live_storages != 0:
formatted = []
for dp, block in allocated_not_in_live_storages.items():
trace = format_tb(block.get("frames", []))

View File

@ -1360,21 +1360,6 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and ( # noqa: F821
(torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
)
def allocator_option_enabled_fn(allocator_config, _, option):
if allocator_config is None:
return False
allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config]
mapping = dict([var.split(':') for var in allocator_config])
if option in mapping and mapping[option] == 'True':
return True
else:
return False
TestEnvironment.def_flag("EXPANDABLE_SEGMENTS",
env_var="PYTORCH_CUDA_ALLOC_CONF",
enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'))
if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2"))
gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30