pre-commit check for torch.cuda in code (#2981)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Ma, Guokai
2023-03-24 11:29:54 +08:00
committed by GitHub
parent e80ae08886
commit 090d49e79f
10 changed files with 95 additions and 20 deletions

View File

@ -68,3 +68,12 @@ repos:
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
- repo: local
hooks:
- id: check-torchcuda
name: check-torchcuda
entry: ./scripts/check-torchcuda.py
language: script
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

View File

@ -309,8 +309,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro
import torchvision.models as models
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
with torch.cuda.device(0):
with get_accelerator().device(0):
model = models.alexnet()
batch_size = 256
flops, macs, params = get_model_profile(model=model, # model
@ -334,6 +335,7 @@ from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
def bert_input_constructor(batch_size, seq_len, tokenizer):
@ -350,7 +352,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer):
return inputs
with torch.cuda.device(0):
with get_accelerator().device(0):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
batch_size = 4

View File

@ -92,7 +92,7 @@ def _set_cuda_rng_state(new_state, device=-1):
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
@ -499,7 +499,7 @@ def get_cpu_activations_for_backward(args, inputs):
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization

View File

@ -31,15 +31,10 @@ def print_rank_0(message, debug=False, force=False):
print(message)
device = get_accelerator().device_name()
if device == 'cuda':
try:
autocast_custom_fwd = torch.cuda.amp.custom_fwd
autocast_custom_bwd = torch.cuda.amp.custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator
else:
try:
autocast_custom_fwd = get_accelerator().amp().custom_fwd
autocast_custom_bwd = get_accelerator().amp().custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator

View File

@ -849,7 +849,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
| Description | Default |
| ------------------------------------------------------------- | ------- |
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
| Inserts get_accelerator().synchronize() at each checkpoint boundary. | `false` |
<i>**profile**</i>: [boolean]

View File

@ -316,8 +316,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro
import torchvision.models as models
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
with torch.cuda.device(0):
with get_accelerator().device(0):
model = models.alexnet()
batch_size = 256
flops, macs, params = get_model_profile(model=model, # model
@ -341,6 +342,7 @@ from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator
def bert_input_constructor(batch_size, seq_len, tokenizer):
@ -357,7 +359,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer):
return inputs
with torch.cuda.device(0):
with get_accelerator().device(0):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
batch_size = 4

View File

@ -275,7 +275,7 @@ DeepSpeed's `save_checkpoint()`.
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['cuda_rng_state'] = get_accelerator().get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
model.save_checkpoint(args.save, iteration, client_state = sd)

66
scripts/check-torchcuda.py Executable file
View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
from __future__ import annotations
'''Copyright The Microsoft DeepSpeed Team'''
"""
Checks each file in sys.argv for the string "torch.cuda".
Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py
"""
import subprocess
import sys
def err(s: str) -> None:
print(s, file=sys.stderr)
# There are many ways we could search for the string "torch.cuda", but `git
# grep --no-index` is nice because
# - it's very fast (as compared to iterating over the file in Python)
# - we can reasonably assume it's available on all machines
# - unlike plain grep, which is slower and has different flags on MacOS versus
# Linux, git grep is always the same.
res = subprocess.run(
[
"git",
"grep",
"-Hn",
"--no-index",
"-e",
r"torch\.cuda",
"--and",
"--not",
"-e",
"#ignore-cuda",
*sys.argv[1:]
],
capture_output=True,
)
if res.returncode == 0:
err('Error: The string "torch.cuda" was found.\nPlease replace all calls to torch.cuda with "get_accelerator()" and add the following import line:\n\n from deepspeed.accelerator import get_accelerator\n\nIf your code is mean to be cuda specific, please add the following comment in the line with torch.cuda:\n\n #ignore-cuda\n'
)
err(res.stdout.decode("utf-8"))
sys.exit(1)
elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
err(res.stderr.decode("utf-8"))
sys.exit(2)
res = subprocess.run(
["git",
"grep",
"-Hn",
"--no-index",
r"\.cuda()",
*sys.argv[1:]],
capture_output=True,
)
if res.returncode == 0:
err('Error: The string ".cuda()" was found. This implies convert a tensor to cuda tensor. Please replace all calls to tensor.cuda() with "tensor.to(get_accelerator().device_name())" and add the following import line:\nfrom deepspeed.accelerator import get_accelerator'
)
err(res.stdout.decode("utf-8"))
sys.exit(1)
elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
err(res.stderr.decode("utf-8"))
sys.exit(2)

View File

@ -124,9 +124,10 @@ class DistributedExec(ABC):
return fixture_kwargs
def _launch_procs(self, num_procs):
if torch.cuda.is_available() and torch.cuda.device_count() < num_procs:
if get_accelerator().is_available(
) and get_accelerator().device_count() < num_procs:
pytest.skip(
f"Skipping test because not enough GPUs are available: {num_procs} required, {torch.cuda.device_count()} available"
f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
)
mp.set_start_method('forkserver', force=True)
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason

View File

@ -8,7 +8,7 @@ from deepspeed.git_version_info import torch_info
def skip_on_arch(min_arch=7):
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
if torch.cuda.get_device_capability()[0] < min_arch:
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
pytest.skip(f"needs higher compute capability than {min_arch}")
else:
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'