pow: fix meta function output argument dtype check. (#140287)

Tracking issue: #138399

This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with
its Python ref implementation. The following example shows the inconsistency between the
two:

```python
def run(device):
    S = (5,)
    a = torch.rand(S, device=device, dtype=torch.float32)
    b = 2
    out = torch.empty(S, device=device, dtype=torch.float64)
    return torch.pow(a, b, out=out)

>>> run("cpu")
Traceback (most recent call last):
  File "test.py", line 34, in run
    return torch.pow(a, b, out=out)
RuntimeError: Found dtype Double but expected Float

>>> run("meta")
tensor(..., device='meta', size=(5,), dtype=torch.float64)
```

**~Update:~**

~Note that this happens only for `pow.Tensor_Scalar` overloads. Therefore, this PR needed
further 2 modifications:~

- ~Split the `pow` ref implementation, making `pow.Tensor_Scalar` error on mismatching
output dtypes~
- ~Create a dispatch for `pow` when `_refs.pow()` is called~

**Update:**

Changing the `TensorIteratorConfig` for `pow.Tensor_Scalar` was easier and,
after the discussion below, more correct. The solution was to change the
`TensorIteratorBase::build_output_borrowing_argument_owning_unary_op` function,
setting:

- `cast_common_dtype_to_outputs`; and
- `enforce_safe_casting_to_output`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140287
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi
2024-11-20 00:14:30 +00:00
committed by PyTorch MergeBot
parent a9e54f64ee
commit 446ea2aea5
5 changed files with 42 additions and 19 deletions

View File

@ -206,7 +206,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("softmax"),
xfail("sort"),
xfail("sparse.sampled_addmm"),
xfail("square"),
xfail("squeeze_copy"),
xfail("t_copy"),
xfail("take"),