mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Check for boolean values as argument on pow function. (#114133)
Hello everyone! 😄 Also @lezcano , nice to meet you! :) Sorry if I miss anything, this is my first time around here. 🙃 This PR basically makes the same behaviour for cuda when using `torch.pow`. Basically Python considers True as 1 and False as 0. I just added this check into `pow` function. From what I understood, when I do `.equal` for `Scalar` that is boolean, I'm sure that types match so that won't cause more trouble. I know that the issue suggest to disable this case but that could be a little more complicated, in my humble opinion. And that can create some compability problems too, I guess. My argument is that code below is correct for native language, so I guess it does makes sense sending booleans as Scalar. ``` $ x = True $ x + x 2 ``` This was my first test: ``` Python 3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import torch >>> torch.pow(torch.tensor([1, 2], device='cuda'), True) tensor([1, 2], device='cuda:0') >>> torch.pow(torch.tensor([1, 2]), True) tensor([1, 2]) >>> torch.pow(torch.tensor([1, 2]), False) tensor([1, 1]) >>> torch.pow(torch.tensor([1, 2], device='cuda'), False) tensor([1, 1], device='cuda:0') ``` I've run `test_torch.py` and got following results, so my guess is that I didn't break anything. I was just looking for a test that uses linear regression, as suggested. ``` Ran 1619 tests in 52.363s OK (skipped=111) [TORCH_VITAL] Dataloader.enabled True [TORCH_VITAL] Dataloader.basic_unit_test TEST_VALUE_STRING [TORCH_VITAL] CUDA.used true ``` (I can paste whole log, if necessary) If this is a bad idea overall, dont worry about it. It's not a big deal, it's actually a two line change 😅 so can we talk of how do things in a different strategy. For the record I've signed the agreement already. And I didn't run linter because it's not working 😞 . Looks like PyYaml 6.0 is broken and there's a 6.0.1 fix already but I have no idea how to update that 😅 Fixes #113198 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114133 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
aca6446a6e
commit
5f504d1de7
@ -50,9 +50,9 @@ TORCH_IMPL_FUNC(pow_Tensor_Tensor_out) (const Tensor& base, const Tensor& exp, c
|
|||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
|
TORCH_IMPL_FUNC(pow_Tensor_Scalar_out) (const Tensor& base, const Scalar& exp, const Tensor& out) {
|
||||||
if (exp.equal(0.0)) {
|
if (exp.equal(0.0) || exp.equal(false)) {
|
||||||
out.fill_(1);
|
out.fill_(1);
|
||||||
} else if (exp.equal(1.0)) {
|
} else if (exp.equal(1.0) || exp.equal(true) ) {
|
||||||
out.copy_(base);
|
out.copy_(base);
|
||||||
} else {
|
} else {
|
||||||
pow_tensor_scalar_stub(device_type(), *this, exp);
|
pow_tensor_scalar_stub(device_type(), *this, exp);
|
||||||
|
@ -1345,7 +1345,7 @@ class TestBinaryUfuncs(TestCase):
|
|||||||
(100, 100), low=1, high=range_high, dtype=dtype, device=device
|
(100, 100), low=1, high=range_high, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]
|
exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False]
|
||||||
complex_exponents = [
|
complex_exponents = [
|
||||||
-2.5j,
|
-2.5j,
|
||||||
-1.0j,
|
-1.0j,
|
||||||
|
Reference in New Issue
Block a user