[OpenReg] Add AMP Integration guide for accelerators (#162050)

Fix part of #158917

Add AMP integration document and OpenReg code as example to explain steps of integration.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162050
Approved by: https://github.com/albanD

Co-authored-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
zeshengzong
2025-09-30 12:27:08 +00:00
committed by PyTorch MergeBot
parent 7f29c47a4f
commit 77354e22e1
7 changed files with 190 additions and 10 deletions

View File

@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo
torch_openreg/
├── CMakeLists.txt
├── csrc
│ ├── amp
│ │ └── autocast_mode.cpp
│ ├── aten
│ │ ├── native
│ │ │ ├── Extra.cpp
@ -59,6 +61,8 @@ torch_openreg/
│ └── stub.c
├── __init__.py
└── openreg
├── amp
│ └── __init__.py
├── __init__.py
├── meta.py
└── random.py
@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
**Key Directories**:
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
- `csrc/amp/`: AMP(Automatic Mixed Precision)
- `csrc/aten/`: Operator registration
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
- `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
- `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
- `torch_openreg/csrc/`: Python C++ binding code.
@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
### Autoload
- Autoload Machanism
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
- Register the backend with Python `entry points`: See `setup` in `setup.py`
- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
- Registering the backend with Python `entry points`: See `setup` in `setup.py`
- Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
### AMP(Automatic Mixed Precision)
`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`.
- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`.
- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py`
## Installation and Usage
@ -168,11 +178,13 @@ print("Result z:\n", z)
print(f"Device of z: {z.device}")
```
## Documentation
Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well.
## Future Plans
- **Enhance Features**:
- Autoload
- AMP
- Device-agnostic APIs
- Memory Management
- Generator
@ -180,5 +192,3 @@ print(f"Device of z: {z.device}")
- Custom Tensor&Storage
- ...
- **Improve Tests**: Add more test cases related to the integration mechanism.
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.

View File

@ -0,0 +1,37 @@
#include <ATen/autocast_mode.h>
using at::Tensor;
Tensor binary_cross_entropy_banned(
const Tensor&,
const Tensor&,
const std::optional<Tensor>&,
int64_t) {
TORCH_CHECK(
false,
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
"or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
"safe to autocast.");
}
// LITERALINCLUDE START: AMP FALLTHROUTH
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
// LITERALINCLUDE END: AMP FALLTHROUTH
// LITERALINCLUDE START: AMP IMPL
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
// lower_precision_fp
KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
// fp32
KERNEL_PRIVATEUSEONE(asin, fp32)
m.impl(
TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&binary_cross_entropy_banned)));
}
// LITERALINCLUDE END: AMP IMPL

View File

@ -0,0 +1,50 @@
# Owner(s): ["module: PrivateUse1"]
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex(
UserWarning,
"In openreg autocast, but the target dtype torch.float32 is not supported.",
):
with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10)
def test_autocast_operator_not_supported(self):
with self.assertRaisesRegex(
RuntimeError,
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.",
):
x = torch.randn(2, 3, device="openreg")
y = torch.randn(2, 3, device="openreg")
with torch.autocast(device_type="openreg", dtype=torch.float16):
_ = torch.nn.functional.binary_cross_entropy(x, y)
def test_autocast_low_precision(self):
with torch.amp.autocast(device_type="openreg", dtype=torch.float16):
x = torch.randn(2, 3, device="openreg")
y = torch.randn(3, 3, device="openreg")
result = torch.mm(x, y)
self.assertEqual(result.dtype, torch.float16)
def test_autocast_fp32(self):
with torch.amp.autocast(device_type="openreg"):
x = torch.randn(2, device="openreg", dtype=torch.float16)
result = torch.asin(x)
self.assertEqual(result.dtype, torch.float32)
def test_autocast_default_dtype(self):
openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg")
self.assertEqual(openreg_fast_dtype, torch.half)
def test_autocast_set_dtype(self):
for dtype in [torch.float16, torch.bfloat16]:
torch.set_autocast_dtype("openreg", dtype)
self.assertEqual(torch.get_autocast_dtype("openreg"), dtype)
if __name__ == "__main__":
run_tests()

View File

@ -3,6 +3,7 @@ import torch
import torch_openreg._C # type: ignore[misc]
from . import meta # noqa: F401
from .amp import get_amp_supported_dtype # noqa: F401
_initialized = False

View File

@ -0,0 +1,9 @@
import torch
# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
def get_amp_supported_dtype():
return [torch.float16, torch.bfloat16]
# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE