mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partially addresses #123062 Ran lintrunner on: - `test/fx` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123622 Approved by: https://github.com/ezyang
133 lines
2.7 KiB
Python
133 lines
2.7 KiB
Python
# Owner(s): ["oncall: fx"]
|
|
|
|
import itertools
|
|
|
|
import torch
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.passes.dialect.common.cse_pass import CSEPass
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
def FactoryFunctionCall(x, device):
|
|
y = torch.full(x.shape, 3, device=device)
|
|
z = torch.add(y, x)
|
|
return z
|
|
|
|
|
|
def TorchTensorCall(x):
|
|
y = torch.tensor(3)
|
|
return x + y
|
|
|
|
|
|
def TakeList(x):
|
|
z = torch.cat([x, x])
|
|
return z
|
|
|
|
|
|
def ReturnList(x):
|
|
a = torch.arange(10).reshape(5, 2)
|
|
z = torch.split(a, [1, 4])
|
|
return z
|
|
|
|
|
|
def Mutation(x):
|
|
y = x + 2
|
|
y.add_(1)
|
|
return x + y
|
|
|
|
|
|
def MutationInput(x):
|
|
x.add_(1)
|
|
y = x + 2
|
|
return x + y
|
|
|
|
|
|
def MutationFactory(x, device):
|
|
y = torch.full(x.shape, 3, device=device)
|
|
y.add_(1)
|
|
return x + y
|
|
|
|
|
|
def MutationTorchTensorCall(x):
|
|
y = torch.tensor(3)
|
|
y.add_(1)
|
|
return x + y
|
|
|
|
|
|
def MutationMetadata(x):
|
|
x.resize_(2)
|
|
return x
|
|
|
|
|
|
Passes = [CSEPass]
|
|
Test_Cases = [
|
|
TakeList,
|
|
ReturnList,
|
|
Mutation,
|
|
MutationInput,
|
|
MutationMetadata,
|
|
MutationTorchTensorCall,
|
|
]
|
|
Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
|
|
Devices = ["cpu"]
|
|
if torch.cuda.is_available():
|
|
Devices.append("cuda")
|
|
|
|
|
|
def name_fn(common_pass, f, device):
|
|
"""Names parameterized test cases."""
|
|
return f"{type(common_pass()).__name__}_{f.__name__}_{device}"
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestCommonPass(TestCase):
|
|
@parametrize(
|
|
"common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn
|
|
)
|
|
def test_correctness(self, common_pass, f, device):
|
|
inp = torch.randn(10, device=device)
|
|
|
|
traced_m = make_fx(f)(inp)
|
|
P = common_pass()
|
|
|
|
res = P(traced_m)
|
|
modified_m = res.graph_module
|
|
assert isinstance(modified_m, GraphModule)
|
|
|
|
inp_copy = inp.clone()
|
|
expected = f(inp)
|
|
result = modified_m(inp_copy)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
@parametrize(
|
|
"common_pass,f,device",
|
|
itertools.product(Passes, Factory_Test_Cases, Devices),
|
|
name_fn,
|
|
)
|
|
def test_correctness_factory(self, common_pass, f, device):
|
|
inp = torch.randn(10, device=device)
|
|
traced_m = make_fx(f)(inp, device)
|
|
P = common_pass()
|
|
|
|
res = P(traced_m)
|
|
modified_m = res.graph_module
|
|
assert isinstance(modified_m, GraphModule)
|
|
|
|
inp_copy = inp.clone()
|
|
expected = f(inp, device)
|
|
result = modified_m(inp_copy, device)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|