mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Compare commits
1 Commits
update_sub
...
ruisi/fix_
| Author | SHA1 | Date | |
|---|---|---|---|
| 57a76b44de |
@ -1,342 +0,0 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/_unsafe_view_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addmv_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/copy_native.h>
|
||||
#include <ATen/ops/dot_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/max.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - Else if scale.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) ==
|
||||
* 1:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return at::isFloat8Type(t.scalar_type()) &&
|
||||
scale.scalar_type() == at::kFloat && scale.numel() == 1;
|
||||
}
|
||||
|
||||
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (
|
||||
at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat &&
|
||||
scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 &&
|
||||
scale.is_contiguous());
|
||||
}
|
||||
|
||||
bool is_desired_scaling(
|
||||
const at::Tensor& t,
|
||||
const at::Tensor& scale,
|
||||
ScalingType desired_scaling) {
|
||||
auto result = desired_scaling == ScalingType::TensorWise
|
||||
? is_tensorwise_scaling(t, scale)
|
||||
: is_rowwise_scaling(t, scale);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
|
||||
const at::Tensor& a,
|
||||
const at::Tensor& b,
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b) {
|
||||
for (auto [lhs, rhs] : options) {
|
||||
if (is_desired_scaling(a, scale_a, lhs) &&
|
||||
is_desired_scaling(b.t(), scale_b.t(), rhs)) {
|
||||
return {lhs, rhs};
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (",
|
||||
a.size(0),
|
||||
", 1) and scale_b should be (1, ",
|
||||
b.size(1),
|
||||
"), and both should be contiguous.\n"
|
||||
"Got a.dtype()=",
|
||||
a.scalar_type(),
|
||||
", scale_a.dtype()=",
|
||||
scale_a.scalar_type(),
|
||||
", scale_a.size()=",
|
||||
scale_a.sizes(),
|
||||
", scale_a.stride()=",
|
||||
scale_a.strides(),
|
||||
", ",
|
||||
"b.dtype()=",
|
||||
b.scalar_type(),
|
||||
", scale_b.dtype()=",
|
||||
scale_b.scalar_type(),
|
||||
", scale_b.size()=",
|
||||
scale_b.sizes(),
|
||||
" and scale_b.stride()=",
|
||||
scale_b.strides());
|
||||
}
|
||||
|
||||
Tensor& _scaled_gemm(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a,
|
||||
const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
// TODO: scale_result and alpha is not defined or used!
|
||||
std::optional<Tensor> scaled_result = std::nullopt;
|
||||
at::native::onednn::scaled_matmul(
|
||||
mat1,
|
||||
mat2,
|
||||
out,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
scaled_result,
|
||||
use_fast_accum);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output
|
||||
// matrices Scales are only applicable when matrices are of Float8 type and
|
||||
// assumed to be equal to 1.0 by default. If output matrix type is 16 or 32-bit
|
||||
// type, scale_result is not applied. Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
// - If 1-dimensional tensors are used then scale_a should be size =
|
||||
// mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat1`: the first operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type
|
||||
// `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher
|
||||
// precision floating point type
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose
|
||||
// shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only
|
||||
// utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: Not applicable for XPU. For now, it should always be
|
||||
// false.
|
||||
// - `out`: a reference to the output tensor
|
||||
|
||||
Tensor& _scaled_mm_out_xpu(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// Note: fast_accum is not supported in XPU for now.
|
||||
TORCH_CHECK(!use_fast_accum, "fast_accum is not supported in XPU for now.");
|
||||
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs. This list is
|
||||
// sorted by decreasing priority.
|
||||
|
||||
// List of supported datatypes for XPU with oneDNN:
|
||||
// https://uxlfoundation.github.io/oneDNN/dev_guide_matmul.html#data-types
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
|
||||
},
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b);
|
||||
TORCH_CHECK(
|
||||
!scale_result ||
|
||||
(scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
TORCH_CHECK(
|
||||
!bias || bias->numel() == mat2.sizes()[1],
|
||||
"Bias must be size ",
|
||||
mat2.sizes()[1],
|
||||
" but got ",
|
||||
bias->numel());
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK(
|
||||
mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0,
|
||||
"mat2 shape (",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
") must be divisible by 16");
|
||||
// Check types
|
||||
TORCH_CHECK(
|
||||
!out_dtype || *out_dtype == out.scalar_type(),
|
||||
"out_dtype must match output matrix type");
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat1.scalar_type()),
|
||||
"Expected mat1 to be Float8 matrix got ",
|
||||
mat1.scalar_type());
|
||||
TORCH_CHECK(
|
||||
at::isFloat8Type(mat2.scalar_type()),
|
||||
"Expected mat2 to be Float8 matrix got ",
|
||||
mat2.scalar_type());
|
||||
// TODO: oneDNN Currently only supports e4m3 with group scales on BMG. Not
|
||||
// support 2D scales, only 1D. Needs to add more checks there.
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(
|
||||
bias->scalar_type() == kFloat ||
|
||||
bias->scalar_type() == c10::ScalarType::BFloat16 ||
|
||||
bias->scalar_type() == c10::ScalarType::Half,
|
||||
"Bias must be Float32 or BFloat16 or Half, but got ",
|
||||
bias->scalar_type());
|
||||
}
|
||||
|
||||
{
|
||||
auto bias_ = bias.value_or(Tensor());
|
||||
auto scale_result_ = scale_result.value_or(Tensor());
|
||||
|
||||
// NOLINTNEXTLINE(*c-array*)
|
||||
TensorArg targs[]{
|
||||
{out, "out", 0},
|
||||
{mat1, "mat1", 1},
|
||||
{mat2, "mat2", 2},
|
||||
{bias_, "bias", 3},
|
||||
{scale_a, "scale_a", 4},
|
||||
{scale_b, "scale_b", 5},
|
||||
{scale_result_, "scale_result", 6}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm
|
||||
// kernels do not support this case).
|
||||
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
if (mat1_sizes[1] == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// TODO: Scale_result is not supported by now!!
|
||||
return _scaled_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scaling_choice_a,
|
||||
scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
Tensor _scaled_mm_xpu(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
return _scaled_mm_out_xpu(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
scale_result,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -1,4 +1,3 @@
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -9,6 +8,7 @@
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
at::Tensor broadcast_bias2D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor& bias,
|
||||
@ -328,236 +328,4 @@ void quantized_matmul(
|
||||
result.copy_(dst);
|
||||
}
|
||||
|
||||
// Describes how to configure oneDNN scales for a given role/ScalingType
|
||||
struct ScaleSpec {
|
||||
// specifies the way scale values will be applied to an ARG tensor.
|
||||
int mask;
|
||||
// specifies how scales are grouped along dimensions where
|
||||
// multiple scale factors are used.
|
||||
dnnl::memory::dims groups;
|
||||
// specifies data type for scale factors.
|
||||
dnnl::memory::data_type dtype;
|
||||
|
||||
// Helper to compute expected number of elements for scale tensors
|
||||
// arg_type: "src" for SRC (groups pattern {1, X}),
|
||||
// "wei" for WEIGHTS (groups pattern {X, 1})
|
||||
int64_t expected_numel(
|
||||
int64_t outer_dim,
|
||||
int64_t inner_dim,
|
||||
const std::string& arg_type) const {
|
||||
if (groups == dnnl::memory::dims{1, 1})
|
||||
return 1; // tensorwise scaling
|
||||
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
|
||||
// For rowwise: SRC groups={1, K}, WEI groups={K, 1}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(groups == dnnl::memory::dims{1, inner_dim} ||
|
||||
groups == dnnl::memory::dims{inner_dim, 1}),
|
||||
"The groups must be either {1, inner_dim} or {inner_dim, 1}. But got ",
|
||||
groups,
|
||||
".");
|
||||
return outer_dim;
|
||||
}
|
||||
|
||||
// Normalize an incoming scale tensor to contiguous storage and appropriate
|
||||
// dtype/view
|
||||
at::Tensor normalize(const at::Tensor& scale) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dtype == dnnl::memory::data_type::f32,
|
||||
"tensor scale currently must be f32, but got scale dtype: ",
|
||||
scale.scalar_type());
|
||||
return scale.to(at::kFloat).contiguous();
|
||||
}
|
||||
};
|
||||
|
||||
// This function defines how to set scales mask and groups according to:
|
||||
// https://github.com/uxlfoundation/oneDNN/blob/main/tests/benchdnn/doc/knobs_attr.md#--attr-scales
|
||||
// The returned value will be used in
|
||||
// `set_scales(arg, mask, groups, data_type)`.
|
||||
inline ScaleSpec make_scale_spec(
|
||||
at::blas::ScalingType scaling_type,
|
||||
int64_t M,
|
||||
int64_t K,
|
||||
int64_t N,
|
||||
const std::string& arg_type) {
|
||||
TORCH_CHECK(
|
||||
arg_type == "src" || arg_type == "wei",
|
||||
"Expected arg_type to be 'src' or 'wei', but got '",
|
||||
arg_type,
|
||||
"'");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(scaling_type == at::blas::ScalingType::TensorWise ||
|
||||
scaling_type == at::blas::ScalingType::RowWise),
|
||||
"Currently only support scaling_type for TensorWise or RowWise");
|
||||
int64_t dim = K; // Currently only K is used for grouping
|
||||
bool is_src = (arg_type == "src");
|
||||
if (scaling_type == at::blas::ScalingType::TensorWise) {
|
||||
// Scale tensorwise. The same as `--attr-scales=common`.
|
||||
// mask=0 : scale whole tensor
|
||||
// groups={1, 1}: indicates that there is only one group for scaling
|
||||
return {0, {1, 1}, dnnl::memory::data_type::f32};
|
||||
} else {
|
||||
// (scaling_type == at::blas::ScalingType::RowWise)
|
||||
// Scale RowWise. The same as `--attr-scales=per_dim_01`.
|
||||
// mask={(1 << 0) | (1 << 1)}: Scale on both dim0 and dim1
|
||||
// SRC: groups={1, K}, WEIGHTS: groups={K, 1}
|
||||
return {
|
||||
(1 << 0) | (1 << 1),
|
||||
is_src ? dnnl::memory::dims{1, dim} : dnnl::memory::dims{dim, 1},
|
||||
dnnl::memory::data_type::f32};
|
||||
}
|
||||
}
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum) {
|
||||
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||
auto& stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
// This function will do steps with following steps
|
||||
// 1. create memory descriptor
|
||||
// 2. call write_to_dnnl_memory() to actually write memory
|
||||
// 3. execute
|
||||
|
||||
const int64_t M = mat1.size(0);
|
||||
const int64_t K = mat1.size(1);
|
||||
const int64_t N = mat2.size(1);
|
||||
|
||||
// 1.1 Create memory descriptor
|
||||
dnnl::memory::desc src_md = get_onednn_md(mat1);
|
||||
dnnl::memory::desc weights_md = get_onednn_md(mat2);
|
||||
dnnl::memory::desc dst_md = get_onednn_md(result);
|
||||
|
||||
// scale_a and scale_b has already be checked in `is_desired_scaling()` call.
|
||||
// So we could directly get their memory desc and set later.
|
||||
dnnl::memory::desc scale_a_md = get_onednn_md(scale_a);
|
||||
dnnl::memory::desc scale_b_md = get_onednn_md(scale_b);
|
||||
|
||||
dnnl::memory::desc bias_md;
|
||||
bool with_bias = bias.has_value();
|
||||
at::Tensor possible_reshaped_bias = bias.value_or(at::Tensor());
|
||||
if (with_bias) {
|
||||
if (possible_reshaped_bias.dim() == 1) {
|
||||
possible_reshaped_bias =
|
||||
possible_reshaped_bias.reshape({1, possible_reshaped_bias.size(0)});
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
} else {
|
||||
bias_md = get_onednn_md(possible_reshaped_bias);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.2 Create primitive descriptor and set scales mask
|
||||
const ScaleSpec src_spec = make_scale_spec(scaling_choice_a, M, K, N, "src");
|
||||
const ScaleSpec wei_spec = make_scale_spec(scaling_choice_b, M, K, N, "wei");
|
||||
|
||||
dnnl::primitive_attr op_attr = dnnl::primitive_attr();
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
op_attr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> default_groups;
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_SRC, src_spec.mask, src_spec.groups, src_spec.dtype);
|
||||
op_attr.set_scales(
|
||||
DNNL_ARG_WEIGHTS, wei_spec.mask, wei_spec.groups, wei_spec.dtype);
|
||||
// scale_result tensor currently only supports scalar(TensorWise Scaling).
|
||||
bool with_dst_scale = scale_result && scale_result->defined();
|
||||
if (with_dst_scale) {
|
||||
op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32);
|
||||
}
|
||||
|
||||
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// 1.3 Create the matmul primitive descriptor
|
||||
dnnl::matmul::primitive_desc matmul_pd = with_bias
|
||||
? dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, bias_md, dst_md, op_attr)
|
||||
: dnnl::matmul::primitive_desc(
|
||||
engine, src_md, weights_md, dst_md, op_attr);
|
||||
|
||||
// 1.4 (Possible) Additional Checks
|
||||
// TODO: In case there are memory desc does not align with the actual tensor,
|
||||
// we might need to reorder weights similar to CPU's reorder_if_differ_in()
|
||||
// call. For example, weights not the same as matmul_pd.weights_desc(),
|
||||
|
||||
// 2. Prepare memory
|
||||
|
||||
// Create memory
|
||||
auto src_usr_m = make_onednn_memory(src_md, engine, mat1.data_ptr());
|
||||
auto weights_usr_m = make_onednn_memory(weights_md, engine, mat2.data_ptr());
|
||||
auto dst_usr_m = make_onednn_memory(dst_md, engine, result.data_ptr());
|
||||
dnnl::memory b_usr_m;
|
||||
if (with_bias) {
|
||||
b_usr_m =
|
||||
make_onednn_memory(bias_md, engine, possible_reshaped_bias.data_ptr());
|
||||
}
|
||||
|
||||
// Prepare runtime scale memories (flat 1-D views) using the specs
|
||||
auto make_scale_mem_from_spec = [&](const ScaleSpec& spec,
|
||||
int64_t expected_numel,
|
||||
const at::Tensor& scale_tensor) {
|
||||
at::Tensor prepared = spec.normalize(scale_tensor);
|
||||
TORCH_CHECK(
|
||||
prepared.numel() == expected_numel,
|
||||
"Scale buffer length mismatch. Expected ",
|
||||
expected_numel,
|
||||
", got ",
|
||||
prepared.numel());
|
||||
dnnl::memory::desc scale_md(
|
||||
{prepared.numel()}, spec.dtype, dnnl::memory::format_tag::x);
|
||||
return make_onednn_memory(scale_md, engine, prepared.data_ptr());
|
||||
};
|
||||
|
||||
auto scratchpad =
|
||||
make_onednn_memory(matmul_pd.scratchpad_desc(), engine, nullptr);
|
||||
|
||||
// 3. Setup Args for exec
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
args.insert({DNNL_ARG_SRC, src_usr_m});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weights_usr_m});
|
||||
args.insert({DNNL_ARG_DST, dst_usr_m});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
if (with_bias) {
|
||||
args.insert({DNNL_ARG_BIAS, b_usr_m});
|
||||
}
|
||||
|
||||
// Attach runtime scales using specs
|
||||
auto src_sc_mem = make_scale_mem_from_spec(
|
||||
src_spec, src_spec.expected_numel(M, K, "src"), scale_a);
|
||||
auto wei_sc_mem = make_scale_mem_from_spec(
|
||||
wei_spec, wei_spec.expected_numel(N, K, "wei"), scale_b);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_mem});
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_sc_mem});
|
||||
if (with_dst_scale) {
|
||||
// Bind single f32 scalar as DST scale
|
||||
at::Tensor dst_scale_f32 = scale_result->to(at::kFloat).contiguous();
|
||||
dnnl::memory::desc dst_sc_md(
|
||||
{1}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::x);
|
||||
auto dst_sc_mem =
|
||||
make_onednn_memory(dst_sc_md, engine, dst_scale_f32.data_ptr());
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_mem});
|
||||
}
|
||||
|
||||
dnnl::matmul matmul_p = dnnl::matmul(matmul_pd);
|
||||
sycl::event matmul_fwd_event =
|
||||
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||
return matmul_fwd_event;
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -78,10 +78,6 @@ dnnl::memory::data_type get_onednn_dtype(
|
||||
return dnnl::memory::data_type::f32;
|
||||
case at::ScalarType::BFloat16:
|
||||
return dnnl::memory::data_type::bf16;
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return dnnl::memory::data_type::f8_e4m3;
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return dnnl::memory::data_type::f8_e5m2;
|
||||
default:
|
||||
if (!allow_undef) {
|
||||
TORCH_CHECK(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
@ -203,16 +202,4 @@ void sdpa_backward(
|
||||
Tensor& grad_query,
|
||||
Tensor& grad_key,
|
||||
Tensor& grad_value);
|
||||
|
||||
sycl::event scaled_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
Tensor& result,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
at::blas::ScalingType scaling_choice_a,
|
||||
at::blas::ScalingType scaling_choice_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
bool use_fast_accum);
|
||||
} // namespace at::native::onednn
|
||||
|
||||
@ -1,113 +0,0 @@
|
||||
# Device Management
|
||||
|
||||
## Background
|
||||
|
||||
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
|
||||
|
||||
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
|
||||
|
||||
## Design
|
||||
|
||||
Accelerator vendors need to implement these core functions:
|
||||
|
||||
| Function Name | Description | Application Scenarios |
|
||||
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
|
||||
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
|
||||
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
|
||||
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
|
||||
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
|
||||
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
|
||||
|
||||
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
|
||||
|
||||
## Implementation
|
||||
|
||||
This section shows how to implement device management using `set_device` as an example. The implementation requires:
|
||||
1. C++ wrappers around the device runtime
|
||||
2. Python bindings to expose the C++ functions
|
||||
3. User-friendly Python APIs
|
||||
|
||||
### C++ Side
|
||||
|
||||
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### Binding
|
||||
|
||||
Expose the C++ functions to Python using pybind11:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 5
|
||||
```
|
||||
|
||||
### Python Side
|
||||
|
||||
Wrap the C++ bindings with user-friendly Python functions:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
Here's the complete mapping from C++ to Python:
|
||||
|
||||
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
|
||||
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
|
||||
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
|
||||
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
|
||||
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
|
||||
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
|
||||
|
||||
## Guard
|
||||
|
||||
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
|
||||
|
||||
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
**What needs to be implemented:**
|
||||
|
||||
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
|
||||
2. **getDevice()**: Get the current device
|
||||
3. **setDevice()**: Set the active device
|
||||
4. **Type checking**: Validate that device type matches the backend
|
||||
|
||||
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
|
||||
|
||||
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"
|
||||
@ -42,7 +42,6 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
device
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
|
||||
@ -4,12 +4,17 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
|
||||
void orCheckFail(
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg = "");
|
||||
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (C10_UNLIKELY(__err != orSuccess)) { \
|
||||
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (__err != orSuccess) { \
|
||||
orCheckFail( \
|
||||
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegException.h"
|
||||
@ -10,22 +9,21 @@ orError_t GetDeviceCount(int* dev_count) {
|
||||
return orGetDeviceCount(dev_count);
|
||||
}
|
||||
|
||||
orError_t GetDevice(DeviceIndex* device) {
|
||||
orError_t GetDevice(c10::DeviceIndex* device) {
|
||||
int tmp_device = -1;
|
||||
auto err = orGetDevice(&tmp_device);
|
||||
*device = static_cast<DeviceIndex>(tmp_device);
|
||||
*device = static_cast<c10::DeviceIndex>(tmp_device);
|
||||
return err;
|
||||
}
|
||||
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
orError_t SetDevice(DeviceIndex device) {
|
||||
|
||||
orError_t SetDevice(c10::DeviceIndex device) {
|
||||
int cur_device = -1;
|
||||
OPENREG_CHECK(orGetDevice(&cur_device));
|
||||
orGetDevice(&cur_device);
|
||||
if (device == cur_device) {
|
||||
return orSuccess;
|
||||
}
|
||||
return orSetDevice(device);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
|
||||
int device_count_impl() {
|
||||
int count = 0;
|
||||
@ -33,37 +31,34 @@ int device_count_impl() {
|
||||
return count;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept {
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||
// initialize number of devices only once
|
||||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl();
|
||||
TORCH_CHECK(
|
||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
||||
result <= std::numeric_limits<c10::DeviceIndex>::max(),
|
||||
"Too many devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
} catch (const Error& ex) {
|
||||
} catch (const c10::Error& ex) {
|
||||
// We don't want to fail, but still log the warning
|
||||
// msg() returns the message without the stack trace
|
||||
TORCH_WARN("Device initialization: ", ex.msg());
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
return static_cast<DeviceIndex>(count);
|
||||
return static_cast<c10::DeviceIndex>(count);
|
||||
}
|
||||
|
||||
OPENREG_EXPORT DeviceIndex current_device() {
|
||||
DeviceIndex cur_device = -1;
|
||||
OPENREG_CHECK(GetDevice(&cur_device));
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||
c10::DeviceIndex cur_device = -1;
|
||||
GetDevice(&cur_device);
|
||||
return cur_device;
|
||||
}
|
||||
|
||||
// LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device) {
|
||||
check_device_index(device);
|
||||
OPENREG_CHECK(SetDevice(device));
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||
SetDevice(device);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
int current_device = -1;
|
||||
@ -76,8 +71,4 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
return current_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
|
||||
check_device_index(to_device);
|
||||
return ExchangeDevice(to_device);
|
||||
}
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -9,20 +9,10 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||
|
||||
static inline void check_device_index(int64_t device) {
|
||||
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
|
||||
"The device index is out of range. It must be in [0, ",
|
||||
static_cast<int>(c10::openreg::device_count()),
|
||||
"), but got ",
|
||||
static_cast<int>(device),
|
||||
".");
|
||||
}
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
|
||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
||||
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||
|
||||
@ -59,7 +58,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
set_device(d.index());
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
|
||||
/**
|
||||
* Set the current device to c10::Device, without checking for errors
|
||||
|
||||
@ -27,10 +27,6 @@ class TestDevice(TestCase):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
|
||||
def test_invalid_device_index(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||
torch.accelerator.set_device_index(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -34,21 +34,18 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
|
||||
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
||||
auto device = THPUtils_unpackDeviceIndex(arg);
|
||||
auto device = THPUtils_unpackLong(arg);
|
||||
|
||||
torch::utils::device_lazy_init(at::kPrivateUse1);
|
||||
c10::openreg::set_device(device);
|
||||
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
|
||||
|
||||
@ -41,13 +41,8 @@ def current_device():
|
||||
return torch_openreg._C._get_device()
|
||||
|
||||
|
||||
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
def set_device(device) -> None:
|
||||
if device >= 0:
|
||||
torch_openreg._C._set_device(device)
|
||||
|
||||
|
||||
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
return torch_openreg._C._set_device(device)
|
||||
|
||||
|
||||
def init():
|
||||
|
||||
@ -331,25 +331,6 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
self.assertEqual(z.placements, (Replicate(),))
|
||||
self.assertEqual(z.to_local(), input)
|
||||
|
||||
def test_inplace_op_partial_to_replicate(self):
|
||||
# test that in-place operations that require redistribution raise an error
|
||||
# to preserve aliasing semantics (issue #163374)
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
input_tensor = torch.tensor(64.0, device=self.device_type)
|
||||
partial_dt = DTensor.from_local(
|
||||
input_tensor, device_mesh, placements=(Partial(),)
|
||||
)
|
||||
|
||||
self.assertTrue(partial_dt.placements[0].is_partial())
|
||||
|
||||
# Inplace ops that require placement changes (Partial -> Replicate) should error
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"in-place operations that require placement changes are not supported",
|
||||
):
|
||||
partial_dt.clamp_(max=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._dynamo.test_case
|
||||
|
||||
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
import torch.fx as fx
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.utils import counters, same
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
|
||||
@ -238,6 +239,49 @@ graph():
|
||||
self.assertTrue(same(out, correct))
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_schedulable_wait(self):
|
||||
"""Test that if a wait node is scheduable or not."""
|
||||
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node
|
||||
|
||||
def test_graph():
|
||||
graph = fx.Graph()
|
||||
|
||||
inp = graph.placeholder("inp")
|
||||
group_size = graph.placeholder("group_size")
|
||||
group_name = graph.placeholder("group_name")
|
||||
|
||||
ag_0_out = graph.call_function(
|
||||
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
args=(inp, group_size, group_name),
|
||||
)
|
||||
ag_0_wait = graph.call_function(
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
args=(ag_0_out,),
|
||||
)
|
||||
ag_1_out = graph.call_function(
|
||||
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
args=(ag_0_wait, group_size, group_name),
|
||||
)
|
||||
ag_1_wait = graph.call_function(
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
args=(ag_1_out,),
|
||||
)
|
||||
ag_2_wait = graph.call_function(
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
args=(ag_1_wait,),
|
||||
)
|
||||
|
||||
graph.output(ag_2_wait)
|
||||
return graph
|
||||
|
||||
graph = test_graph()
|
||||
schedulable = {"wait_tensor_default", "wait_tensor_default_1"}
|
||||
for node in list(graph.nodes):
|
||||
expected = node.name in schedulable
|
||||
assert _schedulable_wait_node(node) is expected
|
||||
|
||||
@torch._inductor.config.patch(get_patches())
|
||||
def test_reorder_compute_for_overlap_mul(self):
|
||||
def func(a, *, tag, ranks, group_size):
|
||||
|
||||
@ -23,7 +23,12 @@ from torch._inductor.comms import (
|
||||
sink_waits_iterative,
|
||||
)
|
||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||
from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
is_all_gather_into_tensor,
|
||||
is_all_reduce_tensor,
|
||||
is_all_to_all_tensor,
|
||||
is_reduce_scatter_tensor,
|
||||
)
|
||||
from torch._inductor.scheduler import (
|
||||
_get_mm_like_fn,
|
||||
BaseSchedulerNode,
|
||||
@ -2193,7 +2198,7 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
|
||||
self.assertEqual(saved_values, [wt1])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_comm_analysis(self):
|
||||
def test_all_gather_comm_analysis(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
torch.cuda.set_device(self.rank)
|
||||
c10d.init_process_group(
|
||||
@ -2234,6 +2239,140 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
|
||||
)
|
||||
assert est_ms_nccl > 0
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_reduce_scatter_comm_analysis(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
torch.cuda.set_device(self.rank)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
group = c10d.distributed_c10d._get_default_group()
|
||||
group_name = "default"
|
||||
torch._C._distributed_c10d._register_process_group(
|
||||
group_name, torch.distributed.group.WORLD
|
||||
)
|
||||
group_size = group.size()
|
||||
|
||||
def func(inp, group_size, group_name):
|
||||
rs_0_out = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
inp, "sum", group_size, group_name
|
||||
)
|
||||
rs_0_wait = torch.ops.c10d_functional.wait_tensor(rs_0_out)
|
||||
rs_1_out = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
rs_0_wait, "sum", group_size, group_name
|
||||
)
|
||||
rs_1_wait = torch.ops.c10d_functional.wait_tensor(rs_1_out)
|
||||
return rs_1_wait
|
||||
|
||||
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
|
||||
g = gm.graph
|
||||
for n in g.nodes:
|
||||
if is_reduce_scatter_tensor(n):
|
||||
from torch._inductor.comm_analysis import (
|
||||
estimate_nccl_collective_runtime_from_fx_node,
|
||||
)
|
||||
|
||||
est_ms = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=False
|
||||
)
|
||||
assert est_ms > 0
|
||||
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=True
|
||||
)
|
||||
assert est_ms_nccl > 0
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce_comm_analysis(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
torch.cuda.set_device(self.rank)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
group = c10d.distributed_c10d._get_default_group()
|
||||
group_name = "default"
|
||||
torch._C._distributed_c10d._register_process_group(
|
||||
group_name, torch.distributed.group.WORLD
|
||||
)
|
||||
group_size = group.size()
|
||||
|
||||
def func(inp, group_size, group_name):
|
||||
ar_0_out = torch.ops._c10d_functional.all_reduce(inp, "sum", group_name)
|
||||
ar_0_wait = torch.ops.c10d_functional.wait_tensor(ar_0_out)
|
||||
ar_1_out = torch.ops._c10d_functional.all_reduce(
|
||||
ar_0_wait, "sum", group_name
|
||||
)
|
||||
ar_1_wait = torch.ops.c10d_functional.wait_tensor(ar_1_out)
|
||||
return ar_1_wait
|
||||
|
||||
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
|
||||
g = gm.graph
|
||||
for n in g.nodes:
|
||||
if is_all_reduce_tensor(n):
|
||||
from torch._inductor.comm_analysis import (
|
||||
estimate_nccl_collective_runtime_from_fx_node,
|
||||
)
|
||||
|
||||
est_ms = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=False
|
||||
)
|
||||
assert est_ms > 0
|
||||
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=True
|
||||
)
|
||||
assert est_ms_nccl > 0
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_to_all_comm_analysis(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
torch.cuda.set_device(self.rank)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
group = c10d.distributed_c10d._get_default_group()
|
||||
group_name = "default"
|
||||
torch._C._distributed_c10d._register_process_group(
|
||||
group_name, torch.distributed.group.WORLD
|
||||
)
|
||||
group_size = group.size()
|
||||
|
||||
def func(inp, group_size, group_name):
|
||||
chunk = inp.numel() // self.world_size
|
||||
split_sizes = [chunk] * self.world_size
|
||||
a2a_0_out = torch.ops._c10d_functional.all_to_all_single(
|
||||
inp,
|
||||
split_sizes,
|
||||
split_sizes,
|
||||
group_name,
|
||||
)
|
||||
a2a_0_wait = torch.ops.c10d_functional.wait_tensor(a2a_0_out)
|
||||
a2a_1_out = torch.ops._c10d_functional.all_to_all_single(
|
||||
a2a_0_wait,
|
||||
split_sizes,
|
||||
split_sizes,
|
||||
group_name,
|
||||
)
|
||||
a2a_1_wait = torch.ops.c10d_functional.wait_tensor(a2a_1_out)
|
||||
return a2a_1_wait
|
||||
|
||||
gm = make_fx(func)(
|
||||
torch.ones(group_size * 4, 1, device=self.device), group_size, group_name
|
||||
)
|
||||
g = gm.graph
|
||||
for n in g.nodes:
|
||||
if is_all_to_all_tensor(n):
|
||||
from torch._inductor.comm_analysis import (
|
||||
estimate_nccl_collective_runtime_from_fx_node,
|
||||
)
|
||||
|
||||
est_ms = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=False
|
||||
)
|
||||
assert est_ms > 0
|
||||
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=True
|
||||
)
|
||||
assert est_ms_nccl > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -208,21 +208,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
||||
)
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
def test_get_remote_tensors(self) -> None:
|
||||
"""
|
||||
Get all remote tensors
|
||||
"""
|
||||
self._init_device()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
|
||||
my_tensor = symm_mem.empty(1, device=self.device).fill_(self.rank)
|
||||
remote_tensors = torch.ops.symm_mem.get_remote_tensors(my_tensor, group_name)
|
||||
dist.barrier()
|
||||
|
||||
for peer, tensor in enumerate(remote_tensors):
|
||||
self.assertEqual(tensor, peer)
|
||||
|
||||
@skipIfRocm
|
||||
def test_nvshmem_put(self) -> None:
|
||||
self._init_device()
|
||||
|
||||
@ -952,9 +952,7 @@ User code traceback:
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
Graph break: skip: from user code at:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
assert x is None
|
||||
""",
|
||||
@ -1080,88 +1078,6 @@ from user code:
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback(self, records):
|
||||
def fn(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message(self, records):
|
||||
def fn(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(dynamo=logging.DEBUG)
|
||||
def test_skip_frame_empty_function_message(self, records):
|
||||
def empty_fn(x):
|
||||
pass
|
||||
|
||||
torch.compile(empty_fn, backend="eager")(torch.randn(3))
|
||||
skip_messages = [
|
||||
r
|
||||
for r in records
|
||||
if "intentionally decided to skip the frame" in r.getMessage()
|
||||
]
|
||||
self.assertEqual(len(skip_messages), 1)
|
||||
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
|
||||
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
|
||||
|
||||
self.assertExpectedInline(
|
||||
msg,
|
||||
"""\
|
||||
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
|
||||
Reason: no content in function call empty_fn test_error_messages.py N""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_nested_compile_user_frames(self, records):
|
||||
def fn(x):
|
||||
@ -1708,110 +1624,6 @@ from user code:
|
||||
)
|
||||
|
||||
|
||||
class NestedGraphBreakLoggingTests(
|
||||
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 2)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 3)
|
||||
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 3)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 2)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 4)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 5)
|
||||
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Data-dependent branching
|
||||
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
|
||||
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
|
||||
Hint: Use `torch.cond` to express dynamic control flow.
|
||||
|
||||
Developer debug context: attempted to jump with TensorVariable()
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 5)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 4)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -14036,44 +14036,6 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorify_track_item_symint(self):
|
||||
def _random_resize(image: torch.Tensor):
|
||||
image_metanet = image
|
||||
default_patch_size = 14
|
||||
rand_cnn_resolution = (224, 256)
|
||||
min_nump = rand_cnn_resolution[0] // default_patch_size
|
||||
max_nump = rand_cnn_resolution[1] // default_patch_size
|
||||
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
|
||||
torch._check(new_nump > 0)
|
||||
torch._check(new_nump * default_patch_size > 1)
|
||||
|
||||
image_metanet = F.interpolate(
|
||||
image_metanet,
|
||||
size=(new_nump * default_patch_size, new_nump * default_patch_size),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
img_h_new, img_w_new = image_metanet.shape[2:]
|
||||
|
||||
return (img_h_new, img_w_new), image_metanet
|
||||
|
||||
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
|
||||
|
||||
# Test the function
|
||||
input_tensor = torch.rand(1, 3, 224, 224)
|
||||
(h, w), output = _random_resize_compiled(input_tensor)
|
||||
|
||||
# Verify output properties
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 3)
|
||||
self.assertEqual(output.shape[2], h)
|
||||
self.assertEqual(output.shape[3], w)
|
||||
self.assertTrue(h % 14 == 0)
|
||||
self.assertTrue(w % 14 == 0)
|
||||
self.assertTrue(224 <= h <= 256)
|
||||
self.assertTrue(224 <= w <= 256)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: 36f6d71922...6fcbc53d33
@ -1870,7 +1870,7 @@ class ConvertFrame:
|
||||
raise
|
||||
|
||||
soft_fail = isinstance(e, Unsupported)
|
||||
code = frame.f_code
|
||||
|
||||
# This is a soft failure. In the sense, the code path reaches here
|
||||
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
||||
# BUILD_SET etc. In such case, we can fallback to eager without
|
||||
@ -1885,13 +1885,7 @@ class ConvertFrame:
|
||||
user_stack_formatted = "".join(
|
||||
traceback.format_list(user_stack)
|
||||
)
|
||||
frame_info = exc.format_frame_info(code)
|
||||
user_stack_trace = (
|
||||
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
|
||||
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
|
||||
"The graph break occurred in the following user code:\n"
|
||||
f"{user_stack_formatted}"
|
||||
)
|
||||
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
@ -1903,7 +1897,6 @@ class ConvertFrame:
|
||||
graph_break_log.debug(
|
||||
user_stack_trace,
|
||||
exc_info=True,
|
||||
stack_info=config.verbose,
|
||||
)
|
||||
|
||||
if not config.suppress_errors and not soft_fail:
|
||||
|
||||
@ -794,38 +794,6 @@ def format_error_msg_verbose(
|
||||
return msg
|
||||
|
||||
|
||||
def format_frame_info(code: types.CodeType) -> str:
|
||||
return (
|
||||
f"{getattr(code, 'co_name', '<unknown>')} "
|
||||
f"({getattr(code, 'co_filename', '<unknown>')} "
|
||||
f"line {getattr(code, 'co_firstlineno', 0)})"
|
||||
)
|
||||
|
||||
|
||||
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
|
||||
if code is not None:
|
||||
frame_info = format_frame_info(code)
|
||||
return (
|
||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
|
||||
|
||||
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
|
||||
frame_info = format_frame_info(code)
|
||||
return (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
||||
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
|
||||
f"{frame_summary}"
|
||||
)
|
||||
|
||||
|
||||
def format_error_msg(
|
||||
exc: Exception,
|
||||
code: types.CodeType,
|
||||
|
||||
@ -94,8 +94,6 @@ from .exc import (
|
||||
BackendCompilerFailed,
|
||||
collapse_resume_frames,
|
||||
format_graph_break_message,
|
||||
format_loop_skip_frame_message,
|
||||
format_skip_frame_message,
|
||||
get_stack_above_dynamo,
|
||||
ResumePrologueTracingError,
|
||||
StepUnsupported,
|
||||
@ -607,9 +605,9 @@ def generic_jump(
|
||||
)
|
||||
# compile a partial subgraph prefix then jump into user code
|
||||
if self.maybe_has_backedge():
|
||||
msg = format_loop_skip_frame_message(
|
||||
self.f_code,
|
||||
"".join(traceback.format_list([self.frame_summary()])),
|
||||
msg = (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"{self.frame_summary()}"
|
||||
)
|
||||
log.info(msg)
|
||||
raise exc.SkipFrame(msg)
|
||||
@ -885,9 +883,9 @@ def break_graph_if_unsupported(
|
||||
)
|
||||
|
||||
if self.maybe_has_backedge():
|
||||
msg = format_loop_skip_frame_message(
|
||||
self.f_code,
|
||||
"".join(traceback.format_list([self.frame_summary()])),
|
||||
msg = (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"{self.frame_summary()}"
|
||||
)
|
||||
log.info(msg)
|
||||
raise exc.SkipFrame(msg) from excp
|
||||
@ -4628,9 +4626,8 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
):
|
||||
raise exc.SkipFrame(
|
||||
format_skip_frame_message(self.f_code, "no content in function call")
|
||||
)
|
||||
raise exc.SkipFrame("because no content in function call")
|
||||
|
||||
self.instruction_pointer = None
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
|
||||
@ -2248,15 +2248,12 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
|
||||
try:
|
||||
val.data_ptr() # will throw for functorch tensors
|
||||
except RuntimeError as e:
|
||||
from .exc import format_skip_frame_message, SkipFrame
|
||||
from .exc import SkipFrame
|
||||
|
||||
# This will be GradTrackingTensor/BatchedTensor/etc
|
||||
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
|
||||
raise SkipFrame(
|
||||
format_skip_frame_message(
|
||||
None,
|
||||
f"torch.compile cannot be run in context: {functorch_subclass_name}",
|
||||
)
|
||||
f"torch.compile cannot be run in context: {functorch_subclass_name}"
|
||||
) from e
|
||||
|
||||
|
||||
|
||||
@ -42,7 +42,6 @@ from torch._guards import Source
|
||||
from .. import config, graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||
from ..exc import (
|
||||
format_skip_frame_message,
|
||||
get_dynamo_observed_exception,
|
||||
handle_observed_exception,
|
||||
InfiniteGeneratorError,
|
||||
@ -1653,13 +1652,8 @@ class SkipFunctionVariable(VariableTracker):
|
||||
skip_frame_msg = kwargs.get("msg")
|
||||
if skip_frame_msg:
|
||||
skip_frame_msg = skip_frame_msg.as_python_constant()
|
||||
else:
|
||||
skip_frame_msg = ""
|
||||
raise SkipFrame(
|
||||
format_skip_frame_message(
|
||||
tx.f_code,
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
|
||||
)
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
|
||||
)
|
||||
elif self.value is torch._dynamo.step_unsupported:
|
||||
raise StepUnsupported
|
||||
|
||||
@ -23,6 +23,7 @@ class NCCL_COLL(IntEnum):
|
||||
ALL_GATHER = 1
|
||||
REDUCE_SCATTER = 2
|
||||
ALL_TO_ALL = 3
|
||||
UNSUPPORTED = 4
|
||||
|
||||
|
||||
class NVIDIA_GPU_TYPE(IntEnum):
|
||||
@ -53,10 +54,10 @@ def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
|
||||
return NCCL_COLL.ALL_GATHER
|
||||
elif "reduce_scatter" in kernel_name:
|
||||
return NCCL_COLL.REDUCE_SCATTER
|
||||
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
|
||||
elif any(comm in kernel_name for comm in ("all_to_all", "alltoall")):
|
||||
return NCCL_COLL.ALL_TO_ALL
|
||||
else:
|
||||
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
||||
return NCCL_COLL.UNSUPPORTED
|
||||
|
||||
|
||||
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
||||
@ -347,13 +348,12 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
|
||||
|
||||
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||
size = 0
|
||||
sz_bytes = 0
|
||||
for node in fx_node.all_input_nodes:
|
||||
if (t := node.meta.get("val")) is not None:
|
||||
size += t.numel() * t.element_size()
|
||||
|
||||
# TODO - symbolic
|
||||
return size
|
||||
numel = get_size_numel(t.size())
|
||||
sz_bytes += numel * get_dtype_size(t.dtype)
|
||||
return sz_bytes
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime_from_fx_node(
|
||||
|
||||
@ -10,6 +10,10 @@ import torch.distributed as dist
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch._inductor.comm_analysis import (
|
||||
get_collective_type_from_kernel_name,
|
||||
NCCL_COLL,
|
||||
)
|
||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||
from torch._logging import trace_structured
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -52,6 +56,23 @@ def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def _schedulable_wait_node(node: torch.fx.Node) -> bool:
|
||||
"""
|
||||
Add additional check on if the wait node is schedulable
|
||||
We should not schedule a fx node that is:
|
||||
1. wait on a collective that is not callable
|
||||
2. wait on a non-NCCL communication node
|
||||
"""
|
||||
if not is_wait_tensor(node):
|
||||
return False
|
||||
assert isinstance(node.args[0], torch.fx.Node)
|
||||
assert isinstance(node.args[0].target.name(), str)
|
||||
is_callable: bool = node.args[0].op == "call_function"
|
||||
coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name())
|
||||
is_collective: bool = coll != NCCL_COLL.UNSUPPORTED
|
||||
return is_callable and is_collective
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None:
|
||||
if is_all_gather_into_tensor(node):
|
||||
group_key_fn = (
|
||||
@ -138,7 +159,6 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops._c10d_functional.wait_tensor.default
|
||||
and node.args[0].op == "call_function"
|
||||
)
|
||||
|
||||
|
||||
@ -149,6 +169,13 @@ def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_all_to_all_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target is torch.ops._c10d_functional.all_to_all_single.default
|
||||
)
|
||||
|
||||
|
||||
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
||||
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@ -8,9 +8,9 @@ import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.graph_deduplication import _stable_topological_sort
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
_schedulable_wait_node,
|
||||
is_all_gather_into_tensor as is_all_gather,
|
||||
is_reduce_scatter_tensor as is_reduce_scatter,
|
||||
is_wait_tensor,
|
||||
merge_all_gather_bucket,
|
||||
merge_reduce_scatter_bucket,
|
||||
)
|
||||
@ -36,7 +36,10 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, node_users: dict[fx.Node, OrderedSet[fx.Node]], *args: Any, **kwargs: Any
|
||||
self,
|
||||
node_users: dict[fx.Node, OrderedSet[fx.Node]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.node_users = node_users
|
||||
@ -97,7 +100,7 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
|
||||
)
|
||||
|
||||
# Identify the new wait and start
|
||||
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
|
||||
new_waits = [n for n in new_nodes if _schedulable_wait_node(n)]
|
||||
assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}"
|
||||
new_wait = new_waits[0]
|
||||
new_start = new_wait.args[0]
|
||||
@ -186,7 +189,7 @@ class ManualOverlapScheduler(OverlapScheduler):
|
||||
def _identify_collectives(self) -> None:
|
||||
"""Identify all collective operations."""
|
||||
for node in self.nodes:
|
||||
if is_wait_tensor(node):
|
||||
if _schedulable_wait_node(node):
|
||||
start = node.args[0]
|
||||
info = CollectiveInfo(
|
||||
start_node=start,
|
||||
|
||||
@ -11,7 +11,8 @@ from typing import Any, Literal
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.fx_passes.bucketing import is_wait_tensor
|
||||
from torch._inductor.comm_analysis import estimate_fx_collective_size
|
||||
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor
|
||||
from torch._inductor.fx_passes.memory_estimator import (
|
||||
_is_releasable,
|
||||
build_memory_profile,
|
||||
@ -67,16 +68,6 @@ def estimate_collective_time(
|
||||
)
|
||||
|
||||
|
||||
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||
size = 0
|
||||
for node in fx_node.all_input_nodes:
|
||||
if (t := node.meta.get("val")) is not None:
|
||||
# todo - symbolic
|
||||
size += t.numel() * t.element_size()
|
||||
|
||||
return size
|
||||
|
||||
|
||||
def is_compute_node(n: fx.Node) -> bool:
|
||||
"""
|
||||
Should we consider this node computationally expensive ?
|
||||
@ -318,7 +309,7 @@ class OverlapScheduler:
|
||||
def _identify_collectives(self) -> None:
|
||||
"""Identify all collective operations."""
|
||||
for node in self.nodes:
|
||||
if is_wait_tensor(node):
|
||||
if _schedulable_wait_node(node):
|
||||
start = node.args[0]
|
||||
coll_time_ms = estimate_collective_time(
|
||||
start, custom_runtime_estimation=self.custom_runtime_estimation
|
||||
@ -531,7 +522,7 @@ class OverlapScheduler:
|
||||
self._handle_compute(node)
|
||||
elif node in self.collective_info:
|
||||
self._handle_collective_start(node)
|
||||
elif is_wait_tensor(node):
|
||||
elif _schedulable_wait_node(node):
|
||||
self._handle_wait(node)
|
||||
else:
|
||||
self._handle_other(node)
|
||||
@ -596,7 +587,7 @@ class OverlapScheduler:
|
||||
def _compute_score(self, node: fx.Node) -> object:
|
||||
"""Compute priority score for a node"""
|
||||
|
||||
if is_wait_tensor(node):
|
||||
if _schedulable_wait_node(node):
|
||||
info = self.collective_info[self.wait_to_start[node]]
|
||||
# defer waits locally if they are exposed.
|
||||
compute_local_priority = int(info.is_exposed)
|
||||
@ -827,7 +818,7 @@ class OverlapScheduler:
|
||||
# thus forcing it to be exposed.
|
||||
# however, if it is already hidden or it cannot be possible hidden,
|
||||
# it's fine to schedule it
|
||||
if is_wait_tensor(node):
|
||||
if _schedulable_wait_node(node):
|
||||
info = self.collective_info[self.wait_to_start[node]]
|
||||
if info.hiding_node and info.hiding_node != curr_compute_node:
|
||||
continue
|
||||
@ -875,7 +866,7 @@ class OverlapScheduler:
|
||||
assert all(n not in self.scheduled for n in path)
|
||||
for node in sorted(path, key=lambda n: self.node_idx[n]):
|
||||
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
|
||||
if is_wait_tensor(node):
|
||||
if _schedulable_wait_node(node):
|
||||
# When we schedule wait tensors, we also force realization of all
|
||||
# collectives enqueued prior to their corresponding collective.
|
||||
# It's possible the scheduling of one wait tensor here has forced
|
||||
|
||||
@ -891,14 +891,10 @@ class TorchLogsFormatter(logging.Formatter):
|
||||
# exception handling - copied from logging.Formatter.format
|
||||
s = record.message
|
||||
if record.exc_info:
|
||||
from torch._dynamo import config
|
||||
|
||||
should_format_exc = config.verbose or artifact_name != "graph_breaks"
|
||||
# Cache the traceback text to avoid converting it multiple times
|
||||
# (it's constant anyway)
|
||||
if should_format_exc:
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
if s[-1:] != "\n":
|
||||
s = s + "\n"
|
||||
|
||||
@ -465,39 +465,6 @@ lib.define(
|
||||
"_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
|
||||
)
|
||||
|
||||
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
|
||||
"""
|
||||
Given a local tensor and a group name, return a tuple of tensors that are
|
||||
symmetric on other devices. The returned tensors are ordered by rank IDs. The
|
||||
length of the tuple equals to the size of the group.
|
||||
|
||||
Note: this API works only when `world_within_direct_access()` returns True, i.e.
|
||||
only when the group is within NVLink domain or similar. It does not work across
|
||||
network interfaces.
|
||||
"""
|
||||
|
||||
|
||||
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
|
||||
def _get_remote_tensors_default(
|
||||
local: torch.Tensor, group_name: str
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
hdl = rendezvous(local, group_name)
|
||||
if hdl is None:
|
||||
raise ValueError("Tensor is not allocated from Symmetric Memory")
|
||||
|
||||
return tuple(
|
||||
hdl.get_remote_tensor(peer, local.size(), local.dtype)
|
||||
for peer in range(hdl.world_size)
|
||||
)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "get_remote_tensors", "Meta")
|
||||
def _get_remote_tensors_meta(
|
||||
local: torch.Tensor, group_name: str
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
return tuple(torch.empty_like(local) for _ in range(group.size()))
|
||||
|
||||
|
||||
class _ScaleMode(Enum):
|
||||
UNSCALED = "unscaled"
|
||||
|
||||
@ -337,34 +337,19 @@ class OpDispatcher:
|
||||
if is_inplace_op:
|
||||
# inplace op should return self instead of re-wrapping
|
||||
if output_sharding.output_spec is not None:
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, DTensorSpec)
|
||||
assert isinstance(args[0], dtensor.DTensor)
|
||||
|
||||
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
|
||||
# the inplace argument's tensor meta. Here we choose to special case
|
||||
# this op because as far as I know this is the only inplace op that
|
||||
# has such as behavior. We can extend this special case if necessary.
|
||||
if op_call == aten.squeeze_.dim:
|
||||
# update the spec to handle tensor meta changes
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, DTensorSpec)
|
||||
assert isinstance(args[0], dtensor.DTensor)
|
||||
args[0]._spec = output_spec
|
||||
# use return_and_correct_aliasing to match the outer and the inner
|
||||
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
|
||||
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
|
||||
else:
|
||||
# For all other inplace ops, check if placement changes are required
|
||||
# Inplace operations that change placement are not supported because
|
||||
# they would require redistribution, which breaks aliasing semantics.
|
||||
# If there are views into the tensor, the views would not be updated.
|
||||
if args[0]._spec.placements != output_spec.placements:
|
||||
raise RuntimeError(
|
||||
f"{op_call}: in-place operations that require placement changes "
|
||||
f"are not supported. The operation would change placement from "
|
||||
f"{args[0]._spec.placements} to {output_spec.placements}, "
|
||||
f"which requires redistribution and breaks aliasing semantics. "
|
||||
f"Please use the out-of-place version of this operation instead."
|
||||
)
|
||||
# Most inplace ops don't change tensor meta, so no spec update needed
|
||||
return args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -207,19 +207,12 @@ def tensorify_python_scalars(
|
||||
and node.target is torch.ops.aten._local_scalar_dense.default
|
||||
):
|
||||
dtype = node.args[0].meta["val"].dtype
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
assert isinstance(node.args[0], fx.Node), node.args[0]
|
||||
|
||||
s = node.meta["val"].node.expr
|
||||
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
|
||||
# only tensorify if the dtype is floating point
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
expr_to_tensor_proxy[s] = MetaProxy(
|
||||
node.args[0], tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
@ -227,7 +220,9 @@ def tensorify_python_scalars(
|
||||
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
|
||||
expr_to_tensor_proxy[s], torch.float64
|
||||
)
|
||||
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
elif (sym_expr := _get_sym_val(node)) is not None:
|
||||
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
||||
|
||||
Reference in New Issue
Block a user