mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
* add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
261 lines
7.0 KiB
Python
261 lines
7.0 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
|
|
import oneccl_bindings_for_pytorch # noqa: F401
|
|
import psutil
|
|
import os
|
|
|
|
|
|
# accelerator for Intel CPU
|
|
class CPU_Accelerator(DeepSpeedAccelerator):
|
|
|
|
def __init__(self):
|
|
self._name = 'cpu'
|
|
self._communication_backend_name = 'ccl'
|
|
self.max_mem = psutil.Process().memory_info().rss
|
|
|
|
def is_synchronized_device(self):
|
|
return True
|
|
|
|
# Device APIs
|
|
def device_name(self, device_index=None):
|
|
return 'cpu'
|
|
|
|
def device(self, device_index=None):
|
|
return None
|
|
|
|
def set_device(self, device_index):
|
|
return
|
|
|
|
def current_device(self):
|
|
return os.environ.get('LOCAL_RANK', 0)
|
|
|
|
def current_device_name(self):
|
|
return 'cpu'
|
|
|
|
def device_count(self):
|
|
device_count = int(os.environ.get('LOCAL_SIZE', 0))
|
|
if device_count > 0:
|
|
return os.environ.get('LOCAL_SIZE')
|
|
else:
|
|
from deepspeed.utils.numa import get_numa_cores
|
|
# Count NUMA node for number of cpu accelerators. On machine with HBM
|
|
# In flat mode, HBM is in separate NUMA node with no cores on this node.
|
|
# Ignore these NUMA nodes with no cores.
|
|
numa_core_lists = get_numa_cores()
|
|
numa_count = 0
|
|
for core_list in numa_core_lists:
|
|
if len(core_list) > 0:
|
|
numa_count += 1
|
|
return numa_count
|
|
|
|
def synchronize(self, device_index=None):
|
|
return
|
|
|
|
# RNG APIs
|
|
def random(self):
|
|
return torch.random
|
|
|
|
def set_rng_state(self, new_state, device_index=None):
|
|
if device_index == None:
|
|
return torch.set_rng_state(new_state)
|
|
return torch.set_rng_state(new_state, device_index)
|
|
|
|
def get_rng_state(self, device_index=None):
|
|
return torch.get_rng_state()
|
|
|
|
def manual_seed(self, seed):
|
|
return torch.manual_seed(seed)
|
|
|
|
def manual_seed_all(self, seed):
|
|
return torch.manual_seed(seed)
|
|
|
|
def initial_seed(self, seed):
|
|
return torch.initial_seed(seed)
|
|
|
|
def default_generator(self, device_index):
|
|
return torch.default_generator
|
|
|
|
# Streams/Events
|
|
@property
|
|
def Stream(self):
|
|
return None
|
|
|
|
def stream(self, stream):
|
|
from deepspeed.runtime.utils import noop_decorator
|
|
return noop_decorator
|
|
|
|
def current_stream(self, device_index=None):
|
|
return None
|
|
|
|
def default_stream(self, device_index=None):
|
|
return None
|
|
|
|
@property
|
|
def Event(self):
|
|
return None
|
|
|
|
# Memory management
|
|
def empty_cache(self):
|
|
return
|
|
|
|
def get_rss(self):
|
|
mem = psutil.Process().memory_info().rss
|
|
if mem > self.max_mem:
|
|
self.max_mem = mem
|
|
return mem
|
|
|
|
def reset_rss(self):
|
|
mem = psutil.Process().memory_info().rss
|
|
self.max_mem = mem
|
|
return mem
|
|
|
|
def memory_allocated(self, device_index=None):
|
|
return self.get_rss()
|
|
|
|
def max_memory_allocated(self, device_index=None):
|
|
self.get_rss()
|
|
return self.max_mem
|
|
|
|
def reset_max_memory_allocated(self, device_index=None):
|
|
self.reset_rss()
|
|
return
|
|
|
|
def memory_cached(self, device_index=None):
|
|
return self.get_rss()
|
|
|
|
def max_memory_cached(self, device_index=None):
|
|
self.get_rss()
|
|
return self.max_mem
|
|
|
|
def reset_max_memory_cached(self, device_index=None):
|
|
self.reset_rss()
|
|
return
|
|
|
|
def memory_stats(self, device_index=None):
|
|
return self.get_rss()
|
|
|
|
def reset_peak_memory_stats(self, device_index=None):
|
|
self.reset_rss()
|
|
return
|
|
|
|
def memory_reserved(self, device_index=None):
|
|
return self.get_rss()
|
|
|
|
def max_memory_reserved(self, device_index=None):
|
|
self.get_rss()
|
|
return self.max_mem
|
|
|
|
def total_memory(self, device_index=None):
|
|
return psutil.virtual_memory().total
|
|
|
|
# Misc
|
|
def amp(self):
|
|
return torch.cpu.amp
|
|
|
|
def is_available(self):
|
|
return True
|
|
|
|
def range_push(self, msg):
|
|
# TODO itt is currently not supported yet
|
|
# return torch.profiler.itt.range_push(msg)
|
|
return
|
|
|
|
def range_pop(self):
|
|
# TODO itt is currently not supported yet
|
|
# return torch.profiler.itt.range_pop()
|
|
return
|
|
|
|
def lazy_call(self, callback):
|
|
return callback()
|
|
|
|
def communication_backend_name(self):
|
|
return self._communication_backend_name
|
|
|
|
# Data types
|
|
def is_bf16_supported(self):
|
|
return True
|
|
|
|
def is_fp16_supported(self):
|
|
return True
|
|
|
|
# Tensor operations
|
|
|
|
@property
|
|
def BFloat16Tensor(self):
|
|
return torch.BFloat16Tensor
|
|
|
|
@property
|
|
def ByteTensor(self):
|
|
return torch.ByteTensor
|
|
|
|
@property
|
|
def DoubleTensor(self):
|
|
return torch.DoubleTensor
|
|
|
|
@property
|
|
def FloatTensor(self):
|
|
return torch.FloatTensor
|
|
|
|
@property
|
|
def HalfTensor(self):
|
|
return torch.HalfTensor
|
|
|
|
@property
|
|
def IntTensor(self):
|
|
return torch.IntTensor
|
|
|
|
@property
|
|
def LongTensor(self):
|
|
return torch.LongTensor
|
|
|
|
def pin_memory(self, tensor):
|
|
return tensor
|
|
|
|
def op_builder_dir(self):
|
|
try:
|
|
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
|
# if successful this also means we're doing a local install and not JIT compile path
|
|
from op_builder import __deepspeed__ # noqa: F401
|
|
return "op_builder.cpu"
|
|
except ImportError:
|
|
return "deepspeed.ops.op_builder.cpu"
|
|
|
|
def on_accelerator(self, tensor):
|
|
device_str = str(tensor.device)
|
|
if device_str.startswith('cpu'):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
# create an instance of op builder and return, name specified by class_name
|
|
def create_op_builder(self, op_name):
|
|
builder_class = self.get_op_builder(op_name)
|
|
if builder_class != None:
|
|
return builder_class()
|
|
return None
|
|
|
|
# return an op builder class, name specified by class_name
|
|
def get_op_builder(self, class_name):
|
|
try:
|
|
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
|
# if successful this also means we're doing a local install and not JIT compile path
|
|
from op_builder import __deepspeed__ # noqa: F401
|
|
from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
|
|
except ImportError:
|
|
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
|
|
|
|
if class_name == "CCLCommBuilder":
|
|
return CCLCommBuilder
|
|
else:
|
|
# return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
|
|
return NotImplementedBuilder
|
|
|
|
def build_extension(self):
|
|
from torch.utils.cpp_extension import BuildExtension
|
|
return BuildExtension
|