Fix identity expansion (#165066)

In some cases, we wrap indexing with `Identity` to prevent expansion from int32 -> int64 range. There are some checks in codegen which intend to check for constants, which did not handle Identity. Update these checks and update Identity so that it recursively prints inputs.

Fix for https://github.com/pytorch/pytorch/issues/164700

Replaces https://github.com/pytorch/pytorch/pull/160190 cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @njriasan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165066
Approved by: https://github.com/njriasan, https://github.com/shunting314, https://github.com/jansel
This commit is contained in:
eellison
2025-10-09 14:46:53 -07:00
committed by PyTorch MergeBot
parent 70925bdf82
commit d272ed4b3e
3 changed files with 54 additions and 2 deletions

View File

@ -1328,6 +1328,10 @@ class Identity(sympy.Function):
# pyrefly: ignore # missing-attribute
return f"Identity({self.args[0]})"
def _sympystr(self, printer):
"""Controls how sympy's StrPrinter prints this"""
return f"({printer.doprint(self.args[0])})"
def _eval_is_real(self):
# pyrefly: ignore # missing-attribute
return self.args[0].is_real