mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
* This commit addresses a Deepspeed issue [#6718](https://github.com/microsoft/DeepSpeed/issues/6718) * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as `param.data = replicated_tensor.data` used in `allgather_params(..)` are compiled into `param.set()` causing the hook assigned to the grad_acc node not being called. * Starting from PyTorch 2.1 there is a new and robust hook API on a param itself: `param.register_post_accumulate_grad_hook(..)` * This commit will make use of the proper API depending on the PyTorch version * It will also disable compile for PyTorch versions < 2.1 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
32 lines
888 B
Python
32 lines
888 B
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from packaging import version as pkg_version
|
|
|
|
import torch
|
|
|
|
|
|
def required_torch_version(min_version=None, max_version=None):
|
|
assert min_version or max_version, "Must provide a min_version or max_version argument"
|
|
|
|
torch_version = pkg_version.parse(torch.__version__)
|
|
|
|
if min_version and pkg_version.parse(str(min_version)) > torch_version:
|
|
return False
|
|
|
|
if max_version and pkg_version.parse(str(max_version)) < torch_version:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def register_grad_hook(param, hook):
|
|
if required_torch_version(min_version=2.1):
|
|
return param.register_post_accumulate_grad_hook(hook)
|
|
else:
|
|
param_tmp = param.expand_as(param)
|
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
return grad_acc.register_hook(hook)
|