mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
DeepCompile: Specify tensor aliasing in C++ op schema (#7597)
PyTorch C++ op schema [1] allows specifying tensor storage aliasing by annotating `(a)` after input/output types. Torch inductor takes this information to determine where to insert explicit `del` statements for tensors that are no longer needed. If what an op schema specifies disagrees with the op implementation, inductor-generated code is likely to release tensors earlier than expected and leads to wrong results. `wait_allgather` and `release_param` return the first argument unchanged and that aliasing should be annotated in the schema. Also remove the code related to `clone_custom_op_output` as it is solely a workaround of the aforementioned issue. [1] https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md Fixes: #7596 Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
@ -17,7 +17,6 @@ c10::intrusive_ptr<c10d::ProcessGroup> process_group = nullptr;
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem = nullptr;
|
||||
ncclComm_t nccl_comm;
|
||||
bool use_symm_mem;
|
||||
bool clone_custom_op_output;
|
||||
bool profile = false;
|
||||
bool pre_div_reduce = true;
|
||||
|
||||
@ -130,8 +129,7 @@ static T get_config(pybind11::object& config, const char* name)
|
||||
|
||||
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
pybind11::object& config,
|
||||
int64_t initial_reduce_bucket_size,
|
||||
bool _clone_custom_op_output)
|
||||
int64_t initial_reduce_bucket_size)
|
||||
{
|
||||
process_group = pg;
|
||||
|
||||
@ -157,7 +155,6 @@ void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(
|
||||
initial_reduce_bucket_size, get_config<bool>(config, "double_buffer"));
|
||||
use_symm_mem = get_config<bool>(config, "symmetric_memory");
|
||||
clone_custom_op_output = _clone_custom_op_output;
|
||||
free_activation_threshold = get_config<int64_t>(config, "free_activation_threshold");
|
||||
|
||||
sync_before_reduce = get_config<bool>(config, "sync_before_reduce");
|
||||
|
@ -12,8 +12,8 @@ TORCH_LIBRARY(dc, m)
|
||||
{
|
||||
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
|
||||
m.def("wait_allgather(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("release_param(Tensor a, int graph_id, int id, int n_users) -> Tensor");
|
||||
m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)");
|
||||
m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)");
|
||||
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("free_tensors(Tensor[] a) -> ()");
|
||||
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
|
||||
|
@ -530,8 +530,6 @@ at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_use
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->releaseParam(ds_id, n_users);
|
||||
|
||||
if (clone_custom_op_output) { return dummy.clone(); }
|
||||
return dummy;
|
||||
}
|
||||
|
||||
|
@ -98,7 +98,6 @@ extern c10::intrusive_ptr<c10d::ProcessGroup> process_group;
|
||||
extern c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem;
|
||||
extern ncclComm_t nccl_comm;
|
||||
extern bool use_symm_mem;
|
||||
extern bool clone_custom_op_output;
|
||||
extern bool profile;
|
||||
extern bool pre_div_reduce;
|
||||
|
||||
@ -591,8 +590,7 @@ void free_tensors_meta(std::vector<at::Tensor> tensors);
|
||||
|
||||
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
pybind11::object& config,
|
||||
int64_t initial_reduce_bucket_size,
|
||||
bool _clone_custom_op_output);
|
||||
int64_t initial_reduce_bucket_size);
|
||||
void reset();
|
||||
void cleanup();
|
||||
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from .passes import zero1_compile, zero3_compile
|
||||
from .backend import make_backend, launch_compile_passes, init_schedule
|
||||
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
|
||||
from .util import get_deepcompile_handle, add_pre_backward_hook
|
||||
|
||||
WARMUP = 5
|
||||
|
||||
@ -24,7 +24,7 @@ def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_
|
||||
optimizer._grad_acc_hooks.clear()
|
||||
|
||||
dc = get_deepcompile_handle()
|
||||
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size(), is_backend_inductor(backend))
|
||||
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size())
|
||||
|
||||
grad_buffer = {}
|
||||
|
||||
|
@ -13,7 +13,7 @@ from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
|
||||
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
|
||||
from .backend import make_backend, launch_compile_passes, init_schedule
|
||||
from .patch_fake_tensor import patch_fake_tensor
|
||||
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
|
||||
from .util import get_deepcompile_handle, add_pre_backward_hook
|
||||
|
||||
WARMUP = 5
|
||||
|
||||
@ -28,7 +28,7 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
dc = get_deepcompile_handle()
|
||||
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size(), is_backend_inductor(backend))
|
||||
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size())
|
||||
|
||||
# Unset hooks
|
||||
for m in engine.module.modules():
|
||||
|
Reference in New Issue
Block a user