Compare commits

...

6 Commits

Author SHA1 Message Date
3eac8e7a66 Releave v0.7.1 2022-04-29 09:10:14 -04:00
b3691db1d6 Patchfix infinite loop (#335) 2022-04-29 09:09:22 -04:00
c6657791d7 Add guards for batch size finder (#334)
* Fix zero reached

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-04-29 09:09:12 -04:00
a6e49ed045 Fix fdsp config in cluster (#331)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-04-29 09:09:03 -04:00
f0bb5f0ed5 Fix default config dicts (#329)
* Fix default config dicts

* style
2022-04-28 11:24:24 -04:00
11e8b33217 Release: v0.7.0 2022-04-28 11:03:50 -04:00
6 changed files with 42 additions and 5 deletions

View File

@ -36,7 +36,7 @@ extras["sagemaker"] = [
setup(
name="accelerate",
version="0.7.0.dev0",
version="0.7.1",
description="Accelerate",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@ -2,7 +2,7 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
__version__ = "0.7.0.dev0"
__version__ = "0.7.1"
from .accelerator import Accelerator
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs

View File

@ -64,7 +64,7 @@ def get_cluster_input():
else:
use_cpu = False
deepspeed_config = None
deepspeed_config = {}
if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO]:
use_deepspeed = _ask_field(
"Do you want to use DeepSpeed? [yes/NO]: ",
@ -78,7 +78,6 @@ def get_cluster_input():
is_deepspeed_available()
), "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
deepspeed_config = {}
if distributed_type == DistributedType.DEEPSPEED:
deepspeed_config["zero_stage"] = _ask_field(
"What should be your DeepSpeed's ZeRO optimization stage (0, 1, 2, 3)? [2]: ",
@ -99,6 +98,7 @@ def get_cluster_input():
default=1,
)
fsdp_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_fsdp = _ask_field(
"Do you want to use FullyShardedDataParallel? [yes/NO]: ",
@ -108,7 +108,6 @@ def get_cluster_input():
)
if use_fsdp:
distributed_type = DistributedType.FSDP
fsdp_config = {}
if distributed_type == DistributedType.FSDP:
fsdp_config["sharding_strategy"] = _ask_field(
"What should be your sharding strategy ([1] FULL_SHARD, [2] SHARD_GRAD_OP)? [1]: ",

View File

@ -139,6 +139,13 @@ class ClusterConfig(BaseConfig):
# args for fsdp
fsdp_config: dict = None
def __post_init__(self):
if self.deepspeed_config is None:
self.deepspeed_config = {}
if self.fsdp_config is None:
self.fsdp_config = {}
return super().__post_init__()
@dataclass
class SageMakerConfig(BaseConfig):

View File

@ -73,6 +73,8 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
)
while True:
if batch_size == 0:
raise RuntimeError("No executable batch size found, reached zero.")
try:
return function(batch_size, *args, **kwargs)
except Exception as e:

View File

@ -50,6 +50,26 @@ class MemoryTest(unittest.TestCase):
self.assertListEqual(batch_sizes, [128, 64, 32, 16, 8])
self.assertListEqual([bs, arg1], [8, "hello"])
def test_start_zero(self):
@find_executable_batch_size(starting_batch_size=0)
def mock_training_loop_function(batch_size):
pass
with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertIn("No executable batch size found, reached zero.", cm.exception.args[0])
def test_approach_zero(self):
@find_executable_batch_size(starting_batch_size=16)
def mock_training_loop_function(batch_size):
if batch_size > 0:
raise_fake_out_of_memory()
pass
with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertIn("No executable batch size found, reached zero.", cm.exception.args[0])
def test_verbose_guard(self):
@find_executable_batch_size(starting_batch_size=128)
def mock_training_loop_function(batch_size, arg1, arg2):
@ -60,3 +80,12 @@ class MemoryTest(unittest.TestCase):
mock_training_loop_function(128, "hello", "world")
self.assertIn("Batch size was passed into `f`", cm.exception.args[0])
self.assertIn("`f(arg1='hello', arg2='world')", cm.exception.args[0])
def test_any_other_error(self):
@find_executable_batch_size(starting_batch_size=16)
def mock_training_loop_function(batch_size):
raise ValueError("Oops, we had an error!")
with self.assertRaises(ValueError) as cm:
mock_training_loop_function()
self.assertIn("Oops, we had an error!", cm.exception.args[0])