mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 11:14:56 +08:00
[inductor] Use int64_t as index type for all platfroms 4 (#133892)
It is parallel PR to https://github.com/pytorch/pytorch/pull/133819 , and it is append change for @jansel 's comments. 1. For `torch/_inductor/codegen/cpp_wrapper_cpu.py`, revert to origin code to append LL on MacOS and Windows:bdc14ad89a2. For `torch/_inductor/codegen/cpp_utils.py`, append LL on MacOS and Windows forlarge constants. And fix its UTs:3a56b76ce0------------------------------ Another solution for https://github.com/pytorch/pytorch/pull/133615, use `int64_t` as index type for all plartform. ### Development notes: The metioned PR( https://github.com/pytorch/pytorch/pull/133615) is fix the index type not match to parse_arg args types. As reviewed with @jansel , Jason think we need to unificate `INDEX_TYPE` for all platforms. Current code is make code cumbersome: ```python INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long" ``` So, I have some attempts to unificate `INDEX_TYPE` as `long` or `int64_t`. For use `long` as index type: https://github.com/pytorch/pytorch/pull/133768 For use `int64_t` as index type: https://github.com/pytorch/pytorch/pull/133782 Since that, we still discussed which type we will select as final solution.  `long` type is different define and size in different OSs and different compilers. So, @jansel make decision that, we need to select `int64_t` for all platforms. So, I would comtine my work based on https://github.com/pytorch/pytorch/pull/133782. As https://github.com/pytorch/pytorch/pull/133782 still has two issues: 1. std::min/std::max could not match function instances by arg types. It as fixed and validated in PR: https://github.com/pytorch/pytorch/pull/133812 4. Cuda TestMemoryPlanning::test_cpp_wrapper issue by wrong index type. It is fixing in this PR. So, we made final solution in this PR. ### Changes: **1. Use `int64_t` type as index type for all OSs: `Windows`, `Linux` and `MacOS`.** **2. Use static_cast<int64_t>(`constant`) to convert constant to `div_floor_integer` with args type(`int64_t`).** **3. Update `parse_arg` function signature to `int64_t`, which follow the index type.** **4. Append double L(`LL`) to constant on Windows and MacOS, because of their int64_t are are long long.** **5. Fix `std::min/std::max` type miss match by static_cast to `INDEX_TYPE`.** **6. Fix UTs, containts: cuda `TestMemoryPlanning::test_cpp_wrapper`, and `test_indexing.py`.** Pull Request resolved: https://github.com/pytorch/pytorch/pull/133892 Approved by: https://github.com/jansel
This commit is contained in:
@ -83,7 +83,7 @@ LAYOUT_TO_ATEN = {
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long"
|
||||
INDEX_TYPE = "int64_t"
|
||||
|
||||
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
|
||||
|
||||
@ -222,7 +222,9 @@ class CppCSEVariable(CSEVariable):
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr):
|
||||
return f"{int(expr)}LL" if _IS_WINDOWS else f"{int(expr)}L"
|
||||
return (
|
||||
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
||||
)
|
||||
|
||||
def _print_Where(self, expr):
|
||||
c = self.paren(self.doprint(expr.args[0]))
|
||||
@ -236,7 +238,7 @@ class CppPrinter(ExprPrinter):
|
||||
if div != 1:
|
||||
div = self.paren(self.doprint(div))
|
||||
if expr.is_integer:
|
||||
x = f"c10::div_floor_integer({x}, {div})"
|
||||
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
else:
|
||||
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
mod = self.paren(self.doprint(mod))
|
||||
@ -247,7 +249,7 @@ class CppPrinter(ExprPrinter):
|
||||
x = self.paren(self.doprint(x))
|
||||
div = self.paren(self.doprint(div))
|
||||
if expr.is_integer:
|
||||
return f"c10::div_floor_integer({x}, {div})"
|
||||
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||
|
||||
def _print_floor(self, expr):
|
||||
@ -345,7 +347,7 @@ class CppPrinter(ExprPrinter):
|
||||
def _print_Min(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::min({args[0]}, {args[1]})"
|
||||
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
@ -354,7 +356,7 @@ class CppPrinter(ExprPrinter):
|
||||
def _print_Max(self, expr):
|
||||
args = [self._print(a) for a in expr.args]
|
||||
if len(args) == 2:
|
||||
return f"std::max({args[0]}, {args[1]})"
|
||||
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||
else:
|
||||
# Initializer list overload
|
||||
il = "{" + ", ".join(args) + "}"
|
||||
|
||||
Reference in New Issue
Block a user