mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][user-defined] User class.__new__ instead of special casing (#146677)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146677 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
de6efa1feb
commit
ee8a06f1f6
@ -1093,42 +1093,30 @@ class BuiltinVariable(VariableTracker):
|
||||
and name_var.is_python_constant()
|
||||
):
|
||||
return obj.method_setattr_standard(tx, name_var, val)
|
||||
if self.fn is object and name == "__new__":
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
return tx.output.side_effects.track_object_new_from_user_defined_class(
|
||||
args[0]
|
||||
)
|
||||
if self.fn is object and name == "__init__":
|
||||
# object.__init__ is a no-op
|
||||
return variables.ConstantVariable(None)
|
||||
if self.fn is dict and name == "__new__":
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
|
||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
|
||||
return dict_vt
|
||||
# We don't have to set the underlying dict_vt in
|
||||
# UserDefinedDictVariable because it will be set to empty
|
||||
# ConstDictVariableTracker in the constructor.
|
||||
return tx.output.side_effects.track_object_new_from_user_defined_class(
|
||||
args[0]
|
||||
)
|
||||
if self.fn is dict and name == "fromkeys":
|
||||
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
||||
|
||||
if self.fn is dict:
|
||||
resolved_fn = getattr(self.fn, name)
|
||||
if resolved_fn in dict_methods:
|
||||
if isinstance(args[0], variables.UserDefinedDictVariable):
|
||||
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
||||
elif isinstance(args[0], variables.ConstDictVariable):
|
||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||
if name == "__new__":
|
||||
# Supported __new__ methods
|
||||
if self.fn is object and len(args) == 1:
|
||||
assert len(kwargs) == 0
|
||||
return tx.output.side_effects.track_new_user_defined_object(
|
||||
self, args[0], args[1:]
|
||||
)
|
||||
|
||||
if self.fn is dict and len(args) == 1 and not kwargs:
|
||||
dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
|
||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is dict:
|
||||
return dict_vt
|
||||
# We don't have to set the underlying dict_vt in
|
||||
# UserDefinedDictVariable because it will be set to empty
|
||||
# ConstDictVariableTracker in the constructor.
|
||||
return tx.output.side_effects.track_new_user_defined_object(
|
||||
self,
|
||||
args[0],
|
||||
args[1:],
|
||||
)
|
||||
|
||||
if self.fn is tuple:
|
||||
resolved_fn = getattr(self.fn, name)
|
||||
if (
|
||||
resolved_fn is tuple.__new__
|
||||
self.fn is tuple
|
||||
and len(args) == 2
|
||||
and args[1].has_unpack_var_sequence(tx)
|
||||
and not kwargs
|
||||
@ -1140,20 +1128,29 @@ class BuiltinVariable(VariableTracker):
|
||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple:
|
||||
return tuple_vt
|
||||
|
||||
result = (
|
||||
tx.output.side_effects.track_object_new_from_user_defined_class(
|
||||
args[0]
|
||||
)
|
||||
result = tx.output.side_effects.track_new_user_defined_object(
|
||||
self,
|
||||
args[0],
|
||||
args[1:],
|
||||
)
|
||||
# side_effects data structure does not support methods like
|
||||
# tuple.__new__ that uses *args parameters for the __new__
|
||||
# method. Therefore, we manage the *args related functionality
|
||||
# here. For other datastructures, this is done in the __init__
|
||||
# method.
|
||||
result.set_new_args(args[1])
|
||||
result.set_underlying_tuple_vt(tuple_vt)
|
||||
return result
|
||||
|
||||
if self.fn is object and name == "__init__":
|
||||
# object.__init__ is a no-op
|
||||
return variables.ConstantVariable(None)
|
||||
|
||||
if self.fn is dict and name == "fromkeys":
|
||||
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
||||
|
||||
if self.fn is dict:
|
||||
resolved_fn = getattr(self.fn, name)
|
||||
if resolved_fn in dict_methods:
|
||||
if isinstance(args[0], variables.UserDefinedDictVariable):
|
||||
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
||||
elif isinstance(args[0], variables.ConstDictVariable):
|
||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
||||
|
||||
Reference in New Issue
Block a user