Files
pytorch/tools/test/gen_operators_yaml_test.py
Xuehai Pan c73a92fbf5 [BE][CI] bump ruff to 0.9.2: multiline assert statements (#144546)
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements

> Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target:
>
> ```python
> # Input
> assert (
>     len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
>
> # Black
> assert (
>     len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
> # Ruff
> assert len(policy_types) >= priority + num_duplicates, (
>     f"This tests needs at least {priority + num_duplicates} many types."
> )
> ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546
Approved by: https://github.com/malfet
2025-02-27 20:46:16 +00:00

246 lines
7.8 KiB
Python

#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import argparse
import json
import unittest
from collections import defaultdict
from unittest.mock import Mock, patch
from gen_operators_yaml import (
fill_output,
get_parser_options,
make_filter_from_options,
verify_all_specified_present,
)
def _mock_options():
options = argparse.Namespace()
options.root_ops = "aten::add,aten::cat"
options.training_root_ops = []
options.output_path = "/tmp"
options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
options.model_name = "test_model"
options.model_versions = None
options.model_assets = None
options.model_backends = None
options.models_yaml_path = None
options.include_all_operators = False
options.rule_name = "test_rule"
options.not_include_all_overloads_static_root_ops = True
options.not_include_all_overloads_closure_ops = True
return options
def _mock_load_op_dep_graph():
result = defaultdict(set)
result["aten::add"] = {"aten::add", "aten::as_strided_"}
result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
return dict(result)
class GenOperatorsYAMLTest(unittest.TestCase):
def setUp(self) -> None:
pass
def test_filter_creation(self) -> None:
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
model_assets=None,
model_backends=None,
)
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 102,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
},
{
"model": {
"name": "abcd",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
filtered_configs = list(filter(filter_func, config))
assert len(filtered_configs) == 2, (
f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
)
def test_verification_success(self) -> None:
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
model_assets=["asset-1", "asset-2"],
model_backends=None,
)
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
filtered_configs = list(filter(filter_func, config))
try:
verify_all_specified_present(
model_assets=["asset-1", "asset-2"],
model_versions=["100", "101"],
selected_models_yaml=filtered_configs,
rule_name="test",
model_name="abc",
new_style_rule=True,
)
except Exception:
self.fail(
"expected verify_all_specified_present to succeed instead it raised an exception"
)
def test_verification_fail(self) -> None:
config = [
{
"model": {
"name": "abc",
"version": 100,
"asset": "asset-1",
"backend": "CPU",
},
"root_operators": [],
"traced_operators": [],
},
{
"model": {
"name": "abc",
"version": 101,
"asset": "asset-2",
"backend": "CPU",
},
"root_operators": [],
},
]
good_assets = ["asset-1", "asset-2"]
good_versions = ["100", "101"]
good_name = "abc"
# Test bad asset
filter_func_bad_asset = make_filter_from_options(
model_name=good_name,
model_versions=good_versions,
model_assets=["asset-1", "asset-2", "asset-3"],
model_backends=None,
)
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=["asset-1", "asset-2", "asset-3"],
model_versions=good_versions,
selected_models_yaml=filtered_configs_asset,
rule_name="test",
model_name=good_name,
new_style_rule=True,
)
# Test bad version
filter_func_bad_version = make_filter_from_options(
model_name=good_name,
model_versions=["100", "101", "102"],
model_assets=good_assets,
model_backends=None,
)
filtered_configs_version = list(filter(filter_func_bad_version, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=good_assets,
model_versions=["100", "101", "102"],
selected_models_yaml=filtered_configs_version,
rule_name="test",
model_name=good_name,
new_style_rule=True,
)
# Test bad name
filter_func_bad_name = make_filter_from_options(
model_name="abcd",
model_versions=good_versions,
model_assets=good_assets,
model_backends=None,
)
filtered_configs_name = list(filter(filter_func_bad_name, config))
with self.assertRaises(RuntimeError):
verify_all_specified_present(
model_assets=good_assets,
model_versions=good_versions,
selected_models_yaml=filtered_configs_name,
rule_name="test",
model_name="abcd",
new_style_rule=True,
)
@patch("gen_operators_yaml.parse_options", return_value=_mock_options())
@patch(
"gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
)
def test_fill_output_with_arguments_not_include_all_overloads(
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
) -> None:
parser = argparse.ArgumentParser(description="Generate used operators YAML")
options = get_parser_options(parser)
model_dict = {
"model_name": options.model_name,
"asset_info": {},
"is_new_style_rule": False,
}
output = {"debug_info": [json.dumps(model_dict)]}
fill_output(output, options)
for op_val in output["operators"].values():
self.assertFalse(op_val["include_all_overloads"])