mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-19 17:14:29 +08:00
Compare commits
4 Commits
feat/async
...
v0.34.1
| Author | SHA1 | Date | |
|---|---|---|---|
| beb43781d7 | |||
| e13bef2c78 | |||
| 73a1531e58 | |||
| 159c0dd02a |
2
setup.py
2
setup.py
@ -49,7 +49,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.34.0.dev0",
|
||||
version="0.34.1",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "0.34.0.dev0"
|
||||
__version__ = "0.34.1"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import (
|
||||
|
||||
@ -416,25 +416,6 @@ class DataLoaderAdapter:
|
||||
else:
|
||||
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
||||
|
||||
# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
|
||||
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
|
||||
# StatefulDataLoader
|
||||
#
|
||||
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
|
||||
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
|
||||
# dispatching scattered throughout various functions and files.
|
||||
#
|
||||
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
|
||||
# transparently.
|
||||
#
|
||||
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
|
||||
# but this would not be backwards compatible with existing code which assumes
|
||||
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
|
||||
base_cls = self.__class__
|
||||
base_cls_name = self.__class__.__name__
|
||||
parent_cls_name = self.base_dataloader.__class__
|
||||
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})
|
||||
|
||||
if hasattr(self.base_dataloader, "state_dict"):
|
||||
self.dl_state_dict = self.base_dataloader.state_dict()
|
||||
|
||||
@ -451,6 +432,18 @@ class DataLoaderAdapter:
|
||||
def load_state_dict(self, state_dict):
|
||||
self.base_dataloader.load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
"""
|
||||
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
||||
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
||||
object.
|
||||
"""
|
||||
return self.base_dataloader.__class__
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base_dataloader)
|
||||
|
||||
def adjust_state_dict_for_prefetch(self):
|
||||
"""
|
||||
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
|
||||
@ -580,6 +573,15 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
self.iteration += 1
|
||||
self.end()
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
|
||||
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
||||
`__class__` member.
|
||||
"""
|
||||
args = super().__reduce__()
|
||||
return (DataLoaderShard, *args[1:])
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
# In case it is manually passed in, the user can set it to what they like
|
||||
if self.iteration != epoch:
|
||||
@ -865,7 +867,7 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __len__(self):
|
||||
whole_length = super().__len__()
|
||||
whole_length = len(self.base_dataloader)
|
||||
if self.split_batches:
|
||||
return whole_length
|
||||
elif self._drop_last:
|
||||
@ -873,6 +875,15 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
else:
|
||||
return math.ceil(whole_length / self.state.num_processes)
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
|
||||
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
||||
`__class__` member.
|
||||
"""
|
||||
args = super().__reduce__()
|
||||
return (DataLoaderDispatcher, *args[1:])
|
||||
|
||||
@property
|
||||
def total_batch_size(self):
|
||||
return (
|
||||
@ -1211,6 +1222,18 @@ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
yield batch
|
||||
self.end()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base_dataloader) - self.skip_batches
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
|
||||
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
||||
`__class__` member.
|
||||
"""
|
||||
args = super().__reduce__()
|
||||
return (SkipDataLoader, *args[1:])
|
||||
|
||||
|
||||
def skip_first_batches(dataloader, num_batches=0):
|
||||
"""
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pickle
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import List
|
||||
@ -247,6 +248,16 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
|
||||
assert "only supported for map-style datasets" in str(w[-1].message)
|
||||
|
||||
|
||||
def test_pickle_accelerator():
|
||||
accelerator = create_accelerator()
|
||||
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
|
||||
_ = accelerator.prepare(data_loader)
|
||||
pickled_accelerator = pickle.dumps(accelerator)
|
||||
unpickled_accelerator = pickle.loads(pickled_accelerator)
|
||||
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
|
||||
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
|
||||
|
||||
|
||||
def test_data_loader(data_loader, accelerator):
|
||||
# Prepare the DataLoader
|
||||
data_loader = accelerator.prepare(data_loader)
|
||||
@ -368,6 +379,9 @@ def main():
|
||||
test_join_raises_warning_for_non_ddp_distributed(accelerator)
|
||||
accelerator.state.distributed_type = original_state
|
||||
|
||||
accelerator.print("Test pickling an accelerator")
|
||||
test_pickle_accelerator()
|
||||
|
||||
dataset = DummyDataset()
|
||||
# Conventional Dataloader with shuffle=False
|
||||
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
||||
|
||||
@ -1500,7 +1500,7 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
def set_state_dict_type(self):
|
||||
"""
|
||||
Set the state dict config based on the `StateDictType.
|
||||
Set the state dict config based on the `StateDictType`.
|
||||
"""
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig,
|
||||
@ -1538,9 +1538,7 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
# First base off of `_no_split_modules`
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
default_transformer_cls_names_to_wrap = (
|
||||
",".join(model._no_split_modules) if no_split_modules is not None else ""
|
||||
)
|
||||
default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else []
|
||||
if self.auto_wrap_policy == transformer_auto_wrap_policy:
|
||||
if self.transformer_cls_names_to_wrap is None:
|
||||
self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap
|
||||
|
||||
@ -49,6 +49,7 @@ from accelerate.utils.other import patch_environment
|
||||
set_seed(42)
|
||||
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
LLAMA_TESTING = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
dtypes = [FP16, BF16]
|
||||
@ -136,15 +137,19 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
assert fsdp_plugin.state_dict_config.rank0_only
|
||||
|
||||
def test_auto_wrap_policy(self):
|
||||
model = AutoModel.from_pretrained(BERT_BASE_CASED)
|
||||
for model_name in [LLAMA_TESTING, BERT_BASE_CASED]:
|
||||
model = AutoModel.from_pretrained(model_name)
|
||||
layer_to_wrap = "LlamaDecoderLayer" if model_name == LLAMA_TESTING else "BertLayer"
|
||||
for policy in FSDP_AUTO_WRAP_POLICY:
|
||||
env = self.fsdp_env.copy()
|
||||
env["FSDP_AUTO_WRAP_POLICY"] = policy
|
||||
transformer_cls_to_wrap = None
|
||||
min_num_params = None
|
||||
env.pop("FSDP_TRANSFORMER_CLS_TO_WRAP", None)
|
||||
env.pop("FSDP_MIN_NUM_PARAMS", None)
|
||||
if policy == "TRANSFORMER_BASED_WRAP":
|
||||
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer"
|
||||
transformer_cls_to_wrap = "BertLayer"
|
||||
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = layer_to_wrap
|
||||
transformer_cls_to_wrap = layer_to_wrap
|
||||
elif policy == "SIZE_BASED_WRAP":
|
||||
env["FSDP_MIN_NUM_PARAMS"] = "2000"
|
||||
min_num_params = 2000
|
||||
|
||||
@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches
|
||||
from accelerate.state import GradientState, PartialState
|
||||
from accelerate.test_utils import (
|
||||
require_bnb,
|
||||
@ -647,6 +648,52 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
model_loaded = pickle.loads(pickle.dumps(model))
|
||||
model_loaded(inputs)
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
def test_can_pickle_dataloader(self, dispatch_batches):
|
||||
"""
|
||||
Test that pickling a prepared dataloader works.
|
||||
"""
|
||||
data = torch.arange(10).to(torch_device)
|
||||
ds = torch.utils.data.TensorDataset(data)
|
||||
dl = torch.utils.data.DataLoader(ds)
|
||||
skip_dl = skip_first_batches(dl, 2)
|
||||
|
||||
# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
|
||||
# TODO: Add support for pickling StatefulDataLoader
|
||||
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
|
||||
accelerator = Accelerator(dataloader_config=dataloader_config)
|
||||
|
||||
original_dl, _ = accelerator.prepare(dl, skip_dl)
|
||||
if dispatch_batches:
|
||||
assert isinstance(original_dl, DataLoaderDispatcher)
|
||||
else:
|
||||
assert isinstance(original_dl, DataLoaderShard)
|
||||
|
||||
prepared_model_dumps = pickle.dumps(accelerator)
|
||||
|
||||
model_loaded = pickle.loads(prepared_model_dumps)
|
||||
assert len(model_loaded._dataloaders) == 2
|
||||
|
||||
# Assert equality of recovered and original dataloader
|
||||
loaded_dl = model_loaded._dataloaders[0]
|
||||
assert isinstance(loaded_dl, DataLoader)
|
||||
if dispatch_batches:
|
||||
assert isinstance(loaded_dl, DataLoaderDispatcher)
|
||||
else:
|
||||
assert isinstance(loaded_dl, DataLoaderShard)
|
||||
assert len(loaded_dl) == len(original_dl)
|
||||
assert [i for i in loaded_dl] == [i for i in original_dl]
|
||||
|
||||
# Test skip dataloader works as expected as well
|
||||
loaded_skip_dl = model_loaded._dataloaders[1]
|
||||
assert isinstance(loaded_skip_dl, DataLoader)
|
||||
if dispatch_batches:
|
||||
assert isinstance(loaded_dl, DataLoaderDispatcher)
|
||||
else:
|
||||
assert isinstance(loaded_dl, DataLoaderShard)
|
||||
assert len(loaded_skip_dl) == len(original_dl) - 2
|
||||
assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]
|
||||
|
||||
# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
|
||||
|
||||
@ -420,6 +420,14 @@ class DataLoaderTester(unittest.TestCase):
|
||||
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
|
||||
dl_shard = DataLoaderShard(range(16), batch_size=4)
|
||||
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
|
||||
|
||||
# Test dataloaders are instances of instantiated classes
|
||||
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
|
||||
assert isinstance(skip_dl, SkipDataLoader)
|
||||
assert isinstance(dl_shard, DataLoaderShard)
|
||||
assert isinstance(dl_dispatcher, DataLoaderDispatcher)
|
||||
|
||||
# Test dataloaders are instances of base classes
|
||||
assert isinstance(skip_dl, DataLoader)
|
||||
assert isinstance(dl_shard, DataLoader)
|
||||
assert isinstance(dl_dispatcher, DataLoader)
|
||||
@ -556,6 +564,13 @@ class StatefulDataLoaderTester(unittest.TestCase):
|
||||
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
|
||||
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
|
||||
# Test dataloaders are instances of instantiated classes
|
||||
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
|
||||
assert isinstance(skip_dl, SkipDataLoader)
|
||||
assert isinstance(dl_shard, DataLoaderShard)
|
||||
assert isinstance(dl_dispatcher, DataLoaderDispatcher)
|
||||
|
||||
assert isinstance(skip_dl, StatefulDataLoader)
|
||||
assert isinstance(dl_shard, StatefulDataLoader)
|
||||
assert isinstance(dl_dispatcher, StatefulDataLoader)
|
||||
|
||||
Reference in New Issue
Block a user