[BE] use DeviceIndex instead of int64_t for related device interfaces (#103068)

This PR unifies the device interfaces in aten/*cpp and torch/csrc/*cpp to use  **c10::DeviceIndex**.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103068
Approved by: https://github.com/malfet
This commit is contained in:
cyy
2023-08-25 20:16:11 +00:00
committed by PyTorch MergeBot
parent 4656e09431
commit d9fb7166d6
18 changed files with 73 additions and 56 deletions

View File

@ -661,6 +661,7 @@ def argument_type_str(
BaseTy.Storage,
BaseTy.Layout,
BaseTy.Device,
BaseTy.DeviceIndex,
BaseTy.MemoryFormat,
BaseTy.Dimname,
BaseTy.Stream,
@ -907,7 +908,7 @@ def argument_type_str_pyi(t: Type) -> str:
add_optional = True
if isinstance(t, BaseType):
if t.name == BaseTy.int:
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
ret = "_int"
if t.name == BaseTy.SymInt:
ret = "Union[_int, SymInt]"
@ -1255,6 +1256,8 @@ def arg_parser_unpack_method(
return "scalartypeWithDefault" if has_default_init else "scalartype"
elif t.name == BaseTy.Device:
return "deviceWithDefault" if has_default_init else "device"
elif t.name == BaseTy.DeviceIndex:
return "toInt64"
elif t.name == BaseTy.int:
return "toInt64"
elif t.name == BaseTy.SymInt: