Files
DeepSpeed/tests/unit/util.py
Stas Bekman c2bb53f20f TiledMLP + SequenceTiledCompute: improve the bs>1 use-case (#7422)
Improved TiledMLP and SequenceTiledCompute for bs>1

This PR:
- extends the testing utils to add `CaptureStd*`, `CaptureLogger`
context managers
- extends the test to run both bs=1 and bs=2
- use an uneven seqlen to test varlen shards
- flattens bs+seqlen dim, to avoid problems with grad tensor strides
when bs>1 - mlp doesn't care for the bs dimension so using a pretend
`bs*seqlen` seqlen instead and restoring the shape at the end for the
grad.

---------

Signed-off-by: Stas Bekman <stas@stason.org>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2025-07-16 09:30:08 -07:00

316 lines
10 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported
from deepspeed.git_version_info import torch_info
from io import StringIO
import deepspeed
import logging
import pytest
import re
import sys
import torch
def skip_on_arch(min_arch=7):
if get_accelerator().device_name() == 'cuda':
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
pytest.skip(f"needs higher compute capability than {min_arch}")
else:
assert is_current_accelerator_supported()
return
def skip_on_cuda(valid_cuda):
split_version = lambda x: map(int, x.split('.')[:2])
if get_accelerator().device_name() == 'cuda':
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR
if valid_cuda.count(CUDA_VERSION) == 0:
pytest.skip(f"requires cuda versions {valid_cuda}")
else:
assert is_current_accelerator_supported()
return
def bf16_required_version_check(accelerator_check=True):
split_version = lambda x: map(int, x.split('.')[:2])
TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version'])
NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version'])
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
# Sometimes bf16 tests are runnable even if not natively supported by accelerator
if accelerator_check:
accelerator_pass = get_accelerator().is_bf16_supported()
else:
accelerator_pass = True
torch_version_available = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
cuda_version_available = CUDA_MAJOR >= 11
nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
npu_available = get_accelerator().device_name() == 'npu'
hpu_available = get_accelerator().device_name() == 'hpu'
xpu_available = get_accelerator().device_name() == 'xpu'
if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass:
return True
elif npu_available:
return True
elif hpu_available:
return True
elif xpu_available:
return True
else:
return False
def required_amp_check():
from importlib.util import find_spec
if find_spec('apex') is None:
return False
else:
return True
class no_child_process_in_deepspeed_io:
def __enter__(self):
# deepspeed_io defaults to creating a dataloader that uses a
# multiprocessing pool. Our tests use pools and we cannot nest pools in
# python. Therefore we're injecting this kwarg to ensure that no pools
# are used in the dataloader.
self.old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io
def new_method(*args, **kwargs):
kwargs["num_local_io_workers"] = 0
return self.old_method(*args, **kwargs)
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method
def __exit__(self, *_):
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
def torch_assert_equal(actual, expected, **kwargs):
"""
Compare two tensors or non-tensor numbers for their equality.
Add msg=blah to add an additional comment to when assert fails.
"""
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
def torch_assert_close(actual, expected, **kwargs):
"""
Compare two tensors or non-tensor numbers for their closeness.
Add msg=blah to add an additional comment to when assert fails.
For default values of `rtol` and `atol` which are dtype dependent, see the table at https://docs.pytorch.org/docs/stable/testing.html#torch.testing.assert_close
For example for bf16 it is `rtol=1.6e-2` and `atol=1e-5`.
The check doesn't assert when `|a - b| <= (atol + rtol * |b|)`
"""
return torch.testing.assert_close(actual, expected, **kwargs)
def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs):
"""
Compare two dicts of tensors or non-tensor numbers for their equality.
Add msg=blah to add an additional comment to when assert fails.
"""
for k in actual.keys():
torch.testing.assert_close(actual[k], expected[k], rtol=0.0, atol=0.0, **kwargs)
# CaptureStd, CaptureLogger context managers from https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py
# When any function contains print() calls that get overwritten, like progress bars,
# a special care needs to be applied, since under pytest -s captured output (capsys
# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
# \r's. This helper function ensures that the buffer will contain the same output
# with and without -s in pytest, by turning:
# foo bar\r tar mar\r final message
# into:
# final message
# it can handle a single string or a multiline buffer
def apply_print_resets(buf):
return re.sub(r"^.*\r", "", buf, 0, re.M)
class CaptureStd:
"""
Context manager to capture:
- stdout: replay it, clean it up and make it available via ``obj.out``
- stderr: replay it and make it available via ``obj.err``
- combined: combined the chosen streams and make it available via ``obj.combined``
init arguments:
- out - capture stdout:`` True``/``False``, default ``True``
- err - capture stdout: ``True``/``False``, default ``True``
- replay - whether to replay or not: ``True``/``False``, default ``True``. By default each
captured stream gets replayed back on context's exit, so that one can see what the test was
doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass
``replay=False`` to disable this feature.
Examples::
# to capture stdout only with auto-replay
with CaptureStdout() as cs:
print("Secret message")
assert "message" in cs.out
# to capture stderr only with auto-replay
import sys
with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr)
assert "Warning" in cs.err
# to capture both streams with auto-replay
with CaptureStd() as cs:
print("Secret message")
print("Warning: ", file=sys.stderr)
assert "message" in cs.out
assert "Warning" in cs.err
# to capture just one of the streams, and not the other, with auto-replay
with CaptureStd(err=False) as cs:
print("Secret message")
assert "message" in cs.out
# but best use the stream-specific subclasses
# to capture without auto-replay
with CaptureStd(replay=False) as cs:
print("Secret message")
assert "message" in cs.out
# sometimes it's easier to not try to figure out if it's stdout or stderr, and yet at
# other times the software may send the same output to stderr or stdout depending on
# environment, so to make the test robust a combined entry of both streams is available
"""
def __init__(self, out=True, err=True, replay=True):
self.replay = replay
if out:
self.out_buf = StringIO()
self.out = "error: CaptureStd context is unfinished yet, called too early"
else:
self.out_buf = None
self.out = "not capturing stdout"
if err:
self.err_buf = StringIO()
self.err = "error: CaptureStd context is unfinished yet, called too early"
else:
self.err_buf = None
self.err = "not capturing stderr"
self.combined = "error: CaptureStd context is unfinished yet, called too early"
def __enter__(self):
if self.out_buf is not None:
self.out_old = sys.stdout
sys.stdout = self.out_buf
if self.err_buf is not None:
self.err_old = sys.stderr
sys.stderr = self.err_buf
self.combined = ""
return self
def __exit__(self, *exc):
if self.out_buf is not None:
sys.stdout = self.out_old
captured = self.out_buf.getvalue()
if self.replay:
sys.stdout.write(captured)
self.out = apply_print_resets(captured)
self.combined += self.out
if self.err_buf is not None:
sys.stderr = self.err_old
captured = self.err_buf.getvalue()
if self.replay:
sys.stderr.write(captured)
self.err = captured
self.combined += self.err
def __repr__(self):
msg = ""
if self.out_buf:
msg += f"stdout: {self.out}\n"
if self.err_buf:
msg += f"stderr: {self.err}\n"
return msg
# in tests it's the best to capture only the stream that's wanted, otherwise
# it's easy to miss things, so unless you need to capture both streams, use the
# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
# disable the stream you don't need to test.
class CaptureStdout(CaptureStd):
"""Same as CaptureStd but captures only stdout"""
def __init__(self, replay=True):
super().__init__(err=False, replay=replay)
class CaptureStderr(CaptureStd):
"""Same as CaptureStd but captures only stderr"""
def __init__(self, replay=True):
super().__init__(out=False, replay=replay)
class CaptureLogger:
"""
Context manager to capture `logging` streams
Args:
- logger: 'logging` logger object
Results:
The captured output is available via `self.out`
Example::
>>> from transformers import logging
>>> from transformers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg+"\n"
"""
def __init__(self, logger):
self.logger = logger
self.io = StringIO()
self.sh = logging.StreamHandler(self.io)
self.out = ""
def __enter__(self):
self.logger.addHandler(self.sh)
return self
def __exit__(self, *exc):
self.logger.removeHandler(self.sh)
self.out = self.io.getvalue()
def __repr__(self):
return f"captured: {self.out}\n"