mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
small cleanups
@ -51,7 +51,7 @@ PyTorch extends Python test frameworks like unittest and pytest with its own tes
|
||||
|
||||
"Instantiating," or, equivalently, "generating" tests programmatically has pros and cons. A pro is that it makes writing tests simpler. A con is that it's harder to understand what tests actually cover and where generated tests come from. The next section will elaborate more on PyTorch's test coverage.
|
||||
|
||||
Before we look at the device-generic test framework, note that using it is not a requirement for all tests. Tests in PyTorch not using the framework are simply methods whose names start with "test" and that accept only `self` as an argument, like this:
|
||||
Before we look at the device-generic test framework, note that using it is not a requirement for all tests. Tests in PyTorch not using the framework are simply `TestCase` class methods whose names start with "test" and that accept only `self` as an argument, like this:
|
||||
|
||||
```python
|
||||
# a simple test
|
||||
@ -59,7 +59,7 @@ def test_foo(self):
|
||||
...
|
||||
```
|
||||
|
||||
Many of PyTorch's test classes use the device-generic framework, however. These classes and the tests written in them are actually templates, however. The simplest version of these test templates just accept an additional device argument:
|
||||
Many of PyTorch's test classes use the device-generic framework, however. These classes and the tests written in them are actually templates. The simplest version of these test templates just accept an additional device argument:
|
||||
|
||||
```python
|
||||
class TestFoo(TestCase):
|
||||
@ -69,7 +69,7 @@ class TestFoo(TestCase):
|
||||
...
|
||||
```
|
||||
|
||||
When this test class is run it will typically generate a version of this template for each available device type. For example, on a machine with a CUDA device this template will create two two classes: `TestFooCPU` and `TestFooCUDA`, with tests `test_foo_cpu` and `test_foo_cuda`, respecitvely. The `device` argument passed to `test_foo_cpu` will be "cpu", and the `device` argument passed to `test_foo_cuda` will be a string representing a CUDA device like "cuda:0" or "cuda:1". To elaborate, we can think of the template being translated at runtime to:
|
||||
When this test class is run it will typically generate a version of this template for each available device type. For example, on a machine with a CUDA device this template will create two two classes: `TestFooCPU` and `TestFooCUDA`, with tests `test_foo_cpu` and `test_foo_cuda`, respectively. The `device` argument passed to `test_foo_cpu` will be "cpu", and the `device` argument passed to `test_foo_cuda` will be a string representing a CUDA device like "cuda:0" or "cuda:1". To elaborate, we can think of the template being translated at runtime to:
|
||||
|
||||
```python
|
||||
class TestFooCPU(TestCase):
|
||||
@ -158,7 +158,11 @@ Using the device-generic test framework is essential when testing new tensor ope
|
||||
|
||||
The OpInfo pattern, described in this section, is the future of testing tensor operations in PyTorch. This pattern is intended to make testing tensor operations simpler, since tensor operations have so much in common that writing a test from scratch for each operation would be redundant and exhausting. Instead, OpInfos contain metadata that test templates can use to test properties of many operators at once.
|
||||
|
||||
The OpInfo class and every OpInfo is defined in [common_methods_invocations.py](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py). The OpInfo class has tensor operation-related metadata like which dtypes the operator supports, and whether it supports inplace gradients or not. One of the most important properties of each OpInfo is its `sample_inputs()` function, used to acquire "sample inputs" to the function with different properties.
|
||||
The OpInfo class and every OpInfo is defined in [common_methods_invocations.py](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py). The OpInfo class has tensor operation-related metadata like
|
||||
- which dtypes the operator supports
|
||||
- whether it supports inplace gradients or not.
|
||||
|
||||
One of the most important properties of each OpInfo is its `sample_inputs()` method, used to acquire "sample inputs" to the function with different properties.
|
||||
|
||||
In addition to the OpInfo base class, there are several derived classes with additional structure. For example, UnaryUfuncInfo is a specialization of OpInfos for unary universal functions. The details of each of these classes is beyond the scope of this section. See the documentation in common_methods_invocations.py for details.
|
||||
|
||||
@ -184,9 +188,11 @@ This OpInfo, or more precisely this UnaryUfuncInfo, is for [torch.asinh()](https
|
||||
|
||||
This is a lot to review at once. In practice writing an OpInfo is typically an iterative process where a simple OpInfo is written, and then the test suites will identify issues with the operator's specification and suggest changes. For example, if an OpInfo incorrectly states that an operator supports a dtype, like `torch.half`, then a test will fail and point out that the operator does not, in fact, support that dtype. The OpInfo can then be updated using this information. If you need help implementing an OpInfo then file an issue.
|
||||
|
||||
OpInfos are used in two ways. First, test_ops.py tests common properties of every tensor operation. Second, tests like test_unary_ufuncs.py can use classes derived from OpInfo to test properties of a subset of the tensor operations. Understanding what test_ops.py tests is important when implementing a new tensor operator or updating an existing one so you know what additional tests, if any, are needed. Knowing how to write a test template that consumes OpInfos is interesting, but only directly useful if you plan to write a test template that runs on multiple tensor operators.
|
||||
OpInfos are used in two ways. First, test_ops.py tests common properties of every tensor operation. Second, tests like test_unary_ufuncs.py can use classes derived from OpInfo to create exemplar Tensors to test properties of a subset of the tensor operations. Understanding what test_ops.py tests is important when implementing a new tensor operator or updating an existing one so you know what additional tests, if any, are needed. Knowing how to write a test template that consumes OpInfos is interesting, but only directly useful if you plan to write a test template that runs on multiple tensor operators.
|
||||
|
||||
test_ops.py has tests that use the `@ops` decorator. Like the `@dtypes` decorator, the `@ops` decorator is part of the device-generic test framework and instantiates variants of test templates. For example:
|
||||
### The `@ops` decorator
|
||||
|
||||
test_ops.py has tests that use the `@ops` decorator. Like the `@dtypes` decorator, the `@ops` decorator is part of the device-generic test framework and instantiates variants of test templates based on an iterable of `OpInfo`s. For example:
|
||||
|
||||
```
|
||||
@ops(op_db)
|
||||
@ -194,7 +200,7 @@ def test_foo(self, device, dtype, op):
|
||||
...
|
||||
```
|
||||
|
||||
This will instantiate a variant of test_foo for every OpInfo (all OpInfos are included in the `op_db` iterable), for every dtype that OpInfo's operator supports, and on each device type. As of this writing this is hundreds of tests, and test templates designed to work on every tensor operation are inherently restricted to testing fundamental properties of these operators. For example, that their derivatives are properly implemented. In particular, test_ops.py tests the following:
|
||||
This will instantiate a variant of test_foo for every OpInfo (all OpInfos are included in the `op_db` iterable), for every dtype that OpInfo's operator supports, and on each device type. As of this writing this is hundreds of tests. Test templates designed to work on every tensor operation are inherently restricted to testing fundamental properties of these operators. For example, that their derivatives are properly implemented. In particular, test_ops.py tests the following:
|
||||
|
||||
- that the operation's dtypes are correctly enumerated
|
||||
- that the operation's function, method, and inplace variants are equivalent
|
||||
@ -265,7 +271,7 @@ Please do not use older methods such as:
|
||||
# .......
|
||||
```
|
||||
|
||||
This solution have lots of code duplication and prone to errors.
|
||||
This solution has lots of code duplication and is prone to errors.
|
||||
|
||||
4) Casting lambdas variants:
|
||||
|
||||
|
Reference in New Issue
Block a user