mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop] refactor only_consist_of with find_mismatched_vars (#140105)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140105 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
70a0906f24
commit
0fcd024f59
@ -83,18 +83,36 @@ def dynamo_under_activation_checkpoint(tx: "InstructionTranslator"):
|
||||
tx.output.current_tracer.under_activation_checkpoint = orig_val
|
||||
|
||||
|
||||
def only_consist_of(var, types, allow_none=False):
|
||||
if isinstance(var, types):
|
||||
return True
|
||||
if allow_none and var.is_python_constant() and var.as_python_constant() is None:
|
||||
return True
|
||||
def find_mismatched_vars(var, types, allow_none=False):
|
||||
"""
|
||||
Recursively finds variables whose type is not an instance of the specified types.
|
||||
Args:
|
||||
var: The variable to check.
|
||||
types: A tuple of allowed types.
|
||||
allow_none (bool): Whether to allow None values. Defaults to False.
|
||||
Returns:
|
||||
A set of variables whose type is not an instance of the specified types.
|
||||
"""
|
||||
mismatched_vars = set()
|
||||
if isinstance(var, (TupleVariable, ListVariable)):
|
||||
return all(only_consist_of(item, types, allow_none) for item in var.items)
|
||||
if isinstance(var, ConstDictVariable):
|
||||
return all(
|
||||
only_consist_of(item, types, allow_none) for item in var.items.values()
|
||||
)
|
||||
return False
|
||||
for item in var.items:
|
||||
mismatched_vars.update(find_mismatched_vars(item, types, allow_none))
|
||||
elif isinstance(var, ConstDictVariable):
|
||||
for value in var.items.values():
|
||||
mismatched_vars.update(find_mismatched_vars(value, types, allow_none))
|
||||
else:
|
||||
|
||||
def _is_none(var):
|
||||
return var.is_python_constant() and var.as_python_constant() is None
|
||||
|
||||
if not isinstance(var, types) and not (allow_none and _is_none(var)):
|
||||
mismatched_vars.add(var)
|
||||
return mismatched_vars
|
||||
|
||||
|
||||
def only_consist_of(var, types, allow_none=False):
|
||||
mismatch_vars = find_mismatched_vars(var, types, allow_none=allow_none)
|
||||
return len(mismatch_vars) == 0
|
||||
|
||||
|
||||
# A more read-able syntax sugar for creating a UserFunctionVariable for f
|
||||
|
Reference in New Issue
Block a user