mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[OpenReg] Add AMP Integration guide for accelerators (#162050)
Fix part of #158917 Add AMP integration document and OpenReg code as example to explain steps of integration. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162050 Approved by: https://github.com/albanD Co-authored-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f29c47a4f
commit
77354e22e1
72
docs/source/accelerator/amp.md
Normal file
72
docs/source/accelerator/amp.md
Normal file
@ -0,0 +1,72 @@
|
||||
# Automatic Mixed Precision
|
||||
|
||||
## Background
|
||||
|
||||
Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.
|
||||
|
||||
Key components include:
|
||||
|
||||
- [**Autocast**](https://docs.pytorch.org/docs/stable/amp.html#autocasting): Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.
|
||||
- [**Gradient Scaling**](https://docs.pytorch.org/docs/stable/amp.html#gradient-scaling): Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.
|
||||
|
||||
## Design
|
||||
|
||||
### Casting Strategy
|
||||
|
||||
The [`CastPolicy`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L416-L438) is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.
|
||||
|
||||
| Policy | Explanation |
|
||||
| :--- | :--- |
|
||||
| **`lower_precision_fp`** | Cast all inputs to `lower_precision_fp` before execute the op. |
|
||||
| **`fp32`** | Cast all inputs to `at::kFloat` before running the op. |
|
||||
| **`fp32_set_opt_dtype`** | Execution in `at::kFloat`, while respecting user-specified output dtype if provided. |
|
||||
| **`fp32_append_dtype`** | Append at::kFloat to the args and redispatch to the type-aware overload |
|
||||
| **`promote`** | Promote all inputs to the “widest” dtype before execution. |
|
||||
|
||||
### Operators Lists
|
||||
|
||||
PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.
|
||||
|
||||
| Policy | Operators List |
|
||||
| :--- | :--- |
|
||||
| **`lower_precision_fp`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L819-L852) |
|
||||
| **`fp32`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L854-L912) |
|
||||
| **`fp32_set_opt_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L914-L931) |
|
||||
| **`fp32_append_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L933-L958) |
|
||||
| **`promote`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L960-L971) |
|
||||
|
||||
## Implementation
|
||||
|
||||
### Python Integration
|
||||
|
||||
Implement the `get_amp_supported_dtype` method to return the data types supported by the new accelerator in the AMP context.
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
|
||||
:end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### C++ Integration
|
||||
|
||||
This section shows how AMP registers autocast kernels for the `AutocastPrivateUse1` dispatch key.
|
||||
|
||||
- Register a fallback that makes unhandled ops fall through to their normal implementations.
|
||||
- Register specific aten kernels under `AutocastPrivateUse1` using the `KERNEL_PRIVATEUSEONE` helper macro, which maps an op to the desired precision implementation (with enum `at::autocast::CastPolicy`)
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: AMP FALLTHROUTH
|
||||
:end-before: LITERALINCLUDE END: AMP FALLTHROUTH
|
||||
:linenos:
|
||||
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: AMP IMPL
|
||||
:end-before: LITERALINCLUDE END: AMP IMPL
|
||||
:emphasize-lines: 3,6,8-10
|
||||
:linenos:
|
||||
```
|
||||
@ -44,6 +44,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
|
||||
autoload
|
||||
operators
|
||||
amp
|
||||
```
|
||||
|
||||
[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL"
|
||||
|
||||
@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo
|
||||
torch_openreg/
|
||||
├── CMakeLists.txt
|
||||
├── csrc
|
||||
│ ├── amp
|
||||
│ │ └── autocast_mode.cpp
|
||||
│ ├── aten
|
||||
│ │ ├── native
|
||||
│ │ │ ├── Extra.cpp
|
||||
@ -59,6 +61,8 @@ torch_openreg/
|
||||
│ └── stub.c
|
||||
├── __init__.py
|
||||
└── openreg
|
||||
├── amp
|
||||
│ └── __init__.py
|
||||
├── __init__.py
|
||||
├── meta.py
|
||||
└── random.py
|
||||
@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
|
||||
**Key Directories**:
|
||||
|
||||
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
|
||||
- `csrc/amp/`: AMP(Automatic Mixed Precision)
|
||||
- `csrc/aten/`: Operator registration
|
||||
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
|
||||
- `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
|
||||
- `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators.
|
||||
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
|
||||
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
|
||||
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
|
||||
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
|
||||
- `torch_openreg/csrc/`: Python C++ binding code.
|
||||
@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
|
||||
|
||||
### Autoload
|
||||
|
||||
- Autoload Machanism
|
||||
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
|
||||
|
||||
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
|
||||
- Register the backend with Python `entry points`: See `setup` in `setup.py`
|
||||
- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
|
||||
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
|
||||
|
||||
- Registering the backend with Python `entry points`: See `setup` in `setup.py`
|
||||
- Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
|
||||
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
|
||||
### AMP(Automatic Mixed Precision)
|
||||
|
||||
`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`.
|
||||
- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py`
|
||||
|
||||
## Installation and Usage
|
||||
|
||||
@ -168,11 +178,13 @@ print("Result z:\n", z)
|
||||
print(f"Device of z: {z.device}")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well.
|
||||
|
||||
## Future Plans
|
||||
|
||||
- **Enhance Features**:
|
||||
- Autoload
|
||||
- AMP
|
||||
- Device-agnostic APIs
|
||||
- Memory Management
|
||||
- Generator
|
||||
@ -180,5 +192,3 @@ print(f"Device of z: {z.device}")
|
||||
- Custom Tensor&Storage
|
||||
- ...
|
||||
- **Improve Tests**: Add more test cases related to the integration mechanism.
|
||||
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
|
||||
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
|
||||
using at::Tensor;
|
||||
|
||||
Tensor binary_cross_entropy_banned(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const std::optional<Tensor>&,
|
||||
int64_t) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
|
||||
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
|
||||
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
|
||||
"or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
|
||||
"safe to autocast.");
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: AMP FALLTHROUTH
|
||||
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
// LITERALINCLUDE END: AMP FALLTHROUTH
|
||||
|
||||
// LITERALINCLUDE START: AMP IMPL
|
||||
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
|
||||
// lower_precision_fp
|
||||
KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
|
||||
|
||||
// fp32
|
||||
KERNEL_PRIVATEUSEONE(asin, fp32)
|
||||
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
|
||||
TORCH_FN((&binary_cross_entropy_banned)));
|
||||
}
|
||||
// LITERALINCLUDE END: AMP IMPL
|
||||
@ -0,0 +1,50 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestAutocast(TestCase):
|
||||
def test_autocast_with_unsupported_type(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"In openreg autocast, but the target dtype torch.float32 is not supported.",
|
||||
):
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
||||
def test_autocast_operator_not_supported(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.",
|
||||
):
|
||||
x = torch.randn(2, 3, device="openreg")
|
||||
y = torch.randn(2, 3, device="openreg")
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float16):
|
||||
_ = torch.nn.functional.binary_cross_entropy(x, y)
|
||||
|
||||
def test_autocast_low_precision(self):
|
||||
with torch.amp.autocast(device_type="openreg", dtype=torch.float16):
|
||||
x = torch.randn(2, 3, device="openreg")
|
||||
y = torch.randn(3, 3, device="openreg")
|
||||
result = torch.mm(x, y)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
def test_autocast_fp32(self):
|
||||
with torch.amp.autocast(device_type="openreg"):
|
||||
x = torch.randn(2, device="openreg", dtype=torch.float16)
|
||||
result = torch.asin(x)
|
||||
self.assertEqual(result.dtype, torch.float32)
|
||||
|
||||
def test_autocast_default_dtype(self):
|
||||
openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg")
|
||||
self.assertEqual(openreg_fast_dtype, torch.half)
|
||||
|
||||
def test_autocast_set_dtype(self):
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
torch.set_autocast_dtype("openreg", dtype)
|
||||
self.assertEqual(torch.get_autocast_dtype("openreg"), dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
import torch_openreg._C # type: ignore[misc]
|
||||
|
||||
from . import meta # noqa: F401
|
||||
from .amp import get_amp_supported_dtype # noqa: F401
|
||||
|
||||
|
||||
_initialized = False
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
|
||||
def get_amp_supported_dtype():
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
|
||||
Reference in New Issue
Block a user