Compare commits

..

2 Commits

Author SHA1 Message Date
8d0a3eeaf7 Release: v0.13.2 2022-10-17 11:09:16 -04:00
2a810a0ebd [Device map] nn.Parameter don't have children (#747)
* [Device map] nn.Parameter don't have children

* Update src/accelerate/utils/modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-10-17 11:07:54 -04:00
3 changed files with 3 additions and 3 deletions

View File

@ -32,7 +32,7 @@ extras["sagemaker"] = [
setup(
name="accelerate",
version="0.13.1",
version="0.13.2",
description="Accelerate",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@ -2,7 +2,7 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
__version__ = "0.13.1"
__version__ = "0.13.2"
from .accelerator import Accelerator
from .big_modeling import cpu_offload, disk_offload, dispatch_model, init_empty_weights, load_checkpoint_and_dispatch

View File

@ -258,7 +258,7 @@ def get_max_layer_size(
modules_to_treat = modules.copy()
while len(modules_to_treat) > 0:
module_name, module = modules_to_treat.pop(0)
modules_children = list(module.named_children())
modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
# No splitting this one so we compare to the max_size
size = module_sizes[module_name]