mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Add copyright + some ruff lint things (#2523)
* Copyright and ruff stuff * lol
This commit is contained in:
4
Makefile
4
Makefile
@ -12,13 +12,13 @@ extra_quality_checks:
|
||||
|
||||
# this target runs checks on all files
|
||||
quality:
|
||||
ruff $(check_dirs)
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
doc-builder style src/accelerate docs/source --max_len 119 --check_only
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
style:
|
||||
ruff $(check_dirs) --fix
|
||||
ruff check $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
doc-builder style src/accelerate docs/source --max_len 119
|
||||
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import threading
|
||||
import time
|
||||
|
@ -75,4 +75,4 @@ end_time = time.time()
|
||||
if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time)/5}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
|
@ -74,4 +74,4 @@ end_time = time.time()
|
||||
if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time)/5}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
|
@ -86,4 +86,4 @@ end_time = time.time()
|
||||
if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time)/5}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import runhouse as rh
|
||||
|
@ -3,10 +3,12 @@ line-length = 119
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.lint]
|
||||
preview = true
|
||||
ignore-init-module-imports = true
|
||||
extend-select = [
|
||||
"B009", # static getattr
|
||||
"B010", # static setattr
|
||||
"CPY", # Copyright
|
||||
"E", # PEP8 errors
|
||||
"F", # PEP8 formatting
|
||||
"I", # Import sorting
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "0.28.0.dev0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
|
@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
@ -1 +1,14 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .selection_menu import BulletMenu
|
||||
|
@ -16,7 +16,6 @@
|
||||
Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import string
|
||||
import sys
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""
|
||||
Main driver for the selection menu, based on https://github.com/bchao1/bullet
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .testing import (
|
||||
DEFAULT_LAUNCH_COMMAND,
|
||||
are_the_same_tensors,
|
||||
|
@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
@ -181,7 +181,7 @@ def training_function(config, args):
|
||||
accelerator.print("resumed checkpoint performance:", accuracy)
|
||||
accelerator.print("resumed checkpoint's scheduler's lr:", lr_scheduler.get_lr()[0])
|
||||
accelerator.print("resumed optimizers's lr:", optimizer.param_groups[0]["lr"])
|
||||
with open(os.path.join(args.output_dir, f"state_{starting_epoch-1}.json")) as f:
|
||||
with open(os.path.join(args.output_dir, f"state_{starting_epoch - 1}.json")) as f:
|
||||
resumed_state = json.load(f)
|
||||
assert resumed_state["accuracy"] == accuracy, "Accuracy mismatch, loading from checkpoint failed"
|
||||
assert (
|
||||
|
@ -113,8 +113,8 @@ def generate_predictions(model, dataloader, accelerator):
|
||||
def test_torch_metrics(
|
||||
accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16
|
||||
):
|
||||
model, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)
|
||||
logits, targs = generate_predictions(ddp_model, dataloader, accelerator)
|
||||
_, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)
|
||||
logits, _ = generate_predictions(ddp_model, dataloader, accelerator)
|
||||
assert (
|
||||
len(logits) == num_samples
|
||||
), f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}"
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -1,4 +1,20 @@
|
||||
# Test file to ensure that in general certain situational setups for notebooks work.
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Test file to ensure that in general certain situational setups for notebooks work.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pytest import raises
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .constants import (
|
||||
MODEL_NAME,
|
||||
OPTIMIZER_NAME,
|
||||
|
@ -354,10 +354,10 @@ class DeepSpeedConfigIntegration(AccelerateTestCase):
|
||||
)
|
||||
assert accelerator.deepspeed_config["zero_allow_untested_optimizer"]
|
||||
assert accelerator.deepspeed_config["train_batch_size"], 16
|
||||
assert type(model) == DeepSpeedEngine
|
||||
assert type(optimizer) == DeepSpeedOptimizerWrapper
|
||||
assert type(lr_scheduler) == AcceleratedScheduler
|
||||
assert type(accelerator.deepspeed_engine_wrapped) == DeepSpeedEngineWrapper
|
||||
assert type(model) is DeepSpeedEngine
|
||||
assert type(optimizer) is DeepSpeedOptimizerWrapper
|
||||
assert type(lr_scheduler) is AcceleratedScheduler
|
||||
assert type(accelerator.deepspeed_engine_wrapped) is DeepSpeedEngineWrapper
|
||||
|
||||
elif optim_type == DS_OPTIMIZER and scheduler_type == DS_SCHEDULER:
|
||||
# Test DeepSpeed optimizer + DeepSpeed scheduler
|
||||
@ -411,10 +411,10 @@ class DeepSpeedConfigIntegration(AccelerateTestCase):
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
assert type(model) == DeepSpeedEngine
|
||||
assert type(optimizer) == DeepSpeedOptimizerWrapper
|
||||
assert type(lr_scheduler) == DeepSpeedSchedulerWrapper
|
||||
assert type(accelerator.deepspeed_engine_wrapped) == DeepSpeedEngineWrapper
|
||||
assert type(model) is DeepSpeedEngine
|
||||
assert type(optimizer) is DeepSpeedOptimizerWrapper
|
||||
assert type(lr_scheduler) is DeepSpeedSchedulerWrapper
|
||||
assert type(accelerator.deepspeed_engine_wrapped) is DeepSpeedEngineWrapper
|
||||
|
||||
elif optim_type == CUSTOM_OPTIMIZER and scheduler_type == DS_SCHEDULER:
|
||||
# Test custom optimizer + DeepSpeed scheduler
|
||||
@ -445,11 +445,11 @@ class DeepSpeedConfigIntegration(AccelerateTestCase):
|
||||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
|
||||
)
|
||||
assert type(model) == DeepSpeedEngine
|
||||
assert type(optimizer) == DeepSpeedOptimizerWrapper
|
||||
assert type(lr_scheduler) == DeepSpeedSchedulerWrapper
|
||||
assert type(accelerator.deepspeed_engine_wrapped) == DeepSpeedEngineWrapper
|
||||
elif optim_type == DS_OPTIMIZER and scheduler_type == CUSTOM_SCHEDULER:
|
||||
assert type(model) is DeepSpeedEngine
|
||||
assert type(optimizer) is DeepSpeedOptimizerWrapper
|
||||
assert type(lr_scheduler) is DeepSpeedSchedulerWrapper
|
||||
assert type(accelerator.deepspeed_engine_wrapped) is DeepSpeedEngineWrapper
|
||||
elif optim_type == DS_OPTIMIZER and scheduler_type is CUSTOM_SCHEDULER:
|
||||
# Test deepspeed optimizer + custom scheduler
|
||||
deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.ds_config_file[ZERO2])
|
||||
with mockenv_context(**self.dist_env):
|
||||
@ -992,8 +992,8 @@ class DeepSpeedIntegrationTest(TempDirTestCase):
|
||||
]
|
||||
)
|
||||
for i in range(3):
|
||||
if f"stage_{i+1}" in spec:
|
||||
cmd_stage.extend([f"--zero_stage={i+1}"])
|
||||
if f"stage_{i + 1}" in spec:
|
||||
cmd_stage.extend([f"--zero_stage={i + 1}"])
|
||||
break
|
||||
cmd_stage.extend(
|
||||
[
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -24,7 +24,6 @@ Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/lau
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from argparse import REMAINDER, ArgumentParser
|
||||
|
@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from datetime import date
|
||||
|
@ -15,6 +15,7 @@
|
||||
Script to close stale issue. Taken in part from the AllenNLP repository.
|
||||
https://github.com/allenai/allennlp.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime as dt
|
||||
from datetime import timezone
|
||||
|
Reference in New Issue
Block a user