mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
545 lines
20 KiB
C++
545 lines
20 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
// DeepSpeed Team
|
|
|
|
#include "z3.h"
|
|
#include "deepcompile.h"
|
|
|
|
#define USE_C10D_NCCL
|
|
|
|
#include <ATen/cuda/CUDAEvent.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <torch/csrc/cuda/nccl.h>
|
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
|
|
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
|
|
|
namespace dc {
|
|
|
|
const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000;
|
|
|
|
class Z3CustomOpExecutor : public CustomOpExecutor {
|
|
public:
|
|
Z3CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
|
std::shared_ptr<DSParamRegistry> param_registry,
|
|
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
|
|
std::vector<long> ds_ids,
|
|
ncclComm_t nccl_comm,
|
|
at::cuda::CUDAStream ag_stream,
|
|
at::cuda::CUDAStream rs_stream,
|
|
at::cuda::CUDAStream copy_stream,
|
|
at::cuda::CUDAStream offload_stream,
|
|
at::cuda::CUDAStream reload_stream,
|
|
bool pre_div_reduce)
|
|
: CustomOpExecutor(process_group,
|
|
param_registry,
|
|
reduce_buckets,
|
|
ds_ids,
|
|
nccl_comm,
|
|
rs_stream,
|
|
copy_stream,
|
|
pre_div_reduce),
|
|
ag_stream_(ag_stream),
|
|
offload_stream_(offload_stream),
|
|
reload_stream_(reload_stream)
|
|
{
|
|
for (long ds_id : ds_ids_) {
|
|
ag_comm_done_events_[ds_id] =
|
|
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
|
ag_comp_done_events_[ds_id] =
|
|
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
|
|
|
param_use_count_[ds_id] = 0;
|
|
}
|
|
}
|
|
~Z3CustomOpExecutor() {}
|
|
|
|
void endBackward() override
|
|
{
|
|
if (param_updated_) {
|
|
for (auto& it : has_acc_grad_) {
|
|
it.second = false;
|
|
param_registry_->setValid(it.first, false);
|
|
}
|
|
}
|
|
|
|
for (auto& it : reload_buffers_) {
|
|
it.second.record_stream(at::cuda::getCurrentCUDAStream());
|
|
}
|
|
reload_buffers_.clear();
|
|
}
|
|
|
|
void launchAllGather(at::Tensor output_buf,
|
|
long ds_id,
|
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
|
{
|
|
const DSParam& param = param_registry_->getParam(ds_id);
|
|
const at::Tensor& ds_tensor = param.getDSTensor();
|
|
|
|
if (symm_mem == nullptr) {
|
|
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
|
|
output_buf.data_ptr(),
|
|
ds_tensor.numel(),
|
|
get_nccl_data_type(ds_tensor.scalar_type()),
|
|
nccl_comm_,
|
|
ag_stream_);
|
|
|
|
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); }
|
|
} else {
|
|
at::cuda::CUDAStreamGuard guard(ag_stream_);
|
|
int world_size = process_group_->getSize();
|
|
int rank = process_group_->getRank();
|
|
|
|
at::Tensor local_buf =
|
|
symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0);
|
|
local_buf.copy_(ds_tensor, true);
|
|
|
|
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER);
|
|
auto chunks = output_buf.flatten().chunk(world_size);
|
|
for (int step = 0; step < world_size; step++) {
|
|
int remote_rank = (rank - step + world_size) % world_size;
|
|
auto src_buf = symm_mem->get_buffer(
|
|
remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0);
|
|
chunks[remote_rank].copy_(src_buf.flatten(), true);
|
|
}
|
|
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER);
|
|
}
|
|
|
|
param_registry_->registerGatheredParam(ds_id, output_buf);
|
|
param_registry_->setValid(ds_id, true);
|
|
}
|
|
|
|
at::Tensor allgatherParam(long ds_id,
|
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
|
{
|
|
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }
|
|
|
|
const DSParam& param = param_registry_->getParam(ds_id);
|
|
const at::Tensor& ds_tensor = param.getDSTensor();
|
|
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
|
|
? param_registry_->getGatheredParam(ds_id)
|
|
: torch::empty(param.getShape(), ds_tensor.options());
|
|
|
|
assert(hasKey(ag_comp_done_events_, ds_id));
|
|
ag_comp_done_events_[ds_id]->record();
|
|
ag_comp_done_events_[ds_id]->block(ag_stream_);
|
|
|
|
launchAllGather(output_buf, ds_id, symm_mem);
|
|
|
|
ag_comm_done_events_[ds_id]->record(ag_stream_);
|
|
return output_buf;
|
|
}
|
|
|
|
void prefetchParamsFused(std::vector<int64_t> ds_ids,
|
|
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
|
{
|
|
std::vector<int64_t> invalid_ds_ids;
|
|
for (const auto& ds_id : ds_ids) {
|
|
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); }
|
|
}
|
|
|
|
std::unordered_map<long, at::Tensor> output_bufs;
|
|
for (long ds_id : invalid_ds_ids) {
|
|
const DSParam& param = param_registry_->getParam(ds_id);
|
|
if (param_registry_->hasGatheredParam(ds_id)) {
|
|
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
|
|
} else {
|
|
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
|
|
}
|
|
}
|
|
|
|
for (long ds_id : invalid_ds_ids) {
|
|
ag_comp_done_events_[ds_id]->record();
|
|
ag_comp_done_events_[ds_id]->block(ag_stream_);
|
|
}
|
|
|
|
ncclGroupStart();
|
|
for (long ds_id : invalid_ds_ids) {
|
|
assert(hasKey(output_bufs, ds_id));
|
|
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem);
|
|
}
|
|
ncclGroupEnd();
|
|
|
|
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
|
|
}
|
|
|
|
void releaseParam(long ds_id, long n_users)
|
|
{
|
|
const DSParam& param = param_registry_->getParam(ds_id);
|
|
|
|
assert(hasKey(param_use_count_, ds_id));
|
|
if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; }
|
|
param_use_count_[ds_id]--;
|
|
|
|
if (param_use_count_[ds_id] == 0 && !param.isPersistent()) {
|
|
at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id);
|
|
|
|
if (gathered_param.defined()) { // gathered param is undefined while profiling
|
|
const auto options = gathered_param.options();
|
|
at::Tensor empty_buffer = torch::empty({0}, options);
|
|
gathered_param.set_data(empty_buffer);
|
|
}
|
|
|
|
param_registry_->unregisterGatheredParam(ds_id);
|
|
}
|
|
}
|
|
|
|
at::Tensor waitAllgather(at::Tensor v, long ds_id)
|
|
{
|
|
assert(hasKey(ag_comm_done_events_, ds_id));
|
|
ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream());
|
|
return v;
|
|
}
|
|
|
|
void flushReduceBucket(at::ScalarType scalar_type) override
|
|
{
|
|
if (!hasKey(reduce_tasks_, scalar_type)) { return; }
|
|
|
|
int64_t tmp_recv_numel = 0;
|
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
|
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
|
|
copy_done_event->block(rs_stream_);
|
|
|
|
if (has_acc_grad_.at(t.getDSId())) {
|
|
tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel();
|
|
}
|
|
}
|
|
|
|
at::Tensor tmp_recv_buf = at::Tensor();
|
|
if (tmp_recv_numel > 0) {
|
|
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
|
tmp_recv_buf = torch::empty({tmp_recv_numel},
|
|
at::TensorOptions().dtype(scalar_type).device(at::kCUDA));
|
|
}
|
|
|
|
ncclGroupStart();
|
|
int64_t offset = 0;
|
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
|
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
|
|
|
|
bool acc_grad = has_acc_grad_.at(t.getDSId());
|
|
|
|
if (acc_grad) {
|
|
recv_buf =
|
|
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())});
|
|
}
|
|
|
|
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg;
|
|
if (pre_div_reduce_) {
|
|
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
|
t.getSendBuf().div_(process_group_->getSize());
|
|
}
|
|
ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(),
|
|
recv_buf.data_ptr(),
|
|
recv_buf.numel(),
|
|
get_nccl_data_type(scalar_type),
|
|
op,
|
|
nccl_comm_,
|
|
rs_stream_);
|
|
if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); }
|
|
|
|
if (acc_grad) { offset += recv_buf.numel(); }
|
|
}
|
|
ncclGroupEnd();
|
|
|
|
{
|
|
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
|
int64_t offset = 0;
|
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
|
bool acc_grad = has_acc_grad_.at(t.getDSId());
|
|
|
|
if (acc_grad) {
|
|
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
|
|
recv_buf.add_(tmp_recv_buf.index(
|
|
{torch::indexing::Slice(offset, offset + recv_buf.numel())}));
|
|
offset += recv_buf.numel();
|
|
}
|
|
has_acc_grad_[t.getDSId()] = true;
|
|
}
|
|
}
|
|
|
|
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_);
|
|
|
|
// Not very sure if this is necessary
|
|
// Want to prevent grad tensor from being released before the copy is done
|
|
auto comp_stream = at::cuda::getCurrentCUDAStream();
|
|
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
|
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
|
|
copy_done_event->block(comp_stream);
|
|
}
|
|
reduce_tasks_[scalar_type].clear();
|
|
|
|
if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); }
|
|
}
|
|
|
|
at::Tensor offloadTensor(at::Tensor tensor, long id)
|
|
{
|
|
if (!hasKey(offload_events_, id)) {
|
|
offload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
|
offload_comp_done_events_[id] =
|
|
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
|
|
|
const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU);
|
|
offload_buffers_[id] = at::empty_like(tensor, options);
|
|
}
|
|
|
|
offload_comp_done_events_[id]->record();
|
|
offload_comp_done_events_[id]->block(offload_stream_);
|
|
{
|
|
at::cuda::CUDAStreamGuard guard(offload_stream_);
|
|
offload_buffers_.at(id).copy_(tensor, true);
|
|
}
|
|
|
|
tensor.record_stream(offload_stream_);
|
|
|
|
offload_events_[id]->record(offload_stream_);
|
|
assert(hasKey(offload_buffers_, id));
|
|
return offload_buffers_.at(id);
|
|
}
|
|
|
|
at::Tensor reloadTensor(at::Tensor tensor, long id)
|
|
{
|
|
if (!hasKey(reload_events_, id)) {
|
|
reload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
|
}
|
|
|
|
assert(hasKey(offload_buffers_, id));
|
|
offload_events_[id]->block(reload_stream_);
|
|
|
|
at::Tensor ten;
|
|
{
|
|
at::cuda::CUDAStreamGuard guard(reload_stream_);
|
|
|
|
assert(hasKey(offload_buffers_, id));
|
|
at::Tensor buf = offload_buffers_.at(id);
|
|
const auto options = at::TensorOptions().device(torch::kCUDA);
|
|
ten = at::empty_like(buf, options);
|
|
ten.copy_(buf, true);
|
|
|
|
reload_buffers_[id] = ten;
|
|
}
|
|
|
|
reload_events_[id]->record(reload_stream_);
|
|
return ten;
|
|
}
|
|
|
|
at::Tensor waitOffload(at::Tensor tensor, long id)
|
|
{
|
|
assert(hasKey(offload_events_, id));
|
|
offload_events_[id]->block(at::cuda::getCurrentCUDAStream());
|
|
|
|
assert(hasKey(offload_buffers_, id));
|
|
return offload_buffers_.at(id);
|
|
}
|
|
|
|
at::Tensor waitReload(at::Tensor tensor, long id)
|
|
{
|
|
assert(hasKey(reload_events_, id));
|
|
reload_events_[id]->block(at::cuda::getCurrentCUDAStream());
|
|
|
|
assert(hasKey(reload_buffers_, id));
|
|
auto ten = reload_buffers_.at(id);
|
|
|
|
// We can't release here because the tensor is still being used
|
|
// We will need "freeReloadedTensor" after the last user of the tensor to call
|
|
// ".record_stream". As it is a bit complicated, we clear the buffer and do at the end of
|
|
// the backward pass for now. reload_buffers_.erase(id);
|
|
return ten;
|
|
}
|
|
|
|
void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); }
|
|
void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); }
|
|
|
|
bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); }
|
|
|
|
bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); }
|
|
|
|
private:
|
|
at::cuda::CUDAStream ag_stream_;
|
|
at::cuda::CUDAStream offload_stream_;
|
|
at::cuda::CUDAStream reload_stream_;
|
|
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comp_done_events_;
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comm_done_events_;
|
|
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_events_;
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_comp_done_events_;
|
|
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> reload_events_;
|
|
std::unordered_map<long, at::Tensor> offload_buffers_;
|
|
std::unordered_map<long, at::Tensor> reload_buffers_;
|
|
|
|
std::unordered_map<long, long> param_use_count_;
|
|
};
|
|
|
|
static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true);
|
|
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
|
|
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
|
|
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true);
|
|
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true);
|
|
|
|
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids)
|
|
{
|
|
executors[graph_id] = std::make_shared<Z3CustomOpExecutor>(process_group,
|
|
param_registry,
|
|
reduce_buckets,
|
|
ds_ids,
|
|
nccl_comm,
|
|
ag_stream,
|
|
rs_stream,
|
|
copy_stream,
|
|
offload_stream,
|
|
reload_stream,
|
|
pre_div_reduce);
|
|
}
|
|
|
|
void register_z3_param(long ds_id,
|
|
const std::vector<int64_t>& ds_shape,
|
|
at::Tensor ds_tensor,
|
|
at::Tensor grad_buffer,
|
|
bool persistent)
|
|
{
|
|
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
|
|
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }
|
|
}
|
|
|
|
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
|
|
if (sync_before_allgather) { c10::cuda::device_synchronize(); }
|
|
auto ret = executor->allgatherParam(ds_id, symm_mem);
|
|
if (sync_after_allgather) { c10::cuda::device_synchronize(); }
|
|
return ret;
|
|
}
|
|
|
|
void set_persistent(long ds_id)
|
|
{
|
|
param_registry->setPersistent(ds_id, true);
|
|
|
|
// Allocate buffer here
|
|
// Memory fragmentation will be more severe if we allocate in forward/backward
|
|
for (auto& it : executors) {
|
|
if (it.second->hasParam(ds_id)) {
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors);
|
|
executor->allgatherParam(ds_id, symm_mem);
|
|
}
|
|
}
|
|
}
|
|
|
|
void prefetch_params_fused(long graph_id,
|
|
const std::vector<at::Tensor> params,
|
|
const std::vector<long>& ds_ids)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
executor->prefetchParamsFused(ds_ids, symm_mem);
|
|
}
|
|
|
|
void prefetch_params_fused_meta(long graph_id,
|
|
const std::vector<at::Tensor> params,
|
|
const std::vector<long>& ds_ids)
|
|
{
|
|
}
|
|
|
|
// for profiling
|
|
void invalidate_gathered_param(long ds_id)
|
|
{
|
|
const DSParam& param = param_registry->getParam(ds_id);
|
|
if (param.isPersistent()) { return; }
|
|
|
|
param_registry->unregisterGatheredParam(ds_id);
|
|
param_registry->registerGatheredParam(ds_id, at::Tensor());
|
|
}
|
|
|
|
void clear_all_gathered_params()
|
|
{
|
|
for (const auto& it : param_registry->getParams()) {
|
|
long ds_id = it.first;
|
|
const DSParam& param = param_registry->getParam(ds_id);
|
|
if (param.isPersistent()) { continue; }
|
|
if (param_registry->hasGatheredParam(ds_id)) {
|
|
param_registry->unregisterGatheredParam(ds_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id)
|
|
{
|
|
const DSParam& param = param_registry->getParam(ds_id);
|
|
auto options = param.getDSTensor().options().device(c10::kMeta);
|
|
at::Tensor output_buf = torch::empty(param.getShape(), options);
|
|
return output_buf;
|
|
}
|
|
|
|
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
executor->releaseParam(ds_id, n_users);
|
|
|
|
if (clone_custom_op_output) { return dummy.clone(); }
|
|
return dummy;
|
|
}
|
|
|
|
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users)
|
|
{
|
|
return dummy;
|
|
}
|
|
|
|
at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
executor->waitAllgather(v, ds_id);
|
|
return v;
|
|
}
|
|
|
|
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; }
|
|
|
|
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
return executor->offloadTensor(tensor, id);
|
|
}
|
|
|
|
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
return executor->reloadTensor(tensor, id);
|
|
}
|
|
|
|
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
return executor->waitOffload(tensor, id);
|
|
}
|
|
|
|
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
if (profile && !executor->hasReloadBuffer(id)) { return tensor; }
|
|
return executor->waitReload(tensor, id);
|
|
}
|
|
|
|
at::Tensor test_call(at::Tensor a)
|
|
{
|
|
std::cout << "test_call" << std::endl;
|
|
return a;
|
|
}
|
|
|
|
void reload_parameter(at::Tensor tensor, long graph_id, long ds_id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
executor->reloadParameter(tensor, ds_id);
|
|
}
|
|
|
|
void offload_parameter(at::Tensor tensor, long graph_id, long ds_id)
|
|
{
|
|
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
|
executor->offloadParameter(tensor, ds_id);
|
|
}
|
|
void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {}
|
|
void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {}
|
|
|
|
} // namespace dc
|