[Inductor][CPP] Fix Data Type issue of frexp (#143746)

**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/143729. `frexp` has 1 input but 2 output tensor with different data type, current `deduce_dtype_for_cpp_cse_variable` can't deduce the data type for each output correctly due to missing of output index. In this PR, we set the data type of cse var in the codegen of `frexp` and avoid it being overridden in the following flow.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_frexp
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143746
Approved by: https://github.com/jgong5
This commit is contained in:
leslie-fang-intel
2024-12-23 04:10:58 -08:00
committed by PyTorch MergeBot
parent 01980cac38
commit 74028cfd0c
3 changed files with 22 additions and 5 deletions

View File

@ -211,7 +211,11 @@ class CppCSEVariable(CSEVariable):
if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
self.is_vec = True
# NOTE [Deduce dtype of CppCSEVariable at runtime]
self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
if self.dtype is None:
# Take frexp for example: 2 output with different data type.
# The output dtype can't be deduced, since we don't know the idx
# of return tensor everywhere invoking update_on_args
self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
assert self.dtype is not None
def _set_dependent_itervars(self, index: sympy.Expr):