mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
pre-commit check for torch.cuda in code (#2981)
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
@ -68,3 +68,12 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-torchcuda
|
||||
name: check-torchcuda
|
||||
entry: ./scripts/check-torchcuda.py
|
||||
language: script
|
||||
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
|
||||
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
|
||||
|
@ -309,8 +309,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro
|
||||
import torchvision.models as models
|
||||
import torch
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
with torch.cuda.device(0):
|
||||
with get_accelerator().device(0):
|
||||
model = models.alexnet()
|
||||
batch_size = 256
|
||||
flops, macs, params = get_model_profile(model=model, # model
|
||||
@ -334,6 +335,7 @@ from functools import partial
|
||||
import torch
|
||||
from transformers import BertForSequenceClassification, BertTokenizer
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
def bert_input_constructor(batch_size, seq_len, tokenizer):
|
||||
@ -350,7 +352,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer):
|
||||
return inputs
|
||||
|
||||
|
||||
with torch.cuda.device(0):
|
||||
with get_accelerator().device(0):
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||
batch_size = 4
|
||||
|
@ -92,7 +92,7 @@ def _set_cuda_rng_state(new_state, device=-1):
|
||||
|
||||
Arguments:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
|
||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
|
||||
with a single change: the input state is not cloned. Cloning caused
|
||||
major performance issues for +4 GPU cases.
|
||||
"""
|
||||
@ -499,7 +499,7 @@ def get_cpu_activations_for_backward(args, inputs):
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
"""This function is adapted from torch.utils.checkpoint with
|
||||
two main changes:
|
||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
|
||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
|
||||
2) the states in the model parallel tracker are also properly
|
||||
tracked/set/reset.
|
||||
3) Performance activation partitioning, contiguous memory optimization
|
||||
|
@ -31,15 +31,10 @@ def print_rank_0(message, debug=False, force=False):
|
||||
print(message)
|
||||
|
||||
|
||||
device = get_accelerator().device_name()
|
||||
if device == 'cuda':
|
||||
try:
|
||||
autocast_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
autocast_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
except (ImportError, AttributeError) as exp:
|
||||
autocast_custom_fwd = noop_decorator
|
||||
autocast_custom_bwd = noop_decorator
|
||||
else:
|
||||
try:
|
||||
autocast_custom_fwd = get_accelerator().amp().custom_fwd
|
||||
autocast_custom_bwd = get_accelerator().amp().custom_bwd
|
||||
except (ImportError, AttributeError) as exp:
|
||||
autocast_custom_fwd = noop_decorator
|
||||
autocast_custom_bwd = noop_decorator
|
||||
|
||||
|
@ -849,7 +849,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
|
||||
|
||||
| Description | Default |
|
||||
| ------------------------------------------------------------- | ------- |
|
||||
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
|
||||
| Inserts get_accelerator().synchronize() at each checkpoint boundary. | `false` |
|
||||
|
||||
|
||||
<i>**profile**</i>: [boolean]
|
||||
|
@ -316,8 +316,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro
|
||||
import torchvision.models as models
|
||||
import torch
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
with torch.cuda.device(0):
|
||||
with get_accelerator().device(0):
|
||||
model = models.alexnet()
|
||||
batch_size = 256
|
||||
flops, macs, params = get_model_profile(model=model, # model
|
||||
@ -341,6 +342,7 @@ from functools import partial
|
||||
import torch
|
||||
from transformers import BertForSequenceClassification, BertTokenizer
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
def bert_input_constructor(batch_size, seq_len, tokenizer):
|
||||
@ -357,7 +359,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer):
|
||||
return inputs
|
||||
|
||||
|
||||
with torch.cuda.device(0):
|
||||
with get_accelerator().device(0):
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||
batch_size = 4
|
||||
|
@ -275,7 +275,7 @@ DeepSpeed's `save_checkpoint()`.
|
||||
sd['random_rng_state'] = random.getstate()
|
||||
sd['np_rng_state'] = np.random.get_state()
|
||||
sd['torch_rng_state'] = torch.get_rng_state()
|
||||
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
|
||||
sd['cuda_rng_state'] = get_accelerator().get_rng_state()
|
||||
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
|
||||
|
||||
model.save_checkpoint(args.save, iteration, client_state = sd)
|
||||
|
66
scripts/check-torchcuda.py
Executable file
66
scripts/check-torchcuda.py
Executable file
@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
'''Copyright The Microsoft DeepSpeed Team'''
|
||||
"""
|
||||
Checks each file in sys.argv for the string "torch.cuda".
|
||||
Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def err(s: str) -> None:
|
||||
print(s, file=sys.stderr)
|
||||
|
||||
|
||||
# There are many ways we could search for the string "torch.cuda", but `git
|
||||
# grep --no-index` is nice because
|
||||
# - it's very fast (as compared to iterating over the file in Python)
|
||||
# - we can reasonably assume it's available on all machines
|
||||
# - unlike plain grep, which is slower and has different flags on MacOS versus
|
||||
# Linux, git grep is always the same.
|
||||
res = subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"grep",
|
||||
"-Hn",
|
||||
"--no-index",
|
||||
"-e",
|
||||
r"torch\.cuda",
|
||||
"--and",
|
||||
"--not",
|
||||
"-e",
|
||||
"#ignore-cuda",
|
||||
*sys.argv[1:]
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
if res.returncode == 0:
|
||||
err('Error: The string "torch.cuda" was found.\nPlease replace all calls to torch.cuda with "get_accelerator()" and add the following import line:\n\n from deepspeed.accelerator import get_accelerator\n\nIf your code is mean to be cuda specific, please add the following comment in the line with torch.cuda:\n\n #ignore-cuda\n'
|
||||
)
|
||||
err(res.stdout.decode("utf-8"))
|
||||
sys.exit(1)
|
||||
elif res.returncode == 2:
|
||||
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
|
||||
err(res.stderr.decode("utf-8"))
|
||||
sys.exit(2)
|
||||
|
||||
res = subprocess.run(
|
||||
["git",
|
||||
"grep",
|
||||
"-Hn",
|
||||
"--no-index",
|
||||
r"\.cuda()",
|
||||
*sys.argv[1:]],
|
||||
capture_output=True,
|
||||
)
|
||||
if res.returncode == 0:
|
||||
err('Error: The string ".cuda()" was found. This implies convert a tensor to cuda tensor. Please replace all calls to tensor.cuda() with "tensor.to(get_accelerator().device_name())" and add the following import line:\nfrom deepspeed.accelerator import get_accelerator'
|
||||
)
|
||||
err(res.stdout.decode("utf-8"))
|
||||
sys.exit(1)
|
||||
elif res.returncode == 2:
|
||||
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
|
||||
err(res.stderr.decode("utf-8"))
|
||||
sys.exit(2)
|
@ -124,9 +124,10 @@ class DistributedExec(ABC):
|
||||
return fixture_kwargs
|
||||
|
||||
def _launch_procs(self, num_procs):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() < num_procs:
|
||||
if get_accelerator().is_available(
|
||||
) and get_accelerator().device_count() < num_procs:
|
||||
pytest.skip(
|
||||
f"Skipping test because not enough GPUs are available: {num_procs} required, {torch.cuda.device_count()} available"
|
||||
f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
|
||||
)
|
||||
mp.set_start_method('forkserver', force=True)
|
||||
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
|
||||
|
@ -8,7 +8,7 @@ from deepspeed.git_version_info import torch_info
|
||||
|
||||
def skip_on_arch(min_arch=7):
|
||||
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
|
||||
if torch.cuda.get_device_capability()[0] < min_arch:
|
||||
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
|
||||
pytest.skip(f"needs higher compute capability than {min_arch}")
|
||||
else:
|
||||
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
|
||||
|
Reference in New Issue
Block a user