Compare commits

..

2 Commits

Author SHA1 Message Date
d79ccd0bba More cleanup
Signed-off-by: Huy Do <huydhn@gmail.com>
2025-11-18 23:13:58 -08:00
d6bb3ad8b9 [vLLM] Update xformers==0.0.33.post1 and remove flashinfer-python
Signed-off-by: Huy Do <huydhn@gmail.com>
2025-11-18 20:46:53 -08:00
38 changed files with 276 additions and 757 deletions

View File

@ -84,7 +84,6 @@ class VllmTestRunner(BaseRunner):
self.VLLM_TEST_WHLS_REGEX = [
"xformers/*.whl",
"vllm/vllm*.whl",
"flashinfer-python/flashinfer*.whl",
]
def prepare(self):

View File

@ -1 +1 @@
617079d944b0e72632311c30ae2bbdf1168b901e
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc

View File

@ -1,4 +1,4 @@
ARG CUDA_VERSION=12.8.1
ARG CUDA_VERSION=12.9.1
ARG PYTHON_VERSION=3.12
# BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine,
@ -124,7 +124,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
git clone https://github.com/facebookresearch/xformers.git
pushd xformers
git checkout v0.0.32.post2
git checkout v0.0.33.post1
git submodule update --init --recursive
python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose
popd
@ -256,7 +256,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
# Install build and runtime dependencies, this is needed for flashinfer install
# Install build and runtime dependencies
COPY requirements/build.txt requirements/build.txt
COPY use_existing_torch.py use_existing_torch.py
RUN python3 use_existing_torch.py
@ -294,33 +294,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system /wheels/xformers/*.whl --verbose
# Build FlashInfer from source
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
RUN --mount=type=cache,target=/root/.cache/uv \
git clone --depth 1 --recursive --shallow-submodules \
--branch ${FLASHINFER_GIT_REF} \
${FLASHINFER_GIT_REPO} flashinfer \
&& echo "Building FlashInfer with AOT for arches: ${torch_cuda_arch_list}" \
&& cd flashinfer \
&& python3 -m flashinfer.aot \
&& python3 -m build --no-isolation --wheel --outdir ../wheels/flashinfer \
&& cd .. \
&& rm -rf flashinfer
# Install FlashInfer
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system wheels/flashinfer/*.whl --verbose
# Logging to confirm the torch versions
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm\|^flashinfer' > build_summary.txt
RUN pip freeze | grep -E 'torch|xformers|vllm'
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm' > build_summary.txt
################### VLLM INSTALLED IMAGE ####################
@ -331,4 +307,3 @@ FROM scratch as export-wheels
COPY --from=base /workspace/xformers-dist /wheels/xformers
COPY --from=build /workspace/vllm-dist /wheels/vllm
COPY --from=vllm-base /workspace/build_summary.txt /wheels/build_summary.txt
COPY --from=vllm-base /workspace/wheels/flashinfer /wheels/flashinfer-python

View File

@ -88,7 +88,7 @@ repackage_wheel() {
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1
pushd externals/vllm/wheels
for package in xformers flashinfer-python vllm; do
for package in xformers vllm; do
repackage_wheel $package
done
popd

View File

@ -117,10 +117,6 @@ Tensor& relu_mps_(Tensor& self) {
TORCH_IMPL_FUNC(log_softmax_mps_out)
(const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"log_softmax for complex is not supported for MPS");
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "log_softmax for bool is not supported for MPS");
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
@ -164,10 +160,6 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
(const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) {
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()),
"log_softmax for complex is not supported for MPS");
TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kBool, "log_softmax for bool is not supported for MPS");
using namespace mps;
using CachedGraph = MPSUnaryGradCachedGraph;
@ -208,7 +200,6 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
}
std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
// NOTE: buffer is only used by CPU dispatch, we just ignore it here
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
@ -715,7 +706,6 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
if (output.numel() == 0)
return;
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
// can't be evenly halved, but give a nicer error message here.
TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
@ -829,7 +819,6 @@ TORCH_IMPL_FUNC(softplus_out_mps)
(const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) {
using namespace mps;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Not implemented for long");
// Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
// \log(1 + \exp(\beta * x))` element-wise.
// For numerical stability the implementation reverts to the linear function
@ -980,8 +969,6 @@ TORCH_IMPL_FUNC(mish_out_mps)
(const Tensor& self, const Tensor& result) {
using namespace mps;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS");
if (result.numel() == 0)
return;
@ -1030,8 +1017,6 @@ TORCH_IMPL_FUNC(mish_out_mps)
Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
using namespace mps;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS");
Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
if (grad_input.numel() == 0)
@ -1221,7 +1206,6 @@ TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) {
using CachedGraph = MPSUnaryCachedGraph;
TORCH_CHECK(self.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64");
// Empty output
if (result.numel() == 0)

View File

@ -80,11 +80,6 @@ static void grid_sampler_2d_mps_impl(Tensor& output,
MPSGraphTensor* outputTensor_ = nil;
};
// Crashes with
// MPSGraphUtilities.mm:97:0: error: 'mps.sample_grid' op operand #0 must be tensor of mps native type values, but got
// 'tensor<2x3x5x20xcomplex<f32>>'
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"grid_sampler_2d is not supported for complex on MPS");
@autoreleasepool {
std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" +
std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);

View File

@ -240,7 +240,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
bool check_errors) {
using namespace mps;
TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat,
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
"linalg.lu_factor(): MPS doesn't support complex types.");
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");
@ -364,7 +364,8 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
const Tensor& info) {
using namespace mps;
TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS only supports floats.");
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
"linalg.lu_factor(): MPS doesn't support complex types.");
Tensor A_t, B_t;
// If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T
// Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output
@ -1057,8 +1058,7 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
using namespace mps;
checkInputsSolver(A, B, left, "linalg.solve_triangular");
TORCH_CHECK(A.scalar_type() == kFloat && B.scalar_type() == kFloat,
"linalg.solve.triangular(); Only float is supported!");
TORCH_CHECK(!A.is_complex() && !B.is_complex(), "linalg.solve.triangular(); Not supported for complex yet!");
Tensor A_t, B_t;
std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr);
at::native::resize_output(out, B_t.sizes());

View File

@ -416,8 +416,6 @@ static void nllnd_loss_forward_impl(Tensor& output,
int64_t reduction,
int64_t ignore_index,
bool is2D) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(output.scalar_type()),
"nlld_loss for complex is not supported for MPS");
std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
reshapedTarget.push_back(1);
@ -826,9 +824,6 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) {
std::string op_name = __func__;
using namespace mps;
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"huber_loss for complex is not supported for MPS");
TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes")
TORCH_CHECK(output.is_mps());

View File

@ -597,7 +597,6 @@ static void avg_pool2d_template(const Tensor& input,
bool count_include_pad,
const std::optional<int64_t> divisor_override,
const std::string& op_name) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "Not implemented for complex");
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
const bool is_backward_pass = grad_output.defined();
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
@ -916,8 +915,6 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
bool ceil_mode,
const Tensor& output,
const Tensor& indices) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"Max pooling for complex is not supported for MPS");
bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
if (use_graph) {
auto indices_memory_format = indices.suggest_memory_format();
@ -970,8 +967,6 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)
bool ceil_mode,
const Tensor& indices,
const Tensor& grad_input) {
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()),
"Max pooling for complex is not supported for MPS");
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor

View File

@ -269,22 +269,17 @@ static void reduction_out_mps(const Tensor& input_t,
name:nil];
castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil];
} else if (reduction_type == MPSReductionType::NANSUM) {
// Integral types cannot contain NaN, so just do regular sum
if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) {
castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil];
} else {
// Create a 0 tensor of the same shape as inputTensor
auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType];
// Find NaNs
auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
// Replace NaNs with 0
auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:zeros
falsePredicateTensor:castInputTensor
name:nil];
// Sum
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil];
}
// Create a 0 tensor of the same shape as inputTensor
MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType];
// Find NaNs
MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
// Replace NaNs with 0
MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:zeros
falsePredicateTensor:castInputTensor
name:nil];
// Sum
castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil];
}
MPSGraphTensor* outputTensor = castOutputTensor;
@ -447,7 +442,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
const std::optional<Scalar>& correction,
bool keepdim,
StdVarType stdVarType) {
TORCH_CHECK_NOT_IMPLEMENTED(input_t.scalar_type() != kLong, "Not implemented for MPS");
using CachedGraph = MPSUnaryCachedGraph;
IntArrayRef input_shape = input_t.sizes();

View File

@ -39,7 +39,6 @@ static void get_shapes(MPSShape* input_shape_readonly,
TORCH_IMPL_FUNC(softmax_mps_out)
(const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) {
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS");
TORCH_CHECK(c10::isFloatingType(input_.scalar_type()), "softmax only supported for floating types");
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if (input_.numel() == 0) {

View File

@ -18,10 +18,6 @@ static Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tens
MPSStream* stream = getCurrentMPSStream();
bool has_weights = weights.defined();
// Crashes with
// MPSGraphUtilities.mm:190:0: error: 'mps.scatter' op operand #2 must be tensor of int values, but got 'tensor<5xi1>'
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "bincount is not supported for Bool");
@autoreleasepool {
std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {

View File

@ -4617,7 +4617,7 @@
dispatch:
CompositeExplicitAutograd: permute
MPS: permute_mps
SparseCPU, SparseCUDA, SparseMPS: permute_sparse_coo
SparseCPU, SparseCUDA: permute_sparse_coo
tags: core
- func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)

View File

@ -48,7 +48,7 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const noexcept {
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
@ -59,12 +59,12 @@ void StorageImpl::incref_pyobject() const noexcept {
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const noexcept {
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const noexcept {
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;

View File

@ -105,11 +105,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const noexcept override final;
void incref_pyobject() const override final;
void decref_pyobject() const noexcept override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const noexcept override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive

View File

@ -988,7 +988,7 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const noexcept {
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
@ -999,12 +999,12 @@ void TensorImpl::incref_pyobject() const noexcept {
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const noexcept {
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const noexcept {
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;

View File

@ -2178,11 +2178,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const noexcept override final;
void incref_pyobject() const override final;
void decref_pyobject() const noexcept override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const noexcept override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]

View File

@ -68,10 +68,6 @@ inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
}
inline bool is_uniquely_owned(uint64_t combined_refcount) {
return (combined_refcount & ~detail::kHasPyObject) == detail::kUniqueRef;
}
// The only requirement for refcount increment is that it happens-before
// decrement, so no additional memory ordering is needed.
inline uint64_t atomic_combined_refcount_increment(
@ -291,9 +287,9 @@ class C10_API intrusive_ptr_target {
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const noexcept {}
virtual void decref_pyobject() const noexcept {}
virtual bool try_incref_pyobject() const noexcept {
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
@ -367,7 +363,7 @@ class intrusive_ptr final {
template <typename, typename...>
friend class pybind11::class_;
void retain_() noexcept {
void retain_() {
if (target_ != NullType::singleton()) {
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
@ -381,7 +377,9 @@ class intrusive_ptr final {
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) {
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
@ -394,60 +392,51 @@ class intrusive_ptr final {
void reset_() noexcept {
if (target_ != NullType::singleton()) {
reset_not_null_(target_);
}
}
// C10_NOINLINE to keep binary size a bit smaller. We pass TTarget* here
// to avoid an extra pointer dereference in the call from reset_().
C10_NOINLINE static void reset_not_null_(TTarget* target) noexcept {
if (detail::is_uniquely_owned(
target->combined_refcount_.load(std::memory_order_acquire))) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target deletion
// call (e.g. calling use_count()) without a data race.
target->combined_refcount_.store(0, std::memory_order_relaxed);
delete target;
return;
}
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
if (detail::weakcount(combined_refcount) == 1) {
delete target;
if (is_uniquely_owned()) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
// call (e.g. calling use_count()) without a data race.
target_->combined_refcount_.store(0, std::memory_order_relaxed);
delete target_;
return;
}
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
release_resources_and_decrement_weakrefs_(target);
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (has_pyobject && new_refcount == 1) {
target->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
C10_NOINLINE static void release_resources_and_decrement_weakrefs_(
TTarget* target) noexcept {
// justification for const_cast: release_resources is basically a
// destructor and a destructor always mutates the object, even for
// const objects.
const_cast<std::remove_const_t<TTarget>*>(target)->release_resources();
if (detail::atomic_weakcount_decrement(target->combined_refcount_) == 0) {
delete target;
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
if (!should_delete) {
// justification for const_cast: release_resources is basically a
// destructor and a destructor always mutates the object, even for
// const objects.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<std::remove_const_t<TTarget>*>(target_)
->release_resources();
should_delete = detail::atomic_weakcount_decrement(
target_->combined_refcount_) == 0;
}
if (should_delete) {
delete target_;
}
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
@ -618,8 +607,9 @@ class intrusive_ptr final {
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
return detail::is_uniquely_owned(
target_->combined_refcount_.load(std::memory_order_acquire));
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/**
@ -1184,7 +1174,9 @@ inline void incref(intrusive_ptr_target* self) {
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) {
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else

View File

@ -428,14 +428,7 @@ class TestFullyShardCommunication(FSDPTest):
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1571
def test_set_reduce_scatter_divide_factor(self):
self.run_subtests(
{
"divide_factor": [self.world_size * 2, self.world_size],
"mesh_shape": [
(self.world_size,),
(self.world_size // 2, 2),
(self.world_size, 1),
],
},
{"divide_factor": [self.world_size * 2, self.world_size]},
self._test_set_reduce_scatter_divide_factor,
)
self.run_subtests(
@ -443,31 +436,18 @@ class TestFullyShardCommunication(FSDPTest):
self._test_set_reduce_scatter_divide_factor_mixed_prevision,
)
def _test_set_reduce_scatter_divide_factor(
self, divide_factor: float, mesh_shape: tuple[int] | tuple[int, int]
):
def _test_set_reduce_scatter_divide_factor(self, divide_factor: float):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
mesh_dim_names = ("outer",) if len(mesh_shape) == 1 else ("outer", "inner")
mesh = init_device_mesh(
device_type.type, mesh_shape, mesh_dim_names=mesh_dim_names
)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, reshard_after_forward=False, mesh=mesh)
model = fully_shard(model, reshard_after_forward=False, mesh=mesh)
fully_shard(module, reshard_after_forward=False)
model = fully_shard(model, reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
model.set_gradient_divide_factor(divide_factor)
# Get ref_model params which should have the specific division factor applied
block_params = set()
for ref_mod in ref_model.modules():
if isinstance(ref_mod, TransformerBlock):
block_params.update(ref_mod.parameters())
non_block_params = set(ref_model.parameters()) - block_params
model.set_reduce_scatter_divide_factor(divide_factor)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
@ -476,18 +456,16 @@ class TestFullyShardCommunication(FSDPTest):
ref_loss = ref_model(inp).sum()
ref_loss.backward()
for param in ref_model.parameters():
factor = divide_factor if param in non_block_params else self.world_size
param.grad.mul_(1.0 / factor)
param.grad.mul_(1.0 / divide_factor)
dist.all_reduce(param.grad)
loss = model(inp).sum()
loss.backward()
ref_optim.step()
optim.step()
self.assertEqual(ref_loss, loss)
# Check parity before calling zero_grad so that grads are also checked
check_sharded_parity(self, ref_model, model)
ref_optim.zero_grad()
optim.zero_grad()
self.assertEqual(ref_loss, loss)
check_sharded_parity(self, ref_model, model)
def _test_set_reduce_scatter_divide_factor_mixed_prevision(
self, divide_factor: float
@ -506,7 +484,7 @@ class TestFullyShardCommunication(FSDPTest):
fully_shard(mlp, mp_policy=mp_policy)
model = fully_shard(model, mp_policy=mp_policy)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
model.set_gradient_divide_factor(divide_factor)
model.set_reduce_scatter_divide_factor(divide_factor)
torch.manual_seed(42 + self.rank)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)

View File

@ -34,11 +34,7 @@ from torch.distributed.tensor._ops.utils import (
register_op_strategy,
replicate_op_strategy,
)
from torch.distributed.tensor.debug import (
_clear_fast_path_sharding_prop_cache,
_clear_python_sharding_prop_cache,
CommDebugMode,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -483,8 +479,7 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
del propagator.op_to_schema_info[op_overload]
else:
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
_clear_fast_path_sharding_prop_cache()
_clear_python_sharding_prop_cache()
propagator.propagate_op_sharding.cache.cache_clear()
def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool:
@ -650,28 +645,6 @@ class TestStrategyHashing(DTensorTestBase):
self.assertEqual(out1.full_tensor(), out2.full_tensor())
class TestStrategyOperation(DTensorTestBase):
@property
def world_size(self):
return 2
@with_comms
def test_cache_clean(self):
mesh = self.build_device_mesh()
test_op = torch.ops.mylib.numpy_sin
x = torch.randn(2, device=self.device_type)
y = torch.randn(2, device=self.device_type)
x_dt = distribute_tensor(x, mesh, [Shard(0)])
y_dt = distribute_tensor(y, mesh, [Shard(0)])
with op_strategy_context(test_op.default, replicate_op_strategy):
self._test_op_on_dtensor(test_op, x_dt, y_dt)
with self.assertRaisesRegex(
NotImplementedError,
f"Operator {test_op.default} does not have a sharding strategy registered",
):
self._test_op_on_dtensor(test_op, x_dt, y_dt)
DistTensorReplicateStrategyRegistrationTestWithLocalTensor = (
create_local_tensor_test_class(
DistTensorReplicateStrategyRegistrationTest,

View File

@ -585,10 +585,6 @@ class GraphModule(torch.nn.Module):
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
# No stacktrace found for following nodes
record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None
wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)

View File

@ -1405,7 +1405,7 @@ class TestConverter(TestCase):
)
# qnnpack not supported on s390x
@xfailIfS390X
def test_ts2ep_convert_quantized_model1(self):
def test_ts2ep_convert_quantized_model(self):
class Standalone(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -640,13 +640,16 @@ class TestPasses(TestCase):
self.assertExpectedInline(
without_token_ep.graph_module.code.strip(),
"""\
def forward(self, obj_attr, x):
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None
getitem_1 = takes_foo_tuple_return_default[0]
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
def forward(self, token, obj_attr, x):
with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None
getitem = with_effects[0]
getitem_1 = with_effects[1]
getitem_2 = with_effects[2]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None
return (takes_foo_default,)""", # noqa: B950
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None
getitem_3 = with_effects_1[0]
getitem_4 = with_effects_1[1]; with_effects_1 = None
return (getitem_3, getitem_4)""", # noqa: B950
)
def test_fakify_script_objects(self):

View File

@ -461,9 +461,9 @@ def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
)
self.assertExpectedInline(
@ -1087,12 +1087,10 @@ def forward(self, token, tq, x):
str(ep.graph_module.graph).strip(),
"""\
graph():
%token : [num_users=1] = placeholder[target=token]
%tq : [num_users=2] = placeholder[target=tq]
%x : [num_users=1] = placeholder[target=x]
%with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
return (getitem, tq)""", # noqa: B950
%queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {})
return (tq,)""", # noqa: B950
)
def test_deepcopy(self):

View File

@ -870,100 +870,6 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
finally:
handle.destroy()
@unittest.skipIf(not TEST_CUDA, "triton")
def test_export_invoke_subgraph(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
recorded_list = []
@torch.library.custom_op("mylib::record_memory", mutates_args=())
def record_memory(prefix: str, module_name: str) -> None:
torch.cuda.synchronize()
mem_alloc = torch.cuda.memory_allocated() / 1024**2
mem_reserved = torch.cuda.memory_reserved() / 1024**2
memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB"
recorded_list.append(memory_str)
@record_memory.register_fake
def record_memory_fake(prefix, module_name):
return
record_memory.register_effect(_EffectType.ORDERED)
class N(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(1024, 1024)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(1024, 1024)
@torch.compiler.nested_compile_region
def forward(self, x):
torch.ops.mylib.record_memory("forward", "N")
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod_list = torch.nn.ModuleList(N() for _ in range(3))
def forward(self, x):
for m in self.mod_list:
x = m(x)
torch.ops.mylib.record_memory("forward", "N")
return (x,)
model = M().to("cuda")
torch.cuda.reset_peak_memory_stats()
x = torch.randn(32, 1024, requires_grad=True, device="cuda")
ep = torch.export.export(model, (x,))
ep = ep.run_decompositions()
self.assertEqual(len(list(ep.graph_module.named_modules())), 2)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None
getitem = invoke_subgraph[0]
getitem_1 = invoke_subgraph[1]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None
getitem_2 = invoke_subgraph_1[0]
getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None
repeated_subgraph0_2 = self.repeated_subgraph0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None
getitem_4 = invoke_subgraph_2[0]
getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None
with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None
getitem_6 = with_effects[0]; with_effects = None
return (getitem_6, getitem_5)""",
)
self.assertExpectedInline(
ep.graph_module.repeated_subgraph0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None
addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None
addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None
return (getitem, addmm_1)""",
)
recorded_list.clear()
out2 = ep.module()(x)
self.assertEqual(len(recorded_list), 4)
self.assertTrue(torch.allclose(model(x)[0], out2[0]))
if __name__ == "__main__":
run_tests()

View File

@ -995,6 +995,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
@expectedFailureMPS
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
@gradcheck_semantics()
def test_permute(self, device, dtype, coalesced, gradcheck):
@ -1034,8 +1035,7 @@ class TestSparse(TestSparseBase):
else:
self.assertFalse(s_permuted.is_coalesced())
kwargs = {"eps": 1e-4} if device == "mps:0" else {}
gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_(), **kwargs)
gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
else:
# otherwise check if exception is thrown
fail_message = "transpositions between sparse and dense dimensions are not allowed"

View File

@ -357,9 +357,6 @@ class TestFFT(TestCase):
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
allowed_dtypes=(torch.cfloat, torch.cdouble))
@toleranceOverride({
torch.cfloat : tol(2e-4, 1.3e-6),
})
def test_reference_nd(self, device, dtype, op):
if op.ref is None:
raise unittest.SkipTest("No reference implementation")

View File

@ -33,7 +33,7 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams, insert_backward_syncs
from .streams import assign_backward_streams
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -477,8 +477,6 @@ def aot_dispatch_autograd_graph(
# After copying metadata, assign streams to gradient accumulation nodes
assign_backward_streams(fx_g)
insert_backward_syncs(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -3,7 +3,6 @@ from typing import Optional, TypeAlias
import torch.fx
import torch.fx.traceback
from torch._dynamo.graph_utils import _get_flat_args
from torch._dynamo.variables.streams import get_current_stream, new_event
Node: TypeAlias = torch.fx.Node
@ -13,14 +12,6 @@ def is_gradient_acc(node: Node) -> bool:
return node.meta.get("is_gradient_acc", False)
def is_bwd_node(node: Node) -> bool:
return node.meta.get("partitioner_tag") == "is_backward"
def get_device(node: Node) -> torch.device:
return node.meta["val"].device
def get_stream(node: Node) -> Optional[int]:
maybe_annotation = node.meta.get("custom", None)
if maybe_annotation is not None:
@ -29,13 +20,6 @@ def get_stream(node: Node) -> Optional[int]:
return None
def get_stream_or_current_stream(node: Node) -> int:
ind = get_stream(node)
if ind is None:
ind = get_current_stream(get_device(node))
return ind
def set_stream(node: Node, ind: int) -> None:
if "custom" in node.meta:
node.meta["custom"].update({"stream": ind})
@ -43,36 +27,6 @@ def set_stream(node: Node, ind: int) -> None:
node.meta["custom"] = {"stream": ind}
def insert_sync(
graph: torch.fx.Graph,
consumer: Node,
producer: Node,
node_to_wait_event_ind: dict[Node, int],
) -> None:
if producer not in node_to_wait_event_ind:
node_to_wait_event_ind[producer] = new_event()
with graph.inserting_after(producer):
node = graph.call_function(
torch.ops.streams.record_event.default,
(
node_to_wait_event_ind[producer],
get_stream_or_current_stream(producer),
),
)
node.meta["partitioner_tag"] = "must_be_in_backward"
with graph.inserting_before(consumer):
node = graph.call_function(
torch.ops.streams.wait_event.default,
(
node_to_wait_event_ind[producer],
get_stream_or_current_stream(consumer),
),
)
node.meta["partitioner_tag"] = "must_be_in_backward"
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
@ -97,18 +51,3 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
if ind is not None:
set_stream(node, ind)
break
def insert_backward_syncs(gm: torch.fx.GraphModule) -> None:
"""Inserts stream syncs for backward nodes if consumer and producer are on different streams"""
node_to_wait_event_ind = {}
for node in gm.graph.nodes:
if is_bwd_node(node):
flat_args = _get_flat_args(node, {})
cur_node_stream = get_stream(node)
for arg in flat_args:
if is_bwd_node(arg):
arg_stream = get_stream(arg)
if arg_stream != cur_node_stream and get_device(arg).type != "cpu":
insert_sync(gm.graph, node, arg, node_to_wait_event_ind)

View File

@ -713,9 +713,6 @@ class InvokeSubgraphCache(HopSubgraphCache):
self.lazy_bwd_cache: dict[
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
] = defaultdict(dict)
self.effects_cache: dict[
str, set
] = {} # Maps identifier -> set of effect types
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
self.dynamo_installed_submodules[fn_id].append(identifier)
@ -754,21 +751,6 @@ class InvokeSubgraphCache(HopSubgraphCache):
return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
def add_effects(self, identifier: str, effects: set) -> None:
"""Store the effect types for a given invoke_subgraph identifier."""
if prev_effects := self.effects_cache.get(identifier, None):
assert effects == prev_effects, (
"Different number of effects were found for invoke_subgraph "
f"call with identifier {identifier}. \n"
f"Previously we had the following effects: {prev_effects}.\n"
f"But now we have: {effects}."
)
self.effects_cache[identifier] = effects
def get_effects(self, identifier: str) -> Optional[set]:
"""Retrieve the effect types for a given invoke_subgraph identifier."""
return self.effects_cache.get(identifier, None)
class HopDispatchSetCache:
def __init__(self) -> None:

View File

@ -80,7 +80,6 @@ class InvokeSubgraphHOP(HigherOrderOperator):
assert all(
isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator))
for o in operands
if o is not None
), (
f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}"
)
@ -305,62 +304,6 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
def get_output_metadata(subgraph, *operands):
"""
Extract metadata about the subgraph outputs WITHOUT executing the subgraph.
This avoids running side-effectful operations twice (once here, once in forward).
We analyze the graph structure statically to extract metadata.
"""
# Unwrap FunctionalizeCtxWrapper if present
if isinstance(subgraph, FunctionalizeCtxWrapper):
subgraph = subgraph.subgraph
# If not a GraphModule, fall back to execution-based metadata extraction
if not isinstance(subgraph, torch.fx.GraphModule):
return _get_output_metadata_by_execution(subgraph, *operands)
output_metadata = OutputMetadata()
# Extract output arguments from the output node
# The output node has args=(output_values,) where output_values is a tuple/list
output_node = next(reversed(subgraph.graph.find_nodes(op="output")))
output_metadata.num_fw_outs = len(output_node.args[0])
for idx, output_arg in enumerate(output_node.args[0]):
if not isinstance(output_arg, torch.fx.Node):
if isinstance(output_arg, int):
output_metadata.indexes_with_symint.add(idx)
output_metadata.indexes_with_no_grad.add(idx)
continue
# Check node metadata for type information
if output_arg.meta.get("val") is None:
# If we don't have complete metadata for all outputs, fall back to execution
# This is important for correctness (e.g., detecting SymInts) even though it
# runs side-effectful operations
return _get_output_metadata_by_execution(subgraph, *operands)
val = output_arg.meta["val"]
if isinstance(val, torch.SymInt):
output_metadata.indexes_with_symint.add(idx)
output_metadata.indexes_with_no_grad.add(idx)
elif isinstance(val, torch.Tensor):
# Check if tensor requires grad from metadata
if hasattr(val, "requires_grad") and not val.requires_grad:
output_metadata.indexes_with_no_grad.add(idx)
else:
# Non-tensor, non-symint (shouldn't happen but be safe)
output_metadata.indexes_with_no_grad.add(idx)
return output_metadata
def _get_output_metadata_by_execution(subgraph, *operands):
"""
Fallback: Extract metadata by executing the subgraph.
This should only be used when static analysis fails.
WARNING: This will run side-effectful operations!
"""
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
# args are functional tensors, generate some example tensors
@ -380,15 +323,19 @@ def _get_output_metadata_by_execution(subgraph, *operands):
num_fw_outs = len(fw_outs)
# Collect the indexes of none in the output to check that the grad
# is None at the corresponding index in the backward. This check is
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
# Also collect the indexes of no_grad in the output to filter out
# the grad_outs in the `backward` method.
output_metadata = OutputMetadata()
output_metadata.num_fw_outs = num_fw_outs
output_metadata.num_fw_outs = num_fw_outs
for idx, fw_out in enumerate(fw_outs):
if isinstance(fw_out, torch.SymInt):
output_metadata.indexes_with_symint.add(idx)
elif not fw_out.requires_grad:
output_metadata.indexes_with_no_grad.add(idx)
return output_metadata
@ -615,34 +562,7 @@ def _(ctx, subgraph, identifier, *operands):
do_auto_functionalize_v2,
)
# (in the functionalization metadata phase) Capture tokens before
tokens_before = dict(ctx.mode._tokens)
# Check if this subgraph has effects stored in the cache
invoke_subgraph_cache = get_invoke_subgraph_cache()
effects = None
if invoke_subgraph_cache:
effects = invoke_subgraph_cache.get_effects(identifier)
if effects:
assert len(effects) == 1, "Multiple effects within a subgraph NYI"
tokens = ctx.mode._tokens
effects = next(iter(effects))
token_input = tokens[effects]
operands = (token_input, *operands)
def wrap_subgraph(subgraph):
def wrapped_subgraph(token, *args):
res = subgraph(*args)
return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res
return wrapped_subgraph
subgraph = wrap_subgraph(subgraph)
unwrapped_operands = ctx.unwrap_tensors(operands)
hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands)
if can_auto_functionalize(hop_instance):
# NOTE: [auto_functionalize x invoke_subgraph caching]
@ -667,28 +587,6 @@ def _(ctx, subgraph, identifier, *operands):
# of invoke_subgraph ops if input aliasing/mutation is detected.
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
if effects:
(new_token, *out) = out
ctx.mode._tokens[effects] = new_token
# (in the functionalization metadata phase) Capture tokens after and see if
# there are any differences (there are new effects or the token value for an
# effect type has changed)
tokens_after = dict(ctx.mode._tokens)
discovered_effects = set()
for effect_type, token in tokens_after.items():
if effect_type not in tokens_before or tokens_before[effect_type] is not token:
discovered_effects.add(effect_type)
if discovered_effects:
assert ctx.mode._allow_token_discovery, (
f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}."
)
# Store discovered effects in the cache by identifier
if invoke_subgraph_cache:
invoke_subgraph_cache.add_effects(identifier, discovered_effects)
return ctx.wrap_tensors(out)

View File

@ -35,18 +35,6 @@ class EffectHolder:
if namespace == "higher_order":
return
# These classes do not have side effects as they just store quantization
# params, so we dont need to mark them as ordered
skip_classes = (
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
"__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
"__torch__.torch.classes.quantized.LinearPackedParamsBase",
"__torch__.torch.classes.xnnpack.Conv2dOpContext",
"__torch__.torch.classes.xnnpack.LinearOpContext",
"__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
)
opname = f"{namespace}::{opname}"
if torch._C._get_operation_overload(opname, overload) is not None:
# Since we call this when destroying the library, sometimes the
@ -54,9 +42,6 @@ class EffectHolder:
schema = torch._C._get_schema(opname, overload)
for arg in schema.arguments:
if isinstance(arg.type, torch.ClassType):
type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
if type_str in skip_classes:
continue
self._effect = EffectType.ORDERED
return

View File

@ -390,27 +390,31 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
m.def("_supported_activities", []() {
std::set<torch::profiler::impl::ActivityType> activities{
torch::profiler::impl::ActivityType::CPU};
#if defined(USE_KINETO)
#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
if (at::getNumGPUs() > 0) {
activities.insert(torch::profiler::impl::ActivityType::CUDA);
}
#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
#if (!defined(LIBKINETO_NOXPUPTI))
if (at::hasXPU()) {
activities.insert(torch::profiler::impl::ActivityType::XPU);
}
#endif // (!defined(LIBKINETO_NOXPUPTI))
#if defined(USE_KINETO) && \
(!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
if (at::hasMTIA()) {
activities.insert(torch::profiler::impl::ActivityType::MTIA);
}
if (at::hasHPU()) {
activities.insert(torch::profiler::impl::ActivityType::HPU);
}
if (at::getNumGPUs() > 0) {
activities.insert(torch::profiler::impl::ActivityType::CUDA);
}
#elif defined(USE_KINETO)
if (at::hasXPU()) {
activities.insert(torch::profiler::impl::ActivityType::XPU);
}
if (at::hasHPU()) {
activities.insert(torch::profiler::impl::ActivityType::HPU);
}
if (at::hasMTIA()) {
activities.insert(torch::profiler::impl::ActivityType::MTIA);
}
if (c10::get_privateuse1_backend() != "privateuseone") {
activities.insert(torch::profiler::impl::ActivityType::PrivateUse1);
}
#endif // defined(USE_KINETO)
#endif
return activities;
});

View File

@ -1200,27 +1200,25 @@ get_thread_local_native_sharding_propagator_cache() {
py::reinterpret_borrow<py::dict>(PyThreadState_GetDict());
// We need to clean up before Python detaches from the thread if
// the thread is being destroyed.
if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) {
thread_dict["__DTensor_fastpath_thread_cache_cleanup"] =
py::capsule(new std::thread::id(this_thread_id), [](void* p) {
auto* ptid = reinterpret_cast<std::thread::id*>(p);
{
std::lock_guard<std::mutex> inner_lock(
native_sharding_propagator_cache_cleanup_mutex);
auto it = all_thread_caches.find(*ptid);
if (it != all_thread_caches.end()) {
// We need to both:
// 1) free python objects, and
it->second->reset();
// 2) make sure we don't try to come back and mess with
// a destroyed thread-local at module unload (e.g.,
// process exit) time.
all_thread_caches.erase(it);
}
thread_dict["__DTensor_fastpath_thread_cache_cleanup"] =
py::capsule(new std::thread::id(this_thread_id), [](void* p) {
auto* ptid = reinterpret_cast<std::thread::id*>(p);
{
std::lock_guard<std::mutex> inner_lock(
native_sharding_propagator_cache_cleanup_mutex);
auto it = all_thread_caches.find(*ptid);
if (it != all_thread_caches.end()) {
// We need to both:
// 1) free python objects, and
it->second->reset();
// 2) make sure we don't try to come back and mess with
// a destroyed thread-local at module unload (e.g.,
// process exit) time.
all_thread_caches.erase(it);
}
delete ptid;
});
}
}
delete ptid;
});
}
return native_sharding_propagator_cache_DO_NOT_USE.value();
}

View File

@ -547,12 +547,8 @@ def foreach_reduce(
op=reduce_scatter_op,
)
else:
# For single GPU, just copy the input to output (no actual reduce-scatter needed), and
# account for a possible gradient_divide_factor.
if gradient_divide_factor is not None:
reduce_output.copy_(reduce_scatter_input / gradient_divide_factor)
else:
reduce_output.copy_(reduce_scatter_input)
# For single GPU, just copy the input to output (no actual reduce-scatter needed)
reduce_output.copy_(reduce_scatter_input)
reduce_scatter_event = reduce_scatter_stream.record_event()
post_reduce_stream = reduce_scatter_stream
if all_reduce_group is not None: # HSDP or DDP/replicate
@ -725,21 +721,20 @@ def _get_gradient_divide_factors(
if all_reduce_group is not None:
data_parallel_size *= all_reduce_group.size()
if factor is None:
factor = float(data_parallel_size)
if not overflow_risk and not force_sum_reduction_for_comms:
if factor is None:
if factor == data_parallel_size:
# Warning: NCCL ReduceOp.AVG may produce incorrect results with
# world size 1.
if data_parallel_size == 1:
return None, None, ReduceOp.SUM, ReduceOp.SUM
return None, None, ReduceOp.AVG, ReduceOp.AVG
if reduce_scatter_group is not None and factor == reduce_scatter_group.size():
reduce_scatter_op = ReduceOp.AVG
else:
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor)
return None, None, reduce_scatter_op, ReduceOp.SUM
return None, None, reduce_scatter_op, ReduceOp.SUM
if factor is None:
factor = float(data_parallel_size)
pre_factor: Optional[float]
if overflow_risk:
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid

View File

@ -15,105 +15,113 @@ from .graph_signature import (
)
def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants):
"""Extract the custom object from a node's arguments."""
custom_obj_node = node
custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr]
assert isinstance(custom_obj_meta, CustomObjArgument)
if custom_obj_meta.fake_val:
return custom_obj_meta.fake_val
elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr]
return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr]
else:
raise RuntimeError(f"Unable to find custom obj for node {node}")
def _replace_with_effects_node(
node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module
def _remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
):
"""Replace a with_effects node with the underlying function call."""
# Get the input nodes
token_node, func, *node_args = node.args
if token_node.op == "placeholder":
input_tokens.append(token_node)
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
output_node = None
with_effect_nodes: list[torch.fx.Node] = []
# Get the schema for the function
if func is torch.ops.higher_order.call_torchbind:
custom_obj = _get_custom_obj_for_node(
node_args[0], inputs_to_lifted_custom_objs, ep.constants
)
schema = _get_schema(func, [custom_obj] + node_args[1:])
else:
schema = _get_schema(func, node_args)
# Output node need to check its args against output_token_names (collected from output_spec)
# Therefore, we only need to find the top-levele output node
output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output")))
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
# Create the replacement node
with module.graph.inserting_before(node):
new_node = module.graph.call_function(func, tuple(node_args), node.kwargs)
for node in module.graph.nodes:
if not (node.op == "call_function" and node.target is with_effects):
continue
# Update getitem nodes that extract outputs from with_effects
for user in list(node.users.keys()):
assert user.target is operator.getitem
# getitem(with_effects, 0) is the token node
if user.args[1] == 0:
for user_user in list(user.users.keys()):
if user_user.op == "output":
output_tokens.append(user)
with_effect_nodes.append(node)
# Fix up the getitem nodes based on return count
if len(schema.returns) == 1:
# Single return: replace getitem(with_effects, 1) with the node itself
for user in list(node.users.keys()):
if user.args[1] == 1:
# Remove tokens from outputs
assert output_node is not None
output_args = output_node.args[0]
assert len(output_args) >= num_tokens
out_token_nodes = output_args[:num_tokens]
output_node.args = (tuple(output_args[num_tokens:]),)
for out_token in out_token_nodes:
assert out_token.name in output_token_names
out_token.users.clear()
ep.graph.erase_node(out_token)
# Replace with_effects(token, func, args) with just func(args)
for node in reversed(with_effect_nodes):
func = node.args[1]
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
if func is torch.ops.higher_order.call_torchbind:
custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr]
assert isinstance(custom_obj_meta, CustomObjArgument)
if custom_obj_meta.fake_val:
custom_obj = custom_obj_meta.fake_val
elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr]
custom_obj = ep.constants[
inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr]
]
else:
raise RuntimeError(f"Unable to find custom obj for node {node}")
schema = _get_schema(func, (custom_obj,) + node.args[3:])
else:
schema = _get_schema(func, node.args[2:])
with ep.graph.inserting_before(node):
new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
if k == "unbacked_bindings":
# Remove the extra layer for effect token
old_bindings = new_node.meta[k]
new_bindings = {
k: path[1:] if path else path for k, path in old_bindings.items()
}
new_node.meta[k] = new_bindings
node.replace_all_uses_with(new_node)
# Update user getitem nodes
for user in list(new_node.users.keys()):
assert user.target is operator.getitem
# getitem(with_effects, 0) == token
if user.args[1] == 0:
ep.graph.erase_node(user)
if len(schema.returns) == 1:
# If the function has 1 return then it will just directly return the
# result -- we don't need a getitem. So we can replace all the
# getitem(with_effects, 1) with just the note itself.
for user in list(new_node.users.keys()):
assert user.args[1] == 1
user.replace_all_uses_with(new_node)
new_node.meta["val"] = node.meta["val"][1]
elif len(schema.returns) > 1:
# Multiple returns: shift getitem indices down by 1
for user in list(node.users.keys()):
if user.args[1] >= 1:
user.args = (new_node, user.args[1] - 1)
new_node.meta["val"] = node.meta["val"][1:]
else:
# No returns
assert len(schema.returns) == 0
assert len(new_node.users) == 0
new_node.meta["val"] = None
# Copy metadata from old node to new node
for k, v in node.meta.items():
new_node.meta[k] = v
if k == "unbacked_bindings":
# Remove the extra layer for effect token
old_bindings = new_node.meta[k]
new_bindings = {
k: path[1:] if path else path for k, path in old_bindings.items()
}
new_node.meta[k] = new_bindings
new_node.meta["val"] = node.meta["val"][1]
elif len(schema.returns) > 1:
# If the function has more than 1 return then since we got rid of
# the 1st return value (the token), we need to bump all the other
# getitem calls by 1 down
for user in list(new_node.users.keys()):
assert user.args[1] >= 1
user.args = (user.args[0], user.args[1] - 1)
new_node.meta["val"] = node.meta["val"][1:]
else:
assert len(schema.returns) == 0
assert len(new_node.users) == 0
new_node.meta["val"] = None
def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens):
"""Replace an invoke_subgraph node to remove the token argument."""
assert node.args[0].op == "get_attr"
submod = getattr(module, node.args[0].target)
if not submod.meta.get("has_with_effects", False):
return
ep.graph.erase_node(node)
# Remove token from inputs
subgraph, identifier, token, *operands = node.args
node.args = (subgraph, identifier, *operands)
if token.op == "placeholder":
input_tokens.append(token)
# Remove tokens from inputs
placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
assert len(placeholders) >= num_tokens
inp_token_nodes = placeholders[:num_tokens]
for inp_token in inp_token_nodes:
assert inp_token.name in input_token_names
ep.graph.erase_node(inp_token)
# Update getitem nodes to account for removed token output
for user in list(node.users.keys()):
if user.args[1] >= 1:
user.args = (node, user.args[1] - 1)
elif user.args[1] == 0:
for user_user in list(user.users.keys()):
if user_user.op == "output":
output_tokens.append(user)
ep.graph.eliminate_dead_code()
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
@ -124,65 +132,6 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
This function does an inplace modification on the given ExportedProgram.
"""
print("before", ep)
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
# mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs
for _, module in ep.graph_module.named_modules():
if not isinstance(module, torch.fx.GraphModule):
continue
with_effect_nodes = [
node for node in module.graph.nodes if node.target is with_effects
]
if len(with_effect_nodes) > 0:
module.meta["has_with_effects"] = True
# Process each module with the replace hook to ensure graph signature is updated
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
for _, module in ep.graph_module.named_modules():
if not isinstance(module, torch.fx.GraphModule):
continue
input_tokens = []
output_tokens = []
# Process with_effects and invoke_subgraph nodes
for node in module.graph.nodes:
if node.target is with_effects:
_replace_with_effects_node(
node,
ep,
inputs_to_lifted_custom_objs,
output_tokens,
input_tokens,
module,
)
elif node.target is torch.ops.higher_order.invoke_subgraph:
_replace_invoke_subgraph_node(
node, module, output_tokens, input_tokens
)
# Remove tokens from the output node
if len(output_tokens) > 0:
output_node = next(reversed(module.graph.find_nodes(op="output")))
output_args = output_node.args[0]
assert len(output_args) >= len(output_tokens), (
f"{output_args} output arguments found\n"
f"{output_tokens} output tokens found\n"
f"{module.graph}"
)
output_node.args = (tuple(output_args[len(output_tokens) :]),)
module.graph.eliminate_dead_code()
# Remove tokens from the input placeholders
for node in module.graph.nodes:
if node.op == "placeholder" and node in input_tokens:
module.graph.erase_node(node)
module.recompile()
num_tokens: int = 0
input_token_names: list[str] = []
new_input_specs: list[InputSpec] = []
@ -210,5 +159,9 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
assert num_tokens == num_out_tokens
print("after", ep)
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
_remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
)
return ep

View File

@ -748,23 +748,11 @@ def _unlift_exported_program_lifted_states(
) -> torch.fx.GraphModule:
check_guards = check_guards and _ok_to_generate_guards_fn()
source_node_dict = {
node.name: node for node in ep.graph.nodes if node.op != "placeholder"
}
# placeholder node name might change after deepcopy
placeholder_source_node_dict = {
node.target: node for node in ep.graph.nodes if node.op == "placeholder"
}
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
new_gm.meta.update(ep.graph_module.meta)
ep = copy.copy(ep)
ep._graph_module = new_gm
# TODO T206340015
if ep.verifiers[0].dialect != "TRAINING":
ep = _remove_effect_tokens(ep)
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
forward_arg_names = (
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
@ -798,13 +786,19 @@ def _unlift_exported_program_lifted_states(
for out_spec in ep.graph_signature.output_specs
]
source_node_dict = {
node.name: node for node in ep.graph.nodes if node.op != "placeholder"
}
# placeholder node name might change after deepcopy
placeholder_source_node_dict = {
node.target: node for node in ep.graph.nodes if node.op == "placeholder"
}
for node in new_gm.graph.nodes:
source_node = None
if node.op == "placeholder":
source_node = placeholder_source_node_dict.get(node.target)
else:
if node.name in source_node_dict:
source_node = source_node_dict.get(node.name)
source_node = source_node_dict.get(node.name)
node.meta["from_node"] = [
NodeSource(
source_node,

View File

@ -753,9 +753,7 @@ class Node(_NodeBase):
# between eager and compiled execution, regardless of generator usage
return True
from torch._higher_order_ops.effects import has_effects
return self.target in _side_effectful_functions or has_effects(self.target)
return self.target in _side_effectful_functions
# Check if an impure module.
if self.op == "call_module":