mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
“Stop Halving My Batch!” · Default back-off 0.5 → 0.9 (#3684)
* feat(memory): change default find_executable_batch_size to change by 10% instead of 50% * Update test_memory_utils.py * Apply style fixes --------- Co-authored-by: Amit Moryossef <amitmoryossef@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@ -121,7 +121,7 @@ def find_executable_batch_size(
|
||||
):
|
||||
"""
|
||||
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
||||
CUDNN, the batch size is cut in half and passed to `function`
|
||||
CUDNN, the batch size is multiplied by 0.9 and passed to `function`
|
||||
|
||||
`function` must take in a `batch_size` parameter as its first argument.
|
||||
|
||||
@ -153,7 +153,7 @@ def find_executable_batch_size(
|
||||
|
||||
def reduce_batch_size_fn():
|
||||
nonlocal batch_size
|
||||
batch_size = batch_size // 2
|
||||
batch_size = int(batch_size * 0.9)
|
||||
return batch_size
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
|
@ -61,7 +61,31 @@ class MemoryTest(unittest.TestCase):
|
||||
raise_fake_out_of_memory()
|
||||
|
||||
mock_training_loop_function()
|
||||
assert batch_sizes == [128, 64, 32, 16, 8]
|
||||
assert batch_sizes == [
|
||||
128,
|
||||
115,
|
||||
103,
|
||||
92,
|
||||
82,
|
||||
73,
|
||||
65,
|
||||
58,
|
||||
52,
|
||||
46,
|
||||
41,
|
||||
36,
|
||||
32,
|
||||
28,
|
||||
25,
|
||||
22,
|
||||
19,
|
||||
17,
|
||||
15,
|
||||
13,
|
||||
11,
|
||||
9,
|
||||
8,
|
||||
]
|
||||
|
||||
def test_memory_explicit(self):
|
||||
batch_sizes = []
|
||||
@ -75,7 +99,31 @@ class MemoryTest(unittest.TestCase):
|
||||
return batch_size, arg1
|
||||
|
||||
bs, arg1 = mock_training_loop_function("hello")
|
||||
assert batch_sizes == [128, 64, 32, 16, 8]
|
||||
assert batch_sizes == [
|
||||
128,
|
||||
115,
|
||||
103,
|
||||
92,
|
||||
82,
|
||||
73,
|
||||
65,
|
||||
58,
|
||||
52,
|
||||
46,
|
||||
41,
|
||||
36,
|
||||
32,
|
||||
28,
|
||||
25,
|
||||
22,
|
||||
19,
|
||||
17,
|
||||
15,
|
||||
13,
|
||||
11,
|
||||
9,
|
||||
8,
|
||||
]
|
||||
assert [bs, arg1] == [8, "hello"]
|
||||
|
||||
def test_start_zero(self):
|
||||
|
Reference in New Issue
Block a user