[inductor] Use int64_t as index type for all platfroms 4 (#133892)

It is parallel PR to https://github.com/pytorch/pytorch/pull/133819 , and it is append change for @jansel 's comments.
1. For `torch/_inductor/codegen/cpp_wrapper_cpu.py`, revert to origin code to append LL on MacOS and Windows: bdc14ad89a
2. For `torch/_inductor/codegen/cpp_utils.py`, append LL on MacOS and Windows forlarge constants. And fix its UTs: 3a56b76ce0

------------------------------
Another solution for https://github.com/pytorch/pytorch/pull/133615, use `int64_t` as index type for all plartform.

### Development notes:
The metioned PR( https://github.com/pytorch/pytorch/pull/133615) is fix the index type not match to parse_arg args types. As reviewed with @jansel , Jason think we need to unificate `INDEX_TYPE` for all platforms.
Current code is make code cumbersome:
```python
INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long"
```

So, I have some attempts to unificate `INDEX_TYPE` as `long` or `int64_t`.
For use `long` as index type: https://github.com/pytorch/pytorch/pull/133768
For use `int64_t` as index type: https://github.com/pytorch/pytorch/pull/133782

Since that, we still discussed which type we will select as final solution.
![image](https://github.com/user-attachments/assets/b23fa577-2d40-4bd6-b934-fb7994fe0bb0)

`long` type is different define and size in different OSs and different compilers. So, @jansel make decision that, we need to select `int64_t` for all platforms. So, I would comtine my work based on https://github.com/pytorch/pytorch/pull/133782.

As https://github.com/pytorch/pytorch/pull/133782 still has two issues:
1. std::min/std::max could not match function instances by arg types. It as fixed and validated in PR: https://github.com/pytorch/pytorch/pull/133812
4. Cuda TestMemoryPlanning::test_cpp_wrapper issue by wrong index type. It is fixing in this PR.

So, we made final solution in this PR.

### Changes:
**1. Use `int64_t` type as index type for all OSs: `Windows`, `Linux` and `MacOS`.**
**2. Use static_cast<int64_t>(`constant`) to convert constant to `div_floor_integer` with args type(`int64_t`).**
**3. Update `parse_arg` function signature to `int64_t`, which follow the index type.**
**4. Append double L(`LL`) to constant on Windows and MacOS, because of their int64_t are are long long.**
**5. Fix `std::min/std::max` type miss match by static_cast to `INDEX_TYPE`.**
**6. Fix UTs, containts: cuda `TestMemoryPlanning::test_cpp_wrapper`, and `test_indexing.py`.**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133892
Approved by: https://github.com/jansel
This commit is contained in:
Xu Han
2024-08-20 16:54:12 +00:00
committed by PyTorch MergeBot
parent 3caf3baabb
commit fbf3fc2a30
4 changed files with 41 additions and 17 deletions

View File

@ -83,7 +83,7 @@ LAYOUT_TO_ATEN = {
_IS_WINDOWS = sys.platform == "win32"
INDEX_TYPE = "int64_t" if _IS_WINDOWS else "long"
INDEX_TYPE = "int64_t"
GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"])
@ -222,7 +222,9 @@ class CppCSEVariable(CSEVariable):
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr):
return f"{int(expr)}LL" if _IS_WINDOWS else f"{int(expr)}L"
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)
def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
@ -236,7 +238,7 @@ class CppPrinter(ExprPrinter):
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer({x}, {div})"
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
@ -247,7 +249,7 @@ class CppPrinter(ExprPrinter):
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer({x}, {div})"
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr):
@ -345,7 +347,7 @@ class CppPrinter(ExprPrinter):
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min({args[0]}, {args[1]})"
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
@ -354,7 +356,7 @@ class CppPrinter(ExprPrinter):
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max({args[0]}, {args[1]})"
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"