mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 17:45:09 +08:00
Compare commits
24 Commits
ciflow/tru
...
whc/shardi
| Author | SHA1 | Date | |
|---|---|---|---|
| 653c0ecf35 | |||
| 057434a442 | |||
| 9fd0af1c3b | |||
| 53305e5379 | |||
| ea5f2aceda | |||
| 83557a528f | |||
| 54d05a0874 | |||
| bfddfde50c | |||
| b6570615f8 | |||
| 226850cc66 | |||
| f8a2ce3b9a | |||
| e2c6834584 | |||
| 0e7235ed73 | |||
| 3522e0ce74 | |||
| 50bf1f0b81 | |||
| c78e64622e | |||
| 5623628894 | |||
| 2aba180114 | |||
| 45b2c3d312 | |||
| 5b1e112cf9 | |||
| 5e6ac5c6e1 | |||
| 79317dc7a7 | |||
| 96a4c4b3d1 | |||
| 05bcfcc5d1 |
330
.spin/cmds.py
Normal file
330
.spin/cmds.py
Normal file
@ -0,0 +1,330 @@
|
||||
import hashlib
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import spin
|
||||
|
||||
|
||||
def file_digest(file, algorithm: str):
|
||||
try:
|
||||
return hashlib.file_digest(file, algorithm)
|
||||
except AttributeError:
|
||||
pass # Fallback to manual implementation below
|
||||
hash = hashlib.new(algorithm)
|
||||
while chunk := file.read(8192):
|
||||
hash.update(chunk)
|
||||
return hash
|
||||
|
||||
|
||||
def _hash_file(file):
|
||||
with open(file, "rb") as f:
|
||||
hash = file_digest(f, "sha256")
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def _hash_files(files):
|
||||
hashes = {file: _hash_file(file) for file in files}
|
||||
return hashes
|
||||
|
||||
|
||||
def _read_hashes(hash_file: Path):
|
||||
if not hash_file.exists():
|
||||
return {}
|
||||
with hash_file.open("r") as f:
|
||||
lines = f.readlines()
|
||||
hashes = {}
|
||||
for line in lines:
|
||||
hash = line[:64]
|
||||
file = line[66:].strip()
|
||||
hashes[file] = hash
|
||||
return hashes
|
||||
|
||||
|
||||
def _updated_hashes(hash_file, files_to_hash):
|
||||
old_hashes = _read_hashes(hash_file)
|
||||
new_hashes = _hash_files(files_to_hash)
|
||||
if new_hashes != old_hashes:
|
||||
return new_hashes
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_version():
|
||||
"""Regenerate version.py."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.generate_torch_version",
|
||||
"--is-debug=false",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
TYPE_STUBS = [
|
||||
(
|
||||
"Pytorch type stubs",
|
||||
Path(".lintbin/.pytorch-type-stubs.sha256"),
|
||||
[
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.pyi.gen_pyi",
|
||||
"--native-functions-path",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"--tags-path",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"--deprecated-functions-path",
|
||||
"tools/autograd/deprecated.yaml",
|
||||
],
|
||||
),
|
||||
(
|
||||
"Datapipes type stubs",
|
||||
None,
|
||||
[],
|
||||
[
|
||||
sys.executable,
|
||||
"torch/utils/data/datapipes/gen_pyi.py",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_type_stubs():
|
||||
"""Regenerate type stubs."""
|
||||
for name, hash_file, files_to_hash, cmd in TYPE_STUBS:
|
||||
if hash_file:
|
||||
if hashes := _updated_hashes(hash_file, files_to_hash):
|
||||
click.echo(
|
||||
f"Changes detected in type stub files for {name}. Regenerating..."
|
||||
)
|
||||
spin.util.run(cmd)
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Type stubs and hashes updated.")
|
||||
else:
|
||||
click.echo(f"No changes detected in type stub files for {name}.")
|
||||
else:
|
||||
click.echo(f"No hash file for {name}. Regenerating...")
|
||||
spin.util.run(cmd)
|
||||
click.echo("Type stubs regenerated.")
|
||||
|
||||
|
||||
@click.command()
|
||||
def regenerate_clangtidy_files():
|
||||
"""Regenerate clang-tidy files."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tools.linter.clang_tidy.generate_build_files",
|
||||
]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
#: These linters are expected to need less than 3s cpu time total
|
||||
VERY_FAST_LINTERS = {
|
||||
"ATEN_CPU_GPU_AGNOSTIC",
|
||||
"BAZEL_LINTER",
|
||||
"C10_NODISCARD",
|
||||
"C10_UNUSED",
|
||||
"CALL_ONCE",
|
||||
"CMAKE_MINIMUM_REQUIRED",
|
||||
"CONTEXT_DECORATOR",
|
||||
"COPYRIGHT",
|
||||
"CUBINCLUDE",
|
||||
"DEPLOY_DETECTION",
|
||||
"ERROR_PRONE_ISINSTANCE",
|
||||
"EXEC",
|
||||
"HEADER_ONLY_LINTER",
|
||||
"IMPORT_LINTER",
|
||||
"INCLUDE",
|
||||
"LINTRUNNER_VERSION",
|
||||
"MERGE_CONFLICTLESS_CSV",
|
||||
"META_NO_CREATE_UNBACKED",
|
||||
"NEWLINE",
|
||||
"NOQA",
|
||||
"NO_WORKFLOWS_ON_FORK",
|
||||
"ONCE_FLAG",
|
||||
"PYBIND11_INCLUDE",
|
||||
"PYBIND11_SPECIALIZATION",
|
||||
"PYPIDEP",
|
||||
"PYPROJECT",
|
||||
"RAWCUDA",
|
||||
"RAWCUDADEVICE",
|
||||
"ROOT_LOGGING",
|
||||
"TABS",
|
||||
"TESTOWNERS",
|
||||
"TYPEIGNORE",
|
||||
"TYPENOSKIP",
|
||||
"WORKFLOWSYNC",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take a few seconds, but less than 10s cpu time total
|
||||
FAST_LINTERS = {
|
||||
"CMAKE",
|
||||
"DOCSTRING_LINTER",
|
||||
"GHA",
|
||||
"NATIVEFUNCTIONS",
|
||||
"RUFF",
|
||||
"SET_LINTER",
|
||||
"SHELLCHECK",
|
||||
"SPACES",
|
||||
}
|
||||
|
||||
|
||||
#: These linters are expected to take more than 10s cpu time total;
|
||||
#: some need more than 1 hour.
|
||||
SLOW_LINTERS = {
|
||||
"ACTIONLINT",
|
||||
"CLANGFORMAT",
|
||||
"CLANGTIDY",
|
||||
"CODESPELL",
|
||||
"FLAKE8",
|
||||
"GB_REGISTRY",
|
||||
"PYFMT",
|
||||
"PYREFLY",
|
||||
"TEST_DEVICE_BIAS",
|
||||
"TEST_HAS_MAIN",
|
||||
}
|
||||
|
||||
|
||||
ALL_LINTERS = VERY_FAST_LINTERS | FAST_LINTERS | SLOW_LINTERS
|
||||
|
||||
|
||||
LINTRUNNER_CACHE_INFO = (
|
||||
Path(".lintbin/.lintrunner.sha256"),
|
||||
[
|
||||
"requirements.txt",
|
||||
"pyproject.toml",
|
||||
".lintrunner.toml",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
LINTRUNNER_BASE_CMD = [
|
||||
"uvx",
|
||||
"--python",
|
||||
"3.10",
|
||||
"lintrunner@0.12.7",
|
||||
]
|
||||
|
||||
|
||||
@click.command()
|
||||
def setup_lint():
|
||||
"""Set up lintrunner with current CI version."""
|
||||
cmd = LINTRUNNER_BASE_CMD + ["init"]
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _check_linters():
|
||||
cmd = LINTRUNNER_BASE_CMD + ["list"]
|
||||
ret = spin.util.run(cmd, output=False, stderr=subprocess.PIPE)
|
||||
linters = {l.strip() for l in ret.stdout.decode().strip().split("\n")[1:]}
|
||||
unknown_linters = linters - ALL_LINTERS
|
||||
missing_linters = ALL_LINTERS - linters
|
||||
if unknown_linters:
|
||||
click.secho(
|
||||
f"Unknown linters found; please add them to the correct category "
|
||||
f"in .spin/cmds.py: {', '.join(unknown_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
if missing_linters:
|
||||
click.secho(
|
||||
f"Missing linters found; please update the corresponding category "
|
||||
f"in .spin/cmds.py: {', '.join(missing_linters)}",
|
||||
fg="yellow",
|
||||
)
|
||||
return unknown_linters, missing_linters
|
||||
|
||||
|
||||
@spin.util.extend_command(
|
||||
setup_lint,
|
||||
doc=f"""
|
||||
If configuration has changed, update lintrunner.
|
||||
|
||||
Compares the stored old hashes of configuration files with new ones and
|
||||
performs setup via setup-lint if the hashes have changed.
|
||||
Hashes are stored in {LINTRUNNER_CACHE_INFO[0]}; the following files are
|
||||
considered: {", ".join(LINTRUNNER_CACHE_INFO[1])}.
|
||||
""",
|
||||
)
|
||||
@click.pass_context
|
||||
def lazy_setup_lint(ctx, parent_callback, **kwargs):
|
||||
if hashes := _updated_hashes(*LINTRUNNER_CACHE_INFO):
|
||||
click.echo(
|
||||
"Changes detected in lint configuration files. Setting up linting tools..."
|
||||
)
|
||||
parent_callback(**kwargs)
|
||||
hash_file = LINTRUNNER_CACHE_INFO[0]
|
||||
hash_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with hash_file.open("w") as f:
|
||||
for file, hash in hashes.items():
|
||||
f.write(f"{hash} {file}\n")
|
||||
click.echo("Linting tools set up and hashes updated.")
|
||||
else:
|
||||
click.echo("No changes detected in lint configuration files. Skipping setup.")
|
||||
click.echo("Regenerating version...")
|
||||
ctx.invoke(regenerate_version)
|
||||
click.echo("Regenerating type stubs...")
|
||||
ctx.invoke(regenerate_type_stubs)
|
||||
click.echo("Done.")
|
||||
_check_linters()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def lint(ctx, apply_patches, **kwargs):
|
||||
"""Lint all files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
all_files_linters = VERY_FAST_LINTERS | FAST_LINTERS
|
||||
changed_files_linters = SLOW_LINTERS
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
all_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(all_files_linters),
|
||||
"--all-files",
|
||||
]
|
||||
spin.util.run(all_files_cmd)
|
||||
changed_files_cmd = cmd + [
|
||||
"--take",
|
||||
",".join(changed_files_linters),
|
||||
]
|
||||
spin.util.run(changed_files_cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def fixlint(ctx, **kwargs):
|
||||
"""Autofix all files."""
|
||||
ctx.invoke(lint, apply_patches=True)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-a", "--apply-patches", is_flag=True)
|
||||
@click.pass_context
|
||||
def quicklint(ctx, apply_patches, **kwargs):
|
||||
"""Lint changed files."""
|
||||
ctx.invoke(lazy_setup_lint)
|
||||
cmd = LINTRUNNER_BASE_CMD
|
||||
if apply_patches:
|
||||
cmd += ["--apply-patches"]
|
||||
spin.util.run(cmd)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def quickfix(ctx, **kwargs):
|
||||
"""Autofix changed files."""
|
||||
ctx.invoke(quicklint, apply_patches=True)
|
||||
@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
|
||||
// clang-[17, 20] crashes when autovectorizing static cast to bf16
|
||||
// Below is a workaround to have some vectorization
|
||||
// Works decently well for smaller int types
|
||||
template <typename from_type>
|
||||
inline void convertToBf16Impl(
|
||||
const from_type* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
uint64_t n) {
|
||||
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
|
||||
uint64_t loopBound = n - (n % 16);
|
||||
uint64_t i = 0;
|
||||
for (; i < loopBound; i += 16) {
|
||||
float32x4_t a, b, c, d;
|
||||
a[0] = static_cast<float>(src[i]);
|
||||
a[1] = static_cast<float>(src[i + 1]);
|
||||
a[2] = static_cast<float>(src[i + 2]);
|
||||
a[3] = static_cast<float>(src[i + 3]);
|
||||
b[0] = static_cast<float>(src[i + 4]);
|
||||
b[1] = static_cast<float>(src[i + 5]);
|
||||
b[2] = static_cast<float>(src[i + 6]);
|
||||
b[3] = static_cast<float>(src[i + 7]);
|
||||
c[0] = static_cast<float>(src[i + 8]);
|
||||
c[1] = static_cast<float>(src[i + 9]);
|
||||
c[2] = static_cast<float>(src[i + 10]);
|
||||
c[3] = static_cast<float>(src[i + 11]);
|
||||
d[0] = static_cast<float>(src[i + 12]);
|
||||
d[1] = static_cast<float>(src[i + 13]);
|
||||
d[2] = static_cast<float>(src[i + 14]);
|
||||
d[3] = static_cast<float>(src[i + 15]);
|
||||
|
||||
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
|
||||
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
|
||||
}
|
||||
|
||||
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
|
||||
for (; i < n; i++) {
|
||||
float a = static_cast<float>(src[i]);
|
||||
dstPtr[i] = vcvth_bf16_f32(a);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
|
||||
template <> \
|
||||
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
|
||||
return convertToBf16Impl<from_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_TO_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_TO_BF16_TEMPLATE(int32_t)
|
||||
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
c10::BFloat16* __restrict dst,
|
||||
|
||||
342
aten/src/ATen/native/mkldnn/xpu/ScaledBlas.cpp
Normal file
342
aten/src/ATen/native/mkldnn/xpu/ScaledBlas.cpp
Normal file
@ -0,0 +1,342 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - Else if scale.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
|
||||
* 1:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return at::isFloat8Type(t.scalar_type()) &&
|
||||
scale.scalar_type() == at::kFloat && scale.numel() == 1;
|
||||
}
|
||||
|
||||
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (
|
||||
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
|
||||
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
|
||||
scale.is_contiguous());
|
||||
}
|
||||
|
||||
bool is_desired_scaling(
|
||||
const at::Tensor& t,
|
||||
const at::Tensor& scale,
|
||||
ScalingType desired_scaling) {
|
||||
auto result = desired_scaling == ScalingType::TensorWise
|
||||
? is_tensorwise_scaling(t, scale)
|
||||
: is_rowwise_scaling(t, scale);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
|
||||
const at::Tensor& a,
|
||||
const at::Tensor& b,
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b) {
|
||||
for (auto [lhs, rhs] : options) {
|
||||
if (is_desired_scaling(a, scale_a, lhs) &&
|
||||
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
|
||||
return {lhs, rhs};
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
|
||||
a.size(0),
|
||||
", 1) and scale_b should be (1, ",
|
||||
b.size(1),
|
||||
"), and both should be contiguous.\n"
|
||||
"Got a.dtype()=",
|
||||
a.scalar_type(),
|
||||
", scale_a.dtype()=",
|
||||
scale_a.scalar_type(),
|
||||
", scale_a.size()=",
|
||||
scale_a.sizes(),
|
||||
", scale_a.stride()=",
|
||||
scale_a.strides(),
|
||||
", ",
|
||||
"b.dtype()=",
|
||||
b.scalar_type(),
|
||||
", scale_b.dtype()=",
|
||||
scale_b.scalar_type(),
|
||||
", scale_b.size()=",
|
||||
scale_b.sizes(),
|
||||
" and scale_b.stride()=",
|
||||
scale_b.strides());
|
||||
}
|
||||
|
||||
Tensor& _scaled_gemm(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a,
|
||||
const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
// TODO: scale_result and alpha is not defined or used!
|
||||
std::optional<Tensor> scaled_result = std::nullopt;
|
||||
at::native::onednn::scaled_matmul(
|
||||
mat1,
|
||||
mat2,
|
||||
out,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
scaled_result,
|
||||
use_fast_accum);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output
|
||||
// matrices Scales are only applicable when matrices are of Float8 type and
|
||||
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
|
||||
// type, scale_result is not applied. Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
// - If 1-dimensional tensors are used then scale_a should be size =
|
||||
// mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat1`: the first operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher
|
||||
// precision floating point type
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only
|
||||
// utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
|
||||
// false.
|
||||
// - `out`: a reference to the output tensor
|
||||
|
||||
Tensor& _scaled_mm_out_xpu(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Note: fast_accum is not supported in XPU for now.
|
||||
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
|
||||
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs. This list is
|
||||
// sorted by decreasing priority.
|
||||
|
||||
// List of supported datatypes for XPU with oneDNN:
|
||||
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
|
||||
},
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b);
|
||||
TORCH_CHECK(
|
||||
!scale_result ||
|
||||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
TORCH_CHECK(
|
||||
!bias || bias->numel() == mat2.sizes()[1],
|
||||
"Bias must be size ",
|
||||
mat2.sizes()[1],
|
||||
" but got ",
|
||||
bias->numel());
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK(
|
||||
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
|
||||
"mat2 shape (",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
") must be divisible by 16");
|
||||
// Check types
|
||||
TORCH_CHECK(
|
||||
!out_dtype || *out_dtype == out.scalar_type(),
|
||||
"out_dtype must match output matrix type");
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat1.scalar_type()),
|
||||
"Expected mat1 to be Float8 matrix got ",
|
||||
mat1.scalar_type());
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat2.scalar_type()),
|
||||
"Expected mat2 to be Float8 matrix got ",
|
||||
mat2.scalar_type());
|
||||
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
|
||||
// support 2D scales, only 1D. Needs to add more checks there.
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(
|
||||
bias->scalar_type() == kFloat ||
|
||||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == c10::ScalarType::Half,
|
||||
"Bias must be Float32 or BFloat16 or Half, but got ",
|
||||
bias->scalar_type());
|
||||
}
|
||||
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
auto scale_result_ = scale_result.value_or(Tensor());
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{
|
||||
{out, "out", 0},
|
||||
{mat1, "mat1", 1},
|
||||
{mat2, "mat2", 2},
|
||||
{bias_, "bias", 3},
|
||||
{scale_a, "scale_a", 4},
|
||||
{scale_b, "scale_b", 5},
|
||||
{scale_result_, "scale_result", 6}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
|
||||
// kernels do not support this case).
|
||||
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
if (mat1_sizes[1] == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO: Scale_result is not supported by now!!
|
||||
return _scaled_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
Tensor _scaled_mm_xpu(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
return _scaled_mm_out_xpu(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
scale_result,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -8,7 +9,6 @@
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
at::Tensor broadcast_bias2D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor& bias,
|
||||
@ -328,4 +328,236 @@ void quantized_matmul(
|
||||
result.copy_(dst);
|
||||
}
|
||||
|
||||
// Describes how to configure oneDNN scales for a given role/ScalingType
|
||||
struct ScaleSpec {
|
||||
// specifies the way scale values will be applied to an ARG tensor.
|
||||
int mask;
|
||||
// specifies how scales are grouped along dimensions where
|
||||
// multiple scale factors are used.
|
||||
dnnl::memory::dims groups;
|
||||
// specifies data type for scale factors.
|
||||
dnnl::memory::data_type dtype;
|
||||
|
||||
// Helper to compute expected number of elements for scale tensors
|
||||
// arg_type: "src" for SRC (groups pattern {1, X}),
|
||||
// "wei" for WEIGHTS (groups pattern {X, 1})
|
||||
int64_t expected_numel(
|
||||
int64_t outer_dim,
|
||||
int64_t inner_dim,
|
||||
const std::string& arg_type) const {
|
||||
if (groups == dnnl::memory::dims{1, 1})
|
||||
return 1; // tensorwise scaling
|
||||
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
|
||||
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(groups == dnnl::memory::dims{1, inner_dim} ||
|
||||
groups == dnnl::memory::dims{inner_dim, 1}),
|
||||
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
|
||||
groups,
|
||||
".");
|
||||
return outer_dim;
|
||||
}
|
||||
|
||||
// Normalize an incoming scale tensor to contiguous storage and appropriate
|
||||
// dtype/view
|
||||
at::Tensor normalize(const at::Tensor& scale) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dtype == dnnl::memory::data_type::f32,
|
||||
"tensor scale currently must be f32, but got scale dtype: ",
|
||||
scale.scalar_type());
|
||||
return scale.to(at::kFloat).contiguous();
|
||||
}
|
||||
};
|
||||
|
||||
// This function defines how to set scales mask and groups according to:
|
||||
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
|
||||
// The returned value will be used in
|
||||
// `set_scales(arg, mask, groups, data_type)`.
|
||||
inline ScaleSpec make_scale_spec(
|
||||
at::blas::ScalingType scaling_type,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t N,
|
||||
const std::string& arg_type) {
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(scaling_type == at::blas::ScalingType::TensorWise ||
|
||||
scaling_type == at::blas::ScalingType::RowWise),
|
||||
"Currently only support scaling_type for TensorWise or RowWise");
|
||||
int64_t dim = K; // Currently only K is used for grouping
|
||||
bool is_src = (arg_type == "src");
|
||||
if (scaling_type == at::blas::ScalingType::TensorWise) {
|
||||
// Scale tensorwise. The same as `--attr-scales=common`.
|
||||
// mask=0 : scale whole tensor
|
||||
// groups={1, 1}: indicates that there is only one group for scaling
|
||||
return {0, {1, 1}, dnnl::memory::data_type::f32};
|
||||
} else {
|
||||
// (scaling_type == at::blas::ScalingType::RowWise)
|
||||
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
|
||||
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
|
||||
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
|
||||
return {
|
||||
(1 << 0) | (1 << 1),
|
||||
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
|
||||
dnnl::memory::data_type::f32};
|
||||
}
|
||||
}
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum) {
|
||||
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||
auto& stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
// This function will do steps with following steps
|
||||
// 1. create memory descriptor
|
||||
// 2. call write_to_dnnl_memory() to actually write memory
|
||||
// 3. execute
|
||||
|
||||
const int64_t M = mat1.size(0);
|
||||
const int64_t K = mat1.size(1);
|
||||
const int64_t N = mat2.size(1);
|
||||
|
||||
// 1.1 Create memory descriptor
|
||||
dnnl::memory::desc src_md = get_onednn_md(mat1);
|
||||
dnnl::memory::desc weights_md = get_onednn_md(mat2);
|
||||
dnnl::memory::desc dst_md = get_onednn_md(result);
|
||||
|
||||
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
|
||||
// So we could directly get their memory desc and set later.
|
||||
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
|
||||
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
|
||||
|
||||
dnnl::memory::desc bias_md;
|
||||
bool with_bias = bias.has_value();
|
||||
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
|
||||
if (with_bias) {
|
||||
if (possible_reshaped_bias.dim() == 1) {
|
||||
possible_reshaped_bias =
|
||||
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
} else {
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.2 Create primitive descriptor and set scales mask
|
||||
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
|
||||
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
|
||||
|
||||
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
op_attr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> default_groups;
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
|
||||
// scale_result tensor currently only supports scalar(TensorWise Scaling).
|
||||
bool with_dst_scale = scale_result && scale_result->defined();
|
||||
if (with_dst_scale) {
|
||||
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
|
||||
}
|
||||
|
||||
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// 1.3 Create the matmul primitive descriptor
|
||||
dnnl::matmul::primitive_desc matmul_pd = with_bias
|
||||
? dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, bias_md, dst_md, op_attr)
|
||||
: dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, dst_md, op_attr);
|
||||
|
||||
// 1.4 (Possible) Additional Checks
|
||||
// TODO: In case there are memory desc does not align with the actual tensor,
|
||||
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
|
||||
// call. For example, weights not the same as matmul_pd.weights_desc(),
|
||||
|
||||
// 2. Prepare memory
|
||||
|
||||
// Create memory
|
||||
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
|
||||
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
|
||||
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
|
||||
dnnl::memory b_usr_m;
|
||||
if (with_bias) {
|
||||
b_usr_m =
|
||||
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
|
||||
}
|
||||
|
||||
// Prepare runtime scale memories (flat 1-D views) using the specs
|
||||
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
|
||||
int64_t expected_numel,
|
||||
const at::Tensor& scale_tensor) {
|
||||
at::Tensor prepared = spec.normalize(scale_tensor);
|
||||
TORCH_CHECK(
|
||||
prepared.numel() == expected_numel,
|
||||
"Scale buffer length mismatch. Expected ",
|
||||
expected_numel,
|
||||
", got ",
|
||||
prepared.numel());
|
||||
dnnl::memory::desc scale_md(
|
||||
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
|
||||
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
|
||||
};
|
||||
|
||||
auto scratchpad =
|
||||
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
|
||||
|
||||
// 3. Setup Args for exec
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
args.insert({DNNL_ARG_SRC, src_usr_m});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
|
||||
args.insert({DNNL_ARG_DST, dst_usr_m});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
if (with_bias) {
|
||||
args.insert({DNNL_ARG_BIAS, b_usr_m});
|
||||
}
|
||||
|
||||
// Attach runtime scales using specs
|
||||
auto src_sc_mem = make_scale_mem_from_spec(
|
||||
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
|
||||
auto wei_sc_mem = make_scale_mem_from_spec(
|
||||
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
|
||||
if (with_dst_scale) {
|
||||
// Bind single f32 scalar as DST scale
|
||||
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
|
||||
dnnl::memory::desc dst_sc_md(
|
||||
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
|
||||
auto dst_sc_mem =
|
||||
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
|
||||
}
|
||||
|
||||
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
|
||||
sycl::event matmul_fwd_event =
|
||||
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||
return matmul_fwd_event;
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -78,6 +78,10 @@ dnnl::memory::data_type get_onednn_dtype(
|
||||
return dnnl::memory::data_type::f32;
|
||||
case at::ScalarType::BFloat16:
|
||||
return dnnl::memory::data_type::bf16;
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return dnnl::memory::data_type::f8_e4m3;
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return dnnl::memory::data_type::f8_e5m2;
|
||||
default:
|
||||
if (!allow_undef) {
|
||||
TORCH_CHECK(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
@ -202,4 +203,16 @@ void sdpa_backward(
|
||||
Tensor& grad_query,
|
||||
Tensor& grad_key,
|
||||
Tensor& grad_value);
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum);
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -118,6 +118,11 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "121a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
@ -126,7 +131,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a")
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
|
||||
113
docs/source/accelerator/device.md
Normal file
113
docs/source/accelerator/device.md
Normal file
@ -0,0 +1,113 @@
|
||||
# Device Management
|
||||
|
||||
## Background
|
||||
|
||||
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
|
||||
|
||||
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
|
||||
|
||||
## Design
|
||||
|
||||
Accelerator vendors need to implement these core functions:
|
||||
|
||||
| Function Name | Description | Application Scenarios |
|
||||
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
|
||||
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
|
||||
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
|
||||
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
|
||||
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
|
||||
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
|
||||
|
||||
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
|
||||
|
||||
## Implementation
|
||||
|
||||
This section shows how to implement device management using `set_device` as an example. The implementation requires:
|
||||
1. C++ wrappers around the device runtime
|
||||
2. Python bindings to expose the C++ functions
|
||||
3. User-friendly Python APIs
|
||||
|
||||
### C++ Side
|
||||
|
||||
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### Binding
|
||||
|
||||
Expose the C++ functions to Python using pybind11:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 5
|
||||
```
|
||||
|
||||
### Python Side
|
||||
|
||||
Wrap the C++ bindings with user-friendly Python functions:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
Here's the complete mapping from C++ to Python:
|
||||
|
||||
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
|
||||
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
|
||||
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
|
||||
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
|
||||
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
|
||||
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
|
||||
|
||||
## Guard
|
||||
|
||||
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
|
||||
|
||||
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
**What needs to be implemented:**
|
||||
|
||||
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
|
||||
2. **getDevice()**: Get the current device
|
||||
3. **setDevice()**: Set the active device
|
||||
4. **Type checking**: Validate that device type matches the backend
|
||||
|
||||
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
|
||||
|
||||
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
device
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
|
||||
@ -376,3 +376,19 @@ keep-runtime-typing = true
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
[tool.spin]
|
||||
package = 'torch'
|
||||
|
||||
[tool.spin.commands]
|
||||
"Build" = [
|
||||
".spin/cmds.py:lint",
|
||||
".spin/cmds.py:fixlint",
|
||||
".spin/cmds.py:quicklint",
|
||||
".spin/cmds.py:quickfix",
|
||||
]
|
||||
"Regenerate" = [
|
||||
".spin/cmds.py:regenerate_version",
|
||||
".spin/cmds.py:regenerate_type_stubs",
|
||||
".spin/cmds.py:regenerate_clangtidy_files",
|
||||
]
|
||||
|
||||
@ -14,6 +14,7 @@ lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
|
||||
networkx>=2.5.1
|
||||
optree>=0.13.0
|
||||
psutil
|
||||
spin
|
||||
sympy>=1.13.3
|
||||
typing-extensions>=4.13.2
|
||||
wheel
|
||||
|
||||
@ -4,17 +4,12 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
void orCheckFail(
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg = "");
|
||||
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
|
||||
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (__err != orSuccess) { \
|
||||
orCheckFail( \
|
||||
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (C10_UNLIKELY(__err != orSuccess)) { \
|
||||
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegException.h"
|
||||
@ -9,21 +10,22 @@ orError_t GetDeviceCount(int* dev_count) {
|
||||
return orGetDeviceCount(dev_count);
|
||||
}
|
||||
|
||||
orError_t GetDevice(c10::DeviceIndex* device) {
|
||||
orError_t GetDevice(DeviceIndex* device) {
|
||||
int tmp_device = -1;
|
||||
auto err = orGetDevice(&tmp_device);
|
||||
*device = static_cast<c10::DeviceIndex>(tmp_device);
|
||||
*device = static_cast<DeviceIndex>(tmp_device);
|
||||
return err;
|
||||
}
|
||||
|
||||
orError_t SetDevice(c10::DeviceIndex device) {
|
||||
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
orError_t SetDevice(DeviceIndex device) {
|
||||
int cur_device = -1;
|
||||
orGetDevice(&cur_device);
|
||||
OPENREG_CHECK(orGetDevice(&cur_device));
|
||||
if (device == cur_device) {
|
||||
return orSuccess;
|
||||
}
|
||||
return orSetDevice(device);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
|
||||
int device_count_impl() {
|
||||
int count = 0;
|
||||
@ -31,34 +33,37 @@ int device_count_impl() {
|
||||
return count;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept {
|
||||
// initialize number of devices only once
|
||||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl();
|
||||
TORCH_CHECK(
|
||||
result <= std::numeric_limits<c10::DeviceIndex>::max(),
|
||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
} catch (const c10::Error& ex) {
|
||||
} catch (const Error& ex) {
|
||||
// We don't want to fail, but still log the warning
|
||||
// msg() returns the message without the stack trace
|
||||
TORCH_WARN("Device initialization: ", ex.msg());
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
return static_cast<c10::DeviceIndex>(count);
|
||||
return static_cast<DeviceIndex>(count);
|
||||
}
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||
c10::DeviceIndex cur_device = -1;
|
||||
GetDevice(&cur_device);
|
||||
OPENREG_EXPORT DeviceIndex current_device() {
|
||||
DeviceIndex cur_device = -1;
|
||||
OPENREG_CHECK(GetDevice(&cur_device));
|
||||
return cur_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||
SetDevice(device);
|
||||
// LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device) {
|
||||
check_device_index(device);
|
||||
OPENREG_CHECK(SetDevice(device));
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
int current_device = -1;
|
||||
@ -71,4 +76,8 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
return current_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
|
||||
check_device_index(to_device);
|
||||
return ExchangeDevice(to_device);
|
||||
}
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -9,10 +9,20 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||
|
||||
static inline void check_device_index(int64_t device) {
|
||||
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
|
||||
"The device index is out of range. It must be in [0, ",
|
||||
static_cast<int>(c10::openreg::device_count()),
|
||||
"), but got ",
|
||||
static_cast<int>(device),
|
||||
".");
|
||||
}
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
|
||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
||||
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||
|
||||
@ -58,6 +59,7 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
set_device(d.index());
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
|
||||
/**
|
||||
* Set the current device to c10::Device, without checking for errors
|
||||
|
||||
@ -27,6 +27,10 @@ class TestDevice(TestCase):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
|
||||
def test_invalid_device_index(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||
torch.accelerator.set_device_index(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -34,18 +34,21 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
|
||||
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
||||
auto device = THPUtils_unpackLong(arg);
|
||||
|
||||
auto device = THPUtils_unpackDeviceIndex(arg);
|
||||
torch::utils::device_lazy_init(at::kPrivateUse1);
|
||||
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
|
||||
c10::openreg::set_device(device);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
|
||||
|
||||
@ -41,8 +41,13 @@ def current_device():
|
||||
return torch_openreg._C._get_device()
|
||||
|
||||
|
||||
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
def set_device(device) -> None:
|
||||
return torch_openreg._C._set_device(device)
|
||||
if device >= 0:
|
||||
torch_openreg._C._set_device(device)
|
||||
|
||||
|
||||
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
|
||||
|
||||
def init():
|
||||
|
||||
@ -65,6 +65,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
curr_backend = dist.get_default_backend_for_device(device_type)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
@ -422,10 +423,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||
return f"cpu:gloo,{device_type}:{curr_backend}"
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_fully_shard_tp_2d_set_full_state_dict(self):
|
||||
dummy_model = SimpleModel().to(device_type)
|
||||
mesh_2d = init_device_mesh(
|
||||
@ -514,8 +515,8 @@ class Test2dFSDP1ParallelIntegration(DTensorTestBase):
|
||||
).to_local()
|
||||
self.assertEqual(param_m2, param_m1)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_ddp_integration_functionality(self) -> None:
|
||||
model, twod_model, dp_pg = self.init_model(self.device_type)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
|
||||
@ -566,8 +567,8 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
|
||||
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_fsdp_state_enable_extension(self):
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
@ -642,18 +643,18 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
# Ensure all params are still the same after optimizer update.
|
||||
self._compare_params(model, model_2d)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_e2e_training_default(self):
|
||||
self._test_2d_e2e_training()
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_e2e_training_use_orig_params(self):
|
||||
self._test_2d_e2e_training(use_orig_params=True)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_2d_e2e_training_not_use_orig_params(self):
|
||||
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
|
||||
# self._test_2d_e2e_training(recompute_activation=True)
|
||||
@ -666,10 +667,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||
return f"cpu:gloo,{device_type}:{curr_backend}"
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_fsdp_2d_extension(self):
|
||||
"""
|
||||
Test whether _fsdp_extension from FSDPstate has been set correctly.
|
||||
@ -700,8 +701,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
|
||||
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -756,8 +757,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_load_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -811,8 +812,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
self.assertEqual(v1.device_mesh, v2.device_mesh)
|
||||
self.assertEqual(v1.placements, v2.placements)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_optim_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -899,9 +900,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
else:
|
||||
self.assertEqual(new_state, state)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fsdp1_tp_2d_set_full_state_dict(self):
|
||||
"""
|
||||
This is a workaround for loading full state dict into a FSDP1+TP 2D model.
|
||||
|
||||
@ -29,8 +29,8 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
at_least_x_gpu,
|
||||
MultiProcessTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
@ -40,7 +40,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
@ -107,11 +106,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
def device(self):
|
||||
return self.rank
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
def test_pp_and_dcp(self):
|
||||
"""
|
||||
Test that pipeline parallelism and distributed checkpointing can be used together and
|
||||
@ -201,11 +198,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
_dcp_test(self)
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
@ -355,11 +350,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
@ -550,11 +543,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@ -18,8 +17,8 @@ from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
DistributedTestBase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
@ -30,9 +29,12 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
visible_devices = list(range(torch.cuda.device_count()))
|
||||
gpus_per_process = torch.cuda.device_count() // world_size
|
||||
visible_devices = list(range(torch.accelerator.device_count()))
|
||||
gpus_per_process = torch.accelerator.device_count() // world_size
|
||||
gpus_for_rank = []
|
||||
for rank in range(world_size):
|
||||
gpus_for_rank.append(
|
||||
@ -60,27 +62,7 @@ class TestDdpCommHook(nn.Module):
|
||||
return self.t0(x ** (1 + rank))
|
||||
|
||||
|
||||
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_process_group_nccl(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
return dist.distributed_c10d._get_default_group()
|
||||
|
||||
class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
@ -119,14 +101,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
param = next(model.parameters())
|
||||
return param.grad
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_allreduce_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``allreduce`` hook registered case gives same result
|
||||
with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -135,14 +117,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_fp16compress_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``fp16 compress`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -151,14 +133,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``quantize per tensor`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -167,14 +149,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_quantize_per_channel_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``quantize per channel`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -185,14 +167,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_noop_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
|
||||
gives same result with no hook registered case.
|
||||
"""
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -204,10 +186,10 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_is_last_hook(self):
|
||||
process_group = self._get_process_group_nccl()
|
||||
process_group = self.create_pg(device_type)
|
||||
|
||||
def hook(flags, bucket):
|
||||
flags.append(bucket.is_last())
|
||||
|
||||
@ -32,7 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
class TestStateDictUtils(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.accelerator.device_count())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@ -49,7 +49,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
)
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
|
||||
self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -69,14 +69,16 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
)
|
||||
if dist.get_rank() in (0, 2):
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
|
||||
self.assertNotEqual(
|
||||
gathered_state_dict["dtensor"].device.type, self.device_type
|
||||
)
|
||||
else:
|
||||
self.assertEqual(gathered_state_dict, {})
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_cpu_and_ranks_only(self):
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(self.device_type)
|
||||
state_dict = {
|
||||
"tensor1": torch.arange(10, device=device),
|
||||
"tensor2": torch.ones(10, device=device),
|
||||
@ -85,7 +87,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
|
||||
if dist.get_rank() in (0, 2):
|
||||
for v in cpu_state_dict.values():
|
||||
self.assertFalse(v.is_cuda)
|
||||
self.assertNotEqual(v.device.type, self.device_type)
|
||||
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
|
||||
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
|
||||
else:
|
||||
@ -109,27 +111,27 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
for _ in range(10):
|
||||
tensor, dtensor = create_dtensor()
|
||||
ltensor.append(tensor)
|
||||
ltensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||
ltensor.append(torch.ones(10, device=torch.device(self.device_type)))
|
||||
ldtensor.append(dtensor)
|
||||
ldtensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||
ldtensor.append(torch.ones(10, device=torch.device(self.device_type)))
|
||||
|
||||
tensor, dtensor = create_dtensor()
|
||||
dist_state_dict = {
|
||||
"local": dtensor,
|
||||
"list": ldtensor,
|
||||
"arange": torch.arange(10, device=torch.device("cuda")),
|
||||
"arange": torch.arange(10, device=torch.device(self.device_type)),
|
||||
}
|
||||
state_dict = {
|
||||
"local": tensor,
|
||||
"list": ltensor,
|
||||
"arange": torch.arange(10, device=torch.device("cuda")),
|
||||
"arange": torch.arange(10, device=torch.device(self.device_type)),
|
||||
}
|
||||
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_create_cpu_state_dict(self):
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(self.device_type)
|
||||
rank = dist.get_rank()
|
||||
# Scale tensors based on world size
|
||||
# to fit in the tensor shards accurately.
|
||||
@ -149,7 +151,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
metadata=ShardMetadata(
|
||||
shard_offsets=[5 * rank, 0],
|
||||
shard_sizes=[5, 10],
|
||||
placement=f"rank:{rank}/cuda:{rank}",
|
||||
placement=f"rank:{rank}/{self.device_type}:{rank}",
|
||||
),
|
||||
)
|
||||
],
|
||||
@ -159,7 +161,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
torch.arange(50 * scale_factor, device=device).reshape(
|
||||
5 * scale_factor, 10
|
||||
),
|
||||
init_device_mesh("cuda", mesh_shape=(self.world_size,)),
|
||||
init_device_mesh(self.device_type, mesh_shape=(self.world_size,)),
|
||||
[Shard(0)],
|
||||
),
|
||||
"non_tensor_bytes_io": copy.deepcopy(buffer),
|
||||
@ -245,7 +247,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
even_tensor = torch.randn(self.world_size, 2)
|
||||
uneven_tensor = torch.randn(1, 2)
|
||||
|
||||
mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
|
||||
mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
even_dtensor = distribute_tensor(
|
||||
torch.randn(self.world_size, 2), mesh, [Shard(0)]
|
||||
)
|
||||
@ -273,10 +275,10 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cpu_offload_for_dtensor(self):
|
||||
device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
sd = {
|
||||
"k": DTensor.from_local(
|
||||
torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)]
|
||||
torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)]
|
||||
)
|
||||
}
|
||||
cpu_sd = _create_cpu_state_dict(sd)
|
||||
@ -290,12 +292,12 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
|
||||
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
_copy_state_dict(sd, cpu_sd, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
sd["k"] += 1
|
||||
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
_copy_state_dict(sd, cpu_sd, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
@ -40,7 +40,6 @@ from torch.testing._internal.common_distributed import (
|
||||
skip_if_rocm_multiprocess,
|
||||
skip_if_win32,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -57,7 +56,17 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_algorithms(enabled=True):
|
||||
prev_state = torch.are_deterministic_algorithms_enabled()
|
||||
torch.use_deterministic_algorithms(enabled)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(prev_state)
|
||||
|
||||
|
||||
class TestZeroRedundancyOptimizer(DistributedTestBase):
|
||||
@ -1241,7 +1250,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
enabled=True, deterministic=True, benchmark=False
|
||||
)
|
||||
if "cuda" in device
|
||||
else torch.use_deterministic_algorithms(True)
|
||||
else deterministic_algorithms(True)
|
||||
)
|
||||
with det_ctx:
|
||||
device_ids = [rank] if requires_ddp_rank(device) else None
|
||||
|
||||
@ -32,6 +32,7 @@ from torch.distributed.tensor._ops._einsum_strategy import (
|
||||
)
|
||||
from torch.distributed.tensor._ops.utils import (
|
||||
register_op_strategy,
|
||||
register_single_dim_strategy,
|
||||
replicate_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
@ -655,5 +656,202 @@ TestStrategyHashingWithLocalTensor = create_local_tensor_test_class(
|
||||
TestStrategyHashing,
|
||||
)
|
||||
|
||||
|
||||
class TestSingleDimStrategy(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_register_single_dim_strategy_replaces_existing_rule(self):
|
||||
"""
|
||||
Test that calling register_single_dim_strategy works and replaces an existing registered rule.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_mm_like_strategy,
|
||||
gen_single_dim_einsum_strategies,
|
||||
)
|
||||
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
# Create test inputs
|
||||
lhs_tensor = torch.randn(6, 8)
|
||||
rhs_tensor = torch.randn(8, 12)
|
||||
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
|
||||
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
|
||||
|
||||
# Test a specific input sharding combination
|
||||
lhs_placement = (Shard(1),)
|
||||
rhs_placement = (Shard(0),)
|
||||
lhs_spec = DTensorSpec(mesh, lhs_placement, lhs_tensor_meta)
|
||||
rhs_spec = DTensorSpec(mesh, rhs_placement, rhs_tensor_meta)
|
||||
|
||||
# Create the OpSchema for mm operation
|
||||
op_schema = OpSchema(
|
||||
torch.ops.aten.mm.default,
|
||||
(
|
||||
OpStrategy([OpSpec(lhs_spec)]),
|
||||
OpStrategy([OpSpec(rhs_spec)]),
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Get the strategies from the old mm_like_strategy (what was used before)
|
||||
old_style_strategy = _mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
|
||||
# Get the strategies from the new register_single_dim_strategy approach
|
||||
# First, we need to get the single dim strategy function
|
||||
def mm_single_dim_strategy_func(op_schema: OpSchema):
|
||||
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
# Now expand it to full strategy using the same logic as register_single_dim_strategy
|
||||
single_dim_strategies = mm_single_dim_strategy_func(op_schema)
|
||||
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [
|
||||
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
|
||||
]
|
||||
all_strategies.append(
|
||||
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||
)
|
||||
new_style_strategy = OpStrategy(all_strategies)
|
||||
|
||||
# Verify that both strategies produce the same set of shardings
|
||||
old_strategy_set = {str(strategy) for strategy in old_style_strategy.strategies}
|
||||
new_strategy_set = {str(strategy) for strategy in new_style_strategy.strategies}
|
||||
|
||||
self.assertEqual(
|
||||
old_strategy_set,
|
||||
new_strategy_set,
|
||||
"Old and new strategies should produce the same shardings",
|
||||
)
|
||||
|
||||
# Verify that the registration actually works by checking the propagator
|
||||
propagator = DTensor._op_dispatcher.sharding_propagator
|
||||
|
||||
# Save the original strategy if it exists
|
||||
original_strategy = None
|
||||
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
|
||||
original_strategy = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
|
||||
|
||||
try:
|
||||
# Register a custom single-dim strategy
|
||||
@register_single_dim_strategy(torch.ops.aten.mm.default)
|
||||
def custom_mm_single_dim_strategy(op_schema: OpSchema):
|
||||
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
# Verify the strategy was registered
|
||||
self.assertIn(
|
||||
torch.ops.aten.mm.default,
|
||||
propagator.op_strategy_funcs,
|
||||
"Strategy should be registered after calling register_single_dim_strategy",
|
||||
)
|
||||
|
||||
# Verify it replaced any existing rule
|
||||
registered_func = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
|
||||
self.assertIsNotNone(
|
||||
registered_func, "Registered strategy function should not be None"
|
||||
)
|
||||
|
||||
# Test that the registered strategy produces valid output
|
||||
result_strategy = registered_func(op_schema)
|
||||
self.assertIsInstance(
|
||||
result_strategy, OpStrategy, "Result should be an OpStrategy"
|
||||
)
|
||||
self.assertGreater(
|
||||
len(result_strategy.strategies),
|
||||
0,
|
||||
"Strategy should contain at least one OpSpec",
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore original strategy if it existed
|
||||
if original_strategy is not None:
|
||||
propagator.op_strategy_funcs[torch.ops.aten.mm.default] = (
|
||||
original_strategy
|
||||
)
|
||||
else:
|
||||
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
|
||||
del propagator.op_strategy_funcs[torch.ops.aten.mm.default]
|
||||
# Clear the cache
|
||||
propagator.propagate_op_sharding.cache.cache_clear()
|
||||
|
||||
@with_comms
|
||||
def test_single_dim_strategy_shardings_match_full_strategy(self):
|
||||
"""
|
||||
Verify that the shardings produced by a single-dim strategy match those produced
|
||||
by the full strategy implementation.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
gen_single_dim_einsum_strategies,
|
||||
)
|
||||
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
# Create test inputs
|
||||
lhs_tensor = torch.randn(6, 8)
|
||||
rhs_tensor = torch.randn(8, 12)
|
||||
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
|
||||
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
|
||||
|
||||
# Test multiple input sharding combinations
|
||||
mm_combs = (
|
||||
(Shard(0), Replicate()),
|
||||
(Replicate(), Shard(1)),
|
||||
(Shard(1), Shard(0)),
|
||||
(Replicate(), Replicate()),
|
||||
)
|
||||
|
||||
for lhs_placement, rhs_placement in mm_combs:
|
||||
lhs_spec = DTensorSpec(mesh, (lhs_placement,), lhs_tensor_meta)
|
||||
rhs_spec = DTensorSpec(mesh, (rhs_placement,), rhs_tensor_meta)
|
||||
|
||||
op_schema = OpSchema(
|
||||
torch.ops.aten.mm.default,
|
||||
(
|
||||
OpStrategy([OpSpec(lhs_spec)]),
|
||||
OpStrategy([OpSpec(rhs_spec)]),
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Get single-dim strategies
|
||||
single_dim_strategies = gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
# Expand to full strategy (mimicking what register_single_dim_strategy does)
|
||||
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
expanded_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [
|
||||
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
|
||||
]
|
||||
expanded_strategies.append(
|
||||
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||
)
|
||||
|
||||
# Verify that for the given input shardings, we can find a matching strategy
|
||||
# with zero redistribute cost
|
||||
found_zero_cost_strategy = False
|
||||
for strategy in expanded_strategies:
|
||||
if strategy.input_specs == (lhs_spec, rhs_spec):
|
||||
# This strategy should have zero redistribute cost since inputs match
|
||||
found_zero_cost_strategy = True
|
||||
# In a real strategy, redistribute costs would be computed
|
||||
# Here we just verify the structure is correct
|
||||
self.assertEqual(
|
||||
len(strategy.input_specs),
|
||||
2,
|
||||
"MM should have exactly 2 input specs",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
strategy.output_specs, "Output spec should not be None"
|
||||
)
|
||||
break
|
||||
|
||||
self.assertTrue(
|
||||
found_zero_cost_strategy,
|
||||
f"Should find a strategy matching input shardings {lhs_placement}, {rhs_placement}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -331,6 +331,25 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
self.assertEqual(z.placements, (Replicate(),))
|
||||
self.assertEqual(z.to_local(), input)
|
||||
|
||||
def test_inplace_op_partial_to_replicate(self):
|
||||
# test that in-place operations that require redistribution raise an error
|
||||
# to preserve aliasing semantics (issue #163374)
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
input_tensor = torch.tensor(64.0, device=self.device_type)
|
||||
partial_dt = DTensor.from_local(
|
||||
input_tensor, device_mesh, placements=(Partial(),)
|
||||
)
|
||||
|
||||
self.assertTrue(partial_dt.placements[0].is_partial())
|
||||
|
||||
# Inplace ops that require placement changes (Partial -> Replicate) should error
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"in-place operations that require placement changes are not supported",
|
||||
):
|
||||
partial_dt.clamp_(max=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -24,7 +24,7 @@ from torch.distributed._functional_collectives import (
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_device_type import e4m3_type
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
DistributedTestBase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
@ -59,12 +59,8 @@ if not dist.is_available():
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
class TestWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
class TestWithNCCL(DistributedTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
@ -78,16 +74,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
return torch.device(self.rank)
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
backend = dist.get_default_backend_for_device(self.device.type)
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
self.create_pg(self.device.type)
|
||||
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
||||
@ -11,13 +11,10 @@ if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
@ -29,16 +26,8 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_HPU:
|
||||
DEVICE = "hpu"
|
||||
elif TEST_CUDA:
|
||||
DEVICE = "cuda"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
device_module = torch.get_device_module(DEVICE)
|
||||
device_count = device_module.device_count()
|
||||
BACKEND = dist.get_default_backend_for_device(DEVICE)
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
@ -49,11 +38,10 @@ def with_comms(func=None):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if DEVICE != "cpu" and device_count < self.world_size:
|
||||
if device_type != "cpu" and device_count < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
kwargs["device"] = DEVICE
|
||||
self.pg = self.create_pg(device=DEVICE)
|
||||
self.pg = self.create_pg(device=device_type)
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
@ -64,7 +52,7 @@ def with_comms(func=None):
|
||||
|
||||
class TestObjectCollectives(DistributedTestBase):
|
||||
@with_comms()
|
||||
def test_all_gather_object(self, device):
|
||||
def test_all_gather_object(self):
|
||||
output = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(object_list=output, obj=self.rank)
|
||||
|
||||
@ -72,7 +60,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@with_comms()
|
||||
def test_gather_object(self, device):
|
||||
def test_gather_object(self):
|
||||
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
||||
dist.gather_object(obj=self.rank, object_gather_list=output)
|
||||
|
||||
@ -82,7 +70,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_send_recv_object_list(self, device):
|
||||
def test_send_recv_object_list(self):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
if self.rank == 0:
|
||||
@ -96,7 +84,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(None, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_broadcast_object_list(self, device):
|
||||
def test_broadcast_object_list(self):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
# TODO test with broadcast_object_list's device argument
|
||||
@ -105,7 +93,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_scatter_object_list(self, device):
|
||||
def test_scatter_object_list(self):
|
||||
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
||||
output_list = [None]
|
||||
dist.scatter_object_list(
|
||||
@ -123,34 +111,30 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
my_pg = dist.new_group(ranks, use_local_synchronization=True)
|
||||
return rank, ranks, my_pg
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_scatter_object(self, device):
|
||||
def test_subpg_scatter_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
|
||||
self.assertEqual(rank, out_list[0])
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_all_gather_object(self, device):
|
||||
def test_subpg_all_gather_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks)
|
||||
dist.all_gather_object(out_list, rank, group=my_pg)
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_gather_object(self, device):
|
||||
def test_subpg_gather_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks) if rank == ranks[0] else None
|
||||
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
|
||||
if rank == ranks[0]:
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_broadcast_object(self, device):
|
||||
def test_subpg_broadcast_object(self):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
if rank == ranks[0]:
|
||||
@ -159,7 +143,5 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(ranks[0], out_list[0])
|
||||
|
||||
|
||||
devices = ("cpu", "cuda", "hpu")
|
||||
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -29,7 +29,7 @@ from torch.distributed.tensor._collective_utils import (
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _Partial, Shard
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
|
||||
os.environ["LOCAL_RANK"] = f"{local_rank}"
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
|
||||
@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.")
|
||||
class DeviceMeshTestGlooBackend(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
|
||||
@ -208,6 +208,21 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
||||
)
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
def test_get_remote_tensors(self) -> None:
|
||||
"""
|
||||
Get all remote tensors
|
||||
"""
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
my_tensor = symm_mem.empty(1, device=self.device).fill_(self.rank)
|
||||
remote_tensors = torch.ops.symm_mem.get_remote_tensors(my_tensor, group_name)
|
||||
dist.barrier()
|
||||
|
||||
for peer, tensor in enumerate(remote_tensors):
|
||||
self.assertEqual(tensor, peer)
|
||||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_put(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,13 +15,16 @@ import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
|
||||
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
|
||||
from torch._dynamo.exc import PackageError, Unsupported
|
||||
from torch._dynamo.package import DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TEST_CUDA,
|
||||
)
|
||||
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
@ -599,6 +604,92 @@ from user code:
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
|
||||
fn,
|
||||
(make_inputs(), {}),
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_module(self):
|
||||
with torch.device("cuda"):
|
||||
from torch._dynamo.hooks import Hooks
|
||||
|
||||
mod = SimpleLinearModule()
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(4, 3),)
|
||||
|
||||
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
|
||||
mod,
|
||||
[ModelInput(make_inputs(), {}, [])],
|
||||
Hooks(),
|
||||
torch._TorchCompileAOTInductorWrapper(None, None, None),
|
||||
)
|
||||
|
||||
def get_grads(m: torch.nn.Module):
|
||||
return {name: p.grad for name, p in m.named_parameters()}
|
||||
|
||||
original_mod = copy.deepcopy(mod)
|
||||
test_inputs = make_inputs()
|
||||
expected = mod(*test_inputs)
|
||||
expected.sum().backward()
|
||||
expected_grads = get_grads(mod)
|
||||
|
||||
actual = compiled_mod(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
serialized = compiled_mod.serialize()
|
||||
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
actual.sum().backward()
|
||||
self.assertEqual(get_grads(original_mod), expected_grads)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_aot_compile_with_aoti_torch_compile(self):
|
||||
with torch.device("cuda"):
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn, fullgraph=True, options={"use_aoti": True}
|
||||
).aot_compile((make_inputs(), {}))
|
||||
test_inputs = make_inputs()
|
||||
expected = fn(*test_inputs)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*test_inputs)
|
||||
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -952,7 +952,9 @@ User code traceback:
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: skip: from user code at:
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
assert x is None
|
||||
""",
|
||||
@ -1078,6 +1080,88 @@ from user code:
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback(self, records):
|
||||
def fn(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message(self, records):
|
||||
def fn(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(dynamo=logging.DEBUG)
|
||||
def test_skip_frame_empty_function_message(self, records):
|
||||
def empty_fn(x):
|
||||
pass
|
||||
|
||||
torch.compile(empty_fn, backend="eager")(torch.randn(3))
|
||||
skip_messages = [
|
||||
r
|
||||
for r in records
|
||||
if "intentionally decided to skip the frame" in r.getMessage()
|
||||
]
|
||||
self.assertEqual(len(skip_messages), 1)
|
||||
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
|
||||
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
|
||||
|
||||
self.assertExpectedInline(
|
||||
msg,
|
||||
"""\
|
||||
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
|
||||
Reason: no content in function call empty_fn test_error_messages.py N""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_nested_compile_user_frames(self, records):
|
||||
def fn(x):
|
||||
@ -1624,6 +1708,110 @@ from user code:
|
||||
)
|
||||
|
||||
|
||||
class NestedGraphBreakLoggingTests(
|
||||
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 2)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 3)
|
||||
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 3)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 2)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 4)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 5)
|
||||
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Data-dependent branching
|
||||
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
|
||||
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
|
||||
Hint: Use `torch.cond` to express dynamic control flow.
|
||||
|
||||
Developer debug context: attempted to jump with TensorVariable()
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 5)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 4)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -14036,6 +14036,44 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorify_track_item_symint(self):
|
||||
def _random_resize(image: torch.Tensor):
|
||||
image_metanet = image
|
||||
default_patch_size = 14
|
||||
rand_cnn_resolution = (224, 256)
|
||||
min_nump = rand_cnn_resolution[0] // default_patch_size
|
||||
max_nump = rand_cnn_resolution[1] // default_patch_size
|
||||
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
|
||||
torch._check(new_nump > 0)
|
||||
torch._check(new_nump * default_patch_size > 1)
|
||||
|
||||
image_metanet = F.interpolate(
|
||||
image_metanet,
|
||||
size=(new_nump * default_patch_size, new_nump * default_patch_size),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
img_h_new, img_w_new = image_metanet.shape[2:]
|
||||
|
||||
return (img_h_new, img_w_new), image_metanet
|
||||
|
||||
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
|
||||
|
||||
# Test the function
|
||||
input_tensor = torch.rand(1, 3, 224, 224)
|
||||
(h, w), output = _random_resize_compiled(input_tensor)
|
||||
|
||||
# Verify output properties
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 3)
|
||||
self.assertEqual(output.shape[2], h)
|
||||
self.assertEqual(output.shape[3], w)
|
||||
self.assertTrue(h % 14 == 0)
|
||||
self.assertTrue(w % 14 == 0)
|
||||
self.assertTrue(224 <= h <= 256)
|
||||
self.assertTrue(224 <= w <= 256)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -3249,7 +3249,14 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
V_sliced = V[:, :, :-128]
|
||||
|
||||
out_eager = flex_attention(Q, K_sliced, V_sliced)
|
||||
out_compiled = func(Q, K_sliced, V_sliced)
|
||||
|
||||
out_compiled, code = run_and_get_code(func, Q, K_sliced, V_sliced)
|
||||
|
||||
# Make sure flex attention kernels have flex_attention in name
|
||||
FileCheck().check_regex("triton_tem_fused_flex_attention.*").run(code[0])
|
||||
FileCheck().check_regex("triton_tem_fused_flex_attention_backward.*").run(
|
||||
code[1]
|
||||
)
|
||||
|
||||
grad = torch.rand_like(out_eager)
|
||||
|
||||
|
||||
@ -2439,6 +2439,35 @@ class _TorchCompileInductorWrapper:
|
||||
reset_cudagraph_trees()
|
||||
|
||||
|
||||
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
|
||||
compiler_name = "aotinductor"
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
super().__init__(mode, options, dynamic)
|
||||
self.apply_options({"cpp_wrapper": True})
|
||||
self.apply_options({"aot_inductor.package": True})
|
||||
|
||||
def __call__(self, model_, inputs_):
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
fake_mode = detect_fake_mode(inputs_)
|
||||
ctx = (
|
||||
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
||||
if fake_mode
|
||||
else nullcontext()
|
||||
)
|
||||
with (
|
||||
V.set_aot_compilation(True),
|
||||
ctx,
|
||||
torch._inductor.config.patch("enable_autograd_for_aot", True),
|
||||
):
|
||||
return super().__call__(model_, inputs_)
|
||||
|
||||
|
||||
class _TorchCompileWrapper:
|
||||
def __init__(self, backend, mode, options, dynamic):
|
||||
from torch._dynamo.backends.registry import lookup_backend
|
||||
@ -2672,8 +2701,10 @@ def compile(
|
||||
backend = bisect_backend
|
||||
|
||||
guard_filter_fn = None
|
||||
use_aoti = False
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
use_aoti = options.pop("use_aoti", False)
|
||||
|
||||
if torch.compiler.is_exporting():
|
||||
warnings.warn(
|
||||
@ -2700,7 +2731,10 @@ def compile(
|
||||
return export_wrapped_fn
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
if use_aoti:
|
||||
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ class CompileArtifacts:
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
source_info: "SourceInfo"
|
||||
device_type: str
|
||||
backend_name: str
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
||||
def check_compatibility(self) -> None:
|
||||
@ -166,7 +167,8 @@ class AOTCompiledFunction:
|
||||
state = pickle.loads(data)
|
||||
state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
|
||||
deserializer, compiled_fn_state = state["compiled_fn"]
|
||||
state["compiled_fn"] = deserializer(compiled_fn_state)
|
||||
with torch._inductor.config.patch(enable_autograd_for_aot=True):
|
||||
state["compiled_fn"] = deserializer(compiled_fn_state)
|
||||
state["original_code"] = SerializedCode.to_code_object(state["original_code"])
|
||||
|
||||
artifacts = CompileArtifacts(**state)
|
||||
@ -273,6 +275,7 @@ def aot_compile_fullgraph(
|
||||
argdefs=fn.__defaults__,
|
||||
source_info=source_info,
|
||||
device_type=device_type,
|
||||
backend_name=getattr(backend, "compiler_name", "unknown"),
|
||||
)
|
||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||
|
||||
|
||||
@ -1870,7 +1870,7 @@ class ConvertFrame:
|
||||
raise
|
||||
|
||||
soft_fail = isinstance(e, Unsupported)
|
||||
|
||||
code = frame.f_code
|
||||
# This is a soft failure. In the sense, the code path reaches here
|
||||
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
||||
# BUILD_SET etc. In such case, we can fallback to eager without
|
||||
@ -1885,7 +1885,13 @@ class ConvertFrame:
|
||||
user_stack_formatted = "".join(
|
||||
traceback.format_list(user_stack)
|
||||
)
|
||||
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
|
||||
frame_info = exc.format_frame_info(code)
|
||||
user_stack_trace = (
|
||||
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
|
||||
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
|
||||
"The graph break occurred in the following user code:\n"
|
||||
f"{user_stack_formatted}"
|
||||
)
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
@ -1897,6 +1903,7 @@ class ConvertFrame:
|
||||
graph_break_log.debug(
|
||||
user_stack_trace,
|
||||
exc_info=True,
|
||||
stack_info=config.verbose,
|
||||
)
|
||||
|
||||
if not config.suppress_errors and not soft_fail:
|
||||
|
||||
@ -794,6 +794,38 @@ def format_error_msg_verbose(
|
||||
return msg
|
||||
|
||||
|
||||
def format_frame_info(code: types.CodeType) -> str:
|
||||
return (
|
||||
f"{getattr(code, 'co_name', '<unknown>')} "
|
||||
f"({getattr(code, 'co_filename', '<unknown>')} "
|
||||
f"line {getattr(code, 'co_firstlineno', 0)})"
|
||||
)
|
||||
|
||||
|
||||
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
|
||||
if code is not None:
|
||||
frame_info = format_frame_info(code)
|
||||
return (
|
||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
|
||||
|
||||
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
|
||||
frame_info = format_frame_info(code)
|
||||
return (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
||||
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
|
||||
f"{frame_summary}"
|
||||
)
|
||||
|
||||
|
||||
def format_error_msg(
|
||||
exc: Exception,
|
||||
code: types.CodeType,
|
||||
|
||||
@ -94,6 +94,8 @@ from .exc import (
|
||||
BackendCompilerFailed,
|
||||
collapse_resume_frames,
|
||||
format_graph_break_message,
|
||||
format_loop_skip_frame_message,
|
||||
format_skip_frame_message,
|
||||
get_stack_above_dynamo,
|
||||
ResumePrologueTracingError,
|
||||
StepUnsupported,
|
||||
@ -605,9 +607,9 @@ def generic_jump(
|
||||
)
|
||||
# compile a partial subgraph prefix then jump into user code
|
||||
if self.maybe_has_backedge():
|
||||
msg = (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"{self.frame_summary()}"
|
||||
msg = format_loop_skip_frame_message(
|
||||
self.f_code,
|
||||
"".join(traceback.format_list([self.frame_summary()])),
|
||||
)
|
||||
log.info(msg)
|
||||
raise exc.SkipFrame(msg)
|
||||
@ -883,9 +885,9 @@ def break_graph_if_unsupported(
|
||||
)
|
||||
|
||||
if self.maybe_has_backedge():
|
||||
msg = (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"{self.frame_summary()}"
|
||||
msg = format_loop_skip_frame_message(
|
||||
self.f_code,
|
||||
"".join(traceback.format_list([self.frame_summary()])),
|
||||
)
|
||||
log.info(msg)
|
||||
raise exc.SkipFrame(msg) from excp
|
||||
@ -4626,8 +4628,9 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
):
|
||||
raise exc.SkipFrame("because no content in function call")
|
||||
|
||||
raise exc.SkipFrame(
|
||||
format_skip_frame_message(self.f_code, "no content in function call")
|
||||
)
|
||||
self.instruction_pointer = None
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
|
||||
@ -2248,12 +2248,15 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
|
||||
try:
|
||||
val.data_ptr() # will throw for functorch tensors
|
||||
except RuntimeError as e:
|
||||
from .exc import SkipFrame
|
||||
from .exc import format_skip_frame_message, SkipFrame
|
||||
|
||||
# This will be GradTrackingTensor/BatchedTensor/etc
|
||||
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
|
||||
raise SkipFrame(
|
||||
f"torch.compile cannot be run in context: {functorch_subclass_name}"
|
||||
format_skip_frame_message(
|
||||
None,
|
||||
f"torch.compile cannot be run in context: {functorch_subclass_name}",
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ from torch._guards import Source
|
||||
from .. import config, graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||
from ..exc import (
|
||||
format_skip_frame_message,
|
||||
get_dynamo_observed_exception,
|
||||
handle_observed_exception,
|
||||
InfiniteGeneratorError,
|
||||
@ -1652,8 +1653,13 @@ class SkipFunctionVariable(VariableTracker):
|
||||
skip_frame_msg = kwargs.get("msg")
|
||||
if skip_frame_msg:
|
||||
skip_frame_msg = skip_frame_msg.as_python_constant()
|
||||
else:
|
||||
skip_frame_msg = ""
|
||||
raise SkipFrame(
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
|
||||
format_skip_frame_message(
|
||||
tx.f_code,
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
|
||||
)
|
||||
)
|
||||
elif self.value is torch._dynamo.step_unsupported:
|
||||
raise StepUnsupported
|
||||
|
||||
@ -3652,24 +3652,26 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
# - lifted args from tracing subgraph: [score_mod_other_buffers, mask_fn_other_buffers]
|
||||
_, _, _, inp_arg_block_mask, inp_arg_scale, inp_arg_kernel_options = inp_args
|
||||
block_mask = tuple(inp_arg_block_mask + (mask_fn_node,))
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=inp_args[:3]
|
||||
+ (
|
||||
score_mod_node,
|
||||
block_mask,
|
||||
inp_arg_scale,
|
||||
inp_arg_kernel_options,
|
||||
score_mod_lifted_args,
|
||||
mask_fn_lifted_args,
|
||||
with torch.fx.experimental.proxy_tensor.set_original_aten_op(self.value):
|
||||
proxy = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=inp_args[:3]
|
||||
+ (
|
||||
score_mod_node,
|
||||
block_mask,
|
||||
inp_arg_scale,
|
||||
inp_arg_kernel_options,
|
||||
score_mod_lifted_args,
|
||||
mask_fn_lifted_args,
|
||||
),
|
||||
kwargs={},
|
||||
),
|
||||
kwargs={},
|
||||
),
|
||||
example_value=None,
|
||||
)
|
||||
example_value=None,
|
||||
)
|
||||
return proxy
|
||||
|
||||
|
||||
class AutogradFunctionApplyVariable(VariableTracker):
|
||||
|
||||
@ -511,6 +511,7 @@ class GenericAOTAutogradResult(Generic[TForward, TBackward]):
|
||||
).post_compile(
|
||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||
)
|
||||
compiled_fw_func._boxed_call = True
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if needs_autograd:
|
||||
|
||||
@ -356,9 +356,10 @@ def trace_flex_attention(
|
||||
)
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", flex_attention, proxy_args, {}
|
||||
)
|
||||
with torch.fx.experimental.proxy_tensor.set_original_aten_op(flex_attention):
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", flex_attention, proxy_args, {}
|
||||
)
|
||||
return track_tensor_tree(
|
||||
example_out,
|
||||
out_proxy,
|
||||
@ -1114,23 +1115,26 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
assert mode is not None, "Mode should always be enabled for python fallback key"
|
||||
return trace_flex_attention_backward(
|
||||
mode,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
grad_out,
|
||||
grad_logsumexp,
|
||||
fw_graph,
|
||||
joint_graph,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
with torch.fx.experimental.proxy_tensor.set_original_aten_op(
|
||||
flex_attention_backward
|
||||
):
|
||||
return trace_flex_attention_backward(
|
||||
mode,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
grad_out,
|
||||
grad_logsumexp,
|
||||
fw_graph,
|
||||
joint_graph,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
|
||||
@flex_attention_backward.py_functionalize_impl
|
||||
|
||||
@ -1640,7 +1640,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(str, list, torch.fx.GraphModule),
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
return CompiledAOTI(
|
||||
filename=compiled_fn, device_type=graph.device_type
|
||||
)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
# pyrefly: ignore [unbound-name]
|
||||
@ -2713,7 +2715,7 @@ def _compile_fx_main(
|
||||
or torch._guards.TracingContext(fake_mode)
|
||||
)
|
||||
|
||||
if V.aot_compilation:
|
||||
if V.aot_compilation and not config.enable_autograd_for_aot:
|
||||
from .utils import is_valid_aoti_model_name
|
||||
|
||||
is_valid_aoti_model_name()
|
||||
|
||||
@ -1190,6 +1190,8 @@ autotune_lookup_table: dict[str, dict[str, Any]] = {}
|
||||
|
||||
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
|
||||
|
||||
enable_autograd_for_aot: bool = False
|
||||
|
||||
|
||||
def get_worker_log_path() -> Optional[str]:
|
||||
log_loc = None
|
||||
|
||||
@ -773,9 +773,86 @@ class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
|
||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||
device_type: str
|
||||
current_callable: Optional[Callable[..., Any]] = None
|
||||
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if not config.aot_inductor.link_libtorch:
|
||||
return
|
||||
|
||||
if (
|
||||
torch._inductor.cpp_builder._IS_MACOS
|
||||
or torch._inductor.cpp_builder._IS_WINDOWS
|
||||
):
|
||||
return
|
||||
|
||||
if config.aot_inductor.cross_target_platform == "windows":
|
||||
return
|
||||
|
||||
if config.aot_inductor.package_cpp_only:
|
||||
return
|
||||
|
||||
if not config.enable_autograd_for_aot:
|
||||
return
|
||||
|
||||
if isinstance(self.filename, list):
|
||||
current_callable = next(
|
||||
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
|
||||
)
|
||||
else:
|
||||
current_callable = self.filename
|
||||
|
||||
if isinstance(current_callable, torch.fx.GraphModule):
|
||||
self.current_callable = current_callable
|
||||
return
|
||||
|
||||
if self.device_type.startswith("cuda"):
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
|
||||
current_callable,
|
||||
1,
|
||||
self.device_type,
|
||||
"",
|
||||
True,
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
elif self.device_type == "cpu":
|
||||
current_callable = (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
|
||||
current_callable, 1
|
||||
).run # type: ignore[attr-defined]
|
||||
) # type: ignore[attr-defined]
|
||||
else:
|
||||
raise RuntimeError(f"unsupported device type {self.device_type}")
|
||||
self.current_callable = current_callable
|
||||
self._boxed_call = True
|
||||
for file in self._cached_files:
|
||||
if not os.path.exists(file):
|
||||
with open(file, "wb") as f:
|
||||
f.write(self._cached_files[file])
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
if self.current_callable is None:
|
||||
raise RuntimeError("AOTInductor compiled so is not loaded")
|
||||
return self.current_callable(inputs)
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
self.current_callable = None
|
||||
self._cached_files = {}
|
||||
filenames: list[str] = []
|
||||
if isinstance(self.filename, list):
|
||||
filenames = self.filename # type: ignore[assignment]
|
||||
elif isinstance(self.filename, str):
|
||||
filenames = [self.filename]
|
||||
for name in filenames:
|
||||
with open(name, "rb") as f:
|
||||
self._cached_files[name] = f.read()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["current_callable"] = None
|
||||
return state
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
@ -783,10 +860,8 @@ class CompiledAOTI(OutputCode):
|
||||
constants: CompiledFxGraphConstants,
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def prepare_for_serialization(self) -> None:
|
||||
pass
|
||||
if self.current_callable is None:
|
||||
self.__post_init__()
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
@ -781,9 +781,19 @@ def get_fused_kernel_name(
|
||||
) -> str:
|
||||
all_origins = aggregate_origins(node_schedule)
|
||||
if descriptive_names == "original_aten":
|
||||
|
||||
def get_origin_meta_str(origin):
|
||||
original_aten = origin.meta["original_aten"]
|
||||
key = ""
|
||||
if isinstance(original_aten, torch._ops.OpOverload):
|
||||
key = original_aten._overloadpacket.__name__
|
||||
elif isinstance(original_aten, torch._ops.HigherOrderOperator):
|
||||
key = str(original_aten.name())
|
||||
return key
|
||||
|
||||
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
|
||||
sources = [
|
||||
origin.meta["original_aten"]._overloadpacket.__name__
|
||||
get_origin_meta_str(origin)
|
||||
for origin in all_origins
|
||||
if origin.op == "call_function"
|
||||
and "original_aten" in origin.meta
|
||||
@ -794,12 +804,22 @@ def get_fused_kernel_name(
|
||||
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
|
||||
sources = []
|
||||
for origin in all_origins:
|
||||
if origin.op == "call_function" and "source_fn_stack" in origin.meta:
|
||||
source_fn = origin.meta["source_fn_stack"][-1]
|
||||
if origin.op == "call_function":
|
||||
source_fn = None
|
||||
suffix = ""
|
||||
if "source_fn_stack" in origin.meta:
|
||||
source_fn = origin.meta["source_fn_stack"][-1]
|
||||
elif "fwd_source_fn_stack" in origin.meta:
|
||||
# backward nodes have "fwd_source_fn_stack" instead
|
||||
source_fn = origin.meta["fwd_source_fn_stack"][-1]
|
||||
suffix = "backward"
|
||||
if not source_fn:
|
||||
continue
|
||||
if isinstance(source_fn[1], str):
|
||||
sources.append(source_fn[1])
|
||||
sources.append(source_fn[1] + suffix)
|
||||
else:
|
||||
sources.append(source_fn[1].__name__)
|
||||
sources.append(source_fn[1].__name__ + suffix)
|
||||
|
||||
sources = sorted(OrderedSet(sources))
|
||||
elif descriptive_names == "inductor_node":
|
||||
sources = [
|
||||
@ -852,11 +872,20 @@ def get_kernel_metadata(
|
||||
|
||||
for node in inductor_nodes:
|
||||
if "original_aten" in node.meta and node.meta["original_aten"] is not None:
|
||||
key = str(node.meta["original_aten"]._overloadpacket)
|
||||
original_aten_dict[key].append(node.name)
|
||||
original_aten = node.meta["original_aten"]
|
||||
key = None
|
||||
if isinstance(original_aten, torch._ops.OpOverload):
|
||||
key = str(original_aten._overloadpacket)
|
||||
elif isinstance(original_aten, torch._ops.HigherOrderOperator):
|
||||
key = str(original_aten.name())
|
||||
if key:
|
||||
original_aten_dict[key].append(node.name)
|
||||
if "from_node" in node.meta:
|
||||
key = node.meta["from_node"][0].name
|
||||
from_node_dict[key].append(node.name)
|
||||
elif node.meta.get("partitioner_tag") == "is_backward":
|
||||
# backward nodes currently don't have a "from node"
|
||||
from_node_dict[node.name].append(node.name)
|
||||
sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
|
||||
metadata = (
|
||||
f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "
|
||||
|
||||
@ -891,10 +891,14 @@ class TorchLogsFormatter(logging.Formatter):
|
||||
# exception handling - copied from logging.Formatter.format
|
||||
s = record.message
|
||||
if record.exc_info:
|
||||
from torch._dynamo import config
|
||||
|
||||
should_format_exc = config.verbose or artifact_name != "graph_breaks"
|
||||
# Cache the traceback text to avoid converting it multiple times
|
||||
# (it's constant anyway)
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if should_format_exc:
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
if s[-1:] != "\n":
|
||||
s = s + "\n"
|
||||
|
||||
@ -24,7 +24,73 @@ __all__ = [
|
||||
|
||||
|
||||
class EventList(list):
|
||||
"""A list of Events (for pretty printing)."""
|
||||
"""A list of profiling events with helper methods for analysis and visualization.
|
||||
|
||||
EventList extends the standard Python list to provide specialized methods for
|
||||
working with profiling events (FunctionEvent or FunctionEventAvg objects).
|
||||
It includes utilities for aggregating statistics, formatting output tables,
|
||||
and exporting profiling data.
|
||||
|
||||
This class is typically returned by profiler methods and should not be
|
||||
instantiated directly by users.
|
||||
|
||||
Args:
|
||||
*args: Standard list arguments.
|
||||
use_device (str, optional): Device type for profiling ("cuda", "xpu", etc.).
|
||||
profile_memory (bool, optional): Whether memory profiling was enabled. Default: False.
|
||||
with_flops (bool, optional): Whether to include FLOP counts. Default: False.
|
||||
|
||||
Attributes:
|
||||
_use_device (str): Device type being profiled.
|
||||
_profile_memory (bool): Whether memory profiling is enabled.
|
||||
_with_flops (bool): Whether FLOP counting is enabled.
|
||||
_tree_built (bool): Whether the event tree structure has been built.
|
||||
|
||||
Key Methods:
|
||||
table(...): Format events as a table string for display.
|
||||
export_chrome_trace(path): Export to Chrome tracing format.
|
||||
export_stacks(path, metric): Export stack traces with metrics.
|
||||
key_averages(...): Compute averaged statistics grouped by operation name.
|
||||
total_average(): Compute aggregate totals across all events (sums, not averages).
|
||||
|
||||
Properties:
|
||||
self_cpu_time_total: Sum of self CPU time across all events.
|
||||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(activities=[ProfilerActivity.CPU]) as prof:
|
||||
x = torch.randn(100, 100)
|
||||
y = torch.matmul(x, x)
|
||||
|
||||
# EventList is returned by prof.events()
|
||||
events = prof.events()
|
||||
|
||||
# Display as formatted table
|
||||
print(
|
||||
events.table(
|
||||
sort_by="cpu_time_total", row_limit=20, top_level_events_only=False
|
||||
)
|
||||
)
|
||||
|
||||
# Export to Chrome tracing format
|
||||
events.export_chrome_trace("trace.json")
|
||||
|
||||
# Get averaged statistics
|
||||
avg_events = events.key_averages()
|
||||
print(avg_events.table())
|
||||
|
||||
# Export stack traces
|
||||
events.export_stacks("stacks.txt", "self_cpu_time_total")
|
||||
|
||||
See Also:
|
||||
- :class:`FunctionEvent`: Individual profiling event
|
||||
- :class:`FunctionEventAvg`: Averaged profiling statistics
|
||||
- :meth:`table`: Format events as a readable table
|
||||
- :meth:`key_averages`: Aggregate events by operation name
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
use_device = kwargs.pop("use_device", None)
|
||||
@ -373,10 +439,23 @@ class EventList(list):
|
||||
return avg_list
|
||||
|
||||
def total_average(self):
|
||||
"""Averages all events.
|
||||
"""Compute aggregate statistics across all events.
|
||||
|
||||
Accumulates statistics from all events into a single FunctionEventAvg object.
|
||||
This is primarily useful for computing total metrics (total CPU time, total
|
||||
memory usage, etc.) across the entire profiling session, regardless of
|
||||
operation type.
|
||||
|
||||
Note:
|
||||
This sums up times and counts across ALL different operations, so the
|
||||
"average" metrics (like cpu_time) represent the average time per operation
|
||||
call across the entire session, mixing all operation types together.
|
||||
For per-operation averages, use :meth:`key_averages` instead.
|
||||
|
||||
Returns:
|
||||
A FunctionEventAvg object.
|
||||
FunctionEventAvg: A single aggregate object with key="Total" containing
|
||||
accumulated statistics.
|
||||
|
||||
"""
|
||||
total_stat = FunctionEventAvg()
|
||||
for evt in self:
|
||||
@ -471,7 +550,64 @@ Kernel = namedtuple("Kernel", ["name", "device", "duration"])
|
||||
|
||||
|
||||
class FunctionEvent(FormattedTimesMixin):
|
||||
"""Profiling information about a single function."""
|
||||
"""Profiling information about a single function.
|
||||
|
||||
FunctionEvent records the execution of a single operation during profiling.
|
||||
These events are obtained from the profiler/kineto and contain detailed
|
||||
timing and memory usage information.
|
||||
|
||||
.. note::
|
||||
FunctionEvent objects are typically created by the profiler/kineto and should not
|
||||
be instantiated directly by users. Access them through the profiler's output.
|
||||
|
||||
Attributes:
|
||||
id (int): Unique identifier for this event.
|
||||
node_id (int): Node identifier for distributed profiling (-1 if not applicable).
|
||||
name (str): Name of the profiled function/operator.
|
||||
overload_name (str): Overload name for the operator (requires _ExperimentalConfig(capture_overload_names=True) set).
|
||||
trace_name (str): Same as name, just changes ProfilerStep* to ProfilerStep#
|
||||
time_range (Interval): Time interval containing start and end timestamps in microseconds.
|
||||
thread (int): Thread ID where the operation started.
|
||||
fwd_thread (int): Thread ID of the corresponding forward operation.
|
||||
kernels (List[Kernel]): List of device kernels launched by this operation.
|
||||
count (int): Number of times this event was called (usually 1).
|
||||
cpu_children (List[FunctionEvent]): Direct CPU child operations.
|
||||
cpu_parent (FunctionEvent): Direct CPU parent operation.
|
||||
input_shapes (Tuple[int, ...]): Shapes of input tensors (requires record_shapes=true).
|
||||
concrete_inputs (List[Any]): Concrete input values (requires record_shapes=true).
|
||||
kwinputs (Dict[str, Any]): Keyword arguments (requires record_shapes=true).
|
||||
stack (List[str]): Python stack trace where the operation was called (requires with_stack=true).
|
||||
scope (int): at::RecordScope identifier (0=forward, 1=backward, etc.).
|
||||
use_device (str): Device type being profiled ("cuda", "xpu", etc.).
|
||||
cpu_memory_usage (int): CPU memory allocated in bytes.
|
||||
device_memory_usage (int): Device memory allocated in bytes.
|
||||
is_async (bool): Whether this is an asynchronous operation.
|
||||
is_remote (bool): Whether this operation occurred on a remote node.
|
||||
sequence_nr (int): Sequence number for autograd operations.
|
||||
device_type (DeviceType): Type of device (CPU, CUDA, XPU, PrivateUse1, etc.).
|
||||
device_index (int): Index of the device (e.g., GPU 0, 1, 2).
|
||||
device_resource_id (int): Resource ID on the device (ie. stream ID).
|
||||
is_legacy (bool): Whether this is from the legacy profiler.
|
||||
flops (int): Estimated floating point operations.
|
||||
is_user_annotation (bool): Whether this is a user-annotated region.
|
||||
metadata_json (str): Additional metadata in JSON format.
|
||||
|
||||
Properties:
|
||||
cpu_time_total (float): Total CPU time in microseconds.
|
||||
device_time_total (float): Total device (CUDA/XPU/etc) time in microseconds.
|
||||
self_cpu_time_total (float): CPU time excluding child operations.
|
||||
self_device_time_total (float): Device time excluding child operations.
|
||||
self_cpu_memory_usage (int): CPU memory usage excluding child operations.
|
||||
self_device_memory_usage (int): Device memory usage excluding child operations.
|
||||
cpu_time (float): Average CPU time per call.
|
||||
device_time (float): Average device time per call.
|
||||
key (str): Key used for grouping events (usually same as name).
|
||||
|
||||
See Also:
|
||||
- :class:`torch.profiler.profile`: Context manager for profiling
|
||||
- :class:`EventList`: List container for FunctionEvent objects with helper methods
|
||||
- :class:`FunctionEventAvg`: Averaged statistics over multiple FunctionEvent objects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -701,7 +837,50 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
|
||||
|
||||
class FunctionEventAvg(FormattedTimesMixin):
|
||||
"""Used to average stats over multiple FunctionEvent objects."""
|
||||
"""Averaged profiling statistics over multiple FunctionEvent objects.
|
||||
|
||||
FunctionEventAvg aggregates statistics from multiple FunctionEvent objects
|
||||
with the same key (typically same operation name). This is useful for getting
|
||||
average performance metrics across multiple invocations of the same operation.
|
||||
|
||||
This class is typically created by calling :meth:`EventList.key_averages()` on
|
||||
a profiler's event list.
|
||||
|
||||
Attributes:
|
||||
key (str): Grouping key for the events (typically operation name).
|
||||
count (int): Total number of events aggregated.
|
||||
node_id (int): Node identifier for distributed profiling (-1 if not applicable).
|
||||
is_async (bool): Whether the operations are asynchronous.
|
||||
is_remote (bool): Whether the operations occurred on a remote node.
|
||||
use_device (str): Device type being profiled ("cuda", "xpu", etc.).
|
||||
cpu_time_total (int): Accumulated total CPU time in microseconds.
|
||||
device_time_total (int): Accumulated total device time in microseconds.
|
||||
self_cpu_time_total (int): Accumulated self CPU time (excluding children) in microseconds.
|
||||
self_device_time_total (int): Accumulated self device time (excluding children) in microseconds.
|
||||
input_shapes (List[List[int]]): Input tensor shapes (requires record_shapes=true).
|
||||
overload_name (str): Operator overload name (requires _ExperimentalConfig(capture_overload_names=True) set).
|
||||
stack (List[str]): Python stack trace where the operation was called (requires with_stack=true).
|
||||
scope (int): at::RecordScope identifier (0=forward, 1=backward, etc.).
|
||||
cpu_memory_usage (int): Accumulated CPU memory usage in bytes.
|
||||
device_memory_usage (int): Accumulated device memory usage in bytes.
|
||||
self_cpu_memory_usage (int): Accumulated self CPU memory usage in bytes.
|
||||
self_device_memory_usage (int): Accumulated self device memory usage in bytes.
|
||||
cpu_children (List[FunctionEvent]): CPU child events.
|
||||
cpu_parent (FunctionEvent): CPU parent event.
|
||||
device_type (DeviceType): Type of device (CPU, CUDA, XPU, PrivateUse1, etc.).
|
||||
is_legacy (bool): Whether from legacy profiler.
|
||||
flops (int): Total floating point operations.
|
||||
is_user_annotation (bool): Whether this is a user-annotated region.
|
||||
|
||||
Properties:
|
||||
cpu_time (float): Average CPU time per invocation.
|
||||
device_time (float): Average device time per invocation.
|
||||
|
||||
See Also:
|
||||
- :class:`EventList.key_averages`: Method that creates FunctionEventAvg objects
|
||||
- :class:`FunctionEvent`: Individual profiling event
|
||||
- :class:`EventList`: Container for profiling events
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.key: Optional[str] = None
|
||||
|
||||
@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&>())
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const bool>())
|
||||
.def(
|
||||
"run",
|
||||
&AOTIModelContainerRunnerCuda::run,
|
||||
|
||||
@ -465,6 +465,39 @@ lib.define(
|
||||
"_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
|
||||
)
|
||||
|
||||
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
|
||||
"""
|
||||
Given a local tensor and a group name, return a tuple of tensors that are
|
||||
symmetric on other devices. The returned tensors are ordered by rank IDs. The
|
||||
length of the tuple equals to the size of the group.
|
||||
|
||||
Note: this API works only when `world_within_direct_access()` returns True, i.e.
|
||||
only when the group is within NVLink domain or similar. It does not work across
|
||||
network interfaces.
|
||||
"""
|
||||
|
||||
|
||||
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
|
||||
def _get_remote_tensors_default(
|
||||
local: torch.Tensor, group_name: str
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
hdl = rendezvous(local, group_name)
|
||||
if hdl is None:
|
||||
raise ValueError("Tensor is not allocated from Symmetric Memory")
|
||||
|
||||
return tuple(
|
||||
hdl.get_remote_tensor(peer, local.size(), local.dtype)
|
||||
for peer in range(hdl.world_size)
|
||||
)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "get_remote_tensors", "Meta")
|
||||
def _get_remote_tensors_meta(
|
||||
local: torch.Tensor, group_name: str
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
return tuple(torch.empty_like(local) for _ in range(group.size()))
|
||||
|
||||
|
||||
class _ScaleMode(Enum):
|
||||
UNSCALED = "unscaled"
|
||||
|
||||
@ -337,19 +337,34 @@ class OpDispatcher:
|
||||
if is_inplace_op:
|
||||
# inplace op should return self instead of re-wrapping
|
||||
if output_sharding.output_spec is not None:
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, DTensorSpec)
|
||||
assert isinstance(args[0], dtensor.DTensor)
|
||||
|
||||
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
|
||||
# the inplace argument's tensor meta. Here we choose to special case
|
||||
# this op because as far as I know this is the only inplace op that
|
||||
# has such as behavior. We can extend this special case if necessary.
|
||||
if op_call == aten.squeeze_.dim:
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, DTensorSpec)
|
||||
assert isinstance(args[0], dtensor.DTensor)
|
||||
# update the spec to handle tensor meta changes
|
||||
args[0]._spec = output_spec
|
||||
# use return_and_correct_aliasing to match the outer and the inner
|
||||
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
|
||||
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
|
||||
else:
|
||||
# For all other inplace ops, check if placement changes are required
|
||||
# Inplace operations that change placement are not supported because
|
||||
# they would require redistribution, which breaks aliasing semantics.
|
||||
# If there are views into the tensor, the views would not be updated.
|
||||
if args[0]._spec.placements != output_spec.placements:
|
||||
raise RuntimeError(
|
||||
f"{op_call}: in-place operations that require placement changes "
|
||||
f"are not supported. The operation would change placement from "
|
||||
f"{args[0]._spec.placements} to {output_spec.placements}, "
|
||||
f"which requires redistribution and breaks aliasing semantics. "
|
||||
f"Please use the out-of-place version of this operation instead."
|
||||
)
|
||||
# Most inplace ops don't change tensor meta, so no spec update needed
|
||||
return args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -23,6 +23,7 @@ from torch.distributed.tensor._ops.utils import (
|
||||
map_placements_after_broadcast,
|
||||
prod,
|
||||
register_op_strategy,
|
||||
register_single_dim_strategy,
|
||||
)
|
||||
from torch.distributed.tensor._utils import (
|
||||
compute_local_shape_and_global_offset,
|
||||
@ -237,10 +238,119 @@ def dot_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
return _mm_like_strategy("i,i->", mesh, op_schema)
|
||||
|
||||
|
||||
@register_op_strategy(aten.mm.default)
|
||||
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# @register_op_strategy(aten.mm.default)
|
||||
# def mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# mesh = op_schema.get_mesh_from_args()
|
||||
# return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
|
||||
|
||||
from ._einsum_strategy import EinsumDims
|
||||
|
||||
|
||||
def gen_single_dim_einsum_strategies(
|
||||
equation: str,
|
||||
mesh: DeviceMesh,
|
||||
*,
|
||||
linearity: bool = False,
|
||||
) -> list[list[Placement]]:
|
||||
"""
|
||||
Generate a strategy list for the ops that follow einsum style notation.
|
||||
|
||||
In principle, each mesh dim is independent of other device mesh dim when we
|
||||
generate strategies. So we generate strategy over each device mesh dim and
|
||||
do product combination on all mesh dims. We basically follow the below rule
|
||||
for each device mesh dim:
|
||||
|
||||
1. Shard on contracting dim: When both inputs shard on contracting dim over
|
||||
the same device dim. The result will be Partial over that device dim.
|
||||
|
||||
2. Shard on noncontracting dim:
|
||||
2.1: Shard on batch dim: output, both inputs all should shard on batch
|
||||
dim.
|
||||
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
|
||||
input should shard on this free dim.
|
||||
|
||||
3. Linearity (Partial): If enabled, set Partial on output and inputs over
|
||||
the same device mesh dim.
|
||||
"""
|
||||
# parse einop equation and extract dims
|
||||
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
||||
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
||||
|
||||
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
|
||||
strategies_over_one_mesh_dim = []
|
||||
|
||||
# placement list stores placements of [output, input1, input2, ...]
|
||||
# first we always have replicate all for inputs and output
|
||||
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split batch dim
|
||||
for batch_dim in edims.batch_dims:
|
||||
output_batch_dim = output_dim.index(batch_dim)
|
||||
placement_list = [Shard(output_batch_dim)]
|
||||
for input_dim in input_dims:
|
||||
input_batch_dim = input_dim.index(batch_dim)
|
||||
placement_list.append(Shard(input_batch_dim))
|
||||
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split contracting dim
|
||||
for contracting_dim in edims.contracting_dims:
|
||||
# Contracting dim can shard on same device axis for both inputs. This
|
||||
# results in the output being Partial on that device axis. For example:
|
||||
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
|
||||
placement_list = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
input_contracting_dim = input_dim.index(contracting_dim)
|
||||
placement_list.append(Shard(input_contracting_dim))
|
||||
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split lhs free dim
|
||||
for lhs_dim in edims.lhs_out_only_dims:
|
||||
lhs_free_dim_output = output_dim.index(lhs_dim)
|
||||
lhs_free_dim_input = input_dims[0].index(lhs_dim)
|
||||
# this means split the lhs input and output
|
||||
# i.e. S(0), R -> S(0)
|
||||
lhs_placement_list: list[Placement] = [
|
||||
Shard(lhs_free_dim_output),
|
||||
Shard(lhs_free_dim_input),
|
||||
Replicate(),
|
||||
]
|
||||
strategies_over_one_mesh_dim.append(lhs_placement_list)
|
||||
|
||||
# split rhs free dim
|
||||
for rhs_dim in edims.rhs_out_only_dims:
|
||||
rhs_free_dim_output = output_dim.index(rhs_dim)
|
||||
rhs_free_dim_input = input_dims[1].index(rhs_dim)
|
||||
rhs_placement_list: list[Placement] = [
|
||||
Shard(rhs_free_dim_output),
|
||||
Replicate(),
|
||||
Shard(rhs_free_dim_input),
|
||||
]
|
||||
strategies_over_one_mesh_dim.append(rhs_placement_list)
|
||||
|
||||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for _ in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
strategies_over_one_mesh_dim.append(linearity_placement_list)
|
||||
|
||||
return strategies_over_one_mesh_dim
|
||||
|
||||
|
||||
@register_single_dim_strategy(aten.mm.default)
|
||||
def mm_single_dim_strategy(op_schema: OpSchema) -> list[Placement]:
|
||||
self_strategy, mat2_strategy = op_schema.args_schema
|
||||
if not isinstance(self_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
|
||||
if not isinstance(mat2_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
|
||||
# generate all possible strategies for mm
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
|
||||
|
||||
|
||||
@register_op_strategy(aten.addmm.default)
|
||||
|
||||
@ -18,6 +18,7 @@ from torch.distributed.tensor._ops.utils import (
|
||||
map_placements_after_broadcast,
|
||||
normalize_dim,
|
||||
register_op_strategy,
|
||||
register_single_dim_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
Partial,
|
||||
@ -488,6 +489,58 @@ def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
return pointwise_strategy(op_schema, linearity=linearity_type)
|
||||
|
||||
|
||||
def single_mesh_dim_pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> list[list[Placement]]:
|
||||
return single_mesh_dim_common_pointwise_strategy(op_schema.args_schema, linearity)
|
||||
|
||||
def single_mesh_dim_common_pointwise_strategy(
|
||||
args_schema: Sequence[object],
|
||||
linearity: int = -1,
|
||||
scalar_tensor_idx: Optional[int] = None
|
||||
) -> list[list[Placement]]:
|
||||
"""
|
||||
Common strategy for pointwise operations.
|
||||
|
||||
Args:
|
||||
args_schema: Input arguments schema
|
||||
|
||||
linearity: depending on the operator, we support different types of linearity
|
||||
-1: the operation does not support linearity
|
||||
0: the unary operation that supports linearity, output propagates partial.
|
||||
1: the binary operation supports add linearity, where it requires every operand
|
||||
to be partial, output propagates partial.
|
||||
2: the binary operation supports multiplicative linearity, where it requires
|
||||
the primary operand to be partial, and the other operands to be replicate,
|
||||
output propagates partial.
|
||||
scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh
|
||||
to be different from the mesh of followed_strategy
|
||||
"""
|
||||
# handle broadcasting
|
||||
common_shape = torch.broadcast_shapes(
|
||||
*[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)]
|
||||
)
|
||||
|
||||
placements_list = []
|
||||
for i in range(len(common_shape)):
|
||||
# Shard output dim i, and then shard the corresponding arguments if they have a corresponding (non broadcast) dim
|
||||
shard_placements = [Shard(i)]
|
||||
for arg in args_schema:
|
||||
if isinstance(arg, OpStrategy):
|
||||
common_dim_to_arg_dim = infer_broadcast_dims_map(common_shape, arg.shape)
|
||||
if common_dim_to_arg_dim[i] >= 0:
|
||||
shard_placements.append(Shard(common_dim_to_arg_dim[i]))
|
||||
else:
|
||||
shard_placements.append(Replicate())
|
||||
|
||||
placements_list.append(shard_placements)
|
||||
|
||||
if linearity > 0:
|
||||
# TODO implement partial
|
||||
# TODO: can the same op support both add and multiplicative linearity?
|
||||
pass
|
||||
|
||||
# TODO: handle scalar_tensor_idx
|
||||
return placements_list
|
||||
|
||||
def common_pointwise_strategy(
|
||||
args_schema: Sequence[object],
|
||||
followed_strategy: OpStrategy,
|
||||
@ -623,11 +676,15 @@ for op in linear_pointwise_ops:
|
||||
linear_pointwise_strategy
|
||||
)
|
||||
|
||||
for op in pointwise_ops:
|
||||
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
|
||||
pointwise_strategy
|
||||
)
|
||||
# for op in pointwise_ops:
|
||||
# register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
|
||||
# pointwise_strategy
|
||||
# )
|
||||
|
||||
for op in pointwise_ops:
|
||||
register_single_dim_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
|
||||
single_mesh_dim_pointwise_strategy
|
||||
)
|
||||
|
||||
# TODO: add all for_each ops
|
||||
for_each_ops = [
|
||||
|
||||
@ -42,6 +42,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# WHC- i think anywhere this is used, we can replace it with a corresponding single-dim passthrough strategy
|
||||
# (anyshard, replicate, partial can all pass through- and then expand that to the mesh dims later)
|
||||
def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
# For ops with a single tensor input, we perform a 1:1 mapping such that
|
||||
# for each strategy that the input supports, we create a corresponding strategy.
|
||||
@ -98,6 +100,28 @@ register_op_strategy(
|
||||
)(propagate_single_input_strategy)
|
||||
|
||||
|
||||
"""
|
||||
WHC- equal_strategy is an example baking an optimization into the sharding rule.
|
||||
|
||||
The unoptimized equal strategy (for one mesh dim) should look like this
|
||||
S, S -> S
|
||||
R, R -> R
|
||||
P, P -> P * - this could work, i think, if we supported a Partial of boolean and reduction?
|
||||
And this should be expanded to the full mesh.
|
||||
|
||||
But what this rule actually does is
|
||||
- compare the two tensor args to equal- look at the strategies for each, which represent the I-O sharding relationship for the
|
||||
op that produced those tensor args. Pick the one that has the strategy (OpSpec) with the most Shard() placements in it.
|
||||
Why? becuase converting the other arg from R->S is cheaper than converting S->R
|
||||
|
||||
- start with the assumption that the 'equal' op has the same strategy as the op that produced its max-shard input
|
||||
- then adjust the placements from partial to replicate since we don't support partial in equal
|
||||
- finally, produce an OpSpec that only populates the 'output_specs' of OpSpec
|
||||
|
||||
TODO: why is it ok to populate only the output_specs of an OpSpec? Is it defined to mean that all input specs are the same as the output spec?
|
||||
"""
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
[
|
||||
aten.equal.default,
|
||||
@ -141,6 +165,19 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
return equal_strategy
|
||||
|
||||
|
||||
"""
|
||||
WHC
|
||||
seems like we could replace this with single-mesh strategy
|
||||
S->S
|
||||
R->R
|
||||
P->R
|
||||
|
||||
The P->R thing is odd, but makes sense:
|
||||
* can't support P->P since it would be incorrect to create a new 'partial' tensor from ones, which would no longer be ones if we replicated them
|
||||
* don't want to omit the support for input Partial becuase we'd force a replication on the input which would be wasteful
|
||||
"""
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
[
|
||||
aten.empty_like.default,
|
||||
@ -489,6 +526,19 @@ def replicate_tensor_dim(
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
WHC- example of a complicated 'follow your inputs' strategy that would be useful to try out as a simple rule
|
||||
|
||||
seems very simple to write this way
|
||||
|
||||
assert input, src same ndim
|
||||
for i in range(input.ndim):
|
||||
if i != slice_dim:
|
||||
Shard(i), Shard(i) -> Shard(i)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
|
||||
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
# 1. number of dimensions in input and src need to match.
|
||||
|
||||
@ -4,8 +4,7 @@ import functools
|
||||
import itertools
|
||||
import operator
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import cast, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DimsSequenceType, DimsType
|
||||
@ -28,10 +27,7 @@ from torch.distributed.tensor.placement_types import (
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
# from torch.testing._internal.distributed._tensor.common_dtensor import redistribute
|
||||
|
||||
|
||||
# convenient wrapper to register sharding propagation rules
|
||||
@ -54,11 +50,69 @@ def register_prop_rule(
|
||||
return wrapper
|
||||
|
||||
|
||||
def register_op_strategy(
|
||||
op, schema_info=None
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def _expand_single_dim_strategy_to_mesh(single_dim_strategy: Callable[[OpSchema], list[list[Placement]]]) -> Callable[[OpSchema], StrategyType]:
|
||||
"""
|
||||
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
|
||||
sharding types used by inputs.
|
||||
"""
|
||||
def expanded_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
|
||||
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
|
||||
inputs_strategy = op_schema.args_strategy
|
||||
mesh = inputs_strategy[0].mesh
|
||||
|
||||
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
|
||||
# TODO: add Replicate since its implicit in single_dim strategies
|
||||
# TODO: filter out 'invalid' placements
|
||||
# - ShardVar needs to say whether 'even sharding' is required or not
|
||||
|
||||
# copied from einsum strategy..
|
||||
# TODO: identify differences between this and 'expand_' util
|
||||
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [
|
||||
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
|
||||
]
|
||||
arg_specs = spec_list[1:]
|
||||
src_strategies = [s for s in op_schema.args_schema if isinstance(s, OpStrategy)]
|
||||
assert len(arg_specs) == len(src_strategies), "expected one src strategy per arg spec"
|
||||
all_strategies.append(
|
||||
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:], redistribute_cost=[
|
||||
generate_redistribute_costs(src_strategy, arg_spec) for (src_strategy, arg_spec) in zip(src_strategies, arg_specs)
|
||||
])
|
||||
)
|
||||
|
||||
return OpStrategy(all_strategies)
|
||||
|
||||
return expanded_strategy
|
||||
|
||||
def register_single_dim_strategy(
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
) -> Callable[
|
||||
[Callable[[OpSchema], list[list[Placement]]]], Callable[[OpSchema], list[list[Placement]]]
|
||||
]:
|
||||
"""
|
||||
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
|
||||
to cover all the mesh dims present in the runtime inputs.
|
||||
"""
|
||||
def expanded_registration_wrapper(
|
||||
single_dim_strategy: Callable[[OpSchema], list[list[Placement]]],
|
||||
) -> Callable[[OpSchema], list[list[Placement]]]:
|
||||
_expanded_strategy = _expand_single_dim_strategy_to_mesh(single_dim_strategy)
|
||||
register_op_strategy(op, schema_info)(_expanded_strategy)
|
||||
|
||||
return single_dim_strategy
|
||||
|
||||
return expanded_registration_wrapper
|
||||
|
||||
|
||||
def register_op_strategy(
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
|
||||
# For every ATen op that accepts any args in this list,
|
||||
# the arg itself can impact the strides (and potentially the sharding strategy)
|
||||
# of the output tensor.
|
||||
@ -68,7 +122,9 @@ def register_op_strategy(
|
||||
"memory_format",
|
||||
]
|
||||
|
||||
def wrapper(impl):
|
||||
def wrapper(
|
||||
impl: Callable[[OpSchema], StrategyType],
|
||||
) -> Callable[[OpSchema], StrategyType]:
|
||||
if isinstance(op, list):
|
||||
overloads = op
|
||||
else:
|
||||
@ -159,7 +215,10 @@ def prod(xs: Iterable[int]) -> int:
|
||||
|
||||
|
||||
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
|
||||
"""Check if the shape is shardable according to the spec."""
|
||||
"""Check if the spec matches these criteria:
|
||||
* any Shard placements in spec refer to valid tensor dims
|
||||
* no empty local tensors (uneven sharding OK, as long as last rank has >0 size)
|
||||
"""
|
||||
# number of shards in each tensor dimension
|
||||
shards_map = [1] * len(shape)
|
||||
for i, placement in enumerate(spec.placements):
|
||||
@ -225,6 +284,9 @@ def infer_broadcast_dims_map(
|
||||
) -> list[int]:
|
||||
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
|
||||
# this is aligned with the broadcast semantics
|
||||
# e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4],
|
||||
# broadcast_dims_map will be [-1, 0, 1, 2]
|
||||
# meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input
|
||||
common_ndim = len(common_shape)
|
||||
input_ndim = len(input_shape)
|
||||
broadcast_dims_map = [-1] * common_ndim
|
||||
|
||||
@ -1543,7 +1543,9 @@ ORIGINAL_ATEN: Optional[object] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
|
||||
def set_original_aten_op(
|
||||
func: OpOverload | torch._ops.HigherOrderOperator,
|
||||
) -> Generator[None, None, None]:
|
||||
global ORIGINAL_ATEN
|
||||
if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
|
||||
ORIGINAL_ATEN = func
|
||||
|
||||
@ -207,12 +207,19 @@ def tensorify_python_scalars(
|
||||
and node.target is torch.ops.aten._local_scalar_dense.default
|
||||
):
|
||||
dtype = node.args[0].meta["val"].dtype
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
assert isinstance(node.args[0], fx.Node), node.args[0]
|
||||
|
||||
s = node.meta["val"].node.expr
|
||||
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
|
||||
# only tensorify if the dtype is floating point
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
expr_to_tensor_proxy[s] = MetaProxy(
|
||||
node.args[0], tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
@ -220,9 +227,7 @@ def tensorify_python_scalars(
|
||||
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
|
||||
expr_to_tensor_proxy[s], torch.float64
|
||||
)
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
elif (sym_expr := _get_sym_val(node)) is not None:
|
||||
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
||||
|
||||
@ -43,6 +43,7 @@ from torch.distributed.tensor.parallel import (
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
ACCELERATOR_DIST_BACKENDS,
|
||||
MultiProcContinuousTest,
|
||||
MultiProcessTestCase,
|
||||
MultiThreadedTestCase,
|
||||
@ -396,14 +397,17 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
return init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
|
||||
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
|
||||
if backend is None:
|
||||
backend = self.backend
|
||||
|
||||
requires_gpu = any(
|
||||
gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS
|
||||
)
|
||||
if requires_gpu and torch.accelerator.device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
curr_backend = dist.get_default_backend_for_device(self.device_type)
|
||||
|
||||
if backend is None:
|
||||
backend = self.backend
|
||||
|
||||
if backend not in [
|
||||
"nccl",
|
||||
"gloo",
|
||||
|
||||
Reference in New Issue
Block a user