“Stop Halving My Batch!” · Default back-off 0.5 → 0.9 (#3684)

* feat(memory): change default find_executable_batch_size to change by 10% instead of 50%

* Update test_memory_utils.py

* Apply style fixes

---------

Co-authored-by: Amit Moryossef <amitmoryossef@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Marc Sun
2025-07-16 12:32:46 +02:00
committed by GitHub
parent 0408ab12d7
commit 3b13453bbf
2 changed files with 52 additions and 4 deletions

View File

@ -121,7 +121,7 @@ def find_executable_batch_size(
):
"""
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function`
CUDNN, the batch size is multiplied by 0.9 and passed to `function`
`function` must take in a `batch_size` parameter as its first argument.
@ -153,7 +153,7 @@ def find_executable_batch_size(
def reduce_batch_size_fn():
nonlocal batch_size
batch_size = batch_size // 2
batch_size = int(batch_size * 0.9)
return batch_size
def decorator(*args, **kwargs):

View File

@ -61,7 +61,31 @@ class MemoryTest(unittest.TestCase):
raise_fake_out_of_memory()
mock_training_loop_function()
assert batch_sizes == [128, 64, 32, 16, 8]
assert batch_sizes == [
128,
115,
103,
92,
82,
73,
65,
58,
52,
46,
41,
36,
32,
28,
25,
22,
19,
17,
15,
13,
11,
9,
8,
]
def test_memory_explicit(self):
batch_sizes = []
@ -75,7 +99,31 @@ class MemoryTest(unittest.TestCase):
return batch_size, arg1
bs, arg1 = mock_training_loop_function("hello")
assert batch_sizes == [128, 64, 32, 16, 8]
assert batch_sizes == [
128,
115,
103,
92,
82,
73,
65,
58,
52,
46,
41,
36,
32,
28,
25,
22,
19,
17,
15,
13,
11,
9,
8,
]
assert [bs, arg1] == [8, "hello"]
def test_start_zero(self):