mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
* feat(memory): change default find_executable_batch_size to change by 10% instead of 50% * Update test_memory_utils.py * Apply style fixes --------- Co-authored-by: Amit Moryossef <amitmoryossef@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
185 lines
5.3 KiB
Python
185 lines
5.3 KiB
Python
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
from torch import nn
|
|
|
|
from accelerate.test_utils import (
|
|
memory_allocated_func,
|
|
require_non_cpu,
|
|
require_non_torch_xla,
|
|
torch_device,
|
|
)
|
|
from accelerate.utils.memory import find_executable_batch_size, release_memory
|
|
|
|
|
|
def raise_fake_out_of_memory():
|
|
raise RuntimeError("CUDA out of memory.")
|
|
|
|
|
|
class ModelForTest(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(3, 4)
|
|
self.batchnorm = nn.BatchNorm1d(4)
|
|
self.linear2 = nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear2(self.batchnorm(self.linear1(x)))
|
|
|
|
|
|
class BigModelForTest(ModelForTest):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear3 = nn.Linear(5, 1000)
|
|
|
|
def forward(self, x):
|
|
return self.linear3(super().forward(x))
|
|
|
|
|
|
class MemoryTest(unittest.TestCase):
|
|
def test_memory_implicit(self):
|
|
batch_sizes = []
|
|
|
|
@find_executable_batch_size(starting_batch_size=128)
|
|
def mock_training_loop_function(batch_size):
|
|
nonlocal batch_sizes
|
|
batch_sizes.append(batch_size)
|
|
if batch_size != 8:
|
|
raise_fake_out_of_memory()
|
|
|
|
mock_training_loop_function()
|
|
assert batch_sizes == [
|
|
128,
|
|
115,
|
|
103,
|
|
92,
|
|
82,
|
|
73,
|
|
65,
|
|
58,
|
|
52,
|
|
46,
|
|
41,
|
|
36,
|
|
32,
|
|
28,
|
|
25,
|
|
22,
|
|
19,
|
|
17,
|
|
15,
|
|
13,
|
|
11,
|
|
9,
|
|
8,
|
|
]
|
|
|
|
def test_memory_explicit(self):
|
|
batch_sizes = []
|
|
|
|
@find_executable_batch_size(starting_batch_size=128)
|
|
def mock_training_loop_function(batch_size, arg1):
|
|
nonlocal batch_sizes
|
|
batch_sizes.append(batch_size)
|
|
if batch_size != 8:
|
|
raise_fake_out_of_memory()
|
|
return batch_size, arg1
|
|
|
|
bs, arg1 = mock_training_loop_function("hello")
|
|
assert batch_sizes == [
|
|
128,
|
|
115,
|
|
103,
|
|
92,
|
|
82,
|
|
73,
|
|
65,
|
|
58,
|
|
52,
|
|
46,
|
|
41,
|
|
36,
|
|
32,
|
|
28,
|
|
25,
|
|
22,
|
|
19,
|
|
17,
|
|
15,
|
|
13,
|
|
11,
|
|
9,
|
|
8,
|
|
]
|
|
assert [bs, arg1] == [8, "hello"]
|
|
|
|
def test_start_zero(self):
|
|
@find_executable_batch_size(starting_batch_size=0)
|
|
def mock_training_loop_function(batch_size):
|
|
pass
|
|
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
mock_training_loop_function()
|
|
assert "No executable batch size found, reached zero." in cm.exception.args[0]
|
|
|
|
def test_approach_zero(self):
|
|
@find_executable_batch_size(starting_batch_size=16)
|
|
def mock_training_loop_function(batch_size):
|
|
if batch_size > 0:
|
|
raise_fake_out_of_memory()
|
|
pass
|
|
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
mock_training_loop_function()
|
|
assert "No executable batch size found, reached zero." in cm.exception.args[0]
|
|
|
|
def test_verbose_guard(self):
|
|
@find_executable_batch_size(starting_batch_size=128)
|
|
def mock_training_loop_function(batch_size, arg1, arg2):
|
|
if batch_size != 8:
|
|
raise raise_fake_out_of_memory()
|
|
|
|
with self.assertRaises(TypeError) as cm:
|
|
mock_training_loop_function(128, "hello", "world")
|
|
assert "Batch size was passed into `f`" in cm.exception.args[0]
|
|
assert "`f(arg1='hello', arg2='world')" in cm.exception.args[0]
|
|
|
|
def test_any_other_error(self):
|
|
@find_executable_batch_size(starting_batch_size=16)
|
|
def mock_training_loop_function(batch_size):
|
|
raise ValueError("Oops, we had an error!")
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
mock_training_loop_function()
|
|
assert "Oops, we had an error!" in cm.exception.args[0]
|
|
|
|
@require_non_cpu
|
|
@require_non_torch_xla
|
|
def test_release_memory(self):
|
|
starting_memory = memory_allocated_func()
|
|
|
|
if torch_device.startswith("hpu"):
|
|
# hpu has a minimum memory allocation that cannot be released,
|
|
# we need to surpass it by using a bigger model (>5767296 bytes)
|
|
model = BigModelForTest()
|
|
else:
|
|
model = ModelForTest()
|
|
|
|
model.to(torch_device)
|
|
assert memory_allocated_func() > starting_memory
|
|
model = release_memory(model)
|
|
assert memory_allocated_func() == starting_memory
|