[OpenReg] Add AMP Integration guide for accelerators (#162050)

Fix part of #158917

Add AMP integration document and OpenReg code as example to explain steps of integration.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162050
Approved by: https://github.com/albanD

Co-authored-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
zeshengzong
2025-09-30 12:27:08 +00:00
committed by PyTorch MergeBot
parent 7f29c47a4f
commit 77354e22e1
7 changed files with 190 additions and 10 deletions

View File

@ -0,0 +1,72 @@
# Automatic Mixed Precision
## Background
Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference.
Key components include:
- [**Autocast**](https://docs.pytorch.org/docs/stable/amp.html#autocasting): Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy.
- [**Gradient Scaling**](https://docs.pytorch.org/docs/stable/amp.html#gradient-scaling): Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision.
## Design
### Casting Strategy
The [`CastPolicy`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L416-L438) is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance.
| Policy | Explanation |
| :--- | :--- |
| **`lower_precision_fp`** | Cast all inputs to `lower_precision_fp` before execute the op. |
| **`fp32`** | Cast all inputs to `at::kFloat` before running the op. |
| **`fp32_set_opt_dtype`** | Execution in `at::kFloat`, while respecting user-specified output dtype if provided. |
| **`fp32_append_dtype`** | Append at::kFloat to the args and redispatch to the type-aware overload |
| **`promote`** | Promote all inputs to the “widest” dtype before execution. |
### Operators Lists
PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators.
| Policy | Operators List |
| :--- | :--- |
| **`lower_precision_fp`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L819-L852) |
| **`fp32`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L854-L912) |
| **`fp32_set_opt_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L914-L931) |
| **`fp32_append_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L933-L958) |
| **`promote`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L960-L971) |
## Implementation
### Python Integration
Implement the `get_amp_supported_dtype` method to return the data types supported by the new accelerator in the AMP context.
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py
:language: python
:start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
:end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
:linenos:
```
### C++ Integration
This section shows how AMP registers autocast kernels for the `AutocastPrivateUse1` dispatch key.
- Register a fallback that makes unhandled ops fall through to their normal implementations.
- Register specific aten kernels under `AutocastPrivateUse1` using the `KERNEL_PRIVATEUSEONE` helper macro, which maps an op to the desired precision implementation (with enum `at::autocast::CastPolicy`)
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
:language: c++
:start-after: LITERALINCLUDE START: AMP FALLTHROUTH
:end-before: LITERALINCLUDE END: AMP FALLTHROUTH
:linenos:
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp
:language: c++
:start-after: LITERALINCLUDE START: AMP IMPL
:end-before: LITERALINCLUDE END: AMP IMPL
:emphasize-lines: 3,6,8-10
:linenos:
```

View File

@ -44,6 +44,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
autoload
operators
amp
```
[OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL"

View File

@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo
torch_openreg/
├── CMakeLists.txt
├── csrc
│ ├── amp
│ │ └── autocast_mode.cpp
│ ├── aten
│ │ ├── native
│ │ │ ├── Extra.cpp
@ -59,6 +61,8 @@ torch_openreg/
│ └── stub.c
├── __init__.py
└── openreg
├── amp
│ └── __init__.py
├── __init__.py
├── meta.py
└── random.py
@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
**Key Directories**:
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
- `csrc/amp/`: AMP(Automatic Mixed Precision)
- `csrc/aten/`: Operator registration
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
- `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
- `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
- `torch_openreg/csrc/`: Python C++ binding code.
@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
### Autoload
- Autoload Machanism
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
- Register the backend with Python `entry points`: See `setup` in `setup.py`
- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
- Registering the backend with Python `entry points`: See `setup` in `setup.py`
- Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
### AMP(Automatic Mixed Precision)
`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`.
- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`.
- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py`
## Installation and Usage
@ -168,11 +178,13 @@ print("Result z:\n", z)
print(f"Device of z: {z.device}")
```
## Documentation
Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well.
## Future Plans
- **Enhance Features**:
- Autoload
- AMP
- Device-agnostic APIs
- Memory Management
- Generator
@ -180,5 +192,3 @@ print(f"Device of z: {z.device}")
- Custom Tensor&Storage
- ...
- **Improve Tests**: Add more test cases related to the integration mechanism.
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.

View File

@ -0,0 +1,37 @@
#include <ATen/autocast_mode.h>
using at::Tensor;
Tensor binary_cross_entropy_banned(
const Tensor&,
const Tensor&,
const std::optional<Tensor>&,
int64_t) {
TORCH_CHECK(
false,
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
"or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
"safe to autocast.");
}
// LITERALINCLUDE START: AMP FALLTHROUTH
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
// LITERALINCLUDE END: AMP FALLTHROUTH
// LITERALINCLUDE START: AMP IMPL
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
// lower_precision_fp
KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
// fp32
KERNEL_PRIVATEUSEONE(asin, fp32)
m.impl(
TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&binary_cross_entropy_banned)));
}
// LITERALINCLUDE END: AMP IMPL

View File

@ -0,0 +1,50 @@
# Owner(s): ["module: PrivateUse1"]
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex(
UserWarning,
"In openreg autocast, but the target dtype torch.float32 is not supported.",
):
with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10)
def test_autocast_operator_not_supported(self):
with self.assertRaisesRegex(
RuntimeError,
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.",
):
x = torch.randn(2, 3, device="openreg")
y = torch.randn(2, 3, device="openreg")
with torch.autocast(device_type="openreg", dtype=torch.float16):
_ = torch.nn.functional.binary_cross_entropy(x, y)
def test_autocast_low_precision(self):
with torch.amp.autocast(device_type="openreg", dtype=torch.float16):
x = torch.randn(2, 3, device="openreg")
y = torch.randn(3, 3, device="openreg")
result = torch.mm(x, y)
self.assertEqual(result.dtype, torch.float16)
def test_autocast_fp32(self):
with torch.amp.autocast(device_type="openreg"):
x = torch.randn(2, device="openreg", dtype=torch.float16)
result = torch.asin(x)
self.assertEqual(result.dtype, torch.float32)
def test_autocast_default_dtype(self):
openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg")
self.assertEqual(openreg_fast_dtype, torch.half)
def test_autocast_set_dtype(self):
for dtype in [torch.float16, torch.bfloat16]:
torch.set_autocast_dtype("openreg", dtype)
self.assertEqual(torch.get_autocast_dtype("openreg"), dtype)
if __name__ == "__main__":
run_tests()

View File

@ -3,6 +3,7 @@ import torch
import torch_openreg._C # type: ignore[misc]
from . import meta # noqa: F401
from .amp import get_amp_supported_dtype # noqa: F401
_initialized = False

View File

@ -0,0 +1,9 @@
import torch
# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
def get_amp_supported_dtype():
return [torch.float16, torch.bfloat16]
# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE