mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
[Inductor][CPP] Fix Data Type issue of frexp (#143746)
**Summary** Fix issue: https://github.com/pytorch/pytorch/issues/143729. `frexp` has 1 input but 2 output tensor with different data type, current `deduce_dtype_for_cpp_cse_variable` can't deduce the data type for each output correctly due to missing of output index. In this PR, we set the data type of cse var in the codegen of `frexp` and avoid it being overridden in the following flow. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_frexp ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143746 Approved by: https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
01980cac38
commit
74028cfd0c
@ -211,7 +211,11 @@ class CppCSEVariable(CSEVariable):
|
||||
if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
|
||||
self.is_vec = True
|
||||
# NOTE [Deduce dtype of CppCSEVariable at runtime]
|
||||
self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
|
||||
if self.dtype is None:
|
||||
# Take frexp for example: 2 output with different data type.
|
||||
# The output dtype can't be deduced, since we don't know the idx
|
||||
# of return tensor everywhere invoking update_on_args
|
||||
self.dtype = deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs)
|
||||
assert self.dtype is not None
|
||||
|
||||
def _set_dependent_itervars(self, index: sympy.Expr):
|
||||
|
||||
Reference in New Issue
Block a user