Implement VariableTracker.python_type() (#134215)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134215
Approved by: https://github.com/amjames, https://github.com/jansel
This commit is contained in:
Tom Ritchford
2024-09-04 13:44:38 +00:00
committed by PyTorch MergeBot
parent 0043dcd79e
commit 2c99f17a32
9 changed files with 4 additions and 34 deletions

View File

@ -207,7 +207,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError(f"{self} has no type")
try:
return type(self.as_python_constant())
except NotImplementedError:
raise NotImplementedError(f"{self} has no type") from None
def as_python_constant(self):
"""For constants"""

View File

@ -637,9 +637,6 @@ class BuiltinVariable(VariableTracker):
return f"{self.__class__.__name__}({name})"
def python_type(self):
return type(self.fn)
def as_python_constant(self):
return self.fn

View File

@ -85,9 +85,6 @@ class ConstantVariable(VariableTracker):
def __str__(self) -> str:
return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@ -222,9 +219,6 @@ class EnumVariable(VariableTracker):
def __str__(self) -> str:
return f"EnumVariable({type(self.value)})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value

View File

@ -626,9 +626,6 @@ class SkipFunctionVariable(VariableTracker):
self.value = value
self.reason = reason
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value

View File

@ -30,9 +30,6 @@ class ItertoolsVariable(VariableTracker):
def __repr__(self) -> str:
return f"ItertoolsVariable({self.value})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value

View File

@ -1192,9 +1192,6 @@ class TypingVariable(VariableTracker):
)
unimplemented("typing")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@ -1312,9 +1309,6 @@ class NumpyVariable(VariableTracker):
) -> "VariableTracker":
unimplemented("numpy")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@ -1489,9 +1483,6 @@ class ConstantLikeVariable(VariableTracker):
super().__init__(**kwargs)
self.value = value
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value

View File

@ -1328,9 +1328,6 @@ class TensorSubclassVariable(VariableTracker):
def as_python_constant(self):
return self.value
def python_type(self):
return type(self.value)
class UntypedStorageVariable(VariableTracker):
_nonvar_fields = {

View File

@ -173,9 +173,6 @@ class BaseTorchVariable(VariableTracker):
def as_proxy(self):
return self.value
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value

View File

@ -115,9 +115,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
def as_python_constant(self):
return self.value
def python_type(self):
return type(self.value)
def as_proxy(self):
return self.value