mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix release of IPG buffer (#7376)
#6993 broke many paths in ZeRO1/2 optimizer. This PR fixes most of the issues the PR caused. Currently we still have one error with tests in `unit/runtime/zero`. ``` ====================================== short test summary info ====================================== FAILED test_zero.py::TestParamPartitioningSkipInit::test[dtype1] - RuntimeError: mat1 and mat2 must have the same dtype, but got Half and BFloat16 ========= 1 failed, 204 passed, 66 skipped, 15 deselected, 5 warnings in 2305.03s (0:38:25) ========= ``` --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
This commit is contained in:
@ -109,7 +109,6 @@ class IPGBucket:
|
||||
has_moe_params: bool = False
|
||||
|
||||
def clear(self):
|
||||
self.buffer.clear()
|
||||
self.params.clear()
|
||||
self.grads.clear()
|
||||
self.elements = 0
|
||||
@ -734,7 +733,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
def _release_ipg_buffers(self):
|
||||
if self.contiguous_gradients:
|
||||
for bucket in self.ipg_buckets.values():
|
||||
bucket.clear()
|
||||
bucket.buffer.clear()
|
||||
|
||||
self.grads_in_partition = None
|
||||
self.grads_in_partition_offset = 0
|
||||
@ -1443,10 +1442,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
if self.contiguous_gradients:
|
||||
if comm_dtype in self.extra_large_param_to_reduce:
|
||||
assert len(bucket.params) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
|
||||
_, _, param_id = self.params[0]
|
||||
assert self.get_param_id(self.extra_large_param_to_reduce
|
||||
_, _, param_id = bucket.params[0]
|
||||
assert self.get_param_id(self.extra_large_param_to_reduce[comm_dtype]
|
||||
) == param_id, "param in ipg bucket does not match extra-large param"
|
||||
extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce)
|
||||
extra_large_grad_reduc = self.get_gradient_for_reduction(
|
||||
self.extra_large_param_to_reduce[comm_dtype])
|
||||
self.average_tensor(extra_large_grad_reduc.view(-1), comm_dtype)
|
||||
del self.extra_large_param_to_reduce[comm_dtype]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user