mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 22:24:32 +08:00
Compare commits
2 Commits
v0.26.0-re
...
v0.7.0
| Author | SHA1 | Date | |
|---|---|---|---|
| f0bb5f0ed5 | |||
| 11e8b33217 |
2
setup.py
2
setup.py
@ -36,7 +36,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.7.0.dev0",
|
||||
version="0.7.0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
__version__ = "0.7.0.dev0"
|
||||
__version__ = "0.7.0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs
|
||||
|
||||
@ -139,6 +139,13 @@ class ClusterConfig(BaseConfig):
|
||||
# args for fsdp
|
||||
fsdp_config: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.deepspeed_config is None:
|
||||
self.deepspeed_config = {}
|
||||
if self.fsdp_config is None:
|
||||
self.fsdp_config = {}
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerConfig(BaseConfig):
|
||||
|
||||
Reference in New Issue
Block a user