mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ts_converter] Fix prim::dtype (#128517)
Summary: prim::dtype has the signature `(Tensor a) -> int`, where it gets the dtype of the tensor and returns the integer corresponding to this dtype based on the enum in ScalarType.h. Previously we were converting prim::dtype by returning the actual dtype of the tensor (ex. torch.float32). This causes some incorrect control flow to behavior, specifically where it checks if `prim::dtype(tensor) in [3, 5, 7]`, where [3, 5, 7] correspond to torch.int32, torch.float16, torch.float64. This control flow would always returns False because we would be comparing torch.float32 against the integers [3, 5, 7], which is a type mismatch. Test Plan: 7/22 internal models now are convertable and runnable in eager and sigmoid! P1410243909 Reviewed By: jiashenC Differential Revision: D58469232 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128517 Approved by: https://github.com/jiashenC
This commit is contained in:
committed by
PyTorch MergeBot
parent
2fa6f80b13
commit
3bc2004f91
@ -491,11 +491,11 @@ class TestConverter(TestCase):
|
||||
def test_ts2ep_converter_contains(self):
|
||||
class MIn(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.dtype in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
return x.dtype in [torch.float32, torch.float64]
|
||||
|
||||
class MNotIn(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.dtype in [-1]
|
||||
return x.dtype in [torch.int8]
|
||||
|
||||
class MTensorIn(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
|
||||
|
@ -54,6 +54,35 @@ def get_node_for_param_and_buffer(fx_graph, name, is_top_level_graph):
|
||||
return fx_graph.placeholder(name)
|
||||
|
||||
|
||||
_TORCH_DTYPE_TO_ENUM = {
|
||||
torch.uint8: 0,
|
||||
torch.int8: 1,
|
||||
torch.int16: 2,
|
||||
torch.int32: 3,
|
||||
torch.int64: 4,
|
||||
torch.float16: 5,
|
||||
torch.float32: 6,
|
||||
torch.float64: 7,
|
||||
torch.complex32: 8,
|
||||
torch.complex64: 9,
|
||||
torch.complex128: 10,
|
||||
torch.bool: 11,
|
||||
torch.bfloat16: 15,
|
||||
}
|
||||
|
||||
|
||||
def get_dtype_as_int(tensor):
|
||||
"""
|
||||
prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
|
||||
the tensor and returns the integer corresponding to this dtype based on the
|
||||
enum in ScalarType.h
|
||||
"""
|
||||
dtype = tensor.dtype
|
||||
if dtype not in _TORCH_DTYPE_TO_ENUM:
|
||||
raise RuntimeError(f"Unsupported dtype {dtype}")
|
||||
return _TORCH_DTYPE_TO_ENUM[dtype]
|
||||
|
||||
|
||||
# Those operators will be automatically populated to a instance method
|
||||
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
|
||||
# Please check __init__ for method population implementations.
|
||||
@ -63,6 +92,7 @@ kind_to_standard_operators = {
|
||||
"aten::__isnot__": operator.is_not,
|
||||
"aten::__not__": operator.not_,
|
||||
"aten::__contains__": operator.contains,
|
||||
"prim::dtype": get_dtype_as_int,
|
||||
}
|
||||
|
||||
|
||||
@ -358,11 +388,6 @@ class TS2FXGraphConverter:
|
||||
else:
|
||||
raise ValueError(f"Unsupported JitType ({input_type}) when get device")
|
||||
|
||||
def convert_prim_dtype(self, node: torch._C.Node):
|
||||
dtype = node.input().type().dtype()
|
||||
output_name = node.output().debugName()
|
||||
self.constant_map[output_name] = dtype
|
||||
|
||||
def convert_prim_GetAttr(self, node: torch._C.Node):
|
||||
def get_attr(name: str):
|
||||
if name in self.attribute_map:
|
||||
|
Reference in New Issue
Block a user