Compare commits

...

3 Commits

Author SHA1 Message Date
cc54c965d2 [pipelining] Validate stage input/output shape/dtype
Address the classes of user errors stemming from (possibly)
unintentional dynamic shapes usage or mismatch of configuration time and
run time data shapes/dtypes.

Input shape to first stage:
- validate to ensure no dynamic shape issues
- only first stage, since subsequent stages use pre-allocated recv
  buffers that can't change even if they wanted to

Output shape from all stages:
- wrong output shape/dtype from first stage could silently corrupt or
  hang when the recv op on the subsequent stage tries to consume the
  wrong amount of data
- validation of first stage input and all stages' outputs inductively
  verifies all shapes

Shape/dtype are most critical as they literally affect the number of
bytes on the wire.  Strides and other tensor properties may also (?)
matter, and the validation function can be adjusted accordingly if needed.

Ensure a clear error is raised rather than relying on some underlying
error to bubble up when a tensor shape is not compatible, or worse,
having a silent correctness issue.

ghstack-source-id: 4c6500857a8fde9957e596cca6e32ca1f6dc4f00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126732
2024-05-21 13:40:18 -07:00
3b19956cb2 [c10d] Add assertRaisesRegexOnRank helper for distributed
Allow asserting that an exception is raised only on the specified rank,
but not on other ranks.  Useful expecially for pipeline parallelism.

ghstack-source-id: 7a27f9e128f465e52e503617914261af4dbbbb41
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126731
2024-05-20 17:00:26 -07:00
f75b79493e [pipelining] Add pipeline stage test
ghstack-source-id: 4168a30e8f58566e4e35ecb7458bd918ea8f40a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126721
2024-05-20 16:59:56 -07:00
9 changed files with 441 additions and 66 deletions

View File

@ -17,9 +17,8 @@ class ExampleCode(torch.nn.Module):
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
def forward(self, x):
x = torch.mm(x, self.mm_param0)
x = x + y
x = torch.relu(x)
# try passing a value that doesn't require_grad across skip boundaries
a_constant = self.cval.clone()
@ -32,6 +31,29 @@ class ExampleCode(torch.nn.Module):
return x
class ModelWithKwargs(torch.nn.Module):
default_dhid = 512
default_batch_size = 256
def __init__(self, d_hid: int = default_dhid):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = self.lin0(x)
x = torch.relu(x)
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x
# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid):

View File

@ -16,7 +16,7 @@ batch_size = 256
torch.manual_seed(0)
class ExampleCode(torch.nn.Module):
class ModelWithKwargs(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
@ -44,7 +44,7 @@ class ExampleCode(torch.nn.Module):
class ChunkSpecTests(TestCase):
def test_chunk_spec(self):
mod = ExampleCode()
mod = ModelWithKwargs()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)

View File

@ -8,7 +8,7 @@ import tempfile
import torch
import torch.distributed as dist
from model_registry import ExampleCode, MultiMLP
from model_registry import ModelWithKwargs, MultiMLP
from torch.distributed.pipelining import (
pipeline,
PipelineStage,
@ -50,60 +50,11 @@ class ScheduleTest(MultiProcContinousTest):
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_ec_forward(self):
# Setting this flag for numerical stability
torch.distributed.pipelining.microbatch._debug_mask_minibatches = True
mod = ExampleCode(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
y = torch.randn(batch_size, d_hid, device=self.device)
pipe = pipeline(
mod,
chunks,
example_args=(x,),
example_kwargs={"y": y},
)
stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
if self.rank == 0:
schedule.step(x, y=y)
else:
out = schedule.step()
dist.barrier()
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = mod(x, y=y)
torch.testing.assert_close(out, ref_out)
# Test qualname mapping
submod_keys = stage.submod.state_dict().keys()
# Confirm keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
# Reset this flag
torch.distributed.pipelining.microbatch._debug_mask_minibatches = False
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_ec_backward(self, ScheduleClass):
mod = ExampleCode(d_hid)
mod = ModelWithKwargs(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)

View File

@ -0,0 +1,284 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
import torch
import torch.distributed as dist
from model_registry import ExampleCode, ModelWithKwargs, MultiMLP
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
ScheduleGPipe,
)
from torch.distributed.pipelining._utils import PipeliningShapeError
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
MultiProcContinousTest,
requires_nccl,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skip_but_pass_in_sandcastle_if,
)
from torch.utils._pytree import tree_map_only
d_hid = 512
batch_size = 256
chunks = 4
torch.manual_seed(0)
def get_dtype_change_hook(new_dtype):
"""A simple hook for simulating mixed precision"""
def dtype_change_hook(module, input, output):
def f(x):
return x.to(new_dtype)
return tree_map_only(torch.Tensor, f, output)
return dtype_change_hook
def get_flatten_hook():
"""A simple hook for simulating wrong model output shape"""
def flatten_hook(module, input, output):
def f(x):
return x.flatten()
return tree_map_only(torch.Tensor, f, output)
return flatten_hook
class StageTest(MultiProcContinousTest):
@classmethod
def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"
@classmethod
def setUpClass(cls):
"""
Class-scope test fixture. Run once for entire test class, before any test starts.
Set up the device.
"""
super().setUpClass()
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ModelClass", [ExampleCode, MultiMLP])
def test_tracer(self, ModelClass):
mod = ModelClass(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
pipe = pipeline(
mod,
chunks,
example_args=(x,),
)
stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
def _run_step(x):
if self.rank == 0:
return schedule.step(x)
else:
return schedule.step()
out = _run_step(x)
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = mod(x)
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
# Test qualname mapping
submod_keys = stage.submod.state_dict().keys()
# Confirm keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
if self.rank == 0:
# intended to run this code on all ranks, but the problem is if rank0 throws,
# it won't perform the send that unblocks rank 1.
# TODO(whc) can't test this until fixing args/kwargs issue
# with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
# _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x.to(torch.int32))
# output of stage's mlp layer will be flattened by this hook, the stage should err
handle = stage.submod.register_forward_hook(get_flatten_hook())
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(x)
handle.remove()
stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ModelClass", [ModelWithKwargs])
def test_tracer_kwargs(self, ModelClass):
mod = ModelClass(d_hid)
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
y = torch.randn(batch_size, d_hid, device=self.device)
pipe = pipeline(
mod,
chunks,
example_args=(x,),
example_kwargs={"y": y},
)
stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
def _run_step(x):
if self.rank == 0:
return schedule.step(x, y=y)
else:
return schedule.step()
# Last rank checks result
out = _run_step(x)
if self.rank == self.world_size - 1:
ref_out = mod(x, y=y)
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
# Test qualname mapping
submod_keys = stage.submod.state_dict().keys()
# Confirm keys are consistent with original model
old_keys = mod.state_dict().keys()
assert all(k in old_keys for k in submod_keys)
if self.rank == 0:
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x.to(torch.int32))
# output of stage's mlp layer will be flattened by this hook, the stage should err
handle = stage.submod.register_forward_hook(get_flatten_hook())
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(x)
handle.remove()
stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_manual(self):
full_mod = MultiMLP(d_hid).to(self.device)
stage_mod = full_mod.get_submodule(f"mlp{self.rank}")
stage_mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
stage = ManualPipelineStage(
stage_mod,
self.rank,
self.world_size,
self.device,
chunks,
input_args=x.chunk(chunks)[0],
)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Run
def _run_step(x):
if self.rank == 0:
return schedule.step(x)
else:
return schedule.step()
out = _run_step(x)
# Last rank checks result
if self.rank == self.world_size - 1:
ref_out = full_mod(x)
torch.testing.assert_close(out, ref_out)
if self.rank == 0:
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x.to(torch.int32))
# output of stage's mlp layer will be flattened by this hook, the stage should err
handle = stage_mod.register_forward_hook(get_flatten_hook())
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(x)
handle.remove()
stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
_run_step(x)
instantiate_parametrized_tests(StageTest)
if __name__ == "__main__":
# Check if GPU and NCCL are available
if not (
dist.is_available()
and dist.is_nccl_available()
and torch.cuda.device_count() > 1
):
print(
"c10d NCCL not available or not enough GPUs, skipping tests",
file=sys.stderr,
)
sys.exit(0)
rank = int(os.getenv("RANK", -1))
world_size = int(os.getenv("WORLD_SIZE", 2))
if rank != -1:
# Launched with torchrun or other multi-proc launchers. Directly run the test.
StageTest.run_rank(rank, world_size)
else:
# Launched as a single process. Spawn subprocess to run the tests.
# Also need a rendezvous file for `init_process_group` purpose.
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
torch.multiprocessing.spawn(
StageTest.run_rank,
nprocs=world_size,
args=(world_size, rdvz_file),
)

View File

@ -2,6 +2,7 @@
import logging
import operator
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@ -16,7 +17,7 @@ from torch.nn.parallel import DistributedDataParallel
from ._backward import stage_backward
from ._debug import map_debug_info
from ._IR import Pipe
from ._utils import flatten_args, modify_graph_op_device
from ._utils import flatten_args, modify_graph_op_device, validate_tensors_metadata
logger = logging.getLogger(__name__)
@ -26,7 +27,8 @@ class RootArgPlaceholder:
Placeholder for model-level inputs.
"""
pass
def __init__(self, tensor):
self.meta = tensor.to("meta")
class RecvInfo:
@ -118,6 +120,7 @@ class PipelineStageBase(ABC):
)
# Run time states
self._outputs_meta: Optional[Tuple[torch.Tensor]] = None
# map microbatch ID to list of forward tensor args
self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {}
# Current forward chunk id
@ -176,6 +179,25 @@ class PipelineStageBase(ABC):
"""
return self.stage_index == self.num_stages - 1
def _configure_outputs_meta(self, outputs_meta: Iterable[torch.Tensor]):
"""
Track the output shapes/dtype of this stage since they determine the send operation(s) which must match
recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial
configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
which could show up as hangs, silent corruption, or other errors.
"""
assert (
self._outputs_meta is None
), "Attempting to reconfigure output_meta, which is not supported"
self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment]
def get_outputs_meta(self) -> Tuple[torch.Tensor]:
"""Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
assert (
self._outputs_meta is not None
), "Attempted to get_outputs_meta() without configuring output meta"
return self._outputs_meta
def _create_grad_send_info(
self,
args_recv_info: Tuple,
@ -453,6 +475,7 @@ class PipelineStageBase(ABC):
`args` and `kwargs` are the inputs from *external* to this stage. They
applies only to the first stage in most cases.
"""
if self.is_first:
# First stage doesn't need to receive anything
composite_args = args
@ -463,6 +486,8 @@ class PipelineStageBase(ABC):
composite_args = self._retrieve_recv_activations()
composite_kwargs = {}
self._validate_fwd_input(args, kwargs)
# Compute forward
try:
output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
@ -498,6 +523,7 @@ class PipelineStageBase(ABC):
logger.debug(
f"{self.log_prefix} Forwarded chunk {self.fwd_chunk_id}, outputs: {map_debug_info(output)}" # noqa: G004
)
self._validate_fwd_outputs(output_tuple)
self.fwd_chunk_id += 1
return output
@ -543,6 +569,42 @@ class PipelineStageBase(ABC):
)
self.bwd_chunk_id += 1
def _validate_fwd_input(self, args, kwargs):
"""Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage."""
if self.is_first:
# TODO why is there a separate recv_info for each pipeline chunk?
expected_args = self.args_recv_info[self.fwd_chunk_id]
else:
expected_args = tuple()
if len(kwargs):
# TODO- need a mapping of kwarg to position in self.args_recv_info
# without it, we just validate shapes for args and ignore kwargs
logger.warning(
"Unable to validate input info for traced modules with kwargs."
)
expected_args = expected_args[: len(expected_args) - len(kwargs)]
# TODO- need a mapping of kwarg to position in self.args_recv_info
# maybe it's impossible to tell whether the len mismatches because
# (a) the user passed an extra arg or missed an arg
# (b) the user did not pass a kwarg, which has a default value baked into expected_args
expected_tensors_meta = [
e.meta if isinstance(e, RootArgPlaceholder) else e.buffer
for e in expected_args
]
validate_tensors_metadata("forward input args", expected_tensors_meta, args)
def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor]):
"""Raises a RuntimeError if this stage produces an output of unexpected shape/dtype.
Most likely, this could be cause either by incorrect user specification of output shapes, or becuase
shape inference was done on the original model but then at runtime the model is wrapped with something like
mixed precision which changes output dtype.
"""
expected_tensors_meta = self.get_outputs_meta()
validate_tensors_metadata("forward outputs", expected_tensors_meta, outputs)
class _PipelineStage(PipelineStageBase):
def __init__(
@ -656,10 +718,11 @@ class _PipelineStage(PipelineStageBase):
"""
Create a receive buffer for a placeholder.
"""
example_value = placeholder.meta["val"]
if arg_node.op == "placeholder":
# This is a root level placeholder, thus an input argument to the entire model.
# We are likely at stage 0, hence no need to create a receive buffer.
return RootArgPlaceholder()
return RootArgPlaceholder(example_value)
# Figure out the source stage of this input
while arg_node.target is operator.getitem:
@ -672,7 +735,6 @@ class _PipelineStage(PipelineStageBase):
src_stage = self.get_stage_index_of_submod(arg_node.name)
# Create a receive buffer for this placeholder
example_value = placeholder.meta["val"]
logger.debug(
f"{self.log_prefix} " # noqa: G004
f"Creating recv buffer for input '{placeholder.name}' "
@ -757,9 +819,21 @@ class _PipelineStage(PipelineStageBase):
if dst_rank is not None:
dsts.append(dst_rank)
output_node = self._get_output_node()
output_vals: Tuple[torch.Tensor] = tuple(
v.meta["val"] for v in flatten_args(output_node.args)
)
self._configure_outputs_meta(output_vals)
logger.debug(f"{self.log_prefix} " f"Send info: {act_send_info}") # noqa: G004
return act_send_info
def _get_output_node(self):
output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"]
assert len(output_nodes) == 1
output_node = output_nodes[0]
return output_node
def _create_grad_recv_info(
self,
act_send_info: Dict,
@ -769,9 +843,8 @@ class _PipelineStage(PipelineStageBase):
"""
# Dict[output_index, RecvInfo]
grad_recv_info: Dict[int, RecvInfo] = {}
output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"]
assert len(output_nodes) == 1
output_node = output_nodes[0]
output_node = self._get_output_node()
# The output node may take multiple args, meaning the submod having multiple output values.
output_vals = flatten_args(output_node.args)
@ -1028,6 +1101,8 @@ class ManualPipelineStage(PipelineStageBase):
else:
self.outputs = create_empty_tensors(output_args, device)
self._configure_outputs_meta(self.outputs)
# these are the buffers used in backwards send/recv, they are allocated later
self.outputs_grad: List[torch.Tensor] = []
@ -1061,7 +1136,7 @@ class ManualPipelineStage(PipelineStageBase):
self.args_recv_info[chunk_id] = recv_infos
else:
self.args_recv_info[chunk_id] = tuple(
[RootArgPlaceholder() for _ in self.inputs]
[RootArgPlaceholder(i) for i in self.inputs]
)
# Send info during forward for each activation

View File

@ -8,7 +8,7 @@ from ._IR import (
pipeline,
SplitPoint,
)
from ._PipelineStage import PipelineStage
from ._PipelineStage import ManualPipelineStage, PipelineStage
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
@ -24,6 +24,7 @@ __all__ = [
"pipeline",
"ArgsChunkSpec",
"KwargsChunkSpec",
"ManualPipelineStage",
"PipelineStage",
"Schedule1F1B",
"ScheduleGPipe",

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Dict, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import fx
@ -132,3 +132,35 @@ class QualnameMapMixin:
return self.tracer_qualname_map[name_before_split]
else:
return name_before_split
class PipeliningShapeError(RuntimeError):
"""Shape mismatch between configured and runtime values."""
def validate_tensor_metadata(desc, expected, given):
if not expected.shape == given.shape:
raise PipeliningShapeError(
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
)
if not expected.dtype == given.dtype:
raise PipeliningShapeError(
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
)
if not expected.stride() == given.stride():
raise PipeliningShapeError(
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
)
def validate_tensors_metadata(
desc,
expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor]],
actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor]],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(
f"Number of {desc} ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
)
for i in range(len(expected_tensors)):
validate_tensor_metadata(f"{desc}[{i}]", expected_tensors[i], actual_tensors[i])

View File

@ -15,7 +15,7 @@ import time
import traceback
import types
import unittest
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
@ -540,6 +540,11 @@ class MultiProcessTestCase(TestCase):
return types.MethodType(wrapper, self)
def assertRaisesRegexOnRank(self, rank, expected_exception, expected_regex, *args, **kwargs):
if self.rank == rank:
return self.assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
return nullcontext()
# The main process spawns N subprocesses that run the test.
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
@ -1303,6 +1308,11 @@ class MultiProcContinousTest(TestCase):
# Rendezvous file
rdvz_file: Optional[str] = None
def assertRaisesRegexOnRank(self, rank, expected_exception, expected_regex, *args, **kwargs):
if self.rank == rank:
return self.assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
return nullcontext()
@classmethod
@abc.abstractmethod
def backend_str(cls) -> str: