mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Compare commits
26 Commits
v2.4.0-rc5
...
v2.4.0-rc7
| Author | SHA1 | Date | |
|---|---|---|---|
| 499621e7bb | |||
| e5bda62849 | |||
| 705e3ae420 | |||
| b26cde49b6 | |||
| 12ad767daf | |||
| 1164d3cb9c | |||
| 9533637daa | |||
| fadd3cc4ab | |||
| 80277a50bc | |||
| d0831d65aa | |||
| ca8d4d1751 | |||
| 3d7d7927ca | |||
| 5f7de217cb | |||
| 072d9e8ac9 | |||
| 1f84579407 | |||
| 4d83bca8d8 | |||
| 04339eec05 | |||
| 22a4d46e2b | |||
| 560869918d | |||
| 2bf37985b1 | |||
| 491e9e2d4a | |||
| ec19059347 | |||
| 04e98d3d0e | |||
| 699c056479 | |||
| 49d2eec960 | |||
| 165e09874b |
@ -1 +1 @@
|
||||
01cbe5045a6898c9a925f01435c8277b2fe6afcc
|
||||
21eae954efa5bf584da70324b640288c3ee7aede
|
||||
|
||||
1
.github/scripts/amd/package_triton_wheel.sh
vendored
1
.github/scripts/amd/package_triton_wheel.sh
vendored
@ -94,6 +94,7 @@ done
|
||||
# Copy Include Files
|
||||
cp -r $ROCM_HOME/include/hip $TRITON_ROCM_DIR/include
|
||||
cp -r $ROCM_HOME/include/roctracer $TRITON_ROCM_DIR/include
|
||||
cp -r $ROCM_HOME/include/hsa $TRITON_ROCM_DIR/include
|
||||
|
||||
# Copy linker
|
||||
mkdir -p $TRITON_ROCM_DIR/llvm/bin
|
||||
|
||||
42
.github/workflows/create_release.yml
vendored
42
.github/workflows/create_release.yml
vendored
@ -5,6 +5,11 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- release/*
|
||||
tags:
|
||||
# Final Release tags look like: v1.11.0
|
||||
- v[0-9]+.[0-9]+.[0-9]+
|
||||
# Release candidate tags look like: v1.11.0-rc1
|
||||
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||
release:
|
||||
types: [published]
|
||||
pull_request:
|
||||
@ -18,6 +23,8 @@ jobs:
|
||||
# https://github.com/softprops/action-gh-release?tab=readme-ov-file#permissions
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
pt_release_name: ${{ steps.release_name.outputs.pt_release_name }}
|
||||
steps:
|
||||
- uses: malfet/checkout@silent-checkout
|
||||
with:
|
||||
@ -49,11 +56,44 @@ jobs:
|
||||
# Create archive
|
||||
tar -czf "$PT_RELEASE_FILE" "$PT_RELEASE_NAME"
|
||||
echo "Created source archive $PT_RELEASE_FILE with content: $(ls -a "$PT_RELEASE_NAME")"
|
||||
- name: Upload source distribution
|
||||
- name: Upload source distribution for release
|
||||
if: ${{ github.event_name == 'release' }}
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
files: ${{env.PT_RELEASE_FILE}}
|
||||
- name: Upload source distribution to GHA artifacts for release tags
|
||||
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: ${{ env.PT_RELEASE_FILE }}
|
||||
path: ${{ env.PT_RELEASE_FILE }}
|
||||
- name: Set output
|
||||
id: release_name
|
||||
run: echo "::set-output name=pt_release_name::${{ env.PT_RELEASE_NAME }}.tar.gz"
|
||||
|
||||
upload_source_code_to_s3:
|
||||
if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
|
||||
runs-on: linux.2xlarge
|
||||
environment: sourcecode-upload
|
||||
name: Upload source code to S3 for release tags
|
||||
permissions:
|
||||
id-token: write
|
||||
needs: release
|
||||
steps:
|
||||
- uses: actions/download-artifact@v2
|
||||
with:
|
||||
name: ${{ needs.release.outputs.pt_release_name }}
|
||||
- name: Configure AWS credentials(PyTorch account)
|
||||
uses: aws-actions/configure-aws-credentials@v3
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::749337293305:role/gha_pytorch_source_code_upload_role
|
||||
aws-region: us-east-1
|
||||
- uses: seemethere/upload-artifact-s3@v5
|
||||
with:
|
||||
s3-bucket: pytorch
|
||||
s3-prefix: source_code/test
|
||||
if-no-files-found: warn
|
||||
path: ${{ needs.release.outputs.pt_release_name }}
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
|
||||
@ -17,7 +17,7 @@ static void metaFallback(
|
||||
"while using an operator with PT2 compilation APIs (torch.compile/torch.export); "
|
||||
"in order to use this operator with those APIs you'll need to add a fake impl. "
|
||||
"Please see the following for next steps: "
|
||||
"https://pytorch.org/docs/main/notes/custom_operators.html");
|
||||
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, Meta, m) {
|
||||
|
||||
@ -152,9 +152,6 @@ void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
|
||||
* Unregisters a CUDA graph from the RNG state.
|
||||
*/
|
||||
void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
|
||||
// Ensures that the RNG state is not currently being captured.
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot unregister the state during capturing stage.");
|
||||
// Verify the graph was previously registered.
|
||||
TORCH_CHECK(
|
||||
registered_graphs_.find(graph) != registered_graphs_.end(),
|
||||
|
||||
@ -84,11 +84,23 @@ struct GemmParams : OpParams {
|
||||
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
||||
}
|
||||
|
||||
size_t GetSizeA() const {
|
||||
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
}
|
||||
|
||||
size_t GetSizeB() const {
|
||||
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
}
|
||||
|
||||
size_t GetSizeC() const {
|
||||
return sizeof(T) * ldc * n;
|
||||
}
|
||||
|
||||
size_t GetSize(bool duplicate_inputs) const {
|
||||
size_t size = sizeof(T) * ldc * n;
|
||||
size_t size = GetSizeC();
|
||||
if (duplicate_inputs) {
|
||||
size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
size += GetSizeA();
|
||||
size += GetSizeB();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
@ -98,13 +110,13 @@ struct GemmParams : OpParams {
|
||||
*copy = *this;
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
size_t c_size = ldc * n * sizeof(T);
|
||||
size_t c_size = GetSizeC();
|
||||
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
||||
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
||||
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
||||
if (duplicate_inputs) {
|
||||
size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
size_t a_size = GetSizeA();
|
||||
size_t b_size = GetSizeB();
|
||||
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
||||
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
||||
copy->duplicate_inputs_ = true;
|
||||
@ -153,11 +165,23 @@ struct GemmStridedBatchedParams : OpParams {
|
||||
return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
|
||||
}
|
||||
|
||||
size_t GetSizeA() const {
|
||||
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m) * batch;
|
||||
}
|
||||
|
||||
size_t GetSizeB() const {
|
||||
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k) * batch;
|
||||
}
|
||||
|
||||
size_t GetSizeC() const {
|
||||
return sizeof(T) * ldc * n * batch;
|
||||
}
|
||||
|
||||
size_t GetSize(bool duplicate_inputs) const {
|
||||
size_t size = sizeof(T) * stride_c * batch;
|
||||
size_t size = GetSizeC();
|
||||
if (duplicate_inputs) {
|
||||
size += sizeof(T) * stride_a * batch;
|
||||
size += sizeof(T) * stride_b * batch;
|
||||
size += GetSizeA();
|
||||
size += GetSizeB();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
@ -167,13 +191,13 @@ struct GemmStridedBatchedParams : OpParams {
|
||||
*copy = *this;
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
size_t c_size = batch * stride_c * sizeof(T);
|
||||
size_t c_size = GetSizeC();
|
||||
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
||||
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
||||
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
||||
if (duplicate_inputs) {
|
||||
size_t a_size = sizeof(T) * stride_a * batch;
|
||||
size_t b_size = sizeof(T) * stride_b * batch;
|
||||
size_t a_size = GetSizeA();
|
||||
size_t b_size = GetSizeB();
|
||||
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
||||
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
||||
copy->duplicate_inputs_ = true;
|
||||
@ -226,11 +250,23 @@ struct ScaledGemmParams : OpParams {
|
||||
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
||||
}
|
||||
|
||||
size_t GetSizeA() const {
|
||||
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
}
|
||||
|
||||
size_t GetSizeB() const {
|
||||
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
}
|
||||
|
||||
size_t GetSizeC() const {
|
||||
return sizeof(T) * ldc * n;
|
||||
}
|
||||
|
||||
size_t GetSize(bool duplicate_inputs) const {
|
||||
size_t size = sizeof(T) * ldc * n;
|
||||
size_t size = GetSizeC();
|
||||
if (duplicate_inputs) {
|
||||
size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
size += GetSizeA();
|
||||
size += GetSizeB();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
@ -240,13 +276,13 @@ struct ScaledGemmParams : OpParams {
|
||||
*copy = *this;
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
size_t c_size = ldc * n * sizeof(T);
|
||||
size_t c_size = GetSizeC();
|
||||
copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
|
||||
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
||||
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
||||
if (duplicate_inputs) {
|
||||
size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
||||
size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
||||
size_t a_size = GetSizeA();
|
||||
size_t b_size = GetSizeB();
|
||||
copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
|
||||
copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
|
||||
copy->duplicate_inputs_ = true;
|
||||
|
||||
@ -375,9 +375,9 @@ void TuningContext::EnableNumericsCheck(bool value) {
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env != nullptr && strcmp(env, "0") == 0) {
|
||||
return false;
|
||||
const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env != nullptr && strcmp(env, "1") == 0) {
|
||||
return true;
|
||||
}
|
||||
return numerics_check_enable_;
|
||||
}
|
||||
|
||||
@ -124,8 +124,11 @@ class TunableOp {
|
||||
std::string id_name = "Default";
|
||||
ParamsT* reference_params = nullptr;
|
||||
|
||||
// numeric check option is controlled by non-static env var, so check it once per tuned operator
|
||||
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
|
||||
|
||||
// calcaulte a reference answer for numerical check
|
||||
if (ctx->IsNumericsCheckEnabled()) {
|
||||
if (do_numerics_check) {
|
||||
reference_params = params->DeepCopy(false);
|
||||
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
||||
}
|
||||
@ -156,10 +159,11 @@ class TunableOp {
|
||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||
|
||||
if (ctx->IsNumericsCheckEnabled()) {
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -807,6 +807,7 @@ struct ReduceOp {
|
||||
bool is_last_block_done = mark_block_finished();
|
||||
|
||||
if (is_last_block_done) {
|
||||
__threadfence(); // complete the acquire pattern after atomic
|
||||
value = ident;
|
||||
if (config.should_block_x_reduce()) {
|
||||
index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
||||
@ -595,6 +595,7 @@ struct ReduceJitOp {
|
||||
bool is_last_block_done = mark_block_finished();
|
||||
|
||||
if (is_last_block_done) {
|
||||
__threadfence(); //complete acquire pattern
|
||||
value = ident;
|
||||
if (config.should_block_x_reduce()) {
|
||||
uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
||||
@ -18,7 +18,7 @@ void throwNullDataPtrError() {
|
||||
"If you're using torch.compile/export/fx, it is likely that we are erroneously "
|
||||
"tracing into a custom kernel. To fix this, please wrap the custom kernel into "
|
||||
"an opaque custom op. Please see the following for details: "
|
||||
"https://pytorch.org/docs/main/notes/custom_operators.html");
|
||||
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html");
|
||||
}
|
||||
|
||||
// NOTE: [FakeTensor.data_ptr deprecation]
|
||||
|
||||
@ -1580,7 +1580,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
"If you're using torch.compile/export/fx, it is likely that we are erroneously "
|
||||
"tracing into a custom kernel. To fix this, please wrap the custom kernel into "
|
||||
"an opaque custom op. Please see the following for details: "
|
||||
"https://pytorch.org/docs/main/notes/custom_operators.html\n"
|
||||
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html\n"
|
||||
"If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call "
|
||||
"mutable_data() or raw_mutable_data() to actually allocate memory.");
|
||||
// Caller does the type check.
|
||||
|
||||
@ -19,8 +19,8 @@ are much faster in ``lower_precision_fp``. Other ops, like reductions, often req
|
||||
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype.
|
||||
|
||||
Ordinarily, "automatic mixed precision training" with datatype of ``torch.float16`` uses :class:`torch.autocast` and
|
||||
:class:`torch.amp.GradScaler` together, as shown in the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>`
|
||||
and `CUDA Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
|
||||
:class:`torch.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>`
|
||||
and `Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
|
||||
However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and may be used separately if desired.
|
||||
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with
|
||||
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.
|
||||
@ -256,6 +256,100 @@ In this case, combine the two layers using :func:`torch.nn.functional.binary_cro
|
||||
or :mod:`torch.nn.BCEWithLogitsLoss`. ``binary_cross_entropy_with_logits`` and ``BCEWithLogits``
|
||||
are safe to autocast.
|
||||
|
||||
.. _autocast-xpu-op-reference:
|
||||
|
||||
XPU Op-Specific Behavior (Experimental)
|
||||
---------------------------------------
|
||||
The following lists describe the behavior of eligible ops in autocast-enabled regions.
|
||||
These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`,
|
||||
as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces,
|
||||
they go through autocasting regardless of the namespace.
|
||||
|
||||
Ops not listed below do not go through autocasting. They run in the type
|
||||
defined by their inputs. However, autocasting may still change the type
|
||||
in which unlisted ops run if they're downstream from autocasted ops.
|
||||
|
||||
If an op is unlisted, we assume it's numerically stable in ``float16``.
|
||||
If you believe an unlisted op is numerically unstable in ``float16``,
|
||||
please file an issue.
|
||||
|
||||
XPU Ops that can autocast to ``float16``
|
||||
""""""""""""""""""""""""""""""""""""""""
|
||||
|
||||
``addbmm``,
|
||||
``addmm``,
|
||||
``addmv``,
|
||||
``addr``,
|
||||
``baddbmm``,
|
||||
``bmm``,
|
||||
``chain_matmul``,
|
||||
``multi_dot``,
|
||||
``conv1d``,
|
||||
``conv2d``,
|
||||
``conv3d``,
|
||||
``conv_transpose1d``,
|
||||
``conv_transpose2d``,
|
||||
``conv_transpose3d``,
|
||||
``GRUCell``,
|
||||
``linear``,
|
||||
``LSTMCell``,
|
||||
``matmul``,
|
||||
``mm``,
|
||||
``mv``,
|
||||
``RNNCell``
|
||||
|
||||
XPU Ops that can autocast to ``float32``
|
||||
""""""""""""""""""""""""""""""""""""""""
|
||||
|
||||
``__pow__``,
|
||||
``__rdiv__``,
|
||||
``__rpow__``,
|
||||
``__rtruediv__``,
|
||||
``binary_cross_entropy_with_logits``,
|
||||
``cosine_embedding_loss``,
|
||||
``cosine_similarity``,
|
||||
``cumsum``,
|
||||
``dist``,
|
||||
``exp``,
|
||||
``group_norm``,
|
||||
``hinge_embedding_loss``,
|
||||
``kl_div``,
|
||||
``l1_loss``,
|
||||
``layer_norm``,
|
||||
``log``,
|
||||
``log_softmax``,
|
||||
``margin_ranking_loss``,
|
||||
``nll_loss``,
|
||||
``normalize``,
|
||||
``poisson_nll_loss``,
|
||||
``pow``,
|
||||
``reciprocal``,
|
||||
``rsqrt``,
|
||||
``soft_margin_loss``,
|
||||
``softmax``,
|
||||
``softmin``,
|
||||
``sum``,
|
||||
``triplet_margin_loss``
|
||||
|
||||
XPU Ops that promote to the widest input type
|
||||
"""""""""""""""""""""""""""""""""""""""""""""
|
||||
These ops don't require a particular dtype for stability, but take multiple inputs
|
||||
and require that the inputs' dtypes match. If all of the inputs are
|
||||
``float16``, the op runs in ``float16``. If any of the inputs is ``float32``,
|
||||
autocast casts all inputs to ``float32`` and runs the op in ``float32``.
|
||||
|
||||
``bilinear``,
|
||||
``cross``,
|
||||
``grid_sample``,
|
||||
``index_put``,
|
||||
``scatter_add``,
|
||||
``tensordot``
|
||||
|
||||
Some ops not listed here (e.g., binary ops like ``add``) natively promote
|
||||
inputs without autocasting's intervention. If inputs are a mixture of ``float16``
|
||||
and ``float32``, these ops run in ``float32`` and produce ``float32`` output,
|
||||
regardless of whether autocast is enabled.
|
||||
|
||||
.. _autocast-cpu-op-reference:
|
||||
|
||||
CPU Op-Specific Behavior
|
||||
|
||||
@ -1,22 +1,22 @@
|
||||
.. _amp-examples:
|
||||
|
||||
CUDA Automatic Mixed Precision examples
|
||||
Automatic Mixed Precision examples
|
||||
=======================================
|
||||
|
||||
.. currentmodule:: torch.cuda.amp
|
||||
.. currentmodule:: torch.amp
|
||||
|
||||
Ordinarily, "automatic mixed precision training" means training with
|
||||
:class:`torch.autocast` and :class:`torch.cuda.amp.GradScaler` together.
|
||||
:class:`torch.autocast` and :class:`torch.amp.GradScaler` together.
|
||||
|
||||
Instances of :class:`torch.autocast` enable autocasting for chosen regions.
|
||||
Autocasting automatically chooses the precision for GPU operations to improve performance
|
||||
while maintaining accuracy.
|
||||
|
||||
Instances of :class:`torch.cuda.amp.GradScaler` help perform the steps of
|
||||
gradient scaling conveniently. Gradient scaling improves convergence for networks with ``float16``
|
||||
Instances of :class:`torch.amp.GradScaler` help perform the steps of
|
||||
gradient scaling conveniently. Gradient scaling improves convergence for networks with ``float16`` (by default on CUDA and XPU)
|
||||
gradients by minimizing gradient underflow, as explained :ref:`here<gradient-scaling>`.
|
||||
|
||||
:class:`torch.autocast` and :class:`torch.cuda.amp.GradScaler` are modular.
|
||||
:class:`torch.autocast` and :class:`torch.amp.GradScaler` are modular.
|
||||
In the samples below, each is used as its individual documentation suggests.
|
||||
|
||||
(Samples here are illustrative. See the
|
||||
@ -109,7 +109,7 @@ Calling ``scaler.unscale_(optimizer)`` before clipping enables you to clip unsca
|
||||
this iteration, so ``scaler.step(optimizer)`` knows not to redundantly unscale gradients before
|
||||
(internally) calling ``optimizer.step()``.
|
||||
|
||||
.. currentmodule:: torch.cuda.amp.GradScaler
|
||||
.. currentmodule:: torch.amp.GradScaler
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_<unscale_>` should only be called once per optimizer per :meth:`step<step>` call,
|
||||
@ -155,7 +155,7 @@ where you called :meth:`step<step>` for a full effective batch::
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
.. currentmodule:: torch.cuda.amp
|
||||
.. currentmodule:: torch.amp
|
||||
|
||||
Gradient penalty
|
||||
----------------
|
||||
@ -241,7 +241,7 @@ Here's how that looks for the same L2 penalty::
|
||||
Working with Multiple Models, Losses, and Optimizers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. currentmodule:: torch.cuda.amp.GradScaler
|
||||
.. currentmodule:: torch.amp.GradScaler
|
||||
|
||||
If your network has multiple losses, you must call :meth:`scaler.scale<scale>` on each of them individually.
|
||||
If your network has multiple optimizers, you may call :meth:`scaler.unscale_<unscale_>` on any of them individually,
|
||||
@ -250,7 +250,7 @@ and you must call :meth:`scaler.step<step>` on each of them individually.
|
||||
However, :meth:`scaler.update<update>` should only be called once,
|
||||
after all optimizers used this iteration have been stepped::
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.amp.GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
@ -282,7 +282,7 @@ while the other one does not. Since step skipping occurs rarely (every several
|
||||
this should not impede convergence. If you observe poor convergence after adding gradient scaling
|
||||
to a multiple-optimizer model, please report a bug.
|
||||
|
||||
.. currentmodule:: torch.cuda.amp
|
||||
.. currentmodule:: torch.amp
|
||||
|
||||
.. _amp-multigpu:
|
||||
|
||||
@ -347,7 +347,7 @@ is to disable autocast and force execution in ``float32`` ( or ``dtype``) at any
|
||||
output = imported_function(input1.float(), input2.float())
|
||||
|
||||
If you're the function's author (or can alter its definition) a better solution is to use the
|
||||
:func:`torch.cuda.amp.custom_fwd` and :func:`torch.cuda.amp.custom_bwd` decorators as shown in
|
||||
:func:`torch.amp.custom_fwd` and :func:`torch.amp.custom_bwd` decorators as shown in
|
||||
the relevant case below.
|
||||
|
||||
Functions with multiple inputs or autocastable ops
|
||||
@ -380,20 +380,21 @@ Functions that need a particular ``dtype``
|
||||
------------------------------------------
|
||||
|
||||
Consider a custom function that requires ``torch.float32`` inputs.
|
||||
Apply :func:`custom_fwd(cast_inputs=torch.float32)<custom_fwd>` to ``forward``
|
||||
and :func:`custom_bwd<custom_bwd>` (with no arguments) to ``backward``.
|
||||
If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point CUDA Tensor
|
||||
inputs to ``float32``, and locally disable autocast during ``forward`` and ``backward``::
|
||||
Apply :func:`custom_fwd(device_type='cuda', cast_inputs=torch.float32)<custom_fwd>` to ``forward``
|
||||
and :func:`custom_bwd(device_type='cuda')<custom_bwd>` to ``backward``.
|
||||
If ``forward`` runs in an autocast-enabled region, the decorators cast floating-point Tensor
|
||||
inputs to ``float32`` on designated device assigned by the argument `device_type <../amp.html>`_,
|
||||
`CUDA` in this example, and locally disable autocast during ``forward`` and ``backward``::
|
||||
|
||||
class MyFloat32Func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
...
|
||||
return fwd_output
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@custom_bwd(device_type='cuda')
|
||||
def backward(ctx, grad):
|
||||
...
|
||||
|
||||
|
||||
@ -3,54 +3,4 @@
|
||||
PyTorch Custom Operators Landing Page
|
||||
=====================================
|
||||
|
||||
PyTorch offers a large library of operators that work on Tensors (e.g. :func:`torch.add`,
|
||||
:func:`torch.sum`, etc). However, you may wish to bring a new custom operation to PyTorch
|
||||
and get it to work with subsystems like :func:`torch.compile`, autograd, and :func:`torch.vmap`.
|
||||
In order to do so, you must register the custom operation with PyTorch via the Python
|
||||
:ref:`torch-library-docs` or C++ TORCH_LIBRARY APIs.
|
||||
|
||||
TL;DR
|
||||
-----
|
||||
|
||||
How do I author a custom op from Python?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
..
|
||||
[comment] TODO(rzou): The following will be a link to a tutorial on the PyTorch tutorials site in 2.4
|
||||
|
||||
Please see the `Python Custom Operators tutorial <https://colab.research.google.com/drive/1xCh5BNHxGnutqGLMHaHwm47cbDL9CB1g#scrollTo=gg6WorNtKzeh>`_
|
||||
|
||||
|
||||
How do I integrate custom C++ and/or CUDA code with PyTorch?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
..
|
||||
[comment] TODO(rzou): The following will be a link to a tutorial on the PyTorch tutorials site in 2.4
|
||||
|
||||
Please see the `Custom C++ and CUDA Operators tutorial <https://docs.google.com/document/d/1-LdJZBzlxiF0Tm-8NfbyFvRJaofdwRgLcycXGmlIpS0>`_
|
||||
|
||||
|
||||
For more details
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Please see `The Custom Operators Manual (gdoc) <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU>`_
|
||||
(we're working on moving the information to our docs site). We recommend that you
|
||||
first read one of the tutorials above and then use the Custom Operators Manual as a reference;
|
||||
it is not meant to be read head to toe.
|
||||
|
||||
When should I create a Custom Operator?
|
||||
---------------------------------------
|
||||
If your operation is expressible as a composition of built-in PyTorch operators
|
||||
then please write it as a Python function and call it instead of creating a
|
||||
custom operator. Use the operator registration APIs to create a custom op if you
|
||||
are calling into some library that PyTorch doesn't understand (e.g. custom C/C++ code,
|
||||
a custom CUDA kernel, or Python bindings to C/C++/CUDA extensions).
|
||||
|
||||
Why should I create a Custom Operator?
|
||||
--------------------------------------
|
||||
|
||||
It is possible to use a C/C++/CUDA kernel by grabbing a Tensor's data pointer
|
||||
and passing it to a pybind'ed kernel. However, this approach doesn't compose with
|
||||
PyTorch subsystems like autograd, torch.compile, vmap, and more. In order
|
||||
for an operation to compose with PyTorch subsystems, it must be registered
|
||||
via the operator registration APIs.
|
||||
`This page has moved. Click here for the new page. <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`_
|
||||
|
||||
339
docs/source/notes/get_start_xpu.rst
Normal file
339
docs/source/notes/get_start_xpu.rst
Normal file
@ -0,0 +1,339 @@
|
||||
Pytorch 2.4: Getting Started on Intel GPU
|
||||
=========================================
|
||||
|
||||
The support for Intel GPUs is released alongside PyTorch v2.4.
|
||||
|
||||
This release only supports build from source for Intel GPUs.
|
||||
|
||||
Hardware Prerequisites
|
||||
----------------------
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Supported Hardware
|
||||
- Intel® Data Center GPU Max Series
|
||||
* - Supported OS
|
||||
- Linux
|
||||
|
||||
|
||||
PyTorch for Intel GPUs is compatible with Intel® Data Center GPU Max Series and only supports OS Linux with release 2.4.
|
||||
|
||||
Software Prerequisites
|
||||
----------------------
|
||||
|
||||
As a prerequisite, install the driver and required packages by following the `PyTorch Installation Prerequisites for Intel GPUs <https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html>`_.
|
||||
|
||||
Set up Environment
|
||||
------------------
|
||||
|
||||
Before you begin, you need to set up the environment. This can be done by sourcing the ``setvars.sh`` script provided by the ``intel-for-pytorch-gpu-dev`` and ``intel-pti-dev`` packages.
|
||||
|
||||
.. code-block::
|
||||
|
||||
source ${ONEAPI_ROOT}/setvars.sh
|
||||
|
||||
.. note::
|
||||
The ``ONEAPI_ROOT`` is the folder you installed your ``intel-for-pytorch-gpu-dev`` and ``intel-pti-dev`` packages. Typically, it is located at ``/opt/intel/oneapi/`` or ``~/intel/oneapi/``.
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
Now we have all the required packages installed and environment acitvated. Use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` by building from source. For more details, refer to official guides in `PyTorch from source <https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support>`_, `Vision from source <https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md#development-installation>`_ and `Audio from source <https://pytorch.org/audio/main/build.linux.html>`_.
|
||||
|
||||
.. code-block::
|
||||
|
||||
# Get PyTorch Source Code
|
||||
git clone --recursive https://github.com/pytorch/pytorch
|
||||
cd pytorch
|
||||
git checkout main # or checkout the specific release version >= v2.4
|
||||
git submodule sync
|
||||
git submodule update --init --recursive
|
||||
|
||||
# Get required packages for compilation
|
||||
conda install cmake ninja
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Pytorch for Intel GPUs only support Linux platform for now.
|
||||
# Install the required packages for pytorch compilation.
|
||||
conda install intel::mkl-static intel::mkl-include
|
||||
|
||||
# (optional) If using torch.compile with inductor/triton, install the matching version of triton
|
||||
# Run from the pytorch directory after cloning
|
||||
# For Intel GPU support, please explicitly `export USE_XPU=1` before running command.
|
||||
USE_XPU=1 make triton
|
||||
|
||||
# If you would like to compile PyTorch with new C++ ABI enabled, then first run this command:
|
||||
export _GLIBCXX_USE_CXX11_ABI=1
|
||||
|
||||
# pytorch build from source
|
||||
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
|
||||
python setup.py develop
|
||||
cd ..
|
||||
|
||||
# (optional) If using torchvison.
|
||||
# Get torchvision Code
|
||||
git clone https://github.com/pytorch/vision.git
|
||||
cd vision
|
||||
git checkout main # or specific version
|
||||
python setup.py develop
|
||||
cd ..
|
||||
|
||||
# (optional) If using torchaudio.
|
||||
# Get torchaudio Code
|
||||
git clone https://github.com/pytorch/audio.git
|
||||
cd audio
|
||||
pip install -r requirements.txt
|
||||
git checkout main # or specific version
|
||||
git submodule sync
|
||||
git submodule update --init --recursive
|
||||
python setup.py develop
|
||||
cd ..
|
||||
|
||||
Check availability for Intel GPU
|
||||
--------------------------------
|
||||
|
||||
.. note::
|
||||
Make sure the environment is properly set up by following `Environment Set up <#set-up-environment>`_ before running the code.
|
||||
|
||||
To check if your Intel GPU is available, you would typically use the following code:
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
torch.xpu.is_available() # torch.xpu is the API for Intel GPU support
|
||||
|
||||
If the output is ``False``, ensure that you have Intel GPU in your system and correctly follow the `PyTorch Installation Prerequisites for Intel GPUs <https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html>`_. Then, check that the PyTorch compilation is correctly finished.
|
||||
|
||||
Minimum Code Change
|
||||
-------------------
|
||||
|
||||
If you are migrating code from ``cuda``, you would change references from ``cuda`` to ``xpu``. For example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
# CUDA CODE
|
||||
tensor = torch.tensor([1.0, 2.0]).to("cuda")
|
||||
|
||||
# CODE for Intel GPU
|
||||
tensor = torch.tensor([1.0, 2.0]).to("xpu")
|
||||
|
||||
The following points outline the support and limitations for PyTorch with Intel GPU:
|
||||
|
||||
#. Both training and inference workflows are supported.
|
||||
#. Both eager mode and ``torch.compile`` is supported.
|
||||
#. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported.
|
||||
#. Models that depend on third-party components, will not be supported until PyTorch v2.5 or later.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
This section contains usage examples for both inference and training workflows.
|
||||
|
||||
Inference Examples
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Here is a few inference workflow examples.
|
||||
|
||||
|
||||
Inference with FP32
|
||||
"""""""""""""""""""
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
|
||||
model.eval()
|
||||
data = torch.rand(1, 3, 224, 224)
|
||||
|
||||
######## code changes #######
|
||||
model = model.to("xpu")
|
||||
data = data.to("xpu")
|
||||
######## code changes #######
|
||||
|
||||
with torch.no_grad():
|
||||
model(data)
|
||||
|
||||
print("Execution finished")
|
||||
|
||||
Inference with AMP
|
||||
""""""""""""""""""
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
|
||||
model.eval()
|
||||
data = torch.rand(1, 3, 224, 224)
|
||||
|
||||
#################### code changes #################
|
||||
model = model.to("xpu")
|
||||
data = data.to("xpu")
|
||||
#################### code changes #################
|
||||
|
||||
with torch.no_grad():
|
||||
d = torch.rand(1, 3, 224, 224)
|
||||
############################# code changes #####################
|
||||
d = d.to("xpu")
|
||||
# set dtype=torch.bfloat16 for BF16
|
||||
with torch.autocast(device_type="xpu", dtype=torch.float16, enabled=True):
|
||||
############################# code changes #####################
|
||||
model(data)
|
||||
|
||||
print("Execution finished")
|
||||
|
||||
Inference with ``torch.compile``
|
||||
""""""""""""""""""""""""""""""""
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
|
||||
model.eval()
|
||||
data = torch.rand(1, 3, 224, 224)
|
||||
ITERS = 10
|
||||
|
||||
######## code changes #######
|
||||
model = model.to("xpu")
|
||||
data = data.to("xpu")
|
||||
######## code changes #######
|
||||
|
||||
model = torch.compile(model)
|
||||
for i in range(ITERS):
|
||||
with torch.no_grad():
|
||||
model(data)
|
||||
|
||||
print("Execution finished")
|
||||
|
||||
Training Examples
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Here is a few training workflow examples.
|
||||
|
||||
Train with FP32
|
||||
"""""""""""""""
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
LR = 0.001
|
||||
DOWNLOAD = True
|
||||
DATA = "datasets/cifar10/"
|
||||
|
||||
transform = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize((224, 224)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root=DATA,
|
||||
train=True,
|
||||
transform=transform,
|
||||
download=DOWNLOAD,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)
|
||||
|
||||
model = torchvision.models.resnet50()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
|
||||
model.train()
|
||||
######################## code changes #######################
|
||||
model = model.to("xpu")
|
||||
criterion = criterion.to("xpu")
|
||||
######################## code changes #######################
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
########## code changes ##########
|
||||
data = data.to("xpu")
|
||||
target = target.to("xpu")
|
||||
########## code changes ##########
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(batch_idx)
|
||||
torch.save(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
},
|
||||
"checkpoint.pth",
|
||||
)
|
||||
|
||||
print("Execution finished")
|
||||
|
||||
Train with AMP
|
||||
""""""""""""""
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
LR = 0.001
|
||||
DOWNLOAD = True
|
||||
DATA = "datasets/cifar10/"
|
||||
|
||||
use_amp=True
|
||||
|
||||
transform = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize((224, 224)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root=DATA,
|
||||
train=True,
|
||||
transform=transform,
|
||||
download=DOWNLOAD,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)
|
||||
|
||||
model = torchvision.models.resnet50()
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
|
||||
scaler = torch.amp.GradScaler(enabled=use_amp)
|
||||
|
||||
model.train()
|
||||
######################## code changes #######################
|
||||
model = model.to("xpu")
|
||||
criterion = criterion.to("xpu")
|
||||
######################## code changes #######################
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
########## code changes ##########
|
||||
data = data.to("xpu")
|
||||
target = target.to("xpu")
|
||||
########## code changes ##########
|
||||
# set dtype=torch.bfloat16 for BF16
|
||||
with torch.autocast(device_type="xpu", dtype=torch.float16, enabled=use_amp):
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
print(batch_idx)
|
||||
|
||||
torch.save(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
},
|
||||
"checkpoint.pth",
|
||||
)
|
||||
|
||||
print("Execution finished")
|
||||
@ -779,4 +779,5 @@ Tensor class reference
|
||||
Tensor.where
|
||||
Tensor.xlogy
|
||||
Tensor.xlogy_
|
||||
Tensor.xpu
|
||||
Tensor.zero_
|
||||
|
||||
@ -22,7 +22,7 @@ written in Python and it marks the transition of PyTorch from C++ to Python.
|
||||
* **TorchInductor** is the default ``torch.compile`` deep learning compiler
|
||||
that generates fast code for multiple accelerators and backends. You
|
||||
need to use a backend compiler to make speedups through ``torch.compile``
|
||||
possible. For NVIDIA and AMD GPUs, it leverages OpenAI Triton as the key
|
||||
possible. For NVIDIA, AMD and Intel GPUs, it leverages OpenAI Triton as the key
|
||||
building block.
|
||||
|
||||
* **AOT Autograd** captures not only the user-level code, but also backpropagation,
|
||||
|
||||
@ -15,7 +15,8 @@ understanding of how you can use ``torch.compile`` in your own programs.
|
||||
.. note::
|
||||
To run this script, you need to have at least one GPU on your machine.
|
||||
If you do not have a GPU, you can remove the ``.to(device="cuda:0")`` code
|
||||
in the snippet below and it will run on CPU.
|
||||
in the snippet below and it will run on CPU. You can also set device to
|
||||
``xpu:0`` to run on Intel® GPUs.
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
||||
13
setup.py
13
setup.py
@ -561,7 +561,6 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
"libomp.dylib" if os.uname().machine == "arm64" else "libiomp5.dylib"
|
||||
)
|
||||
omp_rpath_lib_path = os.path.join("@rpath", omp_lib_name)
|
||||
omp_loader_lib_path = os.path.join("@loader_path", omp_lib_name)
|
||||
if omp_rpath_lib_path not in libs:
|
||||
return
|
||||
|
||||
@ -572,17 +571,16 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
continue
|
||||
target_lib = os.path.join(self.build_lib, "torch", "lib", omp_lib_name)
|
||||
self.copy_file(source_lib, target_lib)
|
||||
# Change OMP library load path to loader_path and delete old rpath
|
||||
# Delete old rpath and add @loader_lib to the rpath
|
||||
# This should prevent delocate from attempting to package another instance
|
||||
# of OpenMP library in torch wheel
|
||||
# of OpenMP library in torch wheel as well as loading two libomp.dylib into
|
||||
# the address space, as libraries are cached by their unresolved names
|
||||
subprocess.check_call(
|
||||
[
|
||||
"install_name_tool",
|
||||
"-change",
|
||||
omp_rpath_lib_path,
|
||||
omp_loader_lib_path,
|
||||
"-delete_rpath",
|
||||
"-rpath",
|
||||
rpath,
|
||||
"@loader_path",
|
||||
libtorch_cpu_path,
|
||||
]
|
||||
)
|
||||
@ -1134,7 +1132,6 @@ def main():
|
||||
"networkx",
|
||||
"jinja2",
|
||||
"fsspec",
|
||||
'mkl>=2021.1.1,<=2021.4.0; platform_system == "Windows"',
|
||||
]
|
||||
|
||||
if sys.version_info >= (3, 12, 0):
|
||||
|
||||
@ -104,8 +104,10 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
for name, dtensor in state_dict.items():
|
||||
self.assertEqual(dtensor.device.type, "cpu")
|
||||
|
||||
# Temporarily disable 2D state dict test, while strided sharding is being devleoped.
|
||||
# TODO: re-enable this test once 2d state_dict is ready.
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_2d_state_dict_save_load(self):
|
||||
def _temp_disable_test_2d_state_dict_save_load(self):
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
|
||||
@ -990,9 +990,31 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||
optim.step()
|
||||
ref_optim.step()
|
||||
|
||||
# TODO: remove this test when 2d state_dict is ready.
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipIfRocm
|
||||
def test_raise_not_implemented_state_dict_if_2d(self):
|
||||
def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
|
||||
_model = Transformer.parallelize(_model, mesh["tp"], use_seq_parallel)
|
||||
for layer in _model.layers:
|
||||
fully_shard(layer, mesh=mesh["dp"])
|
||||
fully_shard(_model, mesh=mesh["dp"])
|
||||
return _model
|
||||
|
||||
global_mesh = self.init_global_mesh()
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = parallelize(Transformer(model_args), global_mesh, True)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "2D"):
|
||||
get_model_state_dict(model)
|
||||
|
||||
# Temporarily disable 2D state dict test, while strided sharding is being devleoped.
|
||||
# TODO: re-enable this test once 2d state_dict is ready.
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
def test_train_parity_2d_transformer_checkpoint_resume(self):
|
||||
def _temp_disable_test_train_parity_2d_transformer_checkpoint_resume(self):
|
||||
"""
|
||||
Tests train parity of a 2D transformer without checkpointing against a
|
||||
2D transformer with a checkpoint save/load.
|
||||
|
||||
@ -536,6 +536,16 @@ class DTensorTest(DTensorTestBase):
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
# Test weights_only load
|
||||
try:
|
||||
torch.serialization.add_safe_globals(
|
||||
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
|
||||
)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
|
||||
@ -18,7 +18,9 @@ from torch.distributed.checkpoint.state_dict import (
|
||||
_patch_model_state_dict,
|
||||
_patch_optimizer_state_dict,
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
get_state_dict,
|
||||
set_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
|
||||
from torch.distributed.checkpoint.utils import CheckpointException
|
||||
@ -417,6 +419,48 @@ class TestNoCPU(DTensorTestBase):
|
||||
f.result()
|
||||
|
||||
|
||||
class TestInitStateDict(DTensorTestBase):
|
||||
@with_temp_dir
|
||||
def test_init_state_dict(self):
|
||||
temp_dir = self.temp_dir
|
||||
model = TestDummyModel()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.1)
|
||||
|
||||
state_dict_to_save = {
|
||||
"model": get_model_state_dict(model),
|
||||
"optimizer": get_optimizer_state_dict(model, optim),
|
||||
}
|
||||
DCP.save(state_dict_to_save, checkpoint_id=temp_dir)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_2 = TestDummyModel()
|
||||
# Changing the learning rate for optimizer, which is not a tensor.
|
||||
optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.2)
|
||||
|
||||
msd = get_model_state_dict(model_2)
|
||||
osd = get_optimizer_state_dict(model_2, optim_2)
|
||||
|
||||
state_dict_to_load = {"model": msd, "optimizer": osd}
|
||||
DCP.load(state_dict_to_load, checkpoint_id=temp_dir)
|
||||
|
||||
# We need to check that the two variables point to the same object in memory,
|
||||
# since we claim DCP is in-place loading.
|
||||
self.assertTrue(msd is state_dict_to_load["model"])
|
||||
self.assertTrue(osd is state_dict_to_load["optimizer"])
|
||||
|
||||
# set_state_dict calls load_state_dict for model and optimizer.
|
||||
# so we should see the optim_2.param_groups learning rate is 0.1 instead of 0.2 now.
|
||||
set_state_dict(
|
||||
model_2,
|
||||
optim_2,
|
||||
model_state_dict=state_dict_to_load["model"],
|
||||
optim_state_dict=state_dict_to_load["optimizer"],
|
||||
)
|
||||
self.assertEqual(msd, get_model_state_dict(model_2))
|
||||
self.assertEqual(osd, get_optimizer_state_dict(model_2, optim_2))
|
||||
self.assertEqual(optim_2.param_groups[0]["lr"], 0.1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestE2ESaveAndLoad)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -30,7 +30,11 @@ from torch.distributed.checkpoint.state_dict import (
|
||||
set_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
|
||||
from torch.distributed.fsdp import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.optim import _apply_optimizer_in_backward
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -67,7 +71,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
def _test_save_load(
|
||||
self,
|
||||
@ -564,55 +568,71 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
|
||||
self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_broadcast_from_rank0(self) -> None:
|
||||
def inner_test(wrapper):
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
fsdp_model = wrapper(copy.deepcopy(model))
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
|
||||
def _test_broadcast_from_rank0(self, wrapper) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
fsdp_model = wrapper(copy.deepcopy(model))
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
|
||||
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
states, optim_states = get_state_dict(model, optim)
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
states, optim_states = get_state_dict(model, optim)
|
||||
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
|
||||
def check(equal):
|
||||
fsdp_states = get_model_state_dict(
|
||||
fsdp_model,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
fsdp_optim_states = get_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
if equal:
|
||||
self.assertEqual(states, fsdp_states)
|
||||
self.assertEqual(optim_states, fsdp_optim_states)
|
||||
else:
|
||||
self.assertNotEqual(states, fsdp_states)
|
||||
self.assertNotEqual(optim_states, fsdp_optim_states)
|
||||
|
||||
check(equal=True)
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
check(equal=False)
|
||||
|
||||
# Drop the states to simulate loading from rank0
|
||||
if dist.get_rank() > 0:
|
||||
load_states = {}
|
||||
load_states2 = {}
|
||||
load_optim_states = {}
|
||||
def check(equal):
|
||||
fsdp_states = get_model_state_dict(
|
||||
fsdp_model,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
fsdp_optim_states = get_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
if equal:
|
||||
self.assertEqual(states, fsdp_states)
|
||||
self.assertEqual(optim_states, fsdp_optim_states)
|
||||
else:
|
||||
load_states = copy.deepcopy(states)
|
||||
load_states2 = copy.deepcopy(states)
|
||||
load_optim_states = copy.deepcopy(optim_states)
|
||||
self.assertNotEqual(states, fsdp_states)
|
||||
self.assertNotEqual(optim_states, fsdp_optim_states)
|
||||
|
||||
check(equal=True)
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
check(equal=False)
|
||||
|
||||
# Drop the states to simulate loading from rank0
|
||||
if dist.get_rank() > 0:
|
||||
load_states = {}
|
||||
load_states2 = {}
|
||||
load_optim_states = {}
|
||||
else:
|
||||
load_states = copy.deepcopy(states)
|
||||
load_states2 = copy.deepcopy(states)
|
||||
load_optim_states = copy.deepcopy(optim_states)
|
||||
|
||||
set_model_state_dict(
|
||||
fsdp_model,
|
||||
model_state_dict=load_states,
|
||||
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
|
||||
)
|
||||
set_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
optim_state_dict=load_optim_states,
|
||||
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
|
||||
)
|
||||
|
||||
check(equal=True)
|
||||
# Verify the `strict` flag.
|
||||
load_states = load_states2
|
||||
if load_states:
|
||||
key = next(iter(load_states.keys()))
|
||||
load_states.pop(key)
|
||||
with self.assertRaisesRegex(RuntimeError, "Missing key"):
|
||||
set_model_state_dict(
|
||||
fsdp_model,
|
||||
model_state_dict=load_states,
|
||||
@ -620,30 +640,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
broadcast_from_rank0=True, full_state_dict=True
|
||||
),
|
||||
)
|
||||
set_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
optim_state_dict=load_optim_states,
|
||||
options=StateDictOptions(
|
||||
broadcast_from_rank0=True, full_state_dict=True
|
||||
),
|
||||
)
|
||||
|
||||
check(equal=True)
|
||||
# Verify the `strict` flag.
|
||||
load_states = load_states2
|
||||
if load_states:
|
||||
key = next(iter(load_states.keys()))
|
||||
load_states.pop(key)
|
||||
with self.assertRaisesRegex(RuntimeError, "Missing key"):
|
||||
set_model_state_dict(
|
||||
fsdp_model,
|
||||
model_state_dict=load_states,
|
||||
options=StateDictOptions(
|
||||
broadcast_from_rank0=True, full_state_dict=True
|
||||
),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_broadcast_from_rank0(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -652,7 +652,24 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
functools.partial(FSDP, device_mesh=device_mesh),
|
||||
]
|
||||
},
|
||||
inner_test,
|
||||
self._test_broadcast_from_rank0,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_broadcast_from_rank0_hsdp(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
|
||||
self.run_subtests(
|
||||
{
|
||||
"wrapper": [
|
||||
functools.partial(
|
||||
FSDP,
|
||||
device_mesh=device_mesh,
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
),
|
||||
]
|
||||
},
|
||||
self._test_broadcast_from_rank0,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@ -813,6 +830,33 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
):
|
||||
get_model_state_dict(model)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_shared_weight(self):
|
||||
class TiedEmbeddingModel(nn.Module):
|
||||
def __init__(self, vocab_size, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
||||
self.decoder = nn.Linear(embedding_dim, vocab_size)
|
||||
self.decoder.weight = self.embedding.weight # Tying weights
|
||||
|
||||
def forward(self, input):
|
||||
input = (input * 10).to(torch.int)
|
||||
embedded = self.embedding(input)
|
||||
output = self.decoder(embedded)
|
||||
return output
|
||||
|
||||
def init_model_optim():
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
|
||||
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
|
||||
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
|
||||
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
|
||||
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
|
||||
self._test_save_load(init_model_optim)
|
||||
|
||||
|
||||
class TestNoComm(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
|
||||
@ -246,14 +246,15 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
return module.foobar(x)
|
||||
|
||||
with self.assertWarnsOnceRegex(
|
||||
UserWarning, ".*https://pytorch.org/docs/main/notes/custom_operators.html.*"
|
||||
UserWarning,
|
||||
".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
|
||||
):
|
||||
f(x)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = list(counters["graph_break"].keys())[0]
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
"""Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/docs/main/notes/custom_operators.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
|
||||
"""Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
|
||||
)
|
||||
|
||||
cpp_source = """
|
||||
|
||||
@ -10452,7 +10452,6 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
|
||||
return kernels
|
||||
|
||||
@expectedFailureXPU
|
||||
def test_divisible_by_16_covers_numel_args(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@ -4534,7 +4534,7 @@ class TestLinalg(TestCase):
|
||||
try:
|
||||
import os
|
||||
os.remove(filename)
|
||||
finally:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# disables TunableOp, no file will be written, restore to default values
|
||||
@ -4545,6 +4545,90 @@ class TestLinalg(TestCase):
|
||||
assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off after resetting"
|
||||
assert torch.cuda.tunable.get_max_tuning_iterations() == 100
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float)
|
||||
def test_bmm_tunableop_rocm(self, device, dtype):
|
||||
# buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
|
||||
torch.cuda.tunable.enable(True)
|
||||
ordinal = torch.cuda.current_device()
|
||||
filename = f"tunableop_results{ordinal}.csv"
|
||||
torch.cuda.tunable.set_filename(filename)
|
||||
iterations = torch.cuda.tunable.get_max_tuning_iterations()
|
||||
torch.cuda.tunable.set_max_tuning_iterations(10)
|
||||
# the following 3 cases cover all previous failure cases and are here to catch regressions
|
||||
B = 16
|
||||
N = M = K = 256
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda:0")
|
||||
# case 1
|
||||
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
|
||||
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
|
||||
out = torch.bmm(i1, i2)
|
||||
# case 2
|
||||
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
|
||||
i1 = torch.permute(i1, (1, 2, 0))
|
||||
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
|
||||
i2 = torch.permute(i2, (1, 0, 2))
|
||||
out = torch.bmm(i1, i2)
|
||||
# case 3
|
||||
i1 = torch.randn((N, B, M), device=device, dtype=dtype)
|
||||
i1 = torch.permute(i1, (1, 0, 2))
|
||||
i2 = torch.randn((M, B, K), device=device, dtype=dtype)
|
||||
i2 = torch.permute(i2, (1, 2, 0))
|
||||
out = torch.bmm(i1, i2)
|
||||
# clean up, remove any file that was generated
|
||||
try:
|
||||
import os
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# reset back to prior settings
|
||||
torch.cuda.tunable.set_max_tuning_iterations(iterations)
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float)
|
||||
def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
|
||||
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
|
||||
import os
|
||||
# run operator first without tuning to ensure all rocm libs are loaded,
|
||||
# otherwise false positive mem leak
|
||||
B = 16
|
||||
N = M = K = 256
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda:0")
|
||||
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
|
||||
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
|
||||
out = torch.bmm(i1, i2)
|
||||
# enable tunableop numeric check via env variable.
|
||||
PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
|
||||
prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
|
||||
try:
|
||||
os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
|
||||
torch.cuda.tunable.enable(True)
|
||||
ordinal = torch.cuda.current_device()
|
||||
filename = f"tunableop_results{ordinal}.csv"
|
||||
torch.cuda.tunable.set_filename(filename)
|
||||
iterations = torch.cuda.tunable.get_max_tuning_iterations()
|
||||
torch.cuda.tunable.set_max_tuning_iterations(10)
|
||||
with CudaMemoryLeakCheck(self):
|
||||
out = torch.bmm(i1, i2)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(iterations)
|
||||
torch.cuda.tunable.enable(False)
|
||||
# clean up, remove any file that was generated
|
||||
try:
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
finally:
|
||||
if prev_val is None:
|
||||
del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
|
||||
else:
|
||||
os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
|
||||
|
||||
|
||||
@dtypes(torch.float, torch.complex64)
|
||||
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
||||
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
|
||||
|
||||
@ -1798,29 +1798,27 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
def test_parameterlistdict_pickle(self):
|
||||
# warning from torch.load call in _load_from_bytes used in UntypedStorage.__reduce__
|
||||
WEIGHTS_ONLY_WARN = "You are using `torch.load` with `weights_only=False`"
|
||||
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with self.assertWarnsRegex(FutureWarning, WEIGHTS_ONLY_WARN):
|
||||
m = pickle.loads(pickle.dumps(m))
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
# Test whether loading from older checkpoints works without triggering warnings
|
||||
m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
|
||||
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with self.assertWarnsRegex(FutureWarning, WEIGHTS_ONLY_WARN):
|
||||
m = pickle.loads(pickle.dumps(m))
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with self.assertWarnsRegex(FutureWarning, WEIGHTS_ONLY_WARN):
|
||||
m = pickle.loads(pickle.dumps(m))
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
# Test whether loading from older checkpoints works without triggering warnings
|
||||
m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
|
||||
del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with self.assertWarnsRegex(FutureWarning, WEIGHTS_ONLY_WARN):
|
||||
m = pickle.loads(pickle.dumps(m))
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
def test_weight_norm_pickle(self):
|
||||
m = torch.nn.utils.weight_norm(nn.Linear(5, 7))
|
||||
|
||||
@ -15,7 +15,7 @@ import pickle
|
||||
import shutil
|
||||
import pathlib
|
||||
import platform
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple, OrderedDict
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
@ -23,7 +23,7 @@ from torch._utils_internal import get_file_path_2
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch.utils._import_utils import import_dill
|
||||
from torch.serialization import check_module_version_greater_or_equal, get_default_load_endianness, \
|
||||
set_default_load_endianness, LoadEndianness
|
||||
set_default_load_endianness, LoadEndianness, SourceChangeWarning
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName,
|
||||
@ -804,6 +804,17 @@ class serialization_method:
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.save = self.torch_save
|
||||
|
||||
Point = namedtuple('Point', ['x', 'y'])
|
||||
|
||||
class ClassThatUsesBuildInstruction:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
# Third item, state here will cause pickle to push a BUILD instruction
|
||||
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
class TestBothSerialization(TestCase):
|
||||
@parametrize("weights_only", (True, False))
|
||||
@ -854,7 +865,9 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
loaded = torch.load(checkpoint)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEqual(len(w), 0)
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, FutureWarning)
|
||||
self.assertTrue("You are using `torch.load` with `weights_only=False`" in str(w[0].message))
|
||||
|
||||
# Replace the module with different source
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
@ -865,8 +878,9 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
loaded = torch.load(checkpoint)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue(w[0].category, 'SourceChangeWarning')
|
||||
self.assertEqual(len(w), 2)
|
||||
self.assertEqual(w[0].category, FutureWarning)
|
||||
self.assertEqual(w[1].category, SourceChangeWarning)
|
||||
|
||||
def test_serialization_container(self):
|
||||
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
|
||||
@ -1040,8 +1054,79 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertIsNone(torch.load(f, weights_only=False))
|
||||
f.seek(0)
|
||||
# Safe load should assert
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
|
||||
torch.load(f, weights_only=True)
|
||||
try:
|
||||
torch.serialization.add_safe_globals([print])
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
def test_weights_only_safe_globals_newobj(self):
|
||||
# This will use NEWOBJ
|
||||
p = Point(x=1, y=2)
|
||||
with BytesIOContext() as f:
|
||||
torch.save(p, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"GLOBAL __main__.Point was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
f.seek(0)
|
||||
try:
|
||||
torch.serialization.add_safe_globals([Point])
|
||||
loaded_p = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_p, p)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
def test_weights_only_safe_globals_build(self):
|
||||
counter = 0
|
||||
|
||||
def fake_set_state(obj, *args):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
c = ClassThatUsesBuildInstruction(2)
|
||||
with BytesIOContext() as f:
|
||||
torch.save(c, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
try:
|
||||
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
|
||||
# Test dict update path
|
||||
f.seek(0)
|
||||
loaded_c = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_c.num, 2)
|
||||
self.assertEqual(loaded_c.foo, 'bar')
|
||||
# Test setstate path
|
||||
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
|
||||
f.seek(0)
|
||||
loaded_c = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_c.num, 2)
|
||||
self.assertEqual(counter, 1)
|
||||
self.assertFalse(hasattr(loaded_c, 'foo'))
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
ClassThatUsesBuildInstruction.__setstate__ = None
|
||||
|
||||
@parametrize("unsafe_global", [True, False])
|
||||
def test_weights_only_error(self, unsafe_global):
|
||||
sd = {'t': TwoTensor(torch.randn(2), torch.randn(2))}
|
||||
pickle_protocol = torch.serialization.DEFAULT_PROTOCOL if unsafe_global else 5
|
||||
with BytesIOContext() as f:
|
||||
torch.save(sd, f, pickle_protocol=pickle_protocol)
|
||||
f.seek(0)
|
||||
if unsafe_global:
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` to allowlist"):
|
||||
torch.load(f, weights_only=True)
|
||||
else:
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"file an issue with the following so that we can make `weights_only=True`"):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@parametrize('weights_only', (False, True))
|
||||
def test_serialization_math_bits(self, weights_only):
|
||||
|
||||
@ -1680,7 +1680,7 @@ err_epilogue = (
|
||||
"(and fall back to eager-mode PyTorch) on all ops "
|
||||
"that have do not have the 'pt2_compliant_tag'. "
|
||||
"Please see the following doc for how to mark this op as PT2 compliant "
|
||||
"https://pytorch.org/docs/main/notes/custom_operators.html"
|
||||
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -655,7 +655,7 @@ class SkipFunctionVariable(VariableTracker):
|
||||
f"so the PyTorch team can add support for it and see the next case for a workaround. "
|
||||
f"If it is a third-party C/C++ Python extension, please "
|
||||
f"either wrap it into a PyTorch-understood custom operator "
|
||||
f"(see https://pytorch.org/docs/main/notes/custom_operators.html "
|
||||
f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
|
||||
f"for more details) or, if it is traceable, use "
|
||||
f"torch.compiler.allow_in_graph."
|
||||
)
|
||||
|
||||
@ -827,6 +827,7 @@ def fx_codegen_and_compile(
|
||||
else:
|
||||
output_strides.append(None)
|
||||
|
||||
_check_triton_bf16_support(graph)
|
||||
compiled_fn = graph.compile_to_fn()
|
||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||
metrics.num_bytes_accessed += num_bytes
|
||||
@ -1596,3 +1597,34 @@ def handle_dynamo_export_graph(
|
||||
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_triton_bf16_support(graph: GraphLowering) -> None:
|
||||
def warn_and_skip(device) -> None:
|
||||
from torch._dynamo.exc import SkipFrame
|
||||
|
||||
device_props = torch.cuda.get_device_properties(device)
|
||||
warnings.warn(
|
||||
f"{device_props.name} does not support bfloat16 compilation natively, skipping"
|
||||
)
|
||||
raise SkipFrame("BF16 is not supported")
|
||||
|
||||
for inp in graph.graph_inputs.values():
|
||||
device = getattr(inp, "get_device", lambda: torch.device("meta"))()
|
||||
if device.type != "cuda" or inp.get_dtype() != torch.bfloat16:
|
||||
continue
|
||||
# Print warning and skip frame if attempting to compile for bfloat16
|
||||
# on device without hardware support for dtype
|
||||
if torch.cuda.is_bf16_supported(including_emulation=False):
|
||||
return
|
||||
warn_and_skip(device)
|
||||
|
||||
for out in graph.graph_outputs:
|
||||
device = getattr(out, "get_device", lambda: torch.device("meta"))()
|
||||
if device.type != "cuda" or out.get_dtype() != torch.bfloat16:
|
||||
continue
|
||||
# Print warning and skip frame if attempting to compile for bfloat16
|
||||
# on device without hardware support for dtype
|
||||
if torch.cuda.is_bf16_supported(including_emulation=False):
|
||||
return
|
||||
warn_and_skip(device)
|
||||
|
||||
@ -571,6 +571,9 @@ def _get_torch_related_args(include_pytorch: bool, aot_mode: bool):
|
||||
if not aot_mode:
|
||||
libraries.append("torch_python")
|
||||
|
||||
if _IS_WINDOWS:
|
||||
libraries.append("sleef")
|
||||
|
||||
# Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
|
||||
if not config.abi_compatible:
|
||||
libraries.append("c10")
|
||||
|
||||
@ -687,7 +687,7 @@ class Reduction(Loops):
|
||||
numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
|
||||
|
||||
should_split = (
|
||||
get_device_type(device) == "cuda"
|
||||
is_gpu(get_device_type(device))
|
||||
and reduction_type
|
||||
not in {
|
||||
"argmax",
|
||||
@ -702,9 +702,13 @@ class Reduction(Loops):
|
||||
return ReductionHint.DEFAULT, 1
|
||||
|
||||
device_interface = get_interface_for_device(get_device_type(device))
|
||||
num_sm = device_interface.Worker.get_device_properties(
|
||||
device
|
||||
).multi_processor_count
|
||||
device_properties = device_interface.Worker.get_device_properties(device)
|
||||
if get_device_type(device) == "xpu":
|
||||
num_sm = device_properties.gpu_subslice_count
|
||||
else:
|
||||
# default is cuda behavior
|
||||
num_sm = device_properties.multi_processor_count
|
||||
|
||||
min_elements_per_thread = 32
|
||||
max_elements_per_thread = 512
|
||||
threads_per_sm = 2048
|
||||
|
||||
@ -962,3 +962,47 @@ class CallbackRegistry(Generic[P]):
|
||||
logger.exception(
|
||||
"Exception in callback for %s registered with gpu trace", self.name
|
||||
)
|
||||
|
||||
|
||||
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
|
||||
# for use in the weights_only Unpickler.
|
||||
|
||||
IMPORT_MAPPING = {
|
||||
"__builtin__": "builtins",
|
||||
"copy_reg": "copyreg",
|
||||
"Queue": "queue",
|
||||
"repr": "reprlib",
|
||||
"_abcoll": "collections.abc",
|
||||
# Non-mutual mappings.
|
||||
"UserDict": "collections",
|
||||
"UserList": "collections",
|
||||
"UserString": "collections",
|
||||
"whichdb": "dbm",
|
||||
"StringIO": "io",
|
||||
"cStringIO": "io",
|
||||
}
|
||||
|
||||
|
||||
# This contains rename rules that are easy to handle. We ignore the more
|
||||
# complex stuff (e.g. mapping the names in the urllib and types modules).
|
||||
# These rules should be run before import names are fixed.
|
||||
NAME_MAPPING = {
|
||||
("__builtin__", "xrange"): ("builtins", "range"),
|
||||
("__builtin__", "reduce"): ("functools", "reduce"),
|
||||
("__builtin__", "intern"): ("sys", "intern"),
|
||||
("__builtin__", "unichr"): ("builtins", "chr"),
|
||||
("__builtin__", "unicode"): ("builtins", "str"),
|
||||
("__builtin__", "long"): ("builtins", "int"),
|
||||
("itertools", "izip"): ("builtins", "zip"),
|
||||
("itertools", "imap"): ("builtins", "map"),
|
||||
("itertools", "ifilter"): ("builtins", "filter"),
|
||||
("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
|
||||
("itertools", "izip_longest"): ("itertools", "zip_longest"),
|
||||
("UserDict", "IterableUserDict"): ("collections", "UserDict"),
|
||||
("UserList", "UserList"): ("collections", "UserList"),
|
||||
("UserString", "UserString"): ("collections", "UserString"),
|
||||
# Non-mutual mappings.
|
||||
("__builtin__", "basestring"): ("builtins", "str"),
|
||||
("exceptions", "StandardError"): ("builtins", "Exception"),
|
||||
("UserDict", "UserDict"): ("collections", "UserDict"),
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
# weights = torch.load(buf, weights_only = True)
|
||||
|
||||
import functools as _functools
|
||||
import warnings
|
||||
from collections import Counter, OrderedDict
|
||||
from pickle import (
|
||||
APPEND,
|
||||
@ -68,6 +69,8 @@ from sys import maxsize
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
|
||||
|
||||
|
||||
_marked_safe_globals_list: List[Any] = []
|
||||
|
||||
@ -97,7 +100,8 @@ def _clear_safe_globals():
|
||||
def _get_user_allowed_globals():
|
||||
rc: Dict[str, Any] = {}
|
||||
for f in _marked_safe_globals_list:
|
||||
rc[f"{f.__module__}.{f.__name__}"] = f
|
||||
module, name = f.__module__, f.__name__
|
||||
rc[f"{module}.{name}"] = f
|
||||
return rc
|
||||
|
||||
|
||||
@ -170,6 +174,7 @@ class Unpickler:
|
||||
self.readline = file.readline
|
||||
self.read = file.read
|
||||
self.memo: Dict[int, Any] = {}
|
||||
self.proto: int = -1
|
||||
|
||||
def load(self):
|
||||
"""Read a pickled object representation from the open file.
|
||||
@ -190,6 +195,13 @@ class Unpickler:
|
||||
if key[0] == GLOBAL[0]:
|
||||
module = readline()[:-1].decode("utf-8")
|
||||
name = readline()[:-1].decode("utf-8")
|
||||
# Patch since torch.save default protocol is 2
|
||||
# users will be running this code in python > 3
|
||||
if self.proto == 2:
|
||||
if (module, name) in NAME_MAPPING:
|
||||
module, name = NAME_MAPPING[(module, name)]
|
||||
elif module in IMPORT_MAPPING:
|
||||
module = IMPORT_MAPPING[module]
|
||||
full_path = f"{module}.{name}"
|
||||
if full_path in _get_allowed_globals():
|
||||
self.append(_get_allowed_globals()[full_path])
|
||||
@ -198,15 +210,18 @@ class Unpickler:
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
"Please use `torch.serialization.add_safe_globals` to allowlist this global "
|
||||
"if you trust this class/function."
|
||||
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
|
||||
"this global if you trust this class/function."
|
||||
)
|
||||
elif key[0] == NEWOBJ[0]:
|
||||
args = self.stack.pop()
|
||||
cls = self.stack.pop()
|
||||
if cls is not torch.nn.Parameter:
|
||||
if cls is torch.nn.Parameter:
|
||||
self.append(torch.nn.Parameter(*args))
|
||||
elif cls in _get_user_allowed_globals().values():
|
||||
self.append(cls.__new__(cls, *args))
|
||||
else:
|
||||
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
|
||||
self.append(torch.nn.Parameter(*args))
|
||||
elif key[0] == REDUCE[0]:
|
||||
args = self.stack.pop()
|
||||
func = self.stack[-1]
|
||||
@ -228,9 +243,14 @@ class Unpickler:
|
||||
inst.__setstate__(state)
|
||||
elif type(inst) is OrderedDict:
|
||||
inst.__dict__.update(state)
|
||||
elif type(inst) in _get_user_allowed_globals().values():
|
||||
if hasattr(inst, "__setstate__"):
|
||||
inst.__setstate__(state)
|
||||
else:
|
||||
inst.__dict__.update(state)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
|
||||
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
|
||||
)
|
||||
# Stack manipulation
|
||||
elif key[0] == APPEND[0]:
|
||||
@ -334,8 +354,14 @@ class Unpickler:
|
||||
self.append(decode_long(data))
|
||||
# First and last deserializer ops
|
||||
elif key[0] == PROTO[0]:
|
||||
# Read and ignore proto version
|
||||
read(1)[0]
|
||||
self.proto = read(1)[0]
|
||||
if self.proto != 2:
|
||||
warnings.warn(
|
||||
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
|
||||
"not the default pickle protocol used by `torch.load` (2). The weights_only "
|
||||
"Unpickler might not support all instructions implemented by this protocol, "
|
||||
"please file an issue for adding support if you encounter this."
|
||||
)
|
||||
elif key[0] == STOP[0]:
|
||||
rc = self.stack.pop()
|
||||
return rc
|
||||
|
||||
@ -80,7 +80,7 @@ class autocast:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
See the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
|
||||
in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
|
||||
|
||||
:class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
|
||||
|
||||
@ -47,12 +47,14 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
|
||||
return ncclResult_t::ncclInvalidArgument;
|
||||
case torch::cuda::nccl::ncclResult::InvalidUsage:
|
||||
return ncclResult_t::ncclInvalidUsage;
|
||||
case torch::cuda::nccl::ncclResult::NumResults:
|
||||
return ncclResult_t::ncclNumResults;
|
||||
case torch::cuda::nccl::ncclResult::RemoteError:
|
||||
return ncclResult_t::ncclRemoteError;
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
case torch::cuda::nccl::ncclResult::InProgress:
|
||||
return ncclResult_t::ncclInProgress;
|
||||
#endif
|
||||
case torch::cuda::nccl::ncclResult::NumResults:
|
||||
return ncclResult_t::ncclNumResults;
|
||||
default:
|
||||
throw std::runtime_error("Unconvertible NCCL type");
|
||||
}
|
||||
@ -72,12 +74,14 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
|
||||
return torch::cuda::nccl::ncclResult::InvalidArgument;
|
||||
case ncclInvalidUsage:
|
||||
return torch::cuda::nccl::ncclResult::InvalidUsage;
|
||||
case ncclNumResults:
|
||||
return torch::cuda::nccl::ncclResult::NumResults;
|
||||
case ncclRemoteError:
|
||||
return torch::cuda::nccl::ncclResult::RemoteError;
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
case ncclInProgress:
|
||||
return torch::cuda::nccl::ncclResult::InProgress;
|
||||
#endif
|
||||
case ncclNumResults:
|
||||
return torch::cuda::nccl::ncclResult::NumResults;
|
||||
default:
|
||||
throw std::runtime_error("Unconvertible NCCL type");
|
||||
}
|
||||
|
||||
@ -44,8 +44,9 @@ enum class ncclResult {
|
||||
InternalError = 3,
|
||||
InvalidArgument = 4,
|
||||
InvalidUsage = 5,
|
||||
NumResults = 6,
|
||||
InProgress = 7
|
||||
RemoteError = 6,
|
||||
InProgress = 7,
|
||||
NumResults = 8
|
||||
};
|
||||
|
||||
/* Reduction operation selector */
|
||||
|
||||
@ -128,7 +128,7 @@ def is_available() -> bool:
|
||||
return torch._C._cuda_getDeviceCount() > 0
|
||||
|
||||
|
||||
def is_bf16_supported():
|
||||
def is_bf16_supported(including_emulation: bool = True):
|
||||
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
||||
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
||||
# since it is supported on AMD GPU archs.
|
||||
@ -147,6 +147,9 @@ def is_bf16_supported():
|
||||
):
|
||||
return True
|
||||
|
||||
if not including_emulation:
|
||||
return False
|
||||
|
||||
# Finally try to create a bfloat16 device.
|
||||
return _check_bf16_tensor_supported(device)
|
||||
|
||||
|
||||
@ -149,6 +149,22 @@ class FSDPParamGroup:
|
||||
# partial reduce output (only reduce-scattered but not all-reduced)
|
||||
self._partial_reduce_output: Optional[torch.Tensor] = None
|
||||
|
||||
# TODO: remove this hook and hook register once 2D state dict is supported.
|
||||
def _raise_not_implemented_if_2d(*args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError(
|
||||
"2D state_dict is under development. Please check "
|
||||
"https://github.com/pytorch/pytorch/issues/129627 for more details."
|
||||
)
|
||||
|
||||
modules_with_2d_params: Set[nn.Module] = set()
|
||||
for fsdp_param in self.fsdp_params:
|
||||
module = fsdp_param._module_info.module
|
||||
if len(fsdp_param._spmd_placements) > 1:
|
||||
modules_with_2d_params.add(module)
|
||||
for module in modules_with_2d_params:
|
||||
module.register_state_dict_pre_hook(_raise_not_implemented_if_2d)
|
||||
module._register_load_state_dict_pre_hook(_raise_not_implemented_if_2d)
|
||||
|
||||
# Initialization #
|
||||
def _init_mp_dtypes(self) -> None:
|
||||
for fsdp_param in self.fsdp_params:
|
||||
|
||||
@ -11,7 +11,7 @@ from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
||||
from torch.distributed.checkpoint.planner import _Checkpointable
|
||||
|
||||
from torch.utils._pytree import tree_map_only
|
||||
from torch.utils._pytree import tree_map_only_
|
||||
|
||||
from .metadata import (
|
||||
BytesStorageMetadata,
|
||||
@ -295,13 +295,7 @@ def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
|
||||
|
||||
|
||||
def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
|
||||
state_dict_assigned_storage = tree_map_only(
|
||||
torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
|
||||
)
|
||||
# The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
|
||||
# So we need to temporariy update the each element in the state dict with meta tensor.
|
||||
for k in state_dict.keys():
|
||||
state_dict[k] = state_dict_assigned_storage[k]
|
||||
tree_map_only_(torch.Tensor, _init_meta_tensor, state_dict)
|
||||
|
||||
|
||||
def _init_meta_tensor(value: Any) -> Any:
|
||||
|
||||
@ -151,6 +151,9 @@ class _StateDictInfo(StateDictOptions):
|
||||
fqn_param_mapping: Dict[
|
||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
||||
] = field(default_factory=dict)
|
||||
shared_params_mapping: Dict[
|
||||
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
|
||||
] = field(default_factory=dict)
|
||||
submodule_prefixes: Set[str] = field(default_factory=set)
|
||||
handle_model: bool = True
|
||||
handle_optim: bool = True
|
||||
@ -284,14 +287,29 @@ def _verify_options(
|
||||
fqn_param_mapping: Dict[
|
||||
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
|
||||
] = {}
|
||||
shared_params_mapping: Dict[
|
||||
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
|
||||
] = {}
|
||||
for name, param in _iterate_valid_model_state(model):
|
||||
if isinstance(param, _EXTRA_STATE):
|
||||
continue
|
||||
|
||||
fqns = _get_fqns(model, name)
|
||||
if not isinstance(param, _EXTRA_STATE):
|
||||
fqn_param_mapping[param] = fqns
|
||||
fqn = fqn_param_mapping.get(param, None)
|
||||
if fqn is not None:
|
||||
cast(Set[str], fqn_param_mapping[param]).update(fqns)
|
||||
shared_params_mapping[param] = fqn_param_mapping[param]
|
||||
else:
|
||||
# We need to do copy as _get_fqns is lru_cached
|
||||
fqn_param_mapping[param] = fqns.copy()
|
||||
for fqn in fqns:
|
||||
if not isinstance(param, _EXTRA_STATE):
|
||||
fqn_param_mapping[fqn] = param
|
||||
|
||||
for param_, fqns_ in list(shared_params_mapping.items()):
|
||||
for fqn in fqns_:
|
||||
shared_params_mapping[fqn] = cast(torch.Tensor, param_)
|
||||
|
||||
submodule_prefixes: Set[str] = set()
|
||||
if submodules:
|
||||
submodules = set(submodules)
|
||||
@ -359,6 +377,7 @@ def _verify_options(
|
||||
return _StateDictInfo(
|
||||
**asdict(options),
|
||||
fqn_param_mapping=fqn_param_mapping,
|
||||
shared_params_mapping=shared_params_mapping,
|
||||
submodule_prefixes=submodule_prefixes,
|
||||
fsdp_context=fsdp_context,
|
||||
fsdp_modules=cast(List[nn.Module], fsdp_modules),
|
||||
@ -448,7 +467,7 @@ def _get_model_state_dict(
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
fqns = _get_fqns(model, key)
|
||||
assert len(fqns) == 1
|
||||
assert len(fqns) == 1, (key, fqns)
|
||||
fqn = next(iter(fqns))
|
||||
if fqn != key:
|
||||
# As we only support FSDP, DDP, and TP, the only cases are
|
||||
@ -795,6 +814,19 @@ def _split_optim_state_dict(
|
||||
pg_state.append({_PARAMS: []})
|
||||
for param in param_group[_PARAMS]:
|
||||
for fqn in info.fqn_param_mapping[param]:
|
||||
if fqn in info.shared_params_mapping:
|
||||
in_params = False
|
||||
for loaded_param_group in cast(
|
||||
ListDictValueType, optim_state_dict[_PG]
|
||||
):
|
||||
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
|
||||
in_params = True
|
||||
break
|
||||
else:
|
||||
in_params = True
|
||||
if not in_params:
|
||||
continue
|
||||
|
||||
params = pg_state[-1][_PARAMS]
|
||||
assert isinstance(params, list)
|
||||
params.append(fqn)
|
||||
@ -803,9 +835,7 @@ def _split_optim_state_dict(
|
||||
for loaded_param_group in cast(
|
||||
ListDictValueType, optim_state_dict[_PG]
|
||||
):
|
||||
params = loaded_param_group[_PARAMS]
|
||||
assert isinstance(params, list)
|
||||
if fqn in params:
|
||||
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
|
||||
pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
|
||||
|
||||
for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
|
||||
|
||||
@ -341,14 +341,14 @@ def _broadcast_processed_state(
|
||||
group: Optional[dist.ProcessGroup],
|
||||
) -> Dict[str, Any]:
|
||||
objects: List[Any] = [None]
|
||||
if fsdp_state.rank == 0:
|
||||
if dist.get_rank(group) == 0:
|
||||
objects[0] = tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr]
|
||||
optim_state,
|
||||
)
|
||||
dist.broadcast_object_list(objects, src=0, group=group)
|
||||
if fsdp_state.rank == 0:
|
||||
if dist.get_rank(group) == 0:
|
||||
return optim_state
|
||||
else:
|
||||
return objects[0]
|
||||
@ -357,7 +357,7 @@ def _broadcast_processed_state(
|
||||
def _broadcast_state(
|
||||
fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]
|
||||
) -> Any:
|
||||
if fsdp_state.rank == 0:
|
||||
if dist.get_rank(group) == 0:
|
||||
if not isinstance(state, torch.Tensor) or state.dim() == 0:
|
||||
return state
|
||||
tensor = state.to(fsdp_state.compute_device)
|
||||
|
||||
@ -571,7 +571,7 @@ def register_fake(
|
||||
This API may be used as a decorator (see examples).
|
||||
|
||||
For a detailed guide on custom ops, please see
|
||||
https://pytorch.org/docs/main/notes/custom_operators.html
|
||||
https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
|
||||
@ -3,6 +3,7 @@ import difflib
|
||||
import functools
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import shutil
|
||||
import struct
|
||||
import sys
|
||||
@ -166,10 +167,28 @@ def get_safe_globals() -> List[Any]:
|
||||
|
||||
def add_safe_globals(safe_globals: List[Any]) -> None:
|
||||
'''
|
||||
Marks the given globals as safe for ``weights_only`` load.
|
||||
Marks the given globals as safe for ``weights_only`` load. For example, functions
|
||||
added to this list can be called during unpickling, classes could be instantiated
|
||||
and have state set.
|
||||
|
||||
Args:
|
||||
safe_globals (List[Any]): list of globals to mark as safe
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
|
||||
>>> import tempfile
|
||||
>>> class MyTensor(torch.Tensor):
|
||||
... pass
|
||||
>>> t = MyTensor(torch.randn(2, 3))
|
||||
>>> with tempfile.NamedTemporaryFile() as f:
|
||||
... torch.save(t, f.name)
|
||||
# Running `torch.load(f.name, weights_only=True)` will fail with
|
||||
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
|
||||
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
|
||||
... torch.serialization.add_safe_globals([MyTensor])
|
||||
... torch.load(f.name, weights_only=True)
|
||||
# MyTensor([[-0.5024, -1.8152, -0.5455],
|
||||
# [-0.8234, 2.0500, -0.3657]])
|
||||
'''
|
||||
_weights_only_unpickler._add_safe_globals(safe_globals)
|
||||
|
||||
@ -872,7 +891,7 @@ def load(
|
||||
map_location: MAP_LOCATION = None,
|
||||
pickle_module: Any = None,
|
||||
*,
|
||||
weights_only: bool = False,
|
||||
weights_only: Optional[bool] = None,
|
||||
mmap: Optional[bool] = None,
|
||||
**pickle_load_args: Any
|
||||
) -> Any:
|
||||
@ -976,12 +995,38 @@ def load(
|
||||
"""
|
||||
torch._C._log_api_usage_once("torch.load")
|
||||
UNSAFE_MESSAGE = (
|
||||
"Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
|
||||
" will likely succeed, but it can result in arbitrary code execution."
|
||||
" Do it only if you get the file from a trusted source. Alternatively, to load"
|
||||
" with `weights_only` please check the recommended steps in the following error message."
|
||||
" WeightsUnpickler error: "
|
||||
"Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
|
||||
"but it can result in arbitrary code execution. Do it only if you got the file from a "
|
||||
"trusted source."
|
||||
)
|
||||
DOCS_MESSAGE = (
|
||||
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
|
||||
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
|
||||
)
|
||||
|
||||
def _get_wo_message(message: str) -> str:
|
||||
pattern = r"GLOBAL (\S+) was not an allowed global by default."
|
||||
has_unsafe_global = re.search(pattern, message) is not None
|
||||
if has_unsafe_global:
|
||||
updated_message = (
|
||||
"Weights only load failed. This file can still be loaded, to do so you have two options "
|
||||
f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
|
||||
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
|
||||
+ message
|
||||
)
|
||||
else:
|
||||
updated_message = (
|
||||
f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
|
||||
"so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
|
||||
"error: " + message
|
||||
)
|
||||
return updated_message + DOCS_MESSAGE
|
||||
|
||||
if weights_only is None:
|
||||
weights_only, warn_weights_only = False, True
|
||||
else:
|
||||
warn_weights_only = False
|
||||
|
||||
# Add ability to force safe only weight loads via environment variable
|
||||
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
|
||||
weights_only = True
|
||||
@ -991,6 +1036,21 @@ def load(
|
||||
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
|
||||
else:
|
||||
if pickle_module is None:
|
||||
if warn_weights_only:
|
||||
warnings.warn(
|
||||
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
|
||||
"the default pickle module implicitly. It is possible to construct malicious pickle data "
|
||||
"which will execute arbitrary code during unpickling (See "
|
||||
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
|
||||
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
|
||||
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
|
||||
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
|
||||
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
|
||||
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
|
||||
"Please open an issue on GitHub for any issues related to this experimental feature.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
pickle_module = pickle
|
||||
|
||||
# make flipping default BC-compatible
|
||||
@ -1033,12 +1093,14 @@ def load(
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args)
|
||||
except RuntimeError as e:
|
||||
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
|
||||
return _load(opened_zipfile,
|
||||
map_location,
|
||||
pickle_module,
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args)
|
||||
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
||||
return _load(
|
||||
opened_zipfile,
|
||||
map_location,
|
||||
pickle_module,
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args,
|
||||
)
|
||||
if mmap:
|
||||
f_name = "" if not isinstance(f, str) else f"{f}, "
|
||||
raise RuntimeError("mmap can only be used with files saved with "
|
||||
@ -1048,8 +1110,10 @@ def load(
|
||||
try:
|
||||
return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
|
||||
except RuntimeError as e:
|
||||
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
|
||||
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
|
||||
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
||||
return _legacy_load(
|
||||
opened_file, map_location, pickle_module, **pickle_load_args
|
||||
)
|
||||
|
||||
|
||||
# Register pickling support for layout instances such as
|
||||
|
||||
@ -961,6 +961,9 @@ def CppExtension(name, sources, *args, **kwargs):
|
||||
libraries.append('torch')
|
||||
libraries.append('torch_cpu')
|
||||
libraries.append('torch_python')
|
||||
if IS_WINDOWS:
|
||||
libraries.append("sleef")
|
||||
|
||||
kwargs['libraries'] = libraries
|
||||
|
||||
kwargs['language'] = 'c++'
|
||||
|
||||
Reference in New Issue
Block a user