mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] directly return local_tensor under no_grad (#128145)
as titled, skip the autograd function and directly return the local_tensor if it's under no_grad context, this would avoid creating views Pull Request resolved: https://github.com/pytorch/pytorch/pull/128145 Approved by: https://github.com/awgu ghstack dependencies: #128112
This commit is contained in:
committed by
PyTorch MergeBot
parent
747fc35ff5
commit
3df53c2a8f
@ -331,6 +331,11 @@ class DTensorTest(DTensorTestBase):
|
||||
except RuntimeError:
|
||||
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
|
||||
|
||||
# test the case under no-grad we directly return the local tensor
|
||||
with torch.no_grad():
|
||||
local_no_grad = sharded_tensor.to_local()
|
||||
assert local_no_grad is sharded_tensor._local_tensor
|
||||
|
||||
@with_comms
|
||||
def test_to_local_grad_hint(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
Reference in New Issue
Block a user