Files
pytorch/test/jit/test_convert_activation.py
Anthony Barbier bf7e290854 Add __main__ guards to jit tests (#154725)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In jit tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725
Approved by: https://github.com/clee2000
2025-06-16 10:28:45 +00:00

205 lines
6.3 KiB
Python

# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
activations = [
F.celu,
F.elu,
F.hardsigmoid,
F.hardswish,
F.hardtanh,
F.leaky_relu,
F.relu,
F.relu6,
F.rrelu,
F.selu,
F.silu,
]
class TestFunctionalToInplaceActivation(JitTestCase):
def test_check_no_type_promotion(self):
dtypes = [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float32,
torch.float64,
]
# restore_mutation.h contains a mapping from activation operators
# to whether they allow type conversion. Use this checking to
# guard the mapping, and if any later change breaks the assumption
# we need to update the mapping correspondingly.
for activation, dtype in product(activations, dtypes):
inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
try:
out = activation(inp)
self.assertEqual(dtype, out.dtype)
except RuntimeError:
# Skip the not implemented error
pass
def test_functional_to_inplace_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
z = activation(y)
return z
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
def test_no_functional_to_inplace(self):
# inplace conversion should not happen because sigmoid may
# perform type conversion
def test1():
y = torch.ones([2, 2])
z = torch.sigmoid(y)
return z
fn = torch.jit.script(test1)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
# inplace conversion should not happen because y is alias
# the input x
def test2(x):
y = x[0]
z = torch.relu(y)
return z
fn = torch.jit.script(test2)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
# inplace conversion should not happen because self.x is
# at the global scope
class Test3(nn.Module):
def __init__(self, x):
super().__init__()
self.x = x
def forward(self):
y = torch.relu(self.x)
return y
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
(
N,
C,
H,
W,
) = (
10,
3,
224,
224,
)
inp = torch.randn(N, C, H, W)
self.run_pass("functional_to_inplace_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
class TestInplaceToFunctionalActivation(JitTestCase):
def test_inplace_to_functional_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
activation(y, inplace=True)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
self.run_pass("inplace_to_functional_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
for activation in [
torch.relu_,
torch.sigmoid_,
torch.tanh_,
]:
def test_basic(x):
y = x + 1
activation(y)
return y
fn = torch.jit.script(test_basic)
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
self.run_pass("inplace_to_functional_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
inp = torch.rand([2, 2])
self.assertEqual(fn(inp), test_basic(inp))
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
(
N,
C,
H,
W,
) = (
10,
3,
224,
224,
)
inp = torch.randn(N, C, H, W)
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")