Update (base update)

[ghstack-poisoned]
This commit is contained in:
Guilherme Leobas
2025-09-16 15:45:19 -03:00
parent be4b63b40c
commit 0203d39965
3 changed files with 13 additions and 7 deletions

View File

@ -1510,9 +1510,18 @@ class BuiltinVariable(VariableTracker):
assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat))
return SymNodeVariable.create(tx, arg.as_proxy() != 0)
if isinstance(arg, (ConstDictVariable, UserDefinedDictVariable)):
if isinstance(arg, ConstDictVariable):
return ConstantVariable.create(len(arg.items) > 0)
if isinstance(arg, UserDefinedObjectVariable):
# for user defined objects, first try __bool__ if defined, else
# __len__. If neither is defined, then any instance is considered True
if arg.call_obj_hasattr(tx, "__bool__").value:
return arg.call_method(tx, "__bool__", [], {})
elif arg.call_obj_hasattr(tx, "__len__").value:
length = arg.call_method(tx, "__len__", [], {})
return ConstantVariable.create(length.value > 0)
else:
return ConstantVariable.create(True)
# TODO handle more cases and merge this with this with `generic_jump`.
def call_str(self, tx: "InstructionTranslator", arg):

View File

@ -1371,6 +1371,8 @@ class DictItemsVariable(DictViewVariable):
return dict_items
def call_method(self, tx, name, args, kwargs):
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":
assert len(args) == 1
if isinstance(args[0], DictItemsVariable):

View File

@ -246,11 +246,6 @@ class IteratorVariable(VariableTracker):
def has_force_unpack_var_sequence(self, tx) -> bool:
return True
def call_obj_hasattr(self, tx, name):
if name in ("__iter__", "__next__"):
return ConstantVariable.create(True)
return super().call_obj_hasattr(tx, name)
class ObjectIteratorVariable(IteratorVariable):
"""