mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 22:24:32 +08:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d0a3eeaf7 | |||
| 2a810a0ebd |
2
setup.py
2
setup.py
@ -32,7 +32,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.13.1",
|
||||
version="0.13.2",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
__version__ = "0.13.1"
|
||||
__version__ = "0.13.2"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import cpu_offload, disk_offload, dispatch_model, init_empty_weights, load_checkpoint_and_dispatch
|
||||
|
||||
@ -258,7 +258,7 @@ def get_max_layer_size(
|
||||
modules_to_treat = modules.copy()
|
||||
while len(modules_to_treat) > 0:
|
||||
module_name, module = modules_to_treat.pop(0)
|
||||
modules_children = list(module.named_children())
|
||||
modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []
|
||||
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
||||
# No splitting this one so we compare to the max_size
|
||||
size = module_sizes[module_name]
|
||||
|
||||
Reference in New Issue
Block a user