Files
pytorch/test/test_numpy_interop.py
Khushi Agrawal 3ded7b1da3 Replace get_all_ type macros with the ATen dispatch macros. (#71561)
Summary:
Hi, Team!
The PR is motivated from https://github.com/pytorch/pytorch/pull/71153#discussion_r782446738. It aims to replace `get_all` type macros with the ATen dispatch macros.

The files it iterates over are: (Thanks, Lezcano, for the idea!!)

<details>
<summary>

`test/test_autograd.py`</summary>

<p>

```python
43:from torch.testing._internal.common_dtype import get_all_dtypes
8506:        floating_dt = [dt for dt in get_all_dtypes() if dt.is_floating_point]
```

</p>
</details>

<details>
<summary>

`test/test_binary_ufuncs.py`</summary>

<p>

```python
26:    all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes,
27:    get_all_complex_dtypes, get_all_fp_dtypes,
935:    dtypes(*get_all_dtypes(include_bool=False, include_complex=False))
1035:    dtypes(*get_all_dtypes(
1488:    dtypes(*(get_all_dtypes(include_bool=False, include_bfloat16=False)))
1879:    dtypes(*product(get_all_dtypes(include_complex=False), get_all_dtypes(include_complex=False)))
1887:    dtypes(*(get_all_int_dtypes() + [torch.bool]))
1913:    dtypes(*(get_all_fp_dtypes()))
1941:    dtypes(*(get_all_fp_dtypes()))
1977:    dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
2019:    dtypes(*product(get_all_fp_dtypes(), get_all_fp_dtypes()))
2048:    dtypes(*get_all_dtypes())
2110:    dtypes(*product(get_all_dtypes(include_complex=False),
2111:                     get_all_dtypes(include_complex=False)))
2128:            types = [torch.bool, torch.bfloat16] + get_all_int_dtypes()
2173:        if dtypes[1] in get_all_fp_dtypes():
2178:    dtypes(*product(get_all_fp_dtypes(),
2179:                     get_all_fp_dtypes()))
2260:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128})
2261:    dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128})
2273:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128})
2274:    dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128})
2307:    dtypes(*get_all_math_dtypes('cpu'))
2319:    dtypes(*get_all_fp_dtypes(include_bfloat16=False))
2331:    dtypes(*get_all_int_dtypes())
2356:    dtypes(*get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
2393:        if dtype in get_all_int_dtypes():
2614:    dtypes(*get_all_dtypes())
2624:    dtypes(*tuple(itertools.combinations_with_replacement(get_all_dtypes(), 2)))
2806:    dtypes(*list(product(get_all_dtypes(include_complex=False),
2807:                          get_all_dtypes(include_complex=False))))
2866:    dtypes(*list(product(get_all_complex_dtypes(),
2867:                          get_all_complex_dtypes())))
2902:    dtypes(*product(get_all_dtypes(), get_all_dtypes()))
2906:    dtypes(*product(get_all_dtypes(), get_all_dtypes()))
2910:    dtypes(*product(get_all_dtypes(), get_all_dtypes()))
3019:        dtypes = [torch.float, torch.double] + get_all_complex_dtypes()
3221:    dtypes(*get_all_dtypes(include_complex=False))
3407:    dtypes(*list(product(get_all_dtypes(include_bool=False),
3408:                          get_all_dtypes(include_bool=False))))
3504:    dtypes(*product(get_all_dtypes(include_complex=False, include_bfloat16=False),
3505:                     get_all_dtypes(include_complex=False, include_bfloat16=False)))
3516:            if x.dtype in get_all_int_dtypes() + [torch.bool]:
3643:    dtypes(*product(get_all_dtypes(include_complex=False,
3645:                     get_all_dtypes(include_complex=False,
```

</p>
</details>

<details>
<summary>

`test/test_complex.py`</summary>

<p>

```python
6:from torch.testing._internal.common_dtype import get_all_complex_dtypes
11:    dtypes(*get_all_complex_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_foreach.py`</summary>

<p>

```python
18:    get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes,
142:            if dtype in get_all_int_dtypes():
179:            disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
201:            disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
205:                disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool]
211:                disable_fastpath |= dtype not in get_all_complex_dtypes()
241:                bool_int_div = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
246:                    disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool]
248:                    disable_fastpath |= dtype not in get_all_complex_dtypes()
250:                    disable_fastpath |= True and dtype not in get_all_complex_dtypes()
307:        disable_fastpath = dtype in get_all_int_dtypes() + [torch.bool]
365:        if opinfo.name == "_foreach_abs" and dtype in get_all_complex_dtypes():
376:    ops(foreach_unary_op_db, dtypes=get_all_dtypes())
393:         dtypes=get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
401:    ops(foreach_minmax_op_db, dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True))
426:            if ord in (1, 2) and dtype in torch.testing.get_all_fp_dtypes():
439:    dtypes(*get_all_dtypes())
449:    ops(foreach_binary_op_db, dtypes=get_all_dtypes())
481:    ops(foreach_binary_op_db, dtypes=get_all_dtypes())
536:            if dtype in get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div:
545:    ops(foreach_binary_op_db, dtypes=get_all_dtypes())
637:    ops(foreach_pointwise_op_db, allowed_dtypes=get_all_fp_dtypes(include_half=False, include_bfloat16=False))
```

</p>
</details>

<details>
<summary>

`test/test_linalg.py`</summary>

<p>

```python
29:    all_types, floating_types, floating_and_complex_types, get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes,
30:    get_all_fp_dtypes,
111:    dtypes(*(get_all_dtypes()))
794:        float_and_complex_dtypes = get_all_fp_dtypes() + get_all_complex_dtypes()
807:    dtypes(*(get_all_int_dtypes()))
828:    dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes()))
841:        if dtype in get_all_complex_dtypes():
844:    dtypes(*itertools.product(get_all_dtypes(),
845:                               get_all_dtypes()))
855:        for dtypes0, dtypes1, dtypes2 in product(get_all_dtypes(), repeat=3):
5607:                  *get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater)))
5608:    dtypes(*(set(get_all_dtypes()) - {torch.half, torch.bool}))
5644:    dtypes(*(get_all_complex_dtypes() + get_all_fp_dtypes()))
6255:    dtypesIfCUDA(*get_all_complex_dtypes(),
6256:                  *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)),
6292:    dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
6323:    dtypesIfCUDA(*get_all_complex_dtypes(),
6324:                  *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
6325:    dtypes(*get_all_complex_dtypes(), *get_all_fp_dtypes())
6358:    dtypesIfCUDA(*([torch.float, torch.double] + get_all_complex_dtypes()))
6556:    dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes())
6668:    dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes())
6741:    dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_nn.py`</summary>

<p>

```python
37:from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes
50:    onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types, \
8862:                for device in get_all_device_types():
9629:            for dt1 in get_all_math_dtypes(device):
9630:                for dt2 in get_all_math_dtypes(device):
9631:                    for dt3 in get_all_math_dtypes(device):
9648:            for input_dtype in get_all_math_dtypes(device):
9664:            for input_dtype in get_all_math_dtypes(device):
13015:    dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
13034:    dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
13159:    dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
17400:    dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
17768:    dtypesIfCUDA(*get_all_fp_dtypes())
17773:    dtypesIfCUDA(*get_all_fp_dtypes())
17778:    dtypesIfCUDA(*get_all_fp_dtypes())
17783:    dtypesIfCUDA(*get_all_fp_dtypes())
17788:    dtypesIfCUDA(*get_all_fp_dtypes())
17793:    dtypesIfCUDA(*get_all_fp_dtypes())
17798:    dtypesIfCUDA(*get_all_fp_dtypes())
17963:    dtypesIfCUDA(*get_all_fp_dtypes())
17977:    dtypesIfCUDA(*get_all_fp_dtypes())
18684:    def test_cross_entropy_loss_prob_target_all_reductions(self, device):
```

</p>
</details>

<details>
<summary>

`test/test_numpy_interop.py`</summary>

<p>

```python
12:from torch.testing._internal.common_dtype import get_all_dtypes
399:    dtypes(*get_all_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_ops.py`</summary>

<p>

```python
12:from torch.testing._internal.common_dtype import floating_and_complex_types_and, get_all_dtypes
86:        for dtype in get_all_dtypes():
```

</p>
</details>

<details>
<summary>

`test/test_reductions.py`</summary>

<p>

```python
16:    get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes,
360:         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
366:         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
394:         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
750:        for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]:
1404:    dtypes(*get_all_dtypes(include_bool=False, include_complex=False))
1457:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1458:              get_all_complex_dtypes()))
1465:            return dtype in get_all_int_dtypes()
1494:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)))
1501:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)))
1507:    dtypes(*(get_all_complex_dtypes()))
1514:        dtypes = list(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))
1523:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)))
1531:        if dtype in get_all_fp_dtypes():
1608:    dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False,
1837:    dtypes(*get_all_dtypes(include_bool=False, include_complex=False))
1855:    dtypes(*(set(get_all_dtypes(include_bool=False, include_complex=False)) - {torch.uint8}))
3219:        for dtype in get_all_dtypes(include_half=True, include_bfloat16=False,
```

</p>
</details>

<details>
<summary>

`test/test_serialization.py`</summary>

<p>

```python
26:from torch.testing._internal.common_dtype import get_all_dtypes
586:        for device, dtype in product(devices, get_all_dtypes()):
589:            for other_dtype in get_all_dtypes():
```

</p>
</details>

<details>
<summary>

`test/test_shape_ops.py`</summary>

<p>

```python
18:from torch.testing._internal.common_dtype import get_all_dtypes
230:    dtypes(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
232:    dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
344:    dtypes(*get_all_dtypes())
443:    dtypes(*get_all_dtypes())
461:    dtypes(*get_all_dtypes())
570:    dtypes(*get_all_dtypes(include_complex=False))
```

</p>
</details>

<details>
<summary>

`test/test_sort_and_select.py`</summary>

<p>

```python
12:    all_types, all_types_and, floating_types_and, get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes,
136:    dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128})
231:    dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128})
296:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
647:    dtypesIfCUDA(*get_all_fp_dtypes())
678:    dtypesIfCUDA(*(get_all_dtypes(include_complex=False,
682:    dtypes(*(get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False)))
739:    dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128})
740:    dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
799:    dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128})
800:    dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
```

</p>
</details>

<details>
<summary>

`test/test_sparse.py`</summary>

<p>

```python
20:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
29:    floating_and_complex_types, floating_and_complex_types_and, get_all_dtypes, get_all_int_dtypes,
1963:            return dtype in get_all_int_dtypes()
1994:    dtypes(*get_all_dtypes(include_bool=False, include_half=False,
2103:            return dtype in get_all_int_dtypes()
2138:    dtypes(*get_all_dtypes(include_bool=False, include_half=False,
2626:        all_sparse_dtypes = get_all_dtypes(include_complex=True)
2633:        all_sparse_dtypes = get_all_dtypes(include_complex=True)
3230:    dtypes(*get_all_complex_dtypes(),
3231:            *get_all_fp_dtypes(include_half=False, include_bfloat16=False))
3234:                  *get_all_fp_dtypes(
```

</p>
</details>

<details>
<summary>

`test/test_sparse_csr.py`</summary>

<p>

```python
7:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes, floating_and_complex_types, make_tensor
17:from torch.testing._internal.common_dtype import floating_types, get_all_dtypes
120:    dtypes(*get_all_dtypes())
133:    dtypes(*get_all_dtypes())
150:    dtypes(*get_all_dtypes())
180:    dtypes(*get_all_dtypes())
201:    dtypes(*get_all_dtypes())
210:    dtypes(*get_all_dtypes())
225:    dtypes(*get_all_dtypes())
244:    dtypes(*get_all_dtypes())
263:    dtypes(*get_all_dtypes())
285:    dtypes(*get_all_dtypes())
411:    dtypes(*get_all_dtypes())
482:    dtypes(*get_all_dtypes())
502:    dtypes(*get_all_dtypes())
562:    dtypes(*get_all_dtypes())
588:    dtypesIfCUDA(*get_all_complex_dtypes(),
589:                  *get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater))
745:    dtypesIfCUDA(*get_all_complex_dtypes(),
746:                  *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
765:    dtypesIfCUDA(*get_all_complex_dtypes(),
766:                  *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
801:                  *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater,
841:                  *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater,
1182:    dtypes(*get_all_dtypes())
1276:    dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_bfloat16=False))
1286:    dtypes(*get_all_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_tensor_creation_ops.py`</summary>

<p>

```python
21:    onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta, get_all_device_types)
23:    get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
150:        for dt in get_all_dtypes():
160:        for dt in get_all_dtypes():
314:        dtypes = [dtype for dtype in get_all_dtypes() if dtype != torch.bfloat16]
1012:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1013:              get_all_complex_dtypes()))
1032:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1033:              get_all_complex_dtypes()))
1050:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1051:              get_all_complex_dtypes()))
1745:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1779:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1868:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1926:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1954:            do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device)
1956:            do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, None)
1957:            do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device)
2538:        for device in get_all_device_types():
2645:        for dtype in get_all_dtypes():
2678:    dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False) +
2679:              get_all_complex_dtypes()))
2716:    dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
2827:            for dt in get_all_dtypes():
2913:    dtypes(*get_all_dtypes(include_bool=False, include_half=False))
2914:    dtypesIfCUDA(*get_all_dtypes(include_bool=False, include_half=True))
3028:    dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes()))
3033:    dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes()))
3074:    dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_complex=False))
3075:    dtypesIfCUDA(*((get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16])
3077:                    else get_all_dtypes(include_bool=False, include_half=True, include_complex=False)))
3873:    dtypes(*get_all_dtypes())
3884:    dtypes(*get_all_dtypes(include_bool=False))
3916:            for other in get_all_dtypes():
3922:    dtypes(*get_all_dtypes())
3932:    dtypes(*get_all_dtypes(include_bool=False))
3955:    dtypes(*get_all_dtypes(include_bool=False))
3961:    dtypes(*get_all_dtypes(include_bool=False))
3965:    dtypes(*get_all_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_testing.py`</summary>

<p>

```python
25:from torch.testing._internal.common_dtype import get_all_dtypes
31:    dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False,
```

</p>
</details>

<details>
<summary>

`test/test_torch.py`</summary>

<p>

```python
51:    expectedAlertNondeterministic, get_all_device_types, skipXLA)
57:    get_all_fp_dtypes, get_all_int_dtypes, get_all_math_dtypes, get_all_dtypes, get_all_complex_dtypes
296:            for d in get_all_device_types():
323:            for device in get_all_device_types():
324:                for dt1 in get_all_dtypes():
325:                    for dt2 in get_all_dtypes():
343:            all_dtypes = get_all_dtypes()
350:            all_dtypes = get_all_dtypes()
781:            for dtype in get_all_dtypes():
986:            for device in get_all_device_types():
1017:            for device in get_all_device_types():
1018:                for dtype in get_all_math_dtypes(device):
2792:            for device in get_all_device_types():
3186:    dtypes(*get_all_dtypes())
3195:        for error_dtype in get_all_dtypes():
3203:    dtypes(*get_all_dtypes())
3212:        for error_dtype in get_all_dtypes():
4539:    dtypes(*get_all_fp_dtypes())
4545:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
4577:    dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
4578:    dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
4579:    dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False)))
4599:    dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
4600:    dtypesIfCPU(*(get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False)))
4601:    dtypesIfCUDA(*(get_all_dtypes(include_bfloat16=False, include_complex=False)))
4613:        for p_dtype in get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False):
4628:    dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
4629:    dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False)))
4640:    dtypes(*get_all_fp_dtypes())
4723:    dtypes(*get_all_fp_dtypes())
4735:    dtypes(*get_all_fp_dtypes(include_bfloat16=False))
4736:    dtypesIfCUDA(*get_all_fp_dtypes())
4747:    dtypes(*get_all_fp_dtypes())
4761:    dtypes(*get_all_fp_dtypes())
4771:    dtypes(*get_all_fp_dtypes())
4792:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
5302:    dtypes(*get_all_dtypes(include_bfloat16=False))
5322:    dtypes(*get_all_dtypes(include_half=False, include_bfloat16=False))
5323:    dtypesIfCPU(*get_all_dtypes(include_bfloat16=False))
5324:    dtypesIfCUDA(*get_all_dtypes(include_bfloat16=False))
5591:        for dt in get_all_dtypes():
5611:        for dt in get_all_dtypes():
5678:        for dt in get_all_dtypes():
5696:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
5697:    dtypes(*set(get_all_math_dtypes('cpu')))
5746:    dtypes(*get_all_dtypes())
5780:    dtypes(*get_all_dtypes())
5885:    dtypes(*get_all_dtypes())
5902:    dtypes(*get_all_dtypes())
5945:    dtypes(*get_all_dtypes())
5979:    dtypes(*get_all_dtypes(include_bool=False))
6049:    dtypes(*get_all_dtypes(include_bool=False))
6092:    dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) +
6093:              get_all_complex_dtypes()))
6094:    dtypesIfCPU(*get_all_dtypes())
6095:    dtypesIfCUDA(*get_all_dtypes())
6122:    dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) +
6123:              get_all_complex_dtypes()))
6124:    dtypesIfCPU(*get_all_dtypes())
6125:    dtypesIfCUDA(*get_all_dtypes())
6163:    dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) +
6164:              get_all_complex_dtypes()))
6165:    dtypesIfCPU(*get_all_dtypes())
6166:    dtypesIfCUDA(*get_all_dtypes())
6190:    dtypes(*(get_all_complex_dtypes() +
6191:              get_all_int_dtypes()))
6238:    dtypes(*get_all_dtypes())
6323:    dtypes(*get_all_dtypes())
6389:    dtypes(*product(get_all_dtypes(), (torch.uint8, torch.bool)))
6699:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
6700:    dtypes(*set(get_all_math_dtypes('cpu')))
7452:    dtypes(*get_all_dtypes(include_bool=False))
7461:    dtypes(*get_all_dtypes(include_bool=False))
7477:    dtypes(*get_all_dtypes(include_bool=False))
7496:    dtypes(*get_all_dtypes(include_bool=False))
7538:    dtypes(*get_all_dtypes(include_bool=False))
8162:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() +
8163:              get_all_complex_dtypes()))
8175:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() +
8176:              get_all_complex_dtypes()))
```

</p>
</details>

<details>
<summary>

`test/test_type_promotion.py`</summary>

<p>

```python
14:    get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes
187:        for dtype in get_all_dtypes():
262:        dtypes1 = get_all_math_dtypes('cuda')
263:        dtypes2 = get_all_math_dtypes(device)
339:    dtypes(*itertools.product(get_all_dtypes(), get_all_dtypes()))
468:            for dt1 in get_all_math_dtypes(device):
469:                for dt2 in get_all_math_dtypes(device):
519:            for dt1 in get_all_math_dtypes(device):
520:                for dt2 in get_all_math_dtypes(device):
528:        for dt in get_all_math_dtypes(device):
561:        for dtype in get_all_dtypes():
766:                                          dtypes=get_all_math_dtypes(device))
771:                                          dtypes=get_all_math_dtypes(device))
782:                                          dtypes=get_all_math_dtypes(device))
879:        dtypes = get_all_dtypes(include_bfloat16=False)
898:        dtypes = get_all_dtypes(include_bfloat16=False, include_bool=False)
965:    dtypesIfCUDA(*itertools.product(get_all_dtypes(include_bfloat16=False, include_complex=False),
966:                                     get_all_dtypes(include_bfloat16=False, include_complex=False)))
967:    dtypes(*itertools.product(get_all_dtypes(include_half=False, include_bfloat16=False,
969:                               get_all_dtypes(include_half=False, include_bfloat16=False,
976:            return dtype in get_all_int_dtypes() + [torch.bool]
979:            return dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False)
```

</p>
</details>

<details>
<summary>

`test/test_unary_ufuncs.py`</summary>

<p>

```python
24:    floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, get_all_dtypes, get_all_math_dtypes,
25:    get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
517:    dtypes(*(get_all_int_dtypes() + [torch.bool] +
518:              get_all_fp_dtypes(include_bfloat16=False)))
596:    dtypes(*get_all_fp_dtypes(include_half=True, include_bfloat16=False))
611:        invalid_input_dtypes = get_all_int_dtypes() + \
612:            get_all_complex_dtypes() + \
619:        for dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False):
1048:    dtypes(*get_all_math_dtypes('cpu'))
1182:    dtypesIfCUDA(*get_all_fp_dtypes())
1190:    dtypesIfCUDA(*get_all_fp_dtypes())
1205:    dtypesIfCUDA(*get_all_fp_dtypes())
1215:    dtypesIfCUDA(*get_all_fp_dtypes())
1307:    dtypes(*(get_all_dtypes(include_bool=False)))
1349:    dtypes(*(get_all_fp_dtypes(include_half=False) +
1350:              get_all_complex_dtypes()))
1351:    dtypesIfCUDA(*(get_all_fp_dtypes(include_half=True) +
1352:                    get_all_complex_dtypes()))
```

</p>
</details>

<details>
<summary>

`test/test_view_ops.py`</summary>

<p>

```python
19:    get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
124:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
131:    dtypes(*get_all_dtypes(include_bfloat16=False))
213:            for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]:
220:    dtypes(*get_all_dtypes())
224:        for view_dtype in get_all_dtypes():
305:    dtypes(*get_all_complex_dtypes(include_complex32=True))
343:    dtypes(*get_all_dtypes())
354:    dtypes(*get_all_dtypes())
364:    dtypes(*get_all_dtypes())
374:    dtypes(*get_all_dtypes())
384:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
395:    dtypes(*get_all_complex_dtypes())
426:    dtypes(*get_all_complex_dtypes())
451:    dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
1263:    dtypes(*(torch.testing.get_all_dtypes()))
1279:    dtypes(*(torch.testing.get_all_dtypes()))
1405:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1406:              get_all_complex_dtypes()))
1471:    dtypes(*get_all_dtypes(include_bfloat16=False))
1574:    dtypes(*get_all_dtypes())
1601:    dtypes(*get_all_dtypes(include_bfloat16=False))
1632:    dtypes(*get_all_dtypes(include_bfloat16=False))
1711:        for dt in get_all_dtypes():
1717:        for dt in get_all_dtypes():
1724:        for dt in get_all_dtypes():
```

</p>
</details>

I'm looking forward to your viewpoints. Thanks :)

cc: mruberry kshitij12345 anjali411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71561

Reviewed By: samdow

Differential Revision: D34856571

Pulled By: mruberry

fbshipit-source-id: 0dca038bcad5cf69906245c496d2e61ac3876335
(cherry picked from commit b058f67b4313143efa714ab105f36e74083131b9)
2022-03-15 20:31:41 +00:00

432 lines
17 KiB
Python

# Owner(s): ["module: numpy"]
import torch
import numpy as np
from itertools import product
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
from torch.testing._internal.common_dtype import all_types_and_complex_and
# For testing handling NumPy objects and sending tensors to / accepting
# arrays from NumPy.
class TestNumPyInterop(TestCase):
# Note: the warning this tests for only appears once per program, so
# other instances of this warning should be addressed to avoid
# the tests depending on the order in which they're run.
@onlyCPU
def test_numpy_non_writeable(self, device):
arr = np.zeros(5)
arr.flags['WRITEABLE'] = False
self.assertWarns(UserWarning, lambda: torch.from_numpy(arr))
@onlyCPU
def test_numpy_unresizable(self, device) -> None:
x = np.zeros((2, 2))
y = torch.from_numpy(x)
with self.assertRaises(ValueError):
x.resize((5, 5))
z = torch.randn(5, 5)
w = z.numpy()
with self.assertRaises(RuntimeError):
z.resize_(10, 10)
with self.assertRaises(ValueError):
w.resize((10, 10))
@onlyCPU
def test_to_numpy(self, device) -> None:
def get_castable_tensor(shape, dtype):
if dtype.is_floating_point:
dtype_info = torch.finfo(dtype)
# can't directly use min and max, because for double, max - min
# is greater than double range and sampling always gives inf.
low = max(dtype_info.min, -1e10)
high = min(dtype_info.max, 1e10)
t = torch.empty(shape, dtype=torch.float64).uniform_(low, high)
else:
# can't directly use min and max, because for int64_t, max - min
# is greater than int64_t range and triggers UB.
low = max(torch.iinfo(dtype).min, int(-1e10))
high = min(torch.iinfo(dtype).max, int(1e10))
t = torch.empty(shape, dtype=torch.int64).random_(low, high)
return t.to(dtype)
dtypes = [
torch.uint8,
torch.int8,
torch.short,
torch.int,
torch.half,
torch.float,
torch.double,
torch.long,
]
for dtp in dtypes:
# 1D
sz = 10
x = get_castable_tensor(sz, dtp)
y = x.numpy()
for i in range(sz):
self.assertEqual(x[i], y[i])
# 1D > 0 storage offset
xm = get_castable_tensor(sz * 2, dtp)
x = xm.narrow(0, sz - 1, sz)
self.assertTrue(x.storage_offset() > 0)
y = x.numpy()
for i in range(sz):
self.assertEqual(x[i], y[i])
def check2d(x, y):
for i in range(sz1):
for j in range(sz2):
self.assertEqual(x[i][j], y[i][j])
# empty
x = torch.tensor([]).to(dtp)
y = x.numpy()
self.assertEqual(y.size, 0)
# contiguous 2D
sz1 = 3
sz2 = 5
x = get_castable_tensor((sz1, sz2), dtp)
y = x.numpy()
check2d(x, y)
self.assertTrue(y.flags['C_CONTIGUOUS'])
# with storage offset
xm = get_castable_tensor((sz1 * 2, sz2), dtp)
x = xm.narrow(0, sz1 - 1, sz1)
y = x.numpy()
self.assertTrue(x.storage_offset() > 0)
check2d(x, y)
self.assertTrue(y.flags['C_CONTIGUOUS'])
# non-contiguous 2D
x = get_castable_tensor((sz2, sz1), dtp).t()
y = x.numpy()
check2d(x, y)
self.assertFalse(y.flags['C_CONTIGUOUS'])
# with storage offset
xm = get_castable_tensor((sz2 * 2, sz1), dtp)
x = xm.narrow(0, sz2 - 1, sz2).t()
y = x.numpy()
self.assertTrue(x.storage_offset() > 0)
check2d(x, y)
# non-contiguous 2D with holes
xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp)
x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t()
y = x.numpy()
self.assertTrue(x.storage_offset() > 0)
check2d(x, y)
if dtp != torch.half:
# check writeable
x = get_castable_tensor((3, 4), dtp)
y = x.numpy()
self.assertTrue(y.flags.writeable)
y[0][1] = 3
self.assertTrue(x[0][1] == 3)
y = x.t().numpy()
self.assertTrue(y.flags.writeable)
y[0][1] = 3
self.assertTrue(x[0][1] == 3)
def test_to_numpy_bool(self, device) -> None:
x = torch.tensor([True, False], dtype=torch.bool)
self.assertEqual(x.dtype, torch.bool)
y = x.numpy()
self.assertEqual(y.dtype, np.bool_)
for i in range(len(x)):
self.assertEqual(x[i], y[i])
x = torch.tensor([True], dtype=torch.bool)
self.assertEqual(x.dtype, torch.bool)
y = x.numpy()
self.assertEqual(y.dtype, np.bool_)
self.assertEqual(x[0], y[0])
def test_from_numpy(self, device) -> None:
dtypes = [
np.double,
np.float64,
np.float16,
np.complex64,
np.complex128,
np.int64,
np.int32,
np.int16,
np.int8,
np.uint8,
np.longlong,
np.bool_,
]
complex_dtypes = [
np.complex64,
np.complex128,
]
for dtype in dtypes:
array = np.array([1, 2, 3, 4], dtype=dtype)
tensor_from_array = torch.from_numpy(array)
# TODO: change to tensor equality check once HalfTensor
# implements `==`
for i in range(len(array)):
self.assertEqual(tensor_from_array[i], array[i])
# ufunc 'remainder' not supported for complex dtypes
if dtype not in complex_dtypes:
# This is a special test case for Windows
# https://github.com/pytorch/pytorch/issues/22615
array2 = array % 2
tensor_from_array2 = torch.from_numpy(array2)
for i in range(len(array2)):
self.assertEqual(tensor_from_array2[i], array2[i])
# Test unsupported type
array = np.array([1, 2, 3, 4], dtype=np.uint16)
with self.assertRaises(TypeError):
tensor_from_array = torch.from_numpy(array)
# check storage offset
x = np.linspace(1, 125, 125)
x.shape = (5, 5, 5)
x = x[1]
expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1]
self.assertEqual(torch.from_numpy(x), expected)
# check noncontiguous
x = np.linspace(1, 25, 25)
x.shape = (5, 5)
expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t()
self.assertEqual(torch.from_numpy(x.T), expected)
# check noncontiguous with holes
x = np.linspace(1, 125, 125)
x.shape = (5, 5, 5)
x = x[:, 1]
expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1]
self.assertEqual(torch.from_numpy(x), expected)
# check zero dimensional
x = np.zeros((0, 2))
self.assertEqual(torch.from_numpy(x).shape, (0, 2))
x = np.zeros((2, 0))
self.assertEqual(torch.from_numpy(x).shape, (2, 0))
# check ill-sized strides raise exception
x = np.array([3., 5., 8.])
x.strides = (3,)
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
@skipMeta
def test_from_list_of_ndarray_warning(self, device):
warning_msg = r"Creating a tensor from a list of numpy.ndarrays is extremely slow"
with self.assertWarnsOnceRegex(UserWarning, warning_msg):
torch.tensor([np.array([0]), np.array([1])], device=device)
@onlyCPU
def test_ctor_with_numpy_scalar_ctor(self, device) -> None:
dtypes = [
np.double,
np.float64,
np.float16,
np.int64,
np.int32,
np.int16,
np.uint8,
np.bool_,
]
for dtype in dtypes:
self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
@onlyCPU
def test_numpy_index(self, device):
i = np.array([0, 1, 2], dtype=np.int32)
x = torch.randn(5, 5)
for idx in i:
self.assertFalse(isinstance(idx, int))
self.assertEqual(x[idx], x[int(idx)])
@onlyCPU
def test_numpy_array_interface(self, device):
types = [
torch.DoubleTensor,
torch.FloatTensor,
torch.HalfTensor,
torch.LongTensor,
torch.IntTensor,
torch.ShortTensor,
torch.ByteTensor,
]
dtypes = [
np.float64,
np.float32,
np.float16,
np.int64,
np.int32,
np.int16,
np.uint8,
]
for tp, dtype in zip(types, dtypes):
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
if np.dtype(dtype).kind == 'u': # type: ignore[misc]
# .type expects a XxxTensor, which have no type hints on
# purpose, so ignore during mypy type checking
x = torch.tensor([1, 2, 3, 4]).type(tp) # type: ignore[call-overload]
array = np.array([1, 2, 3, 4], dtype=dtype)
else:
x = torch.tensor([1, -2, 3, -4]).type(tp) # type: ignore[call-overload]
array = np.array([1, -2, 3, -4], dtype=dtype)
# Test __array__ w/o dtype argument
asarray = np.asarray(x)
self.assertIsInstance(asarray, np.ndarray)
self.assertEqual(asarray.dtype, dtype)
for i in range(len(x)):
self.assertEqual(asarray[i], x[i])
# Test __array_wrap__, same dtype
abs_x = np.abs(x)
abs_array = np.abs(array)
self.assertIsInstance(abs_x, tp)
for i in range(len(x)):
self.assertEqual(abs_x[i], abs_array[i])
# Test __array__ with dtype argument
for dtype in dtypes:
x = torch.IntTensor([1, -2, 3, -4])
asarray = np.asarray(x, dtype=dtype)
self.assertEqual(asarray.dtype, dtype)
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
if np.dtype(dtype).kind == 'u': # type: ignore[misc]
wrapped_x = np.array([1, -2, 3, -4], dtype=dtype)
for i in range(len(x)):
self.assertEqual(asarray[i], wrapped_x[i])
else:
for i in range(len(x)):
self.assertEqual(asarray[i], x[i])
# Test some math functions with float types
float_types = [torch.DoubleTensor, torch.FloatTensor]
float_dtypes = [np.float64, np.float32]
for tp, dtype in zip(float_types, float_dtypes):
x = torch.tensor([1, 2, 3, 4]).type(tp) # type: ignore[call-overload]
array = np.array([1, 2, 3, 4], dtype=dtype)
for func in ['sin', 'sqrt', 'ceil']:
ufunc = getattr(np, func)
res_x = ufunc(x)
res_array = ufunc(array)
self.assertIsInstance(res_x, tp)
for i in range(len(x)):
self.assertEqual(res_x[i], res_array[i])
# Test functions with boolean return value
for tp, dtype in zip(types, dtypes):
x = torch.tensor([1, 2, 3, 4]).type(tp) # type: ignore[call-overload]
array = np.array([1, 2, 3, 4], dtype=dtype)
geq2_x = np.greater_equal(x, 2)
geq2_array = np.greater_equal(array, 2).astype('uint8')
self.assertIsInstance(geq2_x, torch.ByteTensor)
for i in range(len(x)):
self.assertEqual(geq2_x[i], geq2_array[i])
@onlyCPU
def test_multiplication_numpy_scalar(self, device) -> None:
for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]:
for t_dtype in [torch.float, torch.double]:
# mypy raises an error when np.floatXY(2.0) is called
# even though this is valid code
np_sc = np_dtype(2.0) # type: ignore[abstract, arg-type]
t = torch.ones(2, requires_grad=True, dtype=t_dtype)
r1 = t * np_sc
self.assertIsInstance(r1, torch.Tensor)
self.assertTrue(r1.dtype == t_dtype)
self.assertTrue(r1.requires_grad)
r2 = np_sc * t
self.assertIsInstance(r2, torch.Tensor)
self.assertTrue(r2.dtype == t_dtype)
self.assertTrue(r2.requires_grad)
@onlyCPU
def test_parse_numpy_int(self, device):
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
self.assertRaisesRegex(RuntimeError, "Overflow",
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1))) # type: ignore[call-overload]
# https://github.com/pytorch/pytorch/issues/29252
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
scalar = 3
np_arr = np.array([scalar], dtype=nptype)
np_val = np_arr[0]
# np integral type can be treated as a python int in native functions with
# int parameters:
self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val))
self.assertEqual(torch.ones([2, 2, 2, 2]).mean(scalar), torch.ones([2, 2, 2, 2]).mean(np_val))
# numpy integral type parses like a python int in custom python bindings:
self.assertEqual(torch.Storage(np_val).size(), scalar) # type: ignore[attr-defined]
tensor = torch.tensor([2], dtype=torch.int)
tensor[0] = np_val
self.assertEqual(tensor[0], np_val)
# Original reported issue, np integral type parses to the correct
# PyTorch integral type when passed for a `Scalar` parameter in
# arithmetic operations:
t = torch.from_numpy(np_arr)
self.assertEqual((t + np_val).dtype, t.dtype)
self.assertEqual((np_val + t).dtype, t.dtype)
def test_has_storage_numpy(self, device):
for dtype in [np.float32, np.float64, np.int64,
np.int32, np.int16, np.uint8]:
arr = np.array([1], dtype=dtype)
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.float32).storage())
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.double).storage())
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.int).storage())
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.long).storage())
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.uint8).storage())
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_numpy_scalar_cmp(self, device, dtype):
if dtype.is_complex:
tensors = (torch.tensor(complex(1, 3), dtype=dtype, device=device),
torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device),
torch.tensor([[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device))
else:
tensors = (torch.tensor(3, dtype=dtype, device=device),
torch.tensor([1, 0, -3], dtype=dtype, device=device),
torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device))
for tensor in tensors:
if dtype == torch.bfloat16:
with self.assertRaises(TypeError):
np_array = tensor.cpu().numpy()
continue
np_array = tensor.cpu().numpy()
for t, a in product((tensor.flatten()[0], tensor.flatten()[0].item()),
(np_array.flatten()[0], np_array.flatten()[0].item())):
self.assertEqual(t, a)
if dtype == torch.complex64 and torch.is_tensor(t) and type(a) == np.complex64:
# TODO: Imaginary part is dropped in this case. Need fix.
# https://github.com/pytorch/pytorch/issues/43579
self.assertFalse(t == a)
else:
self.assertTrue(t == a)
instantiate_device_type_tests(TestNumPyInterop, globals())
if __name__ == '__main__':
run_tests()