Avoid graph break by removing redundant requires_grad attr change (#7158)

This PR is a continuation of the efforts to improve DeepSpeed
performance when using PyTorch compile.

Dynamo breaks the graph because `flat_tensor.requires_grad = False`:

* Is a side-effecting operation on tensor metadata
* Occurs in a context where Dynamo expects static tensor properties for
tracing

`flat_tensor.requires_grad` is redundant and can be safely removed
because:
* `_allgather_params()` function is already decorated with
`@torch.no_grad()` which ensures the desired property
* `flat_tensor` is created using the `torch.empty()` which sets the
`requires_grad=False` by default.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
This commit is contained in:
Max Kovalenko
2025-03-24 21:50:30 +02:00
committed by GitHub
parent 1ca83a6bb9
commit d40cf4662c

View File

@ -1899,7 +1899,6 @@ class Init(InsertPostInitMethodToModuleSubClasses):
tensor_size = partition_size * self.num_partitions
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
flat_tensor.requires_grad = False
partitions = []
for i in range(self.num_partitions):
start = partition_size * i