Files
DeepSpeed/deepspeed/utils/torch.py
Max Kovalenko 456c9ac679 Stage3: Use new torch grad accumulation hooks API (#6773)
* This commit addresses a Deepspeed issue
[#6718](https://github.com/microsoft/DeepSpeed/issues/6718)
* The existing code has been using the grad_acc node hook to reduce
params grads.
The constructs such as `param.data = replicated_tensor.data` used in
`allgather_params(..)`
are compiled into `param.set()` causing the hook assigned to the
grad_acc node not being called.
* Starting from PyTorch 2.1 there is a new and robust hook API on a
param itself: `param.register_post_accumulate_grad_hook(..)`
* This commit will make use of the proper API depending on the PyTorch
version
* It will also disable compile for PyTorch versions < 2.1

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
2025-01-03 07:48:24 -08:00

32 lines
888 B
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from packaging import version as pkg_version
import torch
def required_torch_version(min_version=None, max_version=None):
assert min_version or max_version, "Must provide a min_version or max_version argument"
torch_version = pkg_version.parse(torch.__version__)
if min_version and pkg_version.parse(str(min_version)) > torch_version:
return False
if max_version and pkg_version.parse(str(max_version)) < torch_version:
return False
return True
def register_grad_hook(param, hook):
if required_torch_version(min_version=2.1):
return param.register_post_accumulate_grad_hook(hook)
else:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
return grad_acc.register_hook(hook)