mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697 Approved by: https://github.com/ezyang
248 lines
7.8 KiB
Python
248 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
import argparse
|
|
import json
|
|
import unittest
|
|
from collections import defaultdict
|
|
from unittest.mock import Mock, patch
|
|
|
|
# pyrefly: ignore # import-error
|
|
from gen_operators_yaml import (
|
|
fill_output,
|
|
get_parser_options,
|
|
make_filter_from_options,
|
|
verify_all_specified_present,
|
|
)
|
|
|
|
|
|
def _mock_options():
|
|
options = argparse.Namespace()
|
|
options.root_ops = "aten::add,aten::cat"
|
|
options.training_root_ops = []
|
|
options.output_path = "/tmp"
|
|
options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
|
|
options.model_name = "test_model"
|
|
options.model_versions = None
|
|
options.model_assets = None
|
|
options.model_backends = None
|
|
options.models_yaml_path = None
|
|
options.include_all_operators = False
|
|
options.rule_name = "test_rule"
|
|
options.not_include_all_overloads_static_root_ops = True
|
|
options.not_include_all_overloads_closure_ops = True
|
|
|
|
return options
|
|
|
|
|
|
def _mock_load_op_dep_graph():
|
|
result = defaultdict(set)
|
|
result["aten::add"] = {"aten::add", "aten::as_strided_"}
|
|
result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
|
|
return dict(result)
|
|
|
|
|
|
class GenOperatorsYAMLTest(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
pass
|
|
|
|
def test_filter_creation(self) -> None:
|
|
filter_func = make_filter_from_options(
|
|
model_name="abc",
|
|
model_versions=["100", "101"],
|
|
model_assets=None,
|
|
model_backends=None,
|
|
)
|
|
config = [
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 102,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abcd",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 101,
|
|
"asset": "asset-2",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
]
|
|
|
|
filtered_configs = list(filter(filter_func, config))
|
|
assert len(filtered_configs) == 2, (
|
|
f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
|
)
|
|
|
|
def test_verification_success(self) -> None:
|
|
filter_func = make_filter_from_options(
|
|
model_name="abc",
|
|
model_versions=["100", "101"],
|
|
model_assets=["asset-1", "asset-2"],
|
|
model_backends=None,
|
|
)
|
|
config = [
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 101,
|
|
"asset": "asset-2",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
]
|
|
filtered_configs = list(filter(filter_func, config))
|
|
try:
|
|
verify_all_specified_present(
|
|
model_assets=["asset-1", "asset-2"],
|
|
model_versions=["100", "101"],
|
|
selected_models_yaml=filtered_configs,
|
|
rule_name="test",
|
|
model_name="abc",
|
|
new_style_rule=True,
|
|
)
|
|
except Exception:
|
|
self.fail(
|
|
"expected verify_all_specified_present to succeed instead it raised an exception"
|
|
)
|
|
|
|
def test_verification_fail(self) -> None:
|
|
config = [
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 100,
|
|
"asset": "asset-1",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
"traced_operators": [],
|
|
},
|
|
{
|
|
"model": {
|
|
"name": "abc",
|
|
"version": 101,
|
|
"asset": "asset-2",
|
|
"backend": "CPU",
|
|
},
|
|
"root_operators": [],
|
|
},
|
|
]
|
|
|
|
good_assets = ["asset-1", "asset-2"]
|
|
good_versions = ["100", "101"]
|
|
good_name = "abc"
|
|
|
|
# Test bad asset
|
|
filter_func_bad_asset = make_filter_from_options(
|
|
model_name=good_name,
|
|
model_versions=good_versions,
|
|
model_assets=["asset-1", "asset-2", "asset-3"],
|
|
model_backends=None,
|
|
)
|
|
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
|
|
with self.assertRaises(RuntimeError):
|
|
verify_all_specified_present(
|
|
model_assets=["asset-1", "asset-2", "asset-3"],
|
|
model_versions=good_versions,
|
|
selected_models_yaml=filtered_configs_asset,
|
|
rule_name="test",
|
|
model_name=good_name,
|
|
new_style_rule=True,
|
|
)
|
|
|
|
# Test bad version
|
|
filter_func_bad_version = make_filter_from_options(
|
|
model_name=good_name,
|
|
model_versions=["100", "101", "102"],
|
|
model_assets=good_assets,
|
|
model_backends=None,
|
|
)
|
|
filtered_configs_version = list(filter(filter_func_bad_version, config))
|
|
with self.assertRaises(RuntimeError):
|
|
verify_all_specified_present(
|
|
model_assets=good_assets,
|
|
model_versions=["100", "101", "102"],
|
|
selected_models_yaml=filtered_configs_version,
|
|
rule_name="test",
|
|
model_name=good_name,
|
|
new_style_rule=True,
|
|
)
|
|
|
|
# Test bad name
|
|
filter_func_bad_name = make_filter_from_options(
|
|
model_name="abcd",
|
|
model_versions=good_versions,
|
|
model_assets=good_assets,
|
|
model_backends=None,
|
|
)
|
|
filtered_configs_name = list(filter(filter_func_bad_name, config))
|
|
with self.assertRaises(RuntimeError):
|
|
verify_all_specified_present(
|
|
model_assets=good_assets,
|
|
model_versions=good_versions,
|
|
selected_models_yaml=filtered_configs_name,
|
|
rule_name="test",
|
|
model_name="abcd",
|
|
new_style_rule=True,
|
|
)
|
|
|
|
@patch("gen_operators_yaml.parse_options", return_value=_mock_options())
|
|
@patch(
|
|
"gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
|
|
)
|
|
def test_fill_output_with_arguments_not_include_all_overloads(
|
|
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
|
|
) -> None:
|
|
parser = argparse.ArgumentParser(description="Generate used operators YAML")
|
|
options = get_parser_options(parser)
|
|
|
|
model_dict = {
|
|
"model_name": options.model_name,
|
|
"asset_info": {},
|
|
"is_new_style_rule": False,
|
|
}
|
|
output = {"debug_info": [json.dumps(model_dict)]}
|
|
|
|
fill_output(output, options)
|
|
|
|
# pyrefly: ignore # missing-attribute
|
|
for op_val in output["operators"].values():
|
|
self.assertFalse(op_val["include_all_overloads"])
|