mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
scripts: Check .is_cuda only in non-C++ files (#7561)
The check-torchcuda.py today will search for all occurrences of .is_cuda in the repository when a commit only modifies C++ headers and sources, which I believe is not intended. Check usage of .is_cuda only when a commit modifies any non-C++ file. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
@ -57,23 +57,24 @@ elif res.returncode == 2:
|
|||||||
|
|
||||||
files = []
|
files = []
|
||||||
for file in sys.argv[1:]:
|
for file in sys.argv[1:]:
|
||||||
if not file.endswith(".cpp"):
|
if file.endswith(".py"):
|
||||||
files.append(file)
|
files.append(file)
|
||||||
|
|
||||||
res = subprocess.run(
|
if len(files) > 0:
|
||||||
["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
|
res = subprocess.run(
|
||||||
capture_output=True,
|
["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
|
||||||
)
|
capture_output=True,
|
||||||
if res.returncode == 0:
|
)
|
||||||
err('''
|
if res.returncode == 0:
|
||||||
|
err('''
|
||||||
Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor.
|
Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor.
|
||||||
Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)",
|
Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)",
|
||||||
and add the following import line:
|
and add the following import line:
|
||||||
'from deepspeed.accelerator import get_accelerator'
|
'from deepspeed.accelerator import get_accelerator'
|
||||||
''')
|
''')
|
||||||
err(res.stdout.decode("utf-8"))
|
err(res.stdout.decode("utf-8"))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
elif res.returncode == 2:
|
elif res.returncode == 2:
|
||||||
err(f"Error invoking grep on {', '.join(files)}:")
|
err(f"Error invoking grep on {', '.join(files)}:")
|
||||||
err(res.stderr.decode("utf-8"))
|
err(res.stderr.decode("utf-8"))
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
Reference in New Issue
Block a user