mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[OpenReg] Add AMP Integration guide for accelerators (#162050)
Fix part of #158917 Add AMP integration document and OpenReg code as example to explain steps of integration. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162050 Approved by: https://github.com/albanD Co-authored-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f29c47a4f
commit
77354e22e1
@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo
|
||||
torch_openreg/
|
||||
├── CMakeLists.txt
|
||||
├── csrc
|
||||
│ ├── amp
|
||||
│ │ └── autocast_mode.cpp
|
||||
│ ├── aten
|
||||
│ │ ├── native
|
||||
│ │ │ ├── Extra.cpp
|
||||
@ -59,6 +61,8 @@ torch_openreg/
|
||||
│ └── stub.c
|
||||
├── __init__.py
|
||||
└── openreg
|
||||
├── amp
|
||||
│ └── __init__.py
|
||||
├── __init__.py
|
||||
├── meta.py
|
||||
└── random.py
|
||||
@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
|
||||
**Key Directories**:
|
||||
|
||||
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
|
||||
- `csrc/amp/`: AMP(Automatic Mixed Precision)
|
||||
- `csrc/aten/`: Operator registration
|
||||
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
|
||||
- `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
|
||||
- `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators.
|
||||
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
|
||||
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
|
||||
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
|
||||
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
|
||||
- `torch_openreg/csrc/`: Python C++ binding code.
|
||||
@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll
|
||||
|
||||
### Autoload
|
||||
|
||||
- Autoload Machanism
|
||||
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
|
||||
|
||||
When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends.
|
||||
- Register the backend with Python `entry points`: See `setup` in `setup.py`
|
||||
- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
|
||||
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
|
||||
|
||||
- Registering the backend with Python `entry points`: See `setup` in `setup.py`
|
||||
- Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py`
|
||||
- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example)
|
||||
### AMP(Automatic Mixed Precision)
|
||||
|
||||
`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`.
|
||||
- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py`
|
||||
|
||||
## Installation and Usage
|
||||
|
||||
@ -168,11 +178,13 @@ print("Result z:\n", z)
|
||||
print(f"Device of z: {z.device}")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well.
|
||||
|
||||
## Future Plans
|
||||
|
||||
- **Enhance Features**:
|
||||
- Autoload
|
||||
- AMP
|
||||
- Device-agnostic APIs
|
||||
- Memory Management
|
||||
- Generator
|
||||
@ -180,5 +192,3 @@ print(f"Device of z: {z.device}")
|
||||
- Custom Tensor&Storage
|
||||
- ...
|
||||
- **Improve Tests**: Add more test cases related to the integration mechanism.
|
||||
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
|
||||
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.
|
||||
|
@ -0,0 +1,37 @@
|
||||
#include <ATen/autocast_mode.h>
|
||||
|
||||
using at::Tensor;
|
||||
|
||||
Tensor binary_cross_entropy_banned(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const std::optional<Tensor>&,
|
||||
int64_t) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n"
|
||||
"Many models use a sigmoid layer right before the binary cross entropy layer.\n"
|
||||
"In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n"
|
||||
"or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n"
|
||||
"safe to autocast.");
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: AMP FALLTHROUTH
|
||||
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
// LITERALINCLUDE END: AMP FALLTHROUTH
|
||||
|
||||
// LITERALINCLUDE START: AMP IMPL
|
||||
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
|
||||
// lower_precision_fp
|
||||
KERNEL_PRIVATEUSEONE(mm, lower_precision_fp)
|
||||
|
||||
// fp32
|
||||
KERNEL_PRIVATEUSEONE(asin, fp32)
|
||||
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
|
||||
TORCH_FN((&binary_cross_entropy_banned)));
|
||||
}
|
||||
// LITERALINCLUDE END: AMP IMPL
|
@ -0,0 +1,50 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestAutocast(TestCase):
|
||||
def test_autocast_with_unsupported_type(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"In openreg autocast, but the target dtype torch.float32 is not supported.",
|
||||
):
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
||||
def test_autocast_operator_not_supported(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.",
|
||||
):
|
||||
x = torch.randn(2, 3, device="openreg")
|
||||
y = torch.randn(2, 3, device="openreg")
|
||||
with torch.autocast(device_type="openreg", dtype=torch.float16):
|
||||
_ = torch.nn.functional.binary_cross_entropy(x, y)
|
||||
|
||||
def test_autocast_low_precision(self):
|
||||
with torch.amp.autocast(device_type="openreg", dtype=torch.float16):
|
||||
x = torch.randn(2, 3, device="openreg")
|
||||
y = torch.randn(3, 3, device="openreg")
|
||||
result = torch.mm(x, y)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
def test_autocast_fp32(self):
|
||||
with torch.amp.autocast(device_type="openreg"):
|
||||
x = torch.randn(2, device="openreg", dtype=torch.float16)
|
||||
result = torch.asin(x)
|
||||
self.assertEqual(result.dtype, torch.float32)
|
||||
|
||||
def test_autocast_default_dtype(self):
|
||||
openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg")
|
||||
self.assertEqual(openreg_fast_dtype, torch.half)
|
||||
|
||||
def test_autocast_set_dtype(self):
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
torch.set_autocast_dtype("openreg", dtype)
|
||||
self.assertEqual(torch.get_autocast_dtype("openreg"), dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -3,6 +3,7 @@ import torch
|
||||
import torch_openreg._C # type: ignore[misc]
|
||||
|
||||
from . import meta # noqa: F401
|
||||
from .amp import get_amp_supported_dtype # noqa: F401
|
||||
|
||||
|
||||
_initialized = False
|
||||
|
@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE
|
||||
def get_amp_supported_dtype():
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE
|
Reference in New Issue
Block a user