Device generic test framework + dtypes

Mike Ruberry
2019-09-18 21:32:57 -07:00
parent 96866ae771
commit 629d8d819a

@ -1,5 +1,7 @@
PyTorch's test framework lets you write tests that can run on all available device types. These "device generic" tests are a great way to ensure operations run properly no matter what hardware PyTorch is using.
### Writing a Device Generic Test
To write a device generic test you have to do three things:
1. Your test should accept two arguments, self and device. The latter will be a string designating a device, like 'cpu,' 'cuda,' or 'xla.'
@ -8,10 +10,52 @@ To write a device generic test you have to do three things:
When the test suite is run it will instantiate a device-specific version of the device generic test class. `TestTorchDeviceType` becomes `TestTorchDeviceTypeCPU` and `TestTorchDeviceTypeCUDA`, for example. Its tests are added to each of these classes and have the device type appended to their name. `test_diagonal`, for example, becomes `test_diagonal_cpu` and `test_diagonal_cuda`, respectively. These tests are called with the appropriate device type string when run.
In Python, a test suite like this
```python
TestTorchDeviceType(TestCase):
def testX(self, device):
...
```
is translated to
```python
TestTorchDeviceTypeCPU(TestCase):
def testX_cpu(self, device='cpu'):
...
TestTorchDeviceTypeCUDA(TestCase):
def testX_cuda(self, device='cuda'):
```
These tests can be run directly, `python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu`, or filtered using pytest. The command `pytest test_torch.py -k 'test_diagonal'` will run both `test_diagonal_cpu` and `test_diagonal_cuda`.
See `TestTorchDeviceType` in test_torch.py for examples.
### Writing a Device Generic Test with Dtype Variants
Tests can accept a third argument, 'dtype,' if they use the @dtypes decorator, like so:
```python
@dtypes(torch.half, torch.float, torch.double)
testX(self, device, dtype)
```
This will instantiate variants of testX for each available device type and each specified dtype...
```python
testX_cpu_half(self, 'cpu', torch.half)
testX_cpu_float(self, 'cpu', torch.float)
testX_cpu_double(self, 'cpu', torch.double)
testX_cuda_half(self, 'cuda', torch.half)
...
```
These tests can be run just like device generic tests without dtypes.
### Older Methods (Do Not Use)
Please do not use older methods such as:
1) `use_cuda` variants: