mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 18:14:54 +08:00
Compare commits
2 Commits
ciflow/tru
...
update-vll
| Author | SHA1 | Date | |
|---|---|---|---|
| d79ccd0bba | |||
| d6bb3ad8b9 |
@ -84,7 +84,6 @@ class VllmTestRunner(BaseRunner):
|
||||
self.VLLM_TEST_WHLS_REGEX = [
|
||||
"xformers/*.whl",
|
||||
"vllm/vllm*.whl",
|
||||
"flashinfer-python/flashinfer*.whl",
|
||||
]
|
||||
|
||||
def prepare(self):
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
617079d944b0e72632311c30ae2bbdf1168b901e
|
||||
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc
|
||||
|
||||
35
.github/ci_configs/vllm/Dockerfile
vendored
35
.github/ci_configs/vllm/Dockerfile
vendored
@ -1,4 +1,4 @@
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG CUDA_VERSION=12.9.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine,
|
||||
@ -124,7 +124,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
git clone https://github.com/facebookresearch/xformers.git
|
||||
|
||||
pushd xformers
|
||||
git checkout v0.0.32.post2
|
||||
git checkout v0.0.33.post1
|
||||
git submodule update --init --recursive
|
||||
python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose
|
||||
popd
|
||||
@ -256,7 +256,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Install build and runtime dependencies, this is needed for flashinfer install
|
||||
# Install build and runtime dependencies
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
RUN python3 use_existing_torch.py
|
||||
@ -294,33 +294,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /wheels/xformers/*.whl --verbose
|
||||
|
||||
# Build FlashInfer from source
|
||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
|
||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer \
|
||||
&& echo "Building FlashInfer with AOT for arches: ${torch_cuda_arch_list}" \
|
||||
&& cd flashinfer \
|
||||
&& python3 -m flashinfer.aot \
|
||||
&& python3 -m build --no-isolation --wheel --outdir ../wheels/flashinfer \
|
||||
&& cd .. \
|
||||
&& rm -rf flashinfer
|
||||
|
||||
# Install FlashInfer
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system wheels/flashinfer/*.whl --verbose
|
||||
|
||||
# Logging to confirm the torch versions
|
||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
|
||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm\|^flashinfer' > build_summary.txt
|
||||
RUN pip freeze | grep -E 'torch|xformers|vllm'
|
||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm' > build_summary.txt
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
|
||||
|
||||
@ -331,4 +307,3 @@ FROM scratch as export-wheels
|
||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
|
||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
|
||||
COPY --from=vllm-base /workspace/build_summary.txt /wheels/build_summary.txt
|
||||
COPY --from=vllm-base /workspace/wheels/flashinfer /wheels/flashinfer-python
|
||||
|
||||
2
.github/scripts/prepare_vllm_wheels.sh
vendored
2
.github/scripts/prepare_vllm_wheels.sh
vendored
@ -88,7 +88,7 @@ repackage_wheel() {
|
||||
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1
|
||||
|
||||
pushd externals/vllm/wheels
|
||||
for package in xformers flashinfer-python vllm; do
|
||||
for package in xformers vllm; do
|
||||
repackage_wheel $package
|
||||
done
|
||||
popd
|
||||
|
||||
@ -117,10 +117,6 @@ Tensor& relu_mps_(Tensor& self) {
|
||||
|
||||
TORCH_IMPL_FUNC(log_softmax_mps_out)
|
||||
(const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
|
||||
"log_softmax for complex is not supported for MPS");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "log_softmax for bool is not supported for MPS");
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
@ -164,10 +160,6 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
|
||||
|
||||
TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
|
||||
(const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()),
|
||||
"log_softmax for complex is not supported for MPS");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kBool, "log_softmax for bool is not supported for MPS");
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryGradCachedGraph;
|
||||
|
||||
@ -208,7 +200,6 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
// NOTE: buffer is only used by CPU dispatch, we just ignore it here
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
@ -715,7 +706,6 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
|
||||
if (output.numel() == 0)
|
||||
return;
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
|
||||
// can't be evenly halved, but give a nicer error message here.
|
||||
TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
|
||||
@ -829,7 +819,6 @@ TORCH_IMPL_FUNC(softplus_out_mps)
|
||||
(const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) {
|
||||
using namespace mps;
|
||||
TORCH_CHECK(self.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Not implemented for long");
|
||||
// Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
|
||||
// \log(1 + \exp(\beta * x))` element-wise.
|
||||
// For numerical stability the implementation reverts to the linear function
|
||||
@ -980,8 +969,6 @@ TORCH_IMPL_FUNC(mish_out_mps)
|
||||
(const Tensor& self, const Tensor& result) {
|
||||
using namespace mps;
|
||||
TORCH_CHECK(self.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS");
|
||||
|
||||
if (result.numel() == 0)
|
||||
return;
|
||||
@ -1030,8 +1017,6 @@ TORCH_IMPL_FUNC(mish_out_mps)
|
||||
Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
|
||||
using namespace mps;
|
||||
TORCH_CHECK(self.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS");
|
||||
|
||||
Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
|
||||
if (grad_input.numel() == 0)
|
||||
@ -1221,7 +1206,6 @@ TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) {
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
TORCH_CHECK(self.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
|
||||
|
||||
// Empty output
|
||||
if (result.numel() == 0)
|
||||
|
||||
@ -80,11 +80,6 @@ static void grid_sampler_2d_mps_impl(Tensor& output,
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
// Crashes with
|
||||
// MPSGraphUtilities.mm:97:0: error: 'mps.sample_grid' op operand #0 must be tensor of mps native type values, but got
|
||||
// 'tensor<2x3x5x20xcomplex<f32>>'
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
|
||||
"grid_sampler_2d is not supported for complex on MPS");
|
||||
@autoreleasepool {
|
||||
std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" +
|
||||
std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);
|
||||
|
||||
@ -240,7 +240,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
||||
bool check_errors) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat,
|
||||
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
||||
"linalg.lu_factor(): MPS doesn't support complex types.");
|
||||
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");
|
||||
|
||||
@ -364,7 +364,8 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
|
||||
const Tensor& info) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS only supports floats.");
|
||||
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
||||
"linalg.lu_factor(): MPS doesn't support complex types.");
|
||||
Tensor A_t, B_t;
|
||||
// If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T
|
||||
// Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output
|
||||
@ -1057,8 +1058,7 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
|
||||
using namespace mps;
|
||||
|
||||
checkInputsSolver(A, B, left, "linalg.solve_triangular");
|
||||
TORCH_CHECK(A.scalar_type() == kFloat && B.scalar_type() == kFloat,
|
||||
"linalg.solve.triangular(); Only float is supported!");
|
||||
TORCH_CHECK(!A.is_complex() && !B.is_complex(), "linalg.solve.triangular(); Not supported for complex yet!");
|
||||
Tensor A_t, B_t;
|
||||
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr);
|
||||
at::native::resize_output(out, B_t.sizes());
|
||||
|
||||
@ -416,8 +416,6 @@ static void nllnd_loss_forward_impl(Tensor& output,
|
||||
int64_t reduction,
|
||||
int64_t ignore_index,
|
||||
bool is2D) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(output.scalar_type()),
|
||||
"nlld_loss for complex is not supported for MPS");
|
||||
std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
|
||||
reshapedTarget.push_back(1);
|
||||
|
||||
@ -826,9 +824,6 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) {
|
||||
std::string op_name = __func__;
|
||||
using namespace mps;
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
|
||||
"huber_loss for complex is not supported for MPS");
|
||||
TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
|
||||
TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes")
|
||||
TORCH_CHECK(output.is_mps());
|
||||
|
||||
@ -597,7 +597,6 @@ static void avg_pool2d_template(const Tensor& input,
|
||||
bool count_include_pad,
|
||||
const std::optional<int64_t> divisor_override,
|
||||
const std::string& op_name) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "Not implemented for complex");
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const bool is_backward_pass = grad_output.defined();
|
||||
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
|
||||
@ -916,8 +915,6 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
|
||||
bool ceil_mode,
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
|
||||
"Max pooling for complex is not supported for MPS");
|
||||
bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
|
||||
if (use_graph) {
|
||||
auto indices_memory_format = indices.suggest_memory_format();
|
||||
@ -970,8 +967,6 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
|
||||
bool ceil_mode,
|
||||
const Tensor& indices,
|
||||
const Tensor& grad_input) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
|
||||
"Max pooling for complex is not supported for MPS");
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor
|
||||
|
||||
@ -269,22 +269,17 @@ static void reduction_out_mps(const Tensor& input_t,
|
||||
name:nil];
|
||||
castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil];
|
||||
} else if (reduction_type == MPSReductionType::NANSUM) {
|
||||
// Integral types cannot contain NaN, so just do regular sum
|
||||
if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) {
|
||||
castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil];
|
||||
} else {
|
||||
// Create a 0 tensor of the same shape as inputTensor
|
||||
auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType];
|
||||
// Find NaNs
|
||||
auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
|
||||
// Replace NaNs with 0
|
||||
auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
|
||||
truePredicateTensor:zeros
|
||||
falsePredicateTensor:castInputTensor
|
||||
name:nil];
|
||||
// Sum
|
||||
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil];
|
||||
}
|
||||
// Create a 0 tensor of the same shape as inputTensor
|
||||
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType];
|
||||
// Find NaNs
|
||||
MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
|
||||
// Replace NaNs with 0
|
||||
MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
|
||||
truePredicateTensor:zeros
|
||||
falsePredicateTensor:castInputTensor
|
||||
name:nil];
|
||||
// Sum
|
||||
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil];
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor = castOutputTensor;
|
||||
@ -447,7 +442,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
const std::optional<Scalar>& correction,
|
||||
bool keepdim,
|
||||
StdVarType stdVarType) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input_t.scalar_type() != kLong, "Not implemented for MPS");
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
|
||||
@ -39,7 +39,6 @@ static void get_shapes(MPSShape* input_shape_readonly,
|
||||
TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) {
|
||||
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS");
|
||||
TORCH_CHECK(c10::isFloatingType(input_.scalar_type()), "softmax only supported for floating types");
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
|
||||
if (input_.numel() == 0) {
|
||||
|
||||
@ -18,10 +18,6 @@ static Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tens
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
bool has_weights = weights.defined();
|
||||
|
||||
// Crashes with
|
||||
// MPSGraphUtilities.mm:190:0: error: 'mps.scatter' op operand #2 must be tensor of int values, but got 'tensor<5xi1>'
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "bincount is not supported for Bool");
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
||||
@ -4617,7 +4617,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: permute
|
||||
MPS: permute_mps
|
||||
SparseCPU, SparseCUDA, SparseMPS: permute_sparse_coo
|
||||
SparseCPU, SparseCUDA: permute_sparse_coo
|
||||
tags: core
|
||||
|
||||
- func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
|
||||
|
||||
@ -48,7 +48,7 @@ void warnDeprecatedDataPtr() {
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void StorageImpl::incref_pyobject() const noexcept {
|
||||
void StorageImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
@ -59,12 +59,12 @@ void StorageImpl::incref_pyobject() const noexcept {
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void StorageImpl::decref_pyobject() const noexcept {
|
||||
void StorageImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool StorageImpl::try_incref_pyobject() const noexcept {
|
||||
bool StorageImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
|
||||
@ -105,11 +105,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
data_ptr_.clear();
|
||||
}
|
||||
|
||||
void incref_pyobject() const noexcept override final;
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const noexcept override final;
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const noexcept override final;
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
size_t nbytes() const {
|
||||
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||
|
||||
@ -988,7 +988,7 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
||||
}
|
||||
}
|
||||
|
||||
void TensorImpl::incref_pyobject() const noexcept {
|
||||
void TensorImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
@ -999,12 +999,12 @@ void TensorImpl::incref_pyobject() const noexcept {
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void TensorImpl::decref_pyobject() const noexcept {
|
||||
void TensorImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool TensorImpl::try_incref_pyobject() const noexcept {
|
||||
bool TensorImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
|
||||
@ -2178,11 +2178,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
void incref_pyobject() const noexcept override final;
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const noexcept override final;
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const noexcept override final;
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
private:
|
||||
// See NOTE [std::optional operator usage in CUDA]
|
||||
|
||||
@ -68,10 +68,6 @@ inline bool has_pyobject(uint64_t combined_refcount) {
|
||||
return (combined_refcount & kHasPyObject) != 0;
|
||||
}
|
||||
|
||||
inline bool is_uniquely_owned(uint64_t combined_refcount) {
|
||||
return (combined_refcount & ~detail::kHasPyObject) == detail::kUniqueRef;
|
||||
}
|
||||
|
||||
// The only requirement for refcount increment is that it happens-before
|
||||
// decrement, so no additional memory ordering is needed.
|
||||
inline uint64_t atomic_combined_refcount_increment(
|
||||
@ -291,9 +287,9 @@ class C10_API intrusive_ptr_target {
|
||||
* These two methods are called when the refcount transitions between one
|
||||
* and two and the object has a PyObject wrapper.
|
||||
*/
|
||||
virtual void incref_pyobject() const noexcept {}
|
||||
virtual void decref_pyobject() const noexcept {}
|
||||
virtual bool try_incref_pyobject() const noexcept {
|
||||
virtual void incref_pyobject() const {}
|
||||
virtual void decref_pyobject() const {}
|
||||
virtual bool try_incref_pyobject() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -367,7 +363,7 @@ class intrusive_ptr final {
|
||||
template <typename, typename...>
|
||||
friend class pybind11::class_;
|
||||
|
||||
void retain_() noexcept {
|
||||
void retain_() {
|
||||
if (target_ != NullType::singleton()) {
|
||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||
@ -381,7 +377,9 @@ class intrusive_ptr final {
|
||||
// PyObject. In other words, we need to ensure that the PyObject stays
|
||||
// alive now that we have a C++ reference to this object in addition to
|
||||
// the PyObject itself.
|
||||
if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) {
|
||||
if (C10_UNLIKELY(
|
||||
detail::has_pyobject(combined) &&
|
||||
detail::refcount(combined) == 2)) {
|
||||
target_->incref_pyobject();
|
||||
}
|
||||
} else {
|
||||
@ -394,60 +392,51 @@ class intrusive_ptr final {
|
||||
|
||||
void reset_() noexcept {
|
||||
if (target_ != NullType::singleton()) {
|
||||
reset_not_null_(target_);
|
||||
}
|
||||
}
|
||||
|
||||
// C10_NOINLINE to keep binary size a bit smaller. We pass TTarget* here
|
||||
// to avoid an extra pointer dereference in the call from reset_().
|
||||
C10_NOINLINE static void reset_not_null_(TTarget* target) noexcept {
|
||||
if (detail::is_uniquely_owned(
|
||||
target->combined_refcount_.load(std::memory_order_acquire))) {
|
||||
// Both counts are 1, so there are no weak references and
|
||||
// we are releasing the last strong reference. No other
|
||||
// threads can observe the effects of this target deletion
|
||||
// call (e.g. calling use_count()) without a data race.
|
||||
target->combined_refcount_.store(0, std::memory_order_relaxed);
|
||||
delete target;
|
||||
return;
|
||||
}
|
||||
|
||||
auto combined_refcount = detail::atomic_combined_refcount_decrement(
|
||||
target->combined_refcount_, detail::kReferenceCountOne);
|
||||
uint32_t new_refcount = detail::refcount(combined_refcount);
|
||||
bool has_pyobject = detail::has_pyobject(combined_refcount);
|
||||
if (new_refcount == 0) {
|
||||
if (detail::weakcount(combined_refcount) == 1) {
|
||||
delete target;
|
||||
if (is_uniquely_owned()) {
|
||||
// Both counts are 1, so there are no weak references and
|
||||
// we are releasing the last strong reference. No other
|
||||
// threads can observe the effects of this target_ deletion
|
||||
// call (e.g. calling use_count()) without a data race.
|
||||
target_->combined_refcount_.store(0, std::memory_order_relaxed);
|
||||
delete target_;
|
||||
return;
|
||||
}
|
||||
// See comment above about weakcount. As long as refcount>0,
|
||||
// weakcount is one larger than the actual number of weak references.
|
||||
// So we need to decrement it here.
|
||||
release_resources_and_decrement_weakrefs_(target);
|
||||
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
// If the refcount transitioned from 2 to 1, we need to decref the
|
||||
// PyObject. In other words, we don't want to keep the PyObject alive if
|
||||
// there are no C++ references to this object other than the PyObject
|
||||
// itself.
|
||||
if (has_pyobject && new_refcount == 1) {
|
||||
target->decref_pyobject();
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!has_pyobject,
|
||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
||||
}
|
||||
}
|
||||
|
||||
C10_NOINLINE static void release_resources_and_decrement_weakrefs_(
|
||||
TTarget* target) noexcept {
|
||||
// justification for const_cast: release_resources is basically a
|
||||
// destructor and a destructor always mutates the object, even for
|
||||
// const objects.
|
||||
const_cast<std::remove_const_t<TTarget>*>(target)->release_resources();
|
||||
if (detail::atomic_weakcount_decrement(target->combined_refcount_) == 0) {
|
||||
delete target;
|
||||
auto combined_refcount = detail::atomic_combined_refcount_decrement(
|
||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||
uint32_t new_refcount = detail::refcount(combined_refcount);
|
||||
bool has_pyobject = detail::has_pyobject(combined_refcount);
|
||||
if (new_refcount == 0) {
|
||||
bool should_delete = detail::weakcount(combined_refcount) == 1;
|
||||
// See comment above about weakcount. As long as refcount>0,
|
||||
// weakcount is one larger than the actual number of weak references.
|
||||
// So we need to decrement it here.
|
||||
if (!should_delete) {
|
||||
// justification for const_cast: release_resources is basically a
|
||||
// destructor and a destructor always mutates the object, even for
|
||||
// const objects.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<std::remove_const_t<TTarget>*>(target_)
|
||||
->release_resources();
|
||||
should_delete = detail::atomic_weakcount_decrement(
|
||||
target_->combined_refcount_) == 0;
|
||||
}
|
||||
if (should_delete) {
|
||||
delete target_;
|
||||
}
|
||||
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
// If the refcount transitioned from 2 to 1, we need to decref the
|
||||
// PyObject. In other words, we don't want to keep the PyObject alive if
|
||||
// there are no C++ references to this object other than the PyObject
|
||||
// itself.
|
||||
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
|
||||
target_->decref_pyobject();
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!has_pyobject,
|
||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -618,8 +607,9 @@ class intrusive_ptr final {
|
||||
*/
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
|
||||
return detail::is_uniquely_owned(
|
||||
target_->combined_refcount_.load(std::memory_order_acquire));
|
||||
uint64_t combined =
|
||||
target_->combined_refcount_.load(std::memory_order_acquire);
|
||||
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1184,7 +1174,9 @@ inline void incref(intrusive_ptr_target* self) {
|
||||
self->combined_refcount_, detail::kReferenceCountOne);
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) {
|
||||
if (C10_UNLIKELY(
|
||||
detail::has_pyobject(combined) &&
|
||||
detail::refcount(combined) == 2)) {
|
||||
self->incref_pyobject();
|
||||
}
|
||||
#else
|
||||
|
||||
@ -428,14 +428,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1571
|
||||
def test_set_reduce_scatter_divide_factor(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"divide_factor": [self.world_size * 2, self.world_size],
|
||||
"mesh_shape": [
|
||||
(self.world_size,),
|
||||
(self.world_size // 2, 2),
|
||||
(self.world_size, 1),
|
||||
],
|
||||
},
|
||||
{"divide_factor": [self.world_size * 2, self.world_size]},
|
||||
self._test_set_reduce_scatter_divide_factor,
|
||||
)
|
||||
self.run_subtests(
|
||||
@ -443,31 +436,18 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
self._test_set_reduce_scatter_divide_factor_mixed_prevision,
|
||||
)
|
||||
|
||||
def _test_set_reduce_scatter_divide_factor(
|
||||
self, divide_factor: float, mesh_shape: tuple[int] | tuple[int, int]
|
||||
):
|
||||
def _test_set_reduce_scatter_divide_factor(self, divide_factor: float):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
mesh_dim_names = ("outer",) if len(mesh_shape) == 1 else ("outer", "inner")
|
||||
mesh = init_device_mesh(
|
||||
device_type.type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module, reshard_after_forward=False, mesh=mesh)
|
||||
model = fully_shard(model, reshard_after_forward=False, mesh=mesh)
|
||||
fully_shard(module, reshard_after_forward=False)
|
||||
model = fully_shard(model, reshard_after_forward=False)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
model.set_gradient_divide_factor(divide_factor)
|
||||
|
||||
# Get ref_model params which should have the specific division factor applied
|
||||
block_params = set()
|
||||
for ref_mod in ref_model.modules():
|
||||
if isinstance(ref_mod, TransformerBlock):
|
||||
block_params.update(ref_mod.parameters())
|
||||
non_block_params = set(ref_model.parameters()) - block_params
|
||||
model.set_reduce_scatter_divide_factor(divide_factor)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
@ -476,18 +456,16 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
ref_loss = ref_model(inp).sum()
|
||||
ref_loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
factor = divide_factor if param in non_block_params else self.world_size
|
||||
param.grad.mul_(1.0 / factor)
|
||||
param.grad.mul_(1.0 / divide_factor)
|
||||
dist.all_reduce(param.grad)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
ref_optim.step()
|
||||
optim.step()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
# Check parity before calling zero_grad so that grads are also checked
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
ref_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_set_reduce_scatter_divide_factor_mixed_prevision(
|
||||
self, divide_factor: float
|
||||
@ -506,7 +484,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
fully_shard(mlp, mp_policy=mp_policy)
|
||||
model = fully_shard(model, mp_policy=mp_policy)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
model.set_gradient_divide_factor(divide_factor)
|
||||
model.set_reduce_scatter_divide_factor(divide_factor)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
|
||||
@ -34,11 +34,7 @@ from torch.distributed.tensor._ops.utils import (
|
||||
register_op_strategy,
|
||||
replicate_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.debug import (
|
||||
_clear_fast_path_sharding_prop_cache,
|
||||
_clear_python_sharding_prop_cache,
|
||||
CommDebugMode,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
@ -483,8 +479,7 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
|
||||
del propagator.op_to_schema_info[op_overload]
|
||||
else:
|
||||
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
|
||||
_clear_fast_path_sharding_prop_cache()
|
||||
_clear_python_sharding_prop_cache()
|
||||
propagator.propagate_op_sharding.cache.cache_clear()
|
||||
|
||||
|
||||
def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool:
|
||||
@ -650,28 +645,6 @@ class TestStrategyHashing(DTensorTestBase):
|
||||
self.assertEqual(out1.full_tensor(), out2.full_tensor())
|
||||
|
||||
|
||||
class TestStrategyOperation(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@with_comms
|
||||
def test_cache_clean(self):
|
||||
mesh = self.build_device_mesh()
|
||||
test_op = torch.ops.mylib.numpy_sin
|
||||
x = torch.randn(2, device=self.device_type)
|
||||
y = torch.randn(2, device=self.device_type)
|
||||
x_dt = distribute_tensor(x, mesh, [Shard(0)])
|
||||
y_dt = distribute_tensor(y, mesh, [Shard(0)])
|
||||
with op_strategy_context(test_op.default, replicate_op_strategy):
|
||||
self._test_op_on_dtensor(test_op, x_dt, y_dt)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
f"Operator {test_op.default} does not have a sharding strategy registered",
|
||||
):
|
||||
self._test_op_on_dtensor(test_op, x_dt, y_dt)
|
||||
|
||||
|
||||
DistTensorReplicateStrategyRegistrationTestWithLocalTensor = (
|
||||
create_local_tensor_test_class(
|
||||
DistTensorReplicateStrategyRegistrationTest,
|
||||
|
||||
@ -585,10 +585,6 @@ class GraphModule(torch.nn.Module):
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
# No stacktrace found for following nodes
|
||||
record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None
|
||||
wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
|
||||
@ -1405,7 +1405,7 @@ class TestConverter(TestCase):
|
||||
)
|
||||
# qnnpack not supported on s390x
|
||||
@xfailIfS390X
|
||||
def test_ts2ep_convert_quantized_model1(self):
|
||||
def test_ts2ep_convert_quantized_model(self):
|
||||
class Standalone(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -640,13 +640,16 @@ class TestPasses(TestCase):
|
||||
self.assertExpectedInline(
|
||||
without_token_ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, obj_attr, x):
|
||||
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None
|
||||
getitem_1 = takes_foo_tuple_return_default[0]
|
||||
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
|
||||
def forward(self, token, obj_attr, x):
|
||||
with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]
|
||||
getitem_2 = with_effects[2]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
|
||||
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None
|
||||
return (takes_foo_default,)""", # noqa: B950
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None
|
||||
getitem_3 = with_effects_1[0]
|
||||
getitem_4 = with_effects_1[1]; with_effects_1 = None
|
||||
return (getitem_3, getitem_4)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_fakify_script_objects(self):
|
||||
|
||||
@ -461,9 +461,9 @@ def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
|
||||
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None
|
||||
add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None
|
||||
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
|
||||
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
|
||||
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
@ -1087,12 +1087,10 @@ def forward(self, token, tq, x):
|
||||
str(ep.graph_module.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%token : [num_users=1] = placeholder[target=token]
|
||||
%tq : [num_users=2] = placeholder[target=tq]
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {})
|
||||
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
|
||||
return (getitem, tq)""", # noqa: B950
|
||||
%queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {})
|
||||
return (tq,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_deepcopy(self):
|
||||
|
||||
@ -870,100 +870,6 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
|
||||
finally:
|
||||
handle.destroy()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "triton")
|
||||
def test_export_invoke_subgraph(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
recorded_list = []
|
||||
|
||||
@torch.library.custom_op("mylib::record_memory", mutates_args=())
|
||||
def record_memory(prefix: str, module_name: str) -> None:
|
||||
torch.cuda.synchronize()
|
||||
mem_alloc = torch.cuda.memory_allocated() / 1024**2
|
||||
mem_reserved = torch.cuda.memory_reserved() / 1024**2
|
||||
memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB"
|
||||
recorded_list.append(memory_str)
|
||||
|
||||
@record_memory.register_fake
|
||||
def record_memory_fake(prefix, module_name):
|
||||
return
|
||||
|
||||
record_memory.register_effect(_EffectType.ORDERED)
|
||||
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(1024, 1024)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(1024, 1024)
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def forward(self, x):
|
||||
torch.ops.mylib.record_memory("forward", "N")
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod_list = torch.nn.ModuleList(N() for _ in range(3))
|
||||
|
||||
def forward(self, x):
|
||||
for m in self.mod_list:
|
||||
x = m(x)
|
||||
torch.ops.mylib.record_memory("forward", "N")
|
||||
return (x,)
|
||||
|
||||
model = M().to("cuda")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
x = torch.randn(32, 1024, requires_grad=True, device="cuda")
|
||||
|
||||
ep = torch.export.export(model, (x,))
|
||||
ep = ep.run_decompositions()
|
||||
self.assertEqual(len(list(ep.graph_module.named_modules())), 2)
|
||||
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x):
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None
|
||||
getitem = invoke_subgraph[0]
|
||||
getitem_1 = invoke_subgraph[1]; invoke_subgraph = None
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None
|
||||
getitem_2 = invoke_subgraph_1[0]
|
||||
getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None
|
||||
repeated_subgraph0_2 = self.repeated_subgraph0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None
|
||||
getitem_4 = invoke_subgraph_2[0]
|
||||
getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None
|
||||
with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None
|
||||
getitem_6 = with_effects[0]; with_effects = None
|
||||
return (getitem_6, getitem_5)""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.repeated_subgraph0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None
|
||||
addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None
|
||||
relu = torch.ops.aten.relu.default(addmm); addmm = None
|
||||
permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None
|
||||
addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None
|
||||
return (getitem, addmm_1)""",
|
||||
)
|
||||
|
||||
recorded_list.clear()
|
||||
out2 = ep.module()(x)
|
||||
self.assertEqual(len(recorded_list), 4)
|
||||
self.assertTrue(torch.allclose(model(x)[0], out2[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -995,6 +995,7 @@ class TestSparse(TestSparseBase):
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
@expectedFailureMPS
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
|
||||
@gradcheck_semantics()
|
||||
def test_permute(self, device, dtype, coalesced, gradcheck):
|
||||
@ -1034,8 +1035,7 @@ class TestSparse(TestSparseBase):
|
||||
else:
|
||||
self.assertFalse(s_permuted.is_coalesced())
|
||||
|
||||
kwargs = {"eps": 1e-4} if device == "mps:0" else {}
|
||||
gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_(), **kwargs)
|
||||
gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
|
||||
else:
|
||||
# otherwise check if exception is thrown
|
||||
fail_message = "transpositions between sparse and dense dimensions are not allowed"
|
||||
|
||||
@ -357,9 +357,6 @@ class TestFFT(TestCase):
|
||||
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
|
||||
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
|
||||
allowed_dtypes=(torch.cfloat, torch.cdouble))
|
||||
@toleranceOverride({
|
||||
torch.cfloat : tol(2e-4, 1.3e-6),
|
||||
})
|
||||
def test_reference_nd(self, device, dtype, op):
|
||||
if op.ref is None:
|
||||
raise unittest.SkipTest("No reference implementation")
|
||||
|
||||
@ -33,7 +33,7 @@ from .graph_capture_wrappers import (
|
||||
handle_effect_tokens_fn,
|
||||
)
|
||||
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
|
||||
from .streams import assign_backward_streams, insert_backward_syncs
|
||||
from .streams import assign_backward_streams
|
||||
from .utils import (
|
||||
call_and_expect_output_descs,
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
@ -477,8 +477,6 @@ def aot_dispatch_autograd_graph(
|
||||
# After copying metadata, assign streams to gradient accumulation nodes
|
||||
assign_backward_streams(fx_g)
|
||||
|
||||
insert_backward_syncs(fx_g)
|
||||
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Optional, TypeAlias
|
||||
import torch.fx
|
||||
import torch.fx.traceback
|
||||
from torch._dynamo.graph_utils import _get_flat_args
|
||||
from torch._dynamo.variables.streams import get_current_stream, new_event
|
||||
|
||||
|
||||
Node: TypeAlias = torch.fx.Node
|
||||
@ -13,14 +12,6 @@ def is_gradient_acc(node: Node) -> bool:
|
||||
return node.meta.get("is_gradient_acc", False)
|
||||
|
||||
|
||||
def is_bwd_node(node: Node) -> bool:
|
||||
return node.meta.get("partitioner_tag") == "is_backward"
|
||||
|
||||
|
||||
def get_device(node: Node) -> torch.device:
|
||||
return node.meta["val"].device
|
||||
|
||||
|
||||
def get_stream(node: Node) -> Optional[int]:
|
||||
maybe_annotation = node.meta.get("custom", None)
|
||||
if maybe_annotation is not None:
|
||||
@ -29,13 +20,6 @@ def get_stream(node: Node) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_or_current_stream(node: Node) -> int:
|
||||
ind = get_stream(node)
|
||||
if ind is None:
|
||||
ind = get_current_stream(get_device(node))
|
||||
return ind
|
||||
|
||||
|
||||
def set_stream(node: Node, ind: int) -> None:
|
||||
if "custom" in node.meta:
|
||||
node.meta["custom"].update({"stream": ind})
|
||||
@ -43,36 +27,6 @@ def set_stream(node: Node, ind: int) -> None:
|
||||
node.meta["custom"] = {"stream": ind}
|
||||
|
||||
|
||||
def insert_sync(
|
||||
graph: torch.fx.Graph,
|
||||
consumer: Node,
|
||||
producer: Node,
|
||||
node_to_wait_event_ind: dict[Node, int],
|
||||
) -> None:
|
||||
if producer not in node_to_wait_event_ind:
|
||||
node_to_wait_event_ind[producer] = new_event()
|
||||
|
||||
with graph.inserting_after(producer):
|
||||
node = graph.call_function(
|
||||
torch.ops.streams.record_event.default,
|
||||
(
|
||||
node_to_wait_event_ind[producer],
|
||||
get_stream_or_current_stream(producer),
|
||||
),
|
||||
)
|
||||
node.meta["partitioner_tag"] = "must_be_in_backward"
|
||||
|
||||
with graph.inserting_before(consumer):
|
||||
node = graph.call_function(
|
||||
torch.ops.streams.wait_event.default,
|
||||
(
|
||||
node_to_wait_event_ind[producer],
|
||||
get_stream_or_current_stream(consumer),
|
||||
),
|
||||
)
|
||||
node.meta["partitioner_tag"] = "must_be_in_backward"
|
||||
|
||||
|
||||
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
|
||||
"""Assigns backward streams to gradient accumulation nodes"""
|
||||
|
||||
@ -97,18 +51,3 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
|
||||
if ind is not None:
|
||||
set_stream(node, ind)
|
||||
break
|
||||
|
||||
|
||||
def insert_backward_syncs(gm: torch.fx.GraphModule) -> None:
|
||||
"""Inserts stream syncs for backward nodes if consumer and producer are on different streams"""
|
||||
node_to_wait_event_ind = {}
|
||||
for node in gm.graph.nodes:
|
||||
if is_bwd_node(node):
|
||||
flat_args = _get_flat_args(node, {})
|
||||
cur_node_stream = get_stream(node)
|
||||
|
||||
for arg in flat_args:
|
||||
if is_bwd_node(arg):
|
||||
arg_stream = get_stream(arg)
|
||||
if arg_stream != cur_node_stream and get_device(arg).type != "cpu":
|
||||
insert_sync(gm.graph, node, arg, node_to_wait_event_ind)
|
||||
|
||||
@ -713,9 +713,6 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||
self.lazy_bwd_cache: dict[
|
||||
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
|
||||
] = defaultdict(dict)
|
||||
self.effects_cache: dict[
|
||||
str, set
|
||||
] = {} # Maps identifier -> set of effect types
|
||||
|
||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
|
||||
self.dynamo_installed_submodules[fn_id].append(identifier)
|
||||
@ -754,21 +751,6 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||
|
||||
return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
|
||||
|
||||
def add_effects(self, identifier: str, effects: set) -> None:
|
||||
"""Store the effect types for a given invoke_subgraph identifier."""
|
||||
if prev_effects := self.effects_cache.get(identifier, None):
|
||||
assert effects == prev_effects, (
|
||||
"Different number of effects were found for invoke_subgraph "
|
||||
f"call with identifier {identifier}. \n"
|
||||
f"Previously we had the following effects: {prev_effects}.\n"
|
||||
f"But now we have: {effects}."
|
||||
)
|
||||
self.effects_cache[identifier] = effects
|
||||
|
||||
def get_effects(self, identifier: str) -> Optional[set]:
|
||||
"""Retrieve the effect types for a given invoke_subgraph identifier."""
|
||||
return self.effects_cache.get(identifier, None)
|
||||
|
||||
|
||||
class HopDispatchSetCache:
|
||||
def __init__(self) -> None:
|
||||
|
||||
@ -80,7 +80,6 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator))
|
||||
for o in operands
|
||||
if o is not None
|
||||
), (
|
||||
f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}"
|
||||
)
|
||||
@ -305,62 +304,6 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
||||
|
||||
|
||||
def get_output_metadata(subgraph, *operands):
|
||||
"""
|
||||
Extract metadata about the subgraph outputs WITHOUT executing the subgraph.
|
||||
This avoids running side-effectful operations twice (once here, once in forward).
|
||||
We analyze the graph structure statically to extract metadata.
|
||||
"""
|
||||
# Unwrap FunctionalizeCtxWrapper if present
|
||||
if isinstance(subgraph, FunctionalizeCtxWrapper):
|
||||
subgraph = subgraph.subgraph
|
||||
|
||||
# If not a GraphModule, fall back to execution-based metadata extraction
|
||||
if not isinstance(subgraph, torch.fx.GraphModule):
|
||||
return _get_output_metadata_by_execution(subgraph, *operands)
|
||||
|
||||
output_metadata = OutputMetadata()
|
||||
|
||||
# Extract output arguments from the output node
|
||||
# The output node has args=(output_values,) where output_values is a tuple/list
|
||||
output_node = next(reversed(subgraph.graph.find_nodes(op="output")))
|
||||
output_metadata.num_fw_outs = len(output_node.args[0])
|
||||
|
||||
for idx, output_arg in enumerate(output_node.args[0]):
|
||||
if not isinstance(output_arg, torch.fx.Node):
|
||||
if isinstance(output_arg, int):
|
||||
output_metadata.indexes_with_symint.add(idx)
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
continue
|
||||
|
||||
# Check node metadata for type information
|
||||
if output_arg.meta.get("val") is None:
|
||||
# If we don't have complete metadata for all outputs, fall back to execution
|
||||
# This is important for correctness (e.g., detecting SymInts) even though it
|
||||
# runs side-effectful operations
|
||||
return _get_output_metadata_by_execution(subgraph, *operands)
|
||||
|
||||
val = output_arg.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
output_metadata.indexes_with_symint.add(idx)
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
elif isinstance(val, torch.Tensor):
|
||||
# Check if tensor requires grad from metadata
|
||||
if hasattr(val, "requires_grad") and not val.requires_grad:
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
else:
|
||||
# Non-tensor, non-symint (shouldn't happen but be safe)
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
|
||||
return output_metadata
|
||||
|
||||
|
||||
def _get_output_metadata_by_execution(subgraph, *operands):
|
||||
"""
|
||||
Fallback: Extract metadata by executing the subgraph.
|
||||
This should only be used when static analysis fails.
|
||||
WARNING: This will run side-effectful operations!
|
||||
"""
|
||||
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
# args are functional tensors, generate some example tensors
|
||||
@ -380,15 +323,19 @@ def _get_output_metadata_by_execution(subgraph, *operands):
|
||||
|
||||
num_fw_outs = len(fw_outs)
|
||||
|
||||
# Collect the indexes of none in the output to check that the grad
|
||||
# is None at the corresponding index in the backward. This check is
|
||||
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
|
||||
# Also collect the indexes of no_grad in the output to filter out
|
||||
# the grad_outs in the `backward` method.
|
||||
output_metadata = OutputMetadata()
|
||||
output_metadata.num_fw_outs = num_fw_outs
|
||||
|
||||
output_metadata.num_fw_outs = num_fw_outs
|
||||
for idx, fw_out in enumerate(fw_outs):
|
||||
if isinstance(fw_out, torch.SymInt):
|
||||
output_metadata.indexes_with_symint.add(idx)
|
||||
elif not fw_out.requires_grad:
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
|
||||
return output_metadata
|
||||
|
||||
|
||||
@ -615,34 +562,7 @@ def _(ctx, subgraph, identifier, *operands):
|
||||
do_auto_functionalize_v2,
|
||||
)
|
||||
|
||||
# (in the functionalization metadata phase) Capture tokens before
|
||||
tokens_before = dict(ctx.mode._tokens)
|
||||
|
||||
# Check if this subgraph has effects stored in the cache
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
effects = None
|
||||
if invoke_subgraph_cache:
|
||||
effects = invoke_subgraph_cache.get_effects(identifier)
|
||||
|
||||
if effects:
|
||||
assert len(effects) == 1, "Multiple effects within a subgraph NYI"
|
||||
tokens = ctx.mode._tokens
|
||||
effects = next(iter(effects))
|
||||
token_input = tokens[effects]
|
||||
|
||||
operands = (token_input, *operands)
|
||||
|
||||
def wrap_subgraph(subgraph):
|
||||
def wrapped_subgraph(token, *args):
|
||||
res = subgraph(*args)
|
||||
return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res
|
||||
|
||||
return wrapped_subgraph
|
||||
|
||||
subgraph = wrap_subgraph(subgraph)
|
||||
|
||||
unwrapped_operands = ctx.unwrap_tensors(operands)
|
||||
|
||||
hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands)
|
||||
if can_auto_functionalize(hop_instance):
|
||||
# NOTE: [auto_functionalize x invoke_subgraph caching]
|
||||
@ -667,28 +587,6 @@ def _(ctx, subgraph, identifier, *operands):
|
||||
# of invoke_subgraph ops if input aliasing/mutation is detected.
|
||||
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
|
||||
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
|
||||
|
||||
if effects:
|
||||
(new_token, *out) = out
|
||||
ctx.mode._tokens[effects] = new_token
|
||||
|
||||
# (in the functionalization metadata phase) Capture tokens after and see if
|
||||
# there are any differences (there are new effects or the token value for an
|
||||
# effect type has changed)
|
||||
tokens_after = dict(ctx.mode._tokens)
|
||||
discovered_effects = set()
|
||||
for effect_type, token in tokens_after.items():
|
||||
if effect_type not in tokens_before or tokens_before[effect_type] is not token:
|
||||
discovered_effects.add(effect_type)
|
||||
|
||||
if discovered_effects:
|
||||
assert ctx.mode._allow_token_discovery, (
|
||||
f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}."
|
||||
)
|
||||
# Store discovered effects in the cache by identifier
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_effects(identifier, discovered_effects)
|
||||
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
|
||||
|
||||
@ -35,18 +35,6 @@ class EffectHolder:
|
||||
if namespace == "higher_order":
|
||||
return
|
||||
|
||||
# These classes do not have side effects as they just store quantization
|
||||
# params, so we dont need to mark them as ordered
|
||||
skip_classes = (
|
||||
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
|
||||
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
|
||||
"__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
|
||||
"__torch__.torch.classes.quantized.LinearPackedParamsBase",
|
||||
"__torch__.torch.classes.xnnpack.Conv2dOpContext",
|
||||
"__torch__.torch.classes.xnnpack.LinearOpContext",
|
||||
"__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
|
||||
)
|
||||
|
||||
opname = f"{namespace}::{opname}"
|
||||
if torch._C._get_operation_overload(opname, overload) is not None:
|
||||
# Since we call this when destroying the library, sometimes the
|
||||
@ -54,9 +42,6 @@ class EffectHolder:
|
||||
schema = torch._C._get_schema(opname, overload)
|
||||
for arg in schema.arguments:
|
||||
if isinstance(arg.type, torch.ClassType):
|
||||
type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
|
||||
if type_str in skip_classes:
|
||||
continue
|
||||
self._effect = EffectType.ORDERED
|
||||
return
|
||||
|
||||
|
||||
@ -390,27 +390,31 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
m.def("_supported_activities", []() {
|
||||
std::set<torch::profiler::impl::ActivityType> activities{
|
||||
torch::profiler::impl::ActivityType::CPU};
|
||||
#if defined(USE_KINETO)
|
||||
#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
|
||||
if (at::getNumGPUs() > 0) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::CUDA);
|
||||
}
|
||||
#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
|
||||
#if (!defined(LIBKINETO_NOXPUPTI))
|
||||
if (at::hasXPU()) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::XPU);
|
||||
}
|
||||
#endif // (!defined(LIBKINETO_NOXPUPTI))
|
||||
#if defined(USE_KINETO) && \
|
||||
(!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
|
||||
if (at::hasMTIA()) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::MTIA);
|
||||
}
|
||||
if (at::hasHPU()) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::HPU);
|
||||
}
|
||||
if (at::getNumGPUs() > 0) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::CUDA);
|
||||
}
|
||||
#elif defined(USE_KINETO)
|
||||
if (at::hasXPU()) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::XPU);
|
||||
}
|
||||
if (at::hasHPU()) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::HPU);
|
||||
}
|
||||
if (at::hasMTIA()) {
|
||||
activities.insert(torch::profiler::impl::ActivityType::MTIA);
|
||||
}
|
||||
if (c10::get_privateuse1_backend() != "privateuseone") {
|
||||
activities.insert(torch::profiler::impl::ActivityType::PrivateUse1);
|
||||
}
|
||||
#endif // defined(USE_KINETO)
|
||||
#endif
|
||||
return activities;
|
||||
});
|
||||
|
||||
|
||||
@ -1200,27 +1200,25 @@ get_thread_local_native_sharding_propagator_cache() {
|
||||
py::reinterpret_borrow<py::dict>(PyThreadState_GetDict());
|
||||
// We need to clean up before Python detaches from the thread if
|
||||
// the thread is being destroyed.
|
||||
if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) {
|
||||
thread_dict["__DTensor_fastpath_thread_cache_cleanup"] =
|
||||
py::capsule(new std::thread::id(this_thread_id), [](void* p) {
|
||||
auto* ptid = reinterpret_cast<std::thread::id*>(p);
|
||||
{
|
||||
std::lock_guard<std::mutex> inner_lock(
|
||||
native_sharding_propagator_cache_cleanup_mutex);
|
||||
auto it = all_thread_caches.find(*ptid);
|
||||
if (it != all_thread_caches.end()) {
|
||||
// We need to both:
|
||||
// 1) free python objects, and
|
||||
it->second->reset();
|
||||
// 2) make sure we don't try to come back and mess with
|
||||
// a destroyed thread-local at module unload (e.g.,
|
||||
// process exit) time.
|
||||
all_thread_caches.erase(it);
|
||||
}
|
||||
thread_dict["__DTensor_fastpath_thread_cache_cleanup"] =
|
||||
py::capsule(new std::thread::id(this_thread_id), [](void* p) {
|
||||
auto* ptid = reinterpret_cast<std::thread::id*>(p);
|
||||
{
|
||||
std::lock_guard<std::mutex> inner_lock(
|
||||
native_sharding_propagator_cache_cleanup_mutex);
|
||||
auto it = all_thread_caches.find(*ptid);
|
||||
if (it != all_thread_caches.end()) {
|
||||
// We need to both:
|
||||
// 1) free python objects, and
|
||||
it->second->reset();
|
||||
// 2) make sure we don't try to come back and mess with
|
||||
// a destroyed thread-local at module unload (e.g.,
|
||||
// process exit) time.
|
||||
all_thread_caches.erase(it);
|
||||
}
|
||||
delete ptid;
|
||||
});
|
||||
}
|
||||
}
|
||||
delete ptid;
|
||||
});
|
||||
}
|
||||
return native_sharding_propagator_cache_DO_NOT_USE.value();
|
||||
}
|
||||
|
||||
@ -547,12 +547,8 @@ def foreach_reduce(
|
||||
op=reduce_scatter_op,
|
||||
)
|
||||
else:
|
||||
# For single GPU, just copy the input to output (no actual reduce-scatter needed), and
|
||||
# account for a possible gradient_divide_factor.
|
||||
if gradient_divide_factor is not None:
|
||||
reduce_output.copy_(reduce_scatter_input / gradient_divide_factor)
|
||||
else:
|
||||
reduce_output.copy_(reduce_scatter_input)
|
||||
# For single GPU, just copy the input to output (no actual reduce-scatter needed)
|
||||
reduce_output.copy_(reduce_scatter_input)
|
||||
reduce_scatter_event = reduce_scatter_stream.record_event()
|
||||
post_reduce_stream = reduce_scatter_stream
|
||||
if all_reduce_group is not None: # HSDP or DDP/replicate
|
||||
@ -725,21 +721,20 @@ def _get_gradient_divide_factors(
|
||||
if all_reduce_group is not None:
|
||||
data_parallel_size *= all_reduce_group.size()
|
||||
|
||||
if factor is None:
|
||||
factor = float(data_parallel_size)
|
||||
|
||||
if not overflow_risk and not force_sum_reduction_for_comms:
|
||||
if factor is None:
|
||||
if factor == data_parallel_size:
|
||||
# Warning: NCCL ReduceOp.AVG may produce incorrect results with
|
||||
# world size 1.
|
||||
if data_parallel_size == 1:
|
||||
return None, None, ReduceOp.SUM, ReduceOp.SUM
|
||||
return None, None, ReduceOp.AVG, ReduceOp.AVG
|
||||
if reduce_scatter_group is not None and factor == reduce_scatter_group.size():
|
||||
reduce_scatter_op = ReduceOp.AVG
|
||||
else:
|
||||
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor)
|
||||
return None, None, reduce_scatter_op, ReduceOp.SUM
|
||||
return None, None, reduce_scatter_op, ReduceOp.SUM
|
||||
|
||||
if factor is None:
|
||||
factor = float(data_parallel_size)
|
||||
pre_factor: Optional[float]
|
||||
if overflow_risk:
|
||||
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
|
||||
|
||||
@ -15,105 +15,113 @@ from .graph_signature import (
|
||||
)
|
||||
|
||||
|
||||
def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants):
|
||||
"""Extract the custom object from a node's arguments."""
|
||||
custom_obj_node = node
|
||||
custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr]
|
||||
assert isinstance(custom_obj_meta, CustomObjArgument)
|
||||
|
||||
if custom_obj_meta.fake_val:
|
||||
return custom_obj_meta.fake_val
|
||||
elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr]
|
||||
return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr]
|
||||
else:
|
||||
raise RuntimeError(f"Unable to find custom obj for node {node}")
|
||||
|
||||
|
||||
def _replace_with_effects_node(
|
||||
node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module
|
||||
def _remove_effect_tokens_from_graph_helper(
|
||||
ep, num_tokens, input_token_names, output_token_names
|
||||
):
|
||||
"""Replace a with_effects node with the underlying function call."""
|
||||
# Get the input nodes
|
||||
token_node, func, *node_args = node.args
|
||||
if token_node.op == "placeholder":
|
||||
input_tokens.append(token_node)
|
||||
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
|
||||
|
||||
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
|
||||
output_node = None
|
||||
with_effect_nodes: list[torch.fx.Node] = []
|
||||
|
||||
# Get the schema for the function
|
||||
if func is torch.ops.higher_order.call_torchbind:
|
||||
custom_obj = _get_custom_obj_for_node(
|
||||
node_args[0], inputs_to_lifted_custom_objs, ep.constants
|
||||
)
|
||||
schema = _get_schema(func, [custom_obj] + node_args[1:])
|
||||
else:
|
||||
schema = _get_schema(func, node_args)
|
||||
# Output node need to check its args against output_token_names (collected from output_spec)
|
||||
# Therefore, we only need to find the top-levele output node
|
||||
output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output")))
|
||||
for module in ep.graph_module.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
# Create the replacement node
|
||||
with module.graph.inserting_before(node):
|
||||
new_node = module.graph.call_function(func, tuple(node_args), node.kwargs)
|
||||
for node in module.graph.nodes:
|
||||
if not (node.op == "call_function" and node.target is with_effects):
|
||||
continue
|
||||
|
||||
# Update getitem nodes that extract outputs from with_effects
|
||||
for user in list(node.users.keys()):
|
||||
assert user.target is operator.getitem
|
||||
# getitem(with_effects, 0) is the token node
|
||||
if user.args[1] == 0:
|
||||
for user_user in list(user.users.keys()):
|
||||
if user_user.op == "output":
|
||||
output_tokens.append(user)
|
||||
with_effect_nodes.append(node)
|
||||
|
||||
# Fix up the getitem nodes based on return count
|
||||
if len(schema.returns) == 1:
|
||||
# Single return: replace getitem(with_effects, 1) with the node itself
|
||||
for user in list(node.users.keys()):
|
||||
if user.args[1] == 1:
|
||||
# Remove tokens from outputs
|
||||
assert output_node is not None
|
||||
output_args = output_node.args[0]
|
||||
assert len(output_args) >= num_tokens
|
||||
out_token_nodes = output_args[:num_tokens]
|
||||
output_node.args = (tuple(output_args[num_tokens:]),)
|
||||
for out_token in out_token_nodes:
|
||||
assert out_token.name in output_token_names
|
||||
out_token.users.clear()
|
||||
ep.graph.erase_node(out_token)
|
||||
|
||||
# Replace with_effects(token, func, args) with just func(args)
|
||||
for node in reversed(with_effect_nodes):
|
||||
func = node.args[1]
|
||||
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
|
||||
|
||||
if func is torch.ops.higher_order.call_torchbind:
|
||||
custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr]
|
||||
assert isinstance(custom_obj_meta, CustomObjArgument)
|
||||
if custom_obj_meta.fake_val:
|
||||
custom_obj = custom_obj_meta.fake_val
|
||||
elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr]
|
||||
custom_obj = ep.constants[
|
||||
inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr]
|
||||
]
|
||||
else:
|
||||
raise RuntimeError(f"Unable to find custom obj for node {node}")
|
||||
schema = _get_schema(func, (custom_obj,) + node.args[3:])
|
||||
else:
|
||||
schema = _get_schema(func, node.args[2:])
|
||||
|
||||
with ep.graph.inserting_before(node):
|
||||
new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
|
||||
for k, v in node.meta.items():
|
||||
new_node.meta[k] = v
|
||||
if k == "unbacked_bindings":
|
||||
# Remove the extra layer for effect token
|
||||
old_bindings = new_node.meta[k]
|
||||
new_bindings = {
|
||||
k: path[1:] if path else path for k, path in old_bindings.items()
|
||||
}
|
||||
new_node.meta[k] = new_bindings
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
# Update user getitem nodes
|
||||
for user in list(new_node.users.keys()):
|
||||
assert user.target is operator.getitem
|
||||
# getitem(with_effects, 0) == token
|
||||
if user.args[1] == 0:
|
||||
ep.graph.erase_node(user)
|
||||
|
||||
if len(schema.returns) == 1:
|
||||
# If the function has 1 return then it will just directly return the
|
||||
# result -- we don't need a getitem. So we can replace all the
|
||||
# getitem(with_effects, 1) with just the note itself.
|
||||
for user in list(new_node.users.keys()):
|
||||
assert user.args[1] == 1
|
||||
user.replace_all_uses_with(new_node)
|
||||
new_node.meta["val"] = node.meta["val"][1]
|
||||
elif len(schema.returns) > 1:
|
||||
# Multiple returns: shift getitem indices down by 1
|
||||
for user in list(node.users.keys()):
|
||||
if user.args[1] >= 1:
|
||||
user.args = (new_node, user.args[1] - 1)
|
||||
new_node.meta["val"] = node.meta["val"][1:]
|
||||
else:
|
||||
# No returns
|
||||
assert len(schema.returns) == 0
|
||||
assert len(new_node.users) == 0
|
||||
new_node.meta["val"] = None
|
||||
|
||||
# Copy metadata from old node to new node
|
||||
for k, v in node.meta.items():
|
||||
new_node.meta[k] = v
|
||||
if k == "unbacked_bindings":
|
||||
# Remove the extra layer for effect token
|
||||
old_bindings = new_node.meta[k]
|
||||
new_bindings = {
|
||||
k: path[1:] if path else path for k, path in old_bindings.items()
|
||||
}
|
||||
new_node.meta[k] = new_bindings
|
||||
new_node.meta["val"] = node.meta["val"][1]
|
||||
elif len(schema.returns) > 1:
|
||||
# If the function has more than 1 return then since we got rid of
|
||||
# the 1st return value (the token), we need to bump all the other
|
||||
# getitem calls by 1 down
|
||||
for user in list(new_node.users.keys()):
|
||||
assert user.args[1] >= 1
|
||||
user.args = (user.args[0], user.args[1] - 1)
|
||||
|
||||
new_node.meta["val"] = node.meta["val"][1:]
|
||||
else:
|
||||
assert len(schema.returns) == 0
|
||||
assert len(new_node.users) == 0
|
||||
new_node.meta["val"] = None
|
||||
|
||||
def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens):
|
||||
"""Replace an invoke_subgraph node to remove the token argument."""
|
||||
assert node.args[0].op == "get_attr"
|
||||
submod = getattr(module, node.args[0].target)
|
||||
if not submod.meta.get("has_with_effects", False):
|
||||
return
|
||||
ep.graph.erase_node(node)
|
||||
|
||||
# Remove token from inputs
|
||||
subgraph, identifier, token, *operands = node.args
|
||||
node.args = (subgraph, identifier, *operands)
|
||||
if token.op == "placeholder":
|
||||
input_tokens.append(token)
|
||||
# Remove tokens from inputs
|
||||
placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
|
||||
assert len(placeholders) >= num_tokens
|
||||
inp_token_nodes = placeholders[:num_tokens]
|
||||
for inp_token in inp_token_nodes:
|
||||
assert inp_token.name in input_token_names
|
||||
ep.graph.erase_node(inp_token)
|
||||
|
||||
# Update getitem nodes to account for removed token output
|
||||
for user in list(node.users.keys()):
|
||||
if user.args[1] >= 1:
|
||||
user.args = (node, user.args[1] - 1)
|
||||
elif user.args[1] == 0:
|
||||
for user_user in list(user.users.keys()):
|
||||
if user_user.op == "output":
|
||||
output_tokens.append(user)
|
||||
ep.graph.eliminate_dead_code()
|
||||
|
||||
|
||||
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
@ -124,65 +132,6 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
|
||||
This function does an inplace modification on the given ExportedProgram.
|
||||
"""
|
||||
print("before", ep)
|
||||
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
|
||||
|
||||
# mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs
|
||||
for _, module in ep.graph_module.named_modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
with_effect_nodes = [
|
||||
node for node in module.graph.nodes if node.target is with_effects
|
||||
]
|
||||
if len(with_effect_nodes) > 0:
|
||||
module.meta["has_with_effects"] = True
|
||||
|
||||
# Process each module with the replace hook to ensure graph signature is updated
|
||||
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
|
||||
for _, module in ep.graph_module.named_modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
|
||||
# Process with_effects and invoke_subgraph nodes
|
||||
for node in module.graph.nodes:
|
||||
if node.target is with_effects:
|
||||
_replace_with_effects_node(
|
||||
node,
|
||||
ep,
|
||||
inputs_to_lifted_custom_objs,
|
||||
output_tokens,
|
||||
input_tokens,
|
||||
module,
|
||||
)
|
||||
elif node.target is torch.ops.higher_order.invoke_subgraph:
|
||||
_replace_invoke_subgraph_node(
|
||||
node, module, output_tokens, input_tokens
|
||||
)
|
||||
|
||||
# Remove tokens from the output node
|
||||
if len(output_tokens) > 0:
|
||||
output_node = next(reversed(module.graph.find_nodes(op="output")))
|
||||
output_args = output_node.args[0]
|
||||
assert len(output_args) >= len(output_tokens), (
|
||||
f"{output_args} output arguments found\n"
|
||||
f"{output_tokens} output tokens found\n"
|
||||
f"{module.graph}"
|
||||
)
|
||||
output_node.args = (tuple(output_args[len(output_tokens) :]),)
|
||||
|
||||
module.graph.eliminate_dead_code()
|
||||
|
||||
# Remove tokens from the input placeholders
|
||||
for node in module.graph.nodes:
|
||||
if node.op == "placeholder" and node in input_tokens:
|
||||
module.graph.erase_node(node)
|
||||
|
||||
module.recompile()
|
||||
|
||||
num_tokens: int = 0
|
||||
input_token_names: list[str] = []
|
||||
new_input_specs: list[InputSpec] = []
|
||||
@ -210,5 +159,9 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
|
||||
assert num_tokens == num_out_tokens
|
||||
|
||||
print("after", ep)
|
||||
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
|
||||
_remove_effect_tokens_from_graph_helper(
|
||||
ep, num_tokens, input_token_names, output_token_names
|
||||
)
|
||||
|
||||
return ep
|
||||
|
||||
@ -748,23 +748,11 @@ def _unlift_exported_program_lifted_states(
|
||||
) -> torch.fx.GraphModule:
|
||||
check_guards = check_guards and _ok_to_generate_guards_fn()
|
||||
|
||||
source_node_dict = {
|
||||
node.name: node for node in ep.graph.nodes if node.op != "placeholder"
|
||||
}
|
||||
# placeholder node name might change after deepcopy
|
||||
placeholder_source_node_dict = {
|
||||
node.target: node for node in ep.graph.nodes if node.op == "placeholder"
|
||||
}
|
||||
|
||||
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
||||
new_gm.meta.update(ep.graph_module.meta)
|
||||
ep = copy.copy(ep)
|
||||
ep._graph_module = new_gm
|
||||
|
||||
# TODO T206340015
|
||||
if ep.verifiers[0].dialect != "TRAINING":
|
||||
ep = _remove_effect_tokens(ep)
|
||||
|
||||
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
||||
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
||||
forward_arg_names = (
|
||||
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
|
||||
@ -798,13 +786,19 @@ def _unlift_exported_program_lifted_states(
|
||||
for out_spec in ep.graph_signature.output_specs
|
||||
]
|
||||
|
||||
source_node_dict = {
|
||||
node.name: node for node in ep.graph.nodes if node.op != "placeholder"
|
||||
}
|
||||
# placeholder node name might change after deepcopy
|
||||
placeholder_source_node_dict = {
|
||||
node.target: node for node in ep.graph.nodes if node.op == "placeholder"
|
||||
}
|
||||
for node in new_gm.graph.nodes:
|
||||
source_node = None
|
||||
if node.op == "placeholder":
|
||||
source_node = placeholder_source_node_dict.get(node.target)
|
||||
else:
|
||||
if node.name in source_node_dict:
|
||||
source_node = source_node_dict.get(node.name)
|
||||
source_node = source_node_dict.get(node.name)
|
||||
node.meta["from_node"] = [
|
||||
NodeSource(
|
||||
source_node,
|
||||
|
||||
@ -753,9 +753,7 @@ class Node(_NodeBase):
|
||||
# between eager and compiled execution, regardless of generator usage
|
||||
return True
|
||||
|
||||
from torch._higher_order_ops.effects import has_effects
|
||||
|
||||
return self.target in _side_effectful_functions or has_effects(self.target)
|
||||
return self.target in _side_effectful_functions
|
||||
|
||||
# Check if an impure module.
|
||||
if self.op == "call_module":
|
||||
|
||||
Reference in New Issue
Block a user