Fix multiple errors while parsing NativeFunctions from YAML (#127413)

Fixing multiple errors in parse_native_yaml when loading NativeFunctions from Yaml file.

Add assertions that validates parsed data.

Fixes #127404, #127405, #127406, #127407, #127408, #127409, #127410, #127411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127413
Approved by: https://github.com/ezyang
This commit is contained in:
Daniil Kutz
2024-05-30 16:25:02 +00:00
committed by PyTorch MergeBot
parent ea5c17de90
commit b506d37331
2 changed files with 23 additions and 2 deletions

View File

@ -165,9 +165,11 @@ def parse_native_yaml_struct(
rs: List[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es:
assert isinstance(e, dict), f"expected to be dict: {e}"
assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"])
funcs = e.get("func")
assert funcs is not None, f"missed 'func' in {e}"
with context(lambda: f"in {loc}:\n {funcs}"):
func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
rs.append(func)
@ -268,7 +270,11 @@ def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
base_func_map[f.func.name.name].append(f)
for f in funcs:
if f.structured_delegate is not None:
delegate_func = func_map[f.structured_delegate]
delegate_func = func_map.get(f.structured_delegate)
assert delegate_func is not None, (
f"{f.func.name} is marked as a structured_delegate pointing to "
f"{f.structured_delegate}, but {f.structured_delegate} is missing."
)
assert delegate_func.structured, (
f"{f.func.name} is marked as a structured_delegate pointing to "
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "

View File

@ -626,6 +626,9 @@ class NativeFunction:
assert device_check_s is None or isinstance(
device_check_s, str
), f"not a str: {device_check_s}"
assert (
device_check_s is None or device_check_s in DeviceCheckType.__members__
), f"illegal device_check: {device_check_s}"
device_check: DeviceCheckType
if device_check_s is None:
device_check = DeviceCheckType.ExactSame
@ -706,7 +709,12 @@ class NativeFunction:
for ks, v in raw_dispatch.items():
if ks == "__line__":
continue # not worth tracking line numbers for dispatch entries
assert isinstance(ks, str), e
assert isinstance(
ks, str
), f"illegal dispatch key '{ks}' in {raw_dispatch}"
assert isinstance(
v, str
), f"illegal dispatch value '{v}' in {raw_dispatch}"
for k in ks.split(","):
dispatch_key = DispatchKey.parse(k.strip())
num_dispatch_keys += 1
@ -2006,8 +2014,12 @@ class Argument:
def parse(arg: str) -> "Argument":
name: str
default: Optional[str]
assert " " in arg, f"illegal argument '{arg}'"
type_and_annot, name_and_default = arg.rsplit(" ", 1)
if "=" in name_and_default:
assert (
name_and_default.count("=") == 1
), f"illegal argument with default value: '{name_and_default}'"
name, default = name_and_default.split("=")
else:
name = name_and_default
@ -2792,6 +2804,9 @@ class Precompute:
)
arg, with_list_raw = raw_replace_item.split(" -> ")
assert (
" " not in arg
), f"illegal kernel param name '{arg}' in precomputed parameters'"
with_list = with_list_raw.split(",")
with_list_args = [Argument.parse(name.strip()) for name in with_list]
replace[arg] = with_list_args