[dtensor] directly return local_tensor under no_grad (#128145)

as titled, skip the autograd function and directly return the
local_tensor if it's under no_grad context, this would avoid creating
views

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128145
Approved by: https://github.com/awgu
ghstack dependencies: #128112
This commit is contained in:
Wanchao Liang
2024-06-06 09:43:38 -07:00
committed by PyTorch MergeBot
parent 747fc35ff5
commit 3df53c2a8f
2 changed files with 8 additions and 0 deletions

View File

@ -331,6 +331,11 @@ class DTensorTest(DTensorTestBase):
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
# test the case under no-grad we directly return the local tensor
with torch.no_grad():
local_no_grad = sharded_tensor.to_local()
assert local_no_grad is sharded_tensor._local_tensor
@with_comms
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))