mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Apex O2 hook state_dict to return fp16 weights as fp32. Exporter cannot identify them as same tensors. Since this hook is only used by optimizer, it is safe to remove this hook while exporting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/66700 Reviewed By: zou3519 Differential Revision: D31695132 Pulled By: malfet fbshipit-source-id: 977bdf57240002498f3ad0f1a8046c352e9860e6
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import unittest
|
|
import onnxruntime # noqa: F401
|
|
import torch
|
|
|
|
from torch.cuda.amp import autocast
|
|
|
|
from test_pytorch_common import disableScriptTest, skipIfUnsupportedMinOpsetVersion
|
|
from test_pytorch_common import skipIfNoCuda
|
|
|
|
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
|
|
|
|
class TestONNXRuntime_cuda(unittest.TestCase):
|
|
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
|
opset_version = _export_onnx_opset_version
|
|
keep_initializers_as_inputs = True
|
|
onnx_shape_inference = True
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@skipIfNoCuda
|
|
def test_gelu_fp16(self):
|
|
class GeluModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(x)
|
|
|
|
x = torch.randn(2, 4, 5, 6, requires_grad=True, dtype=torch.float16, device=torch.device("cuda"))
|
|
self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5)
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(9)
|
|
@skipIfNoCuda
|
|
@disableScriptTest()
|
|
def test_layer_norm_fp16(self):
|
|
class LayerNormModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LayerNormModel, self).__init__()
|
|
self.layer_norm = torch.nn.LayerNorm([10, 10])
|
|
|
|
@autocast()
|
|
def forward(self, x):
|
|
return self.layer_norm(x)
|
|
|
|
x = torch.randn(20, 5, 10, 10, requires_grad=True, dtype=torch.float16, device=torch.device("cuda"))
|
|
self.run_test(LayerNormModel().cuda(), x, rtol=1e-3, atol=1e-5)
|
|
|
|
|
|
@skipIfUnsupportedMinOpsetVersion(12)
|
|
@skipIfNoCuda
|
|
@disableScriptTest()
|
|
def test_softmaxCrossEntropy_fusion_fp16(self):
|
|
class FusionModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(FusionModel, self).__init__()
|
|
self.loss = torch.nn.NLLLoss(reduction="none")
|
|
self.m = torch.nn.LogSoftmax(dim=1)
|
|
|
|
@autocast()
|
|
def forward(self, input, target):
|
|
output = self.loss(self.m(2 * input), target)
|
|
return output
|
|
|
|
N, C = 5, 4
|
|
input = torch.randn(N, 16, dtype=torch.float16, device=torch.device("cuda"))
|
|
target = torch.empty(N, dtype=torch.long, device=torch.device("cuda")).random_(0, C)
|
|
|
|
# using test data containing default ignore_index=-100
|
|
target[target == 1] = -100
|
|
self.run_test(FusionModel(), (input, target))
|
|
|
|
@skipIfNoCuda
|
|
@disableScriptTest()
|
|
def test_apex_o2(self):
|
|
class LinearModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LinearModel, self).__init__()
|
|
self.linear = torch.nn.Linear(3, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
try:
|
|
from apex import amp
|
|
except Exception:
|
|
raise unittest.SkipTest("Apex is not available")
|
|
input = torch.randn(3, 3, device=torch.device("cuda"))
|
|
model = amp.initialize(LinearModel(), opt_level="O2")
|
|
self.run_test(model, input)
|
|
|
|
TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
|
|
TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(TestONNXRuntime_cuda())
|