Files
DeepSpeed/deepspeed/module_inject/containers/features/meta_tensor.py
Olatunji Ruwase fd40516923 Update GH org references (#6998)
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Fabien Dupont <fdupont@redhat.com>
Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr>
2025-02-05 00:56:50 +00:00

71 lines
2.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from abc import ABC, abstractmethod
from packaging import version as pkg_version
import torch
class MetaTensorContainer(ABC):
"""
NOTE: If you are using this feature with a container that
also inherits from `HybridEngineContainer`, ensure that `MetaTensorContainer`
is inherited before `HybridEngineContainer` in the class definition.
"""
def __init__(self, **kwargs):
if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
super().__init__(**kwargs)
self.is_meta = False
self.ckpt_load_enabled = True
def initialize_tensors(self, enable_training=False):
super().initialize_tensors(enable_training=enable_training)
self.is_meta = self.qkvw.is_meta
def apply_tensor_parallelism(self, mp_replace, **kwargs):
if self.is_meta:
if self.qkvb is None:
self.module.attention.attn_qkvb = None
if self.dense_b is None:
self.module.attention.attn_ob = None
else:
super().apply_tensor_parallelism(mp_replace, **kwargs)
def copy_data_to_new_module(self):
if self.is_meta:
if self.attn_nw is None:
self.module.mlp.attn_nw = self.attn_nw
self.module.mlp.attn_nb = self.attn_nb
else:
super().copy_data_to_new_module()
def transpose(self):
if not self.is_meta:
super().transpose()
@abstractmethod
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
"""
Load all the transformer parameter from the checkpoint file (sd).
In addition to the parameter names, we require two
more parameters to help read the data correctly
from the checkpoint and split the qkv heads in the
right order:
1. `use_load_prefix` (Default: False): this specifies
whether we need to use the name of first abstraction
layer of the model for searching the parameter's name
in a checkpoint file. For more information of how this
is used please see
https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py
2. `split_qkv` (Default: True): we use this flag when splitting
the qkv parameter into heads. If it is False, it means the heads
of q, k, and v are stored together and needs to split in the
DeepSpeed-Inference API.
"""
raise NotImplementedError("A load_params() function must be defined in the model container \
when inheriting the MetaTensorContainer feature")