diff --git a/docs/source/accelerator/amp.md b/docs/source/accelerator/amp.md new file mode 100644 index 000000000000..ac78436f56a5 --- /dev/null +++ b/docs/source/accelerator/amp.md @@ -0,0 +1,72 @@ +# Automatic Mixed Precision + +## Background + +Automatic Mixed Precision (AMP) enables the use of both single precision (32-bit) and half precision (16-bit) floating point types during training or inference. + +Key components include: + +- [**Autocast**](https://docs.pytorch.org/docs/stable/amp.html#autocasting): Automatically casts operations to lower-precision (e.g., float16 or bfloat16) to improve performance while maintaining accuracy. +- [**Gradient Scaling**](https://docs.pytorch.org/docs/stable/amp.html#gradient-scaling): Dynamically scales gradients during backpropagation to prevent underflow when training with mixed precision. + +## Design + +### Casting Strategy + +The [`CastPolicy`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L416-L438) is used to define type conversion rules. Each enum value represents a set of type conversion requirements for a group of operators, ensuring consistent handling of operations that prioritize either precision or performance. + +| Policy | Explanation | +| :--- | :--- | +| **`lower_precision_fp`** | Cast all inputs to `lower_precision_fp` before execute the op. | +| **`fp32`** | Cast all inputs to `at::kFloat` before running the op. | +| **`fp32_set_opt_dtype`** | Execution in `at::kFloat`, while respecting user-specified output dtype if provided. | +| **`fp32_append_dtype`** | Append at::kFloat to the args and redispatch to the type-aware overload | +| **`promote`** | Promote all inputs to the “widest” dtype before execution. | + +### Operators Lists + +PyTorch defines a general list of operators for each of casting strategies mentioned above, as a reference for developers of new accelerators. + +| Policy | Operators List | +| :--- | :--- | +| **`lower_precision_fp`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L819-L852) | +| **`fp32`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L854-L912) | +| **`fp32_set_opt_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L914-L931) | +| **`fp32_append_dtype`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L933-L958) | +| **`promote`** | [List Link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.h#L960-L971) | + +## Implementation + +### Python Integration + +Implement the `get_amp_supported_dtype` method to return the data types supported by the new accelerator in the AMP context. + +```{eval-rst} +.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py + :language: python + :start-after: LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE + :end-before: LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE + :linenos: +``` + +### C++ Integration + +This section shows how AMP registers autocast kernels for the `AutocastPrivateUse1` dispatch key. + +- Register a fallback that makes unhandled ops fall through to their normal implementations. +- Register specific aten kernels under `AutocastPrivateUse1` using the `KERNEL_PRIVATEUSEONE` helper macro, which maps an op to the desired precision implementation (with enum `at::autocast::CastPolicy`) + +```{eval-rst} +.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp + :language: c++ + :start-after: LITERALINCLUDE START: AMP FALLTHROUTH + :end-before: LITERALINCLUDE END: AMP FALLTHROUTH + :linenos: + +.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp + :language: c++ + :start-after: LITERALINCLUDE START: AMP IMPL + :end-before: LITERALINCLUDE END: AMP IMPL + :emphasize-lines: 3,6,8-10 + :linenos: +``` diff --git a/docs/source/accelerator/index.md b/docs/source/accelerator/index.md index 70f25812bb9e..3e8e5c895699 100644 --- a/docs/source/accelerator/index.md +++ b/docs/source/accelerator/index.md @@ -44,6 +44,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k autoload operators +amp ``` [OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL" diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/README.md b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md index fd3eaf649abe..7f0e037fab27 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/README.md +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md @@ -25,6 +25,8 @@ The goal of `torch_openreg` is **not to implement a fully functional, high-perfo torch_openreg/ ├── CMakeLists.txt ├── csrc +│ ├── amp +│ │ └── autocast_mode.cpp │ ├── aten │ │ ├── native │ │ │ ├── Extra.cpp @@ -59,6 +61,8 @@ torch_openreg/ │ └── stub.c ├── __init__.py └── openreg + ├── amp + │ └── __init__.py ├── __init__.py ├── meta.py └── random.py @@ -95,11 +99,12 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll **Key Directories**: - `csrc/`: Core device implementation, including operator registration, runtime, etc. + - `csrc/amp/`: AMP(Automatic Mixed Precision) - `csrc/aten/`: Operator registration - `csrc/aten/native/`: Specific operator implementations for the OpenReg device. - `csrc/aten/native/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion). - `csrc/aten/native/OpenRegExtra.cpp`: Implementations for other types of operators. - - `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc. + - `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc. - `third_party/`: A C++ library that simulates a CUDA-like device using the CPU. - `torch_openreg/`: Python interface implementation (Python code and C++ Bindings). - `torch_openreg/csrc/`: Python C++ binding code. @@ -126,13 +131,18 @@ There are 4 DSOs in torch_openreg, and the dependencies between them are as foll ### Autoload -- Autoload Machanism +When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends. - When `import torch`, installed accelerators (such as `torch_openreg`) will be automatically loaded, achieving the same experience as the built-in backends. +- Register the backend with Python `entry points`: See `setup` in `setup.py` +- Add a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py` +- Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example) - - Registering the backend with Python `entry points`: See `setup` in `setup.py` - - Adding a callable function for backend initialization: See `_autoload` in `torch_openreg/__init__.py` - - Dynamically loading the backend without explicit imports: See [Usage Example](#usage-example) +### AMP(Automatic Mixed Precision) + +`AMP` provides convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use `lower precision` floating point datatype: `torch.float16` or `torch.bfloat16`. + +- Register specific operator conversion rules: See `autocat_mode.cpp` in `csrc/amp`. +- Add support for new data types for different accelerators: See `get_amp_supported_dtype` in `torch_openreg/openreg/amp/__init__.py` ## Installation and Usage @@ -168,11 +178,13 @@ print("Result z:\n", z) print(f"Device of z: {z.device}") ``` +## Documentation + +Please refer to [this](https://docs.pytorch.org/docs/main/accelerator/index.html) for a series of documents on integrating new accelerators into PyTorch, which will be kept in sync with the `OpenReg` codebase as well. + ## Future Plans - **Enhance Features**: - - Autoload - - AMP - Device-agnostic APIs - Memory Management - Generator @@ -180,5 +192,3 @@ print(f"Device of z: {z.device}") - Custom Tensor&Storage - ... - **Improve Tests**: Add more test cases related to the integration mechanism. -- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation. -- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync. diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp new file mode 100644 index 000000000000..129e4beabce1 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/amp/autocast_mode.cpp @@ -0,0 +1,37 @@ +#include + +using at::Tensor; + +Tensor binary_cross_entropy_banned( + const Tensor&, + const Tensor&, + const std::optional&, + int64_t) { + TORCH_CHECK( + false, + "torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n" + "Many models use a sigmoid layer right before the binary cross entropy layer.\n" + "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n" + "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n" + "safe to autocast."); +} + +// LITERALINCLUDE START: AMP FALLTHROUTH +TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} +// LITERALINCLUDE END: AMP FALLTHROUTH + +// LITERALINCLUDE START: AMP IMPL +TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) { + // lower_precision_fp + KERNEL_PRIVATEUSEONE(mm, lower_precision_fp) + + // fp32 + KERNEL_PRIVATEUSEONE(asin, fp32) + + m.impl( + TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), + TORCH_FN((&binary_cross_entropy_banned))); +} +// LITERALINCLUDE END: AMP IMPL diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py new file mode 100644 index 000000000000..01423741660d --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py @@ -0,0 +1,50 @@ +# Owner(s): ["module: PrivateUse1"] + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestAutocast(TestCase): + def test_autocast_with_unsupported_type(self): + with self.assertWarnsRegex( + UserWarning, + "In openreg autocast, but the target dtype torch.float32 is not supported.", + ): + with torch.autocast(device_type="openreg", dtype=torch.float32): + _ = torch.ones(10) + + def test_autocast_operator_not_supported(self): + with self.assertRaisesRegex( + RuntimeError, + "torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.", + ): + x = torch.randn(2, 3, device="openreg") + y = torch.randn(2, 3, device="openreg") + with torch.autocast(device_type="openreg", dtype=torch.float16): + _ = torch.nn.functional.binary_cross_entropy(x, y) + + def test_autocast_low_precision(self): + with torch.amp.autocast(device_type="openreg", dtype=torch.float16): + x = torch.randn(2, 3, device="openreg") + y = torch.randn(3, 3, device="openreg") + result = torch.mm(x, y) + self.assertEqual(result.dtype, torch.float16) + + def test_autocast_fp32(self): + with torch.amp.autocast(device_type="openreg"): + x = torch.randn(2, device="openreg", dtype=torch.float16) + result = torch.asin(x) + self.assertEqual(result.dtype, torch.float32) + + def test_autocast_default_dtype(self): + openreg_fast_dtype = torch.get_autocast_dtype(device_type="openreg") + self.assertEqual(openreg_fast_dtype, torch.half) + + def test_autocast_set_dtype(self): + for dtype in [torch.float16, torch.bfloat16]: + torch.set_autocast_dtype("openreg", dtype) + self.assertEqual(torch.get_autocast_dtype("openreg"), dtype) + + +if __name__ == "__main__": + run_tests() diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py index 670f54245fb0..7c8712666a21 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py @@ -3,6 +3,7 @@ import torch import torch_openreg._C # type: ignore[misc] from . import meta # noqa: F401 +from .amp import get_amp_supported_dtype # noqa: F401 _initialized = False diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py new file mode 100644 index 000000000000..02775c54df7f --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/amp/__init__.py @@ -0,0 +1,9 @@ +import torch + + +# LITERALINCLUDE START: AMP GET_SUPPORTED_DTYPE +def get_amp_supported_dtype(): + return [torch.float16, torch.bfloat16] + + +# LITERALINCLUDE END: AMP GET_SUPPORTED_DTYPE