mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Device generic test framework + dtypes
@ -1,5 +1,7 @@
|
||||
PyTorch's test framework lets you write tests that can run on all available device types. These "device generic" tests are a great way to ensure operations run properly no matter what hardware PyTorch is using.
|
||||
|
||||
### Writing a Device Generic Test
|
||||
|
||||
To write a device generic test you have to do three things:
|
||||
|
||||
1. Your test should accept two arguments, self and device. The latter will be a string designating a device, like 'cpu,' 'cuda,' or 'xla.'
|
||||
@ -8,10 +10,52 @@ To write a device generic test you have to do three things:
|
||||
|
||||
When the test suite is run it will instantiate a device-specific version of the device generic test class. `TestTorchDeviceType` becomes `TestTorchDeviceTypeCPU` and `TestTorchDeviceTypeCUDA`, for example. Its tests are added to each of these classes and have the device type appended to their name. `test_diagonal`, for example, becomes `test_diagonal_cpu` and `test_diagonal_cuda`, respectively. These tests are called with the appropriate device type string when run.
|
||||
|
||||
In Python, a test suite like this
|
||||
|
||||
```python
|
||||
TestTorchDeviceType(TestCase):
|
||||
def testX(self, device):
|
||||
...
|
||||
```
|
||||
|
||||
is translated to
|
||||
|
||||
```python
|
||||
TestTorchDeviceTypeCPU(TestCase):
|
||||
def testX_cpu(self, device='cpu'):
|
||||
...
|
||||
TestTorchDeviceTypeCUDA(TestCase):
|
||||
def testX_cuda(self, device='cuda'):
|
||||
```
|
||||
|
||||
These tests can be run directly, `python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu`, or filtered using pytest. The command `pytest test_torch.py -k 'test_diagonal'` will run both `test_diagonal_cpu` and `test_diagonal_cuda`.
|
||||
|
||||
See `TestTorchDeviceType` in test_torch.py for examples.
|
||||
|
||||
### Writing a Device Generic Test with Dtype Variants
|
||||
|
||||
Tests can accept a third argument, 'dtype,' if they use the @dtypes decorator, like so:
|
||||
|
||||
```python
|
||||
@dtypes(torch.half, torch.float, torch.double)
|
||||
testX(self, device, dtype)
|
||||
```
|
||||
|
||||
This will instantiate variants of testX for each available device type and each specified dtype...
|
||||
|
||||
```python
|
||||
testX_cpu_half(self, 'cpu', torch.half)
|
||||
testX_cpu_float(self, 'cpu', torch.float)
|
||||
testX_cpu_double(self, 'cpu', torch.double)
|
||||
testX_cuda_half(self, 'cuda', torch.half)
|
||||
...
|
||||
```
|
||||
|
||||
These tests can be run just like device generic tests without dtypes.
|
||||
|
||||
|
||||
### Older Methods (Do Not Use)
|
||||
|
||||
Please do not use older methods such as:
|
||||
|
||||
1) `use_cuda` variants:
|
||||
|
Reference in New Issue
Block a user