mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix multiple errors while parsing NativeFunctions from YAML (#127413)
Fixing multiple errors in parse_native_yaml when loading NativeFunctions from Yaml file. Add assertions that validates parsed data. Fixes #127404, #127405, #127406, #127407, #127408, #127409, #127410, #127411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127413 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea5c17de90
commit
b506d37331
@ -165,9 +165,11 @@ def parse_native_yaml_struct(
|
||||
rs: List[NativeFunction] = []
|
||||
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
|
||||
for e in es:
|
||||
assert isinstance(e, dict), f"expected to be dict: {e}"
|
||||
assert isinstance(e.get("__line__"), int), e
|
||||
loc = Location(path, e["__line__"])
|
||||
funcs = e.get("func")
|
||||
assert funcs is not None, f"missed 'func' in {e}"
|
||||
with context(lambda: f"in {loc}:\n {funcs}"):
|
||||
func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
|
||||
rs.append(func)
|
||||
@ -268,7 +270,11 @@ def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
|
||||
base_func_map[f.func.name.name].append(f)
|
||||
for f in funcs:
|
||||
if f.structured_delegate is not None:
|
||||
delegate_func = func_map[f.structured_delegate]
|
||||
delegate_func = func_map.get(f.structured_delegate)
|
||||
assert delegate_func is not None, (
|
||||
f"{f.func.name} is marked as a structured_delegate pointing to "
|
||||
f"{f.structured_delegate}, but {f.structured_delegate} is missing."
|
||||
)
|
||||
assert delegate_func.structured, (
|
||||
f"{f.func.name} is marked as a structured_delegate pointing to "
|
||||
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
|
||||
|
@ -626,6 +626,9 @@ class NativeFunction:
|
||||
assert device_check_s is None or isinstance(
|
||||
device_check_s, str
|
||||
), f"not a str: {device_check_s}"
|
||||
assert (
|
||||
device_check_s is None or device_check_s in DeviceCheckType.__members__
|
||||
), f"illegal device_check: {device_check_s}"
|
||||
device_check: DeviceCheckType
|
||||
if device_check_s is None:
|
||||
device_check = DeviceCheckType.ExactSame
|
||||
@ -706,7 +709,12 @@ class NativeFunction:
|
||||
for ks, v in raw_dispatch.items():
|
||||
if ks == "__line__":
|
||||
continue # not worth tracking line numbers for dispatch entries
|
||||
assert isinstance(ks, str), e
|
||||
assert isinstance(
|
||||
ks, str
|
||||
), f"illegal dispatch key '{ks}' in {raw_dispatch}"
|
||||
assert isinstance(
|
||||
v, str
|
||||
), f"illegal dispatch value '{v}' in {raw_dispatch}"
|
||||
for k in ks.split(","):
|
||||
dispatch_key = DispatchKey.parse(k.strip())
|
||||
num_dispatch_keys += 1
|
||||
@ -2006,8 +2014,12 @@ class Argument:
|
||||
def parse(arg: str) -> "Argument":
|
||||
name: str
|
||||
default: Optional[str]
|
||||
assert " " in arg, f"illegal argument '{arg}'"
|
||||
type_and_annot, name_and_default = arg.rsplit(" ", 1)
|
||||
if "=" in name_and_default:
|
||||
assert (
|
||||
name_and_default.count("=") == 1
|
||||
), f"illegal argument with default value: '{name_and_default}'"
|
||||
name, default = name_and_default.split("=")
|
||||
else:
|
||||
name = name_and_default
|
||||
@ -2792,6 +2804,9 @@ class Precompute:
|
||||
)
|
||||
|
||||
arg, with_list_raw = raw_replace_item.split(" -> ")
|
||||
assert (
|
||||
" " not in arg
|
||||
), f"illegal kernel param name '{arg}' in precomputed parameters'"
|
||||
with_list = with_list_raw.split(",")
|
||||
with_list_args = [Argument.parse(name.strip()) for name in with_list]
|
||||
replace[arg] = with_list_args
|
||||
|
Reference in New Issue
Block a user