Files
DeepSpeed/csrc/utils/tensor_cast.cpp
Olatunji Ruwase 24a1d8f936 DeepNVMe update (#7215)
- FastPersist
- ZeRO-Inference+SGLang

---------

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: jerryyangli <jerryyangli@gmail.com>
Co-authored-by: Yang Li <yangli2@microsoft.com>
Co-authored-by: Guanhua Wang <alexwgh333@gmail.com>
Co-authored-by: Connor Holmes <connorholmes@microsoft.com>
Co-authored-by: Bing Xie <67908712+xiexbing@users.noreply.github.com>
Co-authored-by: cassieesvelt <73311224+cassieesvelt@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: swli <47371259+lucasleesw@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: Ubuntu <jomayeri@microsoft.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
2025-06-06 18:49:41 -04:00

27 lines
797 B
C++

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "tensor_cast.h"
at::Tensor cast_to_byte_tensor(at::Tensor& src_tensor)
{
if (src_tensor.nbytes() <= 1) return src_tensor;
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(src_tensor.layout())
.device(src_tensor.device());
return at::from_blob(
src_tensor.data_ptr(), static_cast<long int>(src_tensor.nbytes()), options);
}
std::vector<at::Tensor> cast_to_byte_tensor(std::vector<at::Tensor>& tensor_list)
{
std::vector<at::Tensor> byte_tensors;
for (auto src_tensor : tensor_list) { byte_tensors.push_back(cast_to_byte_tensor(src_tensor)); }
return byte_tensors;
}