[CUDA] Cleanup per-process-memory-fraction in test_cuda.py tests (#140852)

Otherwise certain sequences of tests will fail with OOM e.g.,
```
# python test/test_cuda.py -k max_split_expandable -k test_assigning_back_deleter_fns_to_tensor  --repeat 100                                                                                                                                                                                                                                                                                          ..                                                                                                                                                                                                                                                                                                                                                                                                                                         ----------------------------------------------------------------------                                                                                                                                                                                                                                                                                                                                                                     Ran 2 tests in 0.311s                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 OK                                                                                                                                                                                                                                                                                                                                                                                                                                         E.                                                                                                                                                                                                                                                                                                                                                                                                                                         ======================================================================                                                                                                                                                                                                                                                                                                                                                                     ERROR: test_assigning_back_deleter_fns_to_tensor (__main__.TestBlockStateAbsorption.test_assigning_back_deleter_fns_to_tensor)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/workspace/pytorch/torch/testing/_internal/common_utils.py", line 3058, in wrapper
    method(*args, **kwargs)
  File "/workspace/pytorch/test/test_cuda.py", line 4320, in test_assigning_back_deleter_fns_to_tensor
    graph, outputs = cudagraphify(foo, [inp])
                     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/pytorch/test/test_cuda.py", line 4080, in cudagraphify
    fn(*inputs)
  File "/workspace/pytorch/test/test_cuda.py", line 4316, in foo
    int8_cuda(LARGE_BUFFER) + x,
    ~~~~~~~~~~~~~~~~~~~~~~~~^~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 160.00 MiB. GPU 0 has a total capacity of 31.73 GiB of which 31.30 GiB is free. Process 2916661 has 442.00 MiB memory in use. 120.00 MiB allowed; Of the allocated memory 52.00 MiB is allocated by PyTorch, and 6.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

To execute this test, run the following from the base repo dir:
    python test/test_cuda.py TestBlockStateAbsorption.test_assigning_back_deleter_fns_to_tensor
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 2 tests in 0.136s

FAILED (errors=1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140852
Approved by: https://github.com/Skylion007
This commit is contained in:
eqy
2024-12-06 21:26:54 +00:00
committed by PyTorch MergeBot
parent 660845a1aa
commit 0a619a212f

View File

@ -293,36 +293,41 @@ class TestCuda(TestCase):
@serialTest()
def test_set_per_process_memory_fraction(self):
# test invalid fraction value.
with self.assertRaisesRegex(TypeError, "Invalid type"):
torch.cuda.set_per_process_memory_fraction(1)
with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
torch.cuda.set_per_process_memory_fraction(-0.1)
with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
torch.cuda.set_per_process_memory_fraction(2.0)
try:
# test invalid fraction value.
with self.assertRaisesRegex(TypeError, "Invalid type"):
torch.cuda.set_per_process_memory_fraction(1)
with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
torch.cuda.set_per_process_memory_fraction(-0.1)
with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
torch.cuda.set_per_process_memory_fraction(2.0)
tensor = torch.zeros(1024, device="cuda")
torch.cuda.empty_cache()
total_memory = torch.cuda.get_device_properties(0).total_memory
torch.cuda.set_per_process_memory_fraction(0.5, 0)
tensor = torch.zeros(1024, device="cuda")
torch.cuda.empty_cache()
total_memory = torch.cuda.get_device_properties(0).total_memory
torch.cuda.set_per_process_memory_fraction(0.5, 0)
# test 0.499 allocation is ok.
application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved()
tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda")
del tmp_tensor
torch.cuda.empty_cache()
# test 0.499 allocation is ok.
application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved()
tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda")
del tmp_tensor
torch.cuda.empty_cache()
application = int(total_memory * 0.5)
# it will get OOM when try to allocate more than half memory.
oom_regex = (
"would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
)
with self.assertRaisesRegex(RuntimeError, oom_regex):
torch.empty(application, dtype=torch.int8, device="cuda")
application = int(total_memory * 0.5)
# it will get OOM when try to allocate more than half memory.
oom_regex = (
"would exceed allowed memory"
if TEST_CUDAMALLOCASYNC
else "out of memory"
)
with self.assertRaisesRegex(RuntimeError, oom_regex):
torch.empty(application, dtype=torch.int8, device="cuda")
# ensure out of memory error doesn't disturb subsequent kernel
tensor.fill_(1)
self.assertTrue((tensor == 1).all())
# ensure out of memory error doesn't disturb subsequent kernel
tensor.fill_(1)
self.assertTrue((tensor == 1).all())
finally:
torch.cuda.set_per_process_memory_fraction(1.0, 0)
@serialTest()
def test_get_per_process_memory_fraction(self):
@ -3641,62 +3646,68 @@ class TestCudaMallocAsync(TestCase):
torch.cuda.memory._record_memory_history(None)
def test_max_split_expandable(self):
torch.cuda.memory.empty_cache()
mb = 1024 * 1024
_, all_memory = torch.cuda.memory.mem_get_info()
pre_reserved = torch.cuda.memory_reserved()
total_allowed = 120 * mb + pre_reserved
fraction_allowed = total_allowed / all_memory
self.assertEqual(int(fraction_allowed * all_memory), total_allowed)
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
try:
torch.cuda.memory.empty_cache()
mb = 1024 * 1024
_, all_memory = torch.cuda.memory.mem_get_info()
pre_reserved = torch.cuda.memory_reserved()
total_allowed = 120 * mb + pre_reserved
fraction_allowed = total_allowed / all_memory
self.assertEqual(int(fraction_allowed * all_memory), total_allowed)
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
def alloc(n):
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
def alloc(n):
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
torch.cuda.memory._set_allocator_settings(
"expandable_segments:False,max_split_size_mb:40"
)
a = alloc(40)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:True,max_split_size_mb:40"
)
b = alloc(40)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:False,max_split_size_mb:40"
)
c = alloc(40)
with self.assertRaises(torch.OutOfMemoryError):
alloc(40)
del a, b, c
# force release_cached_blocks to run with some expandable segments in the free list
alloc(120)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:False,max_split_size_mb:40"
)
a = alloc(40)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:True,max_split_size_mb:40"
)
b = alloc(40)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:False,max_split_size_mb:40"
)
c = alloc(40)
with self.assertRaises(torch.OutOfMemoryError):
alloc(40)
del a, b, c
# force release_cached_blocks to run with some expandable segments in the free list
alloc(120)
finally:
torch.cuda.memory.set_per_process_memory_fraction(1.0)
def test_garbage_collect_expandable(self):
torch.cuda.memory.empty_cache()
mb = 1024 * 1024
_, all_memory = torch.cuda.memory.mem_get_info()
pre_reserved = torch.cuda.memory_reserved()
total_allowed = 120 * mb + pre_reserved
fraction_allowed = total_allowed / all_memory
self.assertEqual((fraction_allowed * all_memory), total_allowed)
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
try:
torch.cuda.memory.empty_cache()
mb = 1024 * 1024
_, all_memory = torch.cuda.memory.mem_get_info()
pre_reserved = torch.cuda.memory_reserved()
total_allowed = 120 * mb + pre_reserved
fraction_allowed = total_allowed / all_memory
self.assertEqual((fraction_allowed * all_memory), total_allowed)
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
def alloc(n):
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
def alloc(n):
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
torch.cuda.memory._set_allocator_settings(
"expandable_segments:False,garbage_collection_threshold:0.5"
)
a = alloc(40)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:True,garbage_collection_threshold:0.5"
)
b = alloc(40)
del a, b
# causes GC to run. The expandable segment block will be split
# so GC would not attempt to free it anyway, but this at least makes sure
# expandable_segment blocks can be in the free list when this is called.
alloc(80)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:False,garbage_collection_threshold:0.5"
)
a = alloc(40)
torch.cuda.memory._set_allocator_settings(
"expandable_segments:True,garbage_collection_threshold:0.5"
)
b = alloc(40)
del a, b
# causes GC to run. The expandable segment block will be split
# so GC would not attempt to free it anyway, but this at least makes sure
# expandable_segment blocks can be in the free list when this is called.
alloc(80)
finally:
torch.cuda.memory.set_per_process_memory_fraction(1.0)
def test_allocator_settings(self):
def power2_div(size, div_factor):