Revert "[WIP] customize the C++ class for valueT"

This reverts commit c152817926989033fe22bf3f3ec64fde1ba5a129.

Reverted https://github.com/pytorch/pytorch/pull/76911 on behalf of https://github.com/suo
This commit is contained in:
PyTorch MergeBot
2022-05-06 22:36:04 +00:00
parent 70b4746329
commit ec841b0346
3 changed files with 11 additions and 39 deletions

View File

@ -31,25 +31,7 @@ from torchgen.api.types import (
SymIntT,
)
_valueT = None
def getValueT():
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val):
global _valueT
_valueT = val
valueT = BaseCppType("torch::lazy", "Value")
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
@ -75,17 +57,17 @@ def process_ir_type(
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
return BaseCType(valueT)
elif typ.name == BaseTy.Scalar:
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
return BaseCType(valueT)
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
return BaseCType(getValueT())
return BaseCType(valueT)
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
@ -105,7 +87,7 @@ def process_ir_type(
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(getValueT())))
return ListCType(OptionalCType(BaseCType(valueT)))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
@ -123,7 +105,7 @@ def isValueType(typ: CType) -> bool:
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
return typ.type == valueT or typ.type == scalarT or typ.type == SymIntT
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem)
return False