mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Add end_training/destroy_pg to everything and unpin numpy (#3030)
* Add end_training/destroy_pg to everything * Carry over to AcceleratorState * If forked, ignore * More numpy fun * Skip only init
This commit is contained in:
@ -217,6 +217,7 @@ def training_function(config, args):
|
||||
# And call it at the end with no arguments
|
||||
# Note: You could also refactor this outside of your training loop function
|
||||
inner_training_loop()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -276,6 +276,7 @@ def training_function(config, args):
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -255,6 +255,7 @@ def training_function(config, args):
|
||||
preds = torch.stack(test_predictions, dim=0).sum(dim=0).div(int(args.num_folds)).argmax(dim=-1)
|
||||
test_metric = metric.compute(predictions=preds, references=test_references)
|
||||
accelerator.print("Average test metrics from all folds:", test_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -192,6 +192,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -716,6 +716,7 @@ def main():
|
||||
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"perplexity": perplexity, "eval_loss": eval_loss.item()}, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -222,6 +222,7 @@ def training_function(config, args):
|
||||
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -399,8 +399,7 @@ def training_function(config, args):
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -197,6 +197,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -202,6 +202,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -703,6 +703,7 @@ def main():
|
||||
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"perplexity": perplexity}, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -210,6 +210,7 @@ def training_function(config, args):
|
||||
# And call it at the end with no arguments
|
||||
# Note: You could also refactor this outside of your training loop function
|
||||
inner_training_loop()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -214,6 +214,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -203,6 +203,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -202,6 +202,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -236,11 +236,7 @@ def training_function(config, args):
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
# New Code #
|
||||
# When a run is finished, you should call `accelerator.end_training()`
|
||||
# to close all of the open trackers
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -262,8 +262,7 @@ def training_function(config, args):
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -256,8 +256,7 @@ def training_function(config, args):
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -180,6 +180,7 @@ def training_function(config, args):
|
||||
eval_metric = accurate.item() / num_elems
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -76,3 +76,4 @@ if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
PartialState().destroy_process_group()
|
||||
|
@ -75,3 +75,4 @@ if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
PartialState().destroy_process_group()
|
||||
|
@ -52,3 +52,4 @@ if PartialState().is_last_process:
|
||||
next_token_logits = output[0][:, -1, :]
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
print(tokenizer.batch_decode(next_token))
|
||||
PartialState().destroy_process_group()
|
||||
|
@ -87,3 +87,4 @@ if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
PartialState().destroy_process_group()
|
||||
|
@ -185,6 +185,7 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
2
setup.py
2
setup.py
@ -70,7 +70,7 @@ setup(
|
||||
},
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=[
|
||||
"numpy>=1.17,<2.0.0",
|
||||
"numpy>=1.17,<3.0.0",
|
||||
"packaging>=20.0",
|
||||
"psutil",
|
||||
"pyyaml",
|
||||
|
@ -2727,9 +2727,7 @@ class Accelerator:
|
||||
for tracker in self.trackers:
|
||||
tracker.finish()
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
# needed when using torch.distributed.init_process_group
|
||||
torch.distributed.destroy_process_group()
|
||||
self.state.destroy_process_group()
|
||||
|
||||
def save(self, obj, f, safe_serialization=False):
|
||||
"""
|
||||
|
@ -789,6 +789,16 @@ class PartialState:
|
||||
self.device = torch.device(device, device_index)
|
||||
device_module.set_device(self.device)
|
||||
|
||||
def destroy_process_group(self, group=None):
|
||||
"""
|
||||
Destroys the process group. If one is not specified, the default process group is destroyed.
|
||||
"""
|
||||
if self.fork_launched and group is None:
|
||||
return
|
||||
# needed when using torch.distributed.init_process_group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group(group)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
# By this point we know that no attributes of `self` contain `name`,
|
||||
# so we just modify the error message
|
||||
@ -983,6 +993,18 @@ class AcceleratorState:
|
||||
if reset_partial_state:
|
||||
PartialState._reset_state()
|
||||
|
||||
def destroy_process_group(self, group=None):
|
||||
"""
|
||||
Destroys the process group. If one is not specified, the default process group is destroyed.
|
||||
|
||||
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
|
||||
"""
|
||||
PartialState().destroy_process_group(group)
|
||||
|
||||
@property
|
||||
def fork_launched(self):
|
||||
return PartialState().fork_launched
|
||||
|
||||
@property
|
||||
def use_distributed(self):
|
||||
"""
|
||||
|
@ -223,6 +223,7 @@ def training_function(config, args):
|
||||
if accelerator.is_main_process:
|
||||
with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f:
|
||||
json.dump(state, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -294,6 +294,7 @@ def main():
|
||||
if accelerator.is_local_main_process:
|
||||
print("**Test that `drop_last` is taken into account**")
|
||||
test_gather_for_metrics_drop_last()
|
||||
accelerator.end_training()
|
||||
accelerator.state._reset_state()
|
||||
|
||||
|
||||
|
@ -240,6 +240,7 @@ def training_function(config, args):
|
||||
if accelerator.is_main_process:
|
||||
with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f:
|
||||
json.dump(train_total_peak_memory, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -205,6 +205,7 @@ def training_function(config, args):
|
||||
if accelerator.is_main_process:
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(performance_metric, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -125,5 +125,6 @@ if __name__ == "__main__":
|
||||
state.print("Testing CV model...")
|
||||
test_resnet()
|
||||
test_resnet(3)
|
||||
state.destroy_process_group()
|
||||
else:
|
||||
print("Less than two GPUs found, not running tests!")
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
|
||||
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
|
||||
|
||||
|
||||
class MockModel(torch.nn.Module):
|
||||
@ -71,6 +71,7 @@ def main():
|
||||
]:
|
||||
print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
|
||||
test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -307,6 +307,8 @@ def main():
|
||||
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
|
||||
test_data_loader(loader, accelerator)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -158,3 +158,4 @@ if __name__ == "__main__":
|
||||
if accelerator.is_main_process:
|
||||
shutil.rmtree(out_path)
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
@ -110,6 +110,8 @@ def main():
|
||||
if is_bnb_available():
|
||||
print("Test problematic imports (bnb)")
|
||||
test_problematic_imports()
|
||||
if NUM_PROCESSES > 1:
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -173,6 +173,7 @@ def main():
|
||||
test_op_checker(state)
|
||||
state.print("testing sending tensors across devices")
|
||||
test_copy_tensor_to_devices(state)
|
||||
state.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -822,6 +822,8 @@ def main():
|
||||
print("\n**Test reinstantiated state**")
|
||||
test_reinstantiated_state()
|
||||
|
||||
state.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -20,7 +20,7 @@ from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate.accelerator import Accelerator, GradientAccumulationPlugin
|
||||
from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
|
||||
from accelerate.state import GradientState
|
||||
from accelerate.test_utils import RegressionDataset, RegressionModel
|
||||
from accelerate.utils import DistributedType, set_seed
|
||||
@ -249,9 +249,9 @@ def test_gradient_accumulation_with_opt_and_scheduler(
|
||||
split_batches=False, dispatch_batches=False, sync_each_batch=False
|
||||
):
|
||||
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
|
||||
dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
|
||||
accelerator = Accelerator(
|
||||
split_batches=split_batches,
|
||||
dispatch_batches=dispatch_batches,
|
||||
dataloader_config=dataloader_config,
|
||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||
)
|
||||
# Test that context manager behaves properly
|
||||
@ -392,6 +392,7 @@ def main():
|
||||
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
|
||||
)
|
||||
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
|
||||
state.destroy_process_group()
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -14,6 +14,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
from accelerate import debug_launcher
|
||||
from accelerate.test_utils import (
|
||||
DEFAULT_LAUNCH_COMMAND,
|
||||
@ -29,6 +32,7 @@ from accelerate.utils import patch_environment
|
||||
|
||||
|
||||
@require_huggingface_suite
|
||||
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "Test requires numpy version < 2.0")
|
||||
class MetricTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_metrics.py")
|
||||
|
@ -27,6 +27,7 @@ from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
# We use TF to parse the logs
|
||||
from accelerate import Accelerator
|
||||
@ -68,6 +69,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@require_tensorboard
|
||||
class TensorBoardTrackingTest(unittest.TestCase):
|
||||
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "TB doesn't support numpy 2.0")
|
||||
def test_init_trackers(self):
|
||||
project_name = "test_project_with_config"
|
||||
with tempfile.TemporaryDirectory() as dirpath:
|
||||
|
Reference in New Issue
Block a user