[hop] refactor only_consist_of with find_mismatched_vars (#140105)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140105
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2024-11-07 21:29:21 -08:00
committed by PyTorch MergeBot
parent 70a0906f24
commit 0fcd024f59

View File

@ -83,18 +83,36 @@ def dynamo_under_activation_checkpoint(tx: "InstructionTranslator"):
tx.output.current_tracer.under_activation_checkpoint = orig_val tx.output.current_tracer.under_activation_checkpoint = orig_val
def only_consist_of(var, types, allow_none=False): def find_mismatched_vars(var, types, allow_none=False):
if isinstance(var, types): """
return True Recursively finds variables whose type is not an instance of the specified types.
if allow_none and var.is_python_constant() and var.as_python_constant() is None: Args:
return True var: The variable to check.
types: A tuple of allowed types.
allow_none (bool): Whether to allow None values. Defaults to False.
Returns:
A set of variables whose type is not an instance of the specified types.
"""
mismatched_vars = set()
if isinstance(var, (TupleVariable, ListVariable)): if isinstance(var, (TupleVariable, ListVariable)):
return all(only_consist_of(item, types, allow_none) for item in var.items) for item in var.items:
if isinstance(var, ConstDictVariable): mismatched_vars.update(find_mismatched_vars(item, types, allow_none))
return all( elif isinstance(var, ConstDictVariable):
only_consist_of(item, types, allow_none) for item in var.items.values() for value in var.items.values():
) mismatched_vars.update(find_mismatched_vars(value, types, allow_none))
return False else:
def _is_none(var):
return var.is_python_constant() and var.as_python_constant() is None
if not isinstance(var, types) and not (allow_none and _is_none(var)):
mismatched_vars.add(var)
return mismatched_vars
def only_consist_of(var, types, allow_none=False):
mismatch_vars = find_mismatched_vars(var, types, allow_none=allow_none)
return len(mismatch_vars) == 0
# A more read-able syntax sugar for creating a UserFunctionVariable for f # A more read-able syntax sugar for creating a UserFunctionVariable for f