mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
pow
: fix meta function output argument dtype check. (#140287)
Tracking issue: #138399 This PR changes the `pow` C++ implementation, making its C++ meta kernel consistent with its Python ref implementation. The following example shows the inconsistency between the two: ```python def run(device): S = (5,) a = torch.rand(S, device=device, dtype=torch.float32) b = 2 out = torch.empty(S, device=device, dtype=torch.float64) return torch.pow(a, b, out=out) >>> run("cpu") Traceback (most recent call last): File "test.py", line 34, in run return torch.pow(a, b, out=out) RuntimeError: Found dtype Double but expected Float >>> run("meta") tensor(..., device='meta', size=(5,), dtype=torch.float64) ``` **~Update:~** ~Note that this happens only for `pow.Tensor_Scalar` overloads. Therefore, this PR needed further 2 modifications:~ - ~Split the `pow` ref implementation, making `pow.Tensor_Scalar` error on mismatching output dtypes~ - ~Create a dispatch for `pow` when `_refs.pow()` is called~ **Update:** Changing the `TensorIteratorConfig` for `pow.Tensor_Scalar` was easier and, after the discussion below, more correct. The solution was to change the `TensorIteratorBase::build_output_borrowing_argument_owning_unary_op` function, setting: - `cast_common_dtype_to_outputs`; and - `enforce_safe_casting_to_output`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140287 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9e54f64ee
commit
446ea2aea5
@ -1483,7 +1483,7 @@ class TestBinaryUfuncs(TestCase):
|
||||
else:
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found dtype \\w+ but expected \\w+",
|
||||
r"result type \w+ can't be cast to the desired output type \w+",
|
||||
lambda: actual.pow_(exponent),
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user