mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
DeepCompile: Fuse allgather and downcast (#7588)
With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast. To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers. This corresponds to issue 1 of #7577. Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB. Profiles of a single step before this PR: <img width="1235" height="1029" alt="image" src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054" /> <img width="1466" height="616" alt="image" src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663" /> After this PR: <img width="1218" height="1006" alt="image" src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691" /> <img width="1537" height="559" alt="image" src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b" /> This PR also reduces peak memory usage because the `fast_free_schedule()` today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage. P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
@ -10,8 +10,10 @@
|
||||
|
||||
TORCH_LIBRARY(dc, m)
|
||||
{
|
||||
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
|
||||
m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor");
|
||||
m.def(
|
||||
"prefetch_params_fused(int graph_id, Tensor[] params, int[] ids,"
|
||||
" ScalarType[]? dtypes = None) -> ()");
|
||||
m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)");
|
||||
m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)");
|
||||
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
|
||||
|
@ -68,7 +68,12 @@ public:
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
at::Tensor ds_tensor = param.getDSTensor();
|
||||
|
||||
if (ds_tensor.scalar_type() != output_buf.scalar_type()) {
|
||||
at::cuda::CUDAStreamGuard guard(ag_stream_);
|
||||
ds_tensor = ds_tensor.to(output_buf.scalar_type(), true, true);
|
||||
}
|
||||
|
||||
if (symm_mem == nullptr) {
|
||||
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
|
||||
@ -110,6 +115,7 @@ public:
|
||||
}
|
||||
|
||||
at::Tensor allgatherParam(long ds_id,
|
||||
std::optional<at::ScalarType> dtype,
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
@ -118,11 +124,16 @@ public:
|
||||
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));
|
||||
const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size;
|
||||
const int64_t padded_numel = static_cast<int64_t>(world_size) * padded_per_rank;
|
||||
at::ScalarType target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type();
|
||||
|
||||
if (param_registry_->isValid(ds_id)) {
|
||||
// Return a view sliced to the true size with the original shape
|
||||
//
|
||||
// Persistent params are gathered in their original dtype which may
|
||||
// be different from the requested.
|
||||
auto base = param_registry_->getGatheredParam(ds_id);
|
||||
return base.flatten()
|
||||
.to(target_dtype)
|
||||
.index({torch::indexing::Slice(0, true_numel)})
|
||||
.view(param.getShape());
|
||||
}
|
||||
@ -134,7 +145,7 @@ public:
|
||||
}
|
||||
if (!output_buf.defined()) {
|
||||
at::cuda::CUDAStreamGuard guard(ag_stream_);
|
||||
output_buf = torch::empty({padded_numel}, ds_tensor.options());
|
||||
output_buf = torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype));
|
||||
}
|
||||
|
||||
assert(hasKey(ag_comp_done_events_, ds_id));
|
||||
@ -150,16 +161,20 @@ public:
|
||||
.view(param.getShape());
|
||||
}
|
||||
|
||||
void prefetchParamsFused(std::vector<int64_t> ds_ids,
|
||||
void prefetchParamsFused(const std::vector<long>& ds_ids,
|
||||
const std::optional<std::vector<at::ScalarType>> dtypes,
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
std::vector<int64_t> invalid_ds_ids;
|
||||
for (const auto& ds_id : ds_ids) {
|
||||
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); }
|
||||
std::vector<std::tuple<long, std::optional<at::ScalarType>>> invalid_params;
|
||||
for (int i = 0; i < ds_ids.size(); i++) {
|
||||
if (!param_registry_->isValid(ds_ids[i])) {
|
||||
auto dtype = dtypes ? dtypes.value()[i] : std::optional<at::ScalarType>();
|
||||
invalid_params.push_back(std::make_tuple(ds_ids[i], dtype));
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<long, at::Tensor> output_bufs;
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
for (const auto& [ds_id, dtype] : invalid_params) {
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
const int world_size = process_group_->getSize();
|
||||
@ -173,22 +188,26 @@ public:
|
||||
continue;
|
||||
}
|
||||
}
|
||||
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
|
||||
auto target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type();
|
||||
output_bufs[ds_id] =
|
||||
torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype));
|
||||
}
|
||||
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
for (const auto& [ds_id, _] : invalid_params) {
|
||||
ag_comp_done_events_[ds_id]->record();
|
||||
ag_comp_done_events_[ds_id]->block(ag_stream_);
|
||||
}
|
||||
|
||||
ncclGroupStart();
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
for (const auto& [ds_id, _] : invalid_params) {
|
||||
assert(hasKey(output_bufs, ds_id));
|
||||
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem);
|
||||
}
|
||||
ncclGroupEnd();
|
||||
|
||||
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
|
||||
for (const auto& [ds_id, _] : invalid_params) {
|
||||
ag_comm_done_events_[ds_id]->record(ag_stream_);
|
||||
}
|
||||
}
|
||||
|
||||
void releaseParam(long ds_id, long n_users)
|
||||
@ -458,12 +477,15 @@ void register_z3_param(long ds_id,
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
|
||||
at::Tensor allgather_param(at::Tensor param_tensor,
|
||||
long graph_id,
|
||||
long ds_id,
|
||||
std::optional<at::ScalarType> dtype)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
|
||||
if (sync_before_allgather) { c10::cuda::device_synchronize(); }
|
||||
auto ret = executor->allgatherParam(ds_id, symm_mem);
|
||||
auto ret = executor->allgatherParam(ds_id, dtype, symm_mem);
|
||||
if (sync_after_allgather) { c10::cuda::device_synchronize(); }
|
||||
return ret;
|
||||
}
|
||||
@ -477,22 +499,25 @@ void set_persistent(long ds_id)
|
||||
for (auto& it : executors) {
|
||||
if (it.second->hasParam(ds_id)) {
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors);
|
||||
executor->allgatherParam(ds_id, symm_mem);
|
||||
auto dtype = param_registry->getParam(ds_id).getDtype();
|
||||
executor->allgatherParam(ds_id, dtype, symm_mem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void prefetch_params_fused(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids)
|
||||
const std::vector<at::Tensor>& params,
|
||||
const std::vector<long>& ds_ids,
|
||||
const std::optional<std::vector<at::ScalarType>>& dtypes)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->prefetchParamsFused(ds_ids, symm_mem);
|
||||
executor->prefetchParamsFused(ds_ids, dtypes, symm_mem);
|
||||
}
|
||||
|
||||
void prefetch_params_fused_meta(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids)
|
||||
const std::vector<at::Tensor>& params,
|
||||
const std::vector<long>& ds_ids,
|
||||
const std::optional<std::vector<at::ScalarType>>& dtypes)
|
||||
{
|
||||
}
|
||||
|
||||
@ -518,11 +543,14 @@ void clear_all_gathered_params()
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id)
|
||||
at::Tensor allgather_param_meta(at::Tensor param_tensor,
|
||||
long graph_id,
|
||||
long ds_id,
|
||||
std::optional<at::ScalarType> dtype)
|
||||
{
|
||||
const DSParam& param = param_registry->getParam(ds_id);
|
||||
auto options = param.getDSTensor().options().device(c10::kMeta);
|
||||
at::Tensor output_buf = torch::empty(param.getShape(), options);
|
||||
at::Tensor output_buf = torch::empty(param.getShape(), options.dtype(dtype));
|
||||
return output_buf;
|
||||
}
|
||||
|
||||
|
@ -21,18 +21,26 @@ void register_z3_param(long ds_id,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
bool persistent);
|
||||
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id);
|
||||
at::Tensor allgather_param(at::Tensor param_tensor,
|
||||
long graph_id,
|
||||
long ds_id,
|
||||
std::optional<at::ScalarType> dtype);
|
||||
void set_persistent(long ds_id);
|
||||
void prefetch_params_fused(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids);
|
||||
const std::vector<at::Tensor>& params,
|
||||
const std::vector<long>& ds_ids,
|
||||
const std::optional<std::vector<at::ScalarType>>& dtypes);
|
||||
void prefetch_params_fused_meta(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids);
|
||||
const std::vector<at::Tensor>& params,
|
||||
const std::vector<long>& ds_ids,
|
||||
const std::optional<std::vector<at::ScalarType>>& dtypes);
|
||||
// for profiling
|
||||
void invalidate_gathered_param(long ds_id);
|
||||
void clear_all_gathered_params();
|
||||
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id);
|
||||
at::Tensor allgather_param_meta(at::Tensor param_tensor,
|
||||
long graph_id,
|
||||
long ds_id,
|
||||
std::optional<at::ScalarType> dtype);
|
||||
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
|
||||
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
|
||||
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/cuda/nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#if __has_include(<torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>)
|
||||
@ -261,6 +262,7 @@ public:
|
||||
: id_(id),
|
||||
shape_(std::move(ds_shape)),
|
||||
ds_tensor_(ds_tensor),
|
||||
ds_dtype_(ds_tensor.scalar_type()),
|
||||
grad_buffer_(grad_buffer),
|
||||
partitioned_(partitioned),
|
||||
offset_(offset),
|
||||
@ -272,6 +274,7 @@ public:
|
||||
|
||||
long getId() const { return id_; }
|
||||
std::vector<int64_t> getShape() const { return shape_; }
|
||||
at::ScalarType getDtype() const { return ds_dtype_; }
|
||||
at::Tensor getDSTensor() const
|
||||
{
|
||||
// If the reload event exists and is complete, return the reloaded tensor (if defined)
|
||||
@ -343,6 +346,7 @@ public:
|
||||
private:
|
||||
long id_;
|
||||
std::vector<int64_t> shape_;
|
||||
at::ScalarType ds_dtype_;
|
||||
at::Tensor ds_tensor_;
|
||||
at::Tensor ds_reload_tensor_;
|
||||
at::Tensor grad_buffer_;
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Callable, Any, List
|
||||
from typing import Callable, Any, List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
@ -60,7 +60,8 @@ def add_args_process(graph: Graph,
|
||||
def add_postprocess(graph: Graph,
|
||||
node: Node,
|
||||
fn: Callable[..., Any],
|
||||
extra_args: List[int] = [],
|
||||
extra_args: List[Any] = [],
|
||||
extra_kwargs: Dict[str, Any] = {},
|
||||
name=None,
|
||||
meta={}) -> Node:
|
||||
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
|
||||
@ -70,7 +71,7 @@ def add_postprocess(graph: Graph,
|
||||
args += (a, )
|
||||
|
||||
node_users = node.users.keys()
|
||||
new_node = graph.create_node('call_function', fn, args, {}, name=name)
|
||||
new_node = graph.create_node('call_function', fn, args, extra_kwargs, name=name)
|
||||
users = {}
|
||||
for u in node_users:
|
||||
if u != new_node:
|
||||
|
@ -10,7 +10,7 @@ import _operator
|
||||
import torch
|
||||
from torch.fx import Graph, Node, GraphModule
|
||||
|
||||
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
|
||||
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
|
||||
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
|
||||
from ..profilers.graph_profile import ProfilingInterpreter
|
||||
from ..list_schedule import fast_free_schedule
|
||||
@ -21,14 +21,15 @@ from deepspeed.accelerator import get_accelerator
|
||||
NAME = "zero3_compile"
|
||||
|
||||
|
||||
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
|
||||
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, dtype: torch.dtype):
|
||||
new_ag_node = add_postprocess(graph,
|
||||
node,
|
||||
torch.ops.dc.allgather_param.default,
|
||||
extra_args=[graph_id, ds_id],
|
||||
extra_kwargs={"dtype": dtype},
|
||||
name=f"allgather_ds_param_{node.target}_{ds_id}",
|
||||
meta=_make_node_meta(node, ds_id, True))
|
||||
new_ag_node.meta["val"] = node.meta["val"]
|
||||
new_ag_node.meta["val"] = node.meta["val"].to(dtype)
|
||||
|
||||
# Set the previous node back to output
|
||||
# We don't want to change the output node to allgather
|
||||
@ -42,7 +43,7 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
|
||||
extra_args=[graph_id, ds_id],
|
||||
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
|
||||
meta=_make_node_meta(node, ds_id, False))
|
||||
new_wait_node.meta["val"] = node.meta["val"]
|
||||
new_wait_node.meta["val"] = new_ag_node.meta["val"]
|
||||
|
||||
return new_ag_node
|
||||
|
||||
@ -74,9 +75,30 @@ def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nod
|
||||
if len(pn.users) == 0:
|
||||
continue
|
||||
|
||||
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
|
||||
# If the only use of the parameter is a type-cast to a smaller type, fuse it with all-gather.
|
||||
fuse_typecast = False
|
||||
target_dtype = param_manager.params[pn.name].dtype
|
||||
if len([user for user in pn.users if user.op != "output"]) == 1:
|
||||
typecast_node = next(iter(pn.users))
|
||||
|
||||
is_cast, casted_dtype = is_cast_op(typecast_node)
|
||||
if is_cast and casted_dtype.itemsize < target_dtype.itemsize:
|
||||
fuse_typecast = True
|
||||
target_dtype = casted_dtype
|
||||
|
||||
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name], target_dtype)
|
||||
if fuse_typecast:
|
||||
users = node_to_uses[typecast_node]
|
||||
wait_node = typecast_node.args[0]
|
||||
for user in list(typecast_node.users.keys()):
|
||||
if user.op == "output":
|
||||
wait_node.meta["original_output_name"] = typecast_node.name
|
||||
user.replace_input_with(typecast_node, wait_node)
|
||||
graph.erase_node(typecast_node)
|
||||
else:
|
||||
users = node_to_uses[pn]
|
||||
|
||||
ds_id = param_manager.ds_ids[pn.name]
|
||||
users = node_to_uses[pn]
|
||||
for user in users:
|
||||
# release_param() only accepts tensors as its first argument. If
|
||||
# `user` is a tuple, we should release the param after any of
|
||||
|
@ -130,9 +130,15 @@ class ProfilingInterpreter(Interpreter):
|
||||
assert isinstance(args, tuple)
|
||||
assert isinstance(kwargs, dict)
|
||||
|
||||
partitioned_params = {}
|
||||
|
||||
def rebuild_param_if_necessary(v):
|
||||
if hasattr(v, "ds_id"):
|
||||
v.all_gather(param_list=[v])
|
||||
if hasattr(v, "ds_target_dtype"):
|
||||
casted = v.to(v.ds_target_dtype)
|
||||
partitioned_params[id(casted)] = v
|
||||
return casted
|
||||
return v
|
||||
|
||||
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x))
|
||||
@ -191,6 +197,8 @@ class ProfilingInterpreter(Interpreter):
|
||||
tensor_size = _node_size(out)
|
||||
|
||||
def partition_param_if_necessary(v):
|
||||
if id(v) in partitioned_params:
|
||||
v = partitioned_params[id(v)]
|
||||
if hasattr(v, "ds_id") and not v.ds_persist:
|
||||
v.partition(param_list=[v], has_been_updated=False)
|
||||
return v
|
||||
@ -227,6 +235,8 @@ class ProfilingInterpreter(Interpreter):
|
||||
assert hasattr(out, "ds_id")
|
||||
if not out.ds_persist:
|
||||
self.nz3.invalidate_gathered_param(args[2])
|
||||
if "dtype" in n.kwargs:
|
||||
setattr(out, "ds_target_dtype", n.kwargs["dtype"])
|
||||
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"]
|
||||
|
||||
return out
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
import functools
|
||||
import operator
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
@ -131,6 +131,15 @@ def is_comm_op(node: Node) -> bool:
|
||||
return "comm" in node.meta and node.meta["comm"]
|
||||
|
||||
|
||||
def is_cast_op(node: Node) -> Tuple[bool, Optional[torch.dtype]]:
|
||||
if node.op == "call_function":
|
||||
if node.target == torch.ops.prims.convert_element_type.default:
|
||||
return (True, node.args[1])
|
||||
elif node.target == torch.ops.aten._to_copy.default and set(node.kwargs.keys()) == {"dtype"}:
|
||||
return (True, node.kwargs["dtype"])
|
||||
return (False, None)
|
||||
|
||||
|
||||
def exclude_from_act_offload(node: Node) -> bool:
|
||||
return node.target in sym_size_ops
|
||||
|
||||
|
@ -180,3 +180,36 @@ class TestDeepCompile(DistributedTest):
|
||||
}
|
||||
|
||||
compare_loss(self, config_dict, dtype)
|
||||
|
||||
@pytest.mark.parametrize('dtype', ["bfloat16", "float16"])
|
||||
@pytest.mark.parametrize('zero_stage', [3])
|
||||
def test_fusing_allgather_and_autocast(self, zero_stage, dtype):
|
||||
"""Test that allgather and autocast can be correctly fused with DeepCompile"""
|
||||
if not required_torch_version(min_version=2.6):
|
||||
pytest.skip("DeepCompile requires PyTorch >= v2.6")
|
||||
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU does not support this test yet")
|
||||
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"torch_autocast": {
|
||||
"enable": True,
|
||||
"dtype": dtype,
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
"compile": {
|
||||
"deepcompile": True
|
||||
}
|
||||
}
|
||||
|
||||
compare_loss(self, config_dict, torch.float32)
|
||||
|
Reference in New Issue
Block a user