mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit #134592 as smaller PRs. In fx tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) - Remove any remaining uses of "unittest.main()"" Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715 Approved by: https://github.com/Skylion007
160 lines
5.0 KiB
Python
160 lines
5.0 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
|
|
|
|
|
class MyModuleBase(torch.nn.Module):
|
|
def forward(self, x):
|
|
matrx = self.get_mul_matrix()
|
|
if self.no_relu():
|
|
return torch.mm(x, matrx)
|
|
else:
|
|
return torch.relu(torch.mm(x, matrx))
|
|
|
|
def get_mul_matrix(self):
|
|
return self.param
|
|
|
|
def no_relu(self):
|
|
raise Exception("not implemented") # noqa: TRY002
|
|
|
|
|
|
class MyModuleParamShape(MyModuleBase):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
|
|
|
|
def no_relu(self):
|
|
return self.param.shape[0] < 10
|
|
|
|
|
|
class MyModuleParamSize(MyModuleBase):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
|
|
|
|
def no_relu(self):
|
|
return self.param.size()[0] < 10
|
|
|
|
|
|
class MyModuleParamDim(MyModuleBase):
|
|
def __init__(self, param):
|
|
super().__init__()
|
|
self.param = param
|
|
|
|
def get_mul_matrix(self):
|
|
return self.param[0] if (self.param.dim() == 3) else self.param
|
|
|
|
def no_relu(self):
|
|
return self.param.dim() == 3
|
|
|
|
|
|
class MyModuleParamNDim(MyModuleBase):
|
|
def __init__(self, param):
|
|
super().__init__()
|
|
self.param = param
|
|
|
|
def get_mul_matrix(self):
|
|
return self.param[0] if (self.param.ndim == 3) else self.param
|
|
|
|
def no_relu(self):
|
|
return self.param.ndim == 3
|
|
|
|
|
|
class MyModuleParamNumEl(MyModuleBase):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
|
|
|
|
def no_relu(self):
|
|
return self.param.numel() < 10 * 3
|
|
|
|
|
|
class MyModuleParamNElement(MyModuleBase):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
|
|
|
|
def no_relu(self):
|
|
return self.param.nelement() < 10 * 3
|
|
|
|
|
|
class TestConstParamShapeInControlFlow(TestCase):
|
|
def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
|
|
"""
|
|
Verify one module only does a mm op while the other
|
|
performs both mm and relu ops in cascade
|
|
"""
|
|
x = torch.randn(10, 5)
|
|
torch.testing.assert_close(
|
|
mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())
|
|
)
|
|
tracer = torch.fx.Tracer(param_shapes_constant=True)
|
|
traced_graph = tracer.trace(mm_only_mod)
|
|
|
|
# verify the graph module calculates the same result
|
|
graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
|
|
torch.testing.assert_close(
|
|
graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())
|
|
)
|
|
|
|
# Make a new module with different parameter shape to go down the different
|
|
# code path
|
|
x = torch.randn(10, 15)
|
|
torch.testing.assert_close(
|
|
relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
|
|
)
|
|
|
|
tracer2 = torch.fx.Tracer(param_shapes_constant=True)
|
|
traced_graph2 = tracer2.trace(relu_mod)
|
|
|
|
# verify the graph module calculates the same result
|
|
graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
|
|
torch.testing.assert_close(
|
|
graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
|
|
)
|
|
|
|
graph1_node_targets = [n.target for n in traced_graph.nodes]
|
|
graph2_node_targets = [n.target for n in traced_graph2.nodes]
|
|
|
|
# the second graph has an exta relu function call node
|
|
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
|
|
assert (
|
|
torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
|
|
)
|
|
|
|
def test_param_shape_const(self):
|
|
mymod = MyModuleParamShape(in_channels=5)
|
|
mymod2 = MyModuleParamShape(in_channels=15)
|
|
self.verify_mm_relu_mods(mymod, mymod2)
|
|
|
|
def test_param_size_const(self):
|
|
mymod = MyModuleParamSize(in_channels=5)
|
|
mymod2 = MyModuleParamSize(in_channels=15)
|
|
self.verify_mm_relu_mods(mymod, mymod2)
|
|
|
|
def test_param_dim_const(self):
|
|
mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
|
|
mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
|
|
self.verify_mm_relu_mods(mymod, mymod2)
|
|
|
|
def test_param_ndim_const(self):
|
|
mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
|
|
mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
|
|
self.verify_mm_relu_mods(mymod, mymod2)
|
|
|
|
def test_param_numel_const(self):
|
|
mymod = MyModuleParamNumEl(in_channels=5)
|
|
mymod2 = MyModuleParamNumEl(in_channels=15)
|
|
self.verify_mm_relu_mods(mymod, mymod2)
|
|
|
|
def test_param_nelement_const(self):
|
|
mymod = MyModuleParamNElement(in_channels=5)
|
|
mymod2 = MyModuleParamNElement(in_channels=15)
|
|
self.verify_mm_relu_mods(mymod, mymod2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_fx.py")
|