mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
205 lines
6.3 KiB
Python
205 lines
6.3 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from itertools import product
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.testing import FileCheck
|
|
|
|
|
|
try:
|
|
import torchvision
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
activations = [
|
|
F.celu,
|
|
F.elu,
|
|
F.hardsigmoid,
|
|
F.hardswish,
|
|
F.hardtanh,
|
|
F.leaky_relu,
|
|
F.relu,
|
|
F.relu6,
|
|
F.rrelu,
|
|
F.selu,
|
|
F.silu,
|
|
]
|
|
|
|
|
|
class TestFunctionalToInplaceActivation(JitTestCase):
|
|
def test_check_no_type_promotion(self):
|
|
dtypes = [
|
|
torch.bool,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.float32,
|
|
torch.float64,
|
|
]
|
|
# restore_mutation.h contains a mapping from activation operators
|
|
# to whether they allow type conversion. Use this checking to
|
|
# guard the mapping, and if any later change breaks the assumption
|
|
# we need to update the mapping correspondingly.
|
|
for activation, dtype in product(activations, dtypes):
|
|
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
|
|
try:
|
|
out = activation(inp)
|
|
self.assertEqual(dtype, out.dtype)
|
|
except RuntimeError:
|
|
# Skip the not implemented error
|
|
pass
|
|
|
|
def test_functional_to_inplace_activation(self):
|
|
for activation in activations:
|
|
|
|
def test_basic(x):
|
|
y = x + 1
|
|
z = activation(y)
|
|
return z
|
|
|
|
fn = torch.jit.script(test_basic)
|
|
self.run_pass("inline", fn.graph)
|
|
self.run_pass("constant_propagation", fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
|
self.run_pass("functional_to_inplace_activation", fn.graph)
|
|
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
|
inp = torch.rand([2, 2])
|
|
self.assertEqual(fn(inp), test_basic(inp))
|
|
|
|
def test_no_functional_to_inplace(self):
|
|
# inplace conversion should not happen because sigmoid may
|
|
# perform type conversion
|
|
def test1():
|
|
y = torch.ones([2, 2])
|
|
z = torch.sigmoid(y)
|
|
return z
|
|
|
|
fn = torch.jit.script(test1)
|
|
self.run_pass("functional_to_inplace_activation", fn.graph)
|
|
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
|
|
|
|
# inplace conversion should not happen because y is alias
|
|
# the input x
|
|
def test2(x):
|
|
y = x[0]
|
|
z = torch.relu(y)
|
|
return z
|
|
|
|
fn = torch.jit.script(test2)
|
|
self.run_pass("functional_to_inplace_activation", fn.graph)
|
|
FileCheck().check_not("aten::relu_").run(fn.graph)
|
|
|
|
# inplace conversion should not happen because self.x is
|
|
# at the global scope
|
|
class Test3(nn.Module):
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.x = x
|
|
|
|
def forward(self):
|
|
y = torch.relu(self.x)
|
|
return y
|
|
|
|
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
|
|
self.run_pass("functional_to_inplace_activation", fn.graph)
|
|
FileCheck().check_not("aten::relu_").run(fn.graph)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_resnet18_correctness(self):
|
|
model = torchvision.models.resnet18()
|
|
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
|
(
|
|
N,
|
|
C,
|
|
H,
|
|
W,
|
|
) = (
|
|
10,
|
|
3,
|
|
224,
|
|
224,
|
|
)
|
|
inp = torch.randn(N, C, H, W)
|
|
self.run_pass("functional_to_inplace_activation", frozen_model.graph)
|
|
self.assertEqual(model(inp), frozen_model(inp))
|
|
|
|
|
|
class TestInplaceToFunctionalActivation(JitTestCase):
|
|
def test_inplace_to_functional_activation(self):
|
|
for activation in activations:
|
|
|
|
def test_basic(x):
|
|
y = x + 1
|
|
activation(y, inplace=True)
|
|
return y
|
|
|
|
fn = torch.jit.script(test_basic)
|
|
self.run_pass("inline", fn.graph)
|
|
self.run_pass("constant_propagation", fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
|
self.run_pass("inplace_to_functional_activation", fn.graph)
|
|
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
|
|
|
for activation in [
|
|
torch.relu_,
|
|
torch.sigmoid_,
|
|
torch.tanh_,
|
|
]:
|
|
|
|
def test_basic(x):
|
|
y = x + 1
|
|
activation(y)
|
|
return y
|
|
|
|
fn = torch.jit.script(test_basic)
|
|
self.run_pass("inline", fn.graph)
|
|
self.run_pass("constant_propagation", fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
|
|
self.run_pass("inplace_to_functional_activation", fn.graph)
|
|
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
|
|
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
|
|
|
|
inp = torch.rand([2, 2])
|
|
self.assertEqual(fn(inp), test_basic(inp))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_resnet18_correctness(self):
|
|
model = torchvision.models.resnet18()
|
|
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
|
(
|
|
N,
|
|
C,
|
|
H,
|
|
W,
|
|
) = (
|
|
10,
|
|
3,
|
|
224,
|
|
224,
|
|
)
|
|
inp = torch.randn(N, C, H, W)
|
|
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
|
|
self.assertEqual(model(inp), frozen_model(inp))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|