mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-20 09:34:28 +08:00
Compare commits
12 Commits
fork-teste
...
debug-test
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f48cb52e7 | |||
| e33aba7371 | |||
| 068d586938 | |||
| 76043b402f | |||
| 3126992054 | |||
| 656e15e4f8 | |||
| 1b21f9a630 | |||
| f592aad8df | |||
| b69239577f | |||
| f973f0d5f9 | |||
| 56580b40c5 | |||
| 30ac26cf33 |
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -37,11 +37,11 @@ members/contributors who may be interested in your PR.
|
||||
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
|
||||
|
||||
- Big modeling: @SunMarc
|
||||
- Fully-Sharded Data Parallism: @muellerzr
|
||||
- DeepSpeed: @muellerzr
|
||||
- Fully-Sharded Data Parallism: @pacman100
|
||||
- DeepSpeed: @pacman100
|
||||
- Command Line Interface: @muellerzr
|
||||
- Documentation: @muellerzr
|
||||
- Core parts of the library: @muellerzr @BenjaminBossan @SunMarc
|
||||
- Maintained examples: @muellerzr or @SunMarc
|
||||
- Core parts of the library: @muellerzr @BenjaminBossan
|
||||
- Maintained examples: @muellerzr or @pacman100
|
||||
|
||||
-->
|
||||
324
.github/workflows/nightly.yml
vendored
324
.github/workflows/nightly.yml
vendored
@ -44,186 +44,186 @@ jobs:
|
||||
source activate accelerate
|
||||
make test
|
||||
|
||||
- name: Run examples on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
pip uninstall comet_ml -y
|
||||
make test_examples
|
||||
# - name: Run examples on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# pip uninstall comet_ml -y
|
||||
# make test_examples
|
||||
|
||||
- name: Generate Report
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
# - name: Generate Report
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_deepspeed_tests_single_gpu:
|
||||
runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu_deepspeed"
|
||||
container:
|
||||
image: huggingface/accelerate:gpu-deepspeed-nightly
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Update clone & pip install
|
||||
run: |
|
||||
source activate accelerate
|
||||
git clone https://github.com/huggingface/accelerate;
|
||||
cd accelerate;
|
||||
git checkout ${{ github.sha }};
|
||||
pip install -e . --no-deps
|
||||
pip install pytest-reportlog tabulate
|
||||
# run_deepspeed_tests_single_gpu:
|
||||
# runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci]
|
||||
# env:
|
||||
# CUDA_VISIBLE_DEVICES: "0"
|
||||
# TEST_TYPE: "single_gpu_deepspeed"
|
||||
# container:
|
||||
# image: huggingface/accelerate:gpu-deepspeed-nightly
|
||||
# options: --gpus all --shm-size "16gb"
|
||||
# defaults:
|
||||
# run:
|
||||
# shell: bash
|
||||
# steps:
|
||||
# - name: Update clone & pip install
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# git clone https://github.com/huggingface/accelerate;
|
||||
# cd accelerate;
|
||||
# git checkout ${{ github.sha }};
|
||||
# pip install -e . --no-deps
|
||||
# pip install pytest-reportlog tabulate
|
||||
|
||||
- name: Show installed libraries
|
||||
run: |
|
||||
source activate accelerate;
|
||||
pip freeze
|
||||
# - name: Show installed libraries
|
||||
# run: |
|
||||
# source activate accelerate;
|
||||
# pip freeze
|
||||
|
||||
- name: Run test on GPUs
|
||||
working-directory: accelerate
|
||||
run: |
|
||||
source activate accelerate
|
||||
make test_deepspeed
|
||||
# - name: Run test on GPUs
|
||||
# working-directory: accelerate
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# make test_deepspeed
|
||||
|
||||
- name: Run Integration tests on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
make test_integrations
|
||||
# - name: Run Integration tests on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# make test_integrations
|
||||
|
||||
- name: Run examples on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
pip uninstall comet_ml -y
|
||||
make test_examples
|
||||
# - name: Run examples on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# pip uninstall comet_ml -y
|
||||
# make test_examples
|
||||
|
||||
- name: Generate Report
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
# - name: Generate Report
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_core_tests_multi_gpu:
|
||||
runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu"
|
||||
container:
|
||||
image: huggingface/accelerate:gpu-nightly
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Update clone
|
||||
run: |
|
||||
source activate accelerate
|
||||
git clone https://github.com/huggingface/accelerate;
|
||||
cd accelerate;
|
||||
git checkout ${{ github.sha }};
|
||||
pip install -e . --no-deps
|
||||
pip install pytest-reportlog tabulate
|
||||
# run_core_tests_multi_gpu:
|
||||
# runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
|
||||
# env:
|
||||
# CUDA_VISIBLE_DEVICES: "0,1"
|
||||
# TEST_TYPE: "multi_gpu"
|
||||
# container:
|
||||
# image: huggingface/accelerate:gpu-nightly
|
||||
# options: --gpus all --shm-size "16gb"
|
||||
# defaults:
|
||||
# run:
|
||||
# shell: bash
|
||||
# steps:
|
||||
# - name: Update clone
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# git clone https://github.com/huggingface/accelerate;
|
||||
# cd accelerate;
|
||||
# git checkout ${{ github.sha }};
|
||||
# pip install -e . --no-deps
|
||||
# pip install pytest-reportlog tabulate
|
||||
|
||||
- name: Show installed libraries
|
||||
run: |
|
||||
source activate accelerate;
|
||||
pip freeze
|
||||
# - name: Show installed libraries
|
||||
# run: |
|
||||
# source activate accelerate;
|
||||
# pip freeze
|
||||
|
||||
- name: Run core and big modeling tests on GPUs
|
||||
working-directory: accelerate
|
||||
run: |
|
||||
source activate accelerate
|
||||
make test_core
|
||||
make test_big_modeling
|
||||
make test_cli
|
||||
# - name: Run core and big modeling tests on GPUs
|
||||
# working-directory: accelerate
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# make test_core
|
||||
# make test_big_modeling
|
||||
# make test_cli
|
||||
|
||||
- name: Run Integration tests on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
make test_integrations
|
||||
# - name: Run Integration tests on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# make test_integrations
|
||||
|
||||
- name: Run examples on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
pip uninstall comet_ml -y
|
||||
make test_examples
|
||||
# - name: Run examples on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# pip uninstall comet_ml -y
|
||||
# make test_examples
|
||||
|
||||
- name: Generate Report
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
# - name: Generate Report
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_deepspeed_tests_multi_gpu:
|
||||
runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu_deepspeed"
|
||||
container:
|
||||
image: huggingface/accelerate:gpu-deepspeed-nightly
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Update clone
|
||||
run: |
|
||||
source activate accelerate
|
||||
git clone https://github.com/huggingface/accelerate;
|
||||
cd accelerate;
|
||||
git checkout ${{ github.sha }};
|
||||
pip install -e . --no-deps
|
||||
pip install pytest-reportlog tabulate
|
||||
# run_deepspeed_tests_multi_gpu:
|
||||
# runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
|
||||
# env:
|
||||
# CUDA_VISIBLE_DEVICES: "0,1"
|
||||
# TEST_TYPE: "multi_gpu_deepspeed"
|
||||
# container:
|
||||
# image: huggingface/accelerate:gpu-deepspeed-nightly
|
||||
# options: --gpus all --shm-size "16gb"
|
||||
# defaults:
|
||||
# run:
|
||||
# shell: bash
|
||||
# steps:
|
||||
# - name: Update clone
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# git clone https://github.com/huggingface/accelerate;
|
||||
# cd accelerate;
|
||||
# git checkout ${{ github.sha }};
|
||||
# pip install -e . --no-deps
|
||||
# pip install pytest-reportlog tabulate
|
||||
|
||||
- name: Show installed libraries
|
||||
run: |
|
||||
source activate accelerate;
|
||||
pip freeze
|
||||
# - name: Show installed libraries
|
||||
# run: |
|
||||
# source activate accelerate;
|
||||
# pip freeze
|
||||
|
||||
- name: Run DeepSpeed tests
|
||||
working-directory: accelerate
|
||||
run: |
|
||||
source activate accelerate
|
||||
make test_deepspeed
|
||||
# - name: Run DeepSpeed tests
|
||||
# working-directory: accelerate
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# make test_deepspeed
|
||||
|
||||
- name: Run Integration tests on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
make test_integrations
|
||||
# - name: Run Integration tests on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# make test_integrations
|
||||
|
||||
- name: Run examples on GPUs
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
source activate accelerate
|
||||
pip uninstall comet_ml -y
|
||||
make test_examples
|
||||
# - name: Run examples on GPUs
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# source activate accelerate
|
||||
# pip uninstall comet_ml -y
|
||||
# make test_examples
|
||||
|
||||
- name: Generate Report
|
||||
working-directory: accelerate
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
# - name: Generate Report
|
||||
# working-directory: accelerate
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
run-integration-tests:
|
||||
if: always()
|
||||
uses: ./.github/workflows/self_hosted_integration_tests.yml
|
||||
# run-integration-tests:
|
||||
# if: always()
|
||||
# uses: ./.github/workflows/self_hosted_integration_tests.yml
|
||||
4
Makefile
4
Makefile
@ -42,11 +42,7 @@ test_fsdp:
|
||||
# Since the new version of pytest will *change* how things are collected, we need `deepspeed` to
|
||||
# run after test_core and test_cli
|
||||
test:
|
||||
$(MAKE) test_core
|
||||
$(MAKE) test_cli
|
||||
$(MAKE) test_big_modeling
|
||||
$(MAKE) test_deepspeed
|
||||
$(MAKE) test_fsdp
|
||||
|
||||
test_examples:
|
||||
python -m pytest -s -v ./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_examples.log",)
|
||||
|
||||
@ -430,17 +430,6 @@ args = (model, "fp16", 42, 64)
|
||||
notebook_launcher(training_loop, args, num_processes=8)
|
||||
```
|
||||
|
||||
To launch the training process with elasticity, enabling fault tolerance, you can use the `elastic_launch` feature provided by PyTorch. This requires setting additional parameters such as `rdzv_backend` and `max_restarts`. Here is an example of how to use `notebook_launcher` with elastic capabilities:
|
||||
|
||||
```python
|
||||
notebook_launcher(
|
||||
training_loop,
|
||||
args,
|
||||
num_processes=2,
|
||||
max_restarts=3
|
||||
)
|
||||
```
|
||||
|
||||
As it's running it will print the progress as well as state how many devices you ran on. This tutorial was ran with two GPUs:
|
||||
|
||||
```python out
|
||||
|
||||
@ -15,6 +15,4 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Utilities for Fully Sharded Data Parallelism
|
||||
|
||||
[[autodoc]] utils.merge_fsdp_weights
|
||||
|
||||
[[autodoc]] utils.FullyShardedDataParallelPlugin
|
||||
@ -161,22 +161,6 @@ When using transformers `save_pretrained`, pass `state_dict=accelerator.get_stat
|
||||
|
||||
You can then pass `state` into the `save_pretrained` method. There are several modes for `StateDictType` and `FullStateDictConfig` that you can use to control the behavior of `state_dict`. For more information, see the [PyTorch documentation](https://pytorch.org/docs/stable/fsdp.html).
|
||||
|
||||
If you choose to use `StateDictType.SHARDED_STATE_DICT`, the weights of the model during `Accelerator.save_state` will be split into `n` files for each sub-split on the model. To merge them back into
|
||||
a single dictionary to load back into the model later after training you can use the `merge_weights` utility:
|
||||
|
||||
```py
|
||||
from accelerate.utils import merge_fsdp_weights
|
||||
|
||||
# Our weights are saved usually in a `pytorch_model_fsdp_{model_number}` folder
|
||||
merge_fsdp_weights("pytorch_model_fsdp_0", "output_path", safe_serialization=True)
|
||||
```
|
||||
The final output will then either be saved to `model.safetensors` or `pytorch_model.bin` (if `safe_serialization=False` is passed).
|
||||
|
||||
This can also be called using the CLI:
|
||||
```bash
|
||||
accelerate merge-weights pytorch_model_fsdp_0/ output_path
|
||||
```
|
||||
|
||||
|
||||
## Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages
|
||||
* `FULL_SHARD` maps to the DeepSpeed `ZeRO Stage-3`. Shards optimizer states, gradients and parameters.
|
||||
|
||||
3
setup.py
3
setup.py
@ -22,7 +22,7 @@ extras["quality"] = [
|
||||
"ruff ~= 0.2.1",
|
||||
]
|
||||
extras["docs"] = []
|
||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized", "expecttest"]
|
||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
|
||||
extras["test_dev"] = [
|
||||
"datasets",
|
||||
"diffusers",
|
||||
@ -65,7 +65,6 @@ setup(
|
||||
"accelerate-config=accelerate.commands.config:main",
|
||||
"accelerate-estimate-memory=accelerate.commands.estimate:main",
|
||||
"accelerate-launch=accelerate.commands.launch:main",
|
||||
"accelerate-merge-weights=accelerate.commands.merge:main",
|
||||
]
|
||||
},
|
||||
python_requires=">=3.8.0",
|
||||
|
||||
@ -128,7 +128,9 @@ if is_megatron_lm_available():
|
||||
MegatronLMSchedulerWrapper,
|
||||
megatron_lm_initialize,
|
||||
megatron_lm_prepare_data_loader,
|
||||
megatron_lm_prepare_model_optimizer_scheduler,
|
||||
megatron_lm_prepare_model,
|
||||
megatron_lm_prepare_optimizer,
|
||||
megatron_lm_prepare_scheduler,
|
||||
)
|
||||
|
||||
from torch.distributed.algorithms.join import Join
|
||||
@ -994,14 +996,14 @@ class Accelerator:
|
||||
model.require_backward_grad_sync = old_require_backward_grad_sync
|
||||
model.require_forward_param_sync = old_require_forward_param_sync
|
||||
|
||||
def _do_sync(self):
|
||||
def _do_sync(self, force: bool = False):
|
||||
"Sets the right `sync_gradients` context and either resets or increases `self.step`"
|
||||
if self.gradient_state.sync_with_dataloader and self.gradient_state.end_of_dataloader:
|
||||
self.step = 0
|
||||
self.gradient_state._set_sync_gradients(True)
|
||||
else:
|
||||
self.step += 1
|
||||
self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
|
||||
self.gradient_state._set_sync_gradients(force or ((self.step % self.gradient_state.num_steps) == 0))
|
||||
|
||||
@property
|
||||
def sync_gradients(self):
|
||||
@ -1047,21 +1049,12 @@ class Accelerator:
|
||||
... optimizer.zero_grad()
|
||||
```
|
||||
"""
|
||||
self._do_sync()
|
||||
|
||||
allow_gradient_sync = (
|
||||
self.sync_gradients # must sync if sync gradients need to complete an optimizer step
|
||||
or (
|
||||
# the no_sync context stops the gradients from reducing during distributed training
|
||||
# bringing speedup (potentially at some costs). Here, no_sync can be prevented
|
||||
# by setting sync_each_batch = True.
|
||||
self.use_distributed # only relevant in distributed settings
|
||||
and self.gradient_state.plugin_kwargs.get("sync_each_batch", False)
|
||||
)
|
||||
)
|
||||
# sync_each_batch=True will guarantee below that self.sync_gradients=True, therefore
|
||||
# resulting in the nullcontext always being selected.
|
||||
self._do_sync(force=self.gradient_state.plugin_kwargs.get("sync_each_batch", False))
|
||||
with contextlib.ExitStack() as cm_stack:
|
||||
for m in models:
|
||||
cm_stack.enter_context(contextlib.nullcontext() if allow_gradient_sync else self.no_sync(m))
|
||||
cm_stack.enter_context(contextlib.nullcontext() if self.sync_gradients else self.no_sync(m))
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
@ -1408,7 +1401,7 @@ class Accelerator:
|
||||
if (self.device.index is not None) or (current_device_index != 0):
|
||||
raise ValueError(
|
||||
"You can't train a model that has been loaded in 8-bit precision on a different device than the one "
|
||||
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`"
|
||||
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}"
|
||||
)
|
||||
|
||||
if "cpu" in model_devices or "disk" in model_devices:
|
||||
@ -1818,6 +1811,7 @@ class Accelerator:
|
||||
model = None
|
||||
optimizer = None
|
||||
scheduler = None
|
||||
is_dummy_scheduler = False
|
||||
batch_data = None
|
||||
for obj in args:
|
||||
if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None:
|
||||
@ -1843,9 +1837,6 @@ class Accelerator:
|
||||
|
||||
# initialize megatron-lm
|
||||
megatron_lm_initialize(self, args_defaults=megatron_lm_plugin.megatron_lm_default_args)
|
||||
|
||||
(model, optimizer, scheduler) = megatron_lm_prepare_model_optimizer_scheduler(self)
|
||||
|
||||
counter = 0
|
||||
result = []
|
||||
for obj in args:
|
||||
@ -1861,6 +1852,13 @@ class Accelerator:
|
||||
else:
|
||||
result.append(obj)
|
||||
|
||||
if model is not None:
|
||||
model = megatron_lm_prepare_model(self)
|
||||
if optimizer is not None:
|
||||
optimizer = megatron_lm_prepare_optimizer(self, model)
|
||||
if scheduler is not None:
|
||||
scheduler = megatron_lm_prepare_scheduler(self, optimizer, scheduler)
|
||||
|
||||
if model is not None:
|
||||
model = MegatronEngine(self, model, optimizer, scheduler)
|
||||
if optimizer is not None:
|
||||
|
||||
@ -397,6 +397,7 @@ def dispatch_model(
|
||||
weights_map = OffloadedWeightsLoader(
|
||||
state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
|
||||
)
|
||||
print(weights_map)
|
||||
else:
|
||||
weights_map = None
|
||||
|
||||
@ -415,7 +416,6 @@ def dispatch_model(
|
||||
|
||||
# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
|
||||
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
||||
|
||||
attach_align_device_hook_on_blocks(
|
||||
model,
|
||||
execution_device=execution_device,
|
||||
|
||||
@ -18,7 +18,6 @@ from accelerate.commands.config import get_config_parser
|
||||
from accelerate.commands.env import env_command_parser
|
||||
from accelerate.commands.estimate import estimate_command_parser
|
||||
from accelerate.commands.launch import launch_command_parser
|
||||
from accelerate.commands.merge import merge_command_parser
|
||||
from accelerate.commands.test import test_command_parser
|
||||
from accelerate.commands.tpu import tpu_command_parser
|
||||
from accelerate.commands.utils import CustomArgumentParser
|
||||
@ -33,7 +32,6 @@ def main():
|
||||
estimate_command_parser(subparsers=subparsers)
|
||||
env_command_parser(subparsers=subparsers)
|
||||
launch_command_parser(subparsers=subparsers)
|
||||
merge_command_parser(subparsers=subparsers)
|
||||
tpu_command_parser(subparsers=subparsers)
|
||||
test_command_parser(subparsers=subparsers)
|
||||
|
||||
|
||||
@ -584,12 +584,6 @@ def launch_command_parser(subparsers=None):
|
||||
help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
|
||||
" (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_activation_checkpointing",
|
||||
default="false",
|
||||
type=str,
|
||||
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
|
||||
# megatron_lm args
|
||||
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from accelerate.commands.utils import CustomArgumentParser
|
||||
from accelerate.utils import merge_fsdp_weights
|
||||
|
||||
|
||||
description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
|
||||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
|
||||
|
||||
This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
|
||||
|
||||
|
||||
def merge_command(args):
|
||||
merge_fsdp_weights(
|
||||
args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
|
||||
)
|
||||
|
||||
|
||||
def merge_command_parser(subparsers=None):
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("merge-weights", description=description)
|
||||
else:
|
||||
parser = CustomArgumentParser(description=description)
|
||||
|
||||
parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
type=str,
|
||||
help="The path to save the merged weights. Defaults to the current directory. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unsafe_serialization",
|
||||
action="store_false",
|
||||
default=True,
|
||||
help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_checkpoint_dir",
|
||||
action="store_true",
|
||||
help="Whether to remove the checkpoint directory after merging.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
if subparsers is not None:
|
||||
parser.set_defaults(func=merge_command)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = merge_command_parser()
|
||||
args = parser.parse_args()
|
||||
merge_command(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -44,12 +44,6 @@ def notebook_launcher(
|
||||
master_addr="127.0.0.1",
|
||||
node_rank=0,
|
||||
num_nodes=1,
|
||||
rdzv_backend="static",
|
||||
rdzv_endpoint="",
|
||||
rdzv_conf=None,
|
||||
rdzv_id="none",
|
||||
max_restarts=0,
|
||||
monitor_interval=0.1,
|
||||
):
|
||||
"""
|
||||
Launches a training function, using several processes or multiple nodes if it's possible in the current environment
|
||||
@ -84,18 +78,6 @@ def notebook_launcher(
|
||||
The rank of the current node.
|
||||
num_nodes (`int`, *optional*, defaults to 1):
|
||||
The number of nodes to use for training.
|
||||
rdzv_backend (`str`, *optional*, defaults to `"static"`):
|
||||
The rendezvous method to use, such as 'static' (the default) or 'c10d'
|
||||
rdzv_endpoint (`str`, *optional*, defaults to `""`):
|
||||
The endpoint of the rdzv sync. storage.
|
||||
rdzv_conf (`Dict`, *optional*, defaults to `None`):
|
||||
Additional rendezvous configuration.
|
||||
rdzv_id (`str`, *optional*, defaults to `"none"`):
|
||||
The unique run id of the job.
|
||||
max_restarts (`int`, *optional*, defaults to 0):
|
||||
The maximum amount of restarts that elastic agent will conduct on workers before failure.
|
||||
monitor_interval (`float`, *optional*, defaults to 0.1):
|
||||
The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
|
||||
|
||||
Example:
|
||||
|
||||
@ -159,7 +141,6 @@ def notebook_launcher(
|
||||
raise ValueError("The node_rank must be less than the number of nodes.")
|
||||
if num_processes > 1:
|
||||
# Multi-GPU launch
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from torch.multiprocessing import start_processes
|
||||
from torch.multiprocessing.spawn import ProcessRaisedException
|
||||
|
||||
@ -217,26 +198,7 @@ def notebook_launcher(
|
||||
launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
|
||||
print(f"Launching training on {num_processes} GPUs.")
|
||||
try:
|
||||
if rdzv_conf is None:
|
||||
rdzv_conf = {}
|
||||
if rdzv_backend == "static":
|
||||
rdzv_conf["rank"] = node_rank
|
||||
if not rdzv_endpoint:
|
||||
rdzv_endpoint = f"{master_addr}:{use_port}"
|
||||
launch_config = LaunchConfig(
|
||||
min_nodes=num_nodes,
|
||||
max_nodes=num_nodes,
|
||||
nproc_per_node=num_processes,
|
||||
run_id=rdzv_id,
|
||||
rdzv_endpoint=rdzv_endpoint,
|
||||
rdzv_backend=rdzv_backend,
|
||||
rdzv_configs=rdzv_conf,
|
||||
max_restarts=max_restarts,
|
||||
monitor_interval=monitor_interval,
|
||||
start_method="fork",
|
||||
log_line_prefix_template=os.environ.get("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE"),
|
||||
)
|
||||
elastic_launch(config=launch_config, entrypoint=function)(*args)
|
||||
start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
|
||||
except ProcessRaisedException as e:
|
||||
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
||||
raise RuntimeError(
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
@ -101,11 +102,6 @@ class ThreadLocalSharedDict(threading.local):
|
||||
def __set__(self, obj, value):
|
||||
self._storage = value
|
||||
|
||||
def _get_shared_dict_type():
|
||||
# Prefer global shared dictionary, except when using TPU or `backend == threaded`
|
||||
if is_torch_xla_available() or (torch.distributed.is_initialized() and torch.distributed.get_backend() == "threaded"):
|
||||
return ThreadLocalSharedDict
|
||||
return dict
|
||||
|
||||
# Prefer global shared dictionary, except when using TPU.
|
||||
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
|
||||
@ -167,9 +163,6 @@ class PartialState:
|
||||
]
|
||||
|
||||
def __init__(self, cpu: bool = False, **kwargs):
|
||||
# This is needed when we are launching tests and have the `threaded` backend
|
||||
if _get_shared_dict_type() != self._shared_state.__class__:
|
||||
PartialState._shared_state = _get_shared_dict_type()()
|
||||
self.__dict__ = self._shared_state
|
||||
if not self.initialized:
|
||||
self._cpu = cpu
|
||||
@ -193,7 +186,7 @@ class PartialState:
|
||||
self.backend = backend
|
||||
self.distributed_type = distributed_type
|
||||
use_deepspeed = False
|
||||
if not cpu and self.backend != "xla" and not torch.distributed.is_initialized():
|
||||
if not cpu and self.backend != "xla":
|
||||
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
# Deal with spawning deepspeed
|
||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
@ -282,13 +275,9 @@ class PartialState:
|
||||
else:
|
||||
self.num_processes = torch.distributed.get_world_size()
|
||||
self.process_index = torch.distributed.get_rank()
|
||||
# Setting `local_process_index` requires some care
|
||||
if dist_information is not None:
|
||||
self.local_process_index = dist_information.local_rank
|
||||
elif backend == "threaded":
|
||||
self.local_process_index = self.process_index
|
||||
else:
|
||||
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
||||
self.local_process_index = (
|
||||
int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
|
||||
)
|
||||
self.set_device()
|
||||
# Now we can change to deepseed
|
||||
if use_deepspeed:
|
||||
@ -448,9 +437,11 @@ class PartialState:
|
||||
length = len(inputs[list(inputs.keys())[0]])
|
||||
if not all(len(v) == length for v in inputs.values()):
|
||||
raise ValueError("All values in the dictionary must have the same length")
|
||||
num_samples_per_process, num_extras = divmod(length, self.num_processes)
|
||||
start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
|
||||
end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
|
||||
num_samples_per_process = math.ceil(length / self.num_processes)
|
||||
start_index = self.process_index * num_samples_per_process
|
||||
end_index = start_index + num_samples_per_process
|
||||
if (len(inputs) % self.num_processes != 0) and (self.process_index == self.num_processes - 1):
|
||||
end_index = length
|
||||
|
||||
def _split_values(inputs, start_index, end_index):
|
||||
if isinstance(inputs, (list, tuple, torch.Tensor)):
|
||||
@ -466,7 +457,7 @@ class PartialState:
|
||||
tensorized_result = send_to_device(result, self.device)
|
||||
result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
|
||||
else:
|
||||
result += [result[-1]] * (num_samples_per_process + 1 - len(result))
|
||||
result += [result[-1]] * (num_samples_per_process - len(result))
|
||||
return result
|
||||
elif isinstance(inputs, dict):
|
||||
for key in inputs.keys():
|
||||
@ -483,7 +474,7 @@ class PartialState:
|
||||
end_index = len(inputs)
|
||||
result_idcs = list(range(start_index, end_index))
|
||||
if apply_padding:
|
||||
result_idcs += [end_index - 1] * (num_samples_per_process + 1 - len(result_idcs))
|
||||
result_idcs += [end_index - 1] * (num_samples_per_process - len(result_idcs))
|
||||
return inputs.select(result_idcs)
|
||||
return inputs
|
||||
|
||||
@ -722,10 +713,6 @@ class PartialState:
|
||||
) -> tuple[str, DistributedType]:
|
||||
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
|
||||
distributed_type = None
|
||||
if torch.distributed.is_initialized():
|
||||
backend = torch.distributed.get_backend()
|
||||
if backend == "threaded":
|
||||
distributed_type = DistributedType.MULTI_GPU
|
||||
if sagemaker_dp:
|
||||
import smdistributed.dataparallel.torch.torch_smddp # noqa
|
||||
|
||||
@ -734,7 +721,7 @@ class PartialState:
|
||||
elif is_torch_xla_available():
|
||||
backend = "xla"
|
||||
distributed_type = DistributedType.XLA
|
||||
elif not cpu and int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
|
||||
if is_mlu_available():
|
||||
backend = "cncl"
|
||||
distributed_type = DistributedType.MULTI_MLU
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy, StateDictType
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
from accelerate.commands.merge import merge_command, merge_command_parser
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.test_utils.training import RegressionDataset
|
||||
from accelerate.utils import merge_fsdp_weights, patch_environment, save_fsdp_model
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parser = merge_command_parser()
|
||||
|
||||
|
||||
class TinyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(16, 16)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(16, 16)
|
||||
self.softmax = torch.nn.Softmax()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.activation(self.linear1(x)))
|
||||
|
||||
|
||||
def setup():
|
||||
if AcceleratorState._shared_state != {}:
|
||||
AcceleratorState()._reset_state()
|
||||
plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, state_dict_type=StateDictType.SHARDED_STATE_DICT
|
||||
)
|
||||
model = TinyModel()
|
||||
with patch_environment(fsdp_auto_wrap_policy="SIZE_BASED_WRAP"):
|
||||
plugin.set_auto_wrap_policy(model)
|
||||
accelerator = Accelerator(fsdp_plugin=plugin)
|
||||
model = accelerator.prepare(model)
|
||||
return model, plugin, accelerator
|
||||
|
||||
|
||||
def mock_training(accelerator, model):
|
||||
train_set = RegressionDataset(length=128, seed=42)
|
||||
train_dl = DataLoader(train_set, batch_size=16, shuffle=False)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
||||
for _ in range(3):
|
||||
for batch in train_dl:
|
||||
model.zero_grad()
|
||||
output = model(batch["x"])
|
||||
loss = torch.nn.functional.mse_loss(output, batch["y"])
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
return model
|
||||
|
||||
|
||||
def check_weights(operation, state_1, state_2):
|
||||
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
|
||||
if str(weight_1.device) != "cuda":
|
||||
weight_1 = weight_1.to("cuda")
|
||||
if str(weight_2.device) != "cuda":
|
||||
weight_2 = weight_2.to("cuda")
|
||||
if operation == "same":
|
||||
assert torch.allclose(weight_1, weight_2)
|
||||
else:
|
||||
assert not torch.allclose(weight_1, weight_2)
|
||||
|
||||
|
||||
def check_safetensors_weights(path, model):
|
||||
safe_state_dict = load_file(path / "model.safetensors")
|
||||
safe_loaded_model = TinyModel()
|
||||
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
|
||||
safe_loaded_model.load_state_dict(safe_state_dict)
|
||||
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
|
||||
|
||||
|
||||
def check_pytorch_weights(path, model):
|
||||
nonsafe_state_dict = torch.load(path / "pytorch_model.bin")
|
||||
nonsafe_loaded_model = TinyModel()
|
||||
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
|
||||
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
|
||||
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())
|
||||
|
||||
|
||||
def test_merge_weights_safetensors(model, path):
|
||||
# Should now be saved at `path/merged.safetensors`
|
||||
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=True)
|
||||
check_safetensors_weights(path, model)
|
||||
|
||||
|
||||
def test_merge_weights_command_safetensors(model, path):
|
||||
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path)])
|
||||
merge_command(args)
|
||||
check_safetensors_weights(path, model)
|
||||
|
||||
|
||||
def test_merge_weights_pytorch(model, path):
|
||||
# Should now be saved at `path/merged.bin`
|
||||
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, safe_serialization=False)
|
||||
check_pytorch_weights(path, model)
|
||||
|
||||
|
||||
def test_merge_weights_command_pytorch(model, path):
|
||||
args = parser.parse_args([str(path / "pytorch_model_fsdp_0"), str(path), "--unsafe_serialization"])
|
||||
merge_command(args)
|
||||
check_pytorch_weights(path, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Note this test requires at least two accelerators!
|
||||
model, plugin, accelerator = setup()
|
||||
if accelerator.num_processes > 1:
|
||||
try:
|
||||
# Initial setup for things
|
||||
out_path = Path("test_merge_weights_fsdp_weights")
|
||||
if not out_path.exists():
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Train briefly once weights aren't the baseline
|
||||
model = mock_training(accelerator, model)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
gc.collect() # Needed for some lingering refs after training
|
||||
save_fsdp_model(plugin, accelerator, model, out_path)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Finally we can test
|
||||
test_merge_weights_safetensors(model, out_path)
|
||||
test_merge_weights_command_safetensors(model, out_path)
|
||||
test_merge_weights_pytorch(model, out_path)
|
||||
test_merge_weights_command_pytorch(model, out_path)
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
# Cleanup in case of any failures
|
||||
if accelerator.is_main_process:
|
||||
shutil.rmtree(out_path)
|
||||
accelerator.wait_for_everyone()
|
||||
@ -16,11 +16,8 @@ Test file to ensure that in general certain situational setups for notebooks wor
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Queue
|
||||
|
||||
from pytest import mark, raises
|
||||
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
||||
from pytest import raises
|
||||
|
||||
from accelerate import PartialState, notebook_launcher
|
||||
from accelerate.test_utils import require_bnb
|
||||
@ -32,25 +29,6 @@ def basic_function():
|
||||
print(f"PartialState:\n{PartialState()}")
|
||||
|
||||
|
||||
def tough_nut_function(queue: Queue):
|
||||
if queue.empty():
|
||||
return
|
||||
trial = queue.get()
|
||||
if trial > 0:
|
||||
queue.put(trial - 1)
|
||||
raise RuntimeError("The nut hasn't cracked yet! Try again.")
|
||||
|
||||
print(f"PartialState:\n{PartialState()}")
|
||||
|
||||
|
||||
def bipolar_sleep_function(sleep_sec: int):
|
||||
state = PartialState()
|
||||
if state.process_index % 2 == 0:
|
||||
raise RuntimeError("I'm an even process. I don't like to sleep.")
|
||||
else:
|
||||
time.sleep(sleep_sec)
|
||||
|
||||
|
||||
NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1))
|
||||
|
||||
|
||||
@ -58,36 +36,6 @@ def test_can_initialize():
|
||||
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES)
|
||||
|
||||
|
||||
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends")
|
||||
def test_static_rdzv_backend():
|
||||
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static")
|
||||
|
||||
|
||||
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends")
|
||||
def test_c10d_rdzv_backend():
|
||||
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d")
|
||||
|
||||
|
||||
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance")
|
||||
def test_fault_tolerant(max_restarts: int = 3):
|
||||
queue = Queue()
|
||||
queue.put(max_restarts)
|
||||
notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)
|
||||
|
||||
|
||||
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring")
|
||||
def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100):
|
||||
start_time = time.time()
|
||||
with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."):
|
||||
notebook_launcher(
|
||||
bipolar_sleep_function,
|
||||
(sleep_sec,),
|
||||
num_processes=NUM_PROCESSES,
|
||||
monitor_interval=monitor_interval,
|
||||
)
|
||||
assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time."
|
||||
|
||||
|
||||
@require_bnb
|
||||
def test_problematic_imports():
|
||||
with raises(RuntimeError, match="Please keep these imports"):
|
||||
@ -99,14 +47,6 @@ def test_problematic_imports():
|
||||
def main():
|
||||
print("Test basic notebook can be ran")
|
||||
test_can_initialize()
|
||||
print("Test static rendezvous backend")
|
||||
test_static_rdzv_backend()
|
||||
print("Test c10d rendezvous backend")
|
||||
test_c10d_rdzv_backend()
|
||||
print("Test fault tolerant")
|
||||
test_fault_tolerant()
|
||||
print("Test monitoring")
|
||||
test_monitoring()
|
||||
if is_bnb_available():
|
||||
print("Test problematic imports (bnb)")
|
||||
test_problematic_imports()
|
||||
|
||||
@ -693,24 +693,6 @@ def test_split_between_processes_tensor():
|
||||
state.wait_for_everyone()
|
||||
|
||||
|
||||
def test_split_between_processes_evenly():
|
||||
state = AcceleratorState()
|
||||
if state.num_processes in (1, 2, 4, 8):
|
||||
data = list(range(17))
|
||||
num_samples_per_process = len(data) // state.num_processes
|
||||
num_extras = len(data) % state.num_processes
|
||||
with state.split_between_processes(data) as results:
|
||||
if state.process_index < num_extras:
|
||||
assert (
|
||||
len(results) == num_samples_per_process + 1
|
||||
), f"Each Process should have even elements. Expected: {num_samples_per_process + 1}, Actual: {len(results)}"
|
||||
else:
|
||||
assert (
|
||||
len(results) == num_samples_per_process
|
||||
), f"Each Process should have even elements. Expected: {num_samples_per_process}, Actual: {len(results)}"
|
||||
state.wait_for_everyone()
|
||||
|
||||
|
||||
def test_trigger():
|
||||
accelerator = Accelerator()
|
||||
# should start with being false
|
||||
@ -775,10 +757,6 @@ def main():
|
||||
print("\n**Test split between processes as a tensor**")
|
||||
test_split_between_processes_tensor()
|
||||
|
||||
if state.process_index == 0:
|
||||
print("\n**Test split between processes evenly**")
|
||||
test_split_between_processes_evenly()
|
||||
|
||||
if state.process_index == 0:
|
||||
print("\n**Test split between processes as a datasets.Dataset**")
|
||||
if is_datasets_available():
|
||||
@ -807,10 +785,10 @@ def main():
|
||||
if state.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
# if state.local_process_index == 0:
|
||||
# print("\n**Training integration test**")
|
||||
# training_check(use_seedable_sampler=False)
|
||||
# training_check(use_seedable_sampler=True)
|
||||
if state.local_process_index == 0:
|
||||
print("\n**Training integration test**")
|
||||
training_check(use_seedable_sampler=False)
|
||||
training_check(use_seedable_sampler=True)
|
||||
|
||||
if state.local_process_index == 0:
|
||||
print("\n**Breakpoint trigger test**")
|
||||
|
||||
@ -267,7 +267,7 @@ def test_gradient_accumulation_with_opt_and_scheduler(
|
||||
step_model(model, input, target, accelerator, False)
|
||||
opt.step()
|
||||
|
||||
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)):
|
||||
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch:
|
||||
if split_batches:
|
||||
sched.step()
|
||||
else:
|
||||
@ -284,18 +284,18 @@ def test_gradient_accumulation_with_opt_and_scheduler(
|
||||
assert (
|
||||
opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"]
|
||||
), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n'
|
||||
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader))
|
||||
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch
|
||||
if accelerator.num_processes > 1:
|
||||
check_model_parameters(
|
||||
model,
|
||||
ddp_model,
|
||||
did_step or sync_each_batch, # syncs at each grad_accum interval of if sync_each_batch==True
|
||||
did_step,
|
||||
iteration,
|
||||
rtol=1e-3, # needs a relative tolerance due to roundoff errors
|
||||
rtol=1e-3, # somehow needs a relative tolerance
|
||||
)
|
||||
|
||||
if did_step:
|
||||
opt.zero_grad() # flush gradients every accum step
|
||||
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch:
|
||||
opt.zero_grad() # needs to be guarded by logic as to when we should zero grads
|
||||
ddp_opt.zero_grad()
|
||||
|
||||
# Shuffle ddp_input on each iteration
|
||||
|
||||
@ -30,7 +30,7 @@ import torch
|
||||
|
||||
import accelerate
|
||||
|
||||
from ..state import PartialState, PartialState
|
||||
from ..state import AcceleratorState, PartialState
|
||||
from ..utils import (
|
||||
gather,
|
||||
is_bnb_available,
|
||||
@ -427,14 +427,14 @@ class TempDirTestCase(unittest.TestCase):
|
||||
class AccelerateTestCase(unittest.TestCase):
|
||||
"""
|
||||
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
|
||||
the `PartialState` class should inherit from this to avoid silent failures due to state being shared between
|
||||
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
|
||||
tests.
|
||||
"""
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# Reset the state of the PartialState singleton.
|
||||
PartialState._reset_state()
|
||||
# Reset the state of the AcceleratorState singleton.
|
||||
AcceleratorState._reset_state()
|
||||
PartialState._reset_state()
|
||||
|
||||
|
||||
@ -472,7 +472,7 @@ class MockingTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def are_the_same_tensors(tensor):
|
||||
state = PartialState()
|
||||
state = AcceleratorState()
|
||||
tensor = tensor[None].clone().to(state.device)
|
||||
tensors = gather(tensor).cpu()
|
||||
tensor = tensor[0].cpu()
|
||||
|
||||
@ -50,7 +50,6 @@ from .dataclasses import (
|
||||
SageMakerDistributedType,
|
||||
TensorInformation,
|
||||
TorchDynamoPlugin,
|
||||
add_model_config_to_megatron_parser,
|
||||
)
|
||||
from .environment import (
|
||||
are_libraries_initialized,
|
||||
@ -180,7 +179,7 @@ if is_deepspeed_available():
|
||||
)
|
||||
|
||||
from .bnb import has_4bit_bnb_layers, load_and_quantize_model
|
||||
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, merge_fsdp_weights, save_fsdp_model, save_fsdp_optimizer
|
||||
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
|
||||
from .launch import (
|
||||
PrepareForLaunch,
|
||||
_filter_args,
|
||||
@ -205,7 +204,7 @@ from .megatron_lm import (
|
||||
)
|
||||
from .megatron_lm import initialize as megatron_lm_initialize
|
||||
from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
|
||||
from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
|
||||
from .megatron_lm import prepare_model as megatron_lm_prepare_model
|
||||
from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
|
||||
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
|
||||
from .memory import find_executable_batch_size, release_memory
|
||||
|
||||
@ -1398,14 +1398,6 @@ class MegatronLMPlugin:
|
||||
default=None,
|
||||
metadata={"help": "Custom prepare model function."},
|
||||
)
|
||||
custom_megatron_datasets_provider_function: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Custom megatron train_valid_test datasets provider function."},
|
||||
)
|
||||
custom_get_batch_function: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Custom get batch function."},
|
||||
)
|
||||
|
||||
# remaining args such as enabling Alibi/ROPE positional embeddings,
|
||||
# wandb logging, Multi-Query Attention, etc.
|
||||
@ -1472,15 +1464,87 @@ class MegatronLMPlugin:
|
||||
self.megatron_lm_default_args.update(self.other_megatron_args)
|
||||
|
||||
def set_network_size_args(self, model, batch_data=None):
|
||||
model_config_type = model.config.model_type.lower()
|
||||
for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys():
|
||||
if model_type in model_config_type:
|
||||
MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data)
|
||||
return
|
||||
raise ValueError(
|
||||
f"Accelerate Megatron-LM integration not supports {model_config_type} model. "
|
||||
"You can add your own model config parser."
|
||||
)
|
||||
# Check if the model is either BERT, GPT or T5 else raise error
|
||||
# set 'num_layers', 'hidden_size', 'num_attention_heads', 'max_position_embeddings'
|
||||
if "megatron-bert" in model.config.model_type.lower():
|
||||
model_type_name = "bert"
|
||||
num_layers = model.config.num_hidden_layers
|
||||
hidden_size = model.config.hidden_size
|
||||
num_attention_heads = model.config.num_attention_heads
|
||||
max_position_embeddings = model.config.max_position_embeddings
|
||||
num_labels = model.config.num_labels
|
||||
orig_vocab_size = model.config.vocab_size
|
||||
if "maskedlm" in model.__class__.__name__.lower():
|
||||
pretraining_flag = True
|
||||
if self.seq_length is not None:
|
||||
if self.encoder_seq_length is not None:
|
||||
warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.")
|
||||
self.seq_length = self.encoder_seq_length
|
||||
elif self.encoder_seq_length is not None:
|
||||
self.seq_length = self.encoder_seq_length
|
||||
elif batch_data is not None:
|
||||
self.seq_length = batch_data["input_ids"].shape[1]
|
||||
else:
|
||||
self.seq_length = max_position_embeddings
|
||||
self.megatron_lm_default_args["seq_length"] = self.seq_length
|
||||
elif "gpt2" in model.config.model_type.lower():
|
||||
model_type_name = "gpt"
|
||||
num_layers = model.config.n_layer
|
||||
hidden_size = model.config.n_embd
|
||||
num_attention_heads = model.config.n_head
|
||||
max_position_embeddings = model.config.n_positions
|
||||
orig_vocab_size = model.config.vocab_size
|
||||
pretraining_flag = True
|
||||
if self.seq_length is not None:
|
||||
if self.decoder_seq_length is not None:
|
||||
warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.")
|
||||
self.seq_length = self.decoder_seq_length
|
||||
elif self.decoder_seq_length is not None:
|
||||
self.seq_length = self.decoder_seq_length
|
||||
elif batch_data is not None:
|
||||
self.seq_length = batch_data["input_ids"].shape[1]
|
||||
else:
|
||||
self.seq_length = max_position_embeddings
|
||||
self.megatron_lm_default_args["seq_length"] = self.seq_length
|
||||
self.megatron_lm_default_args["return_logits"] = self.return_logits
|
||||
self.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer"
|
||||
elif "t5" in model.config.model_type.lower():
|
||||
model_type_name = "t5"
|
||||
num_layers = model.config.num_layers
|
||||
hidden_size = model.config.d_model
|
||||
num_attention_heads = model.config.num_heads
|
||||
max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024
|
||||
orig_vocab_size = model.config.vocab_size
|
||||
pretraining_flag = True
|
||||
if self.encoder_seq_length is None:
|
||||
if batch_data is not None:
|
||||
self.encoder_seq_length = batch_data["input_ids"].shape[1]
|
||||
else:
|
||||
self.encoder_seq_length = max_position_embeddings
|
||||
if self.decoder_seq_length is None:
|
||||
if batch_data is not None:
|
||||
self.decoder_seq_length = batch_data["labels"].shape[1]
|
||||
else:
|
||||
self.decoder_seq_length = max_position_embeddings
|
||||
|
||||
self.megatron_lm_default_args["encoder_seq_length"] = self.encoder_seq_length
|
||||
self.megatron_lm_default_args["decoder_seq_length"] = self.decoder_seq_length
|
||||
else:
|
||||
raise ValueError(
|
||||
"🤗 Accelerate Megatron-LM integration supports only BERT, GPT and T5 model. "
|
||||
"Please check the model you are using is one of those."
|
||||
)
|
||||
|
||||
self.megatron_lm_default_args["model_type_name"] = model_type_name
|
||||
self.megatron_lm_default_args["num_layers"] = num_layers
|
||||
self.megatron_lm_default_args["hidden_size"] = hidden_size
|
||||
self.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
|
||||
self.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
|
||||
self.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
|
||||
self.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
|
||||
self.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
|
||||
if model_type_name == "bert":
|
||||
self.megatron_lm_default_args["num_labels"] = num_labels
|
||||
|
||||
def set_mixed_precision(self, mixed_precision):
|
||||
if mixed_precision == "fp16":
|
||||
@ -1557,116 +1621,6 @@ class MegatronLMPlugin:
|
||||
self.megatron_lm_default_args[key.replace("no_", "")] = True
|
||||
|
||||
|
||||
MODEL_CONFIGS_TO_MEGATRON_PARSERS = {}
|
||||
|
||||
|
||||
def add_model_config_to_megatron_parser(model_type: str):
|
||||
def add_model_config_parser_helper(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func
|
||||
return wrapper
|
||||
|
||||
return add_model_config_parser_helper
|
||||
|
||||
|
||||
@add_model_config_to_megatron_parser("megatron-bert")
|
||||
def parse_bert_config(megatron_lm_plugin, model, batch_data):
|
||||
model_type_name = "bert"
|
||||
num_layers = model.config.num_hidden_layers
|
||||
hidden_size = model.config.hidden_size
|
||||
num_attention_heads = model.config.num_attention_heads
|
||||
max_position_embeddings = model.config.max_position_embeddings
|
||||
num_labels = model.config.num_labels
|
||||
orig_vocab_size = model.config.vocab_size
|
||||
if "maskedlm" in model.__class__.__name__.lower():
|
||||
pretraining_flag = True
|
||||
if megatron_lm_plugin.seq_length is not None:
|
||||
if megatron_lm_plugin.encoder_seq_length is not None:
|
||||
warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.")
|
||||
megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length
|
||||
elif megatron_lm_plugin.encoder_seq_length is not None:
|
||||
megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length
|
||||
elif batch_data is not None:
|
||||
megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1]
|
||||
else:
|
||||
megatron_lm_plugin.seq_length = max_position_embeddings
|
||||
megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length
|
||||
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
|
||||
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
|
||||
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
|
||||
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
|
||||
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
|
||||
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_labels"] = num_labels
|
||||
|
||||
|
||||
@add_model_config_to_megatron_parser("gpt2")
|
||||
def parse_gpt2_config(megatron_lm_plugin, model, batch_data):
|
||||
model_type_name = "gpt"
|
||||
num_layers = model.config.n_layer
|
||||
hidden_size = model.config.n_embd
|
||||
num_attention_heads = model.config.n_head
|
||||
max_position_embeddings = model.config.n_positions
|
||||
orig_vocab_size = model.config.vocab_size
|
||||
pretraining_flag = True
|
||||
if megatron_lm_plugin.seq_length is not None:
|
||||
if megatron_lm_plugin.decoder_seq_length is not None:
|
||||
warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.")
|
||||
megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length
|
||||
elif megatron_lm_plugin.decoder_seq_length is not None:
|
||||
megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length
|
||||
elif batch_data is not None:
|
||||
megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1]
|
||||
else:
|
||||
megatron_lm_plugin.seq_length = max_position_embeddings
|
||||
megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length
|
||||
megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits
|
||||
megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer"
|
||||
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
|
||||
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
|
||||
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
|
||||
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
|
||||
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
|
||||
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
|
||||
|
||||
|
||||
@add_model_config_to_megatron_parser("t5")
|
||||
def parse_t5_config(megatron_lm_plugin, model, batch_data):
|
||||
model_type_name = "t5"
|
||||
num_layers = model.config.num_layers
|
||||
hidden_size = model.config.d_model
|
||||
num_attention_heads = model.config.num_heads
|
||||
max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024
|
||||
orig_vocab_size = model.config.vocab_size
|
||||
pretraining_flag = True
|
||||
if megatron_lm_plugin.encoder_seq_length is None:
|
||||
if batch_data is not None:
|
||||
megatron_lm_plugin.encoder_seq_length = batch_data["input_ids"].shape[1]
|
||||
else:
|
||||
megatron_lm_plugin.encoder_seq_length = max_position_embeddings
|
||||
if megatron_lm_plugin.decoder_seq_length is None:
|
||||
if batch_data is not None:
|
||||
megatron_lm_plugin.decoder_seq_length = batch_data["labels"].shape[1]
|
||||
else:
|
||||
megatron_lm_plugin.decoder_seq_length = max_position_embeddings
|
||||
megatron_lm_plugin.megatron_lm_default_args["encoder_seq_length"] = megatron_lm_plugin.encoder_seq_length
|
||||
megatron_lm_plugin.megatron_lm_default_args["decoder_seq_length"] = megatron_lm_plugin.decoder_seq_length
|
||||
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
|
||||
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
|
||||
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
|
||||
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
|
||||
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
|
||||
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
|
||||
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class BnbQuantizationConfig:
|
||||
"""
|
||||
|
||||
@ -12,16 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ..logging import get_logger
|
||||
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME
|
||||
from .imports import is_torch_distributed_available
|
||||
from .modeling import is_peft_model
|
||||
from .other import save
|
||||
from .versions import is_torch_version
|
||||
|
||||
|
||||
@ -31,9 +28,6 @@ if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_availab
|
||||
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
# `dist_cp_format_utils is only available from pt>=2.3.0
|
||||
if is_torch_version(">=", "2.3.0") and is_torch_distributed_available():
|
||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -213,63 +207,3 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
|
||||
logger.info(f"Optimizer loaded from {ckpt_dir}")
|
||||
flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
|
||||
optimizer.load_state_dict(flattened_osd)
|
||||
|
||||
|
||||
def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):
|
||||
"""
|
||||
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
|
||||
|
||||
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
||||
"""
|
||||
state_dict = {}
|
||||
save_path = Path(save_path)
|
||||
dist_cp_format_utils._load_state_dict(
|
||||
state_dict,
|
||||
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
|
||||
planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
|
||||
no_dist=True,
|
||||
)
|
||||
save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME
|
||||
|
||||
# To handle if state is a dict like {model: {...}}
|
||||
if len(state_dict.keys()) == 1:
|
||||
state_dict = state_dict[list(state_dict)[0]]
|
||||
save(state_dict, save_path, safe_serialization=safe_serialization)
|
||||
return save_path
|
||||
|
||||
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False
|
||||
):
|
||||
"""
|
||||
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
|
||||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
|
||||
`safe_serialization` else `pytorch_model.bin`.
|
||||
|
||||
Note: this is a CPU-bound process.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (`str`):
|
||||
The directory containing the FSDP checkpoints (can be either the model or optimizer).
|
||||
output_path (`str`):
|
||||
The path to save the merged checkpoint.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the merged weights with safetensors (recommended).
|
||||
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
||||
Whether to remove the checkpoint directory after merging.
|
||||
"""
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if not is_torch_version(">=", "2.3.0"):
|
||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||
|
||||
# To setup `save` to work
|
||||
state = PartialState()
|
||||
if state.is_main_process:
|
||||
logger.info(f"Merging FSDP weights from {checkpoint_dir}")
|
||||
save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
|
||||
logger.info(f"Successfully merged FSDP weights and saved to {save_path}")
|
||||
if remove_checkpoint_dir:
|
||||
logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
state.wait_for_everyone()
|
||||
|
||||
@ -20,6 +20,7 @@ from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from packaging.version import parse
|
||||
|
||||
from .environment import parse_flag_from_env, str_to_bool
|
||||
from .versions import compare_versions, is_torch_version
|
||||
@ -219,7 +220,12 @@ def is_megatron_lm_available():
|
||||
if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
|
||||
package_exists = importlib.util.find_spec("megatron") is not None
|
||||
if package_exists:
|
||||
return True
|
||||
try:
|
||||
megatron_version = parse(importlib.metadata.version("megatron-lm"))
|
||||
return compare_versions(megatron_version, ">=", "2.2.0")
|
||||
except Exception as e:
|
||||
warnings.warn(f"Parse Megatron version failed. Exception:{e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
|
||||
@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import math
|
||||
import os
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
|
||||
@ -43,25 +41,15 @@ if is_megatron_lm_available():
|
||||
get_args,
|
||||
get_num_microbatches,
|
||||
get_tensorboard_writer,
|
||||
get_timers,
|
||||
get_tokenizer,
|
||||
mpu,
|
||||
print_rank_0,
|
||||
print_rank_last,
|
||||
)
|
||||
from megatron.arguments import (
|
||||
_add_data_args,
|
||||
_add_validation_args,
|
||||
core_transformer_config_from_args,
|
||||
parse_args,
|
||||
validate_args,
|
||||
)
|
||||
from megatron.arguments import _add_data_args, _add_validation_args, parse_args, validate_args
|
||||
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint
|
||||
from megatron.core import mpu, tensor_parallel
|
||||
from megatron.core.distributed import DistributedDataParallel as LocalDDP
|
||||
from megatron.core.distributed import finalize_model_grads
|
||||
from megatron.core.enums import ModelType
|
||||
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
|
||||
from megatron.core.pipeline_parallel import get_forward_backward_func
|
||||
from megatron.core.utils import get_model_config
|
||||
from megatron.data.dataset_utils import build_train_valid_test_datasets
|
||||
from megatron.data.data_samplers import MegatronPretrainingRandomSampler, MegatronPretrainingSampler
|
||||
from megatron.global_vars import set_global_variables
|
||||
from megatron.initialize import (
|
||||
_compile_dependencies,
|
||||
@ -70,22 +58,18 @@ if is_megatron_lm_available():
|
||||
set_jit_fusion_options,
|
||||
write_args_to_tensorboard,
|
||||
)
|
||||
from megatron.model import BertModel, Float16Module, GPTModel, T5Model
|
||||
from megatron.model import BertModel, Float16Module, GPTModel, ModelType, T5Model
|
||||
from megatron.model import DistributedDataParallel as LocalDDP
|
||||
from megatron.model.classification import Classification
|
||||
from megatron.optimizer import get_megatron_optimizer
|
||||
from megatron.schedules import get_forward_backward_func
|
||||
from megatron.text_generation.communication import broadcast_int_list, broadcast_tensor
|
||||
from megatron.text_generation.generation import (
|
||||
beam_search_and_return_on_first_stage,
|
||||
generate_tokens_probs_and_return_on_first_stage,
|
||||
)
|
||||
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
|
||||
from megatron.training import (
|
||||
build_train_valid_test_data_iterators,
|
||||
get_optimizer_param_scheduler,
|
||||
num_floating_point_operations,
|
||||
setup_model_and_optimizer,
|
||||
train_step,
|
||||
training_log,
|
||||
)
|
||||
from megatron.training import get_model, get_optimizer_param_scheduler, training_log
|
||||
from megatron.utils import (
|
||||
average_losses_across_data_parallel_group,
|
||||
calc_params_l2_norm,
|
||||
@ -105,12 +89,10 @@ def model_provider_func(pre_process=True, post_process=True, add_encoder=True, a
|
||||
"The Megatron LM model weights are initialized at random in `accelerator.prepare`. "
|
||||
"Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup."
|
||||
)
|
||||
config = core_transformer_config_from_args(get_args())
|
||||
if args.model_type_name == "bert":
|
||||
if args.pretraining_flag:
|
||||
num_tokentypes = 2 if args.bert_binary_head else 0
|
||||
model = BertModel(
|
||||
config=config,
|
||||
num_tokentypes=num_tokentypes,
|
||||
add_binary_head=args.bert_binary_head,
|
||||
parallel_output=True,
|
||||
@ -119,19 +101,12 @@ def model_provider_func(pre_process=True, post_process=True, add_encoder=True, a
|
||||
)
|
||||
else:
|
||||
model = Classification(
|
||||
config=config,
|
||||
num_classes=args.num_labels,
|
||||
num_tokentypes=2,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
num_classes=args.num_labels, num_tokentypes=2, pre_process=pre_process, post_process=post_process
|
||||
)
|
||||
elif args.model_type_name == "gpt":
|
||||
model = GPTModel(
|
||||
num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process, config=config
|
||||
)
|
||||
model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process)
|
||||
elif args.model_type_name == "t5":
|
||||
model = T5Model(
|
||||
config=config,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
pre_process=pre_process,
|
||||
@ -144,8 +119,8 @@ def model_provider_func(pre_process=True, post_process=True, add_encoder=True, a
|
||||
return model
|
||||
|
||||
|
||||
def prepare_model_optimizer_scheduler(accelerator):
|
||||
accelerator.print("Preparing model optimizer scheduler")
|
||||
def prepare_model(accelerator):
|
||||
accelerator.print("Preparing model")
|
||||
args = get_args()
|
||||
if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None:
|
||||
if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None:
|
||||
@ -154,24 +129,15 @@ def prepare_model_optimizer_scheduler(accelerator):
|
||||
)
|
||||
custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function
|
||||
model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func)
|
||||
optimizer = prepare_optimizer(accelerator, model)
|
||||
scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None)
|
||||
else:
|
||||
model_type = ModelType.encoder_or_decoder
|
||||
if args.model_type_name == "t5":
|
||||
if args.model_type_name in ("bert", "gpt"):
|
||||
model_type = ModelType.encoder_or_decoder
|
||||
elif args.model_type_name == "t5":
|
||||
model_type = ModelType.encoder_and_decoder
|
||||
model_provider_func_ = model_provider_func
|
||||
if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None:
|
||||
model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function
|
||||
(model, optimizer, scheduler) = setup_model_and_optimizer(
|
||||
model_provider_func_,
|
||||
model_type,
|
||||
no_wd_decay_cond=args.no_wd_decay_cond,
|
||||
scale_lr_cond=args.scale_lr_cond,
|
||||
lr_mult=args.lr_mult,
|
||||
)
|
||||
args.model_len = len(model)
|
||||
return model, optimizer, scheduler
|
||||
if args.pipeline_model_parallel_split_rank is None and args.pipeline_model_parallel_size > 1:
|
||||
args.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2
|
||||
model = get_model(model_provider_func, model_type)
|
||||
return model
|
||||
|
||||
|
||||
# dataloader utilities
|
||||
@ -197,27 +163,31 @@ class MegatronLMDummyDataLoader:
|
||||
for key, value in self.dataset_args.items():
|
||||
setattr(args, key, value)
|
||||
|
||||
def get_train_valid_test_datasets_provider(self, accelerator):
|
||||
def get_train_valid_test_datasets_provider(self):
|
||||
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
||||
"""Build train, valid, and test datasets."""
|
||||
args = get_args()
|
||||
dataset_args = {
|
||||
"data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path],
|
||||
"data_prefix": args.data_path,
|
||||
"data_impl": args.data_impl,
|
||||
"splits_string": args.split,
|
||||
"train_valid_test_num_samples": train_val_test_num_samples,
|
||||
"skip_warmup": (not args.mmap_warmup),
|
||||
"seed": args.seed,
|
||||
}
|
||||
if args.model_type_name == "bert":
|
||||
dataset_args.update(
|
||||
{
|
||||
"max_seq_length": args.seq_length,
|
||||
"masked_lm_prob": args.mask_prob,
|
||||
"short_seq_prob": args.short_seq_prob,
|
||||
"binary_head": args.bert_binary_head,
|
||||
}
|
||||
)
|
||||
elif args.model_type_name == "gpt":
|
||||
dataset_args.update(
|
||||
{
|
||||
"max_seq_length": args.seq_length,
|
||||
"seq_length": args.seq_length,
|
||||
}
|
||||
)
|
||||
elif args.model_type_name == "t5":
|
||||
@ -225,36 +195,143 @@ class MegatronLMDummyDataLoader:
|
||||
{
|
||||
"max_seq_length": args.encoder_seq_length,
|
||||
"max_seq_length_dec": args.decoder_seq_length,
|
||||
"masked_lm_prob": args.mask_prob,
|
||||
"short_seq_prob": args.short_seq_prob,
|
||||
"dataset_type": "t5",
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {args.model_type_name}")
|
||||
if args.model_type_name == "gpt":
|
||||
from megatron.data.gpt_dataset import build_train_valid_test_datasets
|
||||
else:
|
||||
from megatron.data.dataset_utils import build_train_valid_test_datasets
|
||||
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args)
|
||||
return train_ds, valid_ds, test_ds
|
||||
|
||||
if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None:
|
||||
return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function
|
||||
return train_valid_test_datasets_provider
|
||||
|
||||
def build_train_valid_test_data_iterators(self, accelerator):
|
||||
def build_pretraining_data_loader(self, dataset, consumed_samples):
|
||||
if dataset is None:
|
||||
return None
|
||||
args = get_args()
|
||||
micro_batch_size = args.micro_batch_size * args.num_micro_batches
|
||||
|
||||
# Megatron sampler
|
||||
if args.dataloader_type == "single":
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size(),
|
||||
)
|
||||
elif args.dataloader_type == "cyclic":
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
dataset,
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=mpu.get_data_parallel_rank(),
|
||||
data_parallel_size=mpu.get_data_parallel_world_size(),
|
||||
data_sharding=args.data_sharding,
|
||||
)
|
||||
else:
|
||||
raise Exception(f"{args.dataloader_type} dataloader type is not supported.")
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True
|
||||
)
|
||||
|
||||
def build_train_valid_test_data_iterators(self):
|
||||
def cyclic_iter(iter):
|
||||
while True:
|
||||
yield from iter
|
||||
|
||||
args = get_args()
|
||||
|
||||
train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator)
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
train_data_iterator = []
|
||||
valid_data_iterator = []
|
||||
test_data_iterator = []
|
||||
for i in range(getattr(args, "model_len", 0)):
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(i)
|
||||
iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
|
||||
train_data_iterator.append(iterators[0])
|
||||
valid_data_iterator.append(iterators[1])
|
||||
test_data_iterator.append(iterators[2])
|
||||
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
|
||||
|
||||
print_rank_0("> building train, validation, and test datasets ...")
|
||||
|
||||
# Backward compatibility, assume fixed batch size.
|
||||
if args.iteration > 0 and args.consumed_train_samples == 0:
|
||||
assert args.train_samples is None, "only backward compatiblity support for iteration-based training"
|
||||
args.consumed_train_samples = args.iteration * args.global_batch_size
|
||||
if args.iteration > 0 and args.consumed_valid_samples == 0:
|
||||
if args.train_samples is None:
|
||||
args.consumed_valid_samples = (
|
||||
(args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size
|
||||
)
|
||||
|
||||
# Data loader only on rank 0 of each model parallel group.
|
||||
if mpu.get_tensor_model_parallel_rank() == 0:
|
||||
# Number of train/valid/test samples.
|
||||
if args.train_samples:
|
||||
train_samples = args.train_samples
|
||||
else:
|
||||
train_samples = args.train_iters * args.global_batch_size
|
||||
eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
|
||||
test_iters = args.eval_iters
|
||||
train_val_test_num_samples = [
|
||||
train_samples,
|
||||
eval_iters * args.global_batch_size,
|
||||
test_iters * args.global_batch_size,
|
||||
]
|
||||
print_rank_0(" > datasets target sizes (minimum size):")
|
||||
print_rank_0(f" train: {train_val_test_num_samples[0]}")
|
||||
print_rank_0(f" validation: {train_val_test_num_samples[1]}")
|
||||
print_rank_0(f" test: {train_val_test_num_samples[2]}")
|
||||
|
||||
# Build the datasets.
|
||||
train_valid_test_datasets_provider = self.get_train_valid_test_datasets_provider()
|
||||
train_ds, valid_ds, test_ds = train_valid_test_datasets_provider(train_val_test_num_samples)
|
||||
|
||||
# Build dataloders.
|
||||
train_dataloader = self.build_pretraining_data_loader(train_ds, args.consumed_train_samples)
|
||||
valid_dataloader = self.build_pretraining_data_loader(valid_ds, args.consumed_valid_samples)
|
||||
test_dataloader = self.build_pretraining_data_loader(test_ds, 0)
|
||||
|
||||
# Flags to know if we need to do training/validation/testing.
|
||||
do_train = train_dataloader is not None and args.train_iters > 0
|
||||
do_valid = valid_dataloader is not None and args.eval_iters > 0
|
||||
do_test = test_dataloader is not None and args.eval_iters > 0
|
||||
# Need to broadcast num_tokens and num_type_tokens.
|
||||
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
|
||||
else:
|
||||
train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators(
|
||||
train_valid_test_dataset_provider
|
||||
flags = torch.cuda.LongTensor([0, 0, 0])
|
||||
|
||||
# Broadcast num tokens.
|
||||
torch.distributed.broadcast(
|
||||
flags, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()
|
||||
)
|
||||
args.do_train = flags[0].item()
|
||||
args.do_valid = flags[1].item()
|
||||
args.do_test = flags[2].item()
|
||||
|
||||
# Build iterators.
|
||||
dl_type = args.dataloader_type
|
||||
assert dl_type in ["single", "cyclic"]
|
||||
|
||||
if train_dataloader is not None:
|
||||
train_data_iterator = (
|
||||
iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader))
|
||||
)
|
||||
else:
|
||||
train_data_iterator = None
|
||||
|
||||
if valid_dataloader is not None:
|
||||
valid_data_iterator = (
|
||||
iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader))
|
||||
)
|
||||
else:
|
||||
valid_data_iterator = None
|
||||
|
||||
if test_dataloader is not None:
|
||||
test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader))
|
||||
else:
|
||||
test_data_iterator = None
|
||||
|
||||
return train_data_iterator, valid_data_iterator, test_data_iterator
|
||||
|
||||
@ -265,6 +342,7 @@ def prepare_data_loader(accelerator, dataloader):
|
||||
if not args.megatron_dataset_flag:
|
||||
from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader
|
||||
|
||||
args = get_args()
|
||||
micro_batch_size = args.micro_batch_size * args.num_micro_batches
|
||||
kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS}
|
||||
if kwargs["batch_size"] is None:
|
||||
@ -299,26 +377,11 @@ def prepare_data_loader(accelerator, dataloader):
|
||||
) = args.consumed_samples
|
||||
else:
|
||||
args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0
|
||||
args.micro_batch_size = args.micro_batch_size * args.num_micro_batches
|
||||
(
|
||||
train_data_iterator,
|
||||
valid_data_iterator,
|
||||
test_data_iterator,
|
||||
) = dataloader.build_train_valid_test_data_iterators(accelerator)
|
||||
args.micro_batch_size = args.micro_batch_size // args.num_micro_batches
|
||||
|
||||
class DummyMegatronDataloader:
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return {}
|
||||
|
||||
if train_data_iterator is None:
|
||||
train_data_iterator = DummyMegatronDataloader()
|
||||
valid_data_iterator = DummyMegatronDataloader()
|
||||
test_data_iterator = DummyMegatronDataloader()
|
||||
|
||||
) = dataloader.build_train_valid_test_data_iterators()
|
||||
return train_data_iterator, valid_data_iterator, test_data_iterator
|
||||
|
||||
|
||||
@ -342,12 +405,7 @@ class MegatronLMOptimizerWrapper(AcceleratedOptimizer):
|
||||
def prepare_optimizer(accelerator, model):
|
||||
accelerator.print("Preparing optimizer")
|
||||
args = get_args()
|
||||
kwargs = {}
|
||||
for f in dataclasses.fields(OptimizerConfig):
|
||||
if hasattr(args, f.name):
|
||||
kwargs[f.name] = getattr(args, f.name)
|
||||
config = OptimizerConfig(**kwargs)
|
||||
optimizer = get_megatron_optimizer(config, model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult)
|
||||
optimizer = get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult)
|
||||
return optimizer
|
||||
|
||||
|
||||
@ -396,7 +454,7 @@ class AbstractTrainStep(ABC):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
||||
def get_batch_func(self):
|
||||
pass
|
||||
|
||||
def get_forward_step_func(self):
|
||||
@ -414,9 +472,9 @@ class BertTrainStep(AbstractTrainStep):
|
||||
args (`argparse.Namespace`): Megatron-LM arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, accelerator, args):
|
||||
def __init__(self, args):
|
||||
super().__init__("BertTrainStep")
|
||||
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
|
||||
self.get_batch = self.get_batch_func(args.megatron_dataset_flag)
|
||||
self.loss_func = self.get_loss_func(args.pretraining_flag, args.num_labels)
|
||||
self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head)
|
||||
if not args.model_return_dict:
|
||||
@ -424,7 +482,7 @@ class BertTrainStep(AbstractTrainStep):
|
||||
else:
|
||||
self.model_output_class = SequenceClassifierOutput
|
||||
|
||||
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
||||
def get_batch_func(self, megatron_dataset_flag):
|
||||
def get_batch_megatron(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
@ -437,7 +495,7 @@ class BertTrainStep(AbstractTrainStep):
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens = data_b["text"].long()
|
||||
@ -474,8 +532,6 @@ class BertTrainStep(AbstractTrainStep):
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
|
||||
return accelerator.state.megatron_lm_plugin.custom_get_batch_function
|
||||
if megatron_dataset_flag:
|
||||
return get_batch_megatron
|
||||
else:
|
||||
@ -545,9 +601,9 @@ class GPTTrainStep(AbstractTrainStep):
|
||||
args (`argparse.Namespace`): Megatron-LM arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, accelerator, args):
|
||||
def __init__(self, args):
|
||||
super().__init__("GPTTrainStep")
|
||||
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
|
||||
self.get_batch = self.get_batch_func(args.megatron_dataset_flag)
|
||||
self.loss_func = self.get_loss_func()
|
||||
self.forward_step = self.get_forward_step_func()
|
||||
self.eod_token = args.padded_vocab_size - 1
|
||||
@ -562,7 +618,7 @@ class GPTTrainStep(AbstractTrainStep):
|
||||
else:
|
||||
self.model_output_class = CausalLMOutputWithCrossAttentions
|
||||
|
||||
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
||||
def get_batch_func(self, megatron_dataset_flag):
|
||||
def get_batch_megatron(data_iterator):
|
||||
"""Generate a batch"""
|
||||
# Items and their type.
|
||||
@ -574,7 +630,7 @@ class GPTTrainStep(AbstractTrainStep):
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b["text"].long()
|
||||
@ -604,8 +660,6 @@ class GPTTrainStep(AbstractTrainStep):
|
||||
)
|
||||
return tokens, labels, loss_mask, attention_mask, position_ids
|
||||
|
||||
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
|
||||
return accelerator.state.megatron_lm_plugin.custom_get_batch_function
|
||||
if megatron_dataset_flag:
|
||||
return get_batch_megatron
|
||||
else:
|
||||
@ -621,20 +675,7 @@ class GPTTrainStep(AbstractTrainStep):
|
||||
losses = output_tensor
|
||||
losses = losses.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
if args.context_parallel_size > 1:
|
||||
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
|
||||
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
|
||||
loss = loss[0] / loss[1]
|
||||
else:
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
||||
|
||||
# Check individual rank losses are not NaN prior to DP all-reduce.
|
||||
if args.check_for_nan_in_loss_and_grad:
|
||||
global_rank = torch.distributed.get_rank()
|
||||
assert not loss.isnan(), (
|
||||
f"Rank {global_rank}: found NaN in local forward loss calculation. "
|
||||
f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}"
|
||||
)
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
||||
|
||||
# Reduce loss for logging.
|
||||
averaged_loss = average_losses_across_data_parallel_group([loss])
|
||||
@ -666,9 +707,9 @@ class T5TrainStep(AbstractTrainStep):
|
||||
args (`argparse.Namespace`): Megatron-LM arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, accelerator, args):
|
||||
def __init__(self, args):
|
||||
super().__init__("T5TrainStep")
|
||||
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag)
|
||||
self.get_batch = self.get_batch_func(args.megatron_dataset_flag)
|
||||
self.loss_func = self.get_loss_func()
|
||||
self.forward_step = self.get_forward_step_func()
|
||||
if not args.model_return_dict:
|
||||
@ -707,7 +748,7 @@ class T5TrainStep(AbstractTrainStep):
|
||||
extended_attention_mask = attention_mask_bss < 0.5
|
||||
return extended_attention_mask
|
||||
|
||||
def get_batch_func(self, accelerator, megatron_dataset_flag):
|
||||
def get_batch_func(self, megatron_dataset_flag):
|
||||
def get_batch_megatron(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
@ -719,7 +760,7 @@ class T5TrainStep(AbstractTrainStep):
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens_enc = data_b["text_enc"].long()
|
||||
@ -756,8 +797,6 @@ class T5TrainStep(AbstractTrainStep):
|
||||
|
||||
return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask
|
||||
|
||||
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None:
|
||||
return accelerator.state.megatron_lm_plugin.custom_get_batch_function
|
||||
if megatron_dataset_flag:
|
||||
return get_batch_megatron
|
||||
else:
|
||||
@ -851,6 +890,8 @@ def initialize(accelerator, extra_args_provider=None, args_defaults={}):
|
||||
print(f"> setting random seeds to {args.seed} ...")
|
||||
_set_random_seed(args.seed, args.data_parallel_random_init)
|
||||
|
||||
args = get_args()
|
||||
|
||||
# Megatron's MPU is the master. Complete initialization right away.
|
||||
finish_mpu_init()
|
||||
|
||||
@ -863,8 +904,7 @@ def initialize(accelerator, extra_args_provider=None, args_defaults={}):
|
||||
# Set pytorch JIT layer fusion options and warmup JIT functions.
|
||||
set_jit_fusion_options()
|
||||
args = get_args()
|
||||
if getattr(args, "padded_vocab_size", None) is None:
|
||||
args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args)
|
||||
args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args)
|
||||
if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2:
|
||||
args.bert_binary_head = True
|
||||
else:
|
||||
@ -895,11 +935,11 @@ class MegatronEngine(torch.nn.Module):
|
||||
args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs
|
||||
)
|
||||
elif args.model_type_name == "bert":
|
||||
self.train_step_handler = BertTrainStep(accelerator, args)
|
||||
self.train_step_handler = BertTrainStep(args)
|
||||
elif args.model_type_name == "gpt":
|
||||
self.train_step_handler = GPTTrainStep(accelerator, args)
|
||||
self.train_step_handler = GPTTrainStep(args)
|
||||
elif args.model_type_name == "t5":
|
||||
self.train_step_handler = T5TrainStep(accelerator, args)
|
||||
self.train_step_handler = T5TrainStep(args)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {args.model_type_name}")
|
||||
self.optimizer.skipped_iter = False
|
||||
@ -909,38 +949,12 @@ class MegatronEngine(torch.nn.Module):
|
||||
self.eval_total_loss_dict = {}
|
||||
self.iteration = 0
|
||||
self.report_memory_flag = True
|
||||
self.num_floating_point_operations_so_far = 0
|
||||
if args.tensorboard_dir is not None:
|
||||
write_args_to_tensorboard()
|
||||
|
||||
def train(self):
|
||||
for model_module in self.module:
|
||||
model_module.train()
|
||||
|
||||
args = get_args()
|
||||
config = get_model_config(self.module[0])
|
||||
# Setup some training config params
|
||||
config.grad_scale_func = self.optimizer.scale_loss
|
||||
if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce:
|
||||
assert config.no_sync_func is None, (
|
||||
"When overlap_grad_reduce is True, config.no_sync_func must be None; "
|
||||
"a custom no_sync_func is not supported when overlapping grad-reduce"
|
||||
)
|
||||
config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module]
|
||||
if len(self.module) == 1:
|
||||
config.no_sync_func = config.no_sync_func[0]
|
||||
if args.delay_grad_reduce:
|
||||
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module]
|
||||
if len(self.module) == 1:
|
||||
config.grad_sync_func = config.grad_sync_func[0]
|
||||
if args.overlap_param_gather and args.delay_param_gather:
|
||||
config.param_sync_func = [
|
||||
lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module))
|
||||
]
|
||||
if len(self.module) == 1:
|
||||
config.param_sync_func = config.param_sync_func[0]
|
||||
config.finalize_model_grads_func = finalize_model_grads
|
||||
|
||||
self.log_eval_results()
|
||||
|
||||
def eval(self):
|
||||
@ -956,10 +970,10 @@ class MegatronEngine(torch.nn.Module):
|
||||
"""
|
||||
|
||||
args = get_args()
|
||||
config = get_model_config(self.module[0])
|
||||
timers = get_timers()
|
||||
|
||||
data_chunks = []
|
||||
if len(batch_data) > 0:
|
||||
data_chunks = []
|
||||
if args.num_micro_batches > 1:
|
||||
for i in range(0, args.num_micro_batches):
|
||||
data_chunks.append(
|
||||
@ -980,18 +994,73 @@ class MegatronEngine(torch.nn.Module):
|
||||
else:
|
||||
batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None
|
||||
|
||||
loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
|
||||
forward_step_func=self.train_step_handler.forward_step,
|
||||
data_iterator=batch_data_iterator,
|
||||
model=self.module,
|
||||
optimizer=self.optimizer,
|
||||
opt_param_scheduler=self.scheduler,
|
||||
config=config,
|
||||
# Set grad to zero.
|
||||
if args.DDP_impl == "local" and args.use_contiguous_buffers_in_local_ddp:
|
||||
for partition in self.module:
|
||||
partition.zero_grad_buffer()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Forward pass.
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
losses_reduced = forward_backward_func(
|
||||
self.train_step_handler.forward_step,
|
||||
batch_data_iterator,
|
||||
self.module,
|
||||
self.optimizer,
|
||||
None,
|
||||
forward_only=False,
|
||||
)
|
||||
|
||||
self.optimizer.skipped_iter = skipped_iter == 1
|
||||
# Empty unused memory.
|
||||
if args.empty_unused_memory_level >= 1:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
# Reduce gradients.
|
||||
timers("backward-reduce-model-grads").start()
|
||||
self.optimizer.reduce_model_grads(args, timers)
|
||||
timers("backward-reduce-model-grads").stop()
|
||||
|
||||
# Update parameters.
|
||||
timers("optimizer").start()
|
||||
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(args, timers)
|
||||
timers("optimizer").stop()
|
||||
|
||||
# Gather params.
|
||||
if update_successful:
|
||||
timers("backward-gather-model-params").start()
|
||||
self.optimizer.gather_model_params(args, timers)
|
||||
timers("backward-gather-model-params").stop()
|
||||
|
||||
# Update learning rate.
|
||||
if update_successful:
|
||||
if self.scheduler is not None:
|
||||
increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
|
||||
self.scheduler.step(increment=increment)
|
||||
skipped_iter = 0
|
||||
else:
|
||||
skipped_iter = 1
|
||||
|
||||
self.optimizer.skipped_iter = not update_successful
|
||||
|
||||
# Empty unused memory.
|
||||
if args.empty_unused_memory_level >= 2:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
args.consumed_train_samples += (
|
||||
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
|
||||
)
|
||||
|
||||
if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# Average loss across microbatches.
|
||||
loss_reduced = {}
|
||||
for key in losses_reduced[0]:
|
||||
losses_reduced_for_key = [x[key] for x in losses_reduced]
|
||||
if len(losses_reduced_for_key[0].shape) == 0:
|
||||
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
|
||||
else:
|
||||
loss_reduced[key] = torch.concat(losses_reduced_for_key)
|
||||
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
return {}, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
|
||||
def eval_step(self, **batch_data):
|
||||
"""
|
||||
@ -1017,12 +1086,11 @@ class MegatronEngine(torch.nn.Module):
|
||||
batch_data_iterator = iter(data_chunks)
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
loss_dicts = forward_backward_func(
|
||||
forward_step_func=self.train_step_handler.forward_step,
|
||||
data_iterator=batch_data_iterator,
|
||||
model=self.module,
|
||||
num_microbatches=get_num_microbatches(),
|
||||
seq_length=args.seq_length,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
self.train_step_handler.forward_step,
|
||||
batch_data_iterator,
|
||||
self.module,
|
||||
optimizer=None,
|
||||
timers=None,
|
||||
forward_only=True,
|
||||
)
|
||||
# Empty unused memory
|
||||
@ -1065,9 +1133,6 @@ class MegatronEngine(torch.nn.Module):
|
||||
if self.module[0].training:
|
||||
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data)
|
||||
self.iteration += 1
|
||||
batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
|
||||
args.consumed_train_samples += batch_size
|
||||
self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size)
|
||||
if args.tensorboard_dir is not None:
|
||||
# Logging.
|
||||
loss_scale = self.optimizer.get_loss_scale().item()
|
||||
@ -1141,13 +1206,7 @@ class MegatronEngine(torch.nn.Module):
|
||||
args = get_args()
|
||||
args.save = output_dir
|
||||
torch.distributed.barrier()
|
||||
save_checkpoint(
|
||||
self.iteration,
|
||||
self.module,
|
||||
self.optimizer,
|
||||
self.scheduler,
|
||||
num_floating_point_operations_so_far=self.num_floating_point_operations_so_far,
|
||||
)
|
||||
save_checkpoint(self.iteration, self.module, self.optimizer, self.scheduler)
|
||||
torch.distributed.barrier()
|
||||
|
||||
def load_checkpoint(self, input_dir):
|
||||
@ -1156,10 +1215,9 @@ class MegatronEngine(torch.nn.Module):
|
||||
args.consumed_train_samples = 0
|
||||
args.consumed_valid_samples = 0
|
||||
torch.distributed.barrier()
|
||||
iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler)
|
||||
iteration = load_checkpoint(self.module, self.optimizer, self.scheduler)
|
||||
torch.distributed.barrier()
|
||||
self.iteration = iteration
|
||||
self.num_floating_point_operations_so_far = num_floating_point_operations_so_far
|
||||
if args.fp16 and self.iteration == 0:
|
||||
self.optimizer.reload_model_params()
|
||||
|
||||
|
||||
@ -340,6 +340,11 @@ def set_module_tensor_to_device(
|
||||
and value.data_ptr() in tied_params_map
|
||||
and device in tied_params_map[value.data_ptr()]
|
||||
):
|
||||
print("using value from tied_params_map value")
|
||||
print(tensor_name)
|
||||
print(value)
|
||||
print(value.data_ptr())
|
||||
print(tied_params_map[value.data_ptr()][device])
|
||||
module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
|
||||
return
|
||||
elif (
|
||||
@ -347,6 +352,11 @@ def set_module_tensor_to_device(
|
||||
and old_value.data_ptr() in tied_params_map
|
||||
and device in tied_params_map[old_value.data_ptr()]
|
||||
):
|
||||
print("using value from tied_params_map old_value")
|
||||
print(tensor_name)
|
||||
print(value)
|
||||
print(old_value.data_ptr())
|
||||
print(tied_params_map[old_value.data_ptr()][device])
|
||||
module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
|
||||
return
|
||||
|
||||
@ -466,6 +476,8 @@ def set_module_tensor_to_device(
|
||||
and device not in tied_params_map[old_value.data_ptr()]
|
||||
):
|
||||
tied_params_map[old_value.data_ptr()][device] = new_value
|
||||
print("tied_map updated 1 ")
|
||||
print(tied_params_map)
|
||||
elif (
|
||||
value is not None
|
||||
and tied_params_map is not None
|
||||
@ -473,6 +485,8 @@ def set_module_tensor_to_device(
|
||||
and device not in tied_params_map[value.data_ptr()]
|
||||
):
|
||||
tied_params_map[value.data_ptr()][device] = new_value
|
||||
print("tied_map updated 2")
|
||||
print(tied_params_map)
|
||||
|
||||
|
||||
def named_module_tensors(
|
||||
|
||||
@ -18,7 +18,7 @@ from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..state import PartialState
|
||||
from ..state import AcceleratorState
|
||||
from .constants import CUDA_DISTRIBUTED_TYPES
|
||||
from .dataclasses import DistributedType, RNGType
|
||||
from .imports import is_mlu_available, is_npu_available, is_torch_xla_available, is_xpu_available
|
||||
@ -41,7 +41,7 @@ def set_seed(seed: int, device_specific: bool = False, deterministic: bool = Fal
|
||||
Whether to use deterministic algorithms where available. Can slow down training.
|
||||
"""
|
||||
if device_specific:
|
||||
seed += PartialState().process_index
|
||||
seed += AcceleratorState().process_index
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -84,7 +84,7 @@ def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optiona
|
||||
rng_state = generator.get_state()
|
||||
|
||||
# Broadcast the rng state from device 0 to other devices
|
||||
state = PartialState()
|
||||
state = AcceleratorState()
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
rng_state = rng_state.to(xm.xla_device())
|
||||
xm.collective_broadcast([rng_state])
|
||||
|
||||
@ -1,76 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from accelerate import PartialState, Accelerator
|
||||
from accelerate.test_utils.testing import assert_exception
|
||||
from accelerate.utils.dataclasses import DistributedType
|
||||
from accelerate.utils.operations import (
|
||||
DistributedOperationException,
|
||||
broadcast,
|
||||
copy_tensor_to_devices,
|
||||
gather,
|
||||
gather_object,
|
||||
pad_across_processes,
|
||||
reduce,
|
||||
)
|
||||
|
||||
|
||||
def create_tensor(state):
|
||||
return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)
|
||||
|
||||
|
||||
def test_gather(state):
|
||||
tensor = create_tensor(state)
|
||||
gathered_tensor = gather(tensor)
|
||||
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))
|
||||
|
||||
|
||||
def test_gather_object(state):
|
||||
# Gather objects in TorchXLA is not supported.
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
return
|
||||
obj = [state.process_index]
|
||||
gathered_obj = gather_object(obj)
|
||||
assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
|
||||
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = Accelerator()
|
||||
state = accelerator.state
|
||||
if state.local_process_index == 0:
|
||||
print("**Initialization**")
|
||||
state.wait_for_everyone()
|
||||
|
||||
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||
num_processes_per_node = torch.cuda.device_count()
|
||||
else:
|
||||
num_processes_per_node = state.num_processes
|
||||
|
||||
# We only run this test on non-multinode
|
||||
if state.process_index == 0:
|
||||
print("\n**Test gather operation**")
|
||||
test_gather(state)
|
||||
if state.process_index == 0:
|
||||
print("\n**Test gather_object operation**")
|
||||
test_gather_object(state)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -145,6 +145,50 @@ class ModelWithUnusedSubModulesForTest(nn.Module):
|
||||
return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))
|
||||
|
||||
|
||||
# To test dispatch with tied weights
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self, ref_to_parameter):
|
||||
super().__init__()
|
||||
self.parameter = ref_to_parameter
|
||||
|
||||
def forward(self, x):
|
||||
return x + torch.max(self.parameter)
|
||||
|
||||
|
||||
class LinearModuleAndSubModule(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features, name):
|
||||
super().__init__(in_features, out_features, bias=False)
|
||||
print("init weights")
|
||||
self.name = name
|
||||
self.weight_submodule = SubModule(self.weight)
|
||||
self.weight_submodule2 = SubModule(self.weight)
|
||||
self.weight_submodule3 = SubModule(self.weight)
|
||||
self.weight_submodule4 = SubModule(self.weight)
|
||||
|
||||
def forward(self, x):
|
||||
print("weight")
|
||||
print(self.weight)
|
||||
print("name")
|
||||
print(self.name)
|
||||
a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)
|
||||
b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)
|
||||
c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)
|
||||
d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)
|
||||
return a + b + c + d
|
||||
|
||||
|
||||
class ModelWithSubmodules(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.module1 = LinearModuleAndSubModule(5000, 5000, "1")
|
||||
self.module2 = LinearModuleAndSubModule(5000, 5000, "2")
|
||||
|
||||
def forward(self, x):
|
||||
a = self.module1(x)
|
||||
b = self.module2(x)
|
||||
return a + b
|
||||
|
||||
|
||||
class BigModelingTester(unittest.TestCase):
|
||||
def test_init_empty_weights(self):
|
||||
# base use
|
||||
@ -484,53 +528,14 @@ class BigModelingTester(unittest.TestCase):
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
# This test fails because sometimes data_ptr() of compute2.weight is the same as compute1.weight.
|
||||
# I checked that the values are not the same but it gives the same address. This does not happen on my local machine.
|
||||
@require_cuda
|
||||
@unittest.skip(
|
||||
"Flaky test, we should have enough coverage with test_dispatch_model_tied_weights_memory_with_nested_offload_cpu test"
|
||||
)
|
||||
def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self):
|
||||
# Test that we do not duplicate tied weights at any point during dispatch_model call.
|
||||
|
||||
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self, ref_to_parameter):
|
||||
super().__init__()
|
||||
self.parameter = ref_to_parameter
|
||||
|
||||
def forward(self, x):
|
||||
return x + torch.max(self.parameter)
|
||||
|
||||
class LinearModuleAndSubModule(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__(in_features, out_features, bias=False)
|
||||
self.weight_submodule = SubModule(self.weight)
|
||||
self.weight_submodule2 = SubModule(self.weight)
|
||||
self.weight_submodule3 = SubModule(self.weight)
|
||||
self.weight_submodule4 = SubModule(self.weight)
|
||||
|
||||
def forward(self, x):
|
||||
a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)
|
||||
b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)
|
||||
c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)
|
||||
d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)
|
||||
return a + b + c + d
|
||||
|
||||
class ModelWithSubmodules(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.compute = LinearModuleAndSubModule(5000, 5000)
|
||||
self.compute1 = LinearModuleAndSubModule(5000, 5000)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.compute(x)
|
||||
b = self.compute1(x)
|
||||
return a + b
|
||||
|
||||
# We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.
|
||||
device_map = {"compute": 0, "compute1": "disk"}
|
||||
device_map = {"module1": 0, "module2": "disk"}
|
||||
|
||||
model = ModelWithSubmodules()
|
||||
|
||||
@ -550,7 +555,13 @@ class BigModelingTester(unittest.TestCase):
|
||||
|
||||
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
print("before dispatch")
|
||||
print(model.module1.weight)
|
||||
print(model.module2.weight)
|
||||
dispatch_model(model, device_map, offload_dir=tmp_dir)
|
||||
print("after dispatch")
|
||||
print(model.module1.weight)
|
||||
print(model.module2.weight)
|
||||
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
|
||||
assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130
|
||||
@ -564,7 +575,6 @@ class BigModelingTester(unittest.TestCase):
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
assert torch.allclose(expected, output.cpu(), atol=1e-5)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
@ -573,16 +583,16 @@ class BigModelingTester(unittest.TestCase):
|
||||
|
||||
# Check that we have no more references on GPU for the offloaded tied weight.
|
||||
n_non_empty = 0
|
||||
for pointer, pointer_dict in model.compute1.weight_submodule._hf_hook.tied_params_map.items():
|
||||
for pointer, pointer_dict in model.module1.weight_submodule._hf_hook.tied_params_map.items():
|
||||
if len(pointer_dict) > 0:
|
||||
n_non_empty += 1
|
||||
assert n_non_empty == 1 # `compute` layer one.
|
||||
assert n_non_empty == 1 # `module1` layer one.
|
||||
|
||||
n_non_empty = 0
|
||||
for pointer, pointer_dict in model.compute1._hf_hook.tied_params_map.items():
|
||||
for pointer, pointer_dict in model.module1._hf_hook.tied_params_map.items():
|
||||
if len(pointer_dict) > 0:
|
||||
n_non_empty += 1
|
||||
assert n_non_empty == 1 # `compute` layer one.
|
||||
assert n_non_empty == 1 # `module1` layer one.
|
||||
|
||||
assert (free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130
|
||||
|
||||
|
||||
@ -456,7 +456,7 @@ class ModelEstimatorTester(unittest.TestCase):
|
||||
args = self.parser.parse_args(["bert-base-cased", "--dtypes", "float32", "float16"])
|
||||
output = gather_data(args)
|
||||
# The largest layer and total size of the model in bytes
|
||||
largest_layer, total_size = 90669056, 433249280
|
||||
largest_layer, total_size = 89075712, 433249280
|
||||
# Check that full precision -> int4 is calculating correctly
|
||||
assert len(output) == 2, f"Output was missing a precision, expected 2 but received {len(output)}"
|
||||
|
||||
@ -484,7 +484,7 @@ class ModelEstimatorTester(unittest.TestCase):
|
||||
args = self.parser.parse_args(["bert-base-cased", "--dtypes", "float32"])
|
||||
output = gather_data(args)
|
||||
# The largest layer and total size of the model in bytes
|
||||
largest_layer, total_size = 90669056, 433249280
|
||||
largest_layer, total_size = 89075712, 433249280
|
||||
assert (
|
||||
largest_layer == output[0][1]
|
||||
), f"Calculation for largest layer size in `fp32` is incorrect, expected {largest_layer} but received {output[0][1]}"
|
||||
|
||||
@ -41,7 +41,6 @@ class MultiDeviceTester(unittest.TestCase):
|
||||
data_loop_file_path = path_in_accelerate_package("test_utils", "scripts", "test_distributed_data_loop.py")
|
||||
operation_file_path = path_in_accelerate_package("test_utils", "scripts", "test_ops.py")
|
||||
pippy_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_pippy.py")
|
||||
merge_weights_file_path = path_in_accelerate_package("test_utils", "scripts", "test_merge_weights.py")
|
||||
|
||||
@require_multi_device
|
||||
def test_multi_device(self):
|
||||
@ -64,13 +63,6 @@ class MultiDeviceTester(unittest.TestCase):
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd)
|
||||
|
||||
@require_multi_device
|
||||
def test_multi_device_merge_fsdp_weights(self):
|
||||
print(f"Found {device_count} devices.")
|
||||
cmd = DEFAULT_LAUNCH_COMMAND + [self.merge_weights_file_path]
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd)
|
||||
|
||||
@require_non_torch_xla
|
||||
@require_multi_gpu
|
||||
def test_distributed_data_loop(self):
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiThreadedTestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
from accelerate import Accelerator, PartialState
|
||||
from accelerate.test_utils import device_count
|
||||
|
||||
class TrainingTester(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return device_count
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
# Verify we are running in multiproc
|
||||
def test_distributed_spawning(self):
|
||||
state = PartialState()
|
||||
assert state.local_process_index == torch.distributed.get_rank()
|
||||
assert state.num_processes == torch.distributed.get_world_size()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -224,8 +224,6 @@ class CometMLTest(unittest.TestCase):
|
||||
if "metric" in j.keys():
|
||||
if j["metric"]["metricName"] == key:
|
||||
return j["metric"]["metricValue"]
|
||||
if j.get("key", None) == key:
|
||||
return j["value"]
|
||||
|
||||
def test_init_trackers(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
|
||||
Reference in New Issue
Block a user