mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-12 06:54:28 +08:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 499a5e506a | |||
| e93cb7a3bd | |||
| de3a54137a | |||
| 379b3d7c09 | |||
| b3d181fa6b | |||
| 93ee98d2b4 | |||
| adcb68b17d | |||
| 13656cda38 | |||
| 5ab0a2d6f7 | |||
| f7e0c26881 | |||
| d46d1e85fd | |||
| 8e61853039 | |||
| 82971af8c5 | |||
| 703c702ecb | |||
| f477f935b6 | |||
| 2575bc829e | |||
| 25bf0bcafb | |||
| b206aad14b | |||
| 1320618cf7 | |||
| e67aa2e525 | |||
| 38c4138de0 | |||
| cd37674729 | |||
| 03a5f8870d | |||
| 68013d06b9 | |||
| 16d20d7bc9 | |||
| 1495069ad1 | |||
| 58d58a1a8d | |||
| 9a4ba4ab90 |
2
.github/deploy_doc.sh
vendored
2
.github/deploy_doc.sh
vendored
@ -35,4 +35,4 @@ function deploy_doc(){
|
||||
|
||||
# You can find the commit for each tag on https://github.com/huggingface/accelerate/tags
|
||||
deploy_doc "main" main
|
||||
deploy_doc "main" # No stable-release yet
|
||||
deploy_doc "0fbbbc5" # v0.1.0 Latest stable release
|
||||
17
.github/workflows/quality.yml
vendored
Normal file
17
.github/workflows/quality.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
name: Quality Check
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
quality:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.6
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.6
|
||||
- name: Install Python dependencies
|
||||
run: pip install -e .[quality]
|
||||
- name: Run Quality check
|
||||
run: make quality
|
||||
17
.github/workflows/test.yml
vendored
Normal file
17
.github/workflows/test.yml
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
name: Run Tests
|
||||
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.6
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.6
|
||||
- name: Install Python dependencies
|
||||
run: pip install -e .[test]
|
||||
- name: Run Tests
|
||||
run: make test
|
||||
186
README.md
186
README.md
@ -46,86 +46,29 @@ limitations under the License.
|
||||
|
||||
## Easy to integrate
|
||||
|
||||
🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boiler code needed to use multi-GPUs/TPU/fp16.
|
||||
🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.
|
||||
|
||||
🤗 Accelerate abstracts exactly and only the boiler code related to multi-GPUs/TPU/fp16 and let the rest of your code unchanged.
|
||||
🤗 Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged.
|
||||
|
||||
Here is an example:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th> Original training code <br> (CPU or mono-GPU only)</th>
|
||||
<th> With Accelerate <br> (CPU/GPU/multi-GPUs/TPUs/fp16) </th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
|
||||
device = 'cpu'
|
||||
|
||||
model = torch.nn.Transformer().to(device)
|
||||
optim = torch.optim.Adam(
|
||||
model.parameters()
|
||||
)
|
||||
|
||||
dataset = load_dataset('my_dataset')
|
||||
data = torch.utils.data.Dataloader(
|
||||
dataset
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model.train()
|
||||
for epoch in range(10):
|
||||
for source, targets in data:
|
||||
source = source.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(source, targets)
|
||||
loss = F.cross_entropy(
|
||||
output, targets
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```python
|
||||
```diff
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
- device = 'cpu'
|
||||
+ device = accelerator.device
|
||||
|
||||
model = torch.nn.Transformer().to(device)
|
||||
optim = torch.optim.Adam(
|
||||
model.parameters()
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
dataset = load_dataset('my_dataset')
|
||||
data = torch.utils.data.Dataloader(
|
||||
dataset
|
||||
)
|
||||
data = torch.utils.data.DataLoader(dataset, shuffle=True)
|
||||
|
||||
+ model, optim, data = accelerator.prepare(
|
||||
+ model, optim, data
|
||||
+ )
|
||||
+ model, optim, data = accelerator.prepare(model, optim, data)
|
||||
|
||||
model.train()
|
||||
for epoch in range(10):
|
||||
@ -135,126 +78,59 @@ for epoch in range(10):
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(source, targets)
|
||||
loss = F.cross_entropy(
|
||||
output, targets
|
||||
)
|
||||
output = model(source)
|
||||
loss = F.cross_entropy(output, targets)
|
||||
|
||||
+ accelerate.backward(loss)
|
||||
+ accelerator.backward(loss)
|
||||
- loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
As you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp16).
|
||||
|
||||
As you can see on this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp16).
|
||||
In particular, the same code can then be run without modification on your local machine for debugging or your training environment.
|
||||
|
||||
The same code can then in particular run without modification on your local machine for debugging or your training environment.
|
||||
🤗 Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further:
|
||||
|
||||
🤗 Accelerate even handles the device placement for you (a bit more changes to your code but safer in general), so you can even simplify your training loop further:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th> Original training code <br> (CPU or mono-GPU only)</th>
|
||||
<th> With Accelerate <br> (CPU/GPU/multi-GPUs/TPUs/fp16) </th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
|
||||
device = 'cpu'
|
||||
|
||||
model = torch.nn.Transformer().to(device)
|
||||
optim = torch.optim.Adam(
|
||||
model.parameters()
|
||||
)
|
||||
|
||||
dataset = load_dataset('my_dataset')
|
||||
data = torch.utils.data.Dataloader(
|
||||
dataset
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model.train()
|
||||
for epoch in range(10):
|
||||
for source, targets in data:
|
||||
source = source.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(source, targets)
|
||||
loss = F.cross_entropy(
|
||||
output, targets
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```python
|
||||
```diff
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
+ device = accelerator.device
|
||||
- device = 'cpu'
|
||||
|
||||
+ model = torch.nn.Transformer()
|
||||
optim = torch.optim.Adam(
|
||||
model.parameters()
|
||||
)
|
||||
- model = torch.nn.Transformer().to(device)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
dataset = load_dataset('my_dataset')
|
||||
data = torch.utils.data.Dataloader(
|
||||
dataset
|
||||
)
|
||||
data = torch.utils.data.DataLoader(dataset, shuffle=True)
|
||||
|
||||
+ model, optim, data = accelerator.prepare(
|
||||
+ model, optim, data
|
||||
+ )
|
||||
+ model, optim, data = accelerator.prepare(model, optim, data)
|
||||
|
||||
model.train()
|
||||
for epoch in range(10):
|
||||
for source, targets in data:
|
||||
-
|
||||
-
|
||||
- source = source.to(device)
|
||||
- targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(source, targets)
|
||||
loss = F.cross_entropy(
|
||||
output, targets
|
||||
)
|
||||
output = model(source)
|
||||
loss = F.cross_entropy(output, targets)
|
||||
|
||||
+ accelerate.backward(loss)
|
||||
+ accelerator.backward(loss)
|
||||
- loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Launching script
|
||||
|
||||
🤗 Accelerate also provides a CLI tool that allows you to quickly configure and test your training environment then launch the scripts. No need to remember how to use `torch.distributed.launch` or to write a specific launcher for TPU training!
|
||||
🤗 Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.launch` or to write a specific launcher for TPU training!
|
||||
On your machine(s) just run:
|
||||
|
||||
```bash
|
||||
@ -270,14 +146,14 @@ accelerate launch my_script.py --args_to_my_script
|
||||
For instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo):
|
||||
|
||||
```bash
|
||||
accelerate launch examples/glue_example.py --task_name mrpc --model_name_or_path bert-base-cased
|
||||
accelerate launch examples/nlp_example.py
|
||||
```
|
||||
|
||||
## Why should I use 🤗 Accelerate?
|
||||
|
||||
You should use 🤗 Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library, In fact the whole API of 🤗 Accelerate is in one class, the `Accelerator` object.
|
||||
|
||||
## Why shouldn't use 🤗 Accelerate?
|
||||
## Why shouldn't I use 🤗 Accelerate?
|
||||
|
||||
You shouldn't use 🤗 Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, 🤗 Accelerate is not one of them.
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
// These two things need to be updated at each release for the version selector.
|
||||
// Last stable version
|
||||
const stableVersion = "v0.0.1"
|
||||
const stableVersion = "v0.1.0"
|
||||
// Dictionary doc folder to label. The last stable version should have an empty key.
|
||||
const versionMapping = {
|
||||
"main": "main",
|
||||
"": "v0.0.1 (stable)",
|
||||
"": "v0.1.0 (stable)",
|
||||
}
|
||||
|
||||
function addIcon() {
|
||||
|
||||
@ -14,17 +14,18 @@
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath('../../src'))
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../src"))
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = u'accelerate'
|
||||
copyright = u'2020, The Hugging Face Team, Licenced under the Apache License, Version 2.0'
|
||||
author = u'huggingface'
|
||||
project = "accelerate"
|
||||
copyright = "2020, The Hugging Face Team, Licenced under the Apache License, Version 2.0"
|
||||
author = "huggingface"
|
||||
|
||||
# The short X.Y version
|
||||
version = u'0.1.0'
|
||||
version = "0.2.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
@ -36,27 +37,28 @@ version = u'0.1.0'
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.extlinks',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.napoleon',
|
||||
'recommonmark',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx_markdown_tables',
|
||||
'sphinx_copybutton'
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.extlinks",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.napoleon",
|
||||
"recommonmark",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_markdown_tables",
|
||||
"sphinx_copybutton",
|
||||
"sphinxext.opengraph",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = ['.rst', '.md']
|
||||
source_suffix = [".rst", ".md"]
|
||||
# source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
master_doc = "index"
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
@ -68,7 +70,7 @@ language = None
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = None
|
||||
@ -82,20 +84,31 @@ copybutton_prompt_is_regexp = True
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
html_theme_options = {
|
||||
'analytics_id': 'UA-83738774-2'
|
||||
}
|
||||
html_theme_options = {"analytics_id": "UA-83738774-2"}
|
||||
|
||||
# Configuration for OpenGraph and Twitter Card Tags.
|
||||
# These are responsible for creating nice shareable social images https://ahrefs.com/blog/open-graph-meta-tags/
|
||||
# https://ogp.me/#type_website
|
||||
ogp_image = "https://huggingface.co/front/thumbnails/docs/accelerate.png"
|
||||
ogp_description = "Run your raw PyTorch training script on any kind of device. 🤗 Accelerate provides an easy API to make your scripts run with mixed precision and on any kind of distributed setting (multi-GPUs, TPUs etc.)"
|
||||
ogp_description_length = 160
|
||||
|
||||
ogp_custom_meta_tags = [
|
||||
f'<meta name="twitter:image" content="{ogp_image}">',
|
||||
f'<meta name="twitter:description" content="{ogp_description}">',
|
||||
]
|
||||
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
@ -107,17 +120,17 @@ html_static_path = ['_static']
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
# This must be the name of an image file (path relative to the configuration
|
||||
# directory) that is the favicon of the docs. Modern browsers use this as
|
||||
# the icon for tabs, windows and bookmarks. It should be a Windows-style
|
||||
# This must be the name of an image file (path relative to the configuration
|
||||
# directory) that is the favicon of the docs. Modern browsers use this as
|
||||
# the icon for tabs, windows and bookmarks. It should be a Windows-style
|
||||
# icon file (.ico).
|
||||
html_favicon = 'favicon.ico'
|
||||
html_favicon = "favicon.ico"
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'acceleratedoc'
|
||||
htmlhelp_basename = "acceleratedoc"
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
@ -126,15 +139,12 @@ latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
@ -144,8 +154,7 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'accelerate.tex', u'accelerate Documentation',
|
||||
u'huggingface', 'manual'),
|
||||
(master_doc, "accelerate.tex", "accelerate Documentation", "huggingface", "manual"),
|
||||
]
|
||||
|
||||
|
||||
@ -153,10 +162,7 @@ latex_documents = [
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'accelerate', u'accelerate Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
man_pages = [(master_doc, "accelerate", "accelerate Documentation", [author], 1)]
|
||||
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
@ -165,9 +171,15 @@ man_pages = [
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'accelerate', u'accelerate Documentation',
|
||||
author, 'accelerate', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
(
|
||||
master_doc,
|
||||
"accelerate",
|
||||
"accelerate Documentation",
|
||||
author,
|
||||
"accelerate",
|
||||
"One line description of project.",
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -186,11 +198,13 @@ epub_title = project
|
||||
# epub_uid = ''
|
||||
|
||||
# A list of files that should not be packed into the epub file.
|
||||
epub_exclude_files = ['search.html']
|
||||
epub_exclude_files = ["search.html"]
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_css_file('css/huggingface.css')
|
||||
app.add_css_file('css/code-snippets.css')
|
||||
app.add_js_file('js/custom.js')
|
||||
app.add_css_file("css/huggingface.css")
|
||||
app.add_css_file("css/code-snippets.css")
|
||||
app.add_js_file("js/custom.js")
|
||||
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
|
||||
@ -47,53 +47,58 @@ A traditional training loop in PyTorch looks like this:
|
||||
|
||||
Changing it to work with accelerate is really easy and only adds a few lines of code:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: diff
|
||||
|
||||
from accelerate import Accelerator
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
# Use the device given by the `accelerator` object.
|
||||
device = accelerator.device
|
||||
my_model.to(device)
|
||||
# Pass every important object (model, optimizer, dataloader) to `accelerator.prepare`
|
||||
my_model, my_optimizer, my_training_dataloader = accelerate.prepare(
|
||||
my_model, my_optimizer, my_training_dataloader
|
||||
)
|
||||
+ accelerator = Accelerator()
|
||||
# Use the device given by the `accelerator` object.
|
||||
+ device = accelerator.device
|
||||
my_model.to(device)
|
||||
# Pass every important object (model, optimizer, dataloader) to `accelerator.prepare`
|
||||
+ my_model, my_optimizer, my_training_dataloader = accelerate.prepare(
|
||||
+ my_model, my_optimizer, my_training_dataloader
|
||||
+ )
|
||||
|
||||
for batch in my_training_dataloader:
|
||||
my_optimizer.zero_grad()
|
||||
inputs, targets = batch
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
outputs = my_model(inputs)
|
||||
loss = my_loss_function(outputs, targets)
|
||||
# Just a small change for the backward instruction
|
||||
accelerate.backward(loss)
|
||||
my_optimizer.step()
|
||||
for batch in my_training_dataloader:
|
||||
my_optimizer.zero_grad()
|
||||
inputs, targets = batch
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
outputs = my_model(inputs)
|
||||
loss = my_loss_function(outputs, targets)
|
||||
# Just a small change for the backward instruction
|
||||
- loss.backward()
|
||||
+ accelerate.backward(loss)
|
||||
my_optimizer.step()
|
||||
|
||||
and with this, your script can now run in a distributed environment (multi-GPU, TPU).
|
||||
|
||||
You can even simplify your script a bit by letting 🤗 Accelerate handle the device placement for you (which is safer,
|
||||
especially for TPU training):
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: diff
|
||||
|
||||
from accelerate import Accelerator
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
# Pass every important object (model, optimizer, dataloader) to `accelerator.prepare`
|
||||
my_model, my_optimizer, my_training_dataloader = accelerate.prepare(
|
||||
my_model, my_optimizer, my_training_dataloader
|
||||
)
|
||||
+ accelerator = Accelerator()
|
||||
- my_model.to(device)
|
||||
# Pass every important object (model, optimizer, dataloader) to `accelerator.prepare`
|
||||
+ my_model, my_optimizer, my_training_dataloader = accelerate.prepare(
|
||||
+ my_model, my_optimizer, my_training_dataloader
|
||||
+ )
|
||||
|
||||
for batch in my_training_dataloader:
|
||||
my_optimizer.zero_grad()
|
||||
inputs, targets = batch
|
||||
outputs = my_model(inputs)
|
||||
loss = my_loss_function(outputs, targets)
|
||||
# Just a small change for the backward instruction
|
||||
accelerate.backward(loss)
|
||||
my_optimizer.step()
|
||||
for batch in my_training_dataloader:
|
||||
my_optimizer.zero_grad()
|
||||
inputs, targets = batch
|
||||
- inputs = inputs.to(device)
|
||||
- targets = targets.to(device)
|
||||
outputs = my_model(inputs)
|
||||
loss = my_loss_function(outputs, targets)
|
||||
# Just a small change for the backward instruction
|
||||
- loss.backward()
|
||||
+ accelerate.backward(loss)
|
||||
my_optimizer.step()
|
||||
|
||||
|
||||
Script launcher
|
||||
@ -139,10 +144,16 @@ Supported integrations
|
||||
quicktour
|
||||
installation
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Guides
|
||||
|
||||
sagemaker
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: API reference
|
||||
|
||||
accelerator
|
||||
kwargs
|
||||
internal
|
||||
|
||||
@ -78,6 +78,8 @@ Utilities
|
||||
|
||||
.. autofunction:: accelerate.utils.set_seed
|
||||
|
||||
.. autofunction:: accelerate.utils.synchronize_rng_state
|
||||
|
||||
.. autofunction:: accelerate.utils.synchronize_rng_states
|
||||
|
||||
.. autofunction:: accelerate.utils.wait_for_everyone
|
||||
|
||||
30
docs/source/kwargs.rst
Normal file
30
docs/source/kwargs.rst
Normal file
@ -0,0 +1,30 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
|
||||
Kwargs Handlers
|
||||
=======================================================================================================================
|
||||
|
||||
The following objects can be passed to the main :class:`~accelerate.Accelerator` to customize how some PyTorch objects
|
||||
related to distributed training or mixed precision are created.
|
||||
|
||||
|
||||
DistributedDataParallelKwargs
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. autoclass:: accelerate.DistributedDataParallelKwargs
|
||||
|
||||
|
||||
GradScalerKwargs
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. autoclass:: accelerate.GradScalerKwargs
|
||||
@ -130,6 +130,13 @@ do with the :meth:`~accelerate.Accelerator.gather` method.
|
||||
Any instruction using your training dataloader length (for instance if you need the number of total training steps
|
||||
to create a learning rate scheduler) should go after the call to :meth:`~accelerate.Accelerator.prepare`.
|
||||
|
||||
.. Warning::
|
||||
|
||||
The :meth:`~accelerate.Accelerator.gather` method requires the tensors to be all the same size on each process. If
|
||||
you have tensors of different sizes on each process (for instance when dynamically padding to the maximum length in
|
||||
a batch), you should use the :meth:`~accelerate.Accelerator.pad_across_processes` method to pad you tensor to the
|
||||
biggest size across processes.
|
||||
|
||||
|
||||
Launching your distributed script
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
@ -207,7 +214,8 @@ lof of time. In practice, that means you must take special care to have all your
|
||||
shape (so no dynamic padding for instance if you are in an NLP problem) and should not use layer with for loops that
|
||||
have different lengths depending on the inputs (such as an LSTM) or the training will be excruciatingly slow.
|
||||
|
||||
To introduce special behavior in your script for TPUs you can check the :obj:`distributed_type` of your :obj:`accelerator`:
|
||||
To introduce special behavior in your script for TPUs you can check the :obj:`distributed_type` of your
|
||||
:obj:`accelerator`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -340,20 +348,30 @@ library handles the sharding of your data between processes by changing that :ob
|
||||
|
||||
The :class:`~accelerate.data_loader.DataLoaderShard` subclasses :obj:`DataLoader` to add the following functionality:
|
||||
|
||||
- it synchronizes the torch random number generators of all processes at each new iteration, to ensure any
|
||||
- it synchronizes the appropriate random number generator of all processes at each new iteration, to ensure any
|
||||
randomization (like shuffling) is done the exact same way across processes.
|
||||
- it puts the batches on the proper device before yielding them (unless you have opted out of
|
||||
:obj:`device_placement=True`).
|
||||
|
||||
The random number generator synchronization will by default synchronize:
|
||||
|
||||
- the :obj:`generator` attribute of a given sampler (like the PyTorch :obj:`RandomSampler`) for PyTorch >= 1.6
|
||||
- the main random number generator in PyTorch <=1.5.1
|
||||
|
||||
You can choose which random number generator(s) to synchronize with the :obj:`rng_types` argument of the main
|
||||
:class:`~accelerate.Accelerator`. In PyTorch >= 1.6, it is recommended to rely on local :obj:`generator` to avoid
|
||||
setting the same seed in the main random number generator in all processes.
|
||||
|
||||
.. Warning::
|
||||
|
||||
The random number generator synchronization will affect any other potential random artifacts you could have in your
|
||||
dataset (like random data augmentation) in the sense all processes will get the same random numbers from the torch
|
||||
random modules (so will apply the same random data augmentation if it's controlled by torch). While this is usually
|
||||
fine, you should use the random number generator from the Python :obj:`random` module or NumPy for your data
|
||||
augmentation if you think this will be a problem.
|
||||
Synchronization the main torch (or CUDA or XLA) random number generator will affect any other potential random
|
||||
artifacts you could have in your dataset (like random data augmentation) in the sense all processes will get the
|
||||
same random numbers from the torch random modules (so will apply the same random data augmentation if it's
|
||||
controlled by torch).
|
||||
|
||||
The randomization part of your sampler on the other hand should absolutely be done using the torch random number
|
||||
generator (like in the traditional :obj:`RandomSampler`).
|
||||
.. Note::
|
||||
|
||||
The randomization part of your custom sampler, batch sampler or iterable dataset should be done using a local
|
||||
:obj:`torch.Generator` object (in PyTorch >= 1.6), see the traditional :obj:`RandomSampler`, as an example.
|
||||
|
||||
See more details about the internal in the :doc:`Internals page <internal>`.
|
||||
|
||||
169
docs/source/sagemaker.rst
Normal file
169
docs/source/sagemaker.rst
Normal file
@ -0,0 +1,169 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Amazon SageMaker
|
||||
=======================================================================================================================
|
||||
|
||||
Hugging Face and Amazon introduced new `Hugging Face Deep Learning Containers (DLCs)
|
||||
<https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-training-containers>`_ to
|
||||
make it easier than ever to train Hugging Face Transformer models in `Amazon SageMaker
|
||||
<https://aws.amazon.com/sagemaker/>`_.
|
||||
|
||||
To learn how to use the new 🤗 DLCs with the Amazon SageMaker to run your 🤗 Accelerate scripts and raw training loops.0
|
||||
|
||||
|
||||
|
||||
Getting Started
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Setup & Installation
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Before you can run your 🤗 Accelerate scripts on Amazon SageMaker you need to sign up for an AWS account. If you do not
|
||||
have an AWS account yet learn more `here <https://docs.aws.amazon.com/sagemaker/latest/dg/gs-set-up.html>`__.
|
||||
|
||||
After you have your AWS Account you need to install the ``sagemaker`` sdk for 🤗 Accelerate with.
|
||||
|
||||
.. code-block::
|
||||
|
||||
pip install "accelerate[sagemaker]" --upgrade
|
||||
|
||||
|
||||
🤗 Accelerate currently uses the 🤗 DLCs, with ``transformers``, ``datasets`` and ``tokenizers`` pre-installed. 🤗
|
||||
Accelerate is not in the DLC yet (will soon be added!) so to use it within Amazon SageMaker you need to create a
|
||||
``requirements.txt`` in the same directory where your training script is located and add it as dependency.
|
||||
|
||||
.. code-block::
|
||||
|
||||
accelerate
|
||||
|
||||
You should also add any other dependencies you have to this ``requirements.txt``.
|
||||
|
||||
|
||||
Configure 🤗 Accelerate
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can configure the launch configuration for Amazon SageMaker the same as you do for non SageMaker training jobs with
|
||||
the 🤗 Accelerate CLI.
|
||||
|
||||
.. code-block::
|
||||
|
||||
accelerate config
|
||||
# In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 1
|
||||
|
||||
|
||||
🤗 Accelerate will go through a questionnaire about your Amazon SageMaker setup and create a config file you can edit.
|
||||
|
||||
.. note::
|
||||
🤗 Accelerate is not saving any of your credentials.
|
||||
|
||||
|
||||
Prepare a 🤗 Accelerate fine-tuning script
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The training script is very similar to a training script you might run outside of SageMaker, but to save your model
|
||||
after training you need to specify either ``/opt/ml/model`` or use ``os.environ["SM_MODEL_DIR"]`` as your save
|
||||
directory. After training, artifacts in this directory are uploaded to S3.
|
||||
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
- torch.save('/opt/ml/model`)
|
||||
+ accelerator.save('/opt/ml/model')
|
||||
|
||||
|
||||
.. warning::
|
||||
SageMaker doesn’t support argparse actions. If you want to use, for example, boolean hyperparameters, you need to
|
||||
specify type as bool in your script and provide an explicit True or False value for this hyperparameter. `[REF]
|
||||
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#prepare-a-pytorch-training-script>`__.
|
||||
|
||||
|
||||
Launch Training
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can launch your training with 🤗 Accelerate CLI with
|
||||
|
||||
.. code-block::
|
||||
|
||||
accelerate launch path_to_script.py --args_to_the_script
|
||||
|
||||
|
||||
This will launch your training script using your configuration. The only thing you have to do is provide all the
|
||||
arguments needed by your training script as named arguments.
|
||||
|
||||
**Examples**
|
||||
|
||||
.. note::
|
||||
If you run one of the example scripts, don't forget to add ``accelerator.save('/opt/ml/model')`` to it.
|
||||
|
||||
.. code-block::
|
||||
|
||||
accelerate launch ./examples/sagemaker_example.py
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Configuring Amazon SageMaker environment
|
||||
Converting Arguments to Hyperparameters
|
||||
Creating Estimator
|
||||
2021-04-08 11:56:50 Starting - Starting the training job...
|
||||
2021-04-08 11:57:13 Starting - Launching requested ML instancesProfilerReport-1617883008: InProgress
|
||||
.........
|
||||
2021-04-08 11:58:54 Starting - Preparing the instances for training.........
|
||||
2021-04-08 12:00:24 Downloading - Downloading input data
|
||||
2021-04-08 12:00:24 Training - Downloading the training image..................
|
||||
2021-04-08 12:03:39 Training - Training image download completed. Training in progress..
|
||||
........
|
||||
epoch 0: {'accuracy': 0.7598039215686274, 'f1': 0.8178438661710037}
|
||||
epoch 1: {'accuracy': 0.8357843137254902, 'f1': 0.882249560632689}
|
||||
epoch 2: {'accuracy': 0.8406862745098039, 'f1': 0.8869565217391304}
|
||||
........
|
||||
2021-04-08 12:05:40 Uploading - Uploading generated training model
|
||||
2021-04-08 12:05:40 Completed - Training job completed
|
||||
Training seconds: 331
|
||||
Billable seconds: 331
|
||||
You can find your model data at: s3://your-bucket/accelerate-sagemaker-1-2021-04-08-11-56-47-108/output/model.tar.gz
|
||||
|
||||
|
||||
|
||||
Advanced Features
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Distributed Training: Data Parallelism
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
*currently in development, will be supported soon.*
|
||||
|
||||
Distributed Training: Model Parallelism
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
*currently in development, will be supported soon.*
|
||||
|
||||
Python packages and dependencies
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
🤗 Accelerate currently uses the 🤗 DLCs, with ``transformers``, ``datasets`` and ``tokenizers`` pre-installed. If you
|
||||
want to use different/other Python packages you can do this by adding them to the ``requirements.txt``. These packages
|
||||
will be installed before your training script is started.
|
||||
|
||||
Remote scripts: Use scripts located on Github
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
*undecided if feature is needed. Contact us if you would like this feature.*
|
||||
|
||||
Use Spot Instances
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
*undecided if feature is needed. Contact us if you would like this feature.*
|
||||
@ -18,7 +18,13 @@ limitations under the License.
|
||||
|
||||
## Simple NLP example
|
||||
|
||||
The [nlp_example.py](./nlp_example.py) script is a simple example to train a Bert model on a classification task ([GLUE's MRPC]()).
|
||||
The [nlp_example.py](./nlp_example.py) script is a simple example to train a Bert model on a classification task ([GLUE's MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398)).
|
||||
|
||||
Prior to running it you should install 🤗 Dataset and 🤗 Transformers:
|
||||
|
||||
```bash
|
||||
pip install datasets, transformers
|
||||
```
|
||||
|
||||
The same script can be run in any of the following configurations:
|
||||
- single CPU or single GPU
|
||||
@ -89,3 +95,91 @@ To run it in each of these various modes, use the following commands:
|
||||
```
|
||||
* In PyTorch:
|
||||
Add an `xmp.spawn` line in your script as you usually do.
|
||||
|
||||
|
||||
## Simple vision example
|
||||
|
||||
The [cv_example.py](./cv_example.py) script is a simple example to fine-tune a ResNet-50 on a classification task ([Ofxord-IIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/)).
|
||||
|
||||
The same script can be run in any of the following configurations:
|
||||
- single CPU or single GPU
|
||||
- multi GPUS (using PyTorch distributed mode)
|
||||
- (multi) TPUs
|
||||
- fp16 (mixed-precision) or fp32 (normal precision)
|
||||
|
||||
Prior to running it you should install timm and torchvision:
|
||||
|
||||
```bash
|
||||
pip install timm, torchvision
|
||||
```
|
||||
|
||||
and you should download the data with the following commands:
|
||||
|
||||
```bash
|
||||
wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
|
||||
tar -xzf images.tar.gz
|
||||
```
|
||||
|
||||
To run it in each of these various modes, use the following commands:
|
||||
- single CPU:
|
||||
* from a server without GPU
|
||||
```bash
|
||||
python ./cv_example.py --data_dir path_to_data
|
||||
```
|
||||
* from any server by passing `cpu=True` to the `Accelerator`.
|
||||
```bash
|
||||
python ./cv_example.py --data_dir path_to_data --cpu
|
||||
```
|
||||
* from any server with Accelerate launcher
|
||||
```bash
|
||||
accelerate launch --cpu ./cv_example.py --data_dir path_to_data
|
||||
```
|
||||
- single GPU:
|
||||
```bash
|
||||
python ./nlp_example.py # from a server with a GPU
|
||||
```
|
||||
- with fp16 (mixed-precision)
|
||||
* from any server by passing `fp16=True` to the `Accelerator`.
|
||||
```bash
|
||||
python ./cv_example.py --data_dir path_to_data --fp16
|
||||
```
|
||||
* from any server with Accelerate launcher
|
||||
```bash
|
||||
accelerate launch --fb16 ./cv_example.py --data_dir path_to_data
|
||||
- multi GPUS (using PyTorch distributed mode)
|
||||
* With Accelerate config and launcher
|
||||
```bash
|
||||
accelerate config # This will create a config file on your server
|
||||
accelerate launch ./cv_example.py --data_dir path_to_data # This will run the script on your server
|
||||
```
|
||||
* With traditional PyTorch launcher
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node 2 --use_env ./cv_example.py --data_dir path_to_data
|
||||
```
|
||||
- multi GPUs, multi node (several machines, using PyTorch distributed mode)
|
||||
* With Accelerate config and launcher, on each machine:
|
||||
```bash
|
||||
accelerate config # This will create a config file on each server
|
||||
accelerate launch ./cv_example.py --data_dir path_to_data # This will run the script on each server
|
||||
```
|
||||
* With PyTorch launcher only
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node 2 \
|
||||
--use_env \
|
||||
--node_rank 0 \
|
||||
--master_addr master_node_ip_address \
|
||||
./cv_example.py --data_dir path_to_data # On the first server
|
||||
python -m torch.distributed.launch --nproc_per_node 2 \
|
||||
--use_env \
|
||||
--node_rank 1 \
|
||||
--master_addr master_node_ip_address \
|
||||
./cv_example.py --data_dir path_to_data # On the second server
|
||||
```
|
||||
- (multi) TPUs
|
||||
* With Accelerate config and launcher
|
||||
```bash
|
||||
accelerate config # This will create a config file on your TPU server
|
||||
accelerate launch ./cv_example.py --data_dir path_to_data # This will run the script on each server
|
||||
```
|
||||
* In PyTorch:
|
||||
Add an `xmp.spawn` line in your script as you usually do.
|
||||
|
||||
197
examples/cv_example.py
Normal file
197
examples/cv_example.py
Normal file
@ -0,0 +1,197 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from timm import create_model
|
||||
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
||||
|
||||
|
||||
########################################################################
|
||||
# This is a fully working simple example to use Accelerate
|
||||
#
|
||||
# This example trains a ResNet50 on the Oxford-IIT Pet Dataset
|
||||
# in any of the following settings (with the same script):
|
||||
# - single CPU or single GPU
|
||||
# - multi GPUS (using PyTorch distributed mode)
|
||||
# - (multi) TPUs
|
||||
# - fp16 (mixed-precision) or fp32 (normal precision)
|
||||
#
|
||||
# To run it in each of these various modes, follow the instructions
|
||||
# in the readme for examples:
|
||||
# https://github.com/huggingface/accelerate/tree/main/examples
|
||||
#
|
||||
########################################################################
|
||||
|
||||
|
||||
# Function to get the label from the filename
|
||||
def extract_label(fname):
|
||||
stem = fname.split(os.path.sep)[-1]
|
||||
return re.search(r"^(.*)_\d+\.jpg$", stem).groups()[0]
|
||||
|
||||
|
||||
class PetsDataset(Dataset):
|
||||
def __init__(self, file_names, image_transform=None, label_to_id=None):
|
||||
self.file_names = file_names
|
||||
self.image_transform = image_transform
|
||||
self.label_to_id = label_to_id
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_names)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fname = self.file_names[idx]
|
||||
raw_image = PIL.Image.open(fname)
|
||||
image = raw_image.convert("RGB")
|
||||
if self.image_transform is not None:
|
||||
image = self.image_transform(image)
|
||||
label = extract_label(fname)
|
||||
if self.label_to_id is not None:
|
||||
label = self.label_to_id[label]
|
||||
return {"image": image, "label": label}
|
||||
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu)
|
||||
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
num_epochs = int(config["num_epochs"])
|
||||
seed = int(config["seed"])
|
||||
batch_size = int(config["batch_size"])
|
||||
image_size = config["image_size"]
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
image_size = (image_size, image_size)
|
||||
|
||||
# Grab all the image filenames
|
||||
file_names = [os.path.join(args.data_dir, fname) for fname in os.listdir(args.data_dir) if fname.endswith(".jpg")]
|
||||
|
||||
# Build the label correspondences
|
||||
all_labels = [extract_label(fname) for fname in file_names]
|
||||
id_to_label = list(set(all_labels))
|
||||
id_to_label.sort()
|
||||
label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}
|
||||
|
||||
# Set the seed before splitting the data.
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Split our filenames between train and validation
|
||||
random_perm = np.random.permutation(len(file_names))
|
||||
cut = int(0.8 * len(file_names))
|
||||
train_split = random_perm[:cut]
|
||||
eval_split = random_perm[cut:]
|
||||
|
||||
# For training we use a simple RandomResizedCrop
|
||||
train_tfm = Compose([RandomResizedCrop(image_size, scale=(0.5, 1.0)), ToTensor()])
|
||||
train_dataset = PetsDataset(
|
||||
[file_names[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id
|
||||
)
|
||||
|
||||
# For evaluation, we use a deterministic Resize
|
||||
eval_tfm = Compose([Resize(image_size), ToTensor()])
|
||||
eval_dataset = PetsDataset([file_names[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id)
|
||||
|
||||
# Instantiate dataloaders.
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=4)
|
||||
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, num_workers=4)
|
||||
|
||||
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
||||
model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id))
|
||||
|
||||
# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
|
||||
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
|
||||
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
|
||||
model = model.to(accelerator.device)
|
||||
|
||||
# Freezing the base model
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in model.get_classifier().parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
# We normalize the batches of images to be a bit faster.
|
||||
mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None].to(accelerator.device)
|
||||
std = torch.tensor(model.default_cfg["std"])[None, :, None, None].to(accelerator.device)
|
||||
|
||||
# Instantiate optimizer
|
||||
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr / 25)
|
||||
|
||||
# Prepare everything
|
||||
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
||||
# prepare method.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Instantiate learning rate scheduler after preparing the training dataloader as the prepare method
|
||||
# may change its length.
|
||||
lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_dataloader))
|
||||
|
||||
# Now we train the model
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
|
||||
inputs = (batch["image"] - mean) / std
|
||||
outputs = model(inputs)
|
||||
loss = torch.nn.functional.cross_entropy(outputs, batch["label"])
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
model.eval()
|
||||
accurate = 0
|
||||
num_elems = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
|
||||
inputs = (batch["image"] - mean) / std
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs)
|
||||
predictions = outputs.argmax(dim=-1)
|
||||
accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch["label"])
|
||||
num_elems += accurate_preds.shape[0]
|
||||
accurate += accurate_preds.long().sum()
|
||||
|
||||
eval_metric = accurate.item() / num_elems
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
||||
parser.add_argument("--data_dir", required=True, help="The data folder on disk.")
|
||||
parser.add_argument("--fp16", action="store_true", help="If passed, will use FP16 training.")
|
||||
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 3e-2, "num_epochs": 3, "seed": 42, "batch_size": 64, "image_size": 224}
|
||||
training_function(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
@ -30,7 +31,7 @@ from transformers import (
|
||||
########################################################################
|
||||
# This is a fully working simple example to use Accelerate
|
||||
#
|
||||
# This example train a Bert base model on GLUE MRPC
|
||||
# This example trains a Bert base model on GLUE MRPC
|
||||
# in any of the following settings (with the same script):
|
||||
# - single CPU or single GPU
|
||||
# - multi GPUS (using PyTorch distributed mode)
|
||||
@ -39,7 +40,7 @@ from transformers import (
|
||||
#
|
||||
# To run it in each of these various modes, follow the instructions
|
||||
# in the readme for examples:
|
||||
# https://github.com/huggingface/accelerate/examples
|
||||
# https://github.com/huggingface/accelerate/tree/main/examples
|
||||
#
|
||||
########################################################################
|
||||
|
||||
@ -49,6 +50,9 @@ EVAL_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu)
|
||||
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
num_epochs = int(config["num_epochs"])
|
||||
@ -97,14 +101,13 @@ def training_function(config, args):
|
||||
)
|
||||
|
||||
set_seed(seed)
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu)
|
||||
|
||||
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
# If setting devices manually, this line absolutely needs to be before the optimizer creation otherwise training
|
||||
# will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
|
||||
|
||||
# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
|
||||
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
|
||||
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
|
||||
model = model.to(accelerator.device)
|
||||
|
||||
# Instantiate optimizer
|
||||
@ -125,7 +128,7 @@ def training_function(config, args):
|
||||
num_training_steps=len(train_dataloader) * num_epochs,
|
||||
)
|
||||
|
||||
# Now we train the model - We prune bad trials after each epoch if needed
|
||||
# Now we train the model
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
@ -144,7 +147,8 @@ def training_function(config, args):
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch.to(accelerator.device)
|
||||
outputs = model(**batch)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
metric.add_batch(
|
||||
predictions=accelerator.gather(predictions),
|
||||
@ -158,16 +162,8 @@ def training_function(config, args):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="If passed, will use FP16 training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
action="store_true",
|
||||
help="If passed, will train on the CPU.",
|
||||
)
|
||||
parser.add_argument("--fp16", type=bool, default=False, help="If passed, will use FP16 training.")
|
||||
parser.add_argument("--cpu", type=bool, default=False, help="If passed, will train on the CPU.")
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 2e-5, "num_epochs": 3, "correct_bias": True, "seed": 42, "batch_size": 16}
|
||||
training_function(config, args)
|
||||
|
||||
1
examples/requirements.txt
Normal file
1
examples/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
accelerate # used to be installed in Amazon SageMaker environment
|
||||
40
setup.py
40
setup.py
@ -17,11 +17,27 @@ from setuptools import find_packages
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = ["black >= 20.8b1", "isort >= 5.5.4", "flake8 >= 3.8.3"]
|
||||
extras["docs"] = ["recommonmark", "sphinx==3.2.1", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
|
||||
extras["docs"] = [
|
||||
"docutils==0.16.0",
|
||||
"recommonmark",
|
||||
"sphinx==3.2.1",
|
||||
"sphinx-markdown-tables",
|
||||
"sphinx-rtd-theme==0.4.3",
|
||||
"sphinx-copybutton",
|
||||
"sphinxext-opengraph==0.4.1",
|
||||
]
|
||||
extras["test"] = [
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
]
|
||||
|
||||
extras["sagemaker"] = [
|
||||
"sagemaker", # boto3 is a required package in sagemaker
|
||||
]
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.1.0",
|
||||
version="0.2.0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
@ -32,13 +48,15 @@ setup(
|
||||
url="https://github.com/huggingface/accelerate",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
entry_points={"console_scripts": [
|
||||
"accelerate=accelerate.commands.accelerate_cli:main",
|
||||
"accelerate-config=accelerate.commands.config:main",
|
||||
"accelerate-launch=accelerate.commands.launch:main",
|
||||
]},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"accelerate=accelerate.commands.accelerate_cli:main",
|
||||
"accelerate-config=accelerate.commands.config:main",
|
||||
"accelerate-launch=accelerate.commands.launch:main",
|
||||
]
|
||||
},
|
||||
python_requires=">=3.6.0",
|
||||
install_requires=["torch>=1.4.0"],
|
||||
install_requires=["torch>=1.4.0", "pyaml>=20.4.0"],
|
||||
extras_require=extras,
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
@ -58,7 +76,7 @@ setup(
|
||||
# 1. Change the version in __init__.py, setup.py as well as docs/source/conf.py.
|
||||
# 2. Commit these changes with the message: "Release: VERSION"
|
||||
# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
|
||||
# Push the tag to git: git push --tags origin master
|
||||
# Push the tag to git: git push --tags origin main
|
||||
# 4. Run the following commands in the top-level directory:
|
||||
# python setup.py bdist_wheel
|
||||
# python setup.py sdist
|
||||
@ -69,6 +87,6 @@ setup(
|
||||
# pip install -i https://testpypi.python.org/pypi accelerate
|
||||
# 7. Upload the final version to actual pypi:
|
||||
# twine upload dist/* -r pypi
|
||||
# 8. Add release notes to RELEASE.md and the tag in github once everything is looking hunky-dory.
|
||||
# 9. Add the release version to docs/source/_static/js/custom.js and .circleci/deploy.sh
|
||||
# 8. Add release notes to the tag in github once everything is looking hunky-dory.
|
||||
# 9. Add the release version to docs/source/_static/js/custom.js and .github/deploy_doc.sh
|
||||
# 10. Update the version in __init__.py, setup.py to the new version "-dev" and push to master
|
||||
|
||||
@ -2,8 +2,9 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs
|
||||
from .state import DistributedType
|
||||
from .utils import synchronize_rng_states
|
||||
|
||||
@ -12,14 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .data_loader import prepare_data_loader
|
||||
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, KwargsHandler
|
||||
from .optimizer import AcceleratedOptimizer
|
||||
from .state import AcceleratorState, DistributedType
|
||||
from .utils import extract_model_from_parallel, gather, save, wait_for_everyone
|
||||
from .utils import RNGType, extract_model_from_parallel, gather, pad_across_processes, save, wait_for_everyone
|
||||
|
||||
|
||||
class Accelerator:
|
||||
@ -42,6 +45,21 @@ class Accelerator:
|
||||
cpu (:obj:`bool`, `optional`):
|
||||
Whether or not to force the script to execute on CPU. Will ignore GPU available if set to :obj:`True` and
|
||||
force the execution on one process only.
|
||||
rng_types (list of :obj:`str` or :class:`~accelerate.utils.RNGType`):
|
||||
The list of random number generators to synchronize at the beginning of each iteration in your prepared
|
||||
dataloaders. Should be one or several of:
|
||||
|
||||
- :obj:`"torch"`: the base torch random number generator
|
||||
- :obj:`"cuda"`: the CUDA random number generator (GPU only)
|
||||
- :obj:`"xla"`: the XLA random number generator (TPU only)
|
||||
- :obj:`"generator"`: the :obj:`torch.Generator` of the sampler (or batch sampler if there is no sampler in
|
||||
your dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
|
||||
|
||||
Will default to :obj:`["torch"]` for PyTorch versions <=1.5.1 and :obj:`["generator"]` for PyTorch versions
|
||||
>= 1.6.
|
||||
kwargs_handlers (list of kwargs handlers, `optional`)
|
||||
A list of :obj:`KwargHandler` to customize how the objects related to distributed training or mixed
|
||||
precision are created. See :doc:`kwargs` for more information.
|
||||
|
||||
Attributes
|
||||
|
||||
@ -50,23 +68,51 @@ class Accelerator:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device_placement: bool = True, split_batches: bool = False, fp16: bool = None, cpu: bool = False
|
||||
self,
|
||||
device_placement: bool = True,
|
||||
split_batches: bool = False,
|
||||
fp16: bool = None,
|
||||
cpu: bool = False,
|
||||
rng_types: Optional[List[Union[str, RNGType]]] = None,
|
||||
kwargs_handlers: Optional[List[KwargsHandler]] = None,
|
||||
):
|
||||
self.state = AcceleratorState(fp16=fp16, cpu=cpu, _from_accelerator=True)
|
||||
|
||||
self.device_placement = device_placement
|
||||
self.split_batches = split_batches
|
||||
|
||||
# Kwargs handlers
|
||||
self.ddp_handler = None
|
||||
self.scaler_handler = None
|
||||
if kwargs_handlers is not None:
|
||||
for handler in kwargs_handlers:
|
||||
assert isinstance(handler, KwargsHandler), f"Unsupported kwargs handler passed: {handler}."
|
||||
if isinstance(handler, DistributedDataParallelKwargs):
|
||||
if self.ddp_handler is not None:
|
||||
raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
|
||||
else:
|
||||
self.ddp_handler = handler
|
||||
elif isinstance(handler, GradScalerKwargs):
|
||||
if self.scaler_handler is not None:
|
||||
raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.")
|
||||
else:
|
||||
self.scaler_handler = handler
|
||||
|
||||
# Mixed precision attributes
|
||||
self.scaler = None
|
||||
self.native_amp = False
|
||||
if self.state.use_fp16:
|
||||
self.native_amp = version.parse(torch.__version__) >= version.parse("1.6")
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
|
||||
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
|
||||
|
||||
# Internal references to the training objects
|
||||
self._optimizers = []
|
||||
|
||||
# RNG Types
|
||||
if rng_types is None:
|
||||
self.rng_types = ["torch"] if version.parse(torch.__version__) <= version.parse("1.5.1") else ["generator"]
|
||||
|
||||
@property
|
||||
def distributed_type(self):
|
||||
return self.state.distributed_type
|
||||
@ -164,16 +210,18 @@ class Accelerator:
|
||||
if isinstance(obj, torch.optim.Optimizer):
|
||||
obj._switch_parameters(mapping)
|
||||
|
||||
return result
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
def prepare_model(self, model):
|
||||
if self.device_placement:
|
||||
model = model.to(self.device)
|
||||
if self.distributed_type == DistributedType.MULTI_GPU:
|
||||
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.local_process_index],
|
||||
output_device=self.local_process_index,
|
||||
**kwargs,
|
||||
)
|
||||
if self.native_amp:
|
||||
model.forward = torch.cuda.amp.autocast()(model.forward)
|
||||
@ -187,6 +235,7 @@ class Accelerator:
|
||||
process_index=self.process_index,
|
||||
split_batches=self.split_batches,
|
||||
put_on_device=self.device_placement,
|
||||
rng_types=self.rng_types,
|
||||
)
|
||||
|
||||
def prepare_optimizer(self, optimizer):
|
||||
@ -240,6 +289,23 @@ class Accelerator:
|
||||
"""
|
||||
return gather(tensor)
|
||||
|
||||
def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
|
||||
"""
|
||||
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
|
||||
they can safely be gathered.
|
||||
|
||||
Args:
|
||||
tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
|
||||
The data to gather.
|
||||
dim (:obj:`int`, `optional`, defaults to 0):
|
||||
The dimension on which to pad.
|
||||
pad_index (:obj:`int`, `optional`, defaults to 0):
|
||||
The value with which to pad.
|
||||
pad_first (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to pad at the beginning or the end.
|
||||
"""
|
||||
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
|
||||
|
||||
def unwrap_model(self, model):
|
||||
"""
|
||||
Unwraps the :obj:`model` from the additional layer possible added by :meth:`~accelerate.Accelerator.prepare`.
|
||||
|
||||
@ -1,184 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from accelerate.state import DistributedType
|
||||
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
)
|
||||
cache_dir = os.path.join(hf_cache_home, "accelerate")
|
||||
default_config_file = os.path.join(cache_dir, "default_config.json")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LaunchConfig:
|
||||
distributed_type: DistributedType
|
||||
num_processes: int
|
||||
fp16: bool
|
||||
machine_rank: int = 0
|
||||
num_machines: int = 1
|
||||
main_process_ip: Optional[str] = None
|
||||
main_process_port: Optional[int] = None
|
||||
main_training_function: str = "main"
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file=None):
|
||||
json_file = default_config_file if json_file is None else json_file
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
return cls(**json.load(f))
|
||||
|
||||
def to_json_file(self, json_file):
|
||||
with open(json_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(self.__dict__, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
|
||||
def config_command_parser(subparsers=None):
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("config")
|
||||
else:
|
||||
parser = argparse.ArgumentParser("Accelerate config command")
|
||||
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
help=(
|
||||
"The path to use to store the config file. Will default to a file named default_config.json in the cache "
|
||||
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
||||
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
||||
"with 'huggingface'."
|
||||
),
|
||||
)
|
||||
|
||||
if subparsers is not None:
|
||||
parser.set_defaults(func=config_command)
|
||||
return parser
|
||||
|
||||
|
||||
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
|
||||
ask_again = True
|
||||
while ask_again:
|
||||
result = input(input_text)
|
||||
try:
|
||||
if default is not None and len(result) == 0:
|
||||
return default
|
||||
return convert_value(result) if convert_value is not None else result
|
||||
except:
|
||||
if error_message is not None:
|
||||
print(error_message)
|
||||
|
||||
|
||||
def get_user_input():
|
||||
def _convert_distributed_mode(value):
|
||||
value = int(value)
|
||||
return DistributedType(["NO", "MULTI_GPU", "TPU"][value])
|
||||
|
||||
def _convert_yes_no_to_bool(value):
|
||||
return {"yes": True, "no": False}[value.lower()]
|
||||
|
||||
distributed_type = _ask_field(
|
||||
"Which type of machine are you using? ([0] No distributed training, [1] multi-GPU, [2] TPU): ",
|
||||
_convert_distributed_mode,
|
||||
error_message="Please enter 0, 1 or 2.",
|
||||
)
|
||||
|
||||
machine_rank = 0
|
||||
num_machines = 1
|
||||
main_process_ip = None
|
||||
main_process_port = None
|
||||
if distributed_type == DistributedType.MULTI_GPU:
|
||||
num_machines = _ask_field(
|
||||
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
)
|
||||
if num_machines > 1:
|
||||
machine_rank = _ask_field(
|
||||
"What is the rank of this machine (from 0 to the number of machines - 1 )? [0]: ",
|
||||
lambda x: int(x),
|
||||
default=0,
|
||||
)
|
||||
main_process_ip = _ask_field(
|
||||
"What is the IP address of the machine that will host the main process? ",
|
||||
)
|
||||
main_process_ip = _ask_field(
|
||||
"What is the port you will use to communicate with the main process? ",
|
||||
lambda x: int(x),
|
||||
)
|
||||
if distributed_type == DistributedType.TPU:
|
||||
main_training_function = _ask_field(
|
||||
"What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
|
||||
default="main",
|
||||
)
|
||||
else:
|
||||
main_training_function = "main"
|
||||
|
||||
num_processes = _ask_field(
|
||||
"How many processes in total will you use? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
|
||||
if distributed_type != DistributedType.TPU:
|
||||
fp16 = _ask_field(
|
||||
"Do you wish to use FP16 (mixed precision)? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
else:
|
||||
fp16 = False
|
||||
|
||||
return LaunchConfig(
|
||||
distributed_type=distributed_type,
|
||||
num_processes=num_processes,
|
||||
fp16=fp16,
|
||||
machine_rank=machine_rank,
|
||||
num_machines=num_machines,
|
||||
main_process_ip=main_process_ip,
|
||||
main_process_port=main_process_port,
|
||||
main_training_function=main_training_function,
|
||||
)
|
||||
|
||||
|
||||
def config_command(args):
|
||||
config = get_user_input()
|
||||
if args.config_file is not None:
|
||||
config_file = args.config_file
|
||||
else:
|
||||
if not os.path.isdir(cache_dir):
|
||||
os.makedirs(cache_dir)
|
||||
config_file = default_config_file
|
||||
|
||||
config.to_json_file(config_file)
|
||||
|
||||
|
||||
def main():
|
||||
parser = config_command_parser()
|
||||
args = parser.parse_args()
|
||||
config_command(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
85
src/accelerate/commands/config/__init__.py
Normal file
85
src/accelerate/commands/config/__init__.py
Normal file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from accelerate.state import ComputeEnvironment
|
||||
|
||||
from .cluster import get_cluster_input
|
||||
from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
|
||||
from .config_utils import _ask_field, _convert_compute_environment
|
||||
from .sagemaker import get_sagemaker_input
|
||||
|
||||
|
||||
def get_user_input():
|
||||
compute_environment = _ask_field(
|
||||
"In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): ",
|
||||
_convert_compute_environment,
|
||||
error_message="Please enter 0 or 1",
|
||||
)
|
||||
if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
||||
config = get_sagemaker_input()
|
||||
else:
|
||||
config = get_cluster_input()
|
||||
return config
|
||||
|
||||
|
||||
def config_command_parser(subparsers=None):
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("config")
|
||||
else:
|
||||
parser = argparse.ArgumentParser("Accelerate config command")
|
||||
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
help=(
|
||||
"The path to use to store the config file. Will default to a file named default_config.json in the cache "
|
||||
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
||||
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
||||
"with 'huggingface'."
|
||||
),
|
||||
)
|
||||
|
||||
if subparsers is not None:
|
||||
parser.set_defaults(func=config_command)
|
||||
return parser
|
||||
|
||||
|
||||
def config_command(args):
|
||||
config = get_user_input()
|
||||
if args.config_file is not None:
|
||||
config_file = args.config_file
|
||||
else:
|
||||
if not os.path.isdir(cache_dir):
|
||||
os.makedirs(cache_dir)
|
||||
config_file = default_yaml_config_file
|
||||
|
||||
if config_file.endswith(".json"):
|
||||
config.to_json_file(config_file)
|
||||
else:
|
||||
config.to_yaml_file(config_file)
|
||||
|
||||
|
||||
def main():
|
||||
parser = config_command_parser()
|
||||
args = parser.parse_args()
|
||||
config_command(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
src/accelerate/commands/config/cluster.py
Normal file
88
src/accelerate/commands/config/cluster.py
Normal file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from accelerate.state import ComputeEnvironment, DistributedType
|
||||
|
||||
from .config_args import ClusterConfig
|
||||
from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool
|
||||
|
||||
|
||||
def get_cluster_input():
|
||||
distributed_type = _ask_field(
|
||||
"Which type of machine are you using? ([0] No distributed training, [1] multi-GPU, [2] TPU): ",
|
||||
_convert_distributed_mode,
|
||||
error_message="Please enter 0, 1 or 2.",
|
||||
)
|
||||
|
||||
machine_rank = 0
|
||||
num_machines = 1
|
||||
main_process_ip = None
|
||||
main_process_port = None
|
||||
if distributed_type == DistributedType.MULTI_GPU:
|
||||
num_machines = _ask_field(
|
||||
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
)
|
||||
if num_machines > 1:
|
||||
machine_rank = _ask_field(
|
||||
"What is the rank of this machine (from 0 to the number of machines - 1 )? [0]: ",
|
||||
lambda x: int(x),
|
||||
default=0,
|
||||
)
|
||||
main_process_ip = _ask_field(
|
||||
"What is the IP address of the machine that will host the main process? ",
|
||||
)
|
||||
main_process_ip = _ask_field(
|
||||
"What is the port you will use to communicate with the main process? ",
|
||||
lambda x: int(x),
|
||||
)
|
||||
if distributed_type == DistributedType.TPU:
|
||||
main_training_function = _ask_field(
|
||||
"What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
|
||||
default="main",
|
||||
)
|
||||
else:
|
||||
main_training_function = "main"
|
||||
|
||||
num_processes = _ask_field(
|
||||
"How many processes in total will you use? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
|
||||
if distributed_type != DistributedType.TPU:
|
||||
fp16 = _ask_field(
|
||||
"Do you wish to use FP16 (mixed precision)? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
else:
|
||||
fp16 = False
|
||||
|
||||
return ClusterConfig(
|
||||
compute_environment=ComputeEnvironment.LOCAL_MACHINE,
|
||||
distributed_type=distributed_type,
|
||||
num_processes=num_processes,
|
||||
fp16=fp16,
|
||||
machine_rank=machine_rank,
|
||||
num_machines=num_machines,
|
||||
main_process_ip=main_process_ip,
|
||||
main_process_port=main_process_port,
|
||||
main_training_function=main_training_function,
|
||||
)
|
||||
131
src/accelerate/commands/config/config_args.py
Normal file
131
src/accelerate/commands/config/config_args.py
Normal file
@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import yaml
|
||||
from accelerate.state import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
||||
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
)
|
||||
cache_dir = os.path.join(hf_cache_home, "accelerate")
|
||||
default_json_config_file = os.path.join(cache_dir, "default_config.json")
|
||||
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
|
||||
|
||||
# For backward compatibility: the default config is the json one if it's the only existing file.
|
||||
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
|
||||
default_config_file = default_yaml_config_file
|
||||
else:
|
||||
default_config_file = default_json_config_file
|
||||
|
||||
|
||||
def load_config_from_file(config_file):
|
||||
config_file = config_file if config_file is not None else default_config_file
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
if config_file.endswith(".json"):
|
||||
if (
|
||||
json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
||||
is ComputeEnvironment.LOCAL_MACHINE
|
||||
):
|
||||
config_class = ClusterConfig
|
||||
else:
|
||||
config_class = SageMakerConfig
|
||||
return config_class.from_json_file(json_file=config_file)
|
||||
else:
|
||||
if (
|
||||
yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
||||
is ComputeEnvironment.LOCAL_MACHINE
|
||||
):
|
||||
config_class = ClusterConfig
|
||||
else:
|
||||
config_class = SageMakerConfig
|
||||
return config_class.from_yaml_file(yaml_file=config_file)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
compute_environment: ComputeEnvironment
|
||||
distributed_type: Union[DistributedType, SageMakerDistributedType]
|
||||
fp16: bool
|
||||
|
||||
def to_dict(self):
|
||||
result = self.__dict__
|
||||
# For serialization, it's best to convert Enums to strings (or their underlying value type).
|
||||
for key, value in result.items():
|
||||
if isinstance(value, Enum):
|
||||
result[key] = value.value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file=None):
|
||||
json_file = default_json_config_file if json_file is None else json_file
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
config_dict = json.load(f)
|
||||
if "compute_environment" not in config_dict:
|
||||
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_json_file(self, json_file):
|
||||
with open(json_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_file(cls, yaml_file=None):
|
||||
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
|
||||
with open(yaml_file, "r", encoding="utf-8") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
if "compute_environment" not in config_dict:
|
||||
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_yaml_file(self, yaml_file):
|
||||
with open(yaml_file, "w", encoding="utf-8") as f:
|
||||
yaml.safe_dump(self.to_dict(), f)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.compute_environment, str):
|
||||
self.compute_environment = ComputeEnvironment(self.compute_environment)
|
||||
if isinstance(self.distributed_type, str):
|
||||
self.distributed_type = DistributedType(self.distributed_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusterConfig(BaseConfig):
|
||||
num_processes: int
|
||||
machine_rank: int = 0
|
||||
num_machines: int = 1
|
||||
main_process_ip: Optional[str] = None
|
||||
main_process_port: Optional[int] = None
|
||||
main_training_function: str = "main"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerConfig(BaseConfig):
|
||||
ec2_instance_type: str
|
||||
iam_role_name: str
|
||||
profile: Optional[str] = None
|
||||
region: str = "us-east-1"
|
||||
num_machines: int = 1
|
||||
base_job_name: str = f"accelerate-sagemaker-{num_machines}"
|
||||
pytorch_version: str = "1.6"
|
||||
transformers_version: str = "4.4"
|
||||
49
src/accelerate/commands/config/config_utils.py
Normal file
49
src/accelerate/commands/config/config_utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from accelerate.state import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
||||
|
||||
|
||||
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
|
||||
ask_again = True
|
||||
while ask_again:
|
||||
result = input(input_text)
|
||||
try:
|
||||
if default is not None and len(result) == 0:
|
||||
return default
|
||||
return convert_value(result) if convert_value is not None else result
|
||||
except:
|
||||
if error_message is not None:
|
||||
print(error_message)
|
||||
|
||||
|
||||
def _convert_compute_environment(value):
|
||||
value = int(value)
|
||||
return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
|
||||
|
||||
|
||||
def _convert_distributed_mode(value):
|
||||
value = int(value)
|
||||
return DistributedType(["NO", "MULTI_GPU", "TPU"][value])
|
||||
|
||||
|
||||
def _convert_sagemaker_distributed_mode(value):
|
||||
value = int(value)
|
||||
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
|
||||
|
||||
|
||||
def _convert_yes_no_to_bool(value):
|
||||
return {"yes": True, "no": False}[value.lower()]
|
||||
158
src/accelerate/commands/config/sagemaker.py
Normal file
158
src/accelerate/commands/config/sagemaker.py
Normal file
@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
|
||||
from accelerate.state import ComputeEnvironment, SageMakerDistributedType
|
||||
from accelerate.utils import is_boto3_available
|
||||
|
||||
from .config_args import SageMakerConfig
|
||||
from .config_utils import _ask_field, _convert_sagemaker_distributed_mode, _convert_yes_no_to_bool
|
||||
|
||||
|
||||
if is_boto3_available():
|
||||
import boto3 # noqa: F401
|
||||
|
||||
|
||||
def _create_iam_role_for_sagemaker(role_name):
|
||||
iam_client = boto3.client("iam")
|
||||
|
||||
sagemaker_trust_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
|
||||
],
|
||||
}
|
||||
try:
|
||||
# create the role, associated with the chosen trust policy
|
||||
iam_client.create_role(
|
||||
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
|
||||
)
|
||||
policy_document = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"sagemaker:*",
|
||||
"ecr:GetDownloadUrlForLayer",
|
||||
"ecr:BatchGetImage",
|
||||
"ecr:BatchCheckLayerAvailability",
|
||||
"ecr:GetAuthorizationToken",
|
||||
"cloudwatch:PutMetricData",
|
||||
"cloudwatch:GetMetricData",
|
||||
"cloudwatch:GetMetricStatistics",
|
||||
"cloudwatch:ListMetrics",
|
||||
"logs:CreateLogGroup",
|
||||
"logs:CreateLogStream",
|
||||
"logs:DescribeLogStreams",
|
||||
"logs:PutLogEvents",
|
||||
"logs:GetLogEvents",
|
||||
"s3:CreateBucket",
|
||||
"s3:ListBucket",
|
||||
"s3:GetBucketLocation",
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
],
|
||||
"Resource": "*",
|
||||
}
|
||||
],
|
||||
}
|
||||
# attach policy to role
|
||||
iam_client.put_role_policy(
|
||||
RoleName=role_name,
|
||||
PolicyName=f"{role_name}_policy_permission",
|
||||
PolicyDocument=json.dumps(policy_document, indent=2),
|
||||
)
|
||||
except iam_client.exceptions.EntityAlreadyExistsException:
|
||||
print(f"role {role_name} already exists. Using existing one")
|
||||
|
||||
|
||||
def _get_iam_role_arn(role_name):
|
||||
iam_client = boto3.client("iam")
|
||||
return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
|
||||
|
||||
|
||||
def get_sagemaker_input():
|
||||
credentials_configuration = _ask_field(
|
||||
"How do you want to authorize? ([0] AWS Profile, [1] Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)): ",
|
||||
lambda x: int(x),
|
||||
)
|
||||
aws_profile = None
|
||||
if credentials_configuration == 0:
|
||||
aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
|
||||
os.environ["AWS_PROFILE"] = aws_profile
|
||||
else:
|
||||
print(
|
||||
"Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
|
||||
"`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
|
||||
)
|
||||
aws_access_key_id = _ask_field("AWS Access Key ID: ")
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||
|
||||
aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||
|
||||
aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
|
||||
os.environ["AWS_DEFAULT_REGION"] = aws_region
|
||||
|
||||
role_management = _ask_field(
|
||||
"Do you already have an IAM Role for executing Amazon SageMaker Training Jobs? ([0] provide IAM Role name, [1] create new IAM role using credentials: ",
|
||||
lambda x: int(x),
|
||||
)
|
||||
if role_management == 0:
|
||||
iam_role_name = _ask_field("Enter your IAM role name: ")
|
||||
else:
|
||||
iam_role_name = "accelerate_sagemaker_execution_role"
|
||||
print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
|
||||
_create_iam_role_for_sagemaker(iam_role_name)
|
||||
|
||||
distributed_type = _ask_field(
|
||||
"Which type of machine are you using? ([0] No distributed training, [1] data parallelism, [2] model parallelism): ",
|
||||
_convert_sagemaker_distributed_mode,
|
||||
error_message="Please enter 0, 1 or 2",
|
||||
)
|
||||
|
||||
# using the best two instances for single-gpu training or multi-gpu -> can turn into question to make it more diverse
|
||||
ec2_instance_type = "ml.p3.2xlarge" if distributed_type == SageMakerDistributedType.NO else "ml.p3dn.24xlarge"
|
||||
num_machines = 1
|
||||
if (
|
||||
distributed_type == SageMakerDistributedType.DATA_PARALLEL
|
||||
or distributed_type == SageMakerDistributedType.MODEL_PARALLEL
|
||||
):
|
||||
raise NotImplementedError("Model or Data Parallelism is not implemented yet. We are working on it")
|
||||
num_machines = _ask_field(
|
||||
"How many machines do you want use? [2]: ",
|
||||
lambda x: int(x),
|
||||
default=2,
|
||||
)
|
||||
fp16 = _ask_field(
|
||||
"Do you wish to use FP16 (mixed precision)? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
return SageMakerConfig(
|
||||
compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
|
||||
distributed_type=distributed_type,
|
||||
ec2_instance_type=ec2_instance_type,
|
||||
profile=aws_profile,
|
||||
region=aws_region,
|
||||
iam_role_name=iam_role_name,
|
||||
fp16=fp16,
|
||||
num_machines=num_machines,
|
||||
)
|
||||
@ -20,17 +20,20 @@ import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from ast import literal_eval
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from accelerate.commands.config import LaunchConfig, default_config_file
|
||||
from accelerate.state import DistributedType
|
||||
from accelerate.commands.config import default_config_file, load_config_from_file
|
||||
from accelerate.commands.config.config_args import SageMakerConfig
|
||||
from accelerate.state import ComputeEnvironment, DistributedType
|
||||
from accelerate.utils import is_sagemaker_available
|
||||
|
||||
|
||||
class _AddOneArg():
|
||||
class _AddOneArg:
|
||||
def __init__(self, launcher):
|
||||
self.launcher = launcher
|
||||
|
||||
|
||||
def __call__(self, index):
|
||||
self.launcher()
|
||||
|
||||
@ -68,12 +71,10 @@ def launch_command_parser(subparsers=None):
|
||||
parser.add_argument(
|
||||
"--machine_rank", type=int, default=0, help="The rank of the machine on which this script is launched."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--main_process_ip", type=Optional[str], default=None, help="The IP address of the machine of rank 0."
|
||||
)
|
||||
parser.add_argument("--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0.")
|
||||
parser.add_argument(
|
||||
"--main_process_port",
|
||||
type=Optional[int],
|
||||
type=int,
|
||||
default=None,
|
||||
help="The port to use to communicate with the machine of rank 0.",
|
||||
)
|
||||
@ -83,6 +84,18 @@ def launch_command_parser(subparsers=None):
|
||||
default=None,
|
||||
help="The name of the main function to be executed in your script (only for TPU training).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aws_access_key_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aws_secret_access_key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job",
|
||||
)
|
||||
parser.add_argument(
|
||||
"training_script",
|
||||
type=str,
|
||||
@ -171,22 +184,125 @@ def tpu_launcher(args):
|
||||
xmp.spawn(main_function, args=(), nprocs=args.num_processes)
|
||||
|
||||
|
||||
def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]:
|
||||
if len(nargs) < 0:
|
||||
return {}
|
||||
# helper function to infer type for argsparser
|
||||
|
||||
def _infer_type(s):
|
||||
try:
|
||||
s = float(s)
|
||||
|
||||
if s // 1 == s:
|
||||
return int(s)
|
||||
return s
|
||||
except ValueError:
|
||||
return s
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
_, unknown = parser.parse_known_args(nargs)
|
||||
for index, argument in enumerate(unknown):
|
||||
if argument.startswith(("-", "--")):
|
||||
action = None
|
||||
if index + 1 < len(unknown): # checks if next index would be in list
|
||||
if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key
|
||||
# raise an error if element is store_true or store_false
|
||||
raise ValueError(
|
||||
"SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
|
||||
)
|
||||
else: # raise an error if last element is store_true or store_false
|
||||
raise ValueError(
|
||||
"SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
|
||||
)
|
||||
# adds argument to parser based on action_store true
|
||||
if action is None:
|
||||
parser.add_argument(argument, type=_infer_type)
|
||||
else:
|
||||
parser.add_argument(argument, action=action)
|
||||
|
||||
return {
|
||||
key: (literal_eval(value) if value == "True" or value == "False" else value)
|
||||
for key, value in parser.parse_args(nargs).__dict__.items()
|
||||
}
|
||||
|
||||
|
||||
def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
|
||||
if not is_sagemaker_available():
|
||||
raise ImportError(
|
||||
"Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`"
|
||||
)
|
||||
from sagemaker.huggingface import HuggingFace
|
||||
|
||||
# configure environment
|
||||
print("Configuring Amazon SageMaker environment")
|
||||
os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region
|
||||
|
||||
# configure credentials
|
||||
if sagemaker_config.profile is not None:
|
||||
os.environ["AWS_PROFILE"] = sagemaker_config.profile
|
||||
elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
"You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile"
|
||||
)
|
||||
|
||||
# extract needed arguments
|
||||
source_dir = os.path.dirname(args.training_script)
|
||||
if not source_dir: # checks if string is empty
|
||||
source_dir = "."
|
||||
entry_point = os.path.basename(args.training_script)
|
||||
if not entry_point.endswith(".py"):
|
||||
raise ValueError(f'Your training script should be a python script and not "{entry_point}"')
|
||||
|
||||
print("Converting Arguments to Hyperparameters")
|
||||
hyperparameters = _convert_nargs_to_dict(args.training_script_args)
|
||||
|
||||
environment = {"USE_FP16": args.fp16} # Environment variables to be set for use during training job
|
||||
|
||||
# configure distribution set up
|
||||
distribution = None # TODO: not yet implemented
|
||||
|
||||
# configure session
|
||||
print("Creating Estimator")
|
||||
huggingface_estimator = HuggingFace(
|
||||
entry_point=entry_point,
|
||||
source_dir=source_dir,
|
||||
role=sagemaker_config.iam_role_name,
|
||||
transformers_version="4.4",
|
||||
pytorch_version="1.6",
|
||||
py_version="py36",
|
||||
base_job_name=sagemaker_config.base_job_name,
|
||||
instance_count=sagemaker_config.num_machines,
|
||||
instance_type=sagemaker_config.ec2_instance_type,
|
||||
debugger_hook_config=False,
|
||||
distribution=distribution,
|
||||
hyperparameters=hyperparameters,
|
||||
environment=environment,
|
||||
)
|
||||
|
||||
huggingface_estimator.fit()
|
||||
print(f"You can find your model data at: {huggingface_estimator.model_data}")
|
||||
|
||||
|
||||
def launch_command(args):
|
||||
# Sanity checks
|
||||
if args.multi_gpu and args.tpu:
|
||||
raise ValueError("You can only pick one between `--multi_gpu` and `--tpu`.")
|
||||
|
||||
defaults = None
|
||||
# Get the default from the config file.
|
||||
if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
|
||||
defaults = LaunchConfig.from_json_file(json_file=args.config_file)
|
||||
defaults = load_config_from_file(args.config_file)
|
||||
if not args.multi_gpu and not args.tpu:
|
||||
args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU
|
||||
args.tpu = defaults.distributed_type == DistributedType.TPU
|
||||
if args.num_processes is None:
|
||||
if args.num_processes is None and defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
|
||||
args.num_processes = defaults.num_processes
|
||||
if not args.fp16:
|
||||
args.fp16 = defaults.fp16
|
||||
if args.main_training_function is None:
|
||||
if args.main_training_function is None and defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
|
||||
args.main_training_function = defaults.main_training_function
|
||||
else:
|
||||
if args.num_processes is None:
|
||||
@ -197,6 +313,8 @@ def launch_command(args):
|
||||
multi_gpu_launcher(args)
|
||||
elif args.tpu and not args.cpu:
|
||||
tpu_launcher(args)
|
||||
elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
||||
sagemaker_launcher(defaults, args)
|
||||
else:
|
||||
simple_launcher(args)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
|
||||
@ -20,7 +20,7 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset
|
||||
from packaging import version
|
||||
|
||||
from .state import AcceleratorState, DistributedType, is_tpu_available
|
||||
from .utils import send_to_device, synchronize_rng_states
|
||||
from .utils import RNGType, send_to_device, synchronize_rng_states
|
||||
|
||||
|
||||
if is_tpu_available():
|
||||
@ -262,16 +262,29 @@ class DataLoaderShard(DataLoader):
|
||||
The dataset to use to build this datalaoder.
|
||||
device (:obj:`torch.device`, `optional`):
|
||||
If passed, the device to put all batches on.
|
||||
rng_types (list of :obj:`str` or :class:`~accelerate.utils.RNGType`):
|
||||
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
||||
several of:
|
||||
|
||||
- :obj:`"torch"`: the base torch random number generator
|
||||
- :obj:`"cuda"`: the CUDA random number generator (GPU only)
|
||||
- :obj:`"xla"`: the XLA random number generator (TPU only)
|
||||
- :obj:`"generator"`: an optional :obj:`torch.Generator`
|
||||
generator (:obj:`torch.Generator`, `optional`):
|
||||
A random number generator to keep synchronized accross processes.
|
||||
kwargs:
|
||||
All other keyword arguments to pass to the regular :obj:`DataLoader` initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, device=None, **kwargs):
|
||||
def __init__(self, dataset, device=None, rng_types=None, generator=None, **kwargs):
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.device = device
|
||||
self.rng_types = rng_types
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self):
|
||||
synchronize_rng_states()
|
||||
if self.rng_types is not None:
|
||||
synchronize_rng_states(self.rng_types, self.generator)
|
||||
state = AcceleratorState()
|
||||
for batch in super().__iter__():
|
||||
if state.distributed_type == DistributedType.TPU:
|
||||
@ -286,6 +299,7 @@ def prepare_data_loader(
|
||||
process_index: Optional[int] = None,
|
||||
split_batches: bool = False,
|
||||
put_on_device: bool = False,
|
||||
rng_types: Optional[List[Union[str, RNGType]]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Wraps a PyTorch :obj:`DataLoader` to generate batches for one of the processes only.
|
||||
@ -318,6 +332,15 @@ def prepare_data_loader(
|
||||
put_on_device (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to put the batches on :obj:`device` (only works if the batches are nested list, tuples or
|
||||
dictionaries of tensors).
|
||||
rng_types (list of :obj:`str` or :class:`~accelerate.utils.RNGType`):
|
||||
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
||||
several of:
|
||||
|
||||
- :obj:`"torch"`: the base torch random number generator
|
||||
- :obj:`"cuda"`: the CUDA random number generator (GPU only)
|
||||
- :obj:`"xla"`: the XLA random number generator (TPU only)
|
||||
- :obj:`"generator"`: the :obj:`torch.Generator` of the sampler (or batch sampler if there is no sampler in
|
||||
your dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
|
||||
@ -342,9 +365,12 @@ def prepare_data_loader(
|
||||
|
||||
new_dataset = dataloader.dataset
|
||||
new_batch_sampler = dataloader.batch_sampler
|
||||
generator = getattr(dataloader, "generator", None)
|
||||
# No change if no multiprocess
|
||||
if num_processes != 1:
|
||||
if isinstance(new_dataset, IterableDataset):
|
||||
if getattr(dataloader.dataset, "generator", None) is not None:
|
||||
generator = dataloader.dataset.generator
|
||||
new_dataset = IterableDatasetShard(
|
||||
new_dataset,
|
||||
batch_size=dataloader.batch_size,
|
||||
@ -355,6 +381,13 @@ def prepare_data_loader(
|
||||
)
|
||||
else:
|
||||
# New batch sampler for the current process.
|
||||
if hasattr(dataloader.sampler, "generator"):
|
||||
if dataloader.sampler.generator is None:
|
||||
dataloader.sampler.generator = torch.Generator()
|
||||
generator = dataloader.sampler.generator
|
||||
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
|
||||
elif getattr(dataloader.batch_sampler, "generator", None) is not None:
|
||||
generator = dataloader.batch_sampler.generator
|
||||
new_batch_sampler = BatchSamplerShard(
|
||||
dataloader.batch_sampler,
|
||||
num_processes=num_processes,
|
||||
@ -369,8 +402,12 @@ def prepare_data_loader(
|
||||
"sampler",
|
||||
"batch_sampler",
|
||||
"drop_last",
|
||||
"generator",
|
||||
]
|
||||
|
||||
if rng_types is not None and generator is None and "generator" in rng_types:
|
||||
rng_types.remove("generator")
|
||||
|
||||
kwargs = {
|
||||
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
||||
for k in _PYTORCH_DATALOADER_KWARGS
|
||||
@ -380,5 +417,7 @@ def prepare_data_loader(
|
||||
new_dataset,
|
||||
device=device if put_on_device else None,
|
||||
batch_sampler=new_batch_sampler,
|
||||
rng_types=rng_types,
|
||||
generator=generator,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
73
src/accelerate/kwargs_handlers.py
Normal file
73
src/accelerate/kwargs_handlers.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class KwargsHandler:
|
||||
"""
|
||||
Internal mixin that implements a :obj:`to_kwargs()` method for a dataclass.
|
||||
"""
|
||||
|
||||
def to_dict(self):
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
||||
def to_kwargs(self):
|
||||
"""
|
||||
Returns a dictionary containing the attributes with values different from the default of this class.
|
||||
"""
|
||||
default_dict = self.__class__().to_dict()
|
||||
this_dict = self.to_dict()
|
||||
return {k: v for k, v in this_dict.items() if default_dict[k] != v}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedDataParallelKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your :class:`~accelerate.Accelerator` to customize how your model is wrapped in a
|
||||
:obj:`torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this `wrapper
|
||||
<https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__ for more information
|
||||
on each argument.
|
||||
|
||||
.. warning::
|
||||
|
||||
:obj:`gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions.
|
||||
"""
|
||||
|
||||
dim: int = 0
|
||||
broadcast_buffers: bool = True
|
||||
bucket_cap_mb: int = 25
|
||||
find_unused_parameters: bool = False
|
||||
check_reduction: bool = False
|
||||
gradient_as_bucket_view: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradScalerKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your :class:`~accelerate.Accelerator` to customize the behavior of mixed precision, specifically
|
||||
how the :obj:`torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this `scaler
|
||||
<https://pytorch.org/docs/stable/amp.html?highlight=gradscaler>`__ for more information on each argument.
|
||||
|
||||
.. warning::
|
||||
|
||||
:obj:`GradScaler` is only available in PyTorch 1.5.0 and later versions.
|
||||
"""
|
||||
|
||||
init_scale: float = 65536.0
|
||||
growth_factor: float = 2.0
|
||||
backoff_factor: float = 0.5
|
||||
growth_interval: int = 2000
|
||||
enabled: bool = True
|
||||
@ -63,6 +63,18 @@ class AcceleratedOptimizer(torch.optim.Optimizer):
|
||||
def param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
@param_groups.setter
|
||||
def param_groups(self, param_groups):
|
||||
self.optimizer.param_groups = param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self.optimizer.defaults
|
||||
|
||||
@defaults.setter
|
||||
def defaults(self, defaults):
|
||||
self.optimizer.defaults = defaults
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self.optimizer.add_param_group(param_group)
|
||||
|
||||
|
||||
@ -53,6 +53,38 @@ class DistributedType(str, Enum):
|
||||
TPU = "TPU"
|
||||
|
||||
|
||||
class SageMakerDistributedType(str, Enum):
|
||||
"""
|
||||
Represents a type of distributed environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **NO** -- Not a distributed environment, just a single process.
|
||||
- **DATA_PARALLEL** -- using sagemaker distributed data parallelism.
|
||||
- **MODEL_PARALLEL** -- using sagemaker distributed model parallelism.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.
|
||||
NO = "NO"
|
||||
DATA_PARALLEL = "DATA_PARALLEL"
|
||||
MODEL_PARALLEL = "MODEL_PARALLEL"
|
||||
|
||||
|
||||
class ComputeEnvironment(str, Enum):
|
||||
"""
|
||||
Represents a type of the compute environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **LOCAL_MACHINE** -- private/custom cluster hardware.
|
||||
- **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box.
|
||||
LOCAL_MACHINE = "LOCAL_MACHINE"
|
||||
AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER"
|
||||
|
||||
|
||||
# Inspired by Alex Martelli's 'Borg'.
|
||||
class AcceleratorState:
|
||||
"""
|
||||
@ -95,6 +127,7 @@ class AcceleratorState:
|
||||
self.process_index = torch.distributed.get_rank()
|
||||
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
||||
self.device = torch.device("cuda", self.local_process_index)
|
||||
torch.cuda.set_device(self.device)
|
||||
self.use_fp16 = parse_flag_from_env("USE_FP16", False) if fp16 is None else fp16
|
||||
else:
|
||||
self.distributed_type = DistributedType.NO
|
||||
|
||||
@ -2,5 +2,5 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from .testing import are_the_same_tensors, execute_subprocess_async, require_multi_gpu, require_tpu
|
||||
from .testing import are_the_same_tensors, execute_subprocess_async, require_cuda, require_multi_gpu, require_tpu
|
||||
from .training import RegressionDataset, RegressionModel
|
||||
|
||||
@ -22,6 +22,7 @@ from accelerate.data_loader import prepare_data_loader
|
||||
from accelerate.state import AcceleratorState, DistributedType
|
||||
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
|
||||
from accelerate.utils import gather, set_seed, synchronize_rng_states
|
||||
from packaging import version
|
||||
|
||||
|
||||
def init_state_check():
|
||||
@ -34,10 +35,16 @@ def init_state_check():
|
||||
|
||||
def rng_sync_check():
|
||||
state = AcceleratorState()
|
||||
synchronize_rng_states()
|
||||
synchronize_rng_states(["torch"])
|
||||
assert are_the_same_tensors(torch.get_rng_state())
|
||||
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||
synchronize_rng_states(["cuda"])
|
||||
assert are_the_same_tensors(torch.cuda.get_rng_state())
|
||||
if version.parse(torch.__version__) >= version.parse("1.6.0"):
|
||||
generator = torch.Generator()
|
||||
synchronize_rng_states(["generator"], generator=generator)
|
||||
assert are_the_same_tensors(generator.get_state())
|
||||
|
||||
if state.local_process_index == 0:
|
||||
print("All rng are properly synched.")
|
||||
|
||||
@ -101,13 +108,14 @@ def dl_preparation_check():
|
||||
print("Shuffled dataloader passing.")
|
||||
|
||||
|
||||
def mock_training(length, batch_size):
|
||||
def mock_training(length, batch_size, generator):
|
||||
set_seed(42)
|
||||
generator.manual_seed(42)
|
||||
train_set = RegressionDataset(length=length)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
for _ in range(3):
|
||||
for epoch in range(3):
|
||||
for batch in train_dl:
|
||||
model.zero_grad()
|
||||
output = model(batch["x"])
|
||||
@ -119,21 +127,23 @@ def mock_training(length, batch_size):
|
||||
|
||||
def training_check():
|
||||
state = AcceleratorState()
|
||||
generator = torch.Generator()
|
||||
batch_size = 8
|
||||
length = batch_size * 4 * state.num_processes
|
||||
|
||||
train_set, old_model = mock_training(length, batch_size * state.num_processes)
|
||||
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator)
|
||||
assert are_the_same_tensors(old_model.a)
|
||||
assert are_the_same_tensors(old_model.b)
|
||||
|
||||
accelerator = Accelerator()
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
||||
set_seed(42)
|
||||
for _ in range(3):
|
||||
generator.manual_seed(42)
|
||||
for epoch in range(3):
|
||||
for batch in train_dl:
|
||||
model.zero_grad()
|
||||
output = model(batch["x"])
|
||||
@ -145,15 +155,16 @@ def training_check():
|
||||
assert torch.allclose(old_model.a, model.a)
|
||||
assert torch.allclose(old_model.b, model.b)
|
||||
|
||||
accelerator.print("Training yielded the same results on one CPU or distributes setup with no batch split.")
|
||||
accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
|
||||
|
||||
accelerator = Accelerator(split_batches=True)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
||||
set_seed(42)
|
||||
generator.manual_seed(42)
|
||||
for _ in range(3):
|
||||
for batch in train_dl:
|
||||
model.zero_grad()
|
||||
@ -170,12 +181,13 @@ def training_check():
|
||||
|
||||
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
|
||||
accelerator = Accelerator(fp16=True)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
|
||||
set_seed(42)
|
||||
generator.manual_seed(42)
|
||||
for _ in range(3):
|
||||
for batch in train_dl:
|
||||
model.zero_grad()
|
||||
@ -188,6 +200,7 @@ def training_check():
|
||||
assert torch.allclose(old_model.a, model.a)
|
||||
assert torch.allclose(old_model.b, model.b)
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = Accelerator()
|
||||
state = accelerator.state
|
||||
@ -207,5 +220,6 @@ def main():
|
||||
print("\n**Training integration test**")
|
||||
training_check()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -33,10 +33,19 @@ def are_the_same_tensors(tensor):
|
||||
return True
|
||||
|
||||
|
||||
def require_cuda(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip("test requires a GPU")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.
|
||||
|
||||
"""
|
||||
if not is_tpu_available():
|
||||
return unittest.skip("test requires TPU")(test_case)
|
||||
|
||||
@ -12,7 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,6 +27,21 @@ if is_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def is_boto3_available():
|
||||
return importlib.util.find_spec("boto3") is not None
|
||||
|
||||
|
||||
def is_sagemaker_available():
|
||||
return importlib.util.find_spec("sagemaker") is not None
|
||||
|
||||
|
||||
class RNGType(Enum):
|
||||
TORCH = "torch"
|
||||
CUDA = "cuda"
|
||||
XLA = "xla"
|
||||
GENERATOR = "generator"
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``.
|
||||
@ -36,27 +54,46 @@ def set_seed(seed: int):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_tpu_available():
|
||||
xm.set_rng_state(seed)
|
||||
|
||||
|
||||
def synchronize_rng_states():
|
||||
"""
|
||||
Helper function to synchronize the rng states in distributed training.
|
||||
"""
|
||||
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
|
||||
# Get the proper rng state
|
||||
if rng_type == RNGType.TORCH:
|
||||
rng_state = torch.get_rng_state()
|
||||
elif rng_type == RNGType.CUDA:
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
elif rng_type == RNGType.XLA:
|
||||
assert is_tpu_available(), "Can't synchronize XLA seeds on an environment without TPUs."
|
||||
rng_state = torch.tensor(xm.get_rng_state())
|
||||
elif rng_type == RNGType.GENERATOR:
|
||||
assert generator is not None, "Need a generator to synchronize its seed."
|
||||
rng_state = generator.get_state()
|
||||
|
||||
# Broadcast the rng state from device 0 to other devices
|
||||
state = AcceleratorState()
|
||||
if state.distributed_type == DistributedType.TPU:
|
||||
rng_state = torch.get_rng_state()
|
||||
rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0])
|
||||
torch.set_rng_state(rng_state)
|
||||
elif state.distributed_type == DistributedType.MULTI_GPU:
|
||||
rng_state = torch.get_rng_state().to(state.device)
|
||||
# Broadcast the state from process 0 to all the others.
|
||||
rng_state = rng_state.to(state.device)
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
torch.set_rng_state(rng_state.cpu())
|
||||
rng_state = rng_state.cpu()
|
||||
|
||||
# Broadcast the state from process 0 to all the others.
|
||||
rng_state = torch.cuda.get_rng_state().to(state.device)
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
torch.cuda.set_rng_state(rng_state.cpu())
|
||||
# Set the broadcast rng state
|
||||
if rng_type == RNGType.TORCH:
|
||||
torch.set_rng_state(rng_state)
|
||||
elif rng_type == RNGType.CUDA:
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
elif rng_type == RNGType.XLA:
|
||||
xm.set_rng_state(rng_state.item())
|
||||
elif rng_type == RNGType.GENERATOR:
|
||||
generator.set_state(rng_state)
|
||||
|
||||
|
||||
def synchronize_rng_states(rng_types: List[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
|
||||
for rng_type in rng_types:
|
||||
synchronize_rng_state(RNGType(rng_type), generator=generator)
|
||||
|
||||
|
||||
def send_to_device(tensor, device):
|
||||
@ -77,10 +114,7 @@ def send_to_device(tensor, device):
|
||||
elif isinstance(tensor, dict):
|
||||
return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()})
|
||||
elif not hasattr(tensor, "to"):
|
||||
raise TypeError(
|
||||
f"Can't send the values of type {type(tensor)} to device {device}, only of nested list/tuple/dicts "
|
||||
"of tensors or objects having a `to` method."
|
||||
)
|
||||
return tensor
|
||||
return tensor.to(device)
|
||||
|
||||
|
||||
@ -123,7 +157,7 @@ def _gpu_gather(tensor):
|
||||
|
||||
def gather(tensor):
|
||||
"""
|
||||
Recusrively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
|
||||
Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
|
||||
|
||||
Args:
|
||||
tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
|
||||
@ -140,6 +174,47 @@ def gather(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
|
||||
"""
|
||||
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they
|
||||
can safely be gathered.
|
||||
|
||||
Args:
|
||||
tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
|
||||
The data to gather.
|
||||
dim (:obj:`int`, `optional`, defaults to 0):
|
||||
The dimension on which to pad.
|
||||
pad_index (:obj:`int`, `optional`, defaults to 0):
|
||||
The value with which to pad.
|
||||
pad_first (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to pad at the beginning or the end.
|
||||
"""
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
return type(tensor)(pad_across_processes(t, dim=dim, pad_index=pad_index) for t in tensor)
|
||||
elif isinstance(tensor, dict):
|
||||
return type(tensor)({k: pad_across_processes(v, dim=dim, pad_index=pad_index) for k, v in tensor.items()})
|
||||
elif not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError(f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors.")
|
||||
|
||||
# Gather all sizes
|
||||
size = torch.tensor(tensor.shape, device=tensor.device)[None]
|
||||
sizes = gather(size).cpu()
|
||||
# Then pad to the maximum size
|
||||
max_size = max(s[dim] for s in sizes)
|
||||
old_size = tensor.shape
|
||||
new_size = list(old_size)
|
||||
new_size[dim] = max_size
|
||||
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
|
||||
if pad_first:
|
||||
indices = tuple(
|
||||
slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))
|
||||
)
|
||||
else:
|
||||
indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
|
||||
new_tensor[indices] = tensor
|
||||
return new_tensor
|
||||
|
||||
|
||||
def wait_for_everyone():
|
||||
"""
|
||||
Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
|
||||
|
||||
96
tests/test_kwargs_handlers.py
Normal file
96
tests/test_kwargs_handlers.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs, GradScalerKwargs
|
||||
from accelerate.kwargs_handlers import KwargsHandler
|
||||
from accelerate.test_utils import execute_subprocess_async, require_cuda, require_multi_gpu
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockClass(KwargsHandler):
|
||||
a: int = 0
|
||||
b: bool = False
|
||||
c: float = 3.0
|
||||
|
||||
|
||||
class DataLoaderTester(unittest.TestCase):
|
||||
def test_kwargs_handler(self):
|
||||
# If no defaults are changed, `to_kwargs` returns an empty dict.
|
||||
self.assertDictEqual(MockClass().to_kwargs(), {})
|
||||
self.assertDictEqual(MockClass(a=2).to_kwargs(), {"a": 2})
|
||||
self.assertDictEqual(MockClass(a=2, b=True).to_kwargs(), {"a": 2, "b": True})
|
||||
self.assertDictEqual(MockClass(a=2, c=2.25).to_kwargs(), {"a": 2, "c": 2.25})
|
||||
|
||||
@require_cuda
|
||||
def test_grad_scaler_kwargs(self):
|
||||
# If no defaults are changed, `to_kwargs` returns an empty dict.
|
||||
scaler_handler = GradScalerKwargs(init_scale=1024, growth_factor=2)
|
||||
accelerator = Accelerator(fp16=True, kwargs_handlers=[scaler_handler])
|
||||
print(accelerator.use_fp16)
|
||||
scaler = accelerator.scaler
|
||||
|
||||
# Check the kwargs have been applied
|
||||
self.assertEqual(scaler._init_scale, 1024.0)
|
||||
self.assertEqual(scaler._growth_factor, 2.0)
|
||||
|
||||
# Check the other values are at the default
|
||||
self.assertEqual(scaler._backoff_factor, 0.5)
|
||||
self.assertEqual(scaler._growth_interval, 2000)
|
||||
self.assertEqual(scaler._enabled, True)
|
||||
|
||||
@require_multi_gpu
|
||||
def test_ddp_kwargs(self):
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={torch.cuda.device_count()}
|
||||
--use_env
|
||||
{inspect.getfile(self.__class__)}
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args
|
||||
execute_subprocess_async(cmd, env=os.environ.copy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ddp_scaler = DistributedDataParallelKwargs(bucket_cap_mb=15, find_unused_parameters=True)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_scaler])
|
||||
model = torch.nn.Linear(100, 200)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Check the values changed in kwargs
|
||||
error_msg = ""
|
||||
observed_bucket_cap_map = model.bucket_bytes_cap // (1024 * 1024)
|
||||
if observed_bucket_cap_map != 15:
|
||||
error_msg += f"Kwargs badly passed, should have `15` but found {observed_bucket_cap_map}.\n"
|
||||
if model.find_unused_parameters is not True:
|
||||
error_msg += f"Kwargs badly passed, should have `True` but found {model.find_unused_parameters}.\n"
|
||||
|
||||
# Check the values of the defaults
|
||||
if model.dim != 0:
|
||||
error_msg += f"Default value not respected, should have `0` but found {model.dim}.\n"
|
||||
if model.broadcast_buffers is not True:
|
||||
error_msg += f"Default value not respected, should have `True` but found {model.broadcast_buffers}.\n"
|
||||
if model.gradient_as_bucket_view is not False:
|
||||
error_msg += f"Default value not respected, should have `False` but found {model.gradient_as_bucket_view}.\n"
|
||||
|
||||
# Raise error at the end to make sure we don't stop at the first failure.
|
||||
if len(error_msg) > 0:
|
||||
raise ValueError(error_msg)
|
||||
@ -20,6 +20,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
import accelerate
|
||||
from accelerate import Accelerator
|
||||
from accelerate.test_utils import execute_subprocess_async, require_multi_gpu
|
||||
|
||||
|
||||
@ -39,3 +40,43 @@ class MultiGPUTester(unittest.TestCase):
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args
|
||||
execute_subprocess_async(cmd, env=os.environ.copy())
|
||||
|
||||
@require_multi_gpu
|
||||
def test_pad_across_processes(self):
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={torch.cuda.device_count()}
|
||||
--use_env
|
||||
{inspect.getfile(self.__class__)}
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args
|
||||
execute_subprocess_async(cmd, env=os.environ.copy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = Accelerator()
|
||||
shape = (accelerator.state.process_index + 2, 10)
|
||||
tensor = torch.randint(0, 10, shape).to(accelerator.device)
|
||||
|
||||
error_msg = ""
|
||||
|
||||
tensor1 = accelerator.pad_across_processes(tensor)
|
||||
if tensor1.shape[0] != accelerator.state.num_processes + 1:
|
||||
error_msg += f"Found shape {tensor1.shape} but should have {accelerator.state.num_processes + 1} at dim 0."
|
||||
if not torch.equal(tensor1[: accelerator.state.process_index + 2], tensor):
|
||||
error_msg += "Tensors have different values."
|
||||
if not torch.all(tensor1[accelerator.state.process_index + 2 :] == 0):
|
||||
error_msg += "Padding was not done with the right value (0)."
|
||||
|
||||
tensor2 = accelerator.pad_across_processes(tensor, pad_first=True)
|
||||
if tensor2.shape[0] != accelerator.state.num_processes + 1:
|
||||
error_msg += f"Found shape {tensor2.shape} but should have {accelerator.state.num_processes + 1} at dim 0."
|
||||
index = accelerator.state.num_processes - accelerator.state.process_index - 1
|
||||
if not torch.equal(tensor2[index:], tensor):
|
||||
error_msg += "Tensors have different values."
|
||||
if not torch.all(tensor2[:index] == 0):
|
||||
error_msg += "Padding was not done with the right value (0)."
|
||||
|
||||
# Raise error at the end to make sure we don't stop at the first failure.
|
||||
if len(error_msg) > 0:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
62
tests/test_sagemaker.py
Normal file
62
tests/test_sagemaker.py
Normal file
@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from accelerate.commands.config.config_args import SageMakerConfig
|
||||
from accelerate.commands.launch import _convert_nargs_to_dict
|
||||
from accelerate.state import ComputeEnvironment
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockLaunchConfig(SageMakerConfig):
|
||||
compute_environment = ComputeEnvironment.AMAZON_SAGEMAKER
|
||||
fp16 = True
|
||||
ec2_instance_type = "ml.p3.2xlarge"
|
||||
iam_role_name = "accelerate_sagemaker_execution_role"
|
||||
profile = "hf-sm"
|
||||
region = "us-east-1"
|
||||
num_machines = 1
|
||||
base_job_name = "accelerate-sagemaker-1"
|
||||
pytorch_version = "1.6"
|
||||
transformers_version = "4.4"
|
||||
training_script = "train.py"
|
||||
success_training_script_args = [
|
||||
"--model_name_or_path",
|
||||
"bert",
|
||||
"--do_train",
|
||||
"False",
|
||||
"--epochs",
|
||||
"3",
|
||||
"--learning_rate",
|
||||
"5e-5",
|
||||
"--max_steps",
|
||||
"50.5",
|
||||
]
|
||||
fail_training_script_args = [
|
||||
"--model_name_or_path",
|
||||
"bert",
|
||||
"--do_train",
|
||||
"--do_test",
|
||||
"False",
|
||||
"--do_predict",
|
||||
"--epochs",
|
||||
"3",
|
||||
"--learning_rate",
|
||||
"5e-5",
|
||||
"--max_steps",
|
||||
"50.5",
|
||||
]
|
||||
|
||||
|
||||
class SageMakerLaunch(unittest.TestCase):
|
||||
def test_args_convert(self):
|
||||
# If no defaults are changed, `to_kwargs` returns an empty dict.
|
||||
converted_args = _convert_nargs_to_dict(MockLaunchConfig.success_training_script_args)
|
||||
assert isinstance(converted_args["model_name_or_path"], str)
|
||||
assert isinstance(converted_args["do_train"], bool)
|
||||
assert isinstance(converted_args["epochs"], int)
|
||||
assert isinstance(converted_args["learning_rate"], float)
|
||||
assert isinstance(converted_args["max_steps"], float)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_convert_nargs_to_dict(MockLaunchConfig.fail_training_script_args)
|
||||
Reference in New Issue
Block a user