mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
Avoid security issues of subprocess shell (#6498)
Avoid security issues of `shell=True` in subprocess --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
@ -10,7 +10,10 @@ import sys
|
|||||||
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
||||||
if not all(map(lambda v: v in os.environ, required_env)):
|
if not all(map(lambda v: v in os.environ, required_env)):
|
||||||
import subprocess
|
import subprocess
|
||||||
subprocess.run("deepspeed $(which ds_bench) " + " ".join(sys.argv[1:]), shell=True)
|
r = subprocess.check_output(["which", "ds_bench"])
|
||||||
|
ds_bench_bin = r.decode('utf-8').strip()
|
||||||
|
safe_cmd = ["deepspeed", ds_bench_bin] + sys.argv[1:]
|
||||||
|
subprocess.run(safe_cmd)
|
||||||
else:
|
else:
|
||||||
args = benchmark_parser().parse_args()
|
args = benchmark_parser().parse_args()
|
||||||
rank = args.local_rank
|
rank = args.local_rank
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
Functionality of swapping tensors to/from (NVMe) storage devices.
|
Functionality of swapping tensors to/from (NVMe) storage devices.
|
||||||
"""
|
"""
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
|
||||||
class Job(object):
|
class Job(object):
|
||||||
@ -39,10 +40,10 @@ class Job(object):
|
|||||||
|
|
||||||
|
|
||||||
def run_job(job):
|
def run_job(job):
|
||||||
args = ' '.join(job.cmd())
|
args = shlex.split(' '.join(job.cmd()))
|
||||||
print(f'args = {args}')
|
print(f'args = {args}')
|
||||||
job.open_output_file()
|
job.open_output_file()
|
||||||
proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
|
proc = subprocess.run(args=args, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
|
||||||
job.close_output_file()
|
job.close_output_file()
|
||||||
assert proc.returncode == 0, \
|
assert proc.returncode == 0, \
|
||||||
f"This command failed: {job.cmd()}"
|
f"This command failed: {job.cmd()}"
|
||||||
|
@ -697,8 +697,9 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True)
|
|||||||
|
|
||||||
master_addr = None
|
master_addr = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
hostname_cmd = ["hostname -I"]
|
import shlex
|
||||||
result = subprocess.check_output(hostname_cmd, shell=True)
|
hostname_cmd = shlex.split("hostname -I")
|
||||||
|
result = subprocess.check_output(hostname_cmd)
|
||||||
master_addr = result.decode('utf-8').split()[0]
|
master_addr = result.decode('utf-8').split()[0]
|
||||||
master_addr = comm.bcast(master_addr, root=0)
|
master_addr = comm.bcast(master_addr, root=0)
|
||||||
|
|
||||||
|
@ -54,7 +54,9 @@ class DSElasticAgent(LocalElasticAgent):
|
|||||||
|
|
||||||
if master_addr is None:
|
if master_addr is None:
|
||||||
# master_addr = _get_fq_hostname()
|
# master_addr = _get_fq_hostname()
|
||||||
result = subprocess.check_output("hostname -I", shell=True)
|
import shlex
|
||||||
|
safe_cmd = shlex.split("hostname -I")
|
||||||
|
result = subprocess.check_output(safe_cmd)
|
||||||
master_addr = result.decode('utf-8').split()[0]
|
master_addr = result.decode('utf-8').split()[0]
|
||||||
|
|
||||||
store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
|
store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
|
||||||
|
@ -406,7 +406,7 @@ class MVAPICHRunner(MultiNodeRunner):
|
|||||||
if not mpiname_exists:
|
if not mpiname_exists:
|
||||||
warnings.warn("mpiname does not exist, mvapich is not installed properly")
|
warnings.warn("mpiname does not exist, mvapich is not installed properly")
|
||||||
else:
|
else:
|
||||||
results = subprocess.check_output('mpiname', shell=True)
|
results = subprocess.check_output(['mpiname'])
|
||||||
mpiname_results = results.decode('utf-8').strip()
|
mpiname_results = results.decode('utf-8').strip()
|
||||||
if "MVAPICH2-GDR" in mpiname_results:
|
if "MVAPICH2-GDR" in mpiname_results:
|
||||||
exists = True
|
exists = True
|
||||||
|
@ -20,6 +20,7 @@ import collections
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
|
import shlex
|
||||||
|
|
||||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner
|
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner
|
||||||
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
|
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
|
||||||
@ -445,7 +446,8 @@ def main(args=None):
|
|||||||
if args.ssh_port is not None:
|
if args.ssh_port is not None:
|
||||||
ssh_check_cmd += f"-p {args.ssh_port} "
|
ssh_check_cmd += f"-p {args.ssh_port} "
|
||||||
ssh_check_cmd += f"{first_host} hostname"
|
ssh_check_cmd += f"{first_host} hostname"
|
||||||
subprocess.check_call(ssh_check_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL, shell=True)
|
safe_ssh_cmd = shlex.split(ssh_check_cmd)
|
||||||
|
subprocess.check_call(safe_ssh_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh."
|
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh."
|
||||||
@ -458,9 +460,9 @@ def main(args=None):
|
|||||||
if args.ssh_port is not None:
|
if args.ssh_port is not None:
|
||||||
ssh_check_cmd += f" -p {args.ssh_port}"
|
ssh_check_cmd += f" -p {args.ssh_port}"
|
||||||
ssh_check_cmd += f" {first_host} hostname -I"
|
ssh_check_cmd += f" {first_host} hostname -I"
|
||||||
hostname_cmd = [ssh_check_cmd]
|
hostname_cmd = shlex.split(ssh_check_cmd)
|
||||||
try:
|
try:
|
||||||
result = subprocess.check_output(hostname_cmd, shell=True)
|
result = subprocess.check_output(hostname_cmd)
|
||||||
except subprocess.CalledProcessError as err:
|
except subprocess.CalledProcessError as err:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
|
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
|
||||||
|
@ -253,7 +253,8 @@ class OpBuilder(ABC):
|
|||||||
rocm_info = Path("rocminfo")
|
rocm_info = Path("rocminfo")
|
||||||
rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'"
|
rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'"
|
||||||
try:
|
try:
|
||||||
result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True)
|
safe_cmd = shlex.split(rocm_gpu_arch_cmd)
|
||||||
|
result = subprocess.check_output(safe_cmd)
|
||||||
rocm_gpu_arch = result.decode('utf-8').strip()
|
rocm_gpu_arch = result.decode('utf-8').strip()
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
rocm_gpu_arch = ""
|
rocm_gpu_arch = ""
|
||||||
@ -271,7 +272,8 @@ class OpBuilder(ABC):
|
|||||||
rocm_wavefront_size_cmd = str(
|
rocm_wavefront_size_cmd = str(
|
||||||
rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'"
|
rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'"
|
||||||
try:
|
try:
|
||||||
result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True)
|
safe_cmd = shlex.split(rocm_wavefront_size_cmd)
|
||||||
|
result = subprocess.check_output(rocm_wavefront_size_cmd)
|
||||||
rocm_wavefront_size = result.decode('utf-8').strip()
|
rocm_wavefront_size = result.decode('utf-8').strip()
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
rocm_wavefront_size = "32"
|
rocm_wavefront_size = "32"
|
||||||
@ -432,7 +434,7 @@ class OpBuilder(ABC):
|
|||||||
"to detect the CPU architecture. 'lscpu' does not appear to exist on "
|
"to detect the CPU architecture. 'lscpu' does not appear to exist on "
|
||||||
"your system, will fall back to use -march=native and non-vectorized execution.")
|
"your system, will fall back to use -march=native and non-vectorized execution.")
|
||||||
return None
|
return None
|
||||||
result = subprocess.check_output('lscpu', shell=True)
|
result = subprocess.check_output(['lscpu'])
|
||||||
result = result.decode('utf-8').strip().lower()
|
result = result.decode('utf-8').strip().lower()
|
||||||
|
|
||||||
cpu_info = {}
|
cpu_info = {}
|
||||||
|
14
setup.py
14
setup.py
@ -27,6 +27,7 @@ from setuptools import setup, find_packages
|
|||||||
from setuptools.command import egg_info
|
from setuptools.command import egg_info
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
import shlex
|
||||||
|
|
||||||
torch_available = True
|
torch_available = True
|
||||||
try:
|
try:
|
||||||
@ -157,10 +158,11 @@ if BUILD_OP_DEFAULT:
|
|||||||
|
|
||||||
def command_exists(cmd):
|
def command_exists(cmd):
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
result = subprocess.Popen(f'{cmd}', stdout=subprocess.PIPE, shell=True)
|
safe_cmd = shlex.split(f'{cmd}')
|
||||||
|
result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
|
||||||
return result.wait() == 1
|
return result.wait() == 1
|
||||||
else:
|
else:
|
||||||
safe_cmd = ["bash", "-c", f"type {cmd}"]
|
safe_cmd = shlex.split(f"bash -c type {cmd}")
|
||||||
result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
|
result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
|
||||||
return result.wait() == 0
|
return result.wait() == 0
|
||||||
|
|
||||||
@ -200,13 +202,13 @@ for op_name, builder in ALL_OPS.items():
|
|||||||
print(f'Install Ops={install_ops}')
|
print(f'Install Ops={install_ops}')
|
||||||
|
|
||||||
# Write out version/git info.
|
# Write out version/git info.
|
||||||
git_hash_cmd = "git rev-parse --short HEAD"
|
git_hash_cmd = shlex.split("bash -c git rev-parse --short HEAD")
|
||||||
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
|
git_branch_cmd = shlex.split("bash -c git rev-parse --abbrev-ref HEAD")
|
||||||
if command_exists('git') and not is_env_set('DS_BUILD_STRING'):
|
if command_exists('git') and not is_env_set('DS_BUILD_STRING'):
|
||||||
try:
|
try:
|
||||||
result = subprocess.check_output(git_hash_cmd, shell=True)
|
result = subprocess.check_output(git_hash_cmd)
|
||||||
git_hash = result.decode('utf-8').strip()
|
git_hash = result.decode('utf-8').strip()
|
||||||
result = subprocess.check_output(git_branch_cmd, shell=True)
|
result = subprocess.check_output(git_branch_cmd)
|
||||||
git_branch = result.decode('utf-8').strip()
|
git_branch = result.decode('utf-8').strip()
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
git_hash = "unknown"
|
git_hash = "unknown"
|
||||||
|
@ -7,6 +7,7 @@ import unittest
|
|||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
|
||||||
class BaseTestCase(unittest.TestCase):
|
class BaseTestCase(unittest.TestCase):
|
||||||
@ -40,9 +41,9 @@ class BaseTestCase(unittest.TestCase):
|
|||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
def clean_test_env(self):
|
def clean_test_env(self):
|
||||||
cmd = "dlts_ssh pkill -9 -f /usr/bin/python"
|
cmd = shlex.split("dlts_ssh pkill -9 -f /usr/bin/python")
|
||||||
print(cmd)
|
print(cmd)
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
time.sleep(20)
|
time.sleep(20)
|
||||||
|
|
||||||
def run_BingBertSquad_test(self, test_config, output):
|
def run_BingBertSquad_test(self, test_config, output):
|
||||||
@ -50,8 +51,8 @@ class BaseTestCase(unittest.TestCase):
|
|||||||
other_args = " " + test_config["other_args"] if "other_args" in test_config else " "
|
other_args = " " + test_config["other_args"] if "other_args" in test_config else " "
|
||||||
|
|
||||||
cmd = "./run_BingBertSquad_sanity.sh -e 1 -g {0} {1} {2}".format(test_config["gpus"], other_args, ds_flag)
|
cmd = "./run_BingBertSquad_sanity.sh -e 1 -g {0} {1} {2}".format(test_config["gpus"], other_args, ds_flag)
|
||||||
|
cmd = shlex.split(cmd)
|
||||||
self.ensure_directory_exists(output)
|
self.ensure_directory_exists(output)
|
||||||
with open(output, "w") as f:
|
with open(output, "w") as f:
|
||||||
print(cmd)
|
print(cmd)
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash', stdout=f, stderr=f)
|
subprocess.run(cmd, check=False, executable='/bin/bash', stdout=f, stderr=f)
|
||||||
|
@ -10,6 +10,7 @@ import unittest
|
|||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shlex
|
||||||
from .test_common import BaseTestCase
|
from .test_common import BaseTestCase
|
||||||
|
|
||||||
LAYERS = 2
|
LAYERS = 2
|
||||||
@ -18,9 +19,9 @@ ATTN_HEADS = 8
|
|||||||
|
|
||||||
|
|
||||||
def remove_file(test_id, filename):
|
def remove_file(test_id, filename):
|
||||||
cmd = f"if [ -f {filename} ] ; then rm -v {filename}; fi"
|
cmd = shlex.split(f"if [ -f {filename} ] ; then rm -v {filename}; fi")
|
||||||
print(f"{test_id} cmd: {cmd}")
|
print(f"{test_id} cmd: {cmd}")
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
|
|
||||||
|
|
||||||
def grep_loss_from_file(file_name):
|
def grep_loss_from_file(file_name):
|
||||||
@ -451,9 +452,9 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||||||
checkpoint_name = test_config["checkpoint_name"]
|
checkpoint_name = test_config["checkpoint_name"]
|
||||||
#---------------remove old checkpoint---------------#
|
#---------------remove old checkpoint---------------#
|
||||||
try:
|
try:
|
||||||
cmd = f"rm -rf {checkpoint_name}"
|
cmd = shlex.split(f"rm -rf {checkpoint_name}")
|
||||||
print(f"{self.id()} cmd: {cmd}")
|
print(f"{self.id()} cmd: {cmd}")
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
except:
|
except:
|
||||||
print("No old checkpoint")
|
print("No old checkpoint")
|
||||||
|
|
||||||
@ -474,8 +475,8 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||||||
|
|
||||||
# remove previous test log
|
# remove previous test log
|
||||||
try:
|
try:
|
||||||
cmd = f"rm {base_file}"
|
cmd = shlex.split(f"rm {base_file}")
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
except:
|
except:
|
||||||
print(f"{self.id()} No old logs")
|
print(f"{self.id()} No old logs")
|
||||||
|
|
||||||
@ -489,9 +490,9 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||||||
|
|
||||||
# set checkpoint load iteration
|
# set checkpoint load iteration
|
||||||
try:
|
try:
|
||||||
cmd = f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt"
|
cmd = shlex.split(f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt")
|
||||||
print(f"{self.id()} running cmd: {cmd}")
|
print(f"{self.id()} running cmd: {cmd}")
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
except:
|
except:
|
||||||
print(f"{self.id()} Failed to update the checkpoint iteration file")
|
print(f"{self.id()} Failed to update the checkpoint iteration file")
|
||||||
return False
|
return False
|
||||||
@ -506,8 +507,8 @@ class GPT2CheckpointTestCase(BaseTestCase):
|
|||||||
|
|
||||||
# remove previous test log
|
# remove previous test log
|
||||||
try:
|
try:
|
||||||
cmd = f"rm {test_file}"
|
cmd = shlex.split(f"rm {test_file}")
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
except:
|
except:
|
||||||
print(f"{self.id()} no previous logs for")
|
print(f"{self.id()} no previous logs for")
|
||||||
self.run_gpt2_test(test_config, test_file)
|
self.run_gpt2_test(test_config, test_file)
|
||||||
|
@ -7,6 +7,7 @@ import unittest
|
|||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
|
||||||
class BaseTestCase(unittest.TestCase):
|
class BaseTestCase(unittest.TestCase):
|
||||||
@ -46,9 +47,9 @@ class BaseTestCase(unittest.TestCase):
|
|||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
def clean_test_env(self):
|
def clean_test_env(self):
|
||||||
cmd = "dlts_ssh pkill -9 -f /usr/bin/python"
|
cmd = shlex.split("dlts_ssh pkill -9 -f /usr/bin/python")
|
||||||
print(cmd)
|
print(cmd)
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
|
subprocess.run(cmd, check=False, executable='/bin/bash')
|
||||||
time.sleep(20)
|
time.sleep(20)
|
||||||
|
|
||||||
def run_gpt2_test(self, test_config, output):
|
def run_gpt2_test(self, test_config, output):
|
||||||
@ -60,8 +61,8 @@ class BaseTestCase(unittest.TestCase):
|
|||||||
test_config["mp"], test_config["gpus"], test_config["nodes"], test_config["bs"], test_config["steps"],
|
test_config["mp"], test_config["gpus"], test_config["nodes"], test_config["bs"], test_config["steps"],
|
||||||
test_config["layers"], test_config["hidden_size"], test_config["seq_length"], test_config["heads"],
|
test_config["layers"], test_config["hidden_size"], test_config["seq_length"], test_config["heads"],
|
||||||
ckpt_num, other_args, ds_flag)
|
ckpt_num, other_args, ds_flag)
|
||||||
|
cmd = shlex.split(cmd)
|
||||||
self.ensure_directory_exists(output)
|
self.ensure_directory_exists(output)
|
||||||
with open(output, "w") as f:
|
with open(output, "w") as f:
|
||||||
print(cmd)
|
print(cmd)
|
||||||
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash', stdout=f, stderr=f)
|
subprocess.run(cmd, check=False, executable='/bin/bash', stdout=f, stderr=f)
|
||||||
|
@ -58,6 +58,20 @@ def get_master_port(base_port=29500, port_range_size=1000):
|
|||||||
raise IOError('no free ports')
|
raise IOError('no free ports')
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cpu_socket_count():
|
||||||
|
import shlex
|
||||||
|
p1 = subprocess.Popen(shlex.split("cat /proc/cpuinfo"), stdout=subprocess.PIPE)
|
||||||
|
p2 = subprocess.Popen(["grep", "physical id"], stdin=p1.stdout, stdout=subprocess.PIPE)
|
||||||
|
p1.stdout.close()
|
||||||
|
p3 = subprocess.Popen(shlex.split("sort -u"), stdin=p2.stdout, stdout=subprocess.PIPE)
|
||||||
|
p2.stdout.close()
|
||||||
|
p4 = subprocess.Popen(shlex.split("wc -l"), stdin=p3.stdout, stdout=subprocess.PIPE)
|
||||||
|
p3.stdout.close()
|
||||||
|
r = int(p4.communicate()[0])
|
||||||
|
p4.stdout.close()
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
def set_accelerator_visible():
|
def set_accelerator_visible():
|
||||||
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||||
xdist_worker_id = get_xdist_worker_id()
|
xdist_worker_id = get_xdist_worker_id()
|
||||||
@ -95,9 +109,7 @@ def set_accelerator_visible():
|
|||||||
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
|
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
|
||||||
else:
|
else:
|
||||||
assert get_accelerator().device_name() == 'cpu'
|
assert get_accelerator().device_name() == 'cpu'
|
||||||
cpu_sockets = int(
|
num_accelerators = _get_cpu_socket_count()
|
||||||
subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
|
|
||||||
num_accelerators = cpu_sockets
|
|
||||||
|
|
||||||
if isinstance(num_accelerators, list):
|
if isinstance(num_accelerators, list):
|
||||||
cuda_visible = ",".join(num_accelerators)
|
cuda_visible = ",".join(num_accelerators)
|
||||||
|
Reference in New Issue
Block a user