Compare commits

..

1 Commits

Author SHA1 Message Date
57a76b44de fix all-to-all estimation 2025-11-14 17:05:47 -08:00
34 changed files with 293 additions and 1192 deletions

View File

@ -1,342 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/BlasBackend.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_efficientzerotensor.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/_unsafe_view_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addmv_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/dot_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/max.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#endif
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - Else if scale.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
* 1:
* - Returns RowWise.
*
* - Otherwise:
* - Returns Error.
*/
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return at::isFloat8Type(t.scalar_type()) &&
scale.scalar_type() == at::kFloat && scale.numel() == 1;
}
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
scale.is_contiguous());
}
bool is_desired_scaling(
const at::Tensor& t,
const at::Tensor& scale,
ScalingType desired_scaling) {
auto result = desired_scaling == ScalingType::TensorWise
? is_tensorwise_scaling(t, scale)
: is_rowwise_scaling(t, scale);
return result;
}
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& scale_a,
const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) &&
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
}
TORCH_CHECK(
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
a.size(0),
", 1) and scale_b should be (1, ",
b.size(1),
"), and both should be contiguous.\n"
"Got a.dtype()=",
a.scalar_type(),
", scale_a.dtype()=",
scale_a.scalar_type(),
", scale_a.size()=",
scale_a.sizes(),
", scale_a.stride()=",
scale_a.strides(),
", ",
"b.dtype()=",
b.scalar_type(),
", scale_b.dtype()=",
scale_b.scalar_type(),
", scale_b.size()=",
scale_b.sizes(),
" and scale_b.stride()=",
scale_b.strides());
}
Tensor& _scaled_gemm(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const ScalingType scaling_choice_a,
const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out,
const std::optional<Tensor>& alpha = std::nullopt) {
// TODO: scale_result and alpha is not defined or used!
std::optional<Tensor> scaled_result = std::nullopt;
at::native::onednn::scaled_matmul(
mat1,
mat2,
out,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
scaled_result,
use_fast_accum);
return out;
}
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output
// matrices Scales are only applicable when matrices are of Float8 type and
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
// type, scale_result is not applied. Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
// - If 1-dimensional tensors are used then scale_a should be size =
// mat1.size(0)
// and scale_b should have size = to mat2.size(1)
// Arguments:
// - `mat1`: the first operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `mat2`: the second operand of the matrix multiply, can be type
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher
// precision floating point type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
// shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only
// utilized if the output is a float8 type
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
// false.
// - `out`: a reference to the output tensor
Tensor& _scaled_mm_out_xpu(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
// Note: fast_accum is not supported in XPU for now.
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
// Check what type of scaling we are doing based on inputs. This list is
// sorted by decreasing priority.
// List of supported datatypes for XPU with oneDNN:
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
},
mat1,
mat2,
scale_a,
scale_b);
TORCH_CHECK(
!scale_result ||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(
!bias || bias->numel() == mat2.sizes()[1],
"Bias must be size ",
mat2.sizes()[1],
" but got ",
bias->numel());
TORCH_CHECK(
mat1.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
").");
TORCH_CHECK(
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
"mat2 shape (",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
") must be divisible by 16");
// Check types
TORCH_CHECK(
!out_dtype || *out_dtype == out.scalar_type(),
"out_dtype must match output matrix type");
TORCH_CHECK(
at::isFloat8Type(mat1.scalar_type()),
"Expected mat1 to be Float8 matrix got ",
mat1.scalar_type());
TORCH_CHECK(
at::isFloat8Type(mat2.scalar_type()),
"Expected mat2 to be Float8 matrix got ",
mat2.scalar_type());
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
// support 2D scales, only 1D. Needs to add more checks there.
if (bias) {
TORCH_CHECK(
bias->scalar_type() == kFloat ||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
bias->scalar_type() == c10::ScalarType::Half,
"Bias must be Float32 or BFloat16 or Half, but got ",
bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
auto scale_result_ = scale_result.value_or(Tensor());
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{
{out, "out", 0},
{mat1, "mat1", 1},
{mat2, "mat2", 2},
{bias_, "bias", 3},
{scale_a, "scale_a", 4},
{scale_b, "scale_b", 5},
{scale_result_, "scale_result", 6}};
checkAllSameGPU(__func__, targs);
}
// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
// kernels do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}
return out;
}
// TODO: Scale_result is not supported by now!!
return _scaled_gemm(
mat1,
mat2,
scale_a,
scale_b,
scaling_choice_a,
scaling_choice_b,
bias,
use_fast_accum,
out);
}
Tensor _scaled_mm_xpu(
const Tensor& mat_a,
const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_xpu(
mat_a,
mat_b,
scale_a,
scale_b,
bias,
scale_result,
out_dtype,
use_fast_accum,
out);
}
} // namespace at::native

View File

@ -1,4 +1,3 @@
#include <ATen/BlasBackend.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <c10/core/ScalarType.h>
@ -9,6 +8,7 @@
#include <oneapi/dnnl/dnnl.hpp>
namespace at::native::onednn {
at::Tensor broadcast_bias2D(
at::Tensor& dst,
at::Tensor& bias,
@ -328,236 +328,4 @@ void quantized_matmul(
result.copy_(dst);
}
// Describes how to configure oneDNN scales for a given role/ScalingType
struct ScaleSpec {
// specifies the way scale values will be applied to an ARG tensor.
int mask;
// specifies how scales are grouped along dimensions where
// multiple scale factors are used.
dnnl::memory::dims groups;
// specifies data type for scale factors.
dnnl::memory::data_type dtype;
// Helper to compute expected number of elements for scale tensors
// arg_type: "src" for SRC (groups pattern {1, X}),
// "wei" for WEIGHTS (groups pattern {X, 1})
int64_t expected_numel(
int64_t outer_dim,
int64_t inner_dim,
const std::string& arg_type) const {
if (groups == dnnl::memory::dims{1, 1})
return 1; // tensorwise scaling
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
TORCH_INTERNAL_ASSERT(
(groups == dnnl::memory::dims{1, inner_dim} ||
groups == dnnl::memory::dims{inner_dim, 1}),
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
groups,
".");
return outer_dim;
}
// Normalize an incoming scale tensor to contiguous storage and appropriate
// dtype/view
at::Tensor normalize(const at::Tensor& scale) const {
TORCH_INTERNAL_ASSERT(
dtype == dnnl::memory::data_type::f32,
"tensor scale currently must be f32, but got scale dtype: ",
scale.scalar_type());
return scale.to(at::kFloat).contiguous();
}
};
// This function defines how to set scales mask and groups according to:
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
// The returned value will be used in
// `set_scales(arg, mask, groups, data_type)`.
inline ScaleSpec make_scale_spec(
at::blas::ScalingType scaling_type,
int64_t M,
int64_t K,
int64_t N,
const std::string& arg_type) {
TORCH_CHECK(
arg_type == "src" || arg_type == "wei",
"Expected arg_type to be 'src' or 'wei', but got '",
arg_type,
"'");
TORCH_INTERNAL_ASSERT(
(scaling_type == at::blas::ScalingType::TensorWise ||
scaling_type == at::blas::ScalingType::RowWise),
"Currently only support scaling_type for TensorWise or RowWise");
int64_t dim = K; // Currently only K is used for grouping
bool is_src = (arg_type == "src");
if (scaling_type == at::blas::ScalingType::TensorWise) {
// Scale tensorwise. The same as `--attr-scales=common`.
// mask=0 : scale whole tensor
// groups={1, 1}: indicates that there is only one group for scaling
return {0, {1, 1}, dnnl::memory::data_type::f32};
} else {
// (scaling_type == at::blas::ScalingType::RowWise)
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
return {
(1 << 0) | (1 << 1),
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
dnnl::memory::data_type::f32};
}
}
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum) {
auto& engine = GpuEngineManager::Instance().get_engine();
auto& stream = GpuStreamManager::Instance().get_stream();
// This function will do steps with following steps
// 1. create memory descriptor
// 2. call write_to_dnnl_memory() to actually write memory
// 3. execute
const int64_t M = mat1.size(0);
const int64_t K = mat1.size(1);
const int64_t N = mat2.size(1);
// 1.1 Create memory descriptor
dnnl::memory::desc src_md = get_onednn_md(mat1);
dnnl::memory::desc weights_md = get_onednn_md(mat2);
dnnl::memory::desc dst_md = get_onednn_md(result);
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
// So we could directly get their memory desc and set later.
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
dnnl::memory::desc bias_md;
bool with_bias = bias.has_value();
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
if (with_bias) {
if (possible_reshaped_bias.dim() == 1) {
possible_reshaped_bias =
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
bias_md = get_onednn_md(possible_reshaped_bias);
} else {
bias_md = get_onednn_md(possible_reshaped_bias);
}
}
// 1.2 Create primitive descriptor and set scales mask
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
#if ONEDNN_SUPPORT_DETERMINISTIC
if (at::globalContext().deterministicAlgorithms() ||
at::globalContext().deterministicMkldnn())
op_attr.set_deterministic(true);
#endif
std::vector<int64_t> default_groups;
op_attr.set_scales(
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
op_attr.set_scales(
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
// scale_result tensor currently only supports scalar(TensorWise Scaling).
bool with_dst_scale = scale_result && scale_result->defined();
if (with_dst_scale) {
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// 1.3 Create the matmul primitive descriptor
dnnl::matmul::primitive_desc matmul_pd = with_bias
? dnnl::matmul::primitive_desc(
engine, src_md, weights_md, bias_md, dst_md, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_md, weights_md, dst_md, op_attr);
// 1.4 (Possible) Additional Checks
// TODO: In case there are memory desc does not align with the actual tensor,
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
// call. For example, weights not the same as matmul_pd.weights_desc(),
// 2. Prepare memory
// Create memory
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
dnnl::memory b_usr_m;
if (with_bias) {
b_usr_m =
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
}
// Prepare runtime scale memories (flat 1-D views) using the specs
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
int64_t expected_numel,
const at::Tensor& scale_tensor) {
at::Tensor prepared = spec.normalize(scale_tensor);
TORCH_CHECK(
prepared.numel() == expected_numel,
"Scale buffer length mismatch. Expected ",
expected_numel,
", got ",
prepared.numel());
dnnl::memory::desc scale_md(
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
};
auto scratchpad =
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
// 3. Setup Args for exec
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, src_usr_m});
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
args.insert({DNNL_ARG_DST, dst_usr_m});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, b_usr_m});
}
// Attach runtime scales using specs
auto src_sc_mem = make_scale_mem_from_spec(
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
auto wei_sc_mem = make_scale_mem_from_spec(
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
if (with_dst_scale) {
// Bind single f32 scalar as DST scale
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
dnnl::memory::desc dst_sc_md(
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
auto dst_sc_mem =
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
}
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
sycl::event matmul_fwd_event =
dnnl::sycl_interop::execute(matmul_p, stream, args);
return matmul_fwd_event;
}
} // namespace at::native::onednn

View File

@ -78,10 +78,6 @@ dnnl::memory::data_type get_onednn_dtype(
return dnnl::memory::data_type::f32;
case at::ScalarType::BFloat16:
return dnnl::memory::data_type::bf16;
case at::ScalarType::Float8_e4m3fn:
return dnnl::memory::data_type::f8_e4m3;
case at::ScalarType::Float8_e5m2:
return dnnl::memory::data_type::f8_e5m2;
default:
if (!allow_undef) {
TORCH_CHECK(

View File

@ -1,7 +1,6 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/BlasBackend.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
@ -203,16 +202,4 @@ void sdpa_backward(
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value);
sycl::event scaled_matmul(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result,
const Tensor& scale_a,
const Tensor& scale_b,
at::blas::ScalingType scaling_choice_a,
at::blas::ScalingType scaling_choice_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
bool use_fast_accum);
} // namespace at::native::onednn

View File

@ -1,113 +0,0 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -42,7 +42,6 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators

View File

@ -4,12 +4,17 @@
#include <c10/util/Exception.h>
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

View File

@ -1,4 +1,3 @@
#include <c10/util/Exception.h>
#include <include/openreg.h>
#include "OpenRegException.h"
@ -10,22 +9,21 @@ orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count);
}
orError_t GetDevice(DeviceIndex* device) {
orError_t GetDevice(c10::DeviceIndex* device) {
int tmp_device = -1;
auto err = orGetDevice(&tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
*device = static_cast<c10::DeviceIndex>(tmp_device);
return err;
}
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
orError_t SetDevice(DeviceIndex device) {
orError_t SetDevice(c10::DeviceIndex device) {
int cur_device = -1;
OPENREG_CHECK(orGetDevice(&cur_device));
orGetDevice(&cur_device);
if (device == cur_device) {
return orSuccess;
}
return orSetDevice(device);
}
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
int device_count_impl() {
int count = 0;
@ -33,37 +31,34 @@ int device_count_impl() {
return count;
}
OPENREG_EXPORT DeviceIndex device_count() noexcept {
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
TORCH_CHECK(
result <= std::numeric_limits<DeviceIndex>::max(),
result <= std::numeric_limits<c10::DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;
} catch (const Error& ex) {
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg());
return 0;
}
}();
return static_cast<DeviceIndex>(count);
return static_cast<c10::DeviceIndex>(count);
}
OPENREG_EXPORT DeviceIndex current_device() {
DeviceIndex cur_device = -1;
OPENREG_CHECK(GetDevice(&cur_device));
OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
return cur_device;
}
// LITERALINCLUDE START: OPENREG set_device FUNCTION
OPENREG_EXPORT void set_device(DeviceIndex device) {
check_device_index(device);
OPENREG_CHECK(SetDevice(device));
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device);
}
// LITERALINCLUDE END: OPENREG set_device FUNCTION
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
@ -76,8 +71,4 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
return current_device;
}
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
check_device_index(to_device);
return ExchangeDevice(to_device);
}
} // namespace c10::openreg

View File

@ -9,20 +9,10 @@
namespace c10::openreg {
OPENREG_EXPORT DeviceIndex device_count() noexcept;
OPENREG_EXPORT DeviceIndex current_device();
OPENREG_EXPORT void set_device(DeviceIndex device);
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
static inline void check_device_index(int64_t device) {
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::openreg::device_count()),
"), but got ",
static_cast<int>(device),
".");
}
} // namespace c10::openreg

View File

@ -2,8 +2,6 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
} // namespace c10::openreg

View File

@ -11,7 +11,6 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
@ -59,7 +58,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
set_device(d.index());
}
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
/**
* Set the current device to c10::Device, without checking for errors

View File

@ -27,10 +27,6 @@ class TestDevice(TestCase):
self.assertEqual(torch.accelerator.current_device_index(), 1)
self.assertEqual(torch.accelerator.current_device_index(), device)
def test_invalid_device_index(self):
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.set_device_index(2)
if __name__ == "__main__":
run_tests()

View File

@ -34,21 +34,18 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
}
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
PyObject* _setDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
auto device = THPUtils_unpackDeviceIndex(arg);
auto device = THPUtils_unpackLong(arg);
torch::utils::device_lazy_init(at::kPrivateUse1);
c10::openreg::set_device(device);
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");

View File

@ -41,13 +41,8 @@ def current_device():
return torch_openreg._C._get_device()
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
def set_device(device) -> None:
if device >= 0:
torch_openreg._C._set_device(device)
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
return torch_openreg._C._set_device(device)
def init():

View File

@ -331,25 +331,6 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
self.assertEqual(z.placements, (Replicate(),))
self.assertEqual(z.to_local(), input)
def test_inplace_op_partial_to_replicate(self):
# test that in-place operations that require redistribution raise an error
# to preserve aliasing semantics (issue #163374)
device_mesh = self.build_device_mesh()
input_tensor = torch.tensor(64.0, device=self.device_type)
partial_dt = DTensor.from_local(
input_tensor, device_mesh, placements=(Partial(),)
)
self.assertTrue(partial_dt.placements[0].is_partial())
# Inplace ops that require placement changes (Partial -> Replicate) should error
with self.assertRaisesRegex(
RuntimeError,
"in-place operations that require placement changes are not supported",
):
partial_dt.clamp_(max=10)
if __name__ == "__main__":
run_tests()

View File

@ -10,6 +10,7 @@ import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
import torch.fx as fx
from torch._C import FileCheck
from torch._dynamo.utils import counters, same
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
@ -238,6 +239,49 @@ graph():
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_patches())
def test_schedulable_wait(self):
"""Test that if a wait node is scheduable or not."""
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node
def test_graph():
graph = fx.Graph()
inp = graph.placeholder("inp")
group_size = graph.placeholder("group_size")
group_name = graph.placeholder("group_name")
ag_0_out = graph.call_function(
torch.ops._c10d_functional.all_gather_into_tensor.default,
args=(inp, group_size, group_name),
)
ag_0_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(ag_0_out,),
)
ag_1_out = graph.call_function(
torch.ops._c10d_functional.all_gather_into_tensor.default,
args=(ag_0_wait, group_size, group_name),
)
ag_1_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(ag_1_out,),
)
ag_2_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(ag_1_wait,),
)
graph.output(ag_2_wait)
return graph
graph = test_graph()
schedulable = {"wait_tensor_default", "wait_tensor_default_1"}
for node in list(graph.nodes):
expected = node.name in schedulable
assert _schedulable_wait_node(node) is expected
@torch._inductor.config.patch(get_patches())
def test_reorder_compute_for_overlap_mul(self):
def func(a, *, tag, ranks, group_size):

View File

@ -23,7 +23,12 @@ from torch._inductor.comms import (
sink_waits_iterative,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor,
is_all_reduce_tensor,
is_all_to_all_tensor,
is_reduce_scatter_tensor,
)
from torch._inductor.scheduler import (
_get_mm_like_fn,
BaseSchedulerNode,
@ -2193,7 +2198,7 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
self.assertEqual(saved_values, [wt1])
@skip_if_lt_x_gpu(2)
def test_comm_analysis(self):
def test_all_gather_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
@ -2234,6 +2239,140 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
rs_0_out = torch.ops._c10d_functional.reduce_scatter_tensor(
inp, "sum", group_size, group_name
)
rs_0_wait = torch.ops.c10d_functional.wait_tensor(rs_0_out)
rs_1_out = torch.ops._c10d_functional.reduce_scatter_tensor(
rs_0_wait, "sum", group_size, group_name
)
rs_1_wait = torch.ops.c10d_functional.wait_tensor(rs_1_out)
return rs_1_wait
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
g = gm.graph
for n in g.nodes:
if is_reduce_scatter_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
def test_all_reduce_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
ar_0_out = torch.ops._c10d_functional.all_reduce(inp, "sum", group_name)
ar_0_wait = torch.ops.c10d_functional.wait_tensor(ar_0_out)
ar_1_out = torch.ops._c10d_functional.all_reduce(
ar_0_wait, "sum", group_name
)
ar_1_wait = torch.ops.c10d_functional.wait_tensor(ar_1_out)
return ar_1_wait
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
g = gm.graph
for n in g.nodes:
if is_all_reduce_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
def test_all_to_all_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
chunk = inp.numel() // self.world_size
split_sizes = [chunk] * self.world_size
a2a_0_out = torch.ops._c10d_functional.all_to_all_single(
inp,
split_sizes,
split_sizes,
group_name,
)
a2a_0_wait = torch.ops.c10d_functional.wait_tensor(a2a_0_out)
a2a_1_out = torch.ops._c10d_functional.all_to_all_single(
a2a_0_wait,
split_sizes,
split_sizes,
group_name,
)
a2a_1_wait = torch.ops.c10d_functional.wait_tensor(a2a_1_out)
return a2a_1_wait
gm = make_fx(func)(
torch.ones(group_size * 4, 1, device=self.device), group_size, group_name
)
g = gm.graph
for n in g.nodes:
if is_all_to_all_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -208,21 +208,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
)
self.assertEqual(y, expected)
def test_get_remote_tensors(self) -> None:
"""
Get all remote tensors
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
my_tensor = symm_mem.empty(1, device=self.device).fill_(self.rank)
remote_tensors = torch.ops.symm_mem.get_remote_tensors(my_tensor, group_name)
dist.barrier()
for peer, tensor in enumerate(remote_tensors):
self.assertEqual(tensor, peer)
@skipIfRocm
def test_nvshmem_put(self) -> None:
self._init_device()

View File

@ -952,9 +952,7 @@ User code traceback:
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
Graph break: skip: from user code at:
File "test_error_messages.py", line N, in fn
assert x is None
""",
@ -1080,88 +1078,6 @@ from user code:
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback(self, records):
def fn(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
self.assertExpectedInline(
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
"""\
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message(self, records):
def fn(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
if x.sum() > 0:
""",
)
@make_logging_test(dynamo=logging.DEBUG)
def test_skip_frame_empty_function_message(self, records):
def empty_fn(x):
pass
torch.compile(empty_fn, backend="eager")(torch.randn(3))
skip_messages = [
r
for r in records
if "intentionally decided to skip the frame" in r.getMessage()
]
self.assertEqual(len(skip_messages), 1)
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
self.assertExpectedInline(
msg,
"""\
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
Reason: no content in function call empty_fn test_error_messages.py N""",
)
@make_logging_test(graph_breaks=True)
def test_nested_compile_user_frames(self, records):
def fn(x):
@ -1708,110 +1624,6 @@ from user code:
)
class NestedGraphBreakLoggingTests(
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
):
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
def f2(x):
return f1(x + 2)
def f3(x):
return f2(x + 3)
torch.compile(f3, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
User code traceback:
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
torch.compile(f3, backend="eager")(torch.randn(3))
File "test_error_messages.py", line N, in f3
return f2(x + 3)
File "test_error_messages.py", line N, in f2
return f1(x + 2)
File "test_error_messages.py", line N, in f1
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
def f2(x):
return f1(x + 4)
def f3(x):
return f2(x + 5)
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
User code traceback:
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
File "test_error_messages.py", line N, in f3
return f2(x + 5)
File "test_error_messages.py", line N, in f2
return f1(x + 4)
File "test_error_messages.py", line N, in f1
if x.sum() > 0:
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -14036,44 +14036,6 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
except Exception as e:
self.fail(f"torch.compile failed with error: {e}")
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorify_track_item_symint(self):
def _random_resize(image: torch.Tensor):
image_metanet = image
default_patch_size = 14
rand_cnn_resolution = (224, 256)
min_nump = rand_cnn_resolution[0] // default_patch_size
max_nump = rand_cnn_resolution[1] // default_patch_size
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
torch._check(new_nump > 0)
torch._check(new_nump * default_patch_size > 1)
image_metanet = F.interpolate(
image_metanet,
size=(new_nump * default_patch_size, new_nump * default_patch_size),
mode="bilinear",
align_corners=True,
)
img_h_new, img_w_new = image_metanet.shape[2:]
return (img_h_new, img_w_new), image_metanet
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
# Test the function
input_tensor = torch.rand(1, 3, 224, 224)
(h, w), output = _random_resize_compiled(input_tensor)
# Verify output properties
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 3)
self.assertEqual(output.shape[2], h)
self.assertEqual(output.shape[3], w)
self.assertTrue(h % 14 == 0)
self.assertTrue(w % 14 == 0)
self.assertTrue(224 <= h <= 256)
self.assertTrue(224 <= w <= 256)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1870,7 +1870,7 @@ class ConvertFrame:
raise
soft_fail = isinstance(e, Unsupported)
code = frame.f_code
# This is a soft failure. In the sense, the code path reaches here
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
# BUILD_SET etc. In such case, we can fallback to eager without
@ -1885,13 +1885,7 @@ class ConvertFrame:
user_stack_formatted = "".join(
traceback.format_list(user_stack)
)
frame_info = exc.format_frame_info(code)
user_stack_trace = (
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
"The graph break occurred in the following user code:\n"
f"{user_stack_formatted}"
)
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
@ -1903,7 +1897,6 @@ class ConvertFrame:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=config.verbose,
)
if not config.suppress_errors and not soft_fail:

View File

@ -794,38 +794,6 @@ def format_error_msg_verbose(
return msg
def format_frame_info(code: types.CodeType) -> str:
return (
f"{getattr(code, 'co_name', '<unknown>')} "
f"({getattr(code, 'co_filename', '<unknown>')} "
f"line {getattr(code, 'co_firstlineno', 0)})"
)
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
if code is not None:
frame_info = format_frame_info(code)
return (
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: {reason}"
)
else:
return (
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
f"Reason: {reason}"
)
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
frame_info = format_frame_info(code)
return (
"Skipping frame because there is a graph break in a for/while loop\n"
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
f"{frame_summary}"
)
def format_error_msg(
exc: Exception,
code: types.CodeType,

View File

@ -94,8 +94,6 @@ from .exc import (
BackendCompilerFailed,
collapse_resume_frames,
format_graph_break_message,
format_loop_skip_frame_message,
format_skip_frame_message,
get_stack_above_dynamo,
ResumePrologueTracingError,
StepUnsupported,
@ -607,9 +605,9 @@ def generic_jump(
)
# compile a partial subgraph prefix then jump into user code
if self.maybe_has_backedge():
msg = format_loop_skip_frame_message(
self.f_code,
"".join(traceback.format_list([self.frame_summary()])),
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
)
log.info(msg)
raise exc.SkipFrame(msg)
@ -885,9 +883,9 @@ def break_graph_if_unsupported(
)
if self.maybe_has_backedge():
msg = format_loop_skip_frame_message(
self.f_code,
"".join(traceback.format_list([self.frame_summary()])),
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
)
log.info(msg)
raise exc.SkipFrame(msg) from excp
@ -4628,9 +4626,8 @@ class InstructionTranslator(InstructionTranslatorBase):
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
):
raise exc.SkipFrame(
format_skip_frame_message(self.f_code, "no content in function call")
)
raise exc.SkipFrame("because no content in function call")
self.instruction_pointer = None
_step_logger()(
logging.INFO,

View File

@ -2248,15 +2248,12 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
try:
val.data_ptr() # will throw for functorch tensors
except RuntimeError as e:
from .exc import format_skip_frame_message, SkipFrame
from .exc import SkipFrame
# This will be GradTrackingTensor/BatchedTensor/etc
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame(
format_skip_frame_message(
None,
f"torch.compile cannot be run in context: {functorch_subclass_name}",
)
f"torch.compile cannot be run in context: {functorch_subclass_name}"
) from e

View File

@ -42,7 +42,6 @@ from torch._guards import Source
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
format_skip_frame_message,
get_dynamo_observed_exception,
handle_observed_exception,
InfiniteGeneratorError,
@ -1653,13 +1652,8 @@ class SkipFunctionVariable(VariableTracker):
skip_frame_msg = kwargs.get("msg")
if skip_frame_msg:
skip_frame_msg = skip_frame_msg.as_python_constant()
else:
skip_frame_msg = ""
raise SkipFrame(
format_skip_frame_message(
tx.f_code,
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
)
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
)
elif self.value is torch._dynamo.step_unsupported:
raise StepUnsupported

View File

@ -23,6 +23,7 @@ class NCCL_COLL(IntEnum):
ALL_GATHER = 1
REDUCE_SCATTER = 2
ALL_TO_ALL = 3
UNSUPPORTED = 4
class NVIDIA_GPU_TYPE(IntEnum):
@ -53,10 +54,10 @@ def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
return NCCL_COLL.ALL_GATHER
elif "reduce_scatter" in kernel_name:
return NCCL_COLL.REDUCE_SCATTER
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
elif any(comm in kernel_name for comm in ("all_to_all", "alltoall")):
return NCCL_COLL.ALL_TO_ALL
else:
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
return NCCL_COLL.UNSUPPORTED
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
@ -347,13 +348,12 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
sz_bytes = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
size += t.numel() * t.element_size()
# TODO - symbolic
return size
numel = get_size_numel(t.size())
sz_bytes += numel * get_dtype_size(t.dtype)
return sz_bytes
def estimate_nccl_collective_runtime_from_fx_node(

View File

@ -10,6 +10,10 @@ import torch.distributed as dist
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.comm_analysis import (
get_collective_type_from_kernel_name,
NCCL_COLL,
)
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._logging import trace_structured
from torch.fx.experimental.proxy_tensor import make_fx
@ -52,6 +56,23 @@ def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def _schedulable_wait_node(node: torch.fx.Node) -> bool:
"""
Add additional check on if the wait node is schedulable
We should not schedule a fx node that is:
1. wait on a collective that is not callable
2. wait on a non-NCCL communication node
"""
if not is_wait_tensor(node):
return False
assert isinstance(node.args[0], torch.fx.Node)
assert isinstance(node.args[0].target.name(), str)
is_callable: bool = node.args[0].op == "call_function"
coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name())
is_collective: bool = coll != NCCL_COLL.UNSUPPORTED
return is_callable and is_collective
def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None:
if is_all_gather_into_tensor(node):
group_key_fn = (
@ -138,7 +159,6 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target is torch.ops._c10d_functional.wait_tensor.default
and node.args[0].op == "call_function"
)
@ -149,6 +169,13 @@ def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
)
def is_all_to_all_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target is torch.ops._c10d_functional.all_to_all_single.default
)
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]

View File

@ -8,9 +8,9 @@ import torch
import torch.fx as fx
from torch._dynamo.graph_deduplication import _stable_topological_sort
from torch._inductor.fx_passes.bucketing import (
_schedulable_wait_node,
is_all_gather_into_tensor as is_all_gather,
is_reduce_scatter_tensor as is_reduce_scatter,
is_wait_tensor,
merge_all_gather_bucket,
merge_reduce_scatter_bucket,
)
@ -36,7 +36,10 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
"""
def __init__(
self, node_users: dict[fx.Node, OrderedSet[fx.Node]], *args: Any, **kwargs: Any
self,
node_users: dict[fx.Node, OrderedSet[fx.Node]],
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.node_users = node_users
@ -97,7 +100,7 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
)
# Identify the new wait and start
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
new_waits = [n for n in new_nodes if _schedulable_wait_node(n)]
assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}"
new_wait = new_waits[0]
new_start = new_wait.args[0]
@ -186,7 +189,7 @@ class ManualOverlapScheduler(OverlapScheduler):
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
if is_wait_tensor(node):
if _schedulable_wait_node(node):
start = node.args[0]
info = CollectiveInfo(
start_node=start,

View File

@ -11,7 +11,8 @@ from typing import Any, Literal
import torch
import torch.fx as fx
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.fx_passes.bucketing import is_wait_tensor
from torch._inductor.comm_analysis import estimate_fx_collective_size
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor
from torch._inductor.fx_passes.memory_estimator import (
_is_releasable,
build_memory_profile,
@ -67,16 +68,6 @@ def estimate_collective_time(
)
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
# todo - symbolic
size += t.numel() * t.element_size()
return size
def is_compute_node(n: fx.Node) -> bool:
"""
Should we consider this node computationally expensive ?
@ -318,7 +309,7 @@ class OverlapScheduler:
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
if is_wait_tensor(node):
if _schedulable_wait_node(node):
start = node.args[0]
coll_time_ms = estimate_collective_time(
start, custom_runtime_estimation=self.custom_runtime_estimation
@ -531,7 +522,7 @@ class OverlapScheduler:
self._handle_compute(node)
elif node in self.collective_info:
self._handle_collective_start(node)
elif is_wait_tensor(node):
elif _schedulable_wait_node(node):
self._handle_wait(node)
else:
self._handle_other(node)
@ -596,7 +587,7 @@ class OverlapScheduler:
def _compute_score(self, node: fx.Node) -> object:
"""Compute priority score for a node"""
if is_wait_tensor(node):
if _schedulable_wait_node(node):
info = self.collective_info[self.wait_to_start[node]]
# defer waits locally if they are exposed.
compute_local_priority = int(info.is_exposed)
@ -827,7 +818,7 @@ class OverlapScheduler:
# thus forcing it to be exposed.
# however, if it is already hidden or it cannot be possible hidden,
# it's fine to schedule it
if is_wait_tensor(node):
if _schedulable_wait_node(node):
info = self.collective_info[self.wait_to_start[node]]
if info.hiding_node and info.hiding_node != curr_compute_node:
continue
@ -875,7 +866,7 @@ class OverlapScheduler:
assert all(n not in self.scheduled for n in path)
for node in sorted(path, key=lambda n: self.node_idx[n]):
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
if is_wait_tensor(node):
if _schedulable_wait_node(node):
# When we schedule wait tensors, we also force realization of all
# collectives enqueued prior to their corresponding collective.
# It's possible the scheduling of one wait tensor here has forced

View File

@ -891,14 +891,10 @@ class TorchLogsFormatter(logging.Formatter):
# exception handling - copied from logging.Formatter.format
s = record.message
if record.exc_info:
from torch._dynamo import config
should_format_exc = config.verbose or artifact_name != "graph_breaks"
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if should_format_exc:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if s[-1:] != "\n":
s = s + "\n"

View File

@ -465,39 +465,6 @@ lib.define(
"_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
)
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
"""
Given a local tensor and a group name, return a tuple of tensors that are
symmetric on other devices. The returned tensors are ordered by rank IDs. The
length of the tuple equals to the size of the group.
Note: this API works only when `world_within_direct_access()` returns True, i.e.
only when the group is within NVLink domain or similar. It does not work across
network interfaces.
"""
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
def _get_remote_tensors_default(
local: torch.Tensor, group_name: str
) -> tuple[torch.Tensor, ...]:
hdl = rendezvous(local, group_name)
if hdl is None:
raise ValueError("Tensor is not allocated from Symmetric Memory")
return tuple(
hdl.get_remote_tensor(peer, local.size(), local.dtype)
for peer in range(hdl.world_size)
)
@torch.library.impl(lib, "get_remote_tensors", "Meta")
def _get_remote_tensors_meta(
local: torch.Tensor, group_name: str
) -> tuple[torch.Tensor, ...]:
group = c10d._resolve_process_group(group_name)
return tuple(torch.empty_like(local) for _ in range(group.size()))
class _ScaleMode(Enum):
UNSCALED = "unscaled"

View File

@ -337,34 +337,19 @@ class OpDispatcher:
if is_inplace_op:
# inplace op should return self instead of re-wrapping
if output_sharding.output_spec is not None:
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], dtensor.DTensor)
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
# the inplace argument's tensor meta. Here we choose to special case
# this op because as far as I know this is the only inplace op that
# has such as behavior. We can extend this special case if necessary.
if op_call == aten.squeeze_.dim:
# update the spec to handle tensor meta changes
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], dtensor.DTensor)
args[0]._spec = output_spec
# use return_and_correct_aliasing to match the outer and the inner
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
else:
# For all other inplace ops, check if placement changes are required
# Inplace operations that change placement are not supported because
# they would require redistribution, which breaks aliasing semantics.
# If there are views into the tensor, the views would not be updated.
if args[0]._spec.placements != output_spec.placements:
raise RuntimeError(
f"{op_call}: in-place operations that require placement changes "
f"are not supported. The operation would change placement from "
f"{args[0]._spec.placements} to {output_spec.placements}, "
f"which requires redistribution and breaks aliasing semantics. "
f"Please use the out-of-place version of this operation instead."
)
# Most inplace ops don't change tensor meta, so no spec update needed
return args[0]
else:
return None

View File

@ -207,19 +207,12 @@ def tensorify_python_scalars(
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if not dtype.is_floating_point:
continue
assert isinstance(node.args[0], fx.Node), node.args[0]
s = node.meta["val"].node.expr
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# only tensorify if the dtype is floating point
if not dtype.is_floating_point:
continue
expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode
)
@ -227,7 +220,9 @@ def tensorify_python_scalars(
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
expr_to_tensor_proxy[s], torch.float64
)
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# pyrefly: ignore [bad-argument-type]
elif (sym_expr := _get_sym_val(node)) is not None:
if sym_expr not in expr_to_sym_proxy and not isinstance(