[ts_converter] Fix prim::dtype (#128517)

Summary: prim::dtype has the signature `(Tensor a) -> int`, where it gets the dtype of the tensor and returns the integer corresponding to this dtype based on the enum in ScalarType.h. Previously we were converting prim::dtype by returning the actual dtype of the tensor (ex. torch.float32). This causes some incorrect control flow to behavior, specifically where it checks if `prim::dtype(tensor) in [3, 5, 7]`, where [3, 5, 7] correspond to torch.int32, torch.float16, torch.float64. This control flow would always returns False because we would be comparing torch.float32 against the integers [3, 5, 7], which is a type mismatch.

Test Plan: 7/22 internal models now are convertable and runnable in eager and sigmoid! P1410243909

Reviewed By: jiashenC

Differential Revision: D58469232

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128517
Approved by: https://github.com/jiashenC
This commit is contained in:
Angela Yi
2024-06-12 23:02:50 +00:00
committed by PyTorch MergeBot
parent 2fa6f80b13
commit 3bc2004f91
2 changed files with 32 additions and 7 deletions

View File

@ -491,11 +491,11 @@ class TestConverter(TestCase):
def test_ts2ep_converter_contains(self):
class MIn(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.dtype in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
return x.dtype in [torch.float32, torch.float64]
class MNotIn(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.dtype in [-1]
return x.dtype in [torch.int8]
class MTensorIn(torch.nn.Module):
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):

View File

@ -54,6 +54,35 @@ def get_node_for_param_and_buffer(fx_graph, name, is_top_level_graph):
return fx_graph.placeholder(name)
_TORCH_DTYPE_TO_ENUM = {
torch.uint8: 0,
torch.int8: 1,
torch.int16: 2,
torch.int32: 3,
torch.int64: 4,
torch.float16: 5,
torch.float32: 6,
torch.float64: 7,
torch.complex32: 8,
torch.complex64: 9,
torch.complex128: 10,
torch.bool: 11,
torch.bfloat16: 15,
}
def get_dtype_as_int(tensor):
"""
prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
the tensor and returns the integer corresponding to this dtype based on the
enum in ScalarType.h
"""
dtype = tensor.dtype
if dtype not in _TORCH_DTYPE_TO_ENUM:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _TORCH_DTYPE_TO_ENUM[dtype]
# Those operators will be automatically populated to a instance method
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
# Please check __init__ for method population implementations.
@ -63,6 +92,7 @@ kind_to_standard_operators = {
"aten::__isnot__": operator.is_not,
"aten::__not__": operator.not_,
"aten::__contains__": operator.contains,
"prim::dtype": get_dtype_as_int,
}
@ -358,11 +388,6 @@ class TS2FXGraphConverter:
else:
raise ValueError(f"Unsupported JitType ({input_type}) when get device")
def convert_prim_dtype(self, node: torch._C.Node):
dtype = node.input().type().dtype()
output_name = node.output().debugName()
self.constant_map[output_name] = dtype
def convert_prim_GetAttr(self, node: torch._C.Node):
def get_attr(name: str):
if name in self.attribute_map: