mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[BE] use DeviceIndex instead of int64_t for related device interfaces (#103068)
This PR unifies the device interfaces in aten/*cpp and torch/csrc/*cpp to use **c10::DeviceIndex**. Pull Request resolved: https://github.com/pytorch/pytorch/pull/103068 Approved by: https://github.com/malfet
This commit is contained in:
@ -661,6 +661,7 @@ def argument_type_str(
|
||||
BaseTy.Storage,
|
||||
BaseTy.Layout,
|
||||
BaseTy.Device,
|
||||
BaseTy.DeviceIndex,
|
||||
BaseTy.MemoryFormat,
|
||||
BaseTy.Dimname,
|
||||
BaseTy.Stream,
|
||||
@ -907,7 +908,7 @@ def argument_type_str_pyi(t: Type) -> str:
|
||||
add_optional = True
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.int:
|
||||
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
|
||||
ret = "_int"
|
||||
if t.name == BaseTy.SymInt:
|
||||
ret = "Union[_int, SymInt]"
|
||||
@ -1255,6 +1256,8 @@ def arg_parser_unpack_method(
|
||||
return "scalartypeWithDefault" if has_default_init else "scalartype"
|
||||
elif t.name == BaseTy.Device:
|
||||
return "deviceWithDefault" if has_default_init else "device"
|
||||
elif t.name == BaseTy.DeviceIndex:
|
||||
return "toInt64"
|
||||
elif t.name == BaseTy.int:
|
||||
return "toInt64"
|
||||
elif t.name == BaseTy.SymInt:
|
||||
|
Reference in New Issue
Block a user