mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Compare commits
99 Commits
bf/partiti
...
benchmarki
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f49e915a9 | |||
| 2f1217f944 | |||
| e0bf01e87b | |||
| 3b5ae0e9fc | |||
| 5f5f654a3e | |||
| 21931cbbc6 | |||
| ef4d57329b | |||
| d62a33c002 | |||
| 0c00e32632 | |||
| 0f56318152 | |||
| 11129d9317 | |||
| d2f506cae8 | |||
| 857f21631d | |||
| ed348e7026 | |||
| d311b79c12 | |||
| e7318b863d | |||
| f6dcc45c44 | |||
| e25074d462 | |||
| c381103fd7 | |||
| 66f53889d5 | |||
| 24980d2641 | |||
| d4ab8e74f3 | |||
| 1c7a70b483 | |||
| 66ac724b56 | |||
| dfe0f48123 | |||
| 92cebed1bd | |||
| b4fe5ca58a | |||
| 4de1b25df7 | |||
| 70539308ac | |||
| e313152a33 | |||
| 3b38989b5f | |||
| d23aa7e182 | |||
| 5bf74753f6 | |||
| 9db7bcb3fe | |||
| 476e0a643a | |||
| 473a93eb58 | |||
| 35a473e364 | |||
| ee4f433963 | |||
| e9b97d19b1 | |||
| a75e3a02be | |||
| 9603d6382d | |||
| 5fd7004dc9 | |||
| e86439ed5b | |||
| 203b0efd63 | |||
| cf7451f279 | |||
| f58143b945 | |||
| fdc339003b | |||
| 853958f82c | |||
| aadf9eae63 | |||
| 9a66c30bdc | |||
| 1fe9842922 | |||
| 28e7aa21c5 | |||
| 9d04c0f352 | |||
| 1d9b7dd2d1 | |||
| fe760b6636 | |||
| 8e25ba6963 | |||
| 08c29deb5f | |||
| 07405a6cff | |||
| dcdaef5206 | |||
| abc3fdc7ac | |||
| ab6cb85cb0 | |||
| fde8f6a8b8 | |||
| b82fb57b67 | |||
| d64b4a91dd | |||
| ef90cc18d7 | |||
| 39df901b2a | |||
| 54f1f29fed | |||
| f12ce4e36b | |||
| c6fc11af76 | |||
| 855eff8e8e | |||
| 919a1a17e3 | |||
| a84d8c4a1c | |||
| cde82d25b7 | |||
| 4d8f3d537a | |||
| e79790e14b | |||
| fe082c5ffe | |||
| 3f10c9d8af | |||
| 4b39832412 | |||
| 247ea229ba | |||
| 53affa273b | |||
| eaf355cb11 | |||
| 241f8dc84d | |||
| 6be829535f | |||
| 555fc05868 | |||
| 7359705232 | |||
| 12fc06d267 | |||
| 3b218e56dc | |||
| 4fd8a54a41 | |||
| b367e5f6a6 | |||
| fa6ca59079 | |||
| 6169ca0b65 | |||
| 75bbd4989c | |||
| 8c0f07f944 | |||
| b8452e55bc | |||
| 5075df6fee | |||
| f472ea63bb | |||
| cfbd99fdfd | |||
| 1ca082d9a1 | |||
| 70fbd5e08c |
@ -820,16 +820,7 @@ test_inductor_torchbench_smoketest_perf() {
|
||||
done
|
||||
}
|
||||
|
||||
test_inductor_get_core_number() {
|
||||
if [[ "${TEST_CONFIG}" == *aarch64* ]]; then
|
||||
echo "$(($(lscpu | grep 'Cluster(s):' | awk '{print $2}') * $(lscpu | grep 'Core(s) per cluster:' | awk '{print $4}')))"
|
||||
else
|
||||
echo "$(($(lscpu | grep 'Socket(s):' | awk '{print $2}') * $(lscpu | grep 'Core(s) per socket:' | awk '{print $4}')))"
|
||||
fi
|
||||
}
|
||||
|
||||
test_inductor_set_cpu_affinity(){
|
||||
#set jemalloc
|
||||
JEMALLOC_LIB="$(find /usr/lib -name libjemalloc.so.2)"
|
||||
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
|
||||
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
|
||||
@ -841,14 +832,23 @@ test_inductor_set_cpu_affinity(){
|
||||
export KMP_AFFINITY=granularity=fine,compact,1,0
|
||||
export KMP_BLOCKTIME=1
|
||||
fi
|
||||
cores=$(test_inductor_get_core_number)
|
||||
# Set number of cores to 16 on Aarch64 for performance runs.
|
||||
|
||||
# Use nproc here instead of lscpu because it takes into account cgroups slice
|
||||
cpus=$(nproc)
|
||||
thread_per_core=$(lscpu | grep 'Thread(s) per core:' | awk '{print $4}')
|
||||
cores=$((cpus / thread_per_core))
|
||||
|
||||
# Set number of cores to 16 on aarch64 for performance runs
|
||||
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
|
||||
cores=16
|
||||
fi
|
||||
export OMP_NUM_THREADS=$cores
|
||||
end_core=$((cores-1))
|
||||
export TASKSET="taskset -c 0-$end_core"
|
||||
|
||||
# Handle cgroups slice start and end CPU
|
||||
start_cpu=$(python -c 'import os; print(min(os.sched_getaffinity(0)))')
|
||||
# Leaving one physical CPU for other tasks
|
||||
end_cpu=$(($(python -c 'import os; print(max(os.sched_getaffinity(0)))') - thread_per_core))
|
||||
export TASKSET="taskset -c $start_cpu-$end_cpu"
|
||||
}
|
||||
|
||||
test_inductor_torchbench_cpu_smoketest_perf(){
|
||||
|
||||
111
.github/ISSUE_TEMPLATE/release-feature-request.yml
vendored
Normal file
111
.github/ISSUE_TEMPLATE/release-feature-request.yml
vendored
Normal file
@ -0,0 +1,111 @@
|
||||
name: 🚀 Release highlight for proposed Feature
|
||||
description: Submit a Release highlight for proposed Feature
|
||||
labels: ["release-feature-request"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Release highlight for proposed Feature
|
||||
description: >
|
||||
Example: “A torch.special module, analogous to SciPy's special module.”
|
||||
- type: input
|
||||
id: contact
|
||||
attributes:
|
||||
label: Point(s) of contact
|
||||
description: How can we get in touch with you if we need more info?
|
||||
placeholder: ex. github username
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Release Mode (pytorch/pytorch features only)
|
||||
description: |
|
||||
If "out-of-tree", please include the GH repo name
|
||||
options:
|
||||
- In-tree
|
||||
- Out-of-tree
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Out-Of-Tree Repo
|
||||
description: >
|
||||
please include the GH repo name
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Description and value to the user
|
||||
description: >
|
||||
Please provide a brief description of the feature and how it will benefit the user.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Link to design doc, GitHub issues, past submissions, etc
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: What feedback adopters have provided
|
||||
description: >
|
||||
Please list users/teams that have tried the feature and provided feedback. If that feedback motivated material changes (API, doc, etc..), a quick overview of the changes and the status (planned, in progress, implemented) would be helpful as well.
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Plan for documentations / tutorials
|
||||
description: |
|
||||
Select One of the following options
|
||||
options:
|
||||
- Tutorial exists
|
||||
- Will submit a PR to pytorch/tutorials
|
||||
- Will submit a PR to a repo
|
||||
- Tutorial is not needed
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional context for tutorials
|
||||
description: >
|
||||
Please provide a link for existing tutorial or link to a repo or context for why tutorial is not needed.
|
||||
validations:
|
||||
required: false
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Marketing/Blog Coverage
|
||||
description: |
|
||||
Are you requesting feature Inclusion in the release blogs?
|
||||
options:
|
||||
- "Yes"
|
||||
- "No"
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Are you requesting other marketing assistance with this feature?
|
||||
description: >
|
||||
E.g. supplementary blogs, social media amplification, etc.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Release Version
|
||||
description: >
|
||||
Please include release version for marketing coverage.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: OS / Platform / Compute Coverage
|
||||
description: >
|
||||
Please list the platforms supported by the proposed feature. If the feature supports all the platforms, write "all". Goal of this section is to clearly share if this feature works in all PyTorch configurations or is it limited to only certain platforms/configurations (e.g. CPU only, GPU only, Linux only, etc...)
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Testing Support (CI, test cases, etc..)
|
||||
description: >
|
||||
Please provide an overview of test coverage. This includes unit testing and integration testing, but if E2E validation testing has been done to show that the feature works for a certain set of use cases or models please mention that as well.
|
||||
validations:
|
||||
required: false
|
||||
2
.github/ci_commit_pins/torchbench.txt
vendored
2
.github/ci_commit_pins/torchbench.txt
vendored
@ -1 +1 @@
|
||||
6693f5845f212d8af3513f8b8d275d5b65db9caf
|
||||
e03a63be43e33596f7f0a43b0f530353785e4a59
|
||||
|
||||
8
.github/workflows/inductor-rocm-mi300.yml
vendored
8
.github/workflows/inductor-rocm-mi300.yml
vendored
@ -38,12 +38,12 @@ jobs:
|
||||
opt_out_experiments: lf
|
||||
|
||||
linux-jammy-rocm-py3_10-inductor-build:
|
||||
name: rocm-py3.10-inductor
|
||||
name: rocm-py3.10-inductor-mi300
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
build-environment: linux-jammy-rocm-py3.10-mi300
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -56,11 +56,11 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: rocm-py3.10-inductor
|
||||
name: rocm-py3.10-inductor-mi300
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs: linux-jammy-rocm-py3_10-inductor-build
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
build-environment: linux-jammy-rocm-py3.10-mi300
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
8
.github/workflows/periodic-rocm-mi300.yml
vendored
8
.github/workflows/periodic-rocm-mi300.yml
vendored
@ -50,12 +50,12 @@ jobs:
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
name: linux-jammy-rocm-py3.10-mi300
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
build-environment: linux-jammy-rocm-py3.10-mi300
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -69,13 +69,13 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
name: linux-jammy-rocm-py3.10-mi300
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
build-environment: linux-jammy-rocm-py3.10-mi300
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
8
.github/workflows/rocm-mi300.yml
vendored
8
.github/workflows/rocm-mi300.yml
vendored
@ -38,12 +38,12 @@ jobs:
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-jammy-rocm-py3.10
|
||||
name: linux-jammy-rocm-py3.10-mi300
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
build-environment: linux-jammy-rocm-py3.10-mi300
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
@ -61,13 +61,13 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
name: linux-jammy-rocm-py3.10-mi300
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
build-environment: linux-jammy-rocm-py3.10-mi300
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
@ -1160,12 +1160,6 @@ exclude_patterns = [
|
||||
'torch/_inductor/autoheuristic/artifacts/**',
|
||||
# These files are all grandfathered in, feel free to remove from this list
|
||||
# as necessary
|
||||
'test/_nvfuser/__init__.py',
|
||||
'test/_nvfuser/test_dynamo.py',
|
||||
'test/_nvfuser/test_python_frontend.py',
|
||||
'test/_nvfuser/test_torchscript.py',
|
||||
'test/delete.py',
|
||||
'test/expect/__init__.py',
|
||||
'test/quantization/__init__.py',
|
||||
'test/quantization/core/__init__.py',
|
||||
'test/quantization/core/experimental/apot_fx_graph_mode_ptq.py',
|
||||
@ -1322,12 +1316,6 @@ exclude_patterns = [
|
||||
'torch/_export/passes/const_prop_pass.py',
|
||||
'torch/_export/passes/functionalize_side_effectful_ops_pass.py',
|
||||
'torch/_export/passes/replace_sym_size_ops_pass.py',
|
||||
'torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py',
|
||||
'torch/_export/serde/__init__.py',
|
||||
'torch/_export/serde/schema.py',
|
||||
'torch/_export/serde/serialize.py',
|
||||
'torch/_export/serde/upgrade.py',
|
||||
'torch/_export/trace.py',
|
||||
'torch/testing/_internal/__init__.py',
|
||||
'torch/testing/_internal/autocast_test_lists.py',
|
||||
'torch/testing/_internal/autograd_function_db.py',
|
||||
@ -1444,7 +1432,6 @@ exclude_patterns = [
|
||||
'torch/utils/throughput_benchmark.py',
|
||||
'torch/utils/viz/__init__.py',
|
||||
'torch/utils/viz/_cycles.py',
|
||||
'torch/utils/weak.py',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
||||
@ -184,6 +184,12 @@ new_local_repository(
|
||||
path = "third_party/nlohmann",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "moodycamel",
|
||||
build_file = "//third_party:moodycamel.BUILD",
|
||||
path = "third_party/concurrentqueue",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "tensorpipe",
|
||||
build_file = "//third_party:tensorpipe.BUILD",
|
||||
|
||||
@ -78,7 +78,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
|
||||
return CUDA_R_64I;
|
||||
case c10::ScalarType::BFloat16:
|
||||
return CUDA_R_16BF;
|
||||
#if defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300)
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60300
|
||||
case c10::ScalarType::Float8_e4m3fn:
|
||||
return CUDA_R_8F_E4M3;
|
||||
case c10::ScalarType::Float8_e5m2:
|
||||
|
||||
@ -139,7 +139,7 @@ void CUDAGraph::capture_end() {
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||
// cudaGraphInstantiateWithFlags
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
|
||||
#if (defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60200))
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||
int version = 0;
|
||||
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
if (version < 11040) {
|
||||
@ -154,7 +154,7 @@ void CUDAGraph::capture_end() {
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
#if (defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60200))
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
|
||||
@ -135,6 +135,7 @@ CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) {
|
||||
} // namespace (anonymous)
|
||||
|
||||
DEFINE_DISPATCH(gemm_stub);
|
||||
DEFINE_DISPATCH(gemm_no_downcast_stub);
|
||||
|
||||
void gemm(
|
||||
TransposeType transa, TransposeType transb,
|
||||
@ -452,18 +453,18 @@ void gemm(
|
||||
// for the fallback path, first compute gemm with beta = 0,
|
||||
// and then add c in full precision.
|
||||
int64_t c_size = n * m;
|
||||
std::vector<at::BFloat16> bfloat_c(c_size, 0.f);
|
||||
gemm_stub(
|
||||
std::vector<float> float_c(c_size, 0.f);
|
||||
gemm_no_downcast_stub(
|
||||
at::kCPU, at::kBFloat16,
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m);
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
|
||||
for (const auto j : c10::irange(n)) {
|
||||
for (const auto i : c10::irange(m)) {
|
||||
auto offset = j * ldc + i;
|
||||
// beta == 0 won't propagate NaN from C
|
||||
if (beta == 0.f) {
|
||||
c[offset] = c10::convert<float>(bfloat_c[j * m + i]);
|
||||
c[offset] = float_c[j * m + i];
|
||||
} else {
|
||||
c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]);
|
||||
c[offset] = beta * c[offset] + float_c[j * m + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,6 +29,18 @@ using gemm_fn = void(*)(
|
||||
|
||||
DECLARE_DISPATCH(gemm_fn, gemm_stub)
|
||||
|
||||
using gemm_no_downcast_fn = void(*)(
|
||||
at::ScalarType type,
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
const Scalar& alpha,
|
||||
const void *a, int64_t lda,
|
||||
const void *b, int64_t ldb,
|
||||
const Scalar& beta,
|
||||
void *c, int64_t ldc);
|
||||
|
||||
DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
|
||||
|
||||
template <typename scalar_t>
|
||||
void gemm(
|
||||
TransposeType transa, TransposeType transb,
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include <ATen/native/cpu/SerialStackImpl.h>
|
||||
#include <ATen/native/cpu/StackKernel.h>
|
||||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <c10/core/Contiguity.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
@ -1993,11 +1994,15 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
|
||||
TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
|
||||
}
|
||||
|
||||
if (self.is_contiguous() && !self.is_mkldnn()) {
|
||||
auto sym_sizes = self.sym_sizes();
|
||||
auto sym_strides = self.sym_strides();
|
||||
auto sym_numel = self.sym_numel();
|
||||
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
|
||||
!self.is_mkldnn()) {
|
||||
return self.view_symint(proposed_shape);
|
||||
}
|
||||
|
||||
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
|
||||
c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
|
||||
|
||||
if (self.is_mkldnn()) {
|
||||
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
|
||||
@ -2005,8 +2010,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
|
||||
|
||||
// `computeStride` returns the proper strides to use if this
|
||||
// `reshape` can be just a view.
|
||||
auto stride =
|
||||
at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
|
||||
auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape);
|
||||
|
||||
// NB: Even though we have viewable geometry and the target strides here,
|
||||
// we do not just call `as_strided` on `self` because the backward
|
||||
|
||||
@ -99,7 +99,7 @@ auto sum(int64_t N, Func f) {
|
||||
return partial_sums[0];
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_notrans_(
|
||||
int64_t m,
|
||||
@ -111,7 +111,7 @@ gemm_notrans_(
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
out_t* c,
|
||||
int64_t ldc) {
|
||||
// c *= beta
|
||||
scale_(m, n, beta, c, ldc);
|
||||
@ -135,7 +135,7 @@ gemm_notrans_(
|
||||
}
|
||||
|
||||
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_notrans_(
|
||||
int64_t m,
|
||||
@ -147,7 +147,7 @@ gemm_notrans_(
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
out_t* c,
|
||||
int64_t ldc) {
|
||||
// c += alpha * (a @ b)
|
||||
for (const auto i : c10::irange(m)) {
|
||||
@ -165,7 +165,7 @@ gemm_notrans_(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
void gemm_transa_(
|
||||
TransposeType transa,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -173,7 +173,7 @@ void gemm_transa_(
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
out_t *c, int64_t ldc) {
|
||||
// c = alpha * (a.T @ b) + beta * c
|
||||
const scalar_t *a_ = a;
|
||||
for (const auto i : c10::irange(m)) {
|
||||
@ -225,6 +225,7 @@ void gemm_transb_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// in this case, scalar_t == opmath_t == out_t so out_t template param is not needed
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_transb_(
|
||||
@ -247,7 +248,7 @@ gemm_transb_(
|
||||
}
|
||||
|
||||
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
|
||||
gemm_transb_(
|
||||
TransposeType transb,
|
||||
@ -260,7 +261,7 @@ gemm_transb_(
|
||||
const scalar_t* b,
|
||||
int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t* c,
|
||||
out_t* c,
|
||||
int64_t ldc) {
|
||||
// We need to calculate full-precision dot products for correctness;
|
||||
// users notice error accumulation with reduced-width types (e.g.,
|
||||
@ -304,7 +305,7 @@ gemm_transb_(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
void gemm_transab_(
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -312,7 +313,7 @@ void gemm_transab_(
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
out_t *c, int64_t ldc) {
|
||||
// c = beta * c + alpha * (a.T @ b.T)
|
||||
for (const auto i : c10::irange(m)) {
|
||||
for (const auto j : c10::irange(n)) {
|
||||
@ -436,7 +437,7 @@ void gemm_transa_(
|
||||
}
|
||||
#endif // !defined(C10_MOBILE)
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
template <typename scalar_t, typename opmath_t, typename out_t>
|
||||
void gemm_core_(
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
@ -444,7 +445,7 @@ void gemm_core_(
|
||||
const scalar_t *a, int64_t lda,
|
||||
const scalar_t *b, int64_t ldb,
|
||||
opmath_t beta,
|
||||
scalar_t *c, int64_t ldc) {
|
||||
out_t *c, int64_t ldc) {
|
||||
if (transa == TransposeType::NoTranspose &&
|
||||
transb == TransposeType::NoTranspose) {
|
||||
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||
@ -493,6 +494,27 @@ void cpublas_gemm_impl(
|
||||
});
|
||||
}
|
||||
|
||||
void cpublas_gemm_no_downcast_impl(
|
||||
at::ScalarType type,
|
||||
TransposeType transa, TransposeType transb,
|
||||
int64_t m, int64_t n, int64_t k,
|
||||
const Scalar& alpha,
|
||||
const void *a, int64_t lda,
|
||||
const void *b, int64_t ldb,
|
||||
const Scalar& beta,
|
||||
void *c, int64_t ldc) {
|
||||
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_no_downcast_impl", [&]{
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
gemm_core_(
|
||||
transa, transb, m, n, k,
|
||||
alpha.to<opmath_t>(),
|
||||
static_cast<const scalar_t *>(a), lda,
|
||||
static_cast<const scalar_t *>(b), ldb,
|
||||
beta.to<opmath_t>(),
|
||||
static_cast<opmath_t *>(c), ldc);
|
||||
});
|
||||
}
|
||||
|
||||
void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){
|
||||
if (type == at::kBool) {
|
||||
auto a = _a.to<bool>();
|
||||
@ -530,6 +552,7 @@ void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t i
|
||||
|
||||
|
||||
REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl)
|
||||
REGISTER_DISPATCH(cpublas::gemm_no_downcast_stub, &cpublas::cpublas_gemm_no_downcast_impl)
|
||||
REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl)
|
||||
REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/mkldnn/xpu/FusionUtils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/utils/ParamUtils.h>
|
||||
#include <ATen/ops/full.h>
|
||||
@ -309,81 +310,6 @@ static at::Tensor view3d(const at::Tensor& tensor) {
|
||||
return tensor.squeeze(2);
|
||||
}
|
||||
|
||||
Attr get_onednn_conv_sum_attr(
|
||||
const Tensor& input_r,
|
||||
const Tensor& weight_r,
|
||||
IntArrayRef stride_,
|
||||
IntArrayRef padding_,
|
||||
IntArrayRef dilation_,
|
||||
Tensor& accumu,
|
||||
double scale,
|
||||
Tensor& output,
|
||||
bool& is_fused,
|
||||
Attr attr = Attr(),
|
||||
bool force_inplace = false) {
|
||||
is_fused = true;
|
||||
if (scale == 0.f)
|
||||
return attr;
|
||||
|
||||
auto ndim = input_r.ndimension();
|
||||
auto output_size = conv_dst_size(
|
||||
ndim,
|
||||
input_r.sizes(),
|
||||
weight_r.sizes(),
|
||||
padding_,
|
||||
padding_,
|
||||
stride_,
|
||||
dilation_);
|
||||
MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
|
||||
auto input_fmt = input_r.suggest_memory_format();
|
||||
auto input_is_cl =
|
||||
(input_fmt == at::MemoryFormat::ChannelsLast ||
|
||||
input_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
auto weight_fmt = weight_r.suggest_memory_format();
|
||||
auto weight_is_cl =
|
||||
(weight_fmt == at::MemoryFormat::ChannelsLast ||
|
||||
weight_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
|
||||
bool propagate_channels_last = input_is_cl || weight_is_cl;
|
||||
if (propagate_channels_last)
|
||||
mem_fmt = get_cl_tag_by_ndim(ndim);
|
||||
|
||||
Tensor out = at::empty(output_size, input_r.options().memory_format(mem_fmt));
|
||||
if (!onednn::binary_valid(out, accumu)) {
|
||||
is_fused = false;
|
||||
return attr;
|
||||
}
|
||||
|
||||
// For post-sum and post-binary-add, onednn needs sum/binary scale=1.f
|
||||
// Thus we need the following transformation
|
||||
// conv(src, wei) + scale * accumu
|
||||
// scale * (1/scale * conv(src, wei) + sum (or binary))
|
||||
if (scale != 1.f)
|
||||
attr.append_post_eltwise(
|
||||
/* scale */ 1.f,
|
||||
/* alpha */ 1.f / scale,
|
||||
/* beta */ 0.f,
|
||||
attr.kind_with_linear);
|
||||
|
||||
if (force_inplace) {
|
||||
// If sizes are the same, post sum is used.
|
||||
output = accumu;
|
||||
attr.append_post_sum(/* sum_scale */ 1.f);
|
||||
} else {
|
||||
// If sizes are different, post binary is used.
|
||||
attr.append_post_binary(attr.kind_with_binary_add, accumu);
|
||||
}
|
||||
|
||||
if (scale != 1.f)
|
||||
attr.append_post_eltwise(
|
||||
/* scale */ 1.f,
|
||||
/* alpha */ scale,
|
||||
/* beta */ 0.f,
|
||||
attr.kind_with_linear);
|
||||
|
||||
return attr;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
using namespace impl;
|
||||
@ -476,6 +402,8 @@ Tensor _convolution_out(
|
||||
params.output_padding,
|
||||
params.groups);
|
||||
output = at::empty(dst_tz, input.options(), mfmt);
|
||||
} else {
|
||||
output = output_r;
|
||||
}
|
||||
|
||||
onednn::deconvolution(
|
||||
@ -518,6 +446,8 @@ Tensor _convolution_out(
|
||||
params.stride,
|
||||
params.dilation);
|
||||
output = at::empty(dst_tz, input.options(), mfmt);
|
||||
} else {
|
||||
output = output_r;
|
||||
}
|
||||
onednn::convolution(
|
||||
output,
|
||||
@ -751,6 +681,119 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
}
|
||||
|
||||
Tensor convolution_pointwise(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<std::string_view> algorithm) {
|
||||
c10::DeviceGuard device_guard(input_t.device());
|
||||
Attr att;
|
||||
att = construct_unary_attr(att, attr, scalars, algorithm);
|
||||
const Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
|
||||
return _convolution(
|
||||
input_t,
|
||||
weight_t,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
/*output_padding*/ {0},
|
||||
groups,
|
||||
att);
|
||||
}
|
||||
|
||||
Tensor convolution_pointwise_binary(
|
||||
const Tensor& input_t,
|
||||
const Tensor& other_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
c10::DeviceGuard device_guard(input_t.device());
|
||||
Tensor output;
|
||||
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
// Step1: Construct binary attr
|
||||
Attr attr;
|
||||
attr = construct_binary_attr(attr, binary_attr, other_t);
|
||||
// Step2: Append unary attr
|
||||
if (unary_attr.has_value())
|
||||
attr = construct_unary_attr(
|
||||
attr, unary_attr.value(), unary_scalars, unary_algorithm);
|
||||
|
||||
Tensor res = _convolution_out(
|
||||
output,
|
||||
input_t,
|
||||
weight_t,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
/*output_padding*/ {0},
|
||||
groups,
|
||||
attr);
|
||||
|
||||
// Step3: Run conv
|
||||
return res;
|
||||
}
|
||||
|
||||
Tensor& convolution_pointwise_binary_(
|
||||
Tensor& other_t,
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
c10::DeviceGuard device_guard(input_t.device());
|
||||
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
// Step1: Construct binary attr
|
||||
Attr attr;
|
||||
attr = construct_binary_attr(attr, binary_attr, other_t);
|
||||
|
||||
// Step2: Append unary attr
|
||||
if (unary_attr.has_value())
|
||||
attr = construct_unary_attr(
|
||||
attr, unary_attr.value(), unary_scalars, unary_algorithm);
|
||||
|
||||
_convolution_out(
|
||||
other_t,
|
||||
input_t,
|
||||
weight_t,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
/*output_padding*/ {0},
|
||||
groups,
|
||||
attr);
|
||||
|
||||
// Step3: Run conv
|
||||
return other_t;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
|
||||
m.impl(
|
||||
@ -758,4 +801,16 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
TORCH_FN(convolution_backward_overrideable));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, XPU, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"),
|
||||
TORCH_FN(convolution_pointwise));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"),
|
||||
TORCH_FN(convolution_pointwise_binary));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
|
||||
TORCH_FN(convolution_pointwise_binary_));
|
||||
}
|
||||
|
||||
} // namespace at::native::xpu
|
||||
|
||||
@ -4981,7 +4981,7 @@
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias
|
||||
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS, MTIA: _reshape_alias
|
||||
# We don't need to support mkldnn since this is handled explicitly by the reshape operator.
|
||||
|
||||
- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
|
||||
@ -10236,7 +10236,7 @@
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU, CUDA, Meta, MPS: unfold
|
||||
CPU, CUDA, Meta, MPS, MTIA: unfold
|
||||
QuantizedCPU, QuantizedCUDA: unfold
|
||||
|
||||
- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
|
||||
|
||||
@ -356,13 +356,14 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
|
||||
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d + 1);
|
||||
}
|
||||
|
||||
return at::_sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
computed_sizes,
|
||||
return at::native::_sparse_coo_tensor_unsafe(
|
||||
indices,
|
||||
values,
|
||||
values.options().layout(kSparse),
|
||||
computed_sizes,
|
||||
optTypeMetaToScalarType(options.dtype_opt()),
|
||||
options.layout_opt(),
|
||||
options.device_opt(),
|
||||
options.pinned_memory_opt(),
|
||||
is_coalesced);
|
||||
}
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
||||
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/linear.h>
|
||||
#include <ATen/ops/narrow_native.h>
|
||||
@ -963,33 +964,98 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||
std::optional<double> scale) {
|
||||
// Used for tracking usage statistics
|
||||
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
|
||||
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
|
||||
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
Tensor q_t = query.transpose(1, 2);
|
||||
Tensor k_t = key.transpose(1, 2);
|
||||
Tensor v_t = value.transpose(1, 2);
|
||||
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
|
||||
int64_t batch_size = query.size(0);
|
||||
|
||||
sdp::CustomMaskType custom_mask_type = is_causal
|
||||
? sdp::CustomMaskType::CausalFromTopLeft
|
||||
: sdp::CustomMaskType::NoCustomMask;
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
|
||||
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
|
||||
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||
}
|
||||
auto process_chunk = [&](const Tensor& q_chunk,
|
||||
const Tensor& k_chunk,
|
||||
const Tensor& v_chunk,
|
||||
const std::optional<Tensor>& bias_chunk)
|
||||
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
|
||||
Tensor q_t = q_chunk.transpose(1, 2);
|
||||
Tensor k_t = k_chunk.transpose(1, 2);
|
||||
Tensor v_t = v_chunk.transpose(1, 2);
|
||||
|
||||
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
attn_bias,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
dropout_p,
|
||||
static_cast<int64_t>(custom_mask_type),
|
||||
compute_log_sumexp,
|
||||
scale);
|
||||
sdp::CustomMaskType custom_mask_type = is_causal
|
||||
? sdp::CustomMaskType::CausalFromTopLeft
|
||||
: sdp::CustomMaskType::NoCustomMask;
|
||||
|
||||
attention = attention.transpose(1, 2);
|
||||
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
|
||||
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] =
|
||||
at::_efficient_attention_forward(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
bias_chunk,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
dropout_p,
|
||||
static_cast<int64_t>(custom_mask_type),
|
||||
compute_log_sumexp,
|
||||
scale);
|
||||
attention = attention.transpose(1, 2);
|
||||
|
||||
return std::make_tuple(std::move(attention),
|
||||
std::move(log_sumexp),
|
||||
std::move(seed),
|
||||
std::move(offset));
|
||||
};
|
||||
|
||||
// when bs is larger than allowed maximum, process in chunks
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
int64_t start = 0;
|
||||
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||
|
||||
Tensor query_chunk = query.slice(0, start, end);
|
||||
Tensor key_chunk = key.slice(0, start, end);
|
||||
Tensor value_chunk = value.slice(0, start, end);
|
||||
std::optional<Tensor> bias_chunk;
|
||||
if (attn_bias.has_value()) {
|
||||
bias_chunk = attn_bias.value().slice(0, start, end);
|
||||
}
|
||||
auto [attn, log_sumexp, seed, offset] =
|
||||
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||
int dim = attn.dim();
|
||||
std::vector<int64_t> sizes;
|
||||
sizes.reserve(dim);
|
||||
sizes.push_back(batch_size);
|
||||
for (int i = 1; i < dim; i++) {
|
||||
sizes.push_back(attn.size(i));
|
||||
}
|
||||
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
|
||||
final_attention.slice(0, start, end).copy_(attn);
|
||||
|
||||
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||
end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||
query_chunk = query.slice(0, start, end);
|
||||
key_chunk = key.slice(0, start, end);
|
||||
value_chunk = value.slice(0, start, end);
|
||||
if (attn_bias.has_value()) {
|
||||
bias_chunk = attn_bias.value().slice(0, start, end);
|
||||
} else {
|
||||
bias_chunk.reset();
|
||||
}
|
||||
|
||||
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
|
||||
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||
final_attention.slice(0, start, end).copy_(chunk_attn);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(final_attention),
|
||||
std::move(log_sumexp),
|
||||
std::move(seed),
|
||||
std::move(offset));
|
||||
}
|
||||
// when bs is within the allowed size, no need to chunk it
|
||||
else {
|
||||
return process_chunk(query, key, value, attn_bias);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
|
||||
@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18140000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1681000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1700000000,0.02
|
||||
|
||||
|
||||
|
||||
float_args,compile_time_instruction_count,449800000,0.015
|
||||
float_args,compile_time_instruction_count,452500000,0.015
|
||||
|
||||
|
||||
|
||||
@ -54,24 +54,24 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2112000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015
|
||||
|
||||
|
@ -178,6 +178,7 @@ THIRD_PARTY_LIBS = {
|
||||
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
|
||||
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
|
||||
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
|
||||
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
|
||||
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
|
||||
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
|
||||
"ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"],
|
||||
|
||||
@ -15,6 +15,7 @@ cxx_library(
|
||||
"//third_party:cpuinfo",
|
||||
"//third_party:fmt",
|
||||
"//third_party:glog",
|
||||
"//third_party:moodycamel",
|
||||
],
|
||||
exported_deps = [],
|
||||
compiler_flags = [
|
||||
|
||||
@ -96,6 +96,7 @@ if(NOT BUILD_LIBTORCHLESS)
|
||||
endif()
|
||||
target_link_libraries(c10 PRIVATE fmt::fmt-header-only)
|
||||
target_link_libraries(c10 PRIVATE nlohmann)
|
||||
target_link_libraries(c10 PRIVATE moodycamel)
|
||||
|
||||
if(C10_USE_NUMA)
|
||||
message(STATUS "NUMA paths:")
|
||||
|
||||
@ -12,24 +12,49 @@ namespace c10 {
|
||||
|
||||
template <typename T>
|
||||
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
bool is_contiguous = true;
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
|
||||
return is_contiguous;
|
||||
return true;
|
||||
}
|
||||
T z = 1;
|
||||
|
||||
T expected_stride = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
|
||||
z *= size_d;
|
||||
} else {
|
||||
is_contiguous = false;
|
||||
break;
|
||||
}
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return is_contiguous;
|
||||
return true;
|
||||
}
|
||||
|
||||
// This function will return True if the tensor is contiguous, and False if the
|
||||
// its not or if we can't determine if it is contiguous due to unbacked symbols
|
||||
// (it could be either in that case based on the actual runtime data).
|
||||
template <typename T>
|
||||
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
T expected_stride = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -188,18 +188,21 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool guard_or_false(const char* file, int64_t line) {
|
||||
// No improvement for unbacked SymBools by default, replace this
|
||||
// with a better implementation!
|
||||
// Note: PT2 primarily uses PythonSymNodeImpl for this functionality.
|
||||
// XLA is currently the main consumer of this fallback path since it uses
|
||||
// ahead-of-time compilation and cannot depend on Python runtime.
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool statically_known_true(const char* file, int64_t line) {
|
||||
// No improvement for unbacked SymBools by default, replace this
|
||||
// with a better implementation!
|
||||
// Note: PT2 primarily uses PythonSymNodeImpl for this functionality.
|
||||
// XLA is currently the main consumer of this fallback path since it uses
|
||||
// ahead-of-time compilation and cannot depend on Python runtime.
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool guard_or_true(const char* file, int64_t line) {
|
||||
// No improvement for unbacked SymBools by default, replace this
|
||||
// with a better implementation!
|
||||
// Note: PT2 primarily uses PythonSymNodeImpl for this functionality.
|
||||
// XLA is currently the main consumer of this fallback path since it uses
|
||||
// ahead-of-time compilation and cannot depend on Python runtime.
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool expect_true(const char* file, int64_t line) {
|
||||
|
||||
@ -833,8 +833,9 @@ class EventPool {
|
||||
|
||||
// CUDA graphs helper
|
||||
struct PrivatePool {
|
||||
PrivatePool(MempoolId_t id)
|
||||
PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
|
||||
: id(std::move(id)),
|
||||
allocator_(allocator),
|
||||
large_blocks(/*small=*/false, this),
|
||||
small_blocks(/*small=*/true, this) {}
|
||||
PrivatePool(const PrivatePool&) = delete;
|
||||
@ -855,8 +856,14 @@ struct PrivatePool {
|
||||
// distinguish private blocks by adding a "pool id" check above the stream
|
||||
// check in BlockComparator. BlockComparator is performance- critical though,
|
||||
// I'd rather not add more logic to it.
|
||||
CUDAAllocator* allocator_;
|
||||
BlockPool large_blocks;
|
||||
BlockPool small_blocks;
|
||||
|
||||
public:
|
||||
CUDAAllocator* allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
};
|
||||
|
||||
MempoolId_t BlockPool::owner_MempoolId() const {
|
||||
@ -905,9 +912,8 @@ struct MempoolIdHash {
|
||||
};
|
||||
|
||||
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) {
|
||||
*ptr = active_pool->allocator()->raw_alloc(size);
|
||||
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
|
||||
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
|
||||
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
|
||||
} else {
|
||||
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
|
||||
@ -1277,14 +1283,14 @@ class DeviceCachingAllocator {
|
||||
alloc_block(params, false, context, lock))
|
||||
// Free all non-split cached blocks and retry alloc.
|
||||
|| (C10_LIKELY(captures_underway.empty()) &&
|
||||
release_cached_blocks(context) &&
|
||||
release_cached_blocks(context, {0, 0}) &&
|
||||
alloc_block(params, true, context, lock));
|
||||
}
|
||||
|
||||
// we are about to oom, try to use existing mempools as a last resort
|
||||
if (!block_found && params.err == cudaErrorMemoryAllocation) {
|
||||
// if already trying to use a mempool, then just oom
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
bool active_pool = params.pool->owner_PrivatePool;
|
||||
if (!active_pool) {
|
||||
for (MempoolId_t mempool_id : use_on_oom_pools) {
|
||||
auto tid = std::this_thread::get_id();
|
||||
@ -1671,10 +1677,10 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
/** returns cached blocks to the system allocator **/
|
||||
void emptyCache() {
|
||||
void emptyCache(MempoolId_t mempool_id) {
|
||||
auto context = maybeGatherContext(RecordContext::ALL);
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
release_cached_blocks(context);
|
||||
release_cached_blocks(context, mempool_id);
|
||||
}
|
||||
|
||||
/** Retrieves size of largest unused block held by the memory cache **/
|
||||
@ -1992,16 +1998,10 @@ class DeviceCachingAllocator {
|
||||
|
||||
/** Dump a complete snapshot of the memory held by the allocator. Potentially
|
||||
* VERY expensive. **/
|
||||
std::vector<SegmentInfo> snapshot() {
|
||||
std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
|
||||
std::vector<Block*> all_blocks;
|
||||
MempoolId_t mempool_id = {0, 0};
|
||||
|
||||
auto active_mempool = MemPoolContext::getActiveMemPool();
|
||||
if (active_mempool) {
|
||||
mempool_id = active_mempool->id();
|
||||
}
|
||||
|
||||
if (mempool_id.first != 0 || mempool_id.second != 0) {
|
||||
// If there is an active mempool, we find the corresponding PrivatePool
|
||||
@ -2011,7 +2011,7 @@ class DeviceCachingAllocator {
|
||||
all_blocks = get_private_pool_head_blocks(pool->second.get());
|
||||
}
|
||||
} else {
|
||||
// When snapshot is called outside a MemPoolContext, we return
|
||||
// When snapshot is called with non-default mempool_id, we return
|
||||
// all the blocks in the CUDACachingAllocator (as returned by
|
||||
// get_all_blocks).
|
||||
all_blocks = get_all_blocks();
|
||||
@ -2130,11 +2130,11 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
void ensureExistsAndIncrefPool(MempoolId_t mempool_id) {
|
||||
void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) {
|
||||
// Create a PrivatePool object if it does not exist yet
|
||||
// and increment its use_count
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
ensure_exists_and_incref_pool(mempool_id);
|
||||
create_or_incref_pool(mempool_id, allocator);
|
||||
}
|
||||
|
||||
void setUseOnOOM(MempoolId_t mempool_id) {
|
||||
@ -2150,7 +2150,7 @@ class DeviceCachingAllocator {
|
||||
MempoolId_t mempool_id,
|
||||
std::function<bool(cudaStream_t)> filter) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
ensure_exists_and_incref_pool(mempool_id);
|
||||
create_or_incref_pool(mempool_id);
|
||||
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
|
||||
++it2) {
|
||||
TORCH_CHECK(
|
||||
@ -2272,21 +2272,24 @@ class DeviceCachingAllocator {
|
||||
return blocks;
|
||||
}
|
||||
|
||||
void ensure_exists_and_incref_pool(MempoolId_t mempool_id) {
|
||||
void create_or_incref_pool(
|
||||
MempoolId_t mempool_id,
|
||||
CUDAAllocator* allocator = nullptr) {
|
||||
auto it = graph_pools.find(mempool_id);
|
||||
if (it == graph_pools.end()) {
|
||||
// mempool_id does not reference an existing pool.
|
||||
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
|
||||
// usage. use_count is initially 1, which means the pool is
|
||||
// being used since somebody called ensureExistsAndIncrefPool.
|
||||
// being used since somebody called createOrIncrefPool.
|
||||
graph_pools.emplace(
|
||||
mempool_id, std::make_unique<PrivatePool>(mempool_id));
|
||||
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
|
||||
} else {
|
||||
// mempool_id references an existing pool, which the current CUDAGraph
|
||||
// capture or torch.cuda.use_mem_pool will
|
||||
// share. Check this pool is live (at least one other capture already
|
||||
// references it). Increment it to establish the usage.
|
||||
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
|
||||
TORCH_INTERNAL_ASSERT(allocator == nullptr);
|
||||
it->second->use_count++;
|
||||
}
|
||||
}
|
||||
@ -2776,7 +2779,8 @@ class DeviceCachingAllocator {
|
||||
bool in_fbcode = false;
|
||||
#endif
|
||||
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
bool active_pool =
|
||||
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
|
||||
if (set_fraction &&
|
||||
total_allocated_memory + size > allowed_memory_maximum) {
|
||||
p.err = cudaErrorMemoryAllocation;
|
||||
@ -2801,12 +2805,6 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
return bool(p.block);
|
||||
} else {
|
||||
if (active_pool && active_pool->allocator() &&
|
||||
p.pool->owner_PrivatePool) {
|
||||
// Ensure that active_pool and p.pool are the same
|
||||
auto pp = get_private_pool(active_pool->id());
|
||||
TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool);
|
||||
}
|
||||
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
|
||||
// At scope exit, acquire the lock again. This provides safety against
|
||||
// any potential exceptions in the cudaMallocMaybeCapturing function.
|
||||
@ -2926,13 +2924,9 @@ class DeviceCachingAllocator {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
|
||||
MempoolId_t mempool_id = {0, 0};
|
||||
auto active_mempool = MemPoolContext::getActiveMemPool();
|
||||
if (active_mempool) {
|
||||
mempool_id = active_mempool->id();
|
||||
}
|
||||
|
||||
bool release_cached_blocks(
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
MempoolId_t mempool_id) {
|
||||
if (mempool_id.first == 0 && mempool_id.second == 0) {
|
||||
// If there is no active mempool, we work on releasing *all* blocks.
|
||||
|
||||
@ -3005,15 +2999,10 @@ class DeviceCachingAllocator {
|
||||
context ? context : block->context_when_segment_allocated);
|
||||
|
||||
auto* pool = block->pool;
|
||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
|
||||
// Ensure that active_pool and pool are the same
|
||||
auto pp = get_private_pool(active_pool->id());
|
||||
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
|
||||
|
||||
if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) {
|
||||
// If there is an active mempool with a given allocator,
|
||||
// we use the given allocator's delete function.
|
||||
active_pool->allocator()->raw_delete((void*)block->ptr);
|
||||
pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr);
|
||||
} else {
|
||||
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
|
||||
}
|
||||
@ -3589,9 +3578,9 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
void emptyCache() override {
|
||||
void emptyCache(MempoolId_t mempool_id) override {
|
||||
for (auto& da : device_allocator)
|
||||
da->emptyCache();
|
||||
da->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
void enable(bool value) override {
|
||||
@ -3639,7 +3628,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
device_allocator[block->device]->recordStream(block, stream);
|
||||
}
|
||||
|
||||
SnapshotInfo snapshot() override {
|
||||
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
|
||||
// Set-up converter to convert timestamps from tsc to microseconds.
|
||||
auto tsc_to_ns = clock_converter.makeConverter();
|
||||
auto tsc_to_us = [=](approx_time_t t_approx) {
|
||||
@ -3657,7 +3646,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
// Get the device_traces' TraceEntry lists.
|
||||
for (auto& da : device_allocator) {
|
||||
result.device_traces.emplace_back(da->trace(tsc_to_us));
|
||||
auto snap = da->snapshot();
|
||||
auto snap = da->snapshot(mempool_id);
|
||||
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
|
||||
}
|
||||
|
||||
@ -3785,11 +3774,13 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
device_allocator[device]->resetPeakStats();
|
||||
}
|
||||
|
||||
void ensureExistsAndIncrefPool(
|
||||
void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id) override {
|
||||
MempoolId_t mempool_id,
|
||||
CUDAAllocator* allocator) override {
|
||||
assertValidDevice(device);
|
||||
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
|
||||
device_allocator[device]->createOrIncrefPool(
|
||||
std::move(mempool_id), allocator);
|
||||
}
|
||||
|
||||
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
@ -4134,7 +4125,7 @@ MemPool::MemPool(
|
||||
id_ = {uuid_++, 0};
|
||||
}
|
||||
device_ = c10::cuda::current_device();
|
||||
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
|
||||
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||
if (use_on_oom) {
|
||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||
}
|
||||
@ -4143,8 +4134,7 @@ MemPool::MemPool(
|
||||
MemPool::~MemPool() {
|
||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||
CUDACachingAllocator::releasePool(device_, id_);
|
||||
auto ctx = MemPoolContext(this);
|
||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
||||
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||
}
|
||||
|
||||
MempoolId_t MemPool::id() {
|
||||
@ -4170,23 +4160,4 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
||||
return {uuid_++, 0};
|
||||
}
|
||||
|
||||
// Note that active_mempool_ is a global variable here
|
||||
// and not inside MemPoolContext class, because in windows we
|
||||
// can't use __declspec(dllexport) and __declspec(thread)
|
||||
// together: https://stackoverflow.com/a/50967977
|
||||
static thread_local MemPool* active_mempool_ = nullptr;
|
||||
|
||||
MemPoolContext::MemPoolContext(MemPool* mempool)
|
||||
: prev_mempool_(active_mempool_) {
|
||||
active_mempool_ = mempool;
|
||||
}
|
||||
|
||||
MemPoolContext::~MemPoolContext() {
|
||||
active_mempool_ = prev_mempool_;
|
||||
}
|
||||
|
||||
MemPool* MemPoolContext::getActiveMemPool() {
|
||||
return active_mempool_;
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator {
|
||||
virtual bool initialized() = 0;
|
||||
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
|
||||
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
|
||||
virtual void emptyCache() = 0;
|
||||
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
|
||||
virtual void enable(bool value) = 0;
|
||||
virtual bool isEnabled() const = 0;
|
||||
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
|
||||
@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator {
|
||||
c10::DeviceIndex device) = 0;
|
||||
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
virtual SnapshotInfo snapshot() = 0;
|
||||
virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0;
|
||||
virtual void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
@ -239,13 +239,14 @@ class CUDAAllocator : public Allocator {
|
||||
" does not yet support getPoolUseCount. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
virtual void ensureExistsAndIncrefPool(
|
||||
virtual void createOrIncrefPool(
|
||||
c10::DeviceIndex /*device*/,
|
||||
MempoolId_t /*mempool_id*/) {
|
||||
MempoolId_t /*mempool_id*/,
|
||||
CUDAAllocator* allocator = nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
name(),
|
||||
" does not yet support ensureExistsAndIncrefPool. "
|
||||
" does not yet support createOrIncrefPool. "
|
||||
"If you need it, please file an issue describing your use case.");
|
||||
}
|
||||
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
@ -364,7 +365,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
|
||||
return get()->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
inline void emptyCache() {
|
||||
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->emptyCache();
|
||||
}
|
||||
|
||||
@ -401,8 +402,8 @@ inline void resetPeakStats(c10::DeviceIndex device) {
|
||||
return get()->resetPeakStats(device);
|
||||
}
|
||||
|
||||
inline SnapshotInfo snapshot() {
|
||||
return get()->snapshot();
|
||||
inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->snapshot(mempool_id);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<AllocatorState> getCheckpointState(
|
||||
@ -475,10 +476,11 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
|
||||
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return get()->releasePool(device, mempool_id);
|
||||
}
|
||||
inline void ensureExistsAndIncrefPool(
|
||||
inline void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id) {
|
||||
get()->ensureExistsAndIncrefPool(device, mempool_id);
|
||||
MempoolId_t mempool_id,
|
||||
CUDAAllocator* allocator_ptr = nullptr) {
|
||||
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
|
||||
}
|
||||
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
get()->setUseOnOOM(device, mempool_id);
|
||||
@ -555,26 +557,4 @@ struct C10_CUDA_API MemPool {
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
// MemPoolContext holds the currently active pool and stashes the previous
|
||||
// pool. On deletion it makes the previous pool active.
|
||||
struct C10_CUDA_API MemPoolContext {
|
||||
MemPoolContext(MemPool* mempool);
|
||||
|
||||
~MemPoolContext();
|
||||
|
||||
// getActiveMemPool() can be used to get the currently active pool.
|
||||
// For instance: in CUDACachingAllocator, we can route allocations
|
||||
// to a user provided allocator, by doing:
|
||||
//
|
||||
// auto active_pool = MemPoolContext::getActiveMemPool();
|
||||
// if (active_pool && active_pool->allocator()) {
|
||||
// ptr = active_pool->allocator()->raw_alloc(size);
|
||||
// }
|
||||
//
|
||||
static MemPool* getActiveMemPool();
|
||||
|
||||
private:
|
||||
MemPool* prev_mempool_;
|
||||
};
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
||||
@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// introduces performance nondeterminism.
|
||||
}
|
||||
|
||||
void emptyCache() override {
|
||||
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
for (int dev = 0; dev < device_count; dev++) {
|
||||
@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
|
||||
}
|
||||
|
||||
SnapshotInfo snapshot() override {
|
||||
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "
|
||||
|
||||
35
c10/test/util/Semaphore_test.cpp
Normal file
35
c10/test/util/Semaphore_test.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
#include <c10/util/Semaphore.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <thread>
|
||||
|
||||
using namespace ::testing;
|
||||
|
||||
TEST(SemaphoreTest, TestConcurrency) {
|
||||
auto num_threads = std::thread::hardware_concurrency();
|
||||
auto num_incr = 10000;
|
||||
|
||||
c10::Semaphore sem;
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
for ([[maybe_unused]] const auto _ : c10::irange(num_threads)) {
|
||||
threads.emplace_back([num_incr = num_incr, &sem]() {
|
||||
for ([[maybe_unused]] const auto _ : c10::irange(num_incr)) {
|
||||
sem.release();
|
||||
}
|
||||
for ([[maybe_unused]] const auto _ : c10::irange(num_incr)) {
|
||||
sem.acquire();
|
||||
}
|
||||
sem.release(num_incr);
|
||||
for ([[maybe_unused]] const auto _ : c10::irange(num_incr)) {
|
||||
sem.acquire();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::for_each(
|
||||
threads.begin(), threads.end(), [](std::thread& t) { t.join(); });
|
||||
|
||||
EXPECT_FALSE(sem.tryAcquire());
|
||||
}
|
||||
@ -289,8 +289,8 @@ class C10_API OutOfMemoryError : public Error {
|
||||
using Error::Error;
|
||||
};
|
||||
|
||||
// Used for handling syntacitc erros in input arguments.
|
||||
// They shuld turn into SytnaxError when the cross into Python
|
||||
// Used for handling syntactic errors in input arguments.
|
||||
// These turn into SyntaxError when the cross into Python.
|
||||
class C10_API SyntaxError : public Error {
|
||||
using Error::Error;
|
||||
};
|
||||
|
||||
71
c10/util/Semaphore.h
Normal file
71
c10/util/Semaphore.h
Normal file
@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
#include <version>
|
||||
|
||||
/*
|
||||
a simple semaphore interface.
|
||||
*/
|
||||
|
||||
// note: __cpp_lib_semaphore will not be defined in some apple platforms
|
||||
// even if >= C++20.
|
||||
#if __has_include(<semaphore>) && defined(__cpp_lib_semaphore) && __cpp_lib_semaphore >= 201907L
|
||||
#define C10_SEMAPHORE_USE_STL
|
||||
#endif
|
||||
|
||||
#ifdef C10_SEMAPHORE_USE_STL
|
||||
#include <semaphore>
|
||||
#else
|
||||
// To use moodycamel semaphore, we need to include the header file
|
||||
// for concurrentqueue first. Hiding implementation detail here.
|
||||
#ifdef BLOCK_SIZE
|
||||
#pragma push_macro("BLOCK_SIZE")
|
||||
#undef BLOCK_SIZE
|
||||
#include <moodycamel/concurrentqueue.h> // @manual
|
||||
#pragma pop_macro("BLOCK_SIZE")
|
||||
#else
|
||||
#include <moodycamel/concurrentqueue.h> // @manual
|
||||
#endif
|
||||
|
||||
#include <moodycamel/lightweightsemaphore.h> // @manual
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class Semaphore {
|
||||
public:
|
||||
Semaphore(int32_t initial_count = 0) : impl_(initial_count) {}
|
||||
|
||||
void release(int32_t n = 1) {
|
||||
#ifdef C10_SEMAPHORE_USE_STL
|
||||
impl_.release(n);
|
||||
#else
|
||||
impl_.signal(n);
|
||||
#endif
|
||||
}
|
||||
|
||||
void acquire() {
|
||||
#ifdef C10_SEMAPHORE_USE_STL
|
||||
impl_.acquire();
|
||||
#else
|
||||
impl_.wait();
|
||||
#endif
|
||||
}
|
||||
|
||||
bool tryAcquire() {
|
||||
#ifdef C10_SEMAPHORE_USE_STL
|
||||
return impl_.try_acquire();
|
||||
#else
|
||||
return impl_.tryWait();
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef C10_SEMAPHORE_USE_STL
|
||||
std::counting_semaphore<> impl_;
|
||||
#else
|
||||
moodycamel::LightweightSemaphore impl_;
|
||||
#endif
|
||||
};
|
||||
} // namespace c10
|
||||
|
||||
#undef C10_SEMAPHORE_USE_STL
|
||||
@ -36,6 +36,7 @@ def define_targets(rules):
|
||||
":bit_cast",
|
||||
"//c10/macros",
|
||||
"@fmt",
|
||||
"@moodycamel//:moodycamel",
|
||||
] + rules.select({
|
||||
"//c10:using_gflags": ["@com_github_gflags_gflags//:gflags"],
|
||||
"//conditions:default": [],
|
||||
|
||||
@ -1154,6 +1154,7 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE)
|
||||
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS tensorpipe)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS nlohmann)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS moodycamel)
|
||||
if(USE_CUDA)
|
||||
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS tensorpipe_cuda)
|
||||
elseif(USE_ROCM)
|
||||
@ -1713,3 +1714,7 @@ target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_
|
||||
# Include nlohmann-json
|
||||
add_library(nlohmann INTERFACE IMPORTED)
|
||||
include_directories(nlohmann SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/nlohmann/include)
|
||||
|
||||
# Include moodycamel
|
||||
add_library(moodycamel INTERFACE IMPORTED)
|
||||
include_directories(moodycamel SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/concurrentqueue)
|
||||
|
||||
@ -2282,7 +2282,6 @@ coverage_ignore_classes = [
|
||||
"UnsynchronizedAccessError",
|
||||
# torch.cuda.memory
|
||||
"MemPool",
|
||||
"MemPoolContext",
|
||||
# torch.distributed.elastic.multiprocessing.errors
|
||||
"ChildFailedError",
|
||||
"ProcessFailure",
|
||||
|
||||
@ -128,7 +128,6 @@ Memory management
|
||||
CUDAPluggableAllocator
|
||||
change_current_allocator
|
||||
MemPool
|
||||
MemPoolContext
|
||||
|
||||
.. currentmodule:: torch.cuda.memory
|
||||
|
||||
|
||||
19
setup.py
19
setup.py
@ -748,6 +748,25 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
|
||||
self.copy_file(export_lib, target_lib)
|
||||
|
||||
# In ROCm on Windows case copy rocblas and hipblaslt files into
|
||||
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
|
||||
use_rocm = os.environ.get("USE_ROCM")
|
||||
if use_rocm:
|
||||
rocm_dir_path = os.environ.get("ROCM_DIR")
|
||||
rocm_bin_path = os.path.join(rocm_dir_path, "bin")
|
||||
|
||||
rocblas_dir = os.path.join(rocm_bin_path, "rocblas")
|
||||
target_rocblas_dir = os.path.join(target_dir, "rocblas")
|
||||
os.makedirs(target_rocblas_dir, exist_ok=True)
|
||||
self.copy_tree(rocblas_dir, target_rocblas_dir)
|
||||
|
||||
hipblaslt_dir = os.path.join(rocm_bin_path, "hipblaslt")
|
||||
target_hipblaslt_dir = os.path.join(target_dir, "hipblaslt")
|
||||
os.makedirs(target_hipblaslt_dir, exist_ok=True)
|
||||
self.copy_tree(hipblaslt_dir, target_hipblaslt_dir)
|
||||
else:
|
||||
report("The specified environment variable does not exist.")
|
||||
|
||||
def build_extensions(self):
|
||||
self.create_compile_commands()
|
||||
|
||||
|
||||
135
test/distributed/checkpoint/test_consolidate_hf_safetensors.py
Normal file
135
test/distributed/checkpoint/test_consolidate_hf_safetensors.py
Normal file
@ -0,0 +1,135 @@
|
||||
# Owner(s): ["oncall: distributed_checkpointing"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import (
|
||||
consolidate_safetensors_files,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import DTensor, Shard
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
class TestConsolidateHFSafeTensors(DTensorTestBase):
|
||||
def _create_d_tensors(self) -> None:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
# Create local tensor with row-wise sharding
|
||||
rows_per_rank = global_tensor.shape[0] // self.world_size
|
||||
start_row = self.rank * rows_per_rank
|
||||
end_row = start_row + rows_per_rank
|
||||
local_tensor = global_tensor[start_row:end_row].clone()
|
||||
|
||||
# Create DTensor with row-wise sharding
|
||||
dtensor = DTensor.from_local(
|
||||
local_tensor,
|
||||
device_mesh=mesh_1d,
|
||||
placements=[Shard(0)],
|
||||
shape=global_tensor.shape,
|
||||
stride=(4, 1),
|
||||
)
|
||||
|
||||
# Create local tensor with column-wise sharding
|
||||
cols_per_rank = global_tensor.shape[1] // self.world_size
|
||||
start_col = self.rank * cols_per_rank
|
||||
end_col = start_col + cols_per_rank
|
||||
local_tensor_col = global_tensor[:, start_col:end_col].clone()
|
||||
|
||||
# Create DTensor with column-wise sharding
|
||||
dtensor_col = DTensor.from_local(
|
||||
local_tensor_col,
|
||||
device_mesh=mesh_1d,
|
||||
placements=[Shard(1)], # Column-wise sharding
|
||||
shape=global_tensor.shape,
|
||||
stride=(4, 1),
|
||||
)
|
||||
|
||||
state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col}
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=self.temp_dir, save_sharded=True
|
||||
),
|
||||
)
|
||||
dist.barrier()
|
||||
os.sync()
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_consolidate_to_one_file(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
checkpoint_dir = self.temp_dir
|
||||
output_dir = os.path.join(checkpoint_dir, "consolidated")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
self._create_d_tensors()
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
|
||||
if self.rank == 0:
|
||||
consolidate_safetensors_files(checkpoint_dir, output_dir)
|
||||
|
||||
file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors")
|
||||
loaded_dict = safetensors.torch.load_file(file_path)
|
||||
self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"})
|
||||
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
|
||||
self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor))
|
||||
dist.barrier()
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_consolidate_to_two_files(self):
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
checkpoint_dir = self.temp_dir
|
||||
output_dir = os.path.join(checkpoint_dir, "consolidated")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
self._create_d_tensors()
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
|
||||
if self.rank == 0:
|
||||
fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2}
|
||||
consolidate_safetensors_files(
|
||||
checkpoint_dir, output_dir, fqn_to_index_mapping
|
||||
)
|
||||
|
||||
file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors")
|
||||
file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors")
|
||||
|
||||
loaded_dict = safetensors.torch.load_file(file1_path)
|
||||
self.assertEqual(loaded_dict.keys(), {"dtensor"})
|
||||
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
|
||||
|
||||
loaded_dict_col = safetensors.torch.load_file(file2_path)
|
||||
self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"})
|
||||
self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor))
|
||||
dist.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
420
test/distributed/checkpoint/test_hf_safetensor_e2e.py
Normal file
420
test/distributed/checkpoint/test_hf_safetensor_e2e.py
Normal file
@ -0,0 +1,420 @@
|
||||
# Owner(s): ["oncall: distributed_checkpointing"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed.checkpoint import _HuggingFaceLoadPlanner
|
||||
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
CHECKPOINT_DIR = "checkpoint"
|
||||
|
||||
|
||||
class MyTestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(5, 5)
|
||||
self.linear_2 = torch.nn.Linear(5, 1)
|
||||
self.emb = torch.nn.EmbeddingBag(5, 10)
|
||||
|
||||
class TestSingleRankSaveLoad(TestCase):
|
||||
@with_temp_dir
|
||||
def test_save(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
)
|
||||
|
||||
state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
|
||||
|
||||
@with_temp_dir
|
||||
def test_load(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
state_dict_to_load = MyTestModule().state_dict()
|
||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
|
||||
|
||||
@with_temp_dir
|
||||
def test_load_into_empty_dict(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
|
||||
state_dict_loaded = _load_state_dict_from_keys(
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
|
||||
|
||||
@with_temp_dir
|
||||
def test_load_allowing_resize(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
|
||||
state_dict_to_load= {}
|
||||
for key in state_dict_to_save.keys():
|
||||
state_dict_to_load[key] = torch.zeros(1)
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
|
||||
|
||||
ONE_D_PLACEMENTS = [
|
||||
[Shard(0)],
|
||||
[Replicate()],
|
||||
]
|
||||
ONE_D_TO_ONE_D_PLACEMENTS = [
|
||||
([Replicate()], [Shard(0)]),
|
||||
([Shard(0)], [Replicate()]),
|
||||
]
|
||||
|
||||
TWO_D_PLACEMENTS = [
|
||||
[Replicate(), Replicate()],
|
||||
[Replicate(), Shard(0)],
|
||||
[Shard(0), Replicate()],
|
||||
[Shard(0), Shard(0)],
|
||||
]
|
||||
TWO_D_TO_TWO_D_PLACEMENTS = []
|
||||
for p1 in TWO_D_PLACEMENTS:
|
||||
for p2 in TWO_D_PLACEMENTS:
|
||||
if p1 != p2:
|
||||
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
|
||||
"""
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
def test_1d_to_1d_reshard_placement_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
|
||||
original_placement, new_placement = one_d_to_one_d_placements
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, device_mesh, placements=original_placement
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=CHECKPOINT_DIR,
|
||||
save_sharded=True,
|
||||
),
|
||||
)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=device_mesh, placements=new_placement
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
CHECKPOINT_DIR,
|
||||
),
|
||||
)
|
||||
|
||||
# materialize the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
device_mesh,
|
||||
placements=[Replicate()],
|
||||
)
|
||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
||||
|
||||
# redistribute the tensor back to its original placement for comparison.
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
device_mesh,
|
||||
placements=original_placement,
|
||||
)
|
||||
self.assertEqual(
|
||||
state_dict_to_save["dtensor"].to_local(),
|
||||
state_dict_to_load["dtensor"].to_local(),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_2d_to_2d_reshard_placement_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS:
|
||||
original_placement, new_placement = two_d_to_two_d_placements
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor,
|
||||
mesh_2d,
|
||||
placements=original_placement,
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
||||
)
|
||||
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
mesh_2d,
|
||||
placements=[Replicate(), Replicate()],
|
||||
)
|
||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
||||
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
mesh_2d,
|
||||
placements=original_placement,
|
||||
)
|
||||
self.assertEqual(
|
||||
state_dict_to_save["dtensor"].to_local(),
|
||||
state_dict_to_load["dtensor"].to_local(),
|
||||
)
|
||||
|
||||
|
||||
class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
|
||||
"""
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_1d_to_2d_reshard_mesh_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for placements_1d in ONE_D_PLACEMENTS:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, mesh_1d, placements=placements_1d
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
||||
)
|
||||
|
||||
for placements_2d in TWO_D_PLACEMENTS:
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=mesh_2d, placements=placements_2d
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
||||
"dtensor"
|
||||
].redistribute(
|
||||
mesh_2d,
|
||||
placements=[Replicate(), Replicate()],
|
||||
)
|
||||
self.assertEqual(
|
||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_to_1d_reshard_mesh_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for placements_2d in TWO_D_PLACEMENTS:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, mesh_2d, placements=placements_2d
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
for placements_1d in ONE_D_PLACEMENTS:
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=mesh_1d, placements=placements_1d
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
||||
"dtensor"
|
||||
].redistribute(
|
||||
mesh_1d,
|
||||
placements=[Replicate()],
|
||||
)
|
||||
self.assertEqual(
|
||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_dtensor_checkpoint_resharding_with_empty_shard(self):
|
||||
"""
|
||||
Test dtensor checkpoint resharding with dtensor containing empty shards.
|
||||
"""
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
ref_state_dict = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=ref_state_dict,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True),
|
||||
)
|
||||
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
|
||||
state_dict = {"dtensor": dtensor}
|
||||
dist_cp.load(
|
||||
state_dict=state_dict,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -8,10 +8,7 @@ import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._hf_planner import (
|
||||
_FqnToFileMapping,
|
||||
_HuggingFaceLoadPlanner,
|
||||
)
|
||||
from torch.distributed.checkpoint import DefaultLoadPlanner
|
||||
from torch.distributed.checkpoint._hf_storage import (
|
||||
_HuggingFaceStorageReader,
|
||||
_HuggingFaceStorageWriter,
|
||||
@ -21,24 +18,25 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
||||
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
TensorProperties,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import LoadPlan, SavePlan
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
_create_read_items,
|
||||
_create_write_item_for_tensor,
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
LoadItemType,
|
||||
LoadPlan,
|
||||
ReadItem,
|
||||
SavePlan,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestHfStorage(TestCase):
|
||||
def test_write_data_hf(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["safetensors"] = mock_module
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.save.return_value = b""
|
||||
sys.modules["safetensors.torch"] = mock_module
|
||||
@ -46,7 +44,7 @@ class TestHfStorage(TestCase):
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
writer = _HuggingFaceStorageWriter(
|
||||
path=path,
|
||||
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1},
|
||||
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2},
|
||||
)
|
||||
writer.fs = FileSystem()
|
||||
|
||||
@ -59,7 +57,7 @@ class TestHfStorage(TestCase):
|
||||
|
||||
save_plan = SavePlan(
|
||||
[write_item_1, write_item_2],
|
||||
storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}),
|
||||
storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}},
|
||||
)
|
||||
save_planner = DefaultSavePlanner()
|
||||
save_planner.set_up_planner(state_dict=state_dict)
|
||||
@ -76,7 +74,7 @@ class TestHfStorage(TestCase):
|
||||
),
|
||||
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="model-00001-of-00001.safetensors",
|
||||
relative_path="model-00001-of-00002.safetensors",
|
||||
offset=0,
|
||||
length=tensor0.numel() * tensor0.element_size(),
|
||||
),
|
||||
@ -87,7 +85,68 @@ class TestHfStorage(TestCase):
|
||||
),
|
||||
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="model-00001-of-00001.safetensors",
|
||||
relative_path="model-00002-of-00002.safetensors",
|
||||
offset=0,
|
||||
length=tensor1.numel() * tensor1.element_size(),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
actual_write_results,
|
||||
expected_write_results,
|
||||
)
|
||||
|
||||
def test_write_data_with_sharding(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
mock_module.save.return_value = b""
|
||||
sys.modules["safetensors.torch"] = mock_module
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
writer = _HuggingFaceStorageWriter(
|
||||
path=path,
|
||||
save_sharded=True,
|
||||
)
|
||||
writer.fs = FileSystem()
|
||||
|
||||
tensor0 = torch.rand(4)
|
||||
tensor1 = torch.rand(10)
|
||||
write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0)
|
||||
write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1)
|
||||
|
||||
state_dict = {"tensor_0": tensor0, "tensor_1": tensor1}
|
||||
|
||||
save_plan = SavePlan(
|
||||
[write_item_1, write_item_2],
|
||||
storage_data={"shard_index": 1},
|
||||
)
|
||||
save_planner = DefaultSavePlanner()
|
||||
save_planner.set_up_planner(state_dict=state_dict)
|
||||
|
||||
write_results = writer.write_data(save_plan, save_planner)
|
||||
|
||||
write_results.wait()
|
||||
actual_write_results = write_results.value()
|
||||
|
||||
expected_write_results = [
|
||||
WriteResult(
|
||||
index=MetadataIndex(
|
||||
fqn="tensor_0", offset=torch.Size([0]), index=None
|
||||
),
|
||||
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="shard-00001-model-00001-of-00001.safetensors",
|
||||
offset=0,
|
||||
length=tensor0.numel() * tensor0.element_size(),
|
||||
),
|
||||
),
|
||||
WriteResult(
|
||||
index=MetadataIndex(
|
||||
fqn="tensor_1", offset=torch.Size([0]), index=None
|
||||
),
|
||||
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="shard-00001-model-00001-of-00001.safetensors",
|
||||
offset=0,
|
||||
length=tensor1.numel() * tensor1.element_size(),
|
||||
),
|
||||
@ -100,43 +159,84 @@ class TestHfStorage(TestCase):
|
||||
)
|
||||
|
||||
def test_read_data_hf(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["safetensors"] = mock_module
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
mock_safetensors = MagicMock()
|
||||
sys.modules["safetensors"] = mock_safetensors
|
||||
|
||||
name = "tensor_0"
|
||||
tensor_0 = torch.rand(4)
|
||||
mock_module = MagicMock()
|
||||
mock_module.load.return_value = {name: tensor_0}
|
||||
sys.modules["safetensors.torch"] = mock_module
|
||||
# Create test tensors
|
||||
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
# Mock the deserialize function to return our test tensors
|
||||
# The format matches what's expected in the read_data method
|
||||
mock_safetensors.deserialize.return_value = [
|
||||
("tensor_0", {
|
||||
"data": tensor_0.numpy().tobytes(),
|
||||
"dtype": "F32",
|
||||
"shape": [4]
|
||||
}),
|
||||
]
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
# Create the reader
|
||||
reader = _HuggingFaceStorageReader(path=path)
|
||||
reader.fs = FileSystem()
|
||||
file_name = "model-00001-of-00001"
|
||||
|
||||
pathlib.Path(os.path.join(path, file_name)).touch()
|
||||
# Create test file
|
||||
file_name = "model-00001-of-00001.safetensors"
|
||||
file_path = os.path.join(path, file_name)
|
||||
pathlib.Path(file_path).touch()
|
||||
|
||||
reader.set_up_storage_reader(
|
||||
Metadata(
|
||||
state_dict_metadata={name: BytesStorageMetadata()},
|
||||
storage_data={name: file_name},
|
||||
),
|
||||
is_coordinator=True,
|
||||
)
|
||||
# Set up storage data with _StorageInfo objects
|
||||
storage_data = {
|
||||
"tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()),
|
||||
}
|
||||
|
||||
read_items = _create_read_items(name, BytesStorageMetadata(), file_name)
|
||||
|
||||
reader.storage_data = storage_data
|
||||
|
||||
# Create target tensors that will be updated by read_data
|
||||
target_tensor_0 = torch.zeros(4)
|
||||
state_dict = {
|
||||
"tensor_0": target_tensor_0,
|
||||
}
|
||||
|
||||
# Create read items for the load plan
|
||||
read_items = []
|
||||
for name, tensor in state_dict.items():
|
||||
storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
|
||||
dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
|
||||
read_items.append(
|
||||
ReadItem(
|
||||
type=LoadItemType.TENSOR,
|
||||
storage_index=storage_index,
|
||||
dest_index=dest_index,
|
||||
storage_offsets=[0, 0],
|
||||
dest_offsets=[0, 0],
|
||||
lengths=tensor.size(),
|
||||
)
|
||||
)
|
||||
|
||||
# Create load plan and planner
|
||||
load_plan = LoadPlan(read_items)
|
||||
load_planner = _HuggingFaceLoadPlanner()
|
||||
load_planner.set_up_planner(state_dict={name: torch.rand(4)})
|
||||
load_planner = DefaultLoadPlanner()
|
||||
load_planner.set_up_planner(
|
||||
state_dict=state_dict,
|
||||
metadata=Metadata(
|
||||
state_dict_metadata={
|
||||
"tensor_0": TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=torch.float32),
|
||||
size=torch.Size([4]),
|
||||
chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])},
|
||||
storage_data=storage_data)
|
||||
)
|
||||
|
||||
read_data = reader.read_data(load_plan, load_planner)
|
||||
read_data.wait()
|
||||
# Call read_data
|
||||
future = reader.read_data(load_plan, load_planner)
|
||||
future.wait()
|
||||
|
||||
loaded_tensor = load_planner.original_state_dict[name]
|
||||
self.assertEqual(loaded_tensor, tensor_0)
|
||||
# Verify results - the target tensors should now contain the values from our test tensor
|
||||
self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0))
|
||||
|
||||
def test_metadata_hf(self) -> None:
|
||||
def test_write_metadata_hf(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
@ -160,7 +260,6 @@ class TestHfStorage(TestCase):
|
||||
|
||||
writer = _HuggingFaceStorageWriter(
|
||||
path=path,
|
||||
fqn_to_index_mapping=_FqnToFileMapping({}),
|
||||
)
|
||||
writer.fs = FileSystem()
|
||||
writer.finish(
|
||||
@ -185,26 +284,16 @@ class TestHfStorage(TestCase):
|
||||
metadata = json.load(f)
|
||||
self.assertEqual(metadata, expected_metadata)
|
||||
|
||||
reader = _HuggingFaceStorageReader(path=path)
|
||||
reader.fs = FileSystem()
|
||||
metadata = reader.read_metadata()
|
||||
self.assertEqual(metadata.storage_data, expected_metadata["weight_map"])
|
||||
|
||||
def test_read_metadata_when_metadata_file_does_not_exist(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
|
||||
def test_read_metadata_hf(self):
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
reader = _HuggingFaceStorageReader(path=path)
|
||||
reader.fs = FileSystem()
|
||||
# there is one safetensor file, but no metadata file,
|
||||
# so we create metadata from the safetensor file
|
||||
keys = ["tensor_0", "tensor_1"]
|
||||
|
||||
key = "tensor_0"
|
||||
file_name = "test.safetensors"
|
||||
with open(os.path.join(path, file_name), "wb") as f:
|
||||
# write metadata the same way it would be in safetensors file
|
||||
metadata_contents = json.dumps(
|
||||
{"tensor_0": "value_0", "tensor_1": "value_1"}
|
||||
{'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}}
|
||||
)
|
||||
metadata_bytes = metadata_contents.encode("utf-8")
|
||||
|
||||
@ -216,13 +305,16 @@ class TestHfStorage(TestCase):
|
||||
self.assertEqual(
|
||||
metadata.state_dict_metadata,
|
||||
{
|
||||
keys[0]: BytesStorageMetadata(),
|
||||
keys[1]: BytesStorageMetadata(),
|
||||
key: TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=torch.float32),
|
||||
size=torch.Size([5, 10]),
|
||||
chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))],
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
metadata.storage_data,
|
||||
{keys[0]: file_name, keys[1]: file_name},
|
||||
{key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -109,6 +109,27 @@ class MLPModule(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class MLPKWargModule(torch.nn.Module):
|
||||
def __init__(self, d_hid: int, layer_num):
|
||||
super().__init__()
|
||||
self.net1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.net2 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.layer_num = layer_num
|
||||
|
||||
def forward(self, x, unused_kwarg: torch.Tensor = torch.zeros(1)):
|
||||
x = self.net1(x)
|
||||
x = self.relu(x)
|
||||
x = self.net2(x)
|
||||
# Test when only 1 module has extra outputs
|
||||
# TODO: handle this case later
|
||||
# if self.layer_num == 0:
|
||||
# return x, unused_kwarg
|
||||
# else:
|
||||
# return x
|
||||
return x
|
||||
|
||||
|
||||
# Multi-MLP model
|
||||
class MultiMLP(torch.nn.Module):
|
||||
def __init__(self, d_hid: int, n_layers: int = 2):
|
||||
@ -125,6 +146,29 @@ class MultiMLP(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Multi-MLP with kwargs model
|
||||
class MultiMLPKwargs(torch.nn.Module):
|
||||
def __init__(self, d_hid: int, n_layers: int = 2):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[MLPKWargModule(d_hid, i) for i in range(n_layers)]
|
||||
)
|
||||
# For testing purpose only, this should be defined by user
|
||||
self.split_spec = {
|
||||
f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
|
||||
}
|
||||
|
||||
def forward(self, x, unused_kwarg: torch.Tensor = torch.zeros(1)):
|
||||
for layer in self.layers:
|
||||
# TODO: handle this case later
|
||||
# if layer.layer_num == 0:
|
||||
# x, _ = layer(x, unused_kwarg)
|
||||
# else:
|
||||
# x = layer(x)
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class CustomLinearDx(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_val, weight, bias, module, layer_idx):
|
||||
|
||||
@ -4,7 +4,7 @@ import copy
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw
|
||||
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPKwargs, MultiMLPWithDw
|
||||
from schedule_registry import (
|
||||
ScheduleUnbalanced,
|
||||
ScheduleVShaped,
|
||||
@ -946,6 +946,113 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
ref_p = ref_submod.get_parameter(name)
|
||||
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
|
||||
)
|
||||
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
full_mod = MultiMLPKwargs(d_hid, n_layers=n_stages)
|
||||
full_mod.to(self.device)
|
||||
|
||||
ref_mod = copy.deepcopy(full_mod)
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
unused_kwarg = torch.tensor([1.0], device=self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
y = ref_mod(x)
|
||||
# Add a small perturbation
|
||||
target = y + torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
|
||||
# Get a submodule, e.g. `layers.0` or `layers.1`
|
||||
stage_indices = [
|
||||
self.rank + i * self.world_size for i in range(stages_per_rank)
|
||||
]
|
||||
submod_names = [f"layers.{i}" for i in stage_indices]
|
||||
stage_modules = [
|
||||
full_mod.get_submodule(submod_name) for submod_name in submod_names
|
||||
]
|
||||
# Run reference
|
||||
for _ in range(2):
|
||||
ref_stage_modules = [
|
||||
ref_mod.get_submodule(submod_name) for submod_name in submod_names
|
||||
]
|
||||
for stage_module in ref_stage_modules:
|
||||
stage_module.zero_grad()
|
||||
|
||||
ref_mod.zero_grad()
|
||||
ref_out = ref_mod(x, unused_kwarg=unused_kwarg)
|
||||
ref_loss = loss_fn(ref_out, target)
|
||||
ref_loss.backward()
|
||||
|
||||
# Create a pipeline stage to wrap that submodule
|
||||
stages = [
|
||||
PipelineStage(
|
||||
stage_module,
|
||||
stage_idx,
|
||||
n_stages,
|
||||
self.device,
|
||||
)
|
||||
for stage_module, stage_idx in zip(stage_modules, stage_indices)
|
||||
]
|
||||
|
||||
# Attach to a schedule
|
||||
num_microbatches = (
|
||||
ScheduleClass.num_microbatches
|
||||
if hasattr(ScheduleClass, "num_microbatches")
|
||||
else 2 * self.world_size
|
||||
)
|
||||
schedule = ScheduleClass(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
|
||||
for _ in range(2):
|
||||
# Zero gradients
|
||||
for stage_module in stage_modules:
|
||||
stage_module.zero_grad()
|
||||
if self.rank == 0:
|
||||
schedule.step(
|
||||
x,
|
||||
unused_kwarg=unused_kwarg.clone()
|
||||
.unsqueeze(0)
|
||||
.expand(num_microbatches, -1),
|
||||
)
|
||||
elif self.rank == self.world_size - 1:
|
||||
losses = []
|
||||
out = schedule.step(target=target, losses=losses)
|
||||
else:
|
||||
schedule.step()
|
||||
|
||||
dist.barrier()
|
||||
# Last rank checks result
|
||||
if self.rank == self.world_size - 1:
|
||||
# Check output
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
|
||||
# Check loss
|
||||
pipe_loss = sum(losses)
|
||||
torch.testing.assert_close(pipe_loss, ref_loss)
|
||||
|
||||
# Every rank checks gradients
|
||||
for stage_module, submod_name in zip(stage_modules, submod_names):
|
||||
# Get corresponding submodule from reference model
|
||||
ref_submod = ref_mod.get_submodule(submod_name)
|
||||
# Check gradients per parameter
|
||||
for name, p in stage_module.named_parameters():
|
||||
ref_p = ref_submod.get_parameter(name)
|
||||
try:
|
||||
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=5e-3)
|
||||
except AssertionError:
|
||||
print(
|
||||
f"Gradient test failed for {name}: {p.grad=} vs {ref_p.grad=}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ScheduleTest)
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ from torch.testing._internal.common_distributed import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_SANDCASTLE,
|
||||
MI300_ARCH,
|
||||
parametrize,
|
||||
retry_on_connect_failures,
|
||||
@ -286,13 +287,15 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
# These tests are expected to throw SIGABRT(6); adding the negative sign
|
||||
# bc the test return code is actually -6
|
||||
# But if we are in Sandcastle, `skip_but_pass_in_sandcastle` would return 0.
|
||||
TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else -signal.SIGABRT
|
||||
self.special_return_code_checks = {
|
||||
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN,
|
||||
self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN,
|
||||
self.test_nan_assert_float64.__wrapped__: TEST_NAN_ASSERT_RETURN,
|
||||
self.test_nan_assert_bfloat16.__wrapped__: TEST_NAN_ASSERT_RETURN,
|
||||
self.test_nan_assert_float8_e4m3fn.__wrapped__: TEST_NAN_ASSERT_RETURN,
|
||||
self.test_nan_assert_float8_e5m2.__wrapped__: TEST_NAN_ASSERT_RETURN,
|
||||
}
|
||||
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
|
||||
231
test/dynamo/cpython/3_13/test_complex.diff
Normal file
231
test/dynamo/cpython/3_13/test_complex.diff
Normal file
@ -0,0 +1,231 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
|
||||
index 6ff1a8ab29d..ab5bd3dab62 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_complex.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_complex.py
|
||||
@@ -1,16 +1,143 @@
|
||||
+# ======= BEGIN Dynamo patch =======
|
||||
+# Owner(s): ["module: dynamo"]
|
||||
+
|
||||
+# ruff: noqa
|
||||
+# flake8: noqa
|
||||
+
|
||||
+import sys
|
||||
+import torch
|
||||
+import torch._dynamo.test_case
|
||||
import unittest
|
||||
+from torch._dynamo.test_case import CPythonTestCase
|
||||
+from torch.testing._internal.common_utils import (
|
||||
+ run_tests,
|
||||
+ xfailIfTorchDynamo,
|
||||
+)
|
||||
+
|
||||
+__TestCase = CPythonTestCase
|
||||
+
|
||||
+
|
||||
+# redirect import statements
|
||||
import sys
|
||||
-from test import support
|
||||
-from test.support.testcase import ComplexesAreIdenticalMixin
|
||||
-from test.support.numbers import (
|
||||
- VALID_UNDERSCORE_LITERALS,
|
||||
- INVALID_UNDERSCORE_LITERALS,
|
||||
+import importlib.abc
|
||||
+
|
||||
+redirect_imports = (
|
||||
+ "test.mapping_tests",
|
||||
+ "test.typinganndata",
|
||||
+ "test.test_grammar",
|
||||
+ "test.test_math",
|
||||
+ "test.test_iter",
|
||||
+ "test.typinganndata.ann_module",
|
||||
)
|
||||
|
||||
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
+ def find_spec(self, fullname, path, target=None):
|
||||
+ # Check if the import is the problematic one
|
||||
+ if fullname in redirect_imports:
|
||||
+ try:
|
||||
+ # Attempt to import the standalone module
|
||||
+ name = fullname.removeprefix("test.")
|
||||
+ r = importlib.import_module(name)
|
||||
+ # Redirect the module in sys.modules
|
||||
+ sys.modules[fullname] = r
|
||||
+ # Return a module spec from the found module
|
||||
+ return importlib.util.find_spec(name)
|
||||
+ except ImportError:
|
||||
+ return None
|
||||
+ return None
|
||||
+
|
||||
+# Add the custom finder to sys.meta_path
|
||||
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||
+
|
||||
+
|
||||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
+import unittest
|
||||
+import sys
|
||||
+from test import support
|
||||
+from test.support.testcase import ComplexesAreIdenticalMixin
|
||||
from random import random
|
||||
from math import isnan, copysign
|
||||
+import math
|
||||
import operator
|
||||
|
||||
+VALID_UNDERSCORE_LITERALS = [
|
||||
+ '0_0_0',
|
||||
+ '4_2',
|
||||
+ '1_0000_0000',
|
||||
+ '0b1001_0100',
|
||||
+ '0xffff_ffff',
|
||||
+ '0o5_7_7',
|
||||
+ '1_00_00.5',
|
||||
+ '1_00_00.5e5',
|
||||
+ '1_00_00e5_1',
|
||||
+ '1e1_0',
|
||||
+ '.1_4',
|
||||
+ '.1_4e1',
|
||||
+ '0b_0',
|
||||
+ '0x_f',
|
||||
+ '0o_5',
|
||||
+ '1_00_00j',
|
||||
+ '1_00_00.5j',
|
||||
+ '1_00_00e5_1j',
|
||||
+ '.1_4j',
|
||||
+ '(1_2.5+3_3j)',
|
||||
+ '(.5_6j)',
|
||||
+]
|
||||
+INVALID_UNDERSCORE_LITERALS = [
|
||||
+ # Trailing underscores:
|
||||
+ '0_',
|
||||
+ '42_',
|
||||
+ '1.4j_',
|
||||
+ '0x_',
|
||||
+ '0b1_',
|
||||
+ '0xf_',
|
||||
+ '0o5_',
|
||||
+ '0 if 1_Else 1',
|
||||
+ # Underscores in the base selector:
|
||||
+ '0_b0',
|
||||
+ '0_xf',
|
||||
+ '0_o5',
|
||||
+ # Old-style octal, still disallowed:
|
||||
+ '0_7',
|
||||
+ '09_99',
|
||||
+ # Multiple consecutive underscores:
|
||||
+ '4_______2',
|
||||
+ '0.1__4',
|
||||
+ '0.1__4j',
|
||||
+ '0b1001__0100',
|
||||
+ '0xffff__ffff',
|
||||
+ '0x___',
|
||||
+ '0o5__77',
|
||||
+ '1e1__0',
|
||||
+ '1e1__0j',
|
||||
+ # Underscore right before a dot:
|
||||
+ '1_.4',
|
||||
+ '1_.4j',
|
||||
+ # Underscore right after a dot:
|
||||
+ '1._4',
|
||||
+ '1._4j',
|
||||
+ '._5',
|
||||
+ '._5j',
|
||||
+ # Underscore right after a sign:
|
||||
+ '1.0e+_1',
|
||||
+ '1.0e+_1j',
|
||||
+ # Underscore right before j:
|
||||
+ '1.4_j',
|
||||
+ '1.4e5_j',
|
||||
+ # Underscore right before e:
|
||||
+ '1_e1',
|
||||
+ '1.4_e1',
|
||||
+ '1.4_e1j',
|
||||
+ # Underscore right after e:
|
||||
+ '1e_1',
|
||||
+ '1.4e_1',
|
||||
+ '1.4e_1j',
|
||||
+ # Complex cases with parens:
|
||||
+ '(1+1.5_j_)',
|
||||
+ '(1+1.5_j)',
|
||||
+]
|
||||
+
|
||||
INF = float("inf")
|
||||
NAN = float("nan")
|
||||
DBL_MAX = sys.float_info.max
|
||||
@@ -45,7 +172,40 @@ class WithComplex:
|
||||
def __complex__(self):
|
||||
return self.value
|
||||
|
||||
-class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
+class ComplexTest(__TestCase):
|
||||
+
|
||||
+ def assertFloatIdentical(self, x, y):
|
||||
+ """Fail unless floats x and y are identical, in the sense that:
|
||||
+ (1) both x and y are nans, or
|
||||
+ (2) both x and y are infinities, with the same sign, or
|
||||
+ (3) both x and y are zeros, with the same sign, or
|
||||
+ (4) x and y are both finite and nonzero, and x == y
|
||||
+
|
||||
+ """
|
||||
+ msg = 'floats {!r} and {!r} are not identical'
|
||||
+
|
||||
+ if math.isnan(x) or math.isnan(y):
|
||||
+ if math.isnan(x) and math.isnan(y):
|
||||
+ return
|
||||
+ elif x == y:
|
||||
+ if x != 0.0:
|
||||
+ return
|
||||
+ # both zero; check that signs match
|
||||
+ elif math.copysign(1.0, x) == math.copysign(1.0, y):
|
||||
+ return
|
||||
+ else:
|
||||
+ msg += ': zeros have different signs'
|
||||
+ self.fail(msg.format(x, y))
|
||||
+
|
||||
+ def assertComplexesAreIdentical(self, x, y):
|
||||
+ """Fail unless complex numbers x and y have equal values and signs.
|
||||
+
|
||||
+ In particular, if x and y both have real (or imaginary) part
|
||||
+ zero, but the zeros have different signs, this test will fail.
|
||||
+
|
||||
+ """
|
||||
+ self.assertFloatIdentical(x.real, y.real)
|
||||
+ self.assertFloatIdentical(x.imag, y.imag)
|
||||
|
||||
def assertAlmostEqual(self, a, b):
|
||||
if isinstance(a, complex):
|
||||
@@ -74,6 +234,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
# check that relative difference < eps
|
||||
self.assertTrue(abs((x-y)/y) < eps)
|
||||
|
||||
+ def assertFloatsAreIdentical(self, x, y):
|
||||
+ """assert that floats x and y are identical, in the sense that:
|
||||
+ (1) both x and y are nans, or
|
||||
+ (2) both x and y are infinities, with the same sign, or
|
||||
+ (3) both x and y are zeros, with the same sign, or
|
||||
+ (4) x and y are both finite and nonzero, and x == y
|
||||
+
|
||||
+ """
|
||||
+ msg = 'floats {!r} and {!r} are not identical'
|
||||
+
|
||||
+ if isnan(x) or isnan(y):
|
||||
+ if isnan(x) and isnan(y):
|
||||
+ return
|
||||
+ elif x == y:
|
||||
+ if x != 0.0:
|
||||
+ return
|
||||
+ # both zero; check that signs match
|
||||
+ elif copysign(1.0, x) == copysign(1.0, y):
|
||||
+ return
|
||||
+ else:
|
||||
+ msg += ': zeros have different signs'
|
||||
+ self.fail(msg.format(x, y))
|
||||
+
|
||||
def assertClose(self, x, y, eps=1e-9):
|
||||
"""Return true iff complexes x and y "are close"."""
|
||||
self.assertCloseAbs(x.real, y.real, eps)
|
||||
@@ -855,4 +1038,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
1041
test/dynamo/cpython/3_13/test_complex.py
Normal file
1041
test/dynamo/cpython/3_13/test_complex.py
Normal file
File diff suppressed because it is too large
Load Diff
85
test/dynamo/cpython/3_13/test_iter.diff
Normal file
85
test/dynamo/cpython/3_13/test_iter.diff
Normal file
@ -0,0 +1,85 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py
|
||||
index 1b9f3cf7624..d0c68f4314c 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_iter.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_iter.py
|
||||
@@ -1,3 +1,57 @@
|
||||
+# ======= BEGIN Dynamo patch =======
|
||||
+# Owner(s): ["module: dynamo"]
|
||||
+
|
||||
+# ruff: noqa
|
||||
+# flake8: noqa
|
||||
+
|
||||
+import sys
|
||||
+import torch
|
||||
+import torch._dynamo.test_case
|
||||
+import unittest
|
||||
+from torch._dynamo.test_case import CPythonTestCase
|
||||
+from torch.testing._internal.common_utils import (
|
||||
+ skipIfTorchDynamo,
|
||||
+ run_tests,
|
||||
+)
|
||||
+
|
||||
+__TestCase = CPythonTestCase
|
||||
+
|
||||
+
|
||||
+# redirect import statements
|
||||
+import sys
|
||||
+import importlib.abc
|
||||
+
|
||||
+redirect_imports = (
|
||||
+ "test.mapping_tests",
|
||||
+ "test.typinganndata",
|
||||
+ "test.test_grammar",
|
||||
+ "test.test_math",
|
||||
+ "test.test_iter",
|
||||
+ "test.typinganndata.ann_module",
|
||||
+)
|
||||
+
|
||||
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
+ def find_spec(self, fullname, path, target=None):
|
||||
+ # Check if the import is the problematic one
|
||||
+ if fullname in redirect_imports:
|
||||
+ try:
|
||||
+ # Attempt to import the standalone module
|
||||
+ name = fullname.removeprefix("test.")
|
||||
+ r = importlib.import_module(name)
|
||||
+ # Redirect the module in sys.modules
|
||||
+ sys.modules[fullname] = r
|
||||
+ # Return a module spec from the found module
|
||||
+ return importlib.util.find_spec(name)
|
||||
+ except ImportError:
|
||||
+ return None
|
||||
+ return None
|
||||
+
|
||||
+# Add the custom finder to sys.meta_path
|
||||
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||
+
|
||||
+
|
||||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
# Test iterators.
|
||||
|
||||
import sys
|
||||
@@ -104,7 +158,7 @@ class EmptyIterClass:
|
||||
|
||||
# Main test suite
|
||||
|
||||
-class TestCase(unittest.TestCase):
|
||||
+class TestCase(__TestCase):
|
||||
|
||||
# Helper to check that an iterator returns a given sequence
|
||||
def check_iterator(self, it, seq, pickle=True):
|
||||
@@ -635,6 +689,7 @@ class TestCase(unittest.TestCase):
|
||||
pass
|
||||
|
||||
# Test zip()'s use of iterators.
|
||||
+ @skipIfTorchDynamo("infinite loop")
|
||||
def test_builtin_zip(self):
|
||||
self.assertEqual(list(zip()), [])
|
||||
self.assertEqual(list(zip(*[])), [])
|
||||
@@ -1187,4 +1242,4 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
1245
test/dynamo/cpython/3_13/test_iter.py
Normal file
1245
test/dynamo/cpython/3_13/test_iter.py
Normal file
File diff suppressed because it is too large
Load Diff
101
test/dynamo/cpython/3_13/test_sort.diff
Normal file
101
test/dynamo/cpython/3_13/test_sort.diff
Normal file
@ -0,0 +1,101 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py
|
||||
index 2a7cfb7affa..d661ae544b9 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_sort.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_sort.py
|
||||
@@ -1,3 +1,54 @@
|
||||
+# ======= BEGIN Dynamo patch =======
|
||||
+# Owner(s): ["module: dynamo"]
|
||||
+
|
||||
+# ruff: noqa
|
||||
+# flake8: noqa
|
||||
+
|
||||
+import sys
|
||||
+import torch
|
||||
+import torch._dynamo.test_case
|
||||
+import unittest
|
||||
+from torch._dynamo.test_case import CPythonTestCase
|
||||
+from torch.testing._internal.common_utils import run_tests
|
||||
+
|
||||
+__TestCase = CPythonTestCase
|
||||
+
|
||||
+
|
||||
+# redirect import statements
|
||||
+import sys
|
||||
+import importlib.abc
|
||||
+
|
||||
+redirect_imports = (
|
||||
+ "test.mapping_tests",
|
||||
+ "test.typinganndata",
|
||||
+ "test.test_grammar",
|
||||
+ "test.test_math",
|
||||
+ "test.test_iter",
|
||||
+ "test.typinganndata.ann_module",
|
||||
+)
|
||||
+
|
||||
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
+ def find_spec(self, fullname, path, target=None):
|
||||
+ # Check if the import is the problematic one
|
||||
+ if fullname in redirect_imports:
|
||||
+ try:
|
||||
+ # Attempt to import the standalone module
|
||||
+ name = fullname.removeprefix("test.")
|
||||
+ r = importlib.import_module(name)
|
||||
+ # Redirect the module in sys.modules
|
||||
+ sys.modules[fullname] = r
|
||||
+ # Return a module spec from the found module
|
||||
+ return importlib.util.find_spec(name)
|
||||
+ except ImportError:
|
||||
+ return None
|
||||
+ return None
|
||||
+
|
||||
+# Add the custom finder to sys.meta_path
|
||||
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||
+
|
||||
+
|
||||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
from test import support
|
||||
import random
|
||||
import unittest
|
||||
@@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None):
|
||||
nerrors += 1
|
||||
return
|
||||
|
||||
-class TestBase(unittest.TestCase):
|
||||
+class TestBase(__TestCase):
|
||||
def testStressfully(self):
|
||||
# Try a variety of sizes at and around powers of 2, and at powers of 10.
|
||||
sizes = [0]
|
||||
@@ -151,7 +202,7 @@ class TestBase(unittest.TestCase):
|
||||
self.assertEqual(forced, native)
|
||||
#==============================================================================
|
||||
|
||||
-class TestBugs(unittest.TestCase):
|
||||
+class TestBugs(__TestCase):
|
||||
|
||||
def test_bug453523(self):
|
||||
# bug 453523 -- list.sort() crasher.
|
||||
@@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase):
|
||||
|
||||
#==============================================================================
|
||||
|
||||
-class TestDecorateSortUndecorate(unittest.TestCase):
|
||||
+class TestDecorateSortUndecorate(__TestCase):
|
||||
|
||||
def test_decorated(self):
|
||||
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
|
||||
@@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L):
|
||||
self.assertIs(opt, ref)
|
||||
#note: not assertEqual! We want to ensure *identical* behavior.
|
||||
|
||||
-class TestOptimizedCompares(unittest.TestCase):
|
||||
+class TestOptimizedCompares(__TestCase):
|
||||
def test_safe_object_compare(self):
|
||||
heterogeneous_lists = [[0, 'foo'],
|
||||
[0.0, 'foo'],
|
||||
@@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase):
|
||||
#==============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
462
test/dynamo/cpython/3_13/test_sort.py
Normal file
462
test/dynamo/cpython/3_13/test_sort.py
Normal file
@ -0,0 +1,462 @@
|
||||
# ======= BEGIN Dynamo patch =======
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# ruff: noqa
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import unittest
|
||||
from torch._dynamo.test_case import CPythonTestCase
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
__TestCase = CPythonTestCase
|
||||
|
||||
|
||||
# redirect import statements
|
||||
import sys
|
||||
import importlib.abc
|
||||
|
||||
redirect_imports = (
|
||||
"test.mapping_tests",
|
||||
"test.typinganndata",
|
||||
"test.test_grammar",
|
||||
"test.test_math",
|
||||
"test.test_iter",
|
||||
"test.typinganndata.ann_module",
|
||||
)
|
||||
|
||||
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
# Check if the import is the problematic one
|
||||
if fullname in redirect_imports:
|
||||
try:
|
||||
# Attempt to import the standalone module
|
||||
name = fullname.removeprefix("test.")
|
||||
r = importlib.import_module(name)
|
||||
# Redirect the module in sys.modules
|
||||
sys.modules[fullname] = r
|
||||
# Return a module spec from the found module
|
||||
return importlib.util.find_spec(name)
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# Add the custom finder to sys.meta_path
|
||||
sys.meta_path.insert(0, RedirectImportFinder())
|
||||
|
||||
|
||||
# ======= END DYNAMO PATCH =======
|
||||
|
||||
from test import support
|
||||
import random
|
||||
import unittest
|
||||
from functools import cmp_to_key
|
||||
|
||||
verbose = support.verbose
|
||||
nerrors = 0
|
||||
|
||||
|
||||
def check(tag, expected, raw, compare=None):
|
||||
global nerrors
|
||||
|
||||
if verbose:
|
||||
print(" checking", tag)
|
||||
|
||||
orig = raw[:] # save input in case of error
|
||||
if compare:
|
||||
raw.sort(key=cmp_to_key(compare))
|
||||
else:
|
||||
raw.sort()
|
||||
|
||||
if len(expected) != len(raw):
|
||||
print("error in", tag)
|
||||
print("length mismatch;", len(expected), len(raw))
|
||||
print(expected)
|
||||
print(orig)
|
||||
print(raw)
|
||||
nerrors += 1
|
||||
return
|
||||
|
||||
for i, good in enumerate(expected):
|
||||
maybe = raw[i]
|
||||
if good is not maybe:
|
||||
print("error in", tag)
|
||||
print("out of order at index", i, good, maybe)
|
||||
print(expected)
|
||||
print(orig)
|
||||
print(raw)
|
||||
nerrors += 1
|
||||
return
|
||||
|
||||
class TestBase(__TestCase):
|
||||
def testStressfully(self):
|
||||
# Try a variety of sizes at and around powers of 2, and at powers of 10.
|
||||
sizes = [0]
|
||||
for power in range(1, 10):
|
||||
n = 2 ** power
|
||||
sizes.extend(range(n-1, n+2))
|
||||
sizes.extend([10, 100, 1000])
|
||||
|
||||
class Complains(object):
|
||||
maybe_complain = True
|
||||
|
||||
def __init__(self, i):
|
||||
self.i = i
|
||||
|
||||
def __lt__(self, other):
|
||||
if Complains.maybe_complain and random.random() < 0.001:
|
||||
if verbose:
|
||||
print(" complaining at", self, other)
|
||||
raise RuntimeError
|
||||
return self.i < other.i
|
||||
|
||||
def __repr__(self):
|
||||
return "Complains(%d)" % self.i
|
||||
|
||||
class Stable(object):
|
||||
def __init__(self, key, i):
|
||||
self.key = key
|
||||
self.index = i
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.key < other.key
|
||||
|
||||
def __repr__(self):
|
||||
return "Stable(%d, %d)" % (self.key, self.index)
|
||||
|
||||
for n in sizes:
|
||||
x = list(range(n))
|
||||
if verbose:
|
||||
print("Testing size", n)
|
||||
|
||||
s = x[:]
|
||||
check("identity", x, s)
|
||||
|
||||
s = x[:]
|
||||
s.reverse()
|
||||
check("reversed", x, s)
|
||||
|
||||
s = x[:]
|
||||
random.shuffle(s)
|
||||
check("random permutation", x, s)
|
||||
|
||||
y = x[:]
|
||||
y.reverse()
|
||||
s = x[:]
|
||||
check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
|
||||
|
||||
if verbose:
|
||||
print(" Checking against an insane comparison function.")
|
||||
print(" If the implementation isn't careful, this may segfault.")
|
||||
s = x[:]
|
||||
s.sort(key=cmp_to_key(lambda a, b: int(random.random() * 3) - 1))
|
||||
check("an insane function left some permutation", x, s)
|
||||
|
||||
if len(x) >= 2:
|
||||
def bad_key(x):
|
||||
raise RuntimeError
|
||||
s = x[:]
|
||||
self.assertRaises(RuntimeError, s.sort, key=bad_key)
|
||||
|
||||
x = [Complains(i) for i in x]
|
||||
s = x[:]
|
||||
random.shuffle(s)
|
||||
Complains.maybe_complain = True
|
||||
it_complained = False
|
||||
try:
|
||||
s.sort()
|
||||
except RuntimeError:
|
||||
it_complained = True
|
||||
if it_complained:
|
||||
Complains.maybe_complain = False
|
||||
check("exception during sort left some permutation", x, s)
|
||||
|
||||
s = [Stable(random.randrange(10), i) for i in range(n)]
|
||||
augmented = [(e, e.index) for e in s]
|
||||
augmented.sort() # forced stable because ties broken by index
|
||||
x = [e for e, i in augmented] # a stable sort of s
|
||||
check("stability", x, s)
|
||||
|
||||
def test_small_stability(self):
|
||||
from itertools import product
|
||||
from operator import itemgetter
|
||||
|
||||
# Exhaustively test stability across all lists of small lengths
|
||||
# and only a few distinct elements.
|
||||
# This can provoke edge cases that randomization is unlikely to find.
|
||||
# But it can grow very expensive quickly, so don't overdo it.
|
||||
NELTS = 3
|
||||
MAXSIZE = 9
|
||||
|
||||
pick0 = itemgetter(0)
|
||||
for length in range(MAXSIZE + 1):
|
||||
# There are NELTS ** length distinct lists.
|
||||
for t in product(range(NELTS), repeat=length):
|
||||
xs = list(zip(t, range(length)))
|
||||
# Stability forced by index in each element.
|
||||
forced = sorted(xs)
|
||||
# Use key= to hide the index from compares.
|
||||
native = sorted(xs, key=pick0)
|
||||
self.assertEqual(forced, native)
|
||||
#==============================================================================
|
||||
|
||||
class TestBugs(__TestCase):
|
||||
|
||||
def test_bug453523(self):
|
||||
# bug 453523 -- list.sort() crasher.
|
||||
# If this fails, the most likely outcome is a core dump.
|
||||
# Mutations during a list sort should raise a ValueError.
|
||||
|
||||
class C:
|
||||
def __lt__(self, other):
|
||||
if L and random.random() < 0.75:
|
||||
L.pop()
|
||||
else:
|
||||
L.append(3)
|
||||
return random.random() < 0.5
|
||||
|
||||
L = [C() for i in range(50)]
|
||||
self.assertRaises(ValueError, L.sort)
|
||||
|
||||
def test_undetected_mutation(self):
|
||||
# Python 2.4a1 did not always detect mutation
|
||||
memorywaster = []
|
||||
for i in range(20):
|
||||
def mutating_cmp(x, y):
|
||||
L.append(3)
|
||||
L.pop()
|
||||
return (x > y) - (x < y)
|
||||
L = [1,2]
|
||||
self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
|
||||
def mutating_cmp(x, y):
|
||||
L.append(3)
|
||||
del L[:]
|
||||
return (x > y) - (x < y)
|
||||
self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
|
||||
memorywaster = [memorywaster]
|
||||
|
||||
#==============================================================================
|
||||
|
||||
class TestDecorateSortUndecorate(__TestCase):
|
||||
|
||||
def test_decorated(self):
|
||||
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
|
||||
copy = data[:]
|
||||
random.shuffle(data)
|
||||
data.sort(key=str.lower)
|
||||
def my_cmp(x, y):
|
||||
xlower, ylower = x.lower(), y.lower()
|
||||
return (xlower > ylower) - (xlower < ylower)
|
||||
copy.sort(key=cmp_to_key(my_cmp))
|
||||
|
||||
def test_baddecorator(self):
|
||||
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
|
||||
self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
|
||||
|
||||
def test_stability(self):
|
||||
data = [(random.randrange(100), i) for i in range(200)]
|
||||
copy = data[:]
|
||||
data.sort(key=lambda t: t[0]) # sort on the random first field
|
||||
copy.sort() # sort using both fields
|
||||
self.assertEqual(data, copy) # should get the same result
|
||||
|
||||
def test_key_with_exception(self):
|
||||
# Verify that the wrapper has been removed
|
||||
data = list(range(-2, 2))
|
||||
dup = data[:]
|
||||
self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
|
||||
self.assertEqual(data, dup)
|
||||
|
||||
def test_key_with_mutation(self):
|
||||
data = list(range(10))
|
||||
def k(x):
|
||||
del data[:]
|
||||
data[:] = range(20)
|
||||
return x
|
||||
self.assertRaises(ValueError, data.sort, key=k)
|
||||
|
||||
def test_key_with_mutating_del(self):
|
||||
data = list(range(10))
|
||||
class SortKiller(object):
|
||||
def __init__(self, x):
|
||||
pass
|
||||
def __del__(self):
|
||||
del data[:]
|
||||
data[:] = range(20)
|
||||
def __lt__(self, other):
|
||||
return id(self) < id(other)
|
||||
self.assertRaises(ValueError, data.sort, key=SortKiller)
|
||||
|
||||
def test_key_with_mutating_del_and_exception(self):
|
||||
data = list(range(10))
|
||||
## dup = data[:]
|
||||
class SortKiller(object):
|
||||
def __init__(self, x):
|
||||
if x > 2:
|
||||
raise RuntimeError
|
||||
def __del__(self):
|
||||
del data[:]
|
||||
data[:] = list(range(20))
|
||||
self.assertRaises(RuntimeError, data.sort, key=SortKiller)
|
||||
## major honking subtlety: we *can't* do:
|
||||
##
|
||||
## self.assertEqual(data, dup)
|
||||
##
|
||||
## because there is a reference to a SortKiller in the
|
||||
## traceback and by the time it dies we're outside the call to
|
||||
## .sort() and so the list protection gimmicks are out of
|
||||
## date (this cost some brain cells to figure out...).
|
||||
|
||||
def test_reverse(self):
|
||||
data = list(range(100))
|
||||
random.shuffle(data)
|
||||
data.sort(reverse=True)
|
||||
self.assertEqual(data, list(range(99,-1,-1)))
|
||||
|
||||
def test_reverse_stability(self):
|
||||
data = [(random.randrange(100), i) for i in range(200)]
|
||||
copy1 = data[:]
|
||||
copy2 = data[:]
|
||||
def my_cmp(x, y):
|
||||
x0, y0 = x[0], y[0]
|
||||
return (x0 > y0) - (x0 < y0)
|
||||
def my_cmp_reversed(x, y):
|
||||
x0, y0 = x[0], y[0]
|
||||
return (y0 > x0) - (y0 < x0)
|
||||
data.sort(key=cmp_to_key(my_cmp), reverse=True)
|
||||
copy1.sort(key=cmp_to_key(my_cmp_reversed))
|
||||
self.assertEqual(data, copy1)
|
||||
copy2.sort(key=lambda x: x[0], reverse=True)
|
||||
self.assertEqual(data, copy2)
|
||||
|
||||
#==============================================================================
|
||||
def check_against_PyObject_RichCompareBool(self, L):
|
||||
## The idea here is to exploit the fact that unsafe_tuple_compare uses
|
||||
## PyObject_RichCompareBool for the second elements of tuples. So we have,
|
||||
## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])]
|
||||
## This will work as long as __eq__ => not __lt__ for all the objects in L,
|
||||
## which holds for all the types used below.
|
||||
##
|
||||
## Testing this way ensures that the optimized implementation remains consistent
|
||||
## with the naive implementation, even if changes are made to any of the
|
||||
## richcompares.
|
||||
##
|
||||
## This function tests sorting for three lists (it randomly shuffles each one):
|
||||
## 1. L
|
||||
## 2. [(x,) for x in L]
|
||||
## 3. [((x,),) for x in L]
|
||||
|
||||
random.seed(0)
|
||||
random.shuffle(L)
|
||||
L_1 = L[:]
|
||||
L_2 = [(x,) for x in L]
|
||||
L_3 = [((x,),) for x in L]
|
||||
for L in [L_1, L_2, L_3]:
|
||||
optimized = sorted(L)
|
||||
reference = [y[1] for y in sorted([(0,x) for x in L])]
|
||||
for (opt, ref) in zip(optimized, reference):
|
||||
self.assertIs(opt, ref)
|
||||
#note: not assertEqual! We want to ensure *identical* behavior.
|
||||
|
||||
class TestOptimizedCompares(__TestCase):
|
||||
def test_safe_object_compare(self):
|
||||
heterogeneous_lists = [[0, 'foo'],
|
||||
[0.0, 'foo'],
|
||||
[('foo',), 'foo']]
|
||||
for L in heterogeneous_lists:
|
||||
self.assertRaises(TypeError, L.sort)
|
||||
self.assertRaises(TypeError, [(x,) for x in L].sort)
|
||||
self.assertRaises(TypeError, [((x,),) for x in L].sort)
|
||||
|
||||
float_int_lists = [[1,1.1],
|
||||
[1<<70,1.1],
|
||||
[1.1,1],
|
||||
[1.1,1<<70]]
|
||||
for L in float_int_lists:
|
||||
check_against_PyObject_RichCompareBool(self, L)
|
||||
|
||||
def test_unsafe_object_compare(self):
|
||||
|
||||
# This test is by ppperry. It ensures that unsafe_object_compare is
|
||||
# verifying ms->key_richcompare == tp->richcompare before comparing.
|
||||
|
||||
class WackyComparator(int):
|
||||
def __lt__(self, other):
|
||||
elem.__class__ = WackyList2
|
||||
return int.__lt__(self, other)
|
||||
|
||||
class WackyList1(list):
|
||||
pass
|
||||
|
||||
class WackyList2(list):
|
||||
def __lt__(self, other):
|
||||
raise ValueError
|
||||
|
||||
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
|
||||
elem = L[-1]
|
||||
with self.assertRaises(ValueError):
|
||||
L.sort()
|
||||
|
||||
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
|
||||
elem = L[-1]
|
||||
with self.assertRaises(ValueError):
|
||||
[(x,) for x in L].sort()
|
||||
|
||||
# The following test is also by ppperry. It ensures that
|
||||
# unsafe_object_compare handles Py_NotImplemented appropriately.
|
||||
class PointlessComparator:
|
||||
def __lt__(self, other):
|
||||
return NotImplemented
|
||||
L = [PointlessComparator(), PointlessComparator()]
|
||||
self.assertRaises(TypeError, L.sort)
|
||||
self.assertRaises(TypeError, [(x,) for x in L].sort)
|
||||
|
||||
# The following tests go through various types that would trigger
|
||||
# ms->key_compare = unsafe_object_compare
|
||||
lists = [list(range(100)) + [(1<<70)],
|
||||
[str(x) for x in range(100)] + ['\uffff'],
|
||||
[bytes(x) for x in range(100)],
|
||||
[cmp_to_key(lambda x,y: x<y)(x) for x in range(100)]]
|
||||
for L in lists:
|
||||
check_against_PyObject_RichCompareBool(self, L)
|
||||
|
||||
def test_unsafe_latin_compare(self):
|
||||
check_against_PyObject_RichCompareBool(self, [str(x) for
|
||||
x in range(100)])
|
||||
|
||||
def test_unsafe_long_compare(self):
|
||||
check_against_PyObject_RichCompareBool(self, [x for
|
||||
x in range(100)])
|
||||
|
||||
def test_unsafe_float_compare(self):
|
||||
check_against_PyObject_RichCompareBool(self, [float(x) for
|
||||
x in range(100)])
|
||||
|
||||
def test_unsafe_tuple_compare(self):
|
||||
# This test was suggested by Tim Peters. It verifies that the tuple
|
||||
# comparison respects the current tuple compare semantics, which do not
|
||||
# guarantee that x < x <=> (x,) < (x,)
|
||||
#
|
||||
# Note that we don't have to put anything in tuples here, because
|
||||
# the check function does a tuple test automatically.
|
||||
|
||||
check_against_PyObject_RichCompareBool(self, [float('nan')]*100)
|
||||
check_against_PyObject_RichCompareBool(self, [float('nan') for
|
||||
_ in range(100)])
|
||||
|
||||
def test_not_all_tuples(self):
|
||||
self.assertRaises(TypeError, [(1.0, 1.0), (False, "A"), 6].sort)
|
||||
self.assertRaises(TypeError, [('a', 1), (1, 'a')].sort)
|
||||
self.assertRaises(TypeError, [(1, 'a'), ('a', 1)].sort)
|
||||
|
||||
def test_none_in_tuples(self):
|
||||
expected = [(None, 1), (None, 2)]
|
||||
actual = sorted([(None, 2), (None, 1)])
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
#==============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
89
test/dynamo/cpython/3_13/test_unittest/test_assertions.diff
Normal file
89
test/dynamo/cpython/3_13/test_unittest/test_assertions.diff
Normal file
@ -0,0 +1,89 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_unittest/test_assertions.py b/test/dynamo/cpython/3_13/test_unittest/test_assertions.py
|
||||
index 1dec947ea76..5a8c2a9d3af 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_unittest/test_assertions.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_unittest/test_assertions.py
|
||||
@@ -1,3 +1,54 @@
|
||||
+# ======= BEGIN Dynamo patch =======
|
||||
+# Owner(s): ["module: dynamo"]
|
||||
+
|
||||
+# ruff: noqa
|
||||
+# flake8: noqa
|
||||
+
|
||||
+import sys
|
||||
+import torch
|
||||
+import torch._dynamo.test_case
|
||||
+import unittest
|
||||
+from torch.testing._internal.common_utils import run_tests
|
||||
+
|
||||
+
|
||||
+__TestCase = torch._dynamo.test_case.CPythonTestCase
|
||||
+
|
||||
+
|
||||
+# redirect import statements
|
||||
+import sys
|
||||
+import importlib.abc
|
||||
+
|
||||
+redirect_imports = (
|
||||
+ "test.mapping_tests",
|
||||
+ "test.typinganndata",
|
||||
+ "test.test_grammar",
|
||||
+ "test.test_math",
|
||||
+ "test.test_iter",
|
||||
+ "test.typinganndata.ann_module",
|
||||
+)
|
||||
+
|
||||
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
+ def find_spec(self, fullname, path, target=None):
|
||||
+ # Check if the import is the problematic one
|
||||
+ if fullname in redirect_imports:
|
||||
+ try:
|
||||
+ # Attempt to import the standalone module
|
||||
+ name = fullname.removeprefix("test.")
|
||||
+ r = importlib.import_module(name)
|
||||
+ # Redirect the module in sys.modules
|
||||
+ sys.modules[fullname] = r
|
||||
+ # Return a module spec from the found module
|
||||
+ return importlib.util.find_spec(name)
|
||||
+ except ImportError:
|
||||
+ return None
|
||||
+ return None
|
||||
+
|
||||
+# Add the custom finder to sys.meta_path
|
||||
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||
+
|
||||
+
|
||||
+# ======= END DYNAMO PATCH =======
|
||||
+
|
||||
import datetime
|
||||
import warnings
|
||||
import weakref
|
||||
@@ -6,7 +57,7 @@ from test.support import gc_collect
|
||||
from itertools import product
|
||||
|
||||
|
||||
-class Test_Assertions(unittest.TestCase):
|
||||
+class Test_Assertions(__TestCase):
|
||||
def test_AlmostEqual(self):
|
||||
self.assertAlmostEqual(1.00000001, 1.0)
|
||||
self.assertNotAlmostEqual(1.0000001, 1.0)
|
||||
@@ -141,12 +192,13 @@ class Test_Assertions(unittest.TestCase):
|
||||
self.fail('assertNotRegex should have failed.')
|
||||
|
||||
|
||||
-class TestLongMessage(unittest.TestCase):
|
||||
+class TestLongMessage(__TestCase):
|
||||
"""Test that the individual asserts honour longMessage.
|
||||
This actually tests all the message behaviour for
|
||||
asserts that use longMessage."""
|
||||
|
||||
def setUp(self):
|
||||
+ super().setUp()
|
||||
class TestableTestFalse(unittest.TestCase):
|
||||
longMessage = False
|
||||
failureException = self.failureException
|
||||
@@ -414,4 +466,4 @@ class TestLongMessage(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
- unittest.main()
|
||||
+ run_tests()
|
||||
469
test/dynamo/cpython/3_13/test_unittest/test_assertions.py
Normal file
469
test/dynamo/cpython/3_13/test_unittest/test_assertions.py
Normal file
@ -0,0 +1,469 @@
|
||||
# ======= BEGIN Dynamo patch =======
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# ruff: noqa
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
__TestCase = torch._dynamo.test_case.CPythonTestCase
|
||||
|
||||
|
||||
# redirect import statements
|
||||
import sys
|
||||
import importlib.abc
|
||||
|
||||
redirect_imports = (
|
||||
"test.mapping_tests",
|
||||
"test.typinganndata",
|
||||
"test.test_grammar",
|
||||
"test.test_math",
|
||||
"test.test_iter",
|
||||
"test.typinganndata.ann_module",
|
||||
)
|
||||
|
||||
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
# Check if the import is the problematic one
|
||||
if fullname in redirect_imports:
|
||||
try:
|
||||
# Attempt to import the standalone module
|
||||
name = fullname.removeprefix("test.")
|
||||
r = importlib.import_module(name)
|
||||
# Redirect the module in sys.modules
|
||||
sys.modules[fullname] = r
|
||||
# Return a module spec from the found module
|
||||
return importlib.util.find_spec(name)
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# Add the custom finder to sys.meta_path
|
||||
sys.meta_path.insert(0, RedirectImportFinder())
|
||||
|
||||
|
||||
# ======= END DYNAMO PATCH =======
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
import weakref
|
||||
import unittest
|
||||
from test.support import gc_collect
|
||||
from itertools import product
|
||||
|
||||
|
||||
class Test_Assertions(__TestCase):
|
||||
def test_AlmostEqual(self):
|
||||
self.assertAlmostEqual(1.00000001, 1.0)
|
||||
self.assertNotAlmostEqual(1.0000001, 1.0)
|
||||
self.assertRaises(self.failureException,
|
||||
self.assertAlmostEqual, 1.0000001, 1.0)
|
||||
self.assertRaises(self.failureException,
|
||||
self.assertNotAlmostEqual, 1.00000001, 1.0)
|
||||
|
||||
self.assertAlmostEqual(1.1, 1.0, places=0)
|
||||
self.assertRaises(self.failureException,
|
||||
self.assertAlmostEqual, 1.1, 1.0, places=1)
|
||||
|
||||
self.assertAlmostEqual(0, .1+.1j, places=0)
|
||||
self.assertNotAlmostEqual(0, .1+.1j, places=1)
|
||||
self.assertRaises(self.failureException,
|
||||
self.assertAlmostEqual, 0, .1+.1j, places=1)
|
||||
self.assertRaises(self.failureException,
|
||||
self.assertNotAlmostEqual, 0, .1+.1j, places=0)
|
||||
|
||||
self.assertAlmostEqual(float('inf'), float('inf'))
|
||||
self.assertRaises(self.failureException, self.assertNotAlmostEqual,
|
||||
float('inf'), float('inf'))
|
||||
|
||||
def test_AmostEqualWithDelta(self):
|
||||
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
|
||||
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
|
||||
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
|
||||
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
|
||||
|
||||
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
|
||||
self.assertRaises(self.failureException, self.assertNotAlmostEqual,
|
||||
1.0, 1.0, delta=0.5)
|
||||
|
||||
self.assertRaises(self.failureException, self.assertAlmostEqual,
|
||||
1.1, 1.0, delta=0.05)
|
||||
self.assertRaises(self.failureException, self.assertNotAlmostEqual,
|
||||
1.1, 1.0, delta=0.5)
|
||||
|
||||
self.assertRaises(TypeError, self.assertAlmostEqual,
|
||||
1.1, 1.0, places=2, delta=2)
|
||||
self.assertRaises(TypeError, self.assertNotAlmostEqual,
|
||||
1.1, 1.0, places=2, delta=2)
|
||||
|
||||
first = datetime.datetime.now()
|
||||
second = first + datetime.timedelta(seconds=10)
|
||||
self.assertAlmostEqual(first, second,
|
||||
delta=datetime.timedelta(seconds=20))
|
||||
self.assertNotAlmostEqual(first, second,
|
||||
delta=datetime.timedelta(seconds=5))
|
||||
|
||||
def test_assertRaises(self):
|
||||
def _raise(e):
|
||||
raise e
|
||||
self.assertRaises(KeyError, _raise, KeyError)
|
||||
self.assertRaises(KeyError, _raise, KeyError("key"))
|
||||
try:
|
||||
self.assertRaises(KeyError, lambda: None)
|
||||
except self.failureException as e:
|
||||
self.assertIn("KeyError not raised", str(e))
|
||||
else:
|
||||
self.fail("assertRaises() didn't fail")
|
||||
try:
|
||||
self.assertRaises(KeyError, _raise, ValueError)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.fail("assertRaises() didn't let exception pass through")
|
||||
with self.assertRaises(KeyError) as cm:
|
||||
try:
|
||||
raise KeyError
|
||||
except Exception as e:
|
||||
exc = e
|
||||
raise
|
||||
self.assertIs(cm.exception, exc)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
raise KeyError("key")
|
||||
try:
|
||||
with self.assertRaises(KeyError):
|
||||
pass
|
||||
except self.failureException as e:
|
||||
self.assertIn("KeyError not raised", str(e))
|
||||
else:
|
||||
self.fail("assertRaises() didn't fail")
|
||||
try:
|
||||
with self.assertRaises(KeyError):
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.fail("assertRaises() didn't let exception pass through")
|
||||
|
||||
def test_assertRaises_frames_survival(self):
|
||||
# Issue #9815: assertRaises should avoid keeping local variables
|
||||
# in a traceback alive.
|
||||
class A:
|
||||
pass
|
||||
wr = None
|
||||
|
||||
class Foo(unittest.TestCase):
|
||||
|
||||
def foo(self):
|
||||
nonlocal wr
|
||||
a = A()
|
||||
wr = weakref.ref(a)
|
||||
try:
|
||||
raise OSError
|
||||
except OSError:
|
||||
raise ValueError
|
||||
|
||||
def test_functional(self):
|
||||
self.assertRaises(ValueError, self.foo)
|
||||
|
||||
def test_with(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.foo()
|
||||
|
||||
Foo("test_functional").run()
|
||||
gc_collect() # For PyPy or other GCs.
|
||||
self.assertIsNone(wr())
|
||||
Foo("test_with").run()
|
||||
gc_collect() # For PyPy or other GCs.
|
||||
self.assertIsNone(wr())
|
||||
|
||||
def testAssertNotRegex(self):
|
||||
self.assertNotRegex('Ala ma kota', r'r+')
|
||||
try:
|
||||
self.assertNotRegex('Ala ma kota', r'k.t', 'Message')
|
||||
except self.failureException as e:
|
||||
self.assertIn('Message', e.args[0])
|
||||
else:
|
||||
self.fail('assertNotRegex should have failed.')
|
||||
|
||||
|
||||
class TestLongMessage(__TestCase):
|
||||
"""Test that the individual asserts honour longMessage.
|
||||
This actually tests all the message behaviour for
|
||||
asserts that use longMessage."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
class TestableTestFalse(unittest.TestCase):
|
||||
longMessage = False
|
||||
failureException = self.failureException
|
||||
|
||||
def testTest(self):
|
||||
pass
|
||||
|
||||
class TestableTestTrue(unittest.TestCase):
|
||||
longMessage = True
|
||||
failureException = self.failureException
|
||||
|
||||
def testTest(self):
|
||||
pass
|
||||
|
||||
self.testableTrue = TestableTestTrue('testTest')
|
||||
self.testableFalse = TestableTestFalse('testTest')
|
||||
|
||||
def testDefault(self):
|
||||
self.assertTrue(unittest.TestCase.longMessage)
|
||||
|
||||
def test_formatMsg(self):
|
||||
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
|
||||
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
|
||||
|
||||
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
|
||||
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
|
||||
|
||||
# This blows up if _formatMessage uses string concatenation
|
||||
self.testableTrue._formatMessage(object(), 'foo')
|
||||
|
||||
def test_formatMessage_unicode_error(self):
|
||||
one = ''.join(chr(i) for i in range(255))
|
||||
# this used to cause a UnicodeDecodeError constructing msg
|
||||
self.testableTrue._formatMessage(one, '\uFFFD')
|
||||
|
||||
def assertMessages(self, methodName, args, errors):
|
||||
"""
|
||||
Check that methodName(*args) raises the correct error messages.
|
||||
errors should be a list of 4 regex that match the error when:
|
||||
1) longMessage = False and no msg passed;
|
||||
2) longMessage = False and msg passed;
|
||||
3) longMessage = True and no msg passed;
|
||||
4) longMessage = True and msg passed;
|
||||
"""
|
||||
def getMethod(i):
|
||||
useTestableFalse = i < 2
|
||||
if useTestableFalse:
|
||||
test = self.testableFalse
|
||||
else:
|
||||
test = self.testableTrue
|
||||
return getattr(test, methodName)
|
||||
|
||||
for i, expected_regex in enumerate(errors):
|
||||
testMethod = getMethod(i)
|
||||
kwargs = {}
|
||||
withMsg = i % 2
|
||||
if withMsg:
|
||||
kwargs = {"msg": "oops"}
|
||||
|
||||
with self.assertRaisesRegex(self.failureException,
|
||||
expected_regex=expected_regex):
|
||||
testMethod(*args, **kwargs)
|
||||
|
||||
def testAssertTrue(self):
|
||||
self.assertMessages('assertTrue', (False,),
|
||||
["^False is not true$", "^oops$", "^False is not true$",
|
||||
"^False is not true : oops$"])
|
||||
|
||||
def testAssertFalse(self):
|
||||
self.assertMessages('assertFalse', (True,),
|
||||
["^True is not false$", "^oops$", "^True is not false$",
|
||||
"^True is not false : oops$"])
|
||||
|
||||
def testNotEqual(self):
|
||||
self.assertMessages('assertNotEqual', (1, 1),
|
||||
["^1 == 1$", "^oops$", "^1 == 1$",
|
||||
"^1 == 1 : oops$"])
|
||||
|
||||
def testAlmostEqual(self):
|
||||
self.assertMessages(
|
||||
'assertAlmostEqual', (1, 2),
|
||||
[r"^1 != 2 within 7 places \(1 difference\)$", "^oops$",
|
||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
||||
r"^1 != 2 within 7 places \(1 difference\) : oops$"])
|
||||
|
||||
def testNotAlmostEqual(self):
|
||||
self.assertMessages('assertNotAlmostEqual', (1, 1),
|
||||
["^1 == 1 within 7 places$", "^oops$",
|
||||
"^1 == 1 within 7 places$", "^1 == 1 within 7 places : oops$"])
|
||||
|
||||
def test_baseAssertEqual(self):
|
||||
self.assertMessages('_baseAssertEqual', (1, 2),
|
||||
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"])
|
||||
|
||||
def testAssertSequenceEqual(self):
|
||||
# Error messages are multiline so not testing on full message
|
||||
# assertTupleEqual and assertListEqual delegate to this method
|
||||
self.assertMessages('assertSequenceEqual', ([], [None]),
|
||||
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$",
|
||||
r"\+ \[None\] : oops$"])
|
||||
|
||||
def testAssertSetEqual(self):
|
||||
self.assertMessages('assertSetEqual', (set(), set([None])),
|
||||
["None$", "^oops$", "None$",
|
||||
"None : oops$"])
|
||||
|
||||
def testAssertIn(self):
|
||||
self.assertMessages('assertIn', (None, []),
|
||||
[r'^None not found in \[\]$', "^oops$",
|
||||
r'^None not found in \[\]$',
|
||||
r'^None not found in \[\] : oops$'])
|
||||
|
||||
def testAssertNotIn(self):
|
||||
self.assertMessages('assertNotIn', (None, [None]),
|
||||
[r'^None unexpectedly found in \[None\]$', "^oops$",
|
||||
r'^None unexpectedly found in \[None\]$',
|
||||
r'^None unexpectedly found in \[None\] : oops$'])
|
||||
|
||||
def testAssertDictEqual(self):
|
||||
self.assertMessages('assertDictEqual', ({}, {'key': 'value'}),
|
||||
[r"\+ \{'key': 'value'\}$", "^oops$",
|
||||
r"\+ \{'key': 'value'\}$",
|
||||
r"\+ \{'key': 'value'\} : oops$"])
|
||||
|
||||
def testAssertMultiLineEqual(self):
|
||||
self.assertMessages('assertMultiLineEqual', ("", "foo"),
|
||||
[r"\+ foo\n$", "^oops$",
|
||||
r"\+ foo\n$",
|
||||
r"\+ foo\n : oops$"])
|
||||
|
||||
def testAssertLess(self):
|
||||
self.assertMessages('assertLess', (2, 1),
|
||||
["^2 not less than 1$", "^oops$",
|
||||
"^2 not less than 1$", "^2 not less than 1 : oops$"])
|
||||
|
||||
def testAssertLessEqual(self):
|
||||
self.assertMessages('assertLessEqual', (2, 1),
|
||||
["^2 not less than or equal to 1$", "^oops$",
|
||||
"^2 not less than or equal to 1$",
|
||||
"^2 not less than or equal to 1 : oops$"])
|
||||
|
||||
def testAssertGreater(self):
|
||||
self.assertMessages('assertGreater', (1, 2),
|
||||
["^1 not greater than 2$", "^oops$",
|
||||
"^1 not greater than 2$",
|
||||
"^1 not greater than 2 : oops$"])
|
||||
|
||||
def testAssertGreaterEqual(self):
|
||||
self.assertMessages('assertGreaterEqual', (1, 2),
|
||||
["^1 not greater than or equal to 2$", "^oops$",
|
||||
"^1 not greater than or equal to 2$",
|
||||
"^1 not greater than or equal to 2 : oops$"])
|
||||
|
||||
def testAssertIsNone(self):
|
||||
self.assertMessages('assertIsNone', ('not None',),
|
||||
["^'not None' is not None$", "^oops$",
|
||||
"^'not None' is not None$",
|
||||
"^'not None' is not None : oops$"])
|
||||
|
||||
def testAssertIsNotNone(self):
|
||||
self.assertMessages('assertIsNotNone', (None,),
|
||||
["^unexpectedly None$", "^oops$",
|
||||
"^unexpectedly None$",
|
||||
"^unexpectedly None : oops$"])
|
||||
|
||||
def testAssertIs(self):
|
||||
self.assertMessages('assertIs', (None, 'foo'),
|
||||
["^None is not 'foo'$", "^oops$",
|
||||
"^None is not 'foo'$",
|
||||
"^None is not 'foo' : oops$"])
|
||||
|
||||
def testAssertIsNot(self):
|
||||
self.assertMessages('assertIsNot', (None, None),
|
||||
["^unexpectedly identical: None$", "^oops$",
|
||||
"^unexpectedly identical: None$",
|
||||
"^unexpectedly identical: None : oops$"])
|
||||
|
||||
def testAssertRegex(self):
|
||||
self.assertMessages('assertRegex', ('foo', 'bar'),
|
||||
["^Regex didn't match:",
|
||||
"^oops$",
|
||||
"^Regex didn't match:",
|
||||
"^Regex didn't match: (.*) : oops$"])
|
||||
|
||||
def testAssertNotRegex(self):
|
||||
self.assertMessages('assertNotRegex', ('foo', 'foo'),
|
||||
["^Regex matched:",
|
||||
"^oops$",
|
||||
"^Regex matched:",
|
||||
"^Regex matched: (.*) : oops$"])
|
||||
|
||||
|
||||
def assertMessagesCM(self, methodName, args, func, errors):
|
||||
"""
|
||||
Check that the correct error messages are raised while executing:
|
||||
with method(*args):
|
||||
func()
|
||||
*errors* should be a list of 4 regex that match the error when:
|
||||
1) longMessage = False and no msg passed;
|
||||
2) longMessage = False and msg passed;
|
||||
3) longMessage = True and no msg passed;
|
||||
4) longMessage = True and msg passed;
|
||||
"""
|
||||
p = product((self.testableFalse, self.testableTrue),
|
||||
({}, {"msg": "oops"}))
|
||||
for (cls, kwargs), err in zip(p, errors):
|
||||
method = getattr(cls, methodName)
|
||||
with self.assertRaisesRegex(cls.failureException, err):
|
||||
with method(*args, **kwargs) as cm:
|
||||
func()
|
||||
|
||||
def testAssertRaises(self):
|
||||
self.assertMessagesCM('assertRaises', (TypeError,), lambda: None,
|
||||
['^TypeError not raised$', '^oops$',
|
||||
'^TypeError not raised$',
|
||||
'^TypeError not raised : oops$'])
|
||||
|
||||
def testAssertRaisesRegex(self):
|
||||
# test error not raised
|
||||
self.assertMessagesCM('assertRaisesRegex', (TypeError, 'unused regex'),
|
||||
lambda: None,
|
||||
['^TypeError not raised$', '^oops$',
|
||||
'^TypeError not raised$',
|
||||
'^TypeError not raised : oops$'])
|
||||
# test error raised but with wrong message
|
||||
def raise_wrong_message():
|
||||
raise TypeError('foo')
|
||||
self.assertMessagesCM('assertRaisesRegex', (TypeError, 'regex'),
|
||||
raise_wrong_message,
|
||||
['^"regex" does not match "foo"$', '^oops$',
|
||||
'^"regex" does not match "foo"$',
|
||||
'^"regex" does not match "foo" : oops$'])
|
||||
|
||||
def testAssertWarns(self):
|
||||
self.assertMessagesCM('assertWarns', (UserWarning,), lambda: None,
|
||||
['^UserWarning not triggered$', '^oops$',
|
||||
'^UserWarning not triggered$',
|
||||
'^UserWarning not triggered : oops$'])
|
||||
|
||||
def test_assertNotWarns(self):
|
||||
def warn_future():
|
||||
warnings.warn('xyz', FutureWarning, stacklevel=2)
|
||||
self.assertMessagesCM('_assertNotWarns', (FutureWarning,),
|
||||
warn_future,
|
||||
['^FutureWarning triggered$',
|
||||
'^oops$',
|
||||
'^FutureWarning triggered$',
|
||||
'^FutureWarning triggered : oops$'])
|
||||
|
||||
def testAssertWarnsRegex(self):
|
||||
# test error not raised
|
||||
self.assertMessagesCM('assertWarnsRegex', (UserWarning, 'unused regex'),
|
||||
lambda: None,
|
||||
['^UserWarning not triggered$', '^oops$',
|
||||
'^UserWarning not triggered$',
|
||||
'^UserWarning not triggered : oops$'])
|
||||
# test warning raised but with wrong message
|
||||
def raise_wrong_message():
|
||||
warnings.warn('foo')
|
||||
self.assertMessagesCM('assertWarnsRegex', (UserWarning, 'regex'),
|
||||
raise_wrong_message,
|
||||
['^"regex" does not match "foo"$', '^oops$',
|
||||
'^"regex" does not match "foo"$',
|
||||
'^"regex" does not match "foo" : oops$'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -130,14 +130,17 @@ class _multiply_invoke(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
|
||||
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
|
||||
l_inputs_ = L_inputs_
|
||||
l_sizes_0_ = L_sizes_0_
|
||||
|
||||
getitem: "f32[s21]" = l_inputs_[0]
|
||||
getitem_1: "f32[s21]" = l_inputs_[1]
|
||||
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
|
||||
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
|
||||
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
@ -160,14 +163,17 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_inputs_ : list, s69: "Sym(s21)"):
|
||||
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]"):
|
||||
l_inputs_ = L_inputs_
|
||||
l_sizes_0_ = L_sizes_0_
|
||||
|
||||
getitem: "f32[s21]" = l_inputs_[0]
|
||||
getitem_1: "f32[s21]" = l_inputs_[1]
|
||||
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
|
||||
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
|
||||
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
@ -242,15 +248,18 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
|
||||
def forward(self, L_inputs_ : list, s69: "Sym(s21)", L_sizes_0_: "f32[0, s21]", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
|
||||
l_inputs_ = L_inputs_
|
||||
l_sizes_0_ = L_sizes_0_
|
||||
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
|
||||
|
||||
getitem: "f32[s21]" = l_inputs_[0]
|
||||
getitem_1: "f32[s21]" = l_inputs_[1]
|
||||
getitem_2: "f32[s21]" = l_inputs_[2]; l_inputs_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [s69], False)]); getitem = s69 = None
|
||||
size: "Sym(s21)" = l_sizes_0_.size(1); l_sizes_0_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [size], False)]); getitem = size = None
|
||||
getitem_9: "f32[s21]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
|
||||
@ -474,13 +474,13 @@ class GraphModule(torch.nn.Module):
|
||||
return invoke_quant_test(inner, x, y, scheme="nf4")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Encountered aliasing during higher order op tracing for HOP"
|
||||
RuntimeError, "Encountered aliasing during higher order op tracing"
|
||||
):
|
||||
f(inner, x, y)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Encountered input mutation during higher order op tracing for HOP",
|
||||
"Encountered input mutation during higher order op tracing",
|
||||
):
|
||||
f(inner2, x, y)
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.bytecode_transformation import transform_code_object
|
||||
from torch._dynamo.exc import PackageError
|
||||
from torch._dynamo.guards import CheckFunctionManager, CompileId
|
||||
from torch._dynamo.symbolic_convert import (
|
||||
ExceptionStack,
|
||||
@ -235,6 +236,15 @@ pytree.register_constant(CustomConstantType)
|
||||
|
||||
|
||||
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
def test_function_locals(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
def fn(x, g):
|
||||
return g(x) + 1
|
||||
|
||||
self._test_serialization("TENSOR_MATCH", fn, torch.randn(3), foo)
|
||||
|
||||
def _tracefunc(self, frame, event, arg):
|
||||
if event != "call":
|
||||
return
|
||||
@ -481,7 +491,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
# === example subclass defined locally (error) ===
|
||||
local_sub = LocalSubclass(torch.randn(3))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Please define the class at global scope"
|
||||
PackageError, "Please define the class at global scope"
|
||||
):
|
||||
self._test_serialization("TENSOR_SUBCLASS_METADATA_MATCH", fn, local_sub)
|
||||
|
||||
@ -646,7 +656,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
# we don't support NN_MODULE because it adds an ID_MATCH guard, and we don't
|
||||
# support that in serialization
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "NN_MODULE guard cannot be serialized."
|
||||
PackageError, "NN_MODULE guard cannot be serialized."
|
||||
):
|
||||
self._test_serialization("NN_MODULE", fn, m, x)
|
||||
|
||||
@ -662,7 +672,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
# we don't support FUNCTION_MATCH because it adds an ID_MATCH guard, and we don't
|
||||
# support that in serialization
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "FUNCTION_MATCH guard cannot be serialized."
|
||||
PackageError, "FUNCTION_MATCH guard cannot be serialized."
|
||||
):
|
||||
self._test_serialization("FUNCTION_MATCH", fn, x)
|
||||
|
||||
@ -676,7 +686,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
# we don't support CLOSURE_MATCH because it adds a FUNCTION_MATCH guard, and we don't
|
||||
# support that in serialization
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "CLOSURE_MATCH guard cannot be serialized."
|
||||
PackageError, "CLOSURE_MATCH guard cannot be serialized."
|
||||
):
|
||||
self._test_serialization("CLOSURE_MATCH", fn, x)
|
||||
|
||||
@ -795,7 +805,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
return pytree.tree_leaves(x)[0] + 1
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DICT_VERSION guard cannot be serialized."
|
||||
PackageError, "DICT_VERSION guard cannot be serialized."
|
||||
):
|
||||
self._test_serialization("DICT_VERSION", fn, {"t": torch.randn(3)})
|
||||
|
||||
@ -847,7 +857,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
return x + id(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "ID_MATCH guard cannot be serialized."
|
||||
PackageError, "ID_MATCH guard cannot be serialized."
|
||||
):
|
||||
self._test_serialization("ID_MATCH", fn, torch.randn(3))
|
||||
|
||||
@ -1023,7 +1033,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DUPLICATE_INPUT guard cannot be serialized"
|
||||
PackageError, "DUPLICATE_INPUT guard cannot be serialized"
|
||||
):
|
||||
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
|
||||
|
||||
@ -1040,7 +1050,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
return params[0].sum()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "WEAKREF_ALIVE guard cannot be serialized"
|
||||
PackageError, "WEAKREF_ALIVE guard cannot be serialized"
|
||||
):
|
||||
with torch.set_grad_enabled(False):
|
||||
self._test_serialization("WEAKREF_ALIVE", fn)
|
||||
@ -1159,7 +1169,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
with torch._C.DisableTorchFunction():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
PackageError,
|
||||
"defined in local scope. Please define the class at global scope",
|
||||
):
|
||||
with LocalTorchFunctionMode():
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import contextlib
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import torch._dynamo.config
|
||||
@ -54,6 +55,104 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
f(torch.randn(2, 6))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_whitelist_suggestion(self):
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(4, 4)
|
||||
self.attr = torch.randn(4)
|
||||
|
||||
def forward(self, x, y):
|
||||
return self.lin(x) + self.attr + y
|
||||
|
||||
sources = [
|
||||
"L['x']",
|
||||
"L['self']._modules['lin']._parameters['weight']",
|
||||
"L['self']._modules['lin']._parameters['bias']",
|
||||
"L['self'].attr",
|
||||
"L['y']",
|
||||
]
|
||||
|
||||
def check_whitelist(sources_):
|
||||
state = torch._dynamo.pgo.render_code_state(
|
||||
torch._dynamo.pgo.get_code_state()
|
||||
)
|
||||
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(
|
||||
1
|
||||
)
|
||||
for src in sources_:
|
||||
self.assertTrue(src in whitelist)
|
||||
|
||||
# check growing whitelist
|
||||
f = Foo()
|
||||
f(torch.randn(2, 4), torch.randn(4))
|
||||
# only x
|
||||
f(torch.randn(4, 4), torch.randn(4))
|
||||
check_whitelist(sources[:1])
|
||||
# x, lin.weight
|
||||
f.lin = torch.nn.Linear(8, 4)
|
||||
f(torch.randn(8, 8), torch.randn(4))
|
||||
check_whitelist(sources[:2])
|
||||
# x, y, lin.weight, lin.bias, attr
|
||||
f.lin = torch.nn.Linear(8, 8)
|
||||
f.attr = torch.randn(8)
|
||||
f(torch.randn(8, 8), torch.randn(8))
|
||||
check_whitelist(sources)
|
||||
|
||||
# now use suggested whitelist
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
|
||||
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1)
|
||||
with torch.compiler.config.patch(dynamic_sources=whitelist):
|
||||
f = Foo()
|
||||
f(torch.randn(2, 4), torch.randn(4))
|
||||
f(torch.randn(4, 4), torch.randn(4))
|
||||
f.lin = torch.nn.Linear(8, 8)
|
||||
f.attr = torch.randn(8)
|
||||
f(torch.randn(8, 8), torch.randn(8))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_pgo_dynamic_params(self):
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin = None
|
||||
|
||||
def forward(self, x):
|
||||
return self.lin(x)
|
||||
|
||||
f = Foo()
|
||||
|
||||
def run():
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
f.lin = torch.nn.Linear(4, 4)
|
||||
f(torch.randn(2, 4))
|
||||
f(torch.randn(4, 4))
|
||||
f.lin = torch.nn.Linear(8, 8)
|
||||
f(torch.randn(8, 8))
|
||||
|
||||
# recompile each run
|
||||
run()
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
# parameter static shapes are forced static, so we recompile once
|
||||
run()
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
# flags are flipped, PGO records dynamism, so params are dynamically compiled to start
|
||||
torch._dynamo.config.force_parameter_static_shapes = False
|
||||
torch._dynamo.config.force_nn_module_property_static_shapes = False
|
||||
run()
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_njt(self):
|
||||
cnts = CompileCounter()
|
||||
|
||||
|
||||
@ -3226,6 +3226,25 @@ class GraphModule(torch.nn.Module):
|
||||
lengths = torch.tensor([2, 4, 3])
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
|
||||
|
||||
def test_in_graph_construction_from_input_6(self):
|
||||
# Construct with symbolic int.
|
||||
def fn(values, offsets, max_seqlen):
|
||||
t = torch.nested.nested_tensor_from_jagged(
|
||||
values, offsets, max_seqlen=max_seqlen
|
||||
)
|
||||
return torch.nested.nested_tensor_from_jagged(
|
||||
values, t.offsets(), max_seqlen=t._maybe_max_seqlen
|
||||
)
|
||||
|
||||
opt_fn = torch.compile(fn, fullgraph=True, dynamic=True)
|
||||
values = torch.randn(10, 5)
|
||||
offsets = torch.tensor([0, 2, 4, 7, 10])
|
||||
max_seqlen = 5
|
||||
|
||||
ref = fn(values, offsets, max_seqlen)
|
||||
res = opt_fn(values, offsets, max_seqlen)
|
||||
self.assertEqualIgnoringNestedInts(ref, res)
|
||||
|
||||
#
|
||||
# Case 2: in-graph construction where offsets are graph intermediates
|
||||
#
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
from unittest import mock
|
||||
@ -141,6 +142,69 @@ class TestUtils(TestCase):
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
|
||||
|
||||
def test_frame_traced_hook(self):
|
||||
from utils import add, break_it
|
||||
|
||||
traced_code_lists = []
|
||||
|
||||
def get_traced_code(s):
|
||||
nonlocal traced_code_lists
|
||||
traced_code_lists.append(s)
|
||||
|
||||
def get_filenames(traced_code_lists):
|
||||
return [
|
||||
[code.co_filename for code in code_list]
|
||||
for code_list in traced_code_lists
|
||||
]
|
||||
|
||||
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
|
||||
|
||||
# === no inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
return x * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# expect hook to be called once with this file
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__]])
|
||||
|
||||
# === successful inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
return add(x) * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
|
||||
# expect hook to be called once with both this file and file of inlined func
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]])
|
||||
|
||||
# === graph break occurs during inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
y = break_it(x)
|
||||
return y * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# expect hook to be called twice; once for this file one for file of inlined func
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]])
|
||||
|
||||
# === empty graph ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# hook is not expected to be called at all for an empty graph
|
||||
self.assertEqual(traced_code_lists, [])
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
@ -39,6 +39,10 @@ def add(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def break_it(x):
|
||||
return x.sum().item()
|
||||
|
||||
|
||||
def create_dummy_module_and_function():
|
||||
module = types.ModuleType("dummy_module")
|
||||
module.__spec__ = importlib.machinery.ModuleSpec(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user