scripts: Check .is_cuda only in non-C++ files (#7561)

The check-torchcuda.py today will search for all occurrences of .is_cuda
in the repository when a commit only modifies C++ headers and sources,
which I believe is not intended.

Check usage of .is_cuda only when a commit modifies any non-C++ file.

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
Junjie Mao
2025-09-19 13:01:50 +08:00
committed by GitHub
parent 2585881ae9
commit 6b731c5c96

View File

@ -57,23 +57,24 @@ elif res.returncode == 2:
files = [] files = []
for file in sys.argv[1:]: for file in sys.argv[1:]:
if not file.endswith(".cpp"): if file.endswith(".py"):
files.append(file) files.append(file)
res = subprocess.run( if len(files) > 0:
["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files], res = subprocess.run(
capture_output=True, ["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
) capture_output=True,
if res.returncode == 0: )
err(''' if res.returncode == 0:
err('''
Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor. Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor.
Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)", Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)",
and add the following import line: and add the following import line:
'from deepspeed.accelerator import get_accelerator' 'from deepspeed.accelerator import get_accelerator'
''') ''')
err(res.stdout.decode("utf-8")) err(res.stdout.decode("utf-8"))
sys.exit(1) sys.exit(1)
elif res.returncode == 2: elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(files)}:") err(f"Error invoking grep on {', '.join(files)}:")
err(res.stderr.decode("utf-8")) err(res.stderr.decode("utf-8"))
sys.exit(2) sys.exit(2)