Compare commits

...

1 Commits

Author SHA1 Message Date
de2300bde8 [export] Make RNNs exportable on GPUs (#163245)
Summary:

Completing https://github.com/pytorch/pytorch/pull/155734

Test Plan:
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_lstm_gpu
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_gru_gpu
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_export_rnn_flatten_parameters_gpu

Differential Revision: D82687470
2025-11-14 09:59:28 -08:00
2 changed files with 96 additions and 6 deletions

View File

@ -8013,6 +8013,82 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
):
_ = export(mod, inp, strict=True)
@requires_gpu
@testing.expectedFailureSerDer
@testing.expectedFailureSerDerNonStrict
def test_export_lstm_gpu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.rnn = torch.nn.LSTM(
input_size=4, hidden_size=5, num_layers=1, batch_first=True
)
def forward(self, x):
out, _ = self.rnn(x)
return out
m = M().to(GPU_TYPE)
x = torch.randn(2, 3, 4, device=GPU_TYPE)
ep = export(m, (x,))
self.assertTrue(callable(ep.module()))
eager_out = m(x)
export_out = ep.module()(x)
self.assertEqual(eager_out, export_out)
@requires_gpu
@testing.expectedFailureSerDer
@testing.expectedFailureSerDerNonStrict
def test_export_gru_gpu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.rnn = torch.nn.GRU(
input_size=4, hidden_size=5, num_layers=1, batch_first=True
)
def forward(self, x):
out, _ = self.rnn(x)
return out
m = M().to(GPU_TYPE)
x = torch.randn(2, 3, 4, device=GPU_TYPE)
ep = export(m, (x,))
self.assertTrue(callable(ep.module()))
eager_out = m(x)
export_out = ep.module()(x)
self.assertEqual(eager_out, export_out)
@requires_gpu
@testing.expectedFailureCppSerDes
@testing.expectedFailureSerDer
@testing.expectedFailureSerDerNonStrict
def test_export_rnn_flatten_parameters_gpu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(
input_size=3, hidden_size=4, num_layers=2, batch_first=True
)
def forward(self, x):
self.lstm.flatten_parameters()
out, (h, c) = self.lstm(x)
return out
m = M().to(GPU_TYPE)
x = torch.randn(1, 5, 3, device=GPU_TYPE)
ep = export(m, (x,), strict=False)
eager_out = m(x)
export_out = ep.module()(x)
self.assertEqual(eager_out, export_out)
def test_device_to_static(self):
class Module(torch.nn.Module):
def forward(self, x):

View File

@ -253,12 +253,26 @@ class RNNBase(Module):
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = {
p.data_ptr() # type: ignore[union-attr]
for p in self._flat_weights
}
if len(unique_data_ptrs) != len(self._flat_weights):
return
# Try to use data_ptr() first, fallback to untyped_storage() if not accessible
try:
unique_data_ptrs = {
p.data_ptr() # type: ignore[union-attr]
for p in self._flat_weights
}
if len(unique_data_ptrs) != len(self._flat_weights):
return
except RuntimeError:
# PT2 specific path to make RNN traceable with fake tensor
from torch._subclasses.fake_tensor import FakeTensor
if not all(isinstance(p, FakeTensor) for p in self._flat_weights): # type: ignore[union-attr]
raise
unique_storage_refs = {
p.untyped_storage()
for p in self._flat_weights # type: ignore[union-attr]
}
if len(unique_storage_refs) != len(self._flat_weights):
return
with torch.cuda.device_of(first_fw):
import torch.backends.cudnn.rnn as rnn