Compare commits

...

4 Commits

Author SHA1 Message Date
185b311b1c Automated submodule update: FBGEMM 2025-11-13 00:50:11 -08:00
ce4f31f662 [OpenReg][Feat][Docs] Enrich hook implementation and add focused documentation (#165980)
## Summary
This PR enriches the implementation of `OpenRegHooks.h` and adds focused documentation for `OpenReg` hooks.

## Key Changes
- A new document: `docs/source/accelerator/hooks.md`
- New `OpenReg` hooks like `isBuilt()`, `isAvailable()` and so on...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165980
Approved by: https://github.com/fffrog

Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
2025-11-13 08:36:18 +00:00
2c846bb614 [xpu][test]port embedding indexing and native_mha test files for Intel GPU (#165886)
we port test_indexing, test_native_mha and test_embedding for Intel GPU in this pr.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

Use torch.accelerator for general gpu
Skip the case if running on xpu which has known issues
using torch.nn.attention.sdpa_kernel() to replace torch.backends.cuda.sdp_kernel() for Intel GPU as torch.backends.cuda.sdp_kernel() is depricated and Intel xpu did not support it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165886
Approved by: https://github.com/guangyey, https://github.com/albanD
2025-11-13 08:17:23 +00:00
8c86ccfbc9 [DebugMode] .show_stack_trace inline (#167589)
Shows inline stack traces, with `.debug_string(show_stack_trace=True)`. For bwd ops we use `.fwd_stack_trace` when available.

Needs some improvement for:
- backwards: not all dispatch calls run under an autograd node, so some just have generic traces (e.g. `loss.backward()`)
- compiled regions: stack trace isn't very meaningful to start (e.g. points to codegened line)

Sample for test_nn_module (fwd + bwd):
```
    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])
    aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:405 in forward, code: return self.xyz(self.abc(x))
    aten::t(t: f32[4, 4])
    aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:429 in test_nn_module, code: out = mod(inp).sum()
    aten::sum(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:429 in test_nn_module, code: out = mod(inp).sum()
    aten::expand(t: f32[], [4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:405 in forward, code: return self.xyz(self.abc(x))
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::sum.dim_IntList(t: f32[4, 4], [0], True)
    aten::view(t: f32[1, 4], [4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:405 in forward, code: return self.xyz(self.abc(x))
    aten::t(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::sum.dim_IntList(t: f32[4, 4], [0], True)
    aten::view(t: f32[1, 4], [4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])
    aten::mm(t: f32[4, 4], t: f32[4, 4])
    aten::t(t: f32[4, 4])
    aten::sum.dim_IntList(t: f32[4, 4], [0], True)
    aten::view(t: f32[1, 4], [4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:396 in forward, code: return self.l2(self.l1(x))
    aten::t(t: f32[4, 4])

    # File: /data/users/pianpwk/pytorch/test/distributed/tensor/debug/test_debug_mode.py:430 in test_nn_module, code: out.backward()
    aten::detach(t: f32[4, 4])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167589
Approved by: https://github.com/yushangdi
2025-11-13 08:15:27 +00:00
14 changed files with 447 additions and 61 deletions

View File

@ -0,0 +1,164 @@
# Accelerator Hooks
## Background
OpenReg hooks provide a mechanism for integrating custom accelerator devices into PyTorch's runtime system. OpenReg (Open Registration) is PyTorch's extensibility framework that allows accelerator vendors to register custom device backends without modifying PyTorch core code.
## Design
The following tables list all hooks that accelerator vendors need to implement when integrating a new device backend. These hooks are categorized into two priority levels:
- **High Priority Hooks**: Core APIs that PyTorch runtime directly depends on. Accelerator vendors are recommended to implement all high priority hooks to ensure full PyTorch compatibility and enable basic device functionality.
- **Low Priority Hooks**: Device management and utility APIs that PyTorch does not directly depend on. These hooks enhance user experience and multi-device support but are *optional*. Accelerator vendors can choose to implement them based on their specific requirements and use cases.
### High Priority Hooks
| Hook Method | Description | Application Scenario |
| ---------------------------------- | --------------------------------------------------------- | -------------------------------------------------------------------------------- |
| `init()` | Initializes the accelerator runtime and device contexts | Set up necessary state when PyTorch first accesses the device |
| `hasPrimaryContext(DeviceIndex)` | Checks if a primary context exists for the device | Determine whether device initialization has occurred |
| `getDefaultGenerator(DeviceIndex)` | Returns the default random number generator for a device | Access the device's primary RNG for reproducible random operations |
| `getNewGenerator(DeviceIndex)` | Creates a new independent random number generator | Create isolated RNG instances for parallel operations |
| `getDeviceFromPtr(void*)` | Determines which device a memory pointer belongs to | Identify the accelerator device associated with a memory allocation |
| `getPinnedMemoryAllocator()` | Returns an allocator for pinned (page-locked) host memory | Allocate host memory that can be efficiently transferred to/from the accelerator |
| `isPinnedPtr(void*)` | Checks if a pointer points to pinned memory | Validate memory types before performing operations |
### Low Priority Hooks
| Hook Method | Description | Application Scenario |
| ---------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------- |
| `isBuilt()` | Returns whether the accelerator backend is built/compiled into the extension | Check whether the accelerator library is available at compile time |
| `isAvailable()` | Returns whether the accelerator hardware is available at runtime | Verify whether accelerator devices can be detected and initialized |
| `deviceCount()` | Returns the number of available accelerator devices | Enumerate all available accelerator devices for device selection |
| `setCurrentDevice(DeviceIndex)` | Sets the active device for the current thread | Switch the current thread's context to a specific accelerator device |
| `getCurrentDevice()` | Returns the currently active device index | Query which accelerator device is active in the current thread |
| `exchangeDevice(DeviceIndex)` | Atomically exchanges the current device and returns the previous one | Temporarily switch devices and restore the previous device afterward |
| `maybeExchangeDevice(DeviceIndex)` | Conditionally exchanges device only if the index is valid | Safely attempt device switching with validation |
## Implementation
We can just take `getDefaultGenerator` as an implementation example:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
:linenos:
```
In this implementation:
1. **Override the base interface**: The `getDefaultGenerator` method overrides the virtual method from `at::PrivateUse1HooksInterface`.
2. **Delegate to device-specific implementation**: It calls `getDefaultOpenRegGenerator(device_index)`, which manages a per-device generator instance.
3. **Return device-specific generator**: The returned `at::Generator` wraps an `OpenRegGeneratorImpl` that implements device-specific random number generation.
This pattern applies to all hooks: override the interface method, validate inputs, delegate to your device-specific API, and return results in PyTorch's expected format.
## Integration Example
The following sections demonstrate how PyTorch integrates with accelerator hooks when accessing the default random number generator. The example traces the complete flow from user-facing Python code down to the device-specific implementation.
### Layer 1: User Code
User code initiates the operation by calling `manual_seed` to set the random seed for reproducible results:
```python
import torch
torch.openreg.manual_seed(42)
```
### Layer 2: Extension Python API
The Python API layer handles device management and calls into the C++ extension (defined in [`torch_openreg/openreg/random.py`][random.py]):
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py
:language: python
:start-after: LITERALINCLUDE START: OPENREG MANUAL SEED
:end-before: LITERALINCLUDE END: OPENREG MANUAL SEED
:linenos:
```
The `manual_seed` function gets the current device index and calls `torch_openreg._C._get_default_generator(idx)` to obtain the device-specific generator, then sets the seed on it.
### Layer 3: Python/C++ Bridge
The C++ extension exposes `_getDefaultGenerator` to Python, which bridges to PyTorch's core runtime:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
:linenos:
:emphasize-lines: 10-11
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 3
```
This function unpacks the device index from Python, creates a `PrivateUse1` device object, and calls `at::globalContext().defaultGenerator()`. PyTorch's context then dispatches to the registered hooks.
### Layer 4: PyTorch Core Context
PyTorch's Context class dispatches to the appropriate accelerator hooks ([`aten/src/ATen/Context.h`][Context.h]):
```{eval-rst}
.. literalinclude:: ../../../aten/src/ATen/Context.h
:language: c++
:lines: 60-103
:linenos:
:emphasize-lines: 8-9, 24-25
```
This layered architecture enables PyTorch to remain device-agnostic while delegating hardware-specific operations to accelerator implementations. The hooks are registered once at module load time:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK REGISTER
:end-before: LITERALINCLUDE END: OPENREG HOOK REGISTER
:linenos:
:emphasize-lines: 4
```
### Layer 5: Accelerator Hooks
The hooks interface provides the abstraction that PyTorch uses to delegate to device-specific implementations:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG HOOK EXAMPLES
:end-before: LITERALINCLUDE END: OPENREG HOOK EXAMPLES
:linenos:
```
The `getDefaultGenerator` hook method overrides the base interface and delegates to `getDefaultOpenRegGenerator`, which manages the actual generator instances.
### Layer 6: Device-Specific Implementation
The device-specific implementation manages per-device generator instances:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
:end-before: LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
:linenos:
```
This function maintains a static vector of generators (one per device), initializes them on first access, validates the device index, and returns the appropriate generator instance.
[random.py]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py#L48-L53 "random.py"
[Context.h]: https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/Context.h#L61-L102 "Context.h"

View File

@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
hooks
autoload
operators
amp

View File

@ -5,6 +5,7 @@ static std::vector<at::Generator> default_generators;
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR IMPL
const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
static bool flag [[maybe_unused]] = []() {
auto deivce_nums = device_count();
@ -24,5 +25,6 @@ const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
}
return default_generators[idx];
}
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR IMPL
} // namespace c10::openreg

View File

@ -1,5 +1,6 @@
#include "OpenRegHooks.h"
// LITERALINCLUDE START: OPENREG HOOK REGISTER
namespace c10::openreg {
static bool register_hook_flag [[maybe_unused]] = []() {
@ -9,3 +10,4 @@ static bool register_hook_flag [[maybe_unused]] = []() {
}();
} // namespace c10::openreg
// LITERALINCLUDE END: OPENREG HOOK REGISTER

View File

@ -8,17 +8,58 @@
#include <include/openreg.h>
#include "OpenRegFunctions.h"
#include "OpenRegGenerator.h"
namespace c10::openreg {
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
struct OPENREG_EXPORT OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
OpenRegHooksInterface() {};
~OpenRegHooksInterface() override = default;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
void init() const override {
// Initialize OpenReg runtime if needed
// This is called when PyTorch first accesses the device
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
return true;
}
bool isBuilt() const override {
// This extension is compiled as part of the OpenReg test extension.
return true;
}
bool isAvailable() const override {
// Consider OpenReg available if there's at least one device reported.
return device_count() > 0;
}
DeviceIndex deviceCount() const override {
return device_count();
}
void setCurrentDevice(DeviceIndex device) const override {
set_device(device);
}
DeviceIndex getCurrentDevice() const override {
return current_device();
}
DeviceIndex exchangeDevice(DeviceIndex device) const override {
return ExchangeDevice(device);
}
DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
// Only exchange if the requested device is valid; otherwise, no-op and return current
auto count = device_count();
if (device < 0 || device >= count) {
return getCurrentDevice();
}
return exchangeDevice(device);
}
at::Allocator* getPinnedMemoryAllocator() const override {
return at::getHostAllocator(at::kPrivateUse1);
}
@ -30,12 +71,23 @@ struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
return attr.type == orMemoryTypeHost;
}
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
at::Device getDeviceFromPtr(void* data) const override {
orPointerAttributes attr{};
auto err = orPointerGetAttributes(&attr, data);
if (err == orSuccess && attr.type == orMemoryTypeDevice) {
return at::Device(at::DeviceType::PrivateUse1, static_cast<int>(attr.device));
} else {
TORCH_CHECK(false, "failed to get device from pointer");
}
return at::Device(at::DeviceType::PrivateUse1, current_device());
}
// LITERALINCLUDE START: OPENREG HOOK EXAMPLES
const at::Generator& getDefaultGenerator(DeviceIndex device_index) const override {
return getDefaultOpenRegGenerator(device_index);
}
// LITERALINCLUDE END: OPENREG HOOK EXAMPLES
at::Generator getNewGenerator(c10::DeviceIndex device_index) const override {
at::Generator getNewGenerator(DeviceIndex device_index) const override {
return at::make_generator<OpenRegGeneratorImpl>(device_index);
}
};

View File

@ -17,6 +17,7 @@ static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}
// LITERALINCLUDE START: OPENREG GET DEFAULT GENERATOR
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
@ -31,6 +32,7 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
PyObject* _setDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
@ -73,6 +75,7 @@ PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}
// LITERALINCLUDE START: OPENREG MODULE METHODS
static PyMethodDef methods[] = {
{"_init", _initExtension, METH_NOARGS, nullptr},
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
@ -81,7 +84,7 @@ static PyMethodDef methods[] = {
{"_exchangeDevice", _exchangeDevice, METH_O, nullptr},
{"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
// LITERALINCLUDE END: OPENREG MODULE METHODS
/*
* When ASAN is enabled, PyTorch modifies the dlopen flag during import,
* causing all global and weak symbols in _C.so and its dependent libraries

View File

@ -45,6 +45,7 @@ def initial_seed() -> int:
return default_generator.initial_seed()
# LITERALINCLUDE START: OPENREG MANUAL SEED
def manual_seed(seed: int) -> None:
seed = int(seed)
@ -53,6 +54,9 @@ def manual_seed(seed: int) -> None:
default_generator.manual_seed(seed)
# LITERALINCLUDE END: OPENREG MANUAL SEED
def manual_seed_all(seed: int) -> None:
seed = int(seed)

View File

@ -450,6 +450,9 @@ class TestDTensorDebugMode(TestCase):
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
self.assertTrue(
"self.l2(self.l1(x))" in debug_mode.debug_string(show_stack_trace=True)
)
@unittest.skipIf(not HAS_GPU, "requires GPU")
@unittest.skipIf(not has_triton_package(), "requires triton")

View File

@ -7,16 +7,17 @@ from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
dtypesIfXPU,
instantiate_device_type_tests,
largeTensorTest,
onlyCUDA,
onlyNativeDeviceTypes,
onlyOn,
skipCUDAIf,
skipMeta,
skipXPUIf,
TEST_WITH_ROCM,
)
from torch.testing._internal.common_nn import NNTestCase
@ -29,6 +30,13 @@ from torch.testing._internal.common_utils import (
run_tests,
set_default_dtype,
skipIfTorchDynamo,
TEST_CUDA,
TEST_XPU,
)
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
@ -36,7 +44,7 @@ class TestEmbeddingNN(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA/XPU unavailable")
def test_embedding_max_norm_unsorted_repeating_indices(self):
def create_embedding(device):
# Seed RNG so we get the same Embedding each time
@ -48,8 +56,8 @@ class TestEmbeddingNN(NNTestCase):
ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000)
out_cpu = create_embedding("cpu")(ix)
ix = ix.to("cuda")
out = create_embedding("cuda")(ix)
ix = ix.to(device_type)
out = create_embedding(device_type)(ix)
self.assertEqual(out.cpu(), out_cpu)
def test_embedding_sparse_basic(self):
@ -81,9 +89,9 @@ class TestEmbeddingNN(NNTestCase):
self.assertEqual(embedding.embedding_dim, 3)
self.assertEqual(embedding.num_embeddings, 10)
if torch.cuda.is_available():
embedding.to("cuda")
self.assertEqual(embedding.weight.device.type, "cuda")
if not torch.accelerator.is_available():
embedding.to(device_type)
self.assertEqual(embedding.weight.device.type, device_type)
embedding.to("cpu")
self.assertEqual(embedding.weight.device.type, "cpu")
@ -182,11 +190,11 @@ class TestEmbeddingNN(NNTestCase):
self.assertEqual(res_old, res_F)
# https://github.com/pytorch/pytorch/issues/130806
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@largeTensorTest("40GB", device="cuda")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA/XPU not available")
@largeTensorTest("40GB", device=device_type)
def test_large_tensors(self):
input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
w = torch.randn([16032, 16384], device="cuda")
input = torch.randint(low=0, high=16032, size=[131072], device=device_type)
w = torch.randn([16032, 16384], device=device_type)
out = torch.nn.functional.embedding(input, w)
self.assertEqual(out.dim(), 2)
self.assertEqual(out.numel(), 2147483648)
@ -308,6 +316,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
torch.nn.functional.embedding(indices, weight)
@dtypesIfCUDA(torch.float16, torch.float64)
@dtypesIfXPU(torch.float16, torch.float64)
@dtypes(torch.float64)
def test_embedding_backward(self, device, dtype):
embedding = nn.Embedding(10, 3, sparse=True)
@ -348,6 +357,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
else (torch.float, torch.double, torch.half)
)
)
@dtypesIfXPU(torch.float32, torch.double, torch.half)
@dtypes(torch.float32)
def test_embedding_max_norm_backward(self, device, dtype):
# can't use gradcheck since in place renorm makes analytical gradients different from produced ones
@ -372,6 +382,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
else (torch.float, torch.double, torch.half)
)
)
@dtypesIfXPU(torch.float32, torch.double, torch.half)
@dtypes(torch.float32)
def test_embedding_max_norm_fwd_AD(self, device, dtype):
if torch.device(device).type == "xla":
@ -396,6 +407,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
else (torch.float, torch.double, torch.half)
)
)
@dtypesIfXPU(torch.float32, torch.double, torch.half)
@dtypes(torch.float32)
def test_embedding_padding_idx(self, device, dtype):
embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype)
@ -488,6 +500,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
@onlyNativeDeviceTypes
@dtypes(torch.float32, torch.float64)
@dtypesIfCUDA(torch.half, torch.bfloat16)
@dtypesIfXPU(torch.half, torch.bfloat16)
def test_embedding_bag_1D_padding_idx(self, device, dtype):
num_features = 3
max_indices_per_bag = 10
@ -632,11 +645,12 @@ class TestEmbeddingNNDeviceType(NNTestCase):
weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol
)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
@dtypes(
torch.bfloat16,
)
@largeTensorTest("80GB", device="cuda")
@largeTensorTest("80GB", device="xpu")
def test_embedding_backward_large_batch_overflow(self, device, dtype):
"""
Test that embedding_dense_backward handles large batches that exceed INT32_MAX thread IDs.
@ -708,6 +722,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
@onlyNativeDeviceTypes
@dtypes(torch.float32, torch.float64)
@dtypesIfCUDA(torch.half, torch.bfloat16)
@dtypesIfXPU(torch.half, torch.bfloat16)
def test_embedding_bag_2D_padding_idx(self, device, dtype):
# Use a Python implementation of embedding_bag with padding_idx support
# to check torch.nn.functional.embedding_bag correctness
@ -818,7 +833,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
rtol = None
self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
@dtypes(
*(
(torch.float, torch.double, torch.bfloat16, torch.half)
@ -854,6 +869,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
self.assertEqual(output, torch.zeros_like(output))
@skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.")
@skipXPUIf(True, "no out-of-bounds check on XPU for perf.")
@dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long)))
@parametrize_test("padding_idx", [None, 0])
@parametrize_test("mode", ["sum", "mean", "max"])
@ -1066,6 +1082,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
(torch.float, torch.double, torch.half),
)
)
@dtypesIfXPU(
*itertools.product(
(torch.int, torch.long),
(torch.int, torch.long),
(torch.float32, torch.double, torch.half),
)
)
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
# Test empty input and per sample weight, and backward pass. There was a CUDA
# invalid configuration bug (more context in #46572)
@ -1132,6 +1155,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
(torch.float, torch.double, torch.half),
)
)
@dtypesIfXPU(
*itertools.product(
(torch.int, torch.long),
(torch.int, torch.long),
(torch.float32, torch.double, torch.half),
)
)
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
def test_per_sample_weights(mode, trainable_scale):
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
@ -1193,6 +1223,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
(torch.float, torch.double, torch.half),
)
)
@dtypesIfXPU(
*itertools.product(
(torch.int, torch.long),
(torch.int, torch.long),
(torch.float32, torch.double, torch.half),
)
)
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
def test_per_sample_weights_new_offsets(
mode, trainable_scale, include_last_offset, has_weight=True
@ -1357,6 +1394,11 @@ class TestEmbeddingNNDeviceType(NNTestCase):
(torch.int, torch.long), (torch.half, torch.float, torch.double)
)
)
@dtypesIfXPU(
*itertools.product(
(torch.int, torch.long), (torch.half, torch.float32, torch.double)
)
)
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
def run_tests(mode, sparse, trainable_per_sample_weights):
@ -1390,8 +1432,8 @@ class TestEmbeddingNNDeviceType(NNTestCase):
):
run_tests(mode, sparse, trainable_per_sample_weights)
# Test CUDA Dense on half precision
if device == "cuda":
# Test CUDA/XPU Dense on half precision
if device != "cpu":
modes = ("sum",)
sparsity = (False,)
trainable_scale = (True, False)
@ -1552,9 +1594,18 @@ class TestEmbeddingNNDeviceType(NNTestCase):
(torch.float, torch.double, torch.half),
)
)
@dtypesIfXPU(
*itertools.product(
(torch.int, torch.long),
(torch.int, torch.long),
(torch.float32, torch.double, torch.half),
)
)
def test_embedding_bag_device(self, device, dtypes):
if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu":
self.skipTest("bfloat16 not supported with Jetson cpu")
if dtypes[2] == torch.float64 and "xpu" in device:
self.skipTest("https://github.com/intel/torch-xpu-ops/issues/2295")
with set_default_dtype(torch.double):
self._test_EmbeddingBag(
device,
@ -1582,10 +1633,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
)
test_backward = False
if self.device_type == "cuda":
if self.device_type != "cpu":
# see 'todo' in test_embedding_bag.
test_backward = dtypes[2] is not torch.float16
elif self.device_type == "cpu":
else:
# TODO: figure out why precision on sparse embeddings isn't the
# same as for dense.
test_backward = (
@ -1626,6 +1677,13 @@ class TestEmbeddingNNDeviceType(NNTestCase):
(torch.float, torch.double, torch.half),
)
)
@dtypesIfXPU(
*itertools.product(
(torch.int, torch.long),
(torch.int, torch.long),
(torch.float32, torch.double, torch.half),
)
)
def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
@ -1703,7 +1761,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
bag(x, per_sample_weights=F.softmax(w, dim=-1))
instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals())
instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals(), allow_xpu=True)
instantiate_parametrized_tests(TestEmbeddingNN)
if __name__ == "__main__":

View File

@ -17,12 +17,14 @@ from torch.testing._internal.common_device_type import (
dtypesIfCPU,
dtypesIfCUDA,
dtypesIfMPS,
dtypesIfXPU,
expectedFailureMPS,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
onlyOn,
skipXLA,
skipXPUIf,
)
from torch.testing._internal.common_dtype import (
all_mps_types_and,
@ -38,6 +40,7 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo,
TEST_CUDA,
TEST_MPS,
TEST_XPU,
TestCase,
xfailIfTorchDynamo,
)
@ -598,8 +601,8 @@ class TestIndexing(TestCase):
# test invalid index fails
reference = torch.empty(10, dtype=dtype, device=device)
# can't test cuda because it is a device assert
if not reference.is_cuda:
# can't test cuda/xpu because it is a device assert
if reference.device.type == "cpu":
for err_idx in (10, -11):
with self.assertRaisesRegex(IndexError, r"out of"):
reference[err_idx]
@ -744,7 +747,7 @@ class TestIndexing(TestCase):
assert_get_eq(reference, indexer)
assert_set_eq(reference, indexer, 212)
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
if torch.cuda.is_available():
if torch.accelerator.is_available():
assert_backward_eq(reference, indexer)
reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
@ -1009,7 +1012,7 @@ class TestIndexing(TestCase):
@skipIfTorchDynamo(
"This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
)
@serialTest(TEST_CUDA or TEST_MPS)
@serialTest(TEST_CUDA or TEST_XPU or TEST_MPS)
def test_index_put_accumulate_large_tensor(self, device):
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
N = (1 << 31) + 5
@ -1086,7 +1089,7 @@ class TestIndexing(TestCase):
out_cpu = t.index_put_(indices, values2d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
def test_index_put_large_indices(self, device):
def generate_indices(num_indices: int, index_range: int):
indices = []
@ -1138,7 +1141,7 @@ class TestIndexing(TestCase):
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
def test_index_put_accumulate_non_contiguous(self, device):
t = torch.zeros((5, 2, 2))
t_dev = t.to(device)
@ -1157,7 +1160,7 @@ class TestIndexing(TestCase):
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
def test_index_put_deterministic_with_optional_tensors(self, device):
def func(x, i, v):
with DeterministicGuard(True):
@ -1188,7 +1191,7 @@ class TestIndexing(TestCase):
indices = torch.tensor([1, 4, 3])
indices_dev = indices.to(device)
val = torch.randn(4)
out_cuda = func1(t_dev, indices_dev, val.cuda())
out_cuda = func1(t_dev, indices_dev, val.to(device))
out_cpu = func1(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)
@ -1321,6 +1324,14 @@ class TestIndexing(TestCase):
torch.float8_e5m2,
torch.float8_e4m3fn,
)
@dtypesIfXPU(
torch.cfloat,
torch.cdouble,
torch.half,
torch.long,
torch.bool,
torch.bfloat16,
)
@dtypesIfMPS(torch.float, torch.float16, torch.long, torch.bool)
def test_index_put_src_datatype(self, device, dtype):
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
@ -1332,6 +1343,7 @@ class TestIndexing(TestCase):
@dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
@dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool)
@dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool)
@dtypesIfXPU(torch.half, torch.long, torch.bfloat16, torch.bool)
def test_index_src_datatype(self, device, dtype):
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
# test index
@ -1630,7 +1642,7 @@ class TestIndexing(TestCase):
self.assertRaisesRegex(IndexError, "invalid index", runner)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
def test_invalid_device(self, device):
idx = torch.tensor([0, 1])
b = torch.zeros(5, device=device)
@ -1642,7 +1654,7 @@ class TestIndexing(TestCase):
lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate),
)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
def test_cpu_indices(self, device):
idx = torch.tensor([0, 1])
b = torch.zeros(2, device=device)
@ -1718,7 +1730,7 @@ class TestIndexing(TestCase):
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
torch.take_along_dim(t, indices, dim=7)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
@dtypes(torch.float)
def test_gather_take_along_dim_cross_device(self, device, dtype):
shape = (2, 3, 1, 4)
@ -1748,7 +1760,7 @@ class TestIndexing(TestCase):
):
torch.take_along_dim(t.cpu(), indices, dim=0)
@onlyCUDA
@onlyOn(["cuda", "xpu"])
def test_cuda_broadcast_index_use_deterministic_algorithms(self, device):
with DeterministicGuard(True):
idx1 = torch.tensor([0])
@ -1969,6 +1981,7 @@ class TestIndexing(TestCase):
return (x, index, src)
@onlyNativeDeviceTypes
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1973")
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/161029
def test_index_copy_deterministic(self, device: torch.device) -> None:
for dim in range(3):
@ -2011,6 +2024,7 @@ class TestIndexing(TestCase):
self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
@onlyNativeDeviceTypes
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1973")
def test_index_put_non_accumulate_deterministic(self, device) -> None:
with DeterministicGuard(True):
for i in range(3):
@ -2048,6 +2062,7 @@ class TestIndexing(TestCase):
# The test fails for zero-dimensional tensors on XLA
@onlyNativeDeviceTypes
@dtypes(*all_types_complex_float8_and(torch.half, torch.bool, torch.bfloat16))
@dtypesIfXPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
def test_index_select(self, device, dtype):
num_src, num_out = 3, 5
@ -2361,8 +2376,8 @@ class NumpyTests(TestCase):
def test_trivial_fancy_out_of_bounds(self, device):
a = torch.zeros(5, device=device)
ind = torch.ones(20, dtype=torch.int64, device=device)
if a.is_cuda:
raise unittest.SkipTest("CUDA asserts instead of raising an exception")
if a.device.type in ["cuda", "xpu"]:
raise unittest.SkipTest("CUDA/XPU asserts instead of raising an exception")
ind[-1] = 10
self.assertRaises(IndexError, a.__getitem__, ind)
self.assertRaises(IndexError, a.__setitem__, ind, 0)
@ -2397,9 +2412,9 @@ class NumpyTests(TestCase):
instantiate_device_type_tests(
TestIndexing, globals(), except_for="meta", allow_mps=True
TestIndexing, globals(), except_for="meta", allow_mps=True, allow_xpu=True
)
instantiate_device_type_tests(NumpyTests, globals(), except_for="meta")
instantiate_device_type_tests(NumpyTests, globals(), except_for="meta", allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -6,11 +6,14 @@ import torch
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
dtypesIfXPU,
instantiate_device_type_tests,
onlyCUDA,
onlyOn,
skipMeta,
skipXPUIf,
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_WITH_ROCM
from torch.nn.attention import SDPBackend
class TestMHADeviceType(TestCase):
@torch.no_grad()
@ -89,6 +92,7 @@ class TestMHADeviceType(TestCase):
torch.testing.assert_close(v, correct_v)
@dtypesIfCUDA(torch.float)
@dtypesIfXPU(torch.float)
@dtypes(torch.float)
@skipMeta
def test_transform_bias_rescale_qkv(self, device, dtype):
@ -99,9 +103,11 @@ class TestMHADeviceType(TestCase):
)
@dtypesIfCUDA(torch.float)
@dtypesIfXPU(torch.float)
@dtypes(torch.float)
@skipMeta
@onlyCUDA
@skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/2182")
@onlyOn(["cuda", "xpu"])
def test_transform_bias_rescale_qkv_nested(self, device, dtype):
for use_padding in (False, True):
with self.subTest(use_padding=use_padding):
@ -185,9 +191,9 @@ class TestMHADeviceType(TestCase):
embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj
).to(dtype)
if device == "cuda":
pt = pt.cuda()
npt = npt.cuda()
if device == "cuda" or device == "xpu":
pt = pt.to(device)
npt = npt.to(device)
ypt, weight_pt = pt(
q,
@ -266,6 +272,7 @@ class TestMHADeviceType(TestCase):
self.assertEqual(weight_pt, weight_npt)
@dtypesIfCUDA(torch.float, torch.half)
@dtypesIfXPU(torch.float, torch.half)
@dtypes(torch.float)
@skipMeta
@parametrize("use_nt", [False, True])
@ -285,10 +292,25 @@ class TestMHADeviceType(TestCase):
with self.subTest(use_padding=use_padding, pad_all=pad_all,
use_nt=use_nt, need_weights=need_weights,
average_attn_weights=average_attn_weights):
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False
) if not fused else torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_mem_efficient=True
sdpa_backends_fused = [
SDPBackend.MATH,
SDPBackend.OVERRIDEABLE,
SDPBackend.CUDNN_ATTENTION,
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
]
sdpa_backends_not_fused = [
SDPBackend.MATH,
SDPBackend.OVERRIDEABLE,
SDPBackend.CUDNN_ATTENTION,
]
if device == "xpu":
sdpa_backends_fused = [SDPBackend.OVERRIDEABLE, SDPBackend.MATH]
sdpa_backends_not_fused = [SDPBackend.MATH]
with torch.nn.attention.sdpa_kernel(
sdpa_backends_not_fused
) if not fused else torch.nn.attention.sdpa_kernel(
sdpa_backends_fused
):
self._test_multihead_attention_impl(
device,
@ -302,6 +324,7 @@ class TestMHADeviceType(TestCase):
)
@dtypesIfCUDA(torch.float, torch.half)
@dtypesIfXPU(torch.float, torch.half)
@dtypes(torch.float)
@skipMeta
@torch.no_grad()
@ -316,6 +339,7 @@ class TestMHADeviceType(TestCase):
)
@dtypesIfCUDA(torch.float, torch.half)
@dtypesIfXPU(torch.float, torch.half)
@dtypes(torch.float)
@skipMeta
@torch.no_grad()
@ -330,7 +354,7 @@ class TestMHADeviceType(TestCase):
)
instantiate_device_type_tests(TestMHADeviceType, globals())
instantiate_device_type_tests(TestMHADeviceType, globals(), allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -297,7 +297,9 @@ class _ParsedStackTrace:
# get File:lineno code from stack_trace
def _parse_stack_trace(stack_trace: str):
def _parse_stack_trace(
stack_trace: str, filter_fn: Optional[Callable[[str, str, str], bool]] = None
):
if stack_trace is None:
return None
pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
@ -314,6 +316,8 @@ def _parse_stack_trace(stack_trace: str):
name = matches.group(3)
# next line should be the code
code = lines[idx + 1].strip()
if filter_fn and not filter_fn(file, name, code):
continue
return _ParsedStackTrace(file, lineno, name, code)
return None

View File

@ -34,6 +34,8 @@ Usage::
import contextlib
import functools
import inspect
import os
import traceback
import weakref
from collections.abc import Callable
@ -41,6 +43,7 @@ from typing import Any, Optional, TYPE_CHECKING # noqa: F401
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.graph import _parse_stack_trace
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
@ -193,6 +196,16 @@ def _get_stack_trace() -> str:
return "".join(summary.format())
def _get_user_stack_trace(stack_trace_str: str) -> str | None:
# Extract user code stack trace, filtering out torch internals.
torch_dir = os.path.dirname(inspect.getfile(torch))
filter_fn = lambda file, name, code: not file.startswith(torch_dir + os.path.sep) # noqa: E731
trace = _parse_stack_trace(stack_trace_str, filter_fn=filter_fn)
if trace:
return f"File: {trace.file}:{trace.lineno} in {trace.name}, code: {trace.code}"
return None
def _maybe_get_autograd_trace() -> str | None:
if torch._C._current_autograd_node() is not None:
tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
@ -781,14 +794,55 @@ class DebugMode(TorchDispatchMode):
self.operators.append(call)
return call
def debug_string(self) -> str:
def debug_string(self, show_stack_trace: bool = False) -> str:
"""
show_stack_trace: If True, display one-line stack trace summaries above groups
of operations (similar to gm.print_readable() style).
Requires record_stack_trace=True.
"""
with torch._C.DisableTorchFunction():
result = ""
result += "\n".join(
" " + " " * op.call_depth + op.render(self.record_tensor_attributes)
for op in self.operators
)
return result
if not show_stack_trace:
result = "\n".join(
" "
+ " " * op.call_depth
+ op.render(self.record_tensor_attributes)
for op in self.operators
)
return result
# Group operations by stack trace
lines = []
prev_stack_summary = None
for op in self.operators:
# Get the stack trace: prefer fwd_stack_trace, fallback to stack_trace
stack_trace = None
if hasattr(op, "fwd_stack_trace") and op.fwd_stack_trace:
stack_trace = op.fwd_stack_trace
elif hasattr(op, "stack_trace") and op.stack_trace:
stack_trace = op.stack_trace
stack_summary = None
if stack_trace:
stack_summary = _get_user_stack_trace(stack_trace)
if stack_summary and stack_summary != prev_stack_summary:
# add blank line before stack trace comment for readability
if lines: # don't add blank line at the very start
lines.append("")
indent = " " * (op.call_depth + 1)
lines.append(indent + "# " + stack_summary)
prev_stack_summary = stack_summary
# Add the operation line
line = (
" "
+ " " * op.call_depth
+ op.render(self.record_tensor_attributes)
)
lines.append(line)
return "\n".join(lines)
@staticmethod
@contextlib.contextmanager