mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement VariableTracker.python_type() (#134215)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134215 Approved by: https://github.com/amjames, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
0043dcd79e
commit
2c99f17a32
@ -207,7 +207,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in a subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{self} has no type")
|
||||
try:
|
||||
return type(self.as_python_constant())
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError(f"{self} has no type") from None
|
||||
|
||||
def as_python_constant(self):
|
||||
"""For constants"""
|
||||
|
@ -637,9 +637,6 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
return f"{self.__class__.__name__}({name})"
|
||||
|
||||
def python_type(self):
|
||||
return type(self.fn)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.fn
|
||||
|
||||
|
@ -85,9 +85,6 @@ class ConstantVariable(VariableTracker):
|
||||
def __str__(self) -> str:
|
||||
return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
@ -222,9 +219,6 @@ class EnumVariable(VariableTracker):
|
||||
def __str__(self) -> str:
|
||||
return f"EnumVariable({type(self.value)})"
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
@ -626,9 +626,6 @@ class SkipFunctionVariable(VariableTracker):
|
||||
self.value = value
|
||||
self.reason = reason
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
@ -30,9 +30,6 @@ class ItertoolsVariable(VariableTracker):
|
||||
def __repr__(self) -> str:
|
||||
return f"ItertoolsVariable({self.value})"
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
@ -1192,9 +1192,6 @@ class TypingVariable(VariableTracker):
|
||||
)
|
||||
unimplemented("typing")
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
@ -1312,9 +1309,6 @@ class NumpyVariable(VariableTracker):
|
||||
) -> "VariableTracker":
|
||||
unimplemented("numpy")
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
@ -1489,9 +1483,6 @@ class ConstantLikeVariable(VariableTracker):
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
@ -1328,9 +1328,6 @@ class TensorSubclassVariable(VariableTracker):
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
|
||||
class UntypedStorageVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
|
@ -173,9 +173,6 @@ class BaseTorchVariable(VariableTracker):
|
||||
def as_proxy(self):
|
||||
return self.value
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
@ -115,9 +115,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def as_proxy(self):
|
||||
return self.value
|
||||
|
||||
|
Reference in New Issue
Block a user