mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fsdp.set_: convey to functionalization that it mutates storage (#132322)
Fixes https://github.com/pytorch/pytorch/issues/132197 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132322 Approved by: https://github.com/albanD, https://github.com/yf225 ghstack dependencies: #132243, #132337
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a0db29932
commit
af8b8a47cb
@ -5500,6 +5500,32 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
self.assertEqual(z2, (x_clone + 1).sin())
|
||||
self.assertEqual(z3, (x_clone + 1).sin())
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/132197
|
||||
def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self):
|
||||
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
|
||||
if not set_available:
|
||||
return
|
||||
|
||||
@torch.compile(backend="aot_eager_decomp_partition")
|
||||
def f(x, l):
|
||||
z = x.sin()
|
||||
y = x + 1
|
||||
# graph input has its storage mutated
|
||||
torch.ops.fsdp.set_.default(x, y)
|
||||
z2 = x.sin()
|
||||
return z2, l**2
|
||||
|
||||
x = torch.randn(3)
|
||||
x_test = x.clone()
|
||||
l = torch.randn(3, requires_grad=True)
|
||||
result, _ = f(x, l)
|
||||
result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")(
|
||||
x_test, l
|
||||
)
|
||||
|
||||
self.assertEqual(result, result_test)
|
||||
self.assertEqual(x, x_test)
|
||||
|
||||
def test_changing_stride(self):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
|
@ -808,6 +808,9 @@ def gen_pyi(
|
||||
"_functionalize_was_storage_changed": [
|
||||
"def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..."
|
||||
],
|
||||
"_functionalize_set_storage_changed": [
|
||||
"def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..."
|
||||
],
|
||||
"_functionalize_has_metadata_mutation": [
|
||||
"def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..."
|
||||
],
|
||||
|
@ -118,6 +118,10 @@ nn.Parameter in order to see the result of .set_.
|
||||
def set__functionalize(tensor, data):
|
||||
torch._sync(tensor)
|
||||
torch._sync(data)
|
||||
# AOTDispatcher needs to know if any inputs had their storages mutated.
|
||||
# (Why? It sometimes detaches inputs before sending them into the graph,
|
||||
# when it sees that they do not need to have any gradients computed)
|
||||
torch._functionalize_set_storage_changed(tensor)
|
||||
tensor_inner = torch._from_functional_tensor(tensor)
|
||||
data_inner = torch._from_functional_tensor(data)
|
||||
with torch._C._ExcludeDispatchKeyGuard(
|
||||
|
Reference in New Issue
Block a user