mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-13 21:59:16 +08:00
Compare commits
6 Commits
v1.0.1
...
v0.7-relea
| Author | SHA1 | Date | |
|---|---|---|---|
| 3eac8e7a66 | |||
| b3691db1d6 | |||
| c6657791d7 | |||
| a6e49ed045 | |||
| f0bb5f0ed5 | |||
| 11e8b33217 |
2
setup.py
2
setup.py
@ -36,7 +36,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.7.0.dev0",
|
||||
version="0.7.1",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
__version__ = "0.7.0.dev0"
|
||||
__version__ = "0.7.1"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs
|
||||
|
||||
@ -64,7 +64,7 @@ def get_cluster_input():
|
||||
else:
|
||||
use_cpu = False
|
||||
|
||||
deepspeed_config = None
|
||||
deepspeed_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO]:
|
||||
use_deepspeed = _ask_field(
|
||||
"Do you want to use DeepSpeed? [yes/NO]: ",
|
||||
@ -78,7 +78,6 @@ def get_cluster_input():
|
||||
is_deepspeed_available()
|
||||
), "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
|
||||
|
||||
deepspeed_config = {}
|
||||
if distributed_type == DistributedType.DEEPSPEED:
|
||||
deepspeed_config["zero_stage"] = _ask_field(
|
||||
"What should be your DeepSpeed's ZeRO optimization stage (0, 1, 2, 3)? [2]: ",
|
||||
@ -99,6 +98,7 @@ def get_cluster_input():
|
||||
default=1,
|
||||
)
|
||||
|
||||
fsdp_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU]:
|
||||
use_fsdp = _ask_field(
|
||||
"Do you want to use FullyShardedDataParallel? [yes/NO]: ",
|
||||
@ -108,7 +108,6 @@ def get_cluster_input():
|
||||
)
|
||||
if use_fsdp:
|
||||
distributed_type = DistributedType.FSDP
|
||||
fsdp_config = {}
|
||||
if distributed_type == DistributedType.FSDP:
|
||||
fsdp_config["sharding_strategy"] = _ask_field(
|
||||
"What should be your sharding strategy ([1] FULL_SHARD, [2] SHARD_GRAD_OP)? [1]: ",
|
||||
|
||||
@ -139,6 +139,13 @@ class ClusterConfig(BaseConfig):
|
||||
# args for fsdp
|
||||
fsdp_config: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.deepspeed_config is None:
|
||||
self.deepspeed_config = {}
|
||||
if self.fsdp_config is None:
|
||||
self.fsdp_config = {}
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerConfig(BaseConfig):
|
||||
|
||||
@ -73,6 +73,8 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i
|
||||
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
|
||||
)
|
||||
while True:
|
||||
if batch_size == 0:
|
||||
raise RuntimeError("No executable batch size found, reached zero.")
|
||||
try:
|
||||
return function(batch_size, *args, **kwargs)
|
||||
except Exception as e:
|
||||
|
||||
@ -50,6 +50,26 @@ class MemoryTest(unittest.TestCase):
|
||||
self.assertListEqual(batch_sizes, [128, 64, 32, 16, 8])
|
||||
self.assertListEqual([bs, arg1], [8, "hello"])
|
||||
|
||||
def test_start_zero(self):
|
||||
@find_executable_batch_size(starting_batch_size=0)
|
||||
def mock_training_loop_function(batch_size):
|
||||
pass
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertIn("No executable batch size found, reached zero.", cm.exception.args[0])
|
||||
|
||||
def test_approach_zero(self):
|
||||
@find_executable_batch_size(starting_batch_size=16)
|
||||
def mock_training_loop_function(batch_size):
|
||||
if batch_size > 0:
|
||||
raise_fake_out_of_memory()
|
||||
pass
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertIn("No executable batch size found, reached zero.", cm.exception.args[0])
|
||||
|
||||
def test_verbose_guard(self):
|
||||
@find_executable_batch_size(starting_batch_size=128)
|
||||
def mock_training_loop_function(batch_size, arg1, arg2):
|
||||
@ -60,3 +80,12 @@ class MemoryTest(unittest.TestCase):
|
||||
mock_training_loop_function(128, "hello", "world")
|
||||
self.assertIn("Batch size was passed into `f`", cm.exception.args[0])
|
||||
self.assertIn("`f(arg1='hello', arg2='world')", cm.exception.args[0])
|
||||
|
||||
def test_any_other_error(self):
|
||||
@find_executable_batch_size(starting_batch_size=16)
|
||||
def mock_training_loop_function(batch_size):
|
||||
raise ValueError("Oops, we had an error!")
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertIn("Oops, we had an error!", cm.exception.args[0])
|
||||
|
||||
Reference in New Issue
Block a user