[dynamo][user-defined] User class.__new__ instead of special casing (#146677)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146677
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-02-09 17:56:28 -08:00
committed by PyTorch MergeBot
parent de6efa1feb
commit ee8a06f1f6
5 changed files with 225 additions and 123 deletions

View File

@ -1093,42 +1093,30 @@ class BuiltinVariable(VariableTracker):
and name_var.is_python_constant()
):
return obj.method_setattr_standard(tx, name_var, val)
if self.fn is object and name == "__new__":
assert len(args) == 1
assert len(kwargs) == 0
return tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
if self.fn is object and name == "__init__":
# object.__init__ is a no-op
return variables.ConstantVariable(None)
if self.fn is dict and name == "__new__":
assert len(args) == 1
assert len(kwargs) == 0
dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
return dict_vt
# We don't have to set the underlying dict_vt in
# UserDefinedDictVariable because it will be set to empty
# ConstDictVariableTracker in the constructor.
return tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
if self.fn is dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
if self.fn is dict:
resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable):
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
if name == "__new__":
# Supported __new__ methods
if self.fn is object and len(args) == 1:
assert len(kwargs) == 0
return tx.output.side_effects.track_new_user_defined_object(
self, args[0], args[1:]
)
if self.fn is dict and len(args) == 1 and not kwargs:
dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
return dict_vt
# We don't have to set the underlying dict_vt in
# UserDefinedDictVariable because it will be set to empty
# ConstDictVariableTracker in the constructor.
return tx.output.side_effects.track_new_user_defined_object(
self,
args[0],
args[1:],
)
if self.fn is tuple:
resolved_fn = getattr(self.fn, name)
if (
resolved_fn is tuple.__new__
self.fn is tuple
and len(args) == 2
and args[1].has_unpack_var_sequence(tx)
and not kwargs
@ -1140,20 +1128,29 @@ class BuiltinVariable(VariableTracker):
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
return tuple_vt
result = (
tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
result = tx.output.side_effects.track_new_user_defined_object(
self,
args[0],
args[1:],
)
# side_effects data structure does not support methods like
# tuple.__new__ that uses *args parameters for the __new__
# method. Therefore, we manage the *args related functionality
# here. For other datastructures, this is done in the __init__
# method.
result.set_new_args(args[1])
result.set_underlying_tuple_vt(tuple_vt)
return result
if self.fn is object and name == "__init__":
# object.__init__ is a no-op
return variables.ConstantVariable(None)
if self.fn is dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
if self.fn is dict:
resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable):
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
return super().call_method(tx, name, args, kwargs)
def _call_int_float(self, tx: "InstructionTranslator", arg):