diff --git a/scripts/check-torchcuda.py b/scripts/check-torchcuda.py index 0723c9888..639d11ad5 100755 --- a/scripts/check-torchcuda.py +++ b/scripts/check-torchcuda.py @@ -57,23 +57,24 @@ elif res.returncode == 2: files = [] for file in sys.argv[1:]: - if not file.endswith(".cpp"): + if file.endswith(".py"): files.append(file) -res = subprocess.run( - ["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files], - capture_output=True, -) -if res.returncode == 0: - err(''' +if len(files) > 0: + res = subprocess.run( + ["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files], + capture_output=True, + ) + if res.returncode == 0: + err(''' Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor. Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)", and add the following import line: 'from deepspeed.accelerator import get_accelerator' ''') - err(res.stdout.decode("utf-8")) - sys.exit(1) -elif res.returncode == 2: - err(f"Error invoking grep on {', '.join(files)}:") - err(res.stderr.decode("utf-8")) - sys.exit(2) + err(res.stdout.decode("utf-8")) + sys.exit(1) + elif res.returncode == 2: + err(f"Error invoking grep on {', '.join(files)}:") + err(res.stderr.decode("utf-8")) + sys.exit(2)