mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Support for expandable segments with cuda graph trees (#128068)"
This reverts commit fdc83610f272610ce50d1a6f5b6354f2df1baabb. Reverted https://github.com/pytorch/pytorch/pull/128068 on behalf of https://github.com/janeyx99 due to Reverting for breaking ROCm tests on trunk, I think the tests need to be qualified with @onlyCUDA ([comment](https://github.com/pytorch/pytorch/pull/128068#issuecomment-2223672381))
This commit is contained in:
@ -1543,16 +1543,6 @@ class DeviceCachingAllocator {
|
||||
|
||||
// allocate all blocks in the segment
|
||||
for (size_t i = 0; i < segment_len; ++i) {
|
||||
// The last block in every expandable segment is the remaining amount of
|
||||
// available unmapped virtual address space. We shouldn't change it but
|
||||
// instead check it is correctly formed then skip over allocating it.
|
||||
if (i == segment_len - 1 && curr_block->expandable_segment_) {
|
||||
TORCH_CHECK(curr_block->next == nullptr);
|
||||
TORCH_CHECK(!curr_block->mapped);
|
||||
TORCH_CHECK(curr_block->allocated == false);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& block_state = segment.blocks.at(i);
|
||||
AllocParams params(
|
||||
block_state.device,
|
||||
@ -1567,11 +1557,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
// splitting a block depends on `max_split_size`, which may have changed
|
||||
// between whe checkpoint was taken and now, so we make sure to recreate
|
||||
// the behavior from the checkpoint. Keep splitting as long as there is
|
||||
// space left in the block because the block is already the size of how it
|
||||
// appears in the segment, so any leftover space belongs to the next
|
||||
// block.
|
||||
bool split = curr_block->size - block_state.size > 0;
|
||||
// the behavior from the checkpoint.
|
||||
bool split = (i + 1) < segment.blocks.size();
|
||||
|
||||
// curr_block will become next pointer if it is split, so reassign with
|
||||
// the returned value
|
||||
@ -1594,13 +1581,6 @@ class DeviceCachingAllocator {
|
||||
curr_block = last_block;
|
||||
|
||||
for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) {
|
||||
if (i == segment_len - 1 && curr_block->expandable_segment_) {
|
||||
TORCH_CHECK(curr_block->next == nullptr);
|
||||
TORCH_CHECK(!curr_block->mapped);
|
||||
TORCH_CHECK(curr_block->allocated == false);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& block_state = segment.blocks.at(i);
|
||||
TORCH_INTERNAL_ASSERT(curr_block != nullptr);
|
||||
|
||||
@ -2450,15 +2430,15 @@ class DeviceCachingAllocator {
|
||||
total_allocated_memory + size > allowed_memory_maximum) {
|
||||
p.err = cudaErrorMemoryAllocation;
|
||||
return false;
|
||||
} else if (CUDAAllocatorConfig::expandable_segments()) {
|
||||
} else if (
|
||||
CUDAAllocatorConfig::expandable_segments() &&
|
||||
// our checkpointing logic for private pools doesn't support
|
||||
// the expandable_segments_ structure yet
|
||||
!p.pool->owner_PrivatePool) {
|
||||
p.block = try_allocate_expandable_block(
|
||||
p.device(), p.stream(), p.pool, p.size(), ctx);
|
||||
if (p.block) {
|
||||
p.err = cudaSuccess;
|
||||
if (p.pool->owner_PrivatePool) {
|
||||
// The block is for a CUDA graph's PrivatePool.
|
||||
p.pool->owner_PrivatePool->cudaMalloc_count++;
|
||||
}
|
||||
} else {
|
||||
p.err = cudaErrorMemoryAllocation;
|
||||
}
|
||||
@ -2701,13 +2681,6 @@ class DeviceCachingAllocator {
|
||||
decrease_stat(stats.reserved_bytes[stat_type], unmapped.size);
|
||||
});
|
||||
|
||||
if (block->pool->owner_PrivatePool) {
|
||||
// The cudaFreed block belonged to a CUDA graph's PrivatePool.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
block->pool->owner_PrivatePool->cudaMalloc_count > 0);
|
||||
block->pool->owner_PrivatePool->cudaMalloc_count--;
|
||||
}
|
||||
|
||||
stats.num_device_free++;
|
||||
record_trace(
|
||||
TraceEntry::SEGMENT_UNMAP,
|
||||
|
@ -1,32 +0,0 @@
|
||||
# Owner(s): ["module: cuda"]
|
||||
# run time cuda tests, but with the allocator using expandable segments
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
from dynamo.test_cudagraphs import TestAotCudagraphs # noqa: F401
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
|
||||
# Make sure to remove REPO_ROOT after import is done
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS:
|
||||
get_disabled_tests(".")
|
||||
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
||||
|
||||
run_tests()
|
@ -1,40 +0,0 @@
|
||||
# Owner(s): ["module: cuda"]
|
||||
# run time cuda tests, but with the allocator using expandable segments
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
from inductor.test_cudagraph_trees import CudaGraphTreeTests # noqa: F401
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
|
||||
# Make sure to remove REPO_ROOT after import is done
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and not IS_JETSON
|
||||
and not IS_WINDOWS
|
||||
and HAS_CUDA
|
||||
and not TEST_WITH_ASAN
|
||||
):
|
||||
get_disabled_tests(".")
|
||||
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
||||
|
||||
run_tests()
|
@ -41,7 +41,6 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker
|
||||
from torch.testing._internal.common_utils import (
|
||||
EXPANDABLE_SEGMENTS,
|
||||
freeze_rng_state,
|
||||
gcIfJetson,
|
||||
get_cycles_per_ms,
|
||||
@ -118,10 +117,6 @@ class TestCuda(TestCase):
|
||||
del self.autocast_lists
|
||||
super().tearDown()
|
||||
|
||||
@property
|
||||
def expandable_segments(self):
|
||||
return EXPANDABLE_SEGMENTS
|
||||
|
||||
def test_pinned_memory_with_cudaregister(self):
|
||||
torch.cuda.memory._set_allocator_settings(
|
||||
"pinned_use_cuda_host_register:True,pinned_num_register_threads:8"
|
||||
@ -2912,21 +2907,6 @@ exit(2)
|
||||
for stat, expected in zip(stats_to_check, expecteds):
|
||||
stat = stat + pool_string + ".current"
|
||||
current = postcapture_stats[stat] - precapture_stats[stat]
|
||||
|
||||
# There will only ever be one expandable segment in each of the small and large pools. The way the
|
||||
# bookeeping is done in the allocator means that we never increment the number of segments.
|
||||
if self.expandable_segments and "segment" in stat:
|
||||
expected = 0
|
||||
# These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
|
||||
# expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
|
||||
# smaller than the page size
|
||||
if (
|
||||
self.expandable_segments
|
||||
and "reserved" in stat
|
||||
and (numel == cases[3][0] or numel == cases[4][0])
|
||||
):
|
||||
expected = 2 * kLargeBuffer
|
||||
|
||||
self.assertEqual(
|
||||
current,
|
||||
expected,
|
||||
@ -2953,27 +2933,6 @@ exit(2)
|
||||
for stat, expected in zip(stats_to_check, expecteds):
|
||||
stat = stat + pool_string + ".current"
|
||||
current = postdel_stats[stat] - precapture_stats[stat]
|
||||
|
||||
# There will only ever be one expandable segment in each of the small and large pools. The way the
|
||||
# bookeeping is done in the allocator means that we never increment the number of segments.
|
||||
if self.expandable_segments and "segment" in stat:
|
||||
expected = 0
|
||||
# These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
|
||||
# expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
|
||||
# smaller than the page size
|
||||
if (
|
||||
self.expandable_segments
|
||||
and "reserved" in stat
|
||||
and numel == cases[3][0]
|
||||
):
|
||||
expected = 2 * kLargeBuffer
|
||||
if (
|
||||
self.expandable_segments
|
||||
and "reserved" in stat
|
||||
and numel == cases[4][0]
|
||||
):
|
||||
expected = kLargeBuffer
|
||||
|
||||
self.assertEqual(
|
||||
current,
|
||||
expected,
|
||||
@ -4565,10 +4524,6 @@ def reconstruct_from_tensor_metadata(metadata):
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestBlockStateAbsorption(TestCase):
|
||||
@property
|
||||
def expandable_segments(self):
|
||||
return EXPANDABLE_SEGMENTS
|
||||
|
||||
def checkCheckpointedBlock(self, before_block, after_block):
|
||||
for field in ("size", "state"):
|
||||
self.assertEqual(before_block[field], after_block[field])
|
||||
@ -4896,9 +4851,7 @@ class TestBlockStateAbsorption(TestCase):
|
||||
graph_thread.join()
|
||||
no_graph_thread.join()
|
||||
|
||||
self.assertEqual(
|
||||
len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4
|
||||
)
|
||||
self.assertEqual(len(get_cudagraph_segments(pool)), 4)
|
||||
|
||||
del graph
|
||||
|
||||
|
@ -1,34 +1,15 @@
|
||||
# Owner(s): ["module: cuda"]
|
||||
# run time cuda tests, but with the allocator using expandable segments
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from test_cuda import ( # noqa: F401
|
||||
TestBlockStateAbsorption,
|
||||
TestCuda,
|
||||
TestCudaMallocAsync,
|
||||
)
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
|
||||
# Make sure to remove REPO_ROOT after import is done
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS:
|
||||
get_disabled_tests(".")
|
||||
from torch.testing._internal.common_cuda import IS_JETSON
|
||||
|
||||
if torch.cuda.is_available() and not IS_JETSON:
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
||||
TestCuda.expandable_segments = lambda _: True
|
||||
TestBlockStateAbsorption.expandable_segments = lambda _: True
|
||||
|
||||
run_tests()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filepath = os.path.join(current_dir, "test_cuda.py")
|
||||
exec(compile(open(filepath).read(), filepath, mode="exec"))
|
||||
|
@ -1690,7 +1690,7 @@ def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWr
|
||||
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
|
||||
)
|
||||
|
||||
if len(allocated_not_in_live_storages) != 0:
|
||||
if allocated_not_in_live_storages != 0:
|
||||
formatted = []
|
||||
for dp, block in allocated_not_in_live_storages.items():
|
||||
trace = format_tb(block.get("frames", []))
|
||||
|
@ -1360,21 +1360,6 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and ( # noqa: F821
|
||||
(torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
|
||||
)
|
||||
|
||||
def allocator_option_enabled_fn(allocator_config, _, option):
|
||||
if allocator_config is None:
|
||||
return False
|
||||
allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config]
|
||||
mapping = dict([var.split(':') for var in allocator_config])
|
||||
|
||||
if option in mapping and mapping[option] == 'True':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
TestEnvironment.def_flag("EXPANDABLE_SEGMENTS",
|
||||
env_var="PYTORCH_CUDA_ALLOC_CONF",
|
||||
enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'))
|
||||
|
||||
if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
|
||||
num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2"))
|
||||
gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30
|
||||
|
Reference in New Issue
Block a user