# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from torch import nn from accelerate.test_utils import ( memory_allocated_func, require_non_cpu, require_non_torch_xla, torch_device, ) from accelerate.utils.memory import find_executable_batch_size, release_memory def raise_fake_out_of_memory(): raise RuntimeError("CUDA out of memory.") class ModelForTest(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(3, 4) self.batchnorm = nn.BatchNorm1d(4) self.linear2 = nn.Linear(4, 5) def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) class BigModelForTest(ModelForTest): def __init__(self): super().__init__() self.linear3 = nn.Linear(5, 1000) def forward(self, x): return self.linear3(super().forward(x)) class MemoryTest(unittest.TestCase): def test_memory_implicit(self): batch_sizes = [] @find_executable_batch_size(starting_batch_size=128) def mock_training_loop_function(batch_size): nonlocal batch_sizes batch_sizes.append(batch_size) if batch_size != 8: raise_fake_out_of_memory() mock_training_loop_function() assert batch_sizes == [ 128, 115, 103, 92, 82, 73, 65, 58, 52, 46, 41, 36, 32, 28, 25, 22, 19, 17, 15, 13, 11, 9, 8, ] def test_memory_explicit(self): batch_sizes = [] @find_executable_batch_size(starting_batch_size=128) def mock_training_loop_function(batch_size, arg1): nonlocal batch_sizes batch_sizes.append(batch_size) if batch_size != 8: raise_fake_out_of_memory() return batch_size, arg1 bs, arg1 = mock_training_loop_function("hello") assert batch_sizes == [ 128, 115, 103, 92, 82, 73, 65, 58, 52, 46, 41, 36, 32, 28, 25, 22, 19, 17, 15, 13, 11, 9, 8, ] assert [bs, arg1] == [8, "hello"] def test_start_zero(self): @find_executable_batch_size(starting_batch_size=0) def mock_training_loop_function(batch_size): pass with self.assertRaises(RuntimeError) as cm: mock_training_loop_function() assert "No executable batch size found, reached zero." in cm.exception.args[0] def test_approach_zero(self): @find_executable_batch_size(starting_batch_size=16) def mock_training_loop_function(batch_size): if batch_size > 0: raise_fake_out_of_memory() pass with self.assertRaises(RuntimeError) as cm: mock_training_loop_function() assert "No executable batch size found, reached zero." in cm.exception.args[0] def test_verbose_guard(self): @find_executable_batch_size(starting_batch_size=128) def mock_training_loop_function(batch_size, arg1, arg2): if batch_size != 8: raise raise_fake_out_of_memory() with self.assertRaises(TypeError) as cm: mock_training_loop_function(128, "hello", "world") assert "Batch size was passed into `f`" in cm.exception.args[0] assert "`f(arg1='hello', arg2='world')" in cm.exception.args[0] def test_any_other_error(self): @find_executable_batch_size(starting_batch_size=16) def mock_training_loop_function(batch_size): raise ValueError("Oops, we had an error!") with self.assertRaises(ValueError) as cm: mock_training_loop_function() assert "Oops, we had an error!" in cm.exception.args[0] @require_non_cpu @require_non_torch_xla def test_release_memory(self): starting_memory = memory_allocated_func() if torch_device.startswith("hpu"): # hpu has a minimum memory allocation that cannot be released, # we need to surpass it by using a bigger model (>5767296 bytes) model = BigModelForTest() else: model = ModelForTest() model.to(torch_device) assert memory_allocated_func() > starting_memory model = release_memory(model) assert memory_allocated_func() == starting_memory