fsdp.set_: convey to functionalization that it mutates storage (#132322)

Fixes https://github.com/pytorch/pytorch/issues/132197

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132322
Approved by: https://github.com/albanD, https://github.com/yf225
ghstack dependencies: #132243, #132337
This commit is contained in:
Brian Hirsh
2024-08-02 11:59:39 -07:00
committed by PyTorch MergeBot
parent 1a0db29932
commit af8b8a47cb
3 changed files with 33 additions and 0 deletions

View File

@ -5500,6 +5500,32 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
self.assertEqual(z2, (x_clone + 1).sin())
self.assertEqual(z3, (x_clone + 1).sin())
# https://github.com/pytorch/pytorch/issues/132197
def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self):
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
if not set_available:
return
@torch.compile(backend="aot_eager_decomp_partition")
def f(x, l):
z = x.sin()
y = x + 1
# graph input has its storage mutated
torch.ops.fsdp.set_.default(x, y)
z2 = x.sin()
return z2, l**2
x = torch.randn(3)
x_test = x.clone()
l = torch.randn(3, requires_grad=True)
result, _ = f(x, l)
result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")(
x_test, l
)
self.assertEqual(result, result_test)
self.assertEqual(x, x_test)
def test_changing_stride(self):
cnt = torch._dynamo.testing.CompileCounter()

View File

@ -808,6 +808,9 @@ def gen_pyi(
"_functionalize_was_storage_changed": [
"def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..."
],
"_functionalize_set_storage_changed": [
"def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..."
],
"_functionalize_has_metadata_mutation": [
"def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..."
],

View File

@ -118,6 +118,10 @@ nn.Parameter in order to see the result of .set_.
def set__functionalize(tensor, data):
torch._sync(tensor)
torch._sync(data)
# AOTDispatcher needs to know if any inputs had their storages mutated.
# (Why? It sometimes detaches inputs before sending them into the graph,
# when it sees that they do not need to have any gradients computed)
torch._functionalize_set_storage_changed(tensor)
tensor_inner = torch._from_functional_tensor(tensor)
data_inner = torch._from_functional_tensor(data)
with torch._C._ExcludeDispatchKeyGuard(