diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db67534e1..eeefecd0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,3 +68,12 @@ repos: hooks: - id: flake8 args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401'] + +- repo: local + hooks: + - id: check-torchcuda + name: check-torchcuda + entry: ./scripts/check-torchcuda.py + language: script + exclude: ^(.github/workflows/|scripts/check-torchcuda.py|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py) + # Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm diff --git a/deepspeed/profiling/flops_profiler/README.md b/deepspeed/profiling/flops_profiler/README.md index 6d749c5df..af23d56ee 100644 --- a/deepspeed/profiling/flops_profiler/README.md +++ b/deepspeed/profiling/flops_profiler/README.md @@ -309,8 +309,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro import torchvision.models as models import torch from deepspeed.profiling.flops_profiler import get_model_profile +from deepspeed.accelerator import get_accelerator -with torch.cuda.device(0): +with get_accelerator().device(0): model = models.alexnet() batch_size = 256 flops, macs, params = get_model_profile(model=model, # model @@ -334,6 +335,7 @@ from functools import partial import torch from transformers import BertForSequenceClassification, BertTokenizer from deepspeed.profiling.flops_profiler import get_model_profile +from deepspeed.accelerator import get_accelerator def bert_input_constructor(batch_size, seq_len, tokenizer): @@ -350,7 +352,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer): return inputs -with torch.cuda.device(0): +with get_accelerator().device(0): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased') batch_size = 4 diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 743725401..eccf9073a 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -92,7 +92,7 @@ def _set_cuda_rng_state(new_state, device=-1): Arguments: new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases. """ @@ -499,7 +499,7 @@ def get_cpu_activations_for_backward(args, inputs): class CheckpointFunction(torch.autograd.Function): """This function is adapted from torch.utils.checkpoint with two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda 2) the states in the model parallel tracker are also properly tracked/set/reset. 3) Performance activation partitioning, contiguous memory optimization diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 6fbcabb16..1ef43d85d 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -31,15 +31,10 @@ def print_rank_0(message, debug=False, force=False): print(message) -device = get_accelerator().device_name() -if device == 'cuda': - try: - autocast_custom_fwd = torch.cuda.amp.custom_fwd - autocast_custom_bwd = torch.cuda.amp.custom_bwd - except (ImportError, AttributeError) as exp: - autocast_custom_fwd = noop_decorator - autocast_custom_bwd = noop_decorator -else: +try: + autocast_custom_fwd = get_accelerator().amp().custom_fwd + autocast_custom_bwd = get_accelerator().amp().custom_bwd +except (ImportError, AttributeError) as exp: autocast_custom_fwd = noop_decorator autocast_custom_bwd = noop_decorator diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 2d497bb1b..dde45da12 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -849,7 +849,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s | Description | Default | | ------------------------------------------------------------- | ------- | -| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` | +| Inserts get_accelerator().synchronize() at each checkpoint boundary. | `false` | **profile**: [boolean] diff --git a/docs/_tutorials/flops-profiler.md b/docs/_tutorials/flops-profiler.md index 169bfb18d..24efc2386 100644 --- a/docs/_tutorials/flops-profiler.md +++ b/docs/_tutorials/flops-profiler.md @@ -316,8 +316,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro import torchvision.models as models import torch from deepspeed.profiling.flops_profiler import get_model_profile +from deepspeed.accelerator import get_accelerator -with torch.cuda.device(0): +with get_accelerator().device(0): model = models.alexnet() batch_size = 256 flops, macs, params = get_model_profile(model=model, # model @@ -341,6 +342,7 @@ from functools import partial import torch from transformers import BertForSequenceClassification, BertTokenizer from deepspeed.profiling.flops_profiler import get_model_profile +from deepspeed.accelerator import get_accelerator def bert_input_constructor(batch_size, seq_len, tokenizer): @@ -357,7 +359,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer): return inputs -with torch.cuda.device(0): +with get_accelerator().device(0): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased') batch_size = 4 diff --git a/docs/_tutorials/megatron.md b/docs/_tutorials/megatron.md index 2977f5773..0ccfd3ec0 100644 --- a/docs/_tutorials/megatron.md +++ b/docs/_tutorials/megatron.md @@ -275,7 +275,7 @@ DeepSpeed's `save_checkpoint()`. sd['random_rng_state'] = random.getstate() sd['np_rng_state'] = np.random.get_state() sd['torch_rng_state'] = torch.get_rng_state() - sd['cuda_rng_state'] = torch.cuda.get_rng_state() + sd['cuda_rng_state'] = get_accelerator().get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() model.save_checkpoint(args.save, iteration, client_state = sd) diff --git a/scripts/check-torchcuda.py b/scripts/check-torchcuda.py new file mode 100755 index 000000000..773db41c9 --- /dev/null +++ b/scripts/check-torchcuda.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +from __future__ import annotations +'''Copyright The Microsoft DeepSpeed Team''' +""" +Checks each file in sys.argv for the string "torch.cuda". +Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py +""" + +import subprocess +import sys + + +def err(s: str) -> None: + print(s, file=sys.stderr) + + +# There are many ways we could search for the string "torch.cuda", but `git +# grep --no-index` is nice because +# - it's very fast (as compared to iterating over the file in Python) +# - we can reasonably assume it's available on all machines +# - unlike plain grep, which is slower and has different flags on MacOS versus +# Linux, git grep is always the same. +res = subprocess.run( + [ + "git", + "grep", + "-Hn", + "--no-index", + "-e", + r"torch\.cuda", + "--and", + "--not", + "-e", + "#ignore-cuda", + *sys.argv[1:] + ], + capture_output=True, +) +if res.returncode == 0: + err('Error: The string "torch.cuda" was found.\nPlease replace all calls to torch.cuda with "get_accelerator()" and add the following import line:\n\n from deepspeed.accelerator import get_accelerator\n\nIf your code is mean to be cuda specific, please add the following comment in the line with torch.cuda:\n\n #ignore-cuda\n' + ) + err(res.stdout.decode("utf-8")) + sys.exit(1) +elif res.returncode == 2: + err(f"Error invoking grep on {', '.join(sys.argv[1:])}:") + err(res.stderr.decode("utf-8")) + sys.exit(2) + +res = subprocess.run( + ["git", + "grep", + "-Hn", + "--no-index", + r"\.cuda()", + *sys.argv[1:]], + capture_output=True, +) +if res.returncode == 0: + err('Error: The string ".cuda()" was found. This implies convert a tensor to cuda tensor. Please replace all calls to tensor.cuda() with "tensor.to(get_accelerator().device_name())" and add the following import line:\nfrom deepspeed.accelerator import get_accelerator' + ) + err(res.stdout.decode("utf-8")) + sys.exit(1) +elif res.returncode == 2: + err(f"Error invoking grep on {', '.join(sys.argv[1:])}:") + err(res.stderr.decode("utf-8")) + sys.exit(2) diff --git a/tests/unit/common.py b/tests/unit/common.py index 35e8f3983..acc778a88 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -124,9 +124,10 @@ class DistributedExec(ABC): return fixture_kwargs def _launch_procs(self, num_procs): - if torch.cuda.is_available() and torch.cuda.device_count() < num_procs: + if get_accelerator().is_available( + ) and get_accelerator().device_count() < num_procs: pytest.skip( - f"Skipping test because not enough GPUs are available: {num_procs} required, {torch.cuda.device_count()} available" + f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" ) mp.set_start_method('forkserver', force=True) skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason diff --git a/tests/unit/util.py b/tests/unit/util.py index fa74e92d8..c206b39e2 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -8,7 +8,7 @@ from deepspeed.git_version_info import torch_info def skip_on_arch(min_arch=7): if deepspeed.accelerator.get_accelerator().device_name() == 'cuda': - if torch.cuda.get_device_capability()[0] < min_arch: + if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda pytest.skip(f"needs higher compute capability than {min_arch}") else: assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'