mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
Revert "[WIP] customize the C++ class for valueT"
This reverts commit c152817926989033fe22bf3f3ec64fde1ba5a129. Reverted https://github.com/pytorch/pytorch/pull/76911 on behalf of https://github.com/suo
This commit is contained in:
@ -31,25 +31,7 @@ from torchgen.api.types import (
|
||||
SymIntT,
|
||||
)
|
||||
|
||||
|
||||
_valueT = None
|
||||
|
||||
|
||||
def getValueT():
|
||||
global _valueT
|
||||
if not _valueT:
|
||||
raise NotImplementedError(
|
||||
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
|
||||
)
|
||||
|
||||
return _valueT
|
||||
|
||||
|
||||
def setValueT(val):
|
||||
global _valueT
|
||||
_valueT = val
|
||||
|
||||
|
||||
valueT = BaseCppType("torch::lazy", "Value")
|
||||
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
|
||||
# making it easier to represent special properties of an arg.
|
||||
tensorListValueT = BaseCppType("torch::lazy", "Value")
|
||||
@ -75,17 +57,17 @@ def process_ir_type(
|
||||
"""
|
||||
if isinstance(typ, BaseType):
|
||||
if typ.name == BaseTy.Tensor:
|
||||
return BaseCType(getValueT())
|
||||
return BaseCType(valueT)
|
||||
elif typ.name == BaseTy.Scalar:
|
||||
# at::scalar has special handling,
|
||||
# and is wrapped in an lazy::Value just like at::tensor
|
||||
return BaseCType(getValueT())
|
||||
return BaseCType(valueT)
|
||||
elif typ.name == BaseTy.ScalarType:
|
||||
return BaseCType(scalarTypeT)
|
||||
elif typ.name == BaseTy.int:
|
||||
return BaseCType(longT)
|
||||
elif typ.name == BaseTy.SymInt:
|
||||
return BaseCType(getValueT())
|
||||
return BaseCType(valueT)
|
||||
elif typ.name == BaseTy.bool:
|
||||
return BaseCType(boolT)
|
||||
elif typ.name == BaseTy.float:
|
||||
@ -105,7 +87,7 @@ def process_ir_type(
|
||||
elif isinstance(typ, ListType):
|
||||
if str(typ.elem) == "Tensor?":
|
||||
# TODO(whc) is this actually correct? or should it use a Vector like above
|
||||
return ListCType(OptionalCType(BaseCType(getValueT())))
|
||||
return ListCType(OptionalCType(BaseCType(valueT)))
|
||||
elif str(typ.elem) == "Tensor":
|
||||
# this is a TensorList which comes in from GetTensorList as a Value
|
||||
return BaseCType(tensorListValueT)
|
||||
@ -123,7 +105,7 @@ def isValueType(typ: CType) -> bool:
|
||||
if isinstance(typ, BaseCType):
|
||||
# I am regretting my naming conventions, but now we are wrapping at::scalar in
|
||||
# lazy value, while preserving other 'scalar' types as scalars in the IR
|
||||
return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
|
||||
return typ.type == valueT or typ.type == scalarT or typ.type == SymIntT
|
||||
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
|
||||
return isValueType(typ.elem)
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user