mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			70 Commits
		
	
	
		
			whc_flight
			...
			release/1.
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 880c811aeb | |||
| dfbd030854 | |||
| e2cb357367 | |||
| eccdfffa95 | |||
| 3c81d48e29 | |||
| 61e9e88c30 | |||
| 04dd41d7cb | |||
| 9cc96ad811 | |||
| 7a9f23c030 | |||
| 3a2876d98e | |||
| 81efb9fd1a | |||
| 568bc668e5 | |||
| 088ca37c5c | |||
| 544c8bd791 | |||
| 5057e20247 | |||
| 81ddd719b4 | |||
| dce3a4047f | |||
| 4d7401c2f7 | |||
| e7cd59c7a0 | |||
| 1a7c23c4c0 | |||
| d69c22dd61 | |||
| 4ad4f6db7f | |||
| 90e67738b1 | |||
| 43c581aa62 | |||
| bc446f6a54 | |||
| abe996a7fb | |||
| 795df76568 | |||
| 3b9cd08901 | |||
| 226c274f70 | |||
| ce24cab257 | |||
| d98d113810 | |||
| 17a44c2bb5 | |||
| 26e6fa380e | |||
| bf16699cc8 | |||
| 6d4fe05502 | |||
| b046542f8a | |||
| 5d57b9392c | |||
| f6a9351776 | |||
| 3071601491 | |||
| d417a094f3 | |||
| 1fdbbc96ae | |||
| e761f16ad5 | |||
| 0544a765d3 | |||
| 1ea8ae5d93 | |||
| 97ca7303b0 | |||
| ac94547143 | |||
| e2027acebe | |||
| 0896c6b1f0 | |||
| 43f6675363 | |||
| 450f5c6f4d | |||
| 310e528a0d | |||
| e4161d0b2b | |||
| 016dc8cb68 | |||
| a3ea5cee52 | |||
| dfc58f4faa | |||
| b5e2635281 | |||
| 9dfd2e7b56 | |||
| f0bdbb4ce1 | |||
| bc4471c8c9 | |||
| 317fd72526 | |||
| b8d36033f0 | |||
| 47507259b9 | |||
| e77e8d52da | |||
| b9fb6d1c7e | |||
| 1ea310bc8e | |||
| 8e6b8d8d46 | |||
| 87c46a5e32 | |||
| 5092364d78 | |||
| 085a3bcb77 | |||
| 5f0bbb38ec | 
@ -97,13 +97,13 @@ WORKFLOW_DATA = [
 | 
			
		||||
        is_master_only=False,
 | 
			
		||||
        is_pr_only=True),
 | 
			
		||||
    AndroidGradleJob(
 | 
			
		||||
        "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single_lite_interpreter",
 | 
			
		||||
        "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit",
 | 
			
		||||
        "pytorch_android_gradle_custom_build_single",
 | 
			
		||||
        [DOCKER_REQUIREMENT_NDK],
 | 
			
		||||
        is_master_only=False,
 | 
			
		||||
        is_pr_only=True,
 | 
			
		||||
        extra_props=tuple({
 | 
			
		||||
            "lite_interpreter": miniutils.quote(str(int(True)))
 | 
			
		||||
            "lite_interpreter": miniutils.quote(str(int(False)))
 | 
			
		||||
        }.items())),
 | 
			
		||||
    AndroidGradleJob(
 | 
			
		||||
        "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build",
 | 
			
		||||
 | 
			
		||||
@ -61,14 +61,20 @@ class IOSJob:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
WORKFLOW_DATA = [
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("x86_64"), is_org_member_context=False),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("x86_64", "lite_interpreter"), is_org_member_context=False, extra_props={
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("x86_64"), is_org_member_context=False, extra_props={
 | 
			
		||||
        "lite_interpreter": miniutils.quote(str(int(True)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64")),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64", "metal"), extra_props={"use_metal": miniutils.quote(str(int(True)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64", "lite_interpreter"), extra_props={
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("x86_64", "full_jit"), is_org_member_context=False, extra_props={
 | 
			
		||||
        "lite_interpreter": miniutils.quote(str(int(False)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64"), extra_props={
 | 
			
		||||
        "lite_interpreter": miniutils.quote(str(int(True)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64", "metal"), extra_props={
 | 
			
		||||
        "use_metal": miniutils.quote(str(int(True))),
 | 
			
		||||
        "lite_interpreter": miniutils.quote(str(int(True)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64", "full_jit"), extra_props={
 | 
			
		||||
        "lite_interpreter": miniutils.quote(str(int(False)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64", "custom"), extra_props={
 | 
			
		||||
        "op_list": "mobilenetv2.yaml",
 | 
			
		||||
        "lite_interpreter": miniutils.quote(str(int(True)))}),
 | 
			
		||||
    IOSJob(XCODE_VERSION, ArchVariant("arm64", "custom"), extra_props={"op_list": "mobilenetv2.yaml"}),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -297,7 +297,7 @@ pytorch_android_params: &pytorch_android_params
 | 
			
		||||
      default: ""
 | 
			
		||||
    lite_interpreter:
 | 
			
		||||
      type: string
 | 
			
		||||
      default: "0"
 | 
			
		||||
      default: "1"
 | 
			
		||||
  environment:
 | 
			
		||||
    BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single
 | 
			
		||||
    DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c"
 | 
			
		||||
@ -324,7 +324,7 @@ pytorch_ios_params: &pytorch_ios_params
 | 
			
		||||
      default: "0"
 | 
			
		||||
    lite_interpreter:
 | 
			
		||||
      type: string
 | 
			
		||||
      default: "0"
 | 
			
		||||
      default: "1"
 | 
			
		||||
  environment:
 | 
			
		||||
    BUILD_ENVIRONMENT: << parameters.build_environment >>
 | 
			
		||||
    IOS_ARCH: << parameters.ios_arch >>
 | 
			
		||||
@ -1759,8 +1759,8 @@ jobs:
 | 
			
		||||
          no_output_timeout: "30m"
 | 
			
		||||
          command: |
 | 
			
		||||
            set -e
 | 
			
		||||
            if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
 | 
			
		||||
              echo "Run Build Test is not for BUILD_LITE_INTERPRETER, skipping."
 | 
			
		||||
            if [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
 | 
			
		||||
              echo "Run Build Test is not for full jit, skipping."
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
            PROJ_ROOT=/Users/distiller/project
 | 
			
		||||
@ -1788,8 +1788,8 @@ jobs:
 | 
			
		||||
            if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then
 | 
			
		||||
              echo "not SIMULATOR build, skip it."
 | 
			
		||||
              exit 0
 | 
			
		||||
            elif [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
 | 
			
		||||
              echo "Run Simulator Tests is not for BUILD_LITE_INTERPRETER, skipping."
 | 
			
		||||
            elif [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
 | 
			
		||||
              echo "Run Simulator Tests is not for full jit, skipping."
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
            WORKSPACE=/Users/distiller/workspace
 | 
			
		||||
@ -6534,8 +6534,8 @@ workflows:
 | 
			
		||||
              only:
 | 
			
		||||
                - /gh\/.*\/head/
 | 
			
		||||
                - /pull\/.*/
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single_lite_interpreter
 | 
			
		||||
          lite_interpreter: "0"
 | 
			
		||||
          name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit
 | 
			
		||||
          requires:
 | 
			
		||||
            - docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c
 | 
			
		||||
      - pytorch_android_gradle_build:
 | 
			
		||||
@ -6555,38 +6555,42 @@ workflows:
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-x86_64_build
 | 
			
		||||
          ios_arch: x86_64
 | 
			
		||||
          ios_platform: SIMULATOR
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch_ios_12_0_0_x86_64_build
 | 
			
		||||
      - pytorch_ios_build:
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-x86_64_lite_interpreter_build
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-x86_64_full_jit_build
 | 
			
		||||
          ios_arch: x86_64
 | 
			
		||||
          ios_platform: SIMULATOR
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch_ios_12_0_0_x86_64_lite_interpreter_build
 | 
			
		||||
          lite_interpreter: "0"
 | 
			
		||||
          name: pytorch_ios_12_0_0_x86_64_full_jit_build
 | 
			
		||||
      - pytorch_ios_build:
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-arm64_build
 | 
			
		||||
          context: org-member
 | 
			
		||||
          ios_arch: arm64
 | 
			
		||||
          ios_platform: OS
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch_ios_12_0_0_arm64_build
 | 
			
		||||
      - pytorch_ios_build:
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-arm64_metal_build
 | 
			
		||||
          context: org-member
 | 
			
		||||
          ios_arch: arm64
 | 
			
		||||
          ios_platform: OS
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch_ios_12_0_0_arm64_metal_build
 | 
			
		||||
          use_metal: "1"
 | 
			
		||||
      - pytorch_ios_build:
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-arm64_lite_interpreter_build
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-arm64_full_jit_build
 | 
			
		||||
          context: org-member
 | 
			
		||||
          ios_arch: arm64
 | 
			
		||||
          ios_platform: OS
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch_ios_12_0_0_arm64_lite_interpreter_build
 | 
			
		||||
          lite_interpreter: "0"
 | 
			
		||||
          name: pytorch_ios_12_0_0_arm64_full_jit_build
 | 
			
		||||
      - pytorch_ios_build:
 | 
			
		||||
          build_environment: pytorch-ios-12.0.0-arm64_custom_build
 | 
			
		||||
          context: org-member
 | 
			
		||||
          ios_arch: arm64
 | 
			
		||||
          ios_platform: OS
 | 
			
		||||
          lite_interpreter: "1"
 | 
			
		||||
          name: pytorch_ios_12_0_0_arm64_custom_build
 | 
			
		||||
          op_list: mobilenetv2.yaml
 | 
			
		||||
      - pytorch_linux_build:
 | 
			
		||||
 | 
			
		||||
@ -88,6 +88,7 @@ case "$image" in
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-py3.6-gcc7.2)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.6
 | 
			
		||||
@ -100,6 +101,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7)
 | 
			
		||||
    CUDA_VERSION=10.0
 | 
			
		||||
@ -109,6 +111,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7)
 | 
			
		||||
    CUDA_VERSION=10.1
 | 
			
		||||
@ -119,6 +122,7 @@ case "$image" in
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7)
 | 
			
		||||
    CUDA_VERSION=10.2
 | 
			
		||||
@ -129,6 +133,7 @@ case "$image" in
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7)
 | 
			
		||||
    CUDA_VERSION=11.1
 | 
			
		||||
@ -139,6 +144,7 @@ case "$image" in
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7)
 | 
			
		||||
    CUDA_VERSION=11.3.0 # Deviating from major.minor to conform to nvidia's Docker image names
 | 
			
		||||
@ -149,6 +155,7 @@ case "$image" in
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-py3-clang5-asan)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.6
 | 
			
		||||
@ -156,6 +163,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-py3-clang7-onnx)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.6
 | 
			
		||||
@ -163,6 +171,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-xenial-py3-clang5-android-ndk-r19c)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.6
 | 
			
		||||
@ -181,6 +190,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-py3.6-clang9)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.6
 | 
			
		||||
@ -188,6 +198,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    VULKAN_SDK_VERSION=1.2.162.1
 | 
			
		||||
    SWIFTSHADER=yes
 | 
			
		||||
    ;;
 | 
			
		||||
@ -198,6 +209,7 @@ case "$image" in
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-cuda10.2-cudnn7-py3.6-clang9)
 | 
			
		||||
    CUDA_VERSION=10.2
 | 
			
		||||
@ -207,6 +219,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-cuda10.2-cudnn7-py3.8-gcc9)
 | 
			
		||||
    CUDA_VERSION=10.2
 | 
			
		||||
@ -216,6 +229,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7)
 | 
			
		||||
    CUDA_VERSION=10.2
 | 
			
		||||
@ -225,6 +239,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9)
 | 
			
		||||
    CUDA_VERSION=11.0
 | 
			
		||||
@ -234,6 +249,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ROCM_VERSION=3.9
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-rocm4.0.1-py3.6)
 | 
			
		||||
@ -242,6 +258,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ROCM_VERSION=4.0.1
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-rocm4.1-py3.6)
 | 
			
		||||
@ -250,6 +267,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ROCM_VERSION=4.1
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-bionic-rocm4.2-py3.6)
 | 
			
		||||
@ -258,6 +276,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    ROCM_VERSION=4.2
 | 
			
		||||
    ;;
 | 
			
		||||
  *)
 | 
			
		||||
@ -265,6 +284,7 @@ case "$image" in
 | 
			
		||||
    PROTOBUF=yes
 | 
			
		||||
    DB=yes
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    BREAKPAD=yes
 | 
			
		||||
    echo "image '$image' did not match an existing build configuration"
 | 
			
		||||
    if [[ "$image" == *py* ]]; then
 | 
			
		||||
      extract_version_from_image_name py ANACONDA_PYTHON_VERSION
 | 
			
		||||
 | 
			
		||||
@ -2,12 +2,18 @@
 | 
			
		||||
 | 
			
		||||
set -ex
 | 
			
		||||
 | 
			
		||||
git clone https://github.com/google/breakpad.git
 | 
			
		||||
cd breakpad
 | 
			
		||||
git clone https://github.com/malfet/breakpad.git -b pytorch/release-1.9
 | 
			
		||||
pushd breakpad
 | 
			
		||||
 | 
			
		||||
git clone https://chromium.googlesource.com/linux-syscall-support src/third_party/lss
 | 
			
		||||
pushd src/third_party/lss
 | 
			
		||||
# same as with breakpad, there are no real releases for this repo so use a
 | 
			
		||||
# commit as the pin
 | 
			
		||||
git checkout e1e7b0ad8ee99a875b272c8e33e308472e897660
 | 
			
		||||
popd
 | 
			
		||||
 | 
			
		||||
./configure
 | 
			
		||||
make
 | 
			
		||||
make install
 | 
			
		||||
cd ..
 | 
			
		||||
popd
 | 
			
		||||
rm -rf breakpad
 | 
			
		||||
 | 
			
		||||
@ -61,6 +61,10 @@ RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi
 | 
			
		||||
RUN rm install_vision.sh
 | 
			
		||||
ENV INSTALLED_VISION ${VISION}
 | 
			
		||||
 | 
			
		||||
ADD ./common/install_openssl.sh install_openssl.sh
 | 
			
		||||
ENV OPENSSL_ROOT_DIR /opt/openssl
 | 
			
		||||
RUN bash ./install_openssl.sh
 | 
			
		||||
 | 
			
		||||
# Install ccache/sccache (do this last, so we get priority in PATH)
 | 
			
		||||
ADD ./common/install_cache.sh install_cache.sh
 | 
			
		||||
ENV PATH /opt/cache/bin:$PATH
 | 
			
		||||
@ -93,9 +97,5 @@ ENV TORCH_NVCC_FLAGS "-Xfatbin -compress-all"
 | 
			
		||||
# Install LLVM dev version (Defined in the pytorch/builder github repository)
 | 
			
		||||
COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm
 | 
			
		||||
 | 
			
		||||
ADD ./common/install_openssl.sh install_openssl.sh
 | 
			
		||||
ENV OPENSSL_ROOT_DIR /opt/openssl
 | 
			
		||||
RUN bash ./install_openssl.sh
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
CMD ["bash"]
 | 
			
		||||
 | 
			
		||||
@ -113,6 +113,10 @@ ADD ./common/install_ninja.sh install_ninja.sh
 | 
			
		||||
RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi
 | 
			
		||||
RUN rm install_ninja.sh
 | 
			
		||||
 | 
			
		||||
ADD ./common/install_openssl.sh install_openssl.sh
 | 
			
		||||
RUN bash ./install_openssl.sh
 | 
			
		||||
ENV OPENSSL_ROOT_DIR /opt/openssl
 | 
			
		||||
 | 
			
		||||
# Install ccache/sccache (do this last, so we get priority in PATH)
 | 
			
		||||
ADD ./common/install_cache.sh install_cache.sh
 | 
			
		||||
ENV PATH /opt/cache/bin:$PATH
 | 
			
		||||
@ -130,9 +134,5 @@ ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT}
 | 
			
		||||
# Install LLVM dev version (Defined in the pytorch/builder github repository)
 | 
			
		||||
COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm
 | 
			
		||||
 | 
			
		||||
ADD ./common/install_openssl.sh install_openssl.sh
 | 
			
		||||
RUN bash ./install_openssl.sh
 | 
			
		||||
ENV OPENSSL_ROOT_DIR /opt/openssl
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
CMD ["bash"]
 | 
			
		||||
 | 
			
		||||
@ -63,6 +63,7 @@ popd
 | 
			
		||||
# Clone the Builder master repo
 | 
			
		||||
retry git clone -q https://github.com/pytorch/builder.git "$BUILDER_ROOT"
 | 
			
		||||
pushd "$BUILDER_ROOT"
 | 
			
		||||
git checkout release/1.9
 | 
			
		||||
echo "Using builder from "
 | 
			
		||||
git --no-pager log --max-count 1
 | 
			
		||||
popd
 | 
			
		||||
 | 
			
		||||
@ -9,10 +9,6 @@ python_nodot="\$(echo $DESIRED_PYTHON | tr -d m.u)"
 | 
			
		||||
 | 
			
		||||
# Set up Python
 | 
			
		||||
if [[ "$PACKAGE_TYPE" == conda ]]; then
 | 
			
		||||
  # There was a bug that was introduced in conda-package-handling >= 1.6.1 that makes archives
 | 
			
		||||
  # above a certain size fail out when attempting to extract
 | 
			
		||||
  # see: https://github.com/conda/conda-package-handling/issues/71
 | 
			
		||||
  conda install -y conda-package-handling=1.6.0
 | 
			
		||||
  retry conda create -qyn testenv python="$DESIRED_PYTHON"
 | 
			
		||||
  source activate testenv >/dev/null
 | 
			
		||||
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
 | 
			
		||||
@ -38,6 +34,10 @@ if [[ "$DESIRED_CUDA" == "cu112" ]]; then
 | 
			
		||||
  EXTRA_CONDA_FLAGS="-c=conda-forge"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Move debug wheels out of the the package dir so they don't get installed
 | 
			
		||||
mkdir -p /tmp/debug_final_pkgs
 | 
			
		||||
mv /final_pkgs/debug-*.zip /tmp/debug_final_pkgs || echo "no debug packages to move"
 | 
			
		||||
 | 
			
		||||
# Install the package
 | 
			
		||||
# These network calls should not have 'retry's because they are installing
 | 
			
		||||
# locally and aren't actually network calls
 | 
			
		||||
 | 
			
		||||
@ -68,6 +68,18 @@ if [[ -z "$DOCKER_IMAGE" ]]; then
 | 
			
		||||
  fi
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
USE_GOLD_LINKER="OFF"
 | 
			
		||||
# GOLD linker can not be used if CUPTI is statically linked into PyTorch, see https://github.com/pytorch/pytorch/issues/57744
 | 
			
		||||
if [[ ${DESIRED_CUDA} == "cpu" ]]; then
 | 
			
		||||
  USE_GOLD_LINKER="ON"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
USE_WHOLE_CUDNN="OFF"
 | 
			
		||||
# Link whole cuDNN for CUDA-11.1 to include fp16 fast kernels
 | 
			
		||||
if [[  "$(uname)" == "Linux" && "${DESIRED_CUDA}" == "cu111" ]]; then
 | 
			
		||||
  USE_WHOLE_CUDNN="ON"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Default to nightly, since that's where this normally uploads to
 | 
			
		||||
PIP_UPLOAD_FOLDER='nightly/'
 | 
			
		||||
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
 | 
			
		||||
@ -169,7 +181,9 @@ export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}"
 | 
			
		||||
export CIRCLE_BRANCH="$CIRCLE_BRANCH"
 | 
			
		||||
export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID"
 | 
			
		||||
 | 
			
		||||
export USE_GOLD_LINKER=1
 | 
			
		||||
export USE_GOLD_LINKER="${USE_GOLD_LINKER}"
 | 
			
		||||
export USE_GLOO_WITH_OPENSSL="ON"
 | 
			
		||||
export USE_WHOLE_CUDNN="${USE_WHOLE_CUDNN}"
 | 
			
		||||
# =================== The above code will be executed inside Docker container ===================
 | 
			
		||||
EOL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,43 +50,50 @@ else
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
add_to_env_file() {
 | 
			
		||||
  local content
 | 
			
		||||
  content=$1
 | 
			
		||||
  # BASH_ENV should be set by CircleCI
 | 
			
		||||
  echo "${content}" >> "${BASH_ENV:-/tmp/env}"
 | 
			
		||||
  local name=$1
 | 
			
		||||
  local value=$2
 | 
			
		||||
  case "$value" in
 | 
			
		||||
    *\ *)
 | 
			
		||||
      # BASH_ENV should be set by CircleCI
 | 
			
		||||
      echo "${name}='${value}'" >> "${BASH_ENV:-/tmp/env}"
 | 
			
		||||
      ;;
 | 
			
		||||
    *)
 | 
			
		||||
      echo "${name}=${value}" >> "${BASH_ENV:-/tmp/env}"
 | 
			
		||||
      ;;
 | 
			
		||||
  esac
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
add_to_env_file "IN_CI=1"
 | 
			
		||||
add_to_env_file "COMMIT_SOURCE=${CIRCLE_BRANCH:-}"
 | 
			
		||||
add_to_env_file "BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}"
 | 
			
		||||
add_to_env_file "CIRCLE_PULL_REQUEST=${CIRCLE_PULL_REQUEST}"
 | 
			
		||||
add_to_env_file IN_CI 1
 | 
			
		||||
add_to_env_file COMMIT_SOURCE "${CIRCLE_BRANCH:-}"
 | 
			
		||||
add_to_env_file BUILD_ENVIRONMENT "${BUILD_ENVIRONMENT}"
 | 
			
		||||
add_to_env_file CIRCLE_PULL_REQUEST "${CIRCLE_PULL_REQUEST}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if [[ "${BUILD_ENVIRONMENT}" == *-build ]]; then
 | 
			
		||||
  add_to_env_file "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2"
 | 
			
		||||
  add_to_env_file SCCACHE_BUCKET ossci-compiler-cache-circleci-v2
 | 
			
		||||
 | 
			
		||||
  SCCACHE_MAX_JOBS=$(( $(nproc) - 1 ))
 | 
			
		||||
  MEMORY_LIMIT_MAX_JOBS=8  # the "large" resource class on CircleCI has 32 CPU cores, if we use all of them we'll OOM
 | 
			
		||||
  MAX_JOBS=$(( ${SCCACHE_MAX_JOBS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${SCCACHE_MAX_JOBS} ))
 | 
			
		||||
  add_to_env_file "MAX_JOBS=${MAX_JOBS}"
 | 
			
		||||
  add_to_env_file MAX_JOBS "${MAX_JOBS}"
 | 
			
		||||
 | 
			
		||||
  if [ -n "${USE_CUDA_DOCKER_RUNTIME:-}" ]; then
 | 
			
		||||
    add_to_env_file "TORCH_CUDA_ARCH_LIST=5.2"
 | 
			
		||||
    add_to_env_file TORCH_CUDA_ARCH_LIST 5.2
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
 | 
			
		||||
    # This IAM user allows write access to S3 bucket for sccache & bazels3cache
 | 
			
		||||
    set +x
 | 
			
		||||
    add_to_env_file "XLA_CLANG_CACHE_S3_BUCKET_NAME=${XLA_CLANG_CACHE_S3_BUCKET_NAME:-}"
 | 
			
		||||
    add_to_env_file "AWS_ACCESS_KEY_ID=${CIRCLECI_AWS_ACCESS_KEY_FOR_SCCACHE_AND_XLA_BAZEL_S3_BUCKET_V2:-}"
 | 
			
		||||
    add_to_env_file "AWS_SECRET_ACCESS_KEY=${CIRCLECI_AWS_SECRET_KEY_FOR_SCCACHE_AND_XLA_BAZEL_S3_BUCKET_V2:-}"
 | 
			
		||||
    add_to_env_file XLA_CLANG_CACHE_S3_BUCKET_NAME "${XLA_CLANG_CACHE_S3_BUCKET_NAME:-}"
 | 
			
		||||
    add_to_env_file AWS_ACCESS_KEY_ID "${CIRCLECI_AWS_ACCESS_KEY_FOR_SCCACHE_AND_XLA_BAZEL_S3_BUCKET_V2:-}"
 | 
			
		||||
    add_to_env_file AWS_SECRET_ACCESS_KEY "${CIRCLECI_AWS_SECRET_KEY_FOR_SCCACHE_AND_XLA_BAZEL_S3_BUCKET_V2:-}"
 | 
			
		||||
    set -x
 | 
			
		||||
  else
 | 
			
		||||
    # This IAM user allows write access to S3 bucket for sccache
 | 
			
		||||
    set +x
 | 
			
		||||
    add_to_env_file "XLA_CLANG_CACHE_S3_BUCKET_NAME=${XLA_CLANG_CACHE_S3_BUCKET_NAME:-}"
 | 
			
		||||
    add_to_env_file "AWS_ACCESS_KEY_ID=${CIRCLECI_AWS_ACCESS_KEY_FOR_SCCACHE_S3_BUCKET_V4:-}"
 | 
			
		||||
    add_to_env_file "AWS_SECRET_ACCESS_KEY=${CIRCLECI_AWS_SECRET_KEY_FOR_SCCACHE_S3_BUCKET_V4:-}"
 | 
			
		||||
    add_to_env_file XLA_CLANG_CACHE_S3_BUCKET_NAME "${XLA_CLANG_CACHE_S3_BUCKET_NAME:-}"
 | 
			
		||||
    add_to_env_file AWS_ACCESS_KEY_ID "${CIRCLECI_AWS_ACCESS_KEY_FOR_SCCACHE_S3_BUCKET_V4:-}"
 | 
			
		||||
    add_to_env_file AWS_SECRET_ACCESS_KEY "${CIRCLECI_AWS_SECRET_KEY_FOR_SCCACHE_S3_BUCKET_V4:-}"
 | 
			
		||||
    set -x
 | 
			
		||||
  fi
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ pytorch_android_params: &pytorch_android_params
 | 
			
		||||
      default: ""
 | 
			
		||||
    lite_interpreter:
 | 
			
		||||
      type: string
 | 
			
		||||
      default: "0"
 | 
			
		||||
      default: "1"
 | 
			
		||||
  environment:
 | 
			
		||||
    BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single
 | 
			
		||||
    DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c"
 | 
			
		||||
@ -59,7 +59,7 @@ pytorch_ios_params: &pytorch_ios_params
 | 
			
		||||
      default: "0"
 | 
			
		||||
    lite_interpreter:
 | 
			
		||||
      type: string
 | 
			
		||||
      default: "0"
 | 
			
		||||
      default: "1"
 | 
			
		||||
  environment:
 | 
			
		||||
    BUILD_ENVIRONMENT: << parameters.build_environment >>
 | 
			
		||||
    IOS_ARCH: << parameters.ios_arch >>
 | 
			
		||||
 | 
			
		||||
@ -528,8 +528,8 @@
 | 
			
		||||
          no_output_timeout: "30m"
 | 
			
		||||
          command: |
 | 
			
		||||
            set -e
 | 
			
		||||
            if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
 | 
			
		||||
              echo "Run Build Test is not for BUILD_LITE_INTERPRETER, skipping."
 | 
			
		||||
            if [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
 | 
			
		||||
              echo "Run Build Test is not for full jit, skipping."
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
            PROJ_ROOT=/Users/distiller/project
 | 
			
		||||
@ -557,8 +557,8 @@
 | 
			
		||||
            if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then
 | 
			
		||||
              echo "not SIMULATOR build, skip it."
 | 
			
		||||
              exit 0
 | 
			
		||||
            elif [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
 | 
			
		||||
              echo "Run Simulator Tests is not for BUILD_LITE_INTERPRETER, skipping."
 | 
			
		||||
            elif [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
 | 
			
		||||
              echo "Run Simulator Tests is not for full jit, skipping."
 | 
			
		||||
              exit 0
 | 
			
		||||
            fi
 | 
			
		||||
            WORKSPACE=/Users/distiller/workspace
 | 
			
		||||
 | 
			
		||||
@ -200,7 +200,7 @@ fi
 | 
			
		||||
 | 
			
		||||
# Patch required to build xla
 | 
			
		||||
if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
 | 
			
		||||
  git clone --recursive https://github.com/pytorch/xla.git
 | 
			
		||||
  git clone --recursive -b r1.9 https://github.com/pytorch/xla.git
 | 
			
		||||
  ./xla/scripts/apply_patches.sh
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,9 +52,9 @@ function get_exit_code() {
 | 
			
		||||
function file_diff_from_base() {
 | 
			
		||||
  # The fetch may fail on Docker hosts, this fetch is necessary for GHA
 | 
			
		||||
  set +e
 | 
			
		||||
  git fetch origin master --quiet
 | 
			
		||||
  git fetch origin release/1.9 --quiet
 | 
			
		||||
  set -e
 | 
			
		||||
  git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
 | 
			
		||||
  git diff --name-only "$(git merge-base origin/release/1.9 HEAD)" > "$1"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function get_bazel() {
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,13 @@ fi
 | 
			
		||||
export PATH="${WORKSPACE_DIR}/miniconda3/bin:$PATH"
 | 
			
		||||
# shellcheck disable=SC1091
 | 
			
		||||
source "${WORKSPACE_DIR}"/miniconda3/bin/activate
 | 
			
		||||
retry conda install -y mkl mkl-include numpy=1.18.5 pyyaml=5.3 setuptools=46.0.0 cmake cffi ninja typing_extensions dataclasses pip
 | 
			
		||||
 | 
			
		||||
# NOTE: mkl 2021.3.0+ cmake requires sub-command PREPEND, may break the build
 | 
			
		||||
retry conda install -y \
 | 
			
		||||
  mkl=2021.2.0 mkl-include=2021.2.0 \
 | 
			
		||||
  numpy=1.18.5 pyyaml=5.3 setuptools=46.0.0 \
 | 
			
		||||
  cmake cffi ninja typing_extensions dataclasses pip
 | 
			
		||||
 | 
			
		||||
# The torch.hub tests make requests to GitHub.
 | 
			
		||||
#
 | 
			
		||||
# The certifi package from conda-forge is new enough to make the
 | 
			
		||||
 | 
			
		||||
@ -368,7 +368,7 @@ test_backward_compatibility() {
 | 
			
		||||
  python -m venv venv
 | 
			
		||||
  # shellcheck disable=SC1091
 | 
			
		||||
  . venv/bin/activate
 | 
			
		||||
  pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
 | 
			
		||||
  pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_test.html
 | 
			
		||||
  pip show torch
 | 
			
		||||
  python dump_all_function_schemas.py --filename nightly_schemas.txt
 | 
			
		||||
  deactivate
 | 
			
		||||
 | 
			
		||||
@ -194,6 +194,9 @@ cmake_dependent_option(
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
    USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
 | 
			
		||||
    "USE_CUDNN" OFF)
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
  USE_WHOLE_CUDNN "Use whole-library linking for cuDNN" OFF
 | 
			
		||||
    "USE_STATIC_CUDNN" OFF)
 | 
			
		||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
 | 
			
		||||
option(USE_KINETO "Use Kineto profiling library" ON)
 | 
			
		||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" OFF)
 | 
			
		||||
@ -266,6 +269,9 @@ option(USE_ZSTD "Use ZSTD" OFF)
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
  USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, and AArch64." "${CPU_INTEL}"
 | 
			
		||||
  "CPU_INTEL OR CPU_AARCH64" OFF)
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
  USE_MKLDNN_ACL "Use Compute Library for the Arm architecture." OFF
 | 
			
		||||
  "USE_MKLDNN AND CPU_AARCH64" OFF)
 | 
			
		||||
set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN})
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
    USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
cmake_minimum_required(VERSION 3.4.1)
 | 
			
		||||
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" OFF)
 | 
			
		||||
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON)
 | 
			
		||||
message(
 | 
			
		||||
  STATUS
 | 
			
		||||
  "BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")
 | 
			
		||||
 | 
			
		||||
@ -17,8 +17,8 @@ android {
 | 
			
		||||
        }
 | 
			
		||||
        externalNativeBuild {
 | 
			
		||||
            cmake {
 | 
			
		||||
              if(System.env.BUILD_LITE_INTERPRETER == '1') {
 | 
			
		||||
                arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=ON"
 | 
			
		||||
              if(System.env.BUILD_LITE_INTERPRETER == '0') {
 | 
			
		||||
                arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF"
 | 
			
		||||
              } else {
 | 
			
		||||
                arguments "-DANDROID_STL=c++_shared"
 | 
			
		||||
              }
 | 
			
		||||
@ -37,12 +37,12 @@ android {
 | 
			
		||||
    sourceSets {
 | 
			
		||||
        main {
 | 
			
		||||
            java {
 | 
			
		||||
              if(System.env.BUILD_LITE_INTERPRETER == '1') {
 | 
			
		||||
                println 'Build pytorch_jni_lite'
 | 
			
		||||
              } else {
 | 
			
		||||
              if(System.env.BUILD_LITE_INTERPRETER == '0') {
 | 
			
		||||
                println 'Build pytorch_jni'
 | 
			
		||||
                exclude 'org/pytorch/LiteModuleLoader.java'
 | 
			
		||||
                exclude 'org/pytorch/LiteNativePeer.java'
 | 
			
		||||
              } else {
 | 
			
		||||
                println 'Build pytorch_jni_lite'
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
            jniLibs.srcDirs = ['src/main/jniLibs']
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ public final class PyTorchAndroid {
 | 
			
		||||
    if (!NativeLoader.isInitialized()) {
 | 
			
		||||
      NativeLoader.init(new SystemDelegate());
 | 
			
		||||
    }
 | 
			
		||||
    NativeLoader.loadLibrary("pytorch_jni");
 | 
			
		||||
    NativeLoader.loadLibrary("pytorch_jni_lite");
 | 
			
		||||
    PyTorchCodegenLoader.loadNativeLibs();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,11 +14,11 @@ namespace at {
 | 
			
		||||
//    OptionalDeviceGuard guard(device_of(tensor));
 | 
			
		||||
 | 
			
		||||
/// Return the Device of a Tensor, if the Tensor is defined.
 | 
			
		||||
inline optional<Device> device_of(const Tensor& t) {
 | 
			
		||||
inline c10::optional<Device> device_of(const Tensor& t) {
 | 
			
		||||
  if (t.defined()) {
 | 
			
		||||
    return make_optional(t.device());
 | 
			
		||||
    return c10::make_optional(t.device());
 | 
			
		||||
  } else {
 | 
			
		||||
    return nullopt;
 | 
			
		||||
    return c10::nullopt;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -29,11 +29,11 @@ inline optional<Device> device_of(const optional<Tensor>& t) {
 | 
			
		||||
/// Return the Device of a TensorList, if the list is non-empty and
 | 
			
		||||
/// the first Tensor is defined.  (This function implicitly assumes
 | 
			
		||||
/// that all tensors in the list have the same device.)
 | 
			
		||||
inline optional<Device> device_of(TensorList t) {
 | 
			
		||||
inline c10::optional<Device> device_of(TensorList t) {
 | 
			
		||||
  if (!t.empty()) {
 | 
			
		||||
    return device_of(t.front());
 | 
			
		||||
  } else {
 | 
			
		||||
    return nullopt;
 | 
			
		||||
    return c10::nullopt;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -372,6 +372,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
 | 
			
		||||
  KERNEL(ADD_NS(pdist), "pdist", Tensor (const Tensor &, double), fp32)
 | 
			
		||||
  KERNEL(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
 | 
			
		||||
  KERNEL(ADD_NS(renorm), "renorm", Tensor (const Tensor &, const Scalar&, int64_t, const Scalar&), fp32)
 | 
			
		||||
  KERNEL(ADD_NS(grid_sampler), "grid_sampler", Tensor (const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
 | 
			
		||||
  // fp32_set_opt_dtype
 | 
			
		||||
  KERNEL(ADD_NS(prod), "prod", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
 | 
			
		||||
  KERNEL(ADD_NS(prod), "prod.dim_int", Tensor (const Tensor &, int64_t, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
 | 
			
		||||
 | 
			
		||||
@ -133,7 +133,6 @@ TORCH_API DimnameList get_names(const TensorImpl* impl);
 | 
			
		||||
// tensor is constructed with names=None.
 | 
			
		||||
TORCH_API c10::optional<DimnameList> get_opt_names(const TensorImpl* impl);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
} // namespace impl
 | 
			
		||||
 | 
			
		||||
} // namespace at
 | 
			
		||||
 | 
			
		||||
@ -495,6 +495,7 @@ _(aten, miopen_depthwise_convolution_backward_input) \
 | 
			
		||||
_(aten, miopen_depthwise_convolution_backward_weight) \
 | 
			
		||||
_(aten, miopen_rnn) \
 | 
			
		||||
_(aten, miopen_rnn_backward) \
 | 
			
		||||
_(aten, mish) \
 | 
			
		||||
_(aten, mkldnn_convolution) \
 | 
			
		||||
_(aten, mkldnn_convolution_backward) \
 | 
			
		||||
_(aten, mkldnn_convolution_backward_input) \
 | 
			
		||||
 | 
			
		||||
@ -479,23 +479,12 @@ public:
 | 
			
		||||
        vsqrtq_f32(values.val[1]));
 | 
			
		||||
  }
 | 
			
		||||
  Vec256<float> reciprocal() const {
 | 
			
		||||
    float32x4_t r0 = vrecpeq_f32(values.val[0]);
 | 
			
		||||
    float32x4_t r1 = vrecpeq_f32(values.val[1]);
 | 
			
		||||
    // Run two more Netwon's method iterations to get more accurate results
 | 
			
		||||
    r0 = vmulq_f32(vrecpsq_f32(values.val[0], r0), r0);
 | 
			
		||||
    r0 = vmulq_f32(vrecpsq_f32(values.val[0], r0), r0);
 | 
			
		||||
    r1 = vmulq_f32(vrecpsq_f32(values.val[1], r1), r1);
 | 
			
		||||
    r1 = vmulq_f32(vrecpsq_f32(values.val[1], r1), r1);
 | 
			
		||||
    auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]);
 | 
			
		||||
    auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]);
 | 
			
		||||
    return Vec256<float>(r0, r1);
 | 
			
		||||
  }
 | 
			
		||||
  Vec256<float> rsqrt() const {
 | 
			
		||||
    float32x4_t r0 =  vrsqrteq_f32(values.val[0]);
 | 
			
		||||
    float32x4_t r1 =  vrsqrteq_f32(values.val[1]);
 | 
			
		||||
    r0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[0], r0), r0), r0);
 | 
			
		||||
    r0 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[0], r0), r0), r0);
 | 
			
		||||
    r1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[1], r1), r1), r1);
 | 
			
		||||
    r1 = vmulq_f32(vrsqrtsq_f32(vmulq_f32(values.val[1], r1), r1), r1);
 | 
			
		||||
    return Vec256<float>(r0, r1);
 | 
			
		||||
    return this->sqrt().reciprocal();
 | 
			
		||||
  }
 | 
			
		||||
  Vec256<float> pow(const Vec256<float> &exp) const {
 | 
			
		||||
    __at_align32__ float tmp[size()];
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,8 @@ namespace at {
 | 
			
		||||
namespace cuda {
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
std::function<void(void)> THCMagma_init;
 | 
			
		||||
 | 
			
		||||
// NB: deleter is dynamic, because we need it to live in a separate
 | 
			
		||||
// compilation unit (alt is to have another method in hooks, but
 | 
			
		||||
// let's not if we don't need to!)
 | 
			
		||||
@ -51,9 +53,8 @@ std::unique_ptr<THCState, void (*)(THCState*)> CUDAHooks::initCUDA() const {
 | 
			
		||||
  THCState* thc_state = THCState_alloc();
 | 
			
		||||
 | 
			
		||||
  THCudaInit(thc_state);
 | 
			
		||||
#ifdef USE_MAGMA
 | 
			
		||||
  THCMagma_init(thc_state);
 | 
			
		||||
#endif
 | 
			
		||||
  if (THCMagma_init)
 | 
			
		||||
    THCMagma_init();
 | 
			
		||||
  return std::unique_ptr<THCState, void (*)(THCState*)>(
 | 
			
		||||
      thc_state, [](THCState* p) {
 | 
			
		||||
        if (p)
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,9 @@
 | 
			
		||||
 | 
			
		||||
namespace at { namespace cuda { namespace detail {
 | 
			
		||||
 | 
			
		||||
// Callback to initialize THC Magma, which is implemented in torch_cuda_cu
 | 
			
		||||
TORCH_CUDA_CPP_API extern std::function<void()> THCMagma_init;
 | 
			
		||||
 | 
			
		||||
// The real implementation of CUDAHooksInterface
 | 
			
		||||
struct CUDAHooks : public at::CUDAHooksInterface {
 | 
			
		||||
  CUDAHooks(at::CUDAHooksArgs) {}
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,10 @@ TORCH_META_FUNC(silu) (const Tensor& self) {
 | 
			
		||||
  build_unary_op(maybe_get_output(), self);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_META_FUNC(mish) (const Tensor& self) {
 | 
			
		||||
  build_unary_op(maybe_get_output(), self);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_META_FUNC(softplus) (
 | 
			
		||||
  const Tensor& self, const Scalar& beta, const Scalar& threshold
 | 
			
		||||
) {
 | 
			
		||||
@ -127,6 +131,10 @@ DEFINE_DISPATCH(leaky_relu_backward_stub);
 | 
			
		||||
DEFINE_DISPATCH(silu_stub);
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
DEFINE_DISPATCH(silu_backward_stub);
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
DEFINE_DISPATCH(mish_stub);
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
DEFINE_DISPATCH(mish_backward_stub);
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(elu_out) (
 | 
			
		||||
  const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result
 | 
			
		||||
@ -140,6 +148,12 @@ TORCH_IMPL_FUNC(silu_out) (
 | 
			
		||||
  silu_stub(device_type(), *this);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(mish_out) (
 | 
			
		||||
  const Tensor& self, const Tensor& result
 | 
			
		||||
) {
 | 
			
		||||
  mish_stub(device_type(), *this);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_IMPL_FUNC(softplus_out) (
 | 
			
		||||
  const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result
 | 
			
		||||
) {
 | 
			
		||||
@ -306,6 +320,23 @@ Tensor math_silu_backward(
 | 
			
		||||
  return grad_output * (input_sigmoid * (1 + input * (1 - input_sigmoid)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor mish_backward(
 | 
			
		||||
    const Tensor& grad_output,
 | 
			
		||||
    const Tensor& input) {
 | 
			
		||||
  Tensor grad_input = at::empty({0}, input.options());
 | 
			
		||||
  auto iter = TensorIterator::binary_op(grad_input, grad_output, input);
 | 
			
		||||
  mish_backward_stub(iter.device_type(), iter);
 | 
			
		||||
  return grad_input;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor math_mish_backward(
 | 
			
		||||
    const Tensor& grad_output,
 | 
			
		||||
    const Tensor& input) {
 | 
			
		||||
  auto input_tanh_softplus = at::tanh(at::softplus(input));
 | 
			
		||||
  auto input_sigmoid = at::sigmoid(input);
 | 
			
		||||
  return grad_output * (input_tanh_softplus + (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
inline void _rrelu_with_noise_train(
 | 
			
		||||
    Tensor& output,
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,8 @@ DECLARE_DISPATCH(activation_fn, glu_stub);
 | 
			
		||||
DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
 | 
			
		||||
DECLARE_DISPATCH(structured_activation_fn, silu_stub);
 | 
			
		||||
DECLARE_DISPATCH(activation_backward_fn, silu_backward_stub);
 | 
			
		||||
DECLARE_DISPATCH(structured_activation_fn, mish_stub);
 | 
			
		||||
DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
 | 
			
		||||
 | 
			
		||||
} // namespace native
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -66,14 +66,15 @@ bool ceil_mode) {
 | 
			
		||||
    outputHeight, outputWidth, memory_format);
 | 
			
		||||
 | 
			
		||||
  /* resize output and indices */
 | 
			
		||||
  DimnameList maybe_names = input.has_names() ? input.names() : DimnameList{};
 | 
			
		||||
  if (input.ndimension() == 3) {
 | 
			
		||||
    set_output(0, {nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), input.names());
 | 
			
		||||
    set_output(0, {nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), maybe_names);
 | 
			
		||||
    /* indices will contain the locations for each output point */
 | 
			
		||||
    set_output(1, {nInputPlane, outputHeight, outputWidth}, {}, input.options().dtype(kLong), input.names());
 | 
			
		||||
    set_output(1, {nInputPlane, outputHeight, outputWidth}, {}, input.options().dtype(kLong), maybe_names);
 | 
			
		||||
  } else {
 | 
			
		||||
    set_output(0, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), input.names());
 | 
			
		||||
    set_output(0, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().memory_format(memory_format), maybe_names);
 | 
			
		||||
    /* indices will contain the locations for each output point */
 | 
			
		||||
    set_output(1, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().dtype(kLong), input.names());
 | 
			
		||||
    set_output(1, {nbatch, nInputPlane, outputHeight, outputWidth}, {}, input.options().dtype(kLong), maybe_names);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -152,7 +153,8 @@ const Tensor& indices) {
 | 
			
		||||
    outputHeight_for_shape_check, outputWidth_for_shape_check,
 | 
			
		||||
    memory_format);
 | 
			
		||||
 | 
			
		||||
  set_output(0, input.sizes(), {}, input.options().memory_format(memory_format), input.names());
 | 
			
		||||
  set_output(0, input.sizes(), {}, input.options().memory_format(memory_format),
 | 
			
		||||
             input.has_names() ? input.names() : DimnameList{});
 | 
			
		||||
}
 | 
			
		||||
} // namespace meta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,9 @@ namespace meta {
 | 
			
		||||
TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
 | 
			
		||||
  TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
 | 
			
		||||
  TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
 | 
			
		||||
      mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
 | 
			
		||||
 | 
			
		||||
  auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self);
 | 
			
		||||
  set_output(0, IntArrayRef({mat1.sizes()[0], mat2.sizes()[1]}), {}, self.options(), names);
 | 
			
		||||
@ -932,12 +935,6 @@ static void addmm_impl_cpu_(
 | 
			
		||||
  auto m2_strides = m2.strides();
 | 
			
		||||
  auto m2_sizes = m2.sizes();
 | 
			
		||||
 | 
			
		||||
  // keeping TORCH_CHECKs here because othe mm methods also utilize this impl.
 | 
			
		||||
  // TODO move this to meta once all methods have migrated to structured kernel.
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      m1_sizes[1] == m2_sizes[0], "mat1 and mat2 shapes cannot be multiplied (",
 | 
			
		||||
      m1_sizes[0], "x", m1_sizes[1], " and ", m2_sizes[0], "x", m2_sizes[1], ")");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      self_sizes[0] == m1_sizes[0] && self_sizes[1] == m2_sizes[1],
 | 
			
		||||
      "input shape is incompatible with matrix multiplication (",
 | 
			
		||||
@ -1111,6 +1108,9 @@ TORCH_IMPL_FUNC(addmm_out_cpu)(const Tensor& self, const Tensor& mat1, const Ten
 | 
			
		||||
Tensor& mm_cpu_out(const Tensor & self, const Tensor & mat2, Tensor & result) {
 | 
			
		||||
  TORCH_CHECK(self.dim() == 2, "self must be a matrix");
 | 
			
		||||
  TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
 | 
			
		||||
      self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
 | 
			
		||||
  native::resize_(result, {self.sizes()[0], mat2.sizes()[1]});
 | 
			
		||||
  addmm_impl_cpu_(result, result, self, mat2, 0, 1);
 | 
			
		||||
  auto names = at::namedinference::propagate_names_for_addmm(self, mat2, result);
 | 
			
		||||
 | 
			
		||||
@ -47,12 +47,6 @@ bool cudnn_is_acceptable(const Tensor& self) {
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor detach(const Tensor& self) {
 | 
			
		||||
  // this just exists to give us a hook in VariableType and an entry in Declarations.yaml
 | 
			
		||||
  //AT_ERROR("detach is not implemented for Tensor");
 | 
			
		||||
  return self;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor & detach_(Tensor & self) {
 | 
			
		||||
  // this just exists to give us a hook in VariableType and an entry in Declarations.yaml
 | 
			
		||||
  //AT_ERROR("detach_ is not implemented for Tensor");
 | 
			
		||||
 | 
			
		||||
@ -2170,6 +2170,12 @@ Tensor alias(const Tensor& self) {
 | 
			
		||||
    return alias_with_sizes_and_strides(self, self.sizes(), self.strides());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor detach(const Tensor& self) {
 | 
			
		||||
  // this just exists to give us a hook in VariableType and an entry in Declarations.yaml
 | 
			
		||||
  //AT_ERROR("detach is not implemented for Tensor");
 | 
			
		||||
  return native::alias(self);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor unfold(const Tensor& self, int64_t dimension, int64_t size, int64_t step) {
 | 
			
		||||
  // some special handling to deal with allow dimension == 0 when self.dim() == 0
 | 
			
		||||
  dimension = at::maybe_wrap_dim(dimension, self.dim(), /*wrap_scalar=*/true);
 | 
			
		||||
 | 
			
		||||
@ -632,6 +632,40 @@ void silu_backward_kernel(TensorIterator& iter) {
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mish_kernel(TensorIteratorBase& iter) {
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() {
 | 
			
		||||
        using Vec = Vec256<scalar_t>;
 | 
			
		||||
        cpu_kernel_vec(
 | 
			
		||||
            iter,
 | 
			
		||||
            [](scalar_t x) -> scalar_t{
 | 
			
		||||
              return static_cast<scalar_t>(x * std::tanh(std::log1p(std::exp(x))));
 | 
			
		||||
            },
 | 
			
		||||
            [](Vec x_vec) -> Vec {
 | 
			
		||||
              return x_vec * x_vec.exp().log1p().tanh();
 | 
			
		||||
            });
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mish_backward_kernel(TensorIterator& iter) {
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() {
 | 
			
		||||
        using Vec = Vec256<scalar_t>;
 | 
			
		||||
        const Vec kOneVec(scalar_t(1));
 | 
			
		||||
        cpu_kernel_vec(
 | 
			
		||||
            iter,
 | 
			
		||||
            [](scalar_t dy, scalar_t x) -> scalar_t {
 | 
			
		||||
              const scalar_t sigmoid =
 | 
			
		||||
                  scalar_t(1) / (scalar_t(1) + std::exp(-x));
 | 
			
		||||
              const scalar_t tanh_softplus = std::tanh(std::log1p(std::exp(x)));
 | 
			
		||||
              return dy * (tanh_softplus + x * sigmoid * (scalar_t(1) - tanh_softplus * tanh_softplus));
 | 
			
		||||
            },
 | 
			
		||||
            [kOneVec](Vec dy_vec, Vec x_vec) -> Vec {
 | 
			
		||||
              const Vec sigmoid = kOneVec / (kOneVec + x_vec.neg().exp());
 | 
			
		||||
              const Vec tanh_softplus = x_vec.exp().log1p().tanh();
 | 
			
		||||
              return dy_vec * (tanh_softplus + x_vec * sigmoid * (kOneVec - tanh_softplus * tanh_softplus));
 | 
			
		||||
            });
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
@ -680,6 +714,10 @@ REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel);
 | 
			
		||||
REGISTER_DISPATCH(silu_stub, &silu_kernel);
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
REGISTER_DISPATCH(mish_stub, &mish_kernel);
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel);
 | 
			
		||||
 | 
			
		||||
} // namespace native
 | 
			
		||||
} // namespace at
 | 
			
		||||
 | 
			
		||||
@ -496,6 +496,45 @@ void silu_backward_kernel(TensorIterator& iter) {
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mish_kernel(TensorIteratorBase& iter) {
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half,
 | 
			
		||||
      at::ScalarType::BFloat16,
 | 
			
		||||
      iter.dtype(),
 | 
			
		||||
      "mish_cuda",
 | 
			
		||||
      [&]() {
 | 
			
		||||
        gpu_kernel(
 | 
			
		||||
            iter,
 | 
			
		||||
            [] GPU_LAMBDA(scalar_t x) -> scalar_t {
 | 
			
		||||
          using T_ACC = acc_type<scalar_t, true>;
 | 
			
		||||
          const T_ACC x_acc = static_cast<T_ACC>(x);
 | 
			
		||||
          return x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)));
 | 
			
		||||
      });
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mish_backward_kernel(TensorIterator& iter) {
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(
 | 
			
		||||
      at::ScalarType::Half,
 | 
			
		||||
      at::ScalarType::BFloat16,
 | 
			
		||||
      iter.dtype(),
 | 
			
		||||
      "mish_backward_cuda",
 | 
			
		||||
      [&]() {
 | 
			
		||||
        gpu_kernel(
 | 
			
		||||
            iter,
 | 
			
		||||
            [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
 | 
			
		||||
          using T_ACC = acc_type<scalar_t, true>;
 | 
			
		||||
          const T_ACC dy_acc = static_cast<T_ACC>(dy);
 | 
			
		||||
          const T_ACC x_acc = static_cast<T_ACC>(x);
 | 
			
		||||
          const T_ACC s_acc =
 | 
			
		||||
              T_ACC(1) / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
 | 
			
		||||
          const T_ACC t_acc =
 | 
			
		||||
              c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)));
 | 
			
		||||
          return dy_acc * (t_acc + x_acc * s_acc * (T_ACC(1) - t_acc * t_acc));
 | 
			
		||||
      });
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
Tensor gelu_cuda(const Tensor& self) {
 | 
			
		||||
@ -540,6 +579,8 @@ REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
 | 
			
		||||
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
 | 
			
		||||
REGISTER_DISPATCH(silu_stub, &silu_kernel);
 | 
			
		||||
REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
 | 
			
		||||
REGISTER_DISPATCH(mish_stub, &mish_kernel);
 | 
			
		||||
REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel);
 | 
			
		||||
REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda);
 | 
			
		||||
 | 
			
		||||
} // namespace native
 | 
			
		||||
 | 
			
		||||
@ -75,7 +75,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 | 
			
		||||
  // Make sure to keep addmm_cuda below in sync with this code; it
 | 
			
		||||
  // preflights a check to try to avoid actually needing to call
 | 
			
		||||
  // expand().
 | 
			
		||||
  TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D");
 | 
			
		||||
  TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
 | 
			
		||||
  TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
 | 
			
		||||
      mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
 | 
			
		||||
 | 
			
		||||
  TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
 | 
			
		||||
  checkAllSameGPU("addmm", args);
 | 
			
		||||
 | 
			
		||||
@ -24,20 +24,27 @@ struct TensorDims {
 | 
			
		||||
  index_t sizes[MAX_DIMS];
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<typename index_t>
 | 
			
		||||
__global__ void write_indices(int64_t * inp, TensorDims<index_t> dims, int ndim, index_t n){
 | 
			
		||||
    CUDA_KERNEL_LOOP(index, n) { // this assumed int (not int64_t) index
 | 
			
		||||
      index_t div = 1;
 | 
			
		||||
      int64_t idx_flat = inp[index];
 | 
			
		||||
      for (int dim = ndim-1; dim >= 0; dim--){
 | 
			
		||||
        auto dim_size = dims.sizes[dim];
 | 
			
		||||
        inp[index + dim*n] = (idx_flat/div) % dim_size;
 | 
			
		||||
        div *= dim_size;
 | 
			
		||||
      }
 | 
			
		||||
template <typename index_t>
 | 
			
		||||
__global__ void write_indices(
 | 
			
		||||
    int64_t* inp,
 | 
			
		||||
    TensorDims<index_t> dims,
 | 
			
		||||
    int ndim,
 | 
			
		||||
    index_t n) {
 | 
			
		||||
  auto index = threadIdx.x + blockIdx.x * blockDim.x;
 | 
			
		||||
  if (index < n) {
 | 
			
		||||
    index_t div = 1;
 | 
			
		||||
    int64_t idx_flat = inp[index];
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int dim = MAX_DIMS; dim >= 0; dim--) {
 | 
			
		||||
      if (dim > ndim - 1)
 | 
			
		||||
        continue;
 | 
			
		||||
      auto dim_size = dims.sizes[dim];
 | 
			
		||||
      inp[index + dim * n] = (idx_flat / div) % dim_size;
 | 
			
		||||
      div *= dim_size;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
} //anonymous namespace
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
 | 
			
		||||
@ -47,9 +47,9 @@ template <int N> struct alignas(N) OpaqueType { char data[N]; };
 | 
			
		||||
 | 
			
		||||
Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor& result) {
 | 
			
		||||
  TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
 | 
			
		||||
  TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
 | 
			
		||||
  TORCH_CHECK(n <= std::numeric_limits<int>::max(),
 | 
			
		||||
    "randperm of tensors larger than INT_MAX is not supported yet in pytorch");
 | 
			
		||||
 | 
			
		||||
  check_supported_max_int_with_precision(n, result);
 | 
			
		||||
 | 
			
		||||
  result.resize_({n});
 | 
			
		||||
@ -73,13 +73,15 @@ Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor&
 | 
			
		||||
  const double log_threshold_12 = std::log(0.9) * 12;
 | 
			
		||||
  double nd = static_cast<double>(n);
 | 
			
		||||
 | 
			
		||||
  constexpr bool is_reduced_bits = true;
 | 
			
		||||
  int bits = std::min(64,
 | 
			
		||||
    static_cast<int>(std::ceil(std::log2(nd - (6 * nd * nd + 1) / log_threshold_12))));
 | 
			
		||||
 | 
			
		||||
  if (n == 0) {
 | 
			
		||||
    return result;
 | 
			
		||||
  } else if (bits <= 32) {
 | 
			
		||||
    // For asserting device type match of the generator and result,
 | 
			
		||||
    // we deligate that to the 'random_' function below.
 | 
			
		||||
 | 
			
		||||
    auto keys = at::empty(result.sizes(), opt.dtype(kInt)).random_(
 | 
			
		||||
      std::numeric_limits<int>::min(), std::numeric_limits<int>::max(), generator);
 | 
			
		||||
    auto keys_tmp = at::empty_like(keys);
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										527
									
								
								aten/src/ATen/native/cuda/SpectralOps.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										527
									
								
								aten/src/ATen/native/cuda/SpectralOps.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,527 @@
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <ATen/Config.h>
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/Utils.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#include <ATen/cuda/detail/KernelUtils.h>
 | 
			
		||||
#include <ATen/cuda/detail/OffsetCalculator.cuh>
 | 
			
		||||
#include <ATen/detail/CUDAHooksInterface.h>
 | 
			
		||||
#include <ATen/native/Resize.h>
 | 
			
		||||
#include <ATen/native/TensorIterator.h>
 | 
			
		||||
#include <ATen/native/SpectralOpsUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/CuFFTUtils.h>
 | 
			
		||||
#include <ATen/native/cuda/CuFFTPlanCache.h>
 | 
			
		||||
 | 
			
		||||
#include <cufft.h>
 | 
			
		||||
#include <cufftXt.h>
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace at { namespace native {
 | 
			
		||||
 | 
			
		||||
using namespace at::native::detail;
 | 
			
		||||
 | 
			
		||||
// Execute a pre-planned transform
 | 
			
		||||
static void exec_cufft_plan(
 | 
			
		||||
    const CuFFTConfig &config, void* in_data, void* out_data, bool forward) {
 | 
			
		||||
  auto& plan = config.plan();
 | 
			
		||||
#ifdef __HIP_PLATFORM_HCC__
 | 
			
		||||
  auto value_type = config.data_type();
 | 
			
		||||
  if (value_type == kFloat) {
 | 
			
		||||
    switch (config.transform_type()) {
 | 
			
		||||
      case CuFFTTransformType::C2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftComplex*>(out_data),
 | 
			
		||||
                                  forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::R2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftComplex*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::C2R: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftReal*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  } else if (value_type == kDouble) {
 | 
			
		||||
    switch (config.transform_type()) {
 | 
			
		||||
      case CuFFTTransformType::C2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast<hipfftDoubleComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftDoubleComplex*>(out_data),
 | 
			
		||||
                                  forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::R2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftDoubleComplex*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::C2R: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecZ2D(plan, static_cast<hipfftDoubleComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftDoubleReal*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(false, "hipFFT doesn't support transforms on type: ", value_type);
 | 
			
		||||
#else
 | 
			
		||||
  CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
 | 
			
		||||
                          forward ? CUFFT_FORWARD : CUFFT_INVERSE));
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// NOTE [ cuFFT Embedded Strides ]
 | 
			
		||||
//
 | 
			
		||||
// cuFFT supports a subset of arbitrary strides via their "advanced data layout"
 | 
			
		||||
// option (http://docs.nvidia.com/cuda/cufft/index.html#advanced-data-layout).
 | 
			
		||||
// Specifically, these are tensors that can be viewed as subtensors resulted
 | 
			
		||||
// from slicing a larger contiguous tensors. For such input tensors, let the
 | 
			
		||||
// sizes of the enclosing tensor be `inembed`, and we can have in 3d case:
 | 
			
		||||
//
 | 
			
		||||
//     input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z)]
 | 
			
		||||
//
 | 
			
		||||
// Above is the simplified formula ignoring the batch dimension. In fact, the
 | 
			
		||||
// last dimension of the enclosing tensor doesn't have to be contiguous, i.e.,
 | 
			
		||||
// it can be greater than 1. Then one can set the base stride for the enclosing
 | 
			
		||||
// tensor with `istride`. Then we have
 | 
			
		||||
//
 | 
			
		||||
//     input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z) * istride]
 | 
			
		||||
//
 | 
			
		||||
// For example, consider
 | 
			
		||||
//
 | 
			
		||||
//     enclosing = torch.zeros(6, 8, 10)  # contiguous
 | 
			
		||||
//     input = enclosing[:4, 2:6, 6:]
 | 
			
		||||
//     input.size()                       # [ 4,  4,  4]
 | 
			
		||||
//     input.stride()                     # [80, 10,  1]
 | 
			
		||||
//     # inembed = [6, 8, 10]
 | 
			
		||||
//     input[2, 1, 3] = input[((2 * 8) + 1) * 10 + 3]   # using above formula
 | 
			
		||||
//                    = input[173]
 | 
			
		||||
//                    = input[2 * 80 + 1 * 10 + 1 * 3]  # using strides directly
 | 
			
		||||
//
 | 
			
		||||
// Generally, the embedded strides can be computed as
 | 
			
		||||
//
 | 
			
		||||
//     embed[i] = stride[i - 1] / stride[i].
 | 
			
		||||
//
 | 
			
		||||
// Note that the value of embed[0] isn't used to compute indices and doesn't
 | 
			
		||||
// matter.
 | 
			
		||||
//
 | 
			
		||||
// Contrary to advanced data layout, simple layout means that *embeds have
 | 
			
		||||
// unit-strides. In particular, unit-stride refers to that the input and output
 | 
			
		||||
// tensors being contiguous, and that the strides at the innermost signal
 | 
			
		||||
// dimension being unit (1) w.r.t. the corresponding data type.
 | 
			
		||||
 | 
			
		||||
#pragma push
 | 
			
		||||
#pragma diag_suppress 177   // Function was declared but never referenced
 | 
			
		||||
static inline Tensor _run_cufft(
 | 
			
		||||
    const CuFFTConfig &config, Tensor& input, int64_t signal_ndim,
 | 
			
		||||
    bool complex_input, bool complex_output, bool inverse,
 | 
			
		||||
    IntArrayRef checked_signal_sizes, fft_norm_mode norm, bool onesided,
 | 
			
		||||
    IntArrayRef output_sizes, bool input_was_cloned
 | 
			
		||||
) {
 | 
			
		||||
  if (config.should_clone_input() && !input_was_cloned) {
 | 
			
		||||
    input = input.clone(at::MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto& plan = config.plan();
 | 
			
		||||
  auto& ctx = at::globalContext();
 | 
			
		||||
 | 
			
		||||
  // set output
 | 
			
		||||
  auto output = at::empty(output_sizes, input.options());
 | 
			
		||||
 | 
			
		||||
  // set to current stream
 | 
			
		||||
  CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
 | 
			
		||||
 | 
			
		||||
  auto ws = at::empty({ config.workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
 | 
			
		||||
  CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr()));
 | 
			
		||||
 | 
			
		||||
  // run
 | 
			
		||||
  exec_cufft_plan(config, input.data_ptr(), output.data_ptr(), !inverse);
 | 
			
		||||
 | 
			
		||||
  // rescale if requested
 | 
			
		||||
  auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
 | 
			
		||||
  if (norm != fft_norm_mode::none) {
 | 
			
		||||
    auto signal_numel = c10::multiply_integers(checked_signal_sizes);
 | 
			
		||||
    double scale_denom;
 | 
			
		||||
    if (norm == fft_norm_mode::by_root_n) {
 | 
			
		||||
      scale_denom = std::sqrt(static_cast<double>(signal_numel));
 | 
			
		||||
    } else {
 | 
			
		||||
      scale_denom = static_cast<double>(signal_numel);
 | 
			
		||||
    }
 | 
			
		||||
    if (!complex_input && complex_output && !onesided) {
 | 
			
		||||
      auto end_data_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim);
 | 
			
		||||
      output.narrow(signal_ndim, 0, end_data_slice).div_(scale_denom);
 | 
			
		||||
    } else {
 | 
			
		||||
      output.div_(scale_denom);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // if needed, fill out the other half using conjugate symmetry
 | 
			
		||||
  if (!complex_input && complex_output && !onesided) {
 | 
			
		||||
    DimVector signal_dims(signal_ndim);
 | 
			
		||||
    std::iota(signal_dims.begin(), signal_dims.end(), 1);
 | 
			
		||||
    auto out_as_complex = at::view_as_complex(output);
 | 
			
		||||
    at::native::_fft_fill_with_conjugate_symmetry_(out_as_complex, signal_dims);
 | 
			
		||||
  }
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
#pragma pop
 | 
			
		||||
 | 
			
		||||
// The cuFFT plan cache
 | 
			
		||||
// unique_ptr for nullability and to avoid reference invalidation on vector resize
 | 
			
		||||
static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
 | 
			
		||||
static std::mutex plan_caches_mutex;
 | 
			
		||||
 | 
			
		||||
static inline
 | 
			
		||||
CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
 | 
			
		||||
  std::lock_guard<std::mutex> guard(plan_caches_mutex);
 | 
			
		||||
 | 
			
		||||
  AT_ASSERT(device_index >= 0);
 | 
			
		||||
 | 
			
		||||
  if (device_index >= plan_caches.size()) {
 | 
			
		||||
    plan_caches.resize(device_index + 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (!plan_caches[device_index]) {
 | 
			
		||||
    plan_caches[device_index] = std::make_unique<CuFFTParamsLRUCache>();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return *plan_caches[device_index];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).max_size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).resize(max_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_get_plan_cache_size: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cufft_clear_plan_cache_impl(int64_t device_index) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_clear_plan_cache: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native::detail
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
constexpr int64_t cufft_max_ndim = 3;
 | 
			
		||||
 | 
			
		||||
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
 | 
			
		||||
static const Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
 | 
			
		||||
                         IntArrayRef dim, bool forward) {
 | 
			
		||||
  const auto ndim = self.dim();
 | 
			
		||||
  const int64_t signal_ndim = dim.size();
 | 
			
		||||
  const auto batch_dims = ndim - signal_ndim;
 | 
			
		||||
 | 
			
		||||
  // Permute dimensions so batch dimensions come first, and in stride order
 | 
			
		||||
  // This maximizes data locality when collapsing to a single batch dimension
 | 
			
		||||
  DimVector dim_permute(ndim);
 | 
			
		||||
  std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
 | 
			
		||||
 | 
			
		||||
  c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
 | 
			
		||||
  for (const auto& d : dim) {
 | 
			
		||||
    is_transformed_dim[d] = true;
 | 
			
		||||
  }
 | 
			
		||||
  auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
 | 
			
		||||
                                  [&](int64_t d) {return !is_transformed_dim[d]; });
 | 
			
		||||
  auto self_strides = self.strides();
 | 
			
		||||
  std::sort(dim_permute.begin(), batch_end,
 | 
			
		||||
            [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
 | 
			
		||||
  std::copy(dim.cbegin(), dim.cend(), batch_end);
 | 
			
		||||
  auto input = self.permute(dim_permute);
 | 
			
		||||
 | 
			
		||||
  // Collapse batch dimensions into a single dimension
 | 
			
		||||
  DimVector batched_sizes(signal_ndim + 1);
 | 
			
		||||
  batched_sizes[0] = -1;
 | 
			
		||||
  std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
 | 
			
		||||
  input = input.reshape(batched_sizes);
 | 
			
		||||
 | 
			
		||||
  const auto batch_size = input.sizes()[0];
 | 
			
		||||
  DimVector signal_size(signal_ndim + 1);
 | 
			
		||||
  signal_size[0] = batch_size;
 | 
			
		||||
  for (int64_t i = 0; i < signal_ndim; ++i) {
 | 
			
		||||
    auto in_size = input.sizes()[i + 1];
 | 
			
		||||
    auto out_size = out_sizes[dim[i]];
 | 
			
		||||
    signal_size[i + 1] = std::max(in_size, out_size);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
 | 
			
		||||
                          in_size == (signal_size[i + 1] / 2) + 1);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
 | 
			
		||||
                          out_size == (signal_size[i + 1] / 2) + 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  batched_sizes[0] = batch_size;
 | 
			
		||||
  DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
 | 
			
		||||
  for (size_t i = 0; i < dim.size(); ++i) {
 | 
			
		||||
    batched_out_sizes[i + 1] = out_sizes[dim[i]];
 | 
			
		||||
  }
 | 
			
		||||
  out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
 | 
			
		||||
 | 
			
		||||
  // Create the transform plan (either from cache or locally)
 | 
			
		||||
  const auto value_type = c10::toValueType(input.scalar_type());
 | 
			
		||||
  auto fft_type = GetCuFFTTransformType(input.is_complex(), out.is_complex());
 | 
			
		||||
  CuFFTParams Params(input.strides(), out.strides(), signal_size, fft_type, value_type);
 | 
			
		||||
  CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(input.device().index());
 | 
			
		||||
  std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
 | 
			
		||||
  c10::optional<CuFFTConfig> uncached_plan;
 | 
			
		||||
  const CuFFTConfig * config = nullptr;
 | 
			
		||||
 | 
			
		||||
  if (plan_cache.max_size() > 0) {
 | 
			
		||||
    guard.lock();
 | 
			
		||||
    if (plan_cache.max_size() > 0) {  // check again after acquiring the lock
 | 
			
		||||
      config = &plan_cache.lookup(Params);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (config == nullptr) {
 | 
			
		||||
    uncached_plan.emplace(Params);
 | 
			
		||||
    config = &uncached_plan.value();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto & plan = config->plan();
 | 
			
		||||
 | 
			
		||||
  if (config->should_clone_input()) {
 | 
			
		||||
    input = input.clone(MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // prepare cufft for execution
 | 
			
		||||
  CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
 | 
			
		||||
  auto workspace = at::empty({ config->workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
 | 
			
		||||
  CUFFT_CHECK(cufftSetWorkArea(plan, workspace.data_ptr()));
 | 
			
		||||
 | 
			
		||||
  // execute transform plan
 | 
			
		||||
  exec_cufft_plan(*config, input.data_ptr(), out.data_ptr(), forward);
 | 
			
		||||
 | 
			
		||||
  // Inplace reshaping to original batch shape and inverting the dimension permutation
 | 
			
		||||
  DimVector out_strides(ndim);
 | 
			
		||||
  int64_t batch_numel = 1;
 | 
			
		||||
  for (int64_t i = batch_dims - 1; i >= 0; --i) {
 | 
			
		||||
    out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
 | 
			
		||||
    batch_numel *= out_sizes[dim_permute[i]];
 | 
			
		||||
  }
 | 
			
		||||
  for (int64_t i = batch_dims; i < ndim; ++i) {
 | 
			
		||||
    out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
 | 
			
		||||
  }
 | 
			
		||||
  return out.as_strided_(out_sizes, out_strides, out.storage_offset());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculates the normalization constant and applies it in-place to self
 | 
			
		||||
// sizes is the sizes of a twosided tensor and dims are all transformed dims
 | 
			
		||||
double _fft_normalization_scale(int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
 | 
			
		||||
  auto norm = static_cast<fft_norm_mode>(normalization);
 | 
			
		||||
  if (norm == fft_norm_mode::none) {
 | 
			
		||||
    return 1.0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t signal_numel = 1;
 | 
			
		||||
  for (auto dim : dims) {
 | 
			
		||||
    signal_numel *= sizes[dim];
 | 
			
		||||
  }
 | 
			
		||||
  const double scale_denom = (norm == fft_norm_mode::by_root_n) ?
 | 
			
		||||
    std::sqrt(signal_numel) : static_cast<double>(signal_numel);
 | 
			
		||||
  return 1.0 / scale_denom;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const Tensor& _fft_apply_normalization(const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
 | 
			
		||||
  auto scale = _fft_normalization_scale(normalization, sizes, dims);
 | 
			
		||||
  return (scale == 1.0) ? self : self.mul_(scale);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_apply_normalization_out(Tensor& out, const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
 | 
			
		||||
  auto scale = _fft_normalization_scale(normalization, sizes, dims);
 | 
			
		||||
  return at::mul_out(out, self, c10::scalar_to_tensor(scale));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace (anonymous)
 | 
			
		||||
 | 
			
		||||
// n-dimensional real to complex FFT
 | 
			
		||||
Tensor _fft_r2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
 | 
			
		||||
  TORCH_CHECK(self.is_floating_point());
 | 
			
		||||
  auto input_sizes = self.sizes();
 | 
			
		||||
  DimVector onesided_sizes(input_sizes.begin(), input_sizes.end());
 | 
			
		||||
  auto last_dim = dim.back();
 | 
			
		||||
  auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
 | 
			
		||||
  onesided_sizes[last_dim] = last_dim_halfsize;
 | 
			
		||||
  IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes;
 | 
			
		||||
 | 
			
		||||
  const auto out_options = self.options().dtype(c10::toComplexType(self.scalar_type()));
 | 
			
		||||
  auto output = at::empty(out_sizes, out_options);
 | 
			
		||||
 | 
			
		||||
  // CuFFT requires real input to be over-aligned, as if it were complex
 | 
			
		||||
  const auto complex_size = 2 * self.element_size();
 | 
			
		||||
  const bool complex_aligned = (
 | 
			
		||||
      reinterpret_cast<std::uintptr_t>(self.data_ptr()) % complex_size == 0);
 | 
			
		||||
  auto working_tensor = self;
 | 
			
		||||
  if (!complex_aligned) {
 | 
			
		||||
    working_tensor = self.movedim(last_dim, -1)
 | 
			
		||||
                         .clone(MemoryFormat::Contiguous)
 | 
			
		||||
                         .movedim(-1, last_dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // First do the R2C transform on the last dimension
 | 
			
		||||
  {
 | 
			
		||||
    auto target_sizes = dim.size() == 1 ? out_sizes : onesided_sizes;
 | 
			
		||||
    _exec_fft(output, working_tensor, target_sizes, last_dim, /*forward=*/true);
 | 
			
		||||
    if (dim.size() > 1) {
 | 
			
		||||
      working_tensor = at::empty(out_sizes, out_options);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Then any remaining C2C transforms
 | 
			
		||||
  DimVector sorted_dims(dim.begin(), dim.end() - 1);
 | 
			
		||||
  while (!sorted_dims.empty()) {
 | 
			
		||||
    std::swap(output, working_tensor);
 | 
			
		||||
 | 
			
		||||
    // Resort dimensions every time as _exec_fft re-strides the output
 | 
			
		||||
    auto strides = working_tensor.strides();
 | 
			
		||||
    std::sort(sorted_dims.begin(), sorted_dims.end(),
 | 
			
		||||
              [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
 | 
			
		||||
 | 
			
		||||
    const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
 | 
			
		||||
    auto last_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
 | 
			
		||||
 | 
			
		||||
    // Intermediate results are always onesided
 | 
			
		||||
    _exec_fft(output, working_tensor, onesided_sizes, last_dims, /*forward=*/true);
 | 
			
		||||
    sorted_dims.resize(sorted_dims.size() - max_dims);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Only need to normalize the onesided slice since data in the other half is overwritten
 | 
			
		||||
  auto out_slice = output.slice(last_dim, 0, last_dim_halfsize);
 | 
			
		||||
  _fft_apply_normalization(out_slice, normalization, input_sizes, dim);
 | 
			
		||||
 | 
			
		||||
  if (!onesided) {
 | 
			
		||||
    if (output.sizes()[last_dim] != out_sizes[last_dim]) {
 | 
			
		||||
      working_tensor.resize_(out_sizes, MemoryFormat::Contiguous);
 | 
			
		||||
      working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(output);
 | 
			
		||||
      output = std::move(working_tensor);
 | 
			
		||||
    }
 | 
			
		||||
    at::native::_fft_fill_with_conjugate_symmetry_(output, dim);
 | 
			
		||||
  }
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_r2c_cufft_out(const Tensor& self, IntArrayRef dim,
 | 
			
		||||
                           int64_t normalization, bool onesided, Tensor& out) {
 | 
			
		||||
  auto result = _fft_r2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
 | 
			
		||||
  if (onesided) {
 | 
			
		||||
    return _fft_apply_normalization_out(out, result, normalization, self.sizes(), dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  resize_output(out, self.sizes());
 | 
			
		||||
 | 
			
		||||
  auto last_dim = dim.back();
 | 
			
		||||
  auto last_dim_halfsize = result.sizes()[last_dim];
 | 
			
		||||
  auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
 | 
			
		||||
  _fft_apply_normalization_out(out_slice, result, normalization, self.sizes(), dim);
 | 
			
		||||
  at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// n-dimensional complex to real IFFT
 | 
			
		||||
Tensor _fft_c2r_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t lastdim) {
 | 
			
		||||
  TORCH_CHECK(self.is_complex());
 | 
			
		||||
  auto in_sizes = self.sizes();
 | 
			
		||||
  DimVector out_sizes(in_sizes.begin(), in_sizes.end());
 | 
			
		||||
  out_sizes[dim.back()] = lastdim;
 | 
			
		||||
 | 
			
		||||
  // First complete any C2C transforms
 | 
			
		||||
  Tensor temp;
 | 
			
		||||
  if (dim.size() > 1) {
 | 
			
		||||
    temp = _fft_c2c_cufft(
 | 
			
		||||
        self, dim.slice(0, dim.size() - 1),
 | 
			
		||||
        static_cast<int64_t>(fft_norm_mode::none), /*forward=*/false);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551)
 | 
			
		||||
    temp = self.clone(MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Finally, do a 1D C2R transform
 | 
			
		||||
  // TODO: could transform up to 2 other dims in the same cuFFT operation
 | 
			
		||||
  auto output = at::empty(out_sizes, self.options().dtype(c10::toValueType(self.scalar_type())));
 | 
			
		||||
  _exec_fft(output, temp, out_sizes, dim.back(), /*forward=*/false);
 | 
			
		||||
  return _fft_apply_normalization(output, normalization, out_sizes, dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_c2r_cufft_out(const Tensor& self, IntArrayRef dim,
 | 
			
		||||
                           int64_t normalization, int64_t lastdim, Tensor& out) {
 | 
			
		||||
  auto result = _fft_c2r_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), lastdim);
 | 
			
		||||
  return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// n-dimensional complex to complex FFT/IFFT
 | 
			
		||||
Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
 | 
			
		||||
  TORCH_CHECK(self.is_complex());
 | 
			
		||||
  if (dim.empty()) {
 | 
			
		||||
    return self.clone();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto out_sizes = self.sizes();
 | 
			
		||||
  auto output = at::empty(out_sizes, self.options());
 | 
			
		||||
 | 
			
		||||
  // Perform any number of C2C transforms
 | 
			
		||||
  DimVector sorted_dims(dim.begin(), dim.end());
 | 
			
		||||
  auto self_strides = self.strides();
 | 
			
		||||
  auto working_tensor = self;
 | 
			
		||||
  while (true) {
 | 
			
		||||
    // Sort dimensions every time as _exec_fft re-strides the output
 | 
			
		||||
    auto strides = working_tensor.strides();
 | 
			
		||||
    std::sort(sorted_dims.begin(), sorted_dims.end(),
 | 
			
		||||
              [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
 | 
			
		||||
 | 
			
		||||
    const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
 | 
			
		||||
    auto first_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
 | 
			
		||||
 | 
			
		||||
    _exec_fft(output, working_tensor, out_sizes, first_dims, forward);
 | 
			
		||||
    sorted_dims.resize(sorted_dims.size() - max_dims);
 | 
			
		||||
 | 
			
		||||
    if (sorted_dims.empty()) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (working_tensor.is_same(self)) {
 | 
			
		||||
      working_tensor = std::move(output);
 | 
			
		||||
      output = at::empty(out_sizes, self.options());
 | 
			
		||||
    } else {
 | 
			
		||||
      std::swap(output, working_tensor);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return _fft_apply_normalization(output, normalization, out_sizes, dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_c2c_cufft_out(const Tensor& self, IntArrayRef dim,
 | 
			
		||||
                           int64_t normalization, bool forward, Tensor& out) {
 | 
			
		||||
  auto result = _fft_c2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
 | 
			
		||||
  return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}} // at::native
 | 
			
		||||
@ -16,10 +16,6 @@
 | 
			
		||||
#include <THC/THCTensorSort.cuh>
 | 
			
		||||
#include <THC/THCThrustAllocator.cuh>
 | 
			
		||||
 | 
			
		||||
#include <thrust/execution_policy.h>
 | 
			
		||||
#include <thrust/unique.h>
 | 
			
		||||
#include <cufft.h>
 | 
			
		||||
#include <cufftXt.h>
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <vector>
 | 
			
		||||
@ -83,6 +79,7 @@ struct HermitianSymmetryOffsetCalculator {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// out[:] = conj(in[:]) where in and out ordering is generalized by offset calculators
 | 
			
		||||
template <typename scalar_t, typename inp_calc_t, typename out_calc_t>
 | 
			
		||||
C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)
 | 
			
		||||
@ -135,504 +132,4 @@ void _fft_fill_with_conjugate_symmetry_cuda_(
 | 
			
		||||
 | 
			
		||||
REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);
 | 
			
		||||
 | 
			
		||||
// Execute a pre-planned tranform
 | 
			
		||||
static void exec_cufft_plan(
 | 
			
		||||
    const CuFFTConfig &config, void* in_data, void* out_data, bool forward) {
 | 
			
		||||
  auto& plan = config.plan();
 | 
			
		||||
#ifdef __HIP_PLATFORM_HCC__
 | 
			
		||||
  auto value_type = config.data_type();
 | 
			
		||||
  if (value_type == kFloat) {
 | 
			
		||||
    switch (config.transform_type()) {
 | 
			
		||||
      case CuFFTTransformType::C2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftComplex*>(out_data),
 | 
			
		||||
                                  forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::R2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftComplex*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::C2R: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftReal*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  } else if (value_type == kDouble) {
 | 
			
		||||
    switch (config.transform_type()) {
 | 
			
		||||
      case CuFFTTransformType::C2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast<hipfftDoubleComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftDoubleComplex*>(out_data),
 | 
			
		||||
                                  forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::R2C: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftDoubleComplex*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      case CuFFTTransformType::C2R: {
 | 
			
		||||
        CUFFT_CHECK(hipfftExecZ2D(plan, static_cast<hipfftDoubleComplex*>(in_data),
 | 
			
		||||
                                  static_cast<hipfftDoubleReal*>(out_data)));
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(false, "hipFFT doesn't support transforms on type: ", value_type);
 | 
			
		||||
#else
 | 
			
		||||
  CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
 | 
			
		||||
                          forward ? CUFFT_FORWARD : CUFFT_INVERSE));
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// NOTE [ cuFFT Embedded Strides ]
 | 
			
		||||
//
 | 
			
		||||
// cuFFT supports a subset of arbitrary strides via their "advanced data layout"
 | 
			
		||||
// option (http://docs.nvidia.com/cuda/cufft/index.html#advanced-data-layout).
 | 
			
		||||
// Specifically, these are tensors that can be viewed as subtensors resulted
 | 
			
		||||
// from slicing a larger contiguous tensors. For such input tensors, let the
 | 
			
		||||
// sizes of the enclosing tensor be `inembed`, and we can have in 3d case:
 | 
			
		||||
//
 | 
			
		||||
//     input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z)]
 | 
			
		||||
//
 | 
			
		||||
// Above is the simplified formula ignoring the batch dimension. In fact, the
 | 
			
		||||
// last dimension of the enclosing tensor doesn't have to be contiguous, i.e.,
 | 
			
		||||
// it can be greater than 1. Then one can set the base stride for the enclosing
 | 
			
		||||
// tensor with `istride`. Then we have
 | 
			
		||||
//
 | 
			
		||||
//     input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z) * istride]
 | 
			
		||||
//
 | 
			
		||||
// For example, consider
 | 
			
		||||
//
 | 
			
		||||
//     enclosing = torch.zeros(6, 8, 10)  # contiguous
 | 
			
		||||
//     input = enclosing[:4, 2:6, 6:]
 | 
			
		||||
//     input.size()                       # [ 4,  4,  4]
 | 
			
		||||
//     input.stride()                     # [80, 10,  1]
 | 
			
		||||
//     # inembed = [6, 8, 10]
 | 
			
		||||
//     input[2, 1, 3] = input[((2 * 8) + 1) * 10 + 3]   # using above formula
 | 
			
		||||
//                    = input[173]
 | 
			
		||||
//                    = input[2 * 80 + 1 * 10 + 1 * 3]  # using strides directly
 | 
			
		||||
//
 | 
			
		||||
// Generally, the embedded strides can be computed as
 | 
			
		||||
//
 | 
			
		||||
//     embed[i] = stride[i - 1] / stride[i].
 | 
			
		||||
//
 | 
			
		||||
// Note that the value of embed[0] isn't used to compute indices and doesn't
 | 
			
		||||
// matter.
 | 
			
		||||
//
 | 
			
		||||
// Contrary to advanced data layout, simple layout means that *embeds have
 | 
			
		||||
// unit-strides. In particular, unit-stride refers to that the input and output
 | 
			
		||||
// tensors being contiguous, and that the strides at the innermost signal
 | 
			
		||||
// dimension being unit (1) w.r.t. the corresponding data type.
 | 
			
		||||
 | 
			
		||||
#pragma push
 | 
			
		||||
#pragma diag_suppress 177   // Function was declared but never referenced
 | 
			
		||||
static inline Tensor _run_cufft(
 | 
			
		||||
    const CuFFTConfig &config, Tensor& input, int64_t signal_ndim,
 | 
			
		||||
    bool complex_input, bool complex_output, bool inverse,
 | 
			
		||||
    IntArrayRef checked_signal_sizes, fft_norm_mode norm, bool onesided,
 | 
			
		||||
    IntArrayRef output_sizes, bool input_was_cloned
 | 
			
		||||
) {
 | 
			
		||||
  if (config.should_clone_input() && !input_was_cloned) {
 | 
			
		||||
    input = input.clone(at::MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto& plan = config.plan();
 | 
			
		||||
  auto& ctx = at::globalContext();
 | 
			
		||||
 | 
			
		||||
  // set output
 | 
			
		||||
  auto output = at::empty(output_sizes, input.options());
 | 
			
		||||
 | 
			
		||||
  // set to current stream
 | 
			
		||||
  CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
 | 
			
		||||
 | 
			
		||||
  auto ws = at::empty({ config.workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
 | 
			
		||||
  CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr()));
 | 
			
		||||
 | 
			
		||||
  // run
 | 
			
		||||
  exec_cufft_plan(config, input.data_ptr(), output.data_ptr(), !inverse);
 | 
			
		||||
 | 
			
		||||
  // rescale if requested
 | 
			
		||||
  auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
 | 
			
		||||
  if (norm != fft_norm_mode::none) {
 | 
			
		||||
    auto signal_numel = c10::multiply_integers(checked_signal_sizes);
 | 
			
		||||
    double scale_denom;
 | 
			
		||||
    if (norm == fft_norm_mode::by_root_n) {
 | 
			
		||||
      scale_denom = std::sqrt(static_cast<double>(signal_numel));
 | 
			
		||||
    } else {
 | 
			
		||||
      scale_denom = static_cast<double>(signal_numel);
 | 
			
		||||
    }
 | 
			
		||||
    if (!complex_input && complex_output && !onesided) {
 | 
			
		||||
      auto end_data_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim);
 | 
			
		||||
      output.narrow(signal_ndim, 0, end_data_slice).div_(scale_denom);
 | 
			
		||||
    } else {
 | 
			
		||||
      output.div_(scale_denom);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // if needed, fill out the other half using conjugate symmetry
 | 
			
		||||
  if (!complex_input && complex_output && !onesided) {
 | 
			
		||||
    DimVector signal_dims(signal_ndim);
 | 
			
		||||
    std::iota(signal_dims.begin(), signal_dims.end(), 1);
 | 
			
		||||
    auto out_as_complex = at::view_as_complex(output);
 | 
			
		||||
    at::native::_fft_fill_with_conjugate_symmetry_(out_as_complex, signal_dims);
 | 
			
		||||
  }
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
#pragma pop
 | 
			
		||||
 | 
			
		||||
// The cuFFT plan cache
 | 
			
		||||
// unique_ptr for nullability and to avoid reference invalidation on vector resize
 | 
			
		||||
static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
 | 
			
		||||
static std::mutex plan_caches_mutex;
 | 
			
		||||
 | 
			
		||||
static inline
 | 
			
		||||
CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
 | 
			
		||||
  std::lock_guard<std::mutex> guard(plan_caches_mutex);
 | 
			
		||||
 | 
			
		||||
  AT_ASSERT(device_index >= 0);
 | 
			
		||||
 | 
			
		||||
  if (device_index >= plan_caches.size()) {
 | 
			
		||||
    plan_caches.resize(device_index + 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (!plan_caches[device_index]) {
 | 
			
		||||
    plan_caches[device_index] = std::make_unique<CuFFTParamsLRUCache>();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return *plan_caches[device_index];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).max_size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).resize(max_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_get_plan_cache_size: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cufft_clear_plan_cache_impl(int64_t device_index) {
 | 
			
		||||
  TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
 | 
			
		||||
    "cufft_clear_plan_cache: expected 0 <= device_index < ",
 | 
			
		||||
    at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
 | 
			
		||||
    device_index);
 | 
			
		||||
  return cufft_get_plan_cache(device_index).clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native::detail
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
constexpr int64_t cufft_max_ndim = 3;
 | 
			
		||||
 | 
			
		||||
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
 | 
			
		||||
static const Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
 | 
			
		||||
                         IntArrayRef dim, bool forward) {
 | 
			
		||||
  const auto ndim = self.dim();
 | 
			
		||||
  const int64_t signal_ndim = dim.size();
 | 
			
		||||
  const auto batch_dims = ndim - signal_ndim;
 | 
			
		||||
 | 
			
		||||
  // Permute dimensions so batch dimensions come first, and in stride order
 | 
			
		||||
  // This maximizes data locality when collapsing to a single batch dimension
 | 
			
		||||
  DimVector dim_permute(ndim);
 | 
			
		||||
  std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
 | 
			
		||||
 | 
			
		||||
  c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
 | 
			
		||||
  for (const auto& d : dim) {
 | 
			
		||||
    is_transformed_dim[d] = true;
 | 
			
		||||
  }
 | 
			
		||||
  auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
 | 
			
		||||
                                  [&](int64_t d) {return !is_transformed_dim[d]; });
 | 
			
		||||
  auto self_strides = self.strides();
 | 
			
		||||
  std::sort(dim_permute.begin(), batch_end,
 | 
			
		||||
            [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
 | 
			
		||||
  std::copy(dim.cbegin(), dim.cend(), batch_end);
 | 
			
		||||
  auto input = self.permute(dim_permute);
 | 
			
		||||
 | 
			
		||||
  // Collapse batch dimensions into a single dimension
 | 
			
		||||
  DimVector batched_sizes(signal_ndim + 1);
 | 
			
		||||
  batched_sizes[0] = -1;
 | 
			
		||||
  std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
 | 
			
		||||
  input = input.reshape(batched_sizes);
 | 
			
		||||
 | 
			
		||||
  const auto batch_size = input.sizes()[0];
 | 
			
		||||
  DimVector signal_size(signal_ndim + 1);
 | 
			
		||||
  signal_size[0] = batch_size;
 | 
			
		||||
  for (int64_t i = 0; i < signal_ndim; ++i) {
 | 
			
		||||
    auto in_size = input.sizes()[i + 1];
 | 
			
		||||
    auto out_size = out_sizes[dim[i]];
 | 
			
		||||
    signal_size[i + 1] = std::max(in_size, out_size);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
 | 
			
		||||
                          in_size == (signal_size[i + 1] / 2) + 1);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
 | 
			
		||||
                          out_size == (signal_size[i + 1] / 2) + 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  batched_sizes[0] = batch_size;
 | 
			
		||||
  DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
 | 
			
		||||
  for (size_t i = 0; i < dim.size(); ++i) {
 | 
			
		||||
    batched_out_sizes[i + 1] = out_sizes[dim[i]];
 | 
			
		||||
  }
 | 
			
		||||
  out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
 | 
			
		||||
 | 
			
		||||
  // Create the transform plan (either from cache or locally)
 | 
			
		||||
  const auto value_type = c10::toValueType(input.scalar_type());
 | 
			
		||||
  auto fft_type = GetCuFFTTransformType(input.is_complex(), out.is_complex());
 | 
			
		||||
  CuFFTParams Params(input.strides(), out.strides(), signal_size, fft_type, value_type);
 | 
			
		||||
  CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(input.device().index());
 | 
			
		||||
  std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
 | 
			
		||||
  c10::optional<CuFFTConfig> uncached_plan;
 | 
			
		||||
  const CuFFTConfig * config = nullptr;
 | 
			
		||||
 | 
			
		||||
  if (plan_cache.max_size() > 0) {
 | 
			
		||||
    guard.lock();
 | 
			
		||||
    if (plan_cache.max_size() > 0) {  // check again after acquiring the lock
 | 
			
		||||
      config = &plan_cache.lookup(Params);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (config == nullptr) {
 | 
			
		||||
    uncached_plan.emplace(Params);
 | 
			
		||||
    config = &uncached_plan.value();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto & plan = config->plan();
 | 
			
		||||
 | 
			
		||||
  if (config->should_clone_input()) {
 | 
			
		||||
    input = input.clone(MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // prepare cufft for execution
 | 
			
		||||
  CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
 | 
			
		||||
  auto workspace = at::empty({ config->workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
 | 
			
		||||
  CUFFT_CHECK(cufftSetWorkArea(plan, workspace.data_ptr()));
 | 
			
		||||
 | 
			
		||||
  // execute transform plan
 | 
			
		||||
  exec_cufft_plan(*config, input.data_ptr(), out.data_ptr(), forward);
 | 
			
		||||
 | 
			
		||||
  // Inplace reshaping to original batch shape and inverting the dimension permutation
 | 
			
		||||
  DimVector out_strides(ndim);
 | 
			
		||||
  int64_t batch_numel = 1;
 | 
			
		||||
  for (int64_t i = batch_dims - 1; i >= 0; --i) {
 | 
			
		||||
    out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
 | 
			
		||||
    batch_numel *= out_sizes[dim_permute[i]];
 | 
			
		||||
  }
 | 
			
		||||
  for (int64_t i = batch_dims; i < ndim; ++i) {
 | 
			
		||||
    out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
 | 
			
		||||
  }
 | 
			
		||||
  return out.as_strided_(out_sizes, out_strides, out.storage_offset());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculates the normalization constant and applies it in-place to self
 | 
			
		||||
// sizes is the sizes of a twosided tensor and dims are all transformed dims
 | 
			
		||||
double _fft_normalization_scale(int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
 | 
			
		||||
  auto norm = static_cast<fft_norm_mode>(normalization);
 | 
			
		||||
  if (norm == fft_norm_mode::none) {
 | 
			
		||||
    return 1.0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t signal_numel = 1;
 | 
			
		||||
  for (auto dim : dims) {
 | 
			
		||||
    signal_numel *= sizes[dim];
 | 
			
		||||
  }
 | 
			
		||||
  const double scale_denom = (norm == fft_norm_mode::by_root_n) ?
 | 
			
		||||
    std::sqrt(signal_numel) : static_cast<double>(signal_numel);
 | 
			
		||||
  return 1.0 / scale_denom;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const Tensor& _fft_apply_normalization(const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
 | 
			
		||||
  auto scale = _fft_normalization_scale(normalization, sizes, dims);
 | 
			
		||||
  return (scale == 1.0) ? self : self.mul_(scale);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_apply_normalization_out(Tensor& out, const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
 | 
			
		||||
  auto scale = _fft_normalization_scale(normalization, sizes, dims);
 | 
			
		||||
  return at::mul_out(out, self, c10::scalar_to_tensor(scale));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace (anonymous)
 | 
			
		||||
 | 
			
		||||
// n-dimensional real to complex FFT
 | 
			
		||||
Tensor _fft_r2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
 | 
			
		||||
  TORCH_CHECK(self.is_floating_point());
 | 
			
		||||
  auto input_sizes = self.sizes();
 | 
			
		||||
  DimVector onesided_sizes(input_sizes.begin(), input_sizes.end());
 | 
			
		||||
  auto last_dim = dim.back();
 | 
			
		||||
  auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
 | 
			
		||||
  onesided_sizes[last_dim] = last_dim_halfsize;
 | 
			
		||||
  IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes;
 | 
			
		||||
 | 
			
		||||
  const auto out_options = self.options().dtype(c10::toComplexType(self.scalar_type()));
 | 
			
		||||
  auto output = at::empty(out_sizes, out_options);
 | 
			
		||||
 | 
			
		||||
  // CuFFT requires real input to be over-aligned, as if it were complex
 | 
			
		||||
  const auto complex_size = 2 * self.element_size();
 | 
			
		||||
  const bool complex_aligned = (
 | 
			
		||||
      reinterpret_cast<std::uintptr_t>(self.data_ptr()) % complex_size == 0);
 | 
			
		||||
  auto working_tensor = self;
 | 
			
		||||
  if (!complex_aligned) {
 | 
			
		||||
    working_tensor = self.movedim(last_dim, -1)
 | 
			
		||||
                         .clone(MemoryFormat::Contiguous)
 | 
			
		||||
                         .movedim(-1, last_dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // First do the R2C transform on the last dimension
 | 
			
		||||
  {
 | 
			
		||||
    auto target_sizes = dim.size() == 1 ? out_sizes : onesided_sizes;
 | 
			
		||||
    _exec_fft(output, working_tensor, target_sizes, last_dim, /*forward=*/true);
 | 
			
		||||
    if (dim.size() > 1) {
 | 
			
		||||
      working_tensor = at::empty(out_sizes, out_options);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Then any remaining C2C transforms
 | 
			
		||||
  DimVector sorted_dims(dim.begin(), dim.end() - 1);
 | 
			
		||||
  while (!sorted_dims.empty()) {
 | 
			
		||||
    std::swap(output, working_tensor);
 | 
			
		||||
 | 
			
		||||
    // Resort dimensions every time as _exec_fft re-strides the output
 | 
			
		||||
    auto strides = working_tensor.strides();
 | 
			
		||||
    std::sort(sorted_dims.begin(), sorted_dims.end(),
 | 
			
		||||
              [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
 | 
			
		||||
 | 
			
		||||
    const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
 | 
			
		||||
    auto last_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
 | 
			
		||||
 | 
			
		||||
    // Intermediate results are always onesided
 | 
			
		||||
    _exec_fft(output, working_tensor, onesided_sizes, last_dims, /*forward=*/true);
 | 
			
		||||
    sorted_dims.resize(sorted_dims.size() - max_dims);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Only need to normalize the onesided slice since data in the other half is overwritten
 | 
			
		||||
  auto out_slice = output.slice(last_dim, 0, last_dim_halfsize);
 | 
			
		||||
  _fft_apply_normalization(out_slice, normalization, input_sizes, dim);
 | 
			
		||||
 | 
			
		||||
  if (!onesided) {
 | 
			
		||||
    if (output.sizes()[last_dim] != out_sizes[last_dim]) {
 | 
			
		||||
      working_tensor.resize_(out_sizes, MemoryFormat::Contiguous);
 | 
			
		||||
      working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(output);
 | 
			
		||||
      output = std::move(working_tensor);
 | 
			
		||||
    }
 | 
			
		||||
    at::native::_fft_fill_with_conjugate_symmetry_(output, dim);
 | 
			
		||||
  }
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_r2c_cufft_out(const Tensor& self, IntArrayRef dim,
 | 
			
		||||
                           int64_t normalization, bool onesided, Tensor& out) {
 | 
			
		||||
  auto result = _fft_r2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
 | 
			
		||||
  if (onesided) {
 | 
			
		||||
    return _fft_apply_normalization_out(out, result, normalization, self.sizes(), dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  resize_output(out, self.sizes());
 | 
			
		||||
 | 
			
		||||
  auto last_dim = dim.back();
 | 
			
		||||
  auto last_dim_halfsize = result.sizes()[last_dim];
 | 
			
		||||
  auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
 | 
			
		||||
  _fft_apply_normalization_out(out_slice, result, normalization, self.sizes(), dim);
 | 
			
		||||
  at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
 | 
			
		||||
  return out;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// n-dimensional complex to real IFFT
 | 
			
		||||
Tensor _fft_c2r_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t lastdim) {
 | 
			
		||||
  TORCH_CHECK(self.is_complex());
 | 
			
		||||
  auto in_sizes = self.sizes();
 | 
			
		||||
  DimVector out_sizes(in_sizes.begin(), in_sizes.end());
 | 
			
		||||
  out_sizes[dim.back()] = lastdim;
 | 
			
		||||
 | 
			
		||||
  // First complete any C2C transforms
 | 
			
		||||
  Tensor temp;
 | 
			
		||||
  if (dim.size() > 1) {
 | 
			
		||||
    temp = _fft_c2c_cufft(
 | 
			
		||||
        self, dim.slice(0, dim.size() - 1),
 | 
			
		||||
        static_cast<int64_t>(fft_norm_mode::none), /*forward=*/false);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551)
 | 
			
		||||
    temp = self.clone(MemoryFormat::Contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Finally, do a 1D C2R transform
 | 
			
		||||
  // TODO: could transform up to 2 other dims in the same cuFFT operation
 | 
			
		||||
  auto output = at::empty(out_sizes, self.options().dtype(c10::toValueType(self.scalar_type())));
 | 
			
		||||
  _exec_fft(output, temp, out_sizes, dim.back(), /*forward=*/false);
 | 
			
		||||
  return _fft_apply_normalization(output, normalization, out_sizes, dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_c2r_cufft_out(const Tensor& self, IntArrayRef dim,
 | 
			
		||||
                           int64_t normalization, int64_t lastdim, Tensor& out) {
 | 
			
		||||
  auto result = _fft_c2r_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), lastdim);
 | 
			
		||||
  return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// n-dimensional complex to complex FFT/IFFT
 | 
			
		||||
Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
 | 
			
		||||
  TORCH_CHECK(self.is_complex());
 | 
			
		||||
  if (dim.empty()) {
 | 
			
		||||
    return self.clone();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto out_sizes = self.sizes();
 | 
			
		||||
  auto output = at::empty(out_sizes, self.options());
 | 
			
		||||
 | 
			
		||||
  // Perform any number of C2C transforms
 | 
			
		||||
  DimVector sorted_dims(dim.begin(), dim.end());
 | 
			
		||||
  auto self_strides = self.strides();
 | 
			
		||||
  auto working_tensor = self;
 | 
			
		||||
  while (true) {
 | 
			
		||||
    // Sort dimensions every time as _exec_fft re-strides the output
 | 
			
		||||
    auto strides = working_tensor.strides();
 | 
			
		||||
    std::sort(sorted_dims.begin(), sorted_dims.end(),
 | 
			
		||||
              [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
 | 
			
		||||
 | 
			
		||||
    const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
 | 
			
		||||
    auto first_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
 | 
			
		||||
 | 
			
		||||
    _exec_fft(output, working_tensor, out_sizes, first_dims, forward);
 | 
			
		||||
    sorted_dims.resize(sorted_dims.size() - max_dims);
 | 
			
		||||
 | 
			
		||||
    if (sorted_dims.empty()) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (working_tensor.is_same(self)) {
 | 
			
		||||
      working_tensor = std::move(output);
 | 
			
		||||
      output = at::empty(out_sizes, self.options());
 | 
			
		||||
    } else {
 | 
			
		||||
      std::swap(output, working_tensor);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return _fft_apply_normalization(output, normalization, out_sizes, dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_c2c_cufft_out(const Tensor& self, IntArrayRef dim,
 | 
			
		||||
                           int64_t normalization, bool forward, Tensor& out) {
 | 
			
		||||
  auto result = _fft_c2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
 | 
			
		||||
  return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}} // at::native
 | 
			
		||||
 | 
			
		||||
@ -141,8 +141,9 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
 | 
			
		||||
      num_inp,
 | 
			
		||||
      " is too big to be handled by cub");
 | 
			
		||||
  Tensor sorted;
 | 
			
		||||
  Tensor self_c = self.contiguous();
 | 
			
		||||
  if (consecutive) {
 | 
			
		||||
    sorted = self;
 | 
			
		||||
    sorted = self_c;
 | 
			
		||||
  } else {
 | 
			
		||||
    sorted = at::empty({num_inp}, self.options());
 | 
			
		||||
  }
 | 
			
		||||
@ -151,14 +152,14 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
 | 
			
		||||
  Tensor sorted_indices;
 | 
			
		||||
  if (!return_inverse) {
 | 
			
		||||
    if (!consecutive) {
 | 
			
		||||
      cuda::cub::sort_keys(self.data_ptr<scalar_t>(), sorted_data, num_inp);
 | 
			
		||||
      cuda::cub::sort_keys(self_c.data_ptr<scalar_t>(), sorted_data, num_inp);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    if (!consecutive) {
 | 
			
		||||
      Tensor range = at::arange(0, num_inp, options);
 | 
			
		||||
      sorted_indices = at::empty({num_inp}, options);
 | 
			
		||||
      cuda::cub::sort_pairs(
 | 
			
		||||
          self.data_ptr<scalar_t>(),
 | 
			
		||||
          self_c.data_ptr<scalar_t>(),
 | 
			
		||||
          sorted_data,
 | 
			
		||||
          range.data_ptr<int64_t>(),
 | 
			
		||||
          sorted_indices.data_ptr<int64_t>(),
 | 
			
		||||
 | 
			
		||||
@ -2148,6 +2148,7 @@
 | 
			
		||||
- func: _cufft_clear_plan_cache(int device_index) -> ()
 | 
			
		||||
 | 
			
		||||
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
 | 
			
		||||
  device_check: NoCheck   # TensorIterator
 | 
			
		||||
  variants: function, method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: index
 | 
			
		||||
@ -2172,6 +2173,7 @@
 | 
			
		||||
  variants: function, method
 | 
			
		||||
 | 
			
		||||
- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
 | 
			
		||||
  device_check: NoCheck   # delegate to _index_put_impl_, which leverages TensorIterator
 | 
			
		||||
  variants: function, method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CompositeExplicitAutograd: index_put_
 | 
			
		||||
@ -2182,6 +2184,7 @@
 | 
			
		||||
  # - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Scalar v)
 | 
			
		||||
 | 
			
		||||
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
 | 
			
		||||
  device_check: NoCheck   # delegate to _index_put_impl_ after clone, which leverages TensorIterator
 | 
			
		||||
  variants: function, method
 | 
			
		||||
 | 
			
		||||
- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
 | 
			
		||||
@ -3515,6 +3518,31 @@
 | 
			
		||||
    CPU, CUDA: silu_backward
 | 
			
		||||
    CompositeImplicitAutograd: math_silu_backward
 | 
			
		||||
 | 
			
		||||
- func: mish(Tensor self) -> Tensor
 | 
			
		||||
  structured_delegate: mish.out
 | 
			
		||||
  python_module: nn
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CompositeExplicitAutograd: mish
 | 
			
		||||
 | 
			
		||||
- func: mish_(Tensor(a!) self) -> Tensor(a!)
 | 
			
		||||
  structured_delegate: mish.out
 | 
			
		||||
  python_module: nn
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CompositeExplicitAutograd: mish_
 | 
			
		||||
 | 
			
		||||
- func: mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
 | 
			
		||||
  structured: True
 | 
			
		||||
  structured_inherits: TensorIteratorBase
 | 
			
		||||
  python_module: nn
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: mish_out
 | 
			
		||||
 | 
			
		||||
- func: mish_backward(Tensor grad_output, Tensor self) -> Tensor
 | 
			
		||||
  python_module: nn
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU, CUDA: mish_backward
 | 
			
		||||
    CompositeImplicitAutograd: math_mish_backward
 | 
			
		||||
 | 
			
		||||
- func: sigmoid(Tensor self) -> Tensor
 | 
			
		||||
  device_check: NoCheck   # TensorIterator
 | 
			
		||||
  structured_delegate: sigmoid.out
 | 
			
		||||
@ -4780,9 +4808,9 @@
 | 
			
		||||
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
 | 
			
		||||
# the default would never make sense.
 | 
			
		||||
 | 
			
		||||
- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 | 
			
		||||
- func: _sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 | 
			
		||||
- func: _sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,7 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
 | 
			
		||||
// TODO: This constructor should probably use an ATen abstract method in order
 | 
			
		||||
// to make autograd dispatch available for the CSR constructor. See the relevant
 | 
			
		||||
// note in native_functions.yaml.
 | 
			
		||||
Tensor sparse_csr_tensor(
 | 
			
		||||
Tensor _sparse_csr_tensor(
 | 
			
		||||
    const Tensor& crow_indices,
 | 
			
		||||
    const Tensor& col_indices,
 | 
			
		||||
    const Tensor& values,
 | 
			
		||||
@ -86,7 +86,7 @@ Tensor sparse_csr_tensor(
 | 
			
		||||
  return self;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor sparse_csr_tensor(
 | 
			
		||||
Tensor _sparse_csr_tensor(
 | 
			
		||||
    const Tensor& crow_indices,
 | 
			
		||||
    const Tensor& col_indices,
 | 
			
		||||
    const Tensor& values,
 | 
			
		||||
@ -125,7 +125,7 @@ Tensor sparse_csr_tensor(
 | 
			
		||||
    size[1] = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return at::sparse_csr_tensor(
 | 
			
		||||
  return at::_sparse_csr_tensor(
 | 
			
		||||
      crow_indices, col_indices, values, size, options);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -110,6 +110,15 @@ TEST(MathKernelTest, SiluBackward) {
 | 
			
		||||
  ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
TEST(MathKernelTest, MishBackward) {
 | 
			
		||||
  const auto input = rand({20, 10});
 | 
			
		||||
  const auto grad_output = rand({20, 10});
 | 
			
		||||
  auto out = at::native::mish_backward(grad_output, input);
 | 
			
		||||
  auto math_out = at::native::math_mish_backward(grad_output, input);
 | 
			
		||||
  ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
TEST(MathKernelTest, NarrowCopy)  {
 | 
			
		||||
  auto x = rand({5, 8, 7});
 | 
			
		||||
 | 
			
		||||
@ -47,8 +47,6 @@ TORCH_CUDA_CPP_API int THCState_getPeerToPeerAccess(THCState* state, int dev, in
 | 
			
		||||
 | 
			
		||||
TORCH_CUDA_CPP_API c10::Allocator* THCState_getCudaHostAllocator(THCState* state);
 | 
			
		||||
 | 
			
		||||
TORCH_CUDA_CPP_API void THCMagma_init(THCState *state);
 | 
			
		||||
 | 
			
		||||
/* For the current device and stream, returns the allocated scratch space */
 | 
			
		||||
TORCH_CUDA_CPP_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@
 | 
			
		||||
#include <THC/THCStorage.hpp>
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <ATen/native/cuda/MiscUtils.h>
 | 
			
		||||
#include <ATen/cuda/detail/CUDAHooks.h>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_MAGMA
 | 
			
		||||
#include <magma_v2.h>
 | 
			
		||||
@ -17,12 +18,19 @@
 | 
			
		||||
 | 
			
		||||
#define NoMagma(name) "No CUDA implementation of '" #name "'. Install MAGMA and rebuild cutorch (http://icl.cs.utk.edu/magma/)"
 | 
			
		||||
 | 
			
		||||
void THCMagma_init(THCState *state)
 | 
			
		||||
{
 | 
			
		||||
namespace {
 | 
			
		||||
void _THCMagma_init() {
 | 
			
		||||
#ifdef USE_MAGMA
 | 
			
		||||
  magma_init();
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct Initializer {
 | 
			
		||||
  Initializer() {
 | 
			
		||||
    ::at::cuda::detail::THCMagma_init = _THCMagma_init;
 | 
			
		||||
  };
 | 
			
		||||
} initializer;
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
#include <THC/generic/THCTensorMathMagma.cu>
 | 
			
		||||
#include <THC/THCGenerateAllTypes.h>
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ def gen_sparse_csr(shape, nnz):
 | 
			
		||||
        dense[f] = fill_value
 | 
			
		||||
    dense = torch.from_numpy(dense.reshape(shape))
 | 
			
		||||
 | 
			
		||||
    return dense.to_sparse_csr()
 | 
			
		||||
    return dense._to_sparse_csr()
 | 
			
		||||
 | 
			
		||||
def gen_sparse_coo(shape, nnz):
 | 
			
		||||
    dense = np.random.randn(*shape)
 | 
			
		||||
@ -51,4 +51,4 @@ def gen_sparse_coo_and_csr(shape, nnz):
 | 
			
		||||
        dense[f] = 0
 | 
			
		||||
 | 
			
		||||
    dense = torch.from_numpy(dense.reshape(shape))
 | 
			
		||||
    return dense.to_sparse(), dense.to_sparse_csr()
 | 
			
		||||
    return dense.to_sparse(), dense._to_sparse_csr()
 | 
			
		||||
 | 
			
		||||
@ -33,17 +33,21 @@ class C10_CUDA_API CUDAError : public c10::Error {
 | 
			
		||||
    }                                                            \
 | 
			
		||||
  } while (0)
 | 
			
		||||
#else
 | 
			
		||||
#define C10_CUDA_CHECK(EXPR)                                              \
 | 
			
		||||
  do {                                                                    \
 | 
			
		||||
    cudaError_t __err = EXPR;                                             \
 | 
			
		||||
    if (__err != cudaSuccess) {                                           \
 | 
			
		||||
      auto error_unused C10_UNUSED = cudaGetLastError();                  \
 | 
			
		||||
      auto _cuda_check_prefix = c10::cuda::get_cuda_check_prefix();       \
 | 
			
		||||
      throw c10::CUDAError(                                               \
 | 
			
		||||
          {__func__, __FILE__, static_cast<uint32_t>(__LINE__)},          \
 | 
			
		||||
          TORCH_CHECK_MSG(                                                \
 | 
			
		||||
              false, "", _cuda_check_prefix, cudaGetErrorString(__err))); \
 | 
			
		||||
    }                                                                     \
 | 
			
		||||
#define C10_CUDA_CHECK(EXPR)                                        \
 | 
			
		||||
  do {                                                              \
 | 
			
		||||
    cudaError_t __err = EXPR;                                       \
 | 
			
		||||
    if (__err != cudaSuccess) {                                     \
 | 
			
		||||
      auto error_unused C10_UNUSED = cudaGetLastError();            \
 | 
			
		||||
      auto _cuda_check_suffix = c10::cuda::get_cuda_check_suffix(); \
 | 
			
		||||
      throw c10::CUDAError(                                         \
 | 
			
		||||
          {__func__, __FILE__, static_cast<uint32_t>(__LINE__)},    \
 | 
			
		||||
          TORCH_CHECK_MSG(                                          \
 | 
			
		||||
              false,                                                \
 | 
			
		||||
              "",                                                   \
 | 
			
		||||
              "CUDA error: ",                                       \
 | 
			
		||||
              cudaGetErrorString(__err),                            \
 | 
			
		||||
              _cuda_check_suffix));                                 \
 | 
			
		||||
    }                                                               \
 | 
			
		||||
  } while (0)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -141,17 +141,16 @@ void device_synchronize() {
 | 
			
		||||
  C10_CUDA_CHECK(cudaDeviceSynchronize());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const char* get_cuda_check_prefix() noexcept {
 | 
			
		||||
const char* get_cuda_check_suffix() noexcept {
 | 
			
		||||
  static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING");
 | 
			
		||||
  static bool blocking_enabled =
 | 
			
		||||
      (device_blocking_flag && atoi(device_blocking_flag));
 | 
			
		||||
  if (blocking_enabled) {
 | 
			
		||||
    return "CUDA error: ";
 | 
			
		||||
    return "";
 | 
			
		||||
  } else {
 | 
			
		||||
    return "CUDA kernel errors might be "
 | 
			
		||||
           "asynchronously reported at some other API call,so the "
 | 
			
		||||
           "stacktrace below might be incorrect. For debugging "
 | 
			
		||||
           "consider passing CUDA_LAUNCH_BLOCKING=1. CUDA error: ";
 | 
			
		||||
    return "\nCUDA kernel errors might be asynchronously reported at some"
 | 
			
		||||
           " other API call,so the stacktrace below might be incorrect."
 | 
			
		||||
           "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -30,7 +30,7 @@ C10_CUDA_API void set_device(DeviceIndex device);
 | 
			
		||||
 | 
			
		||||
C10_CUDA_API void device_synchronize();
 | 
			
		||||
 | 
			
		||||
C10_CUDA_API const char* get_cuda_check_prefix() noexcept;
 | 
			
		||||
C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
 | 
			
		||||
 | 
			
		||||
} // namespace cuda
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
@ -176,7 +176,7 @@ endif()
 | 
			
		||||
if(BUILD_SPLIT_CUDA)
 | 
			
		||||
  # Splitting the source files that'll be in torch_cuda between torch_cuda_cu and torch_cuda_cpp
 | 
			
		||||
  foreach(tmp ${Caffe2_GPU_SRCS})
 | 
			
		||||
    if("${tmp}" MATCHES "(.*aten.*\\.cu|.*(b|B)las.*|.*((s|S)olver|Register.*CUDA|Legacy|THC|CUDAHooks|detail/|TensorShapeCUDA).*\\.cpp)" AND NOT "${tmp}" MATCHES ".*(THC((CachingHost)?Allocator|General)).*")
 | 
			
		||||
    if("${tmp}" MATCHES "(.*aten.*\\.cu|.*(b|B)las.*|.*((s|S)olver|Register.*CUDA|Legacy|THC|TensorShapeCUDA).*\\.cpp)" AND NOT "${tmp}" MATCHES ".*(THC((CachingHost)?Allocator|General)).*")
 | 
			
		||||
      # Currently, torch_cuda_cu will have all the .cu files in aten, as well as some others that depend on those files
 | 
			
		||||
      list(APPEND Caffe2_GPU_SRCS_CU ${tmp})
 | 
			
		||||
    else()
 | 
			
		||||
@ -738,6 +738,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
 | 
			
		||||
      ${TORCH_SRC_DIR}/csrc/api/src/optim/schedulers/step_lr.cpp
 | 
			
		||||
      ${TORCH_SRC_DIR}/csrc/api/src/serialize/input-archive.cpp
 | 
			
		||||
      ${TORCH_SRC_DIR}/csrc/api/src/serialize/output-archive.cpp
 | 
			
		||||
      ${TORCH_SRC_DIR}/csrc/utils/crash_handler.cpp
 | 
			
		||||
    )
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
@ -1020,9 +1021,10 @@ endif()
 | 
			
		||||
if(LINUX)
 | 
			
		||||
  find_library(BREAKPAD_LIB breakpad_client)
 | 
			
		||||
  find_path(BREAKPAD_INCLUDE_DIR breakpad)
 | 
			
		||||
  if(BREAKPAD_LIB_FOUND AND BREAKPAD_INCLUDE_DIR_FOUND)
 | 
			
		||||
    target_link_libraries(torch_cpu PRIVATE ${BREAKPAD_LIB})
 | 
			
		||||
    add_compile_definitions(ADD_BREAKPAD_SIGNAL_HANDLER)
 | 
			
		||||
  if(BREAKPAD_LIB AND BREAKPAD_INCLUDE_DIR)
 | 
			
		||||
    message(STATUS "found breakpad library")
 | 
			
		||||
    target_link_libraries(torch_cpu PUBLIC ${BREAKPAD_LIB})
 | 
			
		||||
    target_compile_definitions(torch_cpu PRIVATE ADD_BREAKPAD_SIGNAL_HANDLER)
 | 
			
		||||
    target_include_directories(torch_cpu PUBLIC ${BREAKPAD_INCLUDE_DIR}/breakpad)
 | 
			
		||||
  else()
 | 
			
		||||
    message(STATUS "breakpad library not found")
 | 
			
		||||
@ -1417,8 +1419,9 @@ target_link_libraries(torch PUBLIC torch_cpu_library)
 | 
			
		||||
if(USE_CUDA)
 | 
			
		||||
  target_link_libraries(torch PUBLIC torch_cuda_library)
 | 
			
		||||
  if(BUILD_SPLIT_CUDA)
 | 
			
		||||
    target_link_libraries(torch_cuda PUBLIC torch_cuda_cu_library)
 | 
			
		||||
    # NS: Library order is important here to prevent cudnn double linking
 | 
			
		||||
    target_link_libraries(torch_cuda PUBLIC torch_cuda_cpp_library)
 | 
			
		||||
    target_link_libraries(torch_cuda PUBLIC torch_cuda_cu_library)
 | 
			
		||||
  endif()
 | 
			
		||||
elseif(USE_ROCM)
 | 
			
		||||
  target_link_libraries(torch PUBLIC torch_hip_library)
 | 
			
		||||
@ -1465,6 +1468,10 @@ if(BUILD_SPLIT_CUDA)
 | 
			
		||||
  target_link_libraries(
 | 
			
		||||
      torch_cuda_cpp PRIVATE ${Caffe2_CUDA_DEPENDENCY_LIBS})
 | 
			
		||||
  target_link_libraries(torch_cuda_cu PRIVATE torch_cuda_cpp)
 | 
			
		||||
  if(USE_CUDNN)
 | 
			
		||||
    target_link_libraries(
 | 
			
		||||
        torch_cuda_cpp PRIVATE  caffe2::cudnn-private)
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  # These public dependencies must go after the previous dependencies, as the
 | 
			
		||||
  # order of the libraries in the linker call matters here when statically
 | 
			
		||||
@ -1481,6 +1488,10 @@ elseif(USE_CUDA)
 | 
			
		||||
      torch_cuda PRIVATE ${Caffe2_GPU_INCLUDE})
 | 
			
		||||
  target_link_libraries(
 | 
			
		||||
      torch_cuda PRIVATE ${Caffe2_CUDA_DEPENDENCY_LIBS})
 | 
			
		||||
  if(USE_CUDNN)
 | 
			
		||||
    target_link_libraries(
 | 
			
		||||
        torch_cuda PRIVATE  caffe2::cudnn-private)
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  # These public dependencies must go after the previous dependencies, as the
 | 
			
		||||
  # order of the libraries in the linker call matters here when statically
 | 
			
		||||
 | 
			
		||||
@ -65,9 +65,13 @@ constexpr uint64_t kProducedFileFormatVersion = 0x3L;
 | 
			
		||||
//  0x1L: Initial version
 | 
			
		||||
//  0x2L: (Comment missing)
 | 
			
		||||
//  0x3L: (Comment missing)
 | 
			
		||||
//  0x4L: (Comment missing)
 | 
			
		||||
//  0x4L: (update) Added schema to function tuple. Forward-compatible change.
 | 
			
		||||
constexpr uint64_t kProducedBytecodeVersion = 0x4L;
 | 
			
		||||
//  0x5L: (update) Update bytecode is sharing constant tensor files from torchscript, and only serialize
 | 
			
		||||
//  extra tensors that are not in the torchscript constant table. Also update tensor storage schema adapting
 | 
			
		||||
//  to the unify format, the root key of tensor storage is updated from {index} to
 | 
			
		||||
//  {the_pointer_value_the_tensor.storage}, for example: `140245072983168.storage`
 | 
			
		||||
//  Forward-compatibility change.
 | 
			
		||||
constexpr uint64_t kProducedBytecodeVersion = 0x5L;
 | 
			
		||||
 | 
			
		||||
static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
 | 
			
		||||
    "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion.");
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,6 @@ PThreadPool* pthreadpool() {
 | 
			
		||||
      auto num_threads = leaked->get_thread_count();
 | 
			
		||||
      // NOLINTNEXTLINE(modernize-make-unique)
 | 
			
		||||
      threadpool.reset(new PThreadPool(num_threads));
 | 
			
		||||
      TORCH_WARN("Leaking Caffe2 thread-pool after fork.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return threadpool.get();
 | 
			
		||||
 | 
			
		||||
@ -1165,7 +1165,7 @@ if(USE_CUDA)
 | 
			
		||||
      caffe2_update_option(USE_NVRTC OFF)
 | 
			
		||||
    endif()
 | 
			
		||||
    if(CAFFE2_USE_CUDNN)
 | 
			
		||||
      list(APPEND Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS caffe2::cudnn)
 | 
			
		||||
      list(APPEND Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS caffe2::cudnn-public)
 | 
			
		||||
    else()
 | 
			
		||||
      caffe2_update_option(USE_CUDNN OFF)
 | 
			
		||||
    endif()
 | 
			
		||||
@ -1216,8 +1216,8 @@ if(USE_ROCM)
 | 
			
		||||
    list(APPEND HIP_CXX_FLAGS -Wno-implicit-int-float-conversion)
 | 
			
		||||
    list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
 | 
			
		||||
    list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
 | 
			
		||||
    list(APPEND HIP_CXX_FLAGS -DROCM_VERSION=${ROCM_VERSION_DEV_INT})
 | 
			
		||||
    list(APPEND HIP_CXX_FLAGS -std=c++14)
 | 
			
		||||
    add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
 | 
			
		||||
 | 
			
		||||
    if(CMAKE_BUILD_TYPE MATCHES Debug)
 | 
			
		||||
       list(APPEND HIP_CXX_FLAGS -g2)
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ MACRO(Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list)
 | 
			
		||||
      else ( APPLE )
 | 
			
		||||
        find_library(${_prefix}_${_library}_LIBRARY
 | 
			
		||||
          NAMES ${_library}
 | 
			
		||||
          PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 /opt/OpenBLAS/lib /usr/lib/aarch64-linux-gnu
 | 
			
		||||
          PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 /opt/OpenBLAS/lib /usr/lib/aarch64-linux-gnu ${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}
 | 
			
		||||
          ENV LD_LIBRARY_PATH )
 | 
			
		||||
      endif( APPLE )
 | 
			
		||||
      mark_as_advanced(${_prefix}_${_library}_LIBRARY)
 | 
			
		||||
@ -178,6 +178,19 @@ if((NOT BLAS_LIBRARIES)
 | 
			
		||||
  endif(BLAS_LIBRARIES)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if((NOT BLAS_LIBRARIES)
 | 
			
		||||
    AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "open")))
 | 
			
		||||
  check_fortran_libraries(
 | 
			
		||||
  BLAS_LIBRARIES
 | 
			
		||||
  BLAS
 | 
			
		||||
  sgemm
 | 
			
		||||
  ""
 | 
			
		||||
  "openblas;pthread;m;gomp")
 | 
			
		||||
  if(BLAS_LIBRARIES)
 | 
			
		||||
    set(BLAS_INFO "open")
 | 
			
		||||
  endif(BLAS_LIBRARIES)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if((NOT BLAS_LIBRARIES) AND (WIN32)
 | 
			
		||||
    AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "open")))
 | 
			
		||||
  check_fortran_libraries(
 | 
			
		||||
 | 
			
		||||
@ -128,13 +128,7 @@ if(BLAS_FOUND)
 | 
			
		||||
      if(NOT LAPACK_CGESDD_WORKS)
 | 
			
		||||
        find_library(GFORTRAN_LIBRARY
 | 
			
		||||
          NAMES libgfortran.a gfortran
 | 
			
		||||
          PATHS /usr/lib/gcc/aarch64-linux-gnu/9/
 | 
			
		||||
                /usr/lib/gcc/x86_64-redhat-linux/9/
 | 
			
		||||
                /usr/lib/gcc/aarch64-linux-gnu/8/
 | 
			
		||||
                /usr/lib/gcc/x86_64-redhat-linux/8/
 | 
			
		||||
                /usr/lib/gcc/aarch64-linux-gnu/7/
 | 
			
		||||
                /usr/lib/gcc/x86_64-redhat-linux/7/
 | 
			
		||||
                )
 | 
			
		||||
          PATHS ${CMAKE_C_IMPLICIT_LINK_DIRECTORIES})
 | 
			
		||||
       list(APPEND CMAKE_REQUIRED_LIBRARIES "${GFORTRAN_LIBRARY}")
 | 
			
		||||
       unset(LAPACK_CGESDD_WORKS CACHE)
 | 
			
		||||
       check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS)
 | 
			
		||||
 | 
			
		||||
@ -90,8 +90,12 @@ function(caffe2_print_configuration_summary)
 | 
			
		||||
    get_target_property(__tmp caffe2::curand IMPORTED_LOCATION)
 | 
			
		||||
    message(STATUS "    curand library      : ${__tmp}")
 | 
			
		||||
    if(${USE_CUDNN})
 | 
			
		||||
      get_target_property(__tmp caffe2::cudnn IMPORTED_LOCATION)
 | 
			
		||||
      get_target_property(__tmp caffe2::cudnn-public INTERFACE_LINK_LIBRARIES)
 | 
			
		||||
      message(STATUS "    cuDNN library       : ${__tmp}")
 | 
			
		||||
      if(${CUDNN_STATIC})
 | 
			
		||||
        get_target_property(__tmp caffe2::cudnn-private INTERFACE_LINK_LIBRARIES)
 | 
			
		||||
        message(STATUS "    cuDNN static library: ${__tmp}")
 | 
			
		||||
      endif()
 | 
			
		||||
    endif()
 | 
			
		||||
    get_target_property(__tmp caffe2::nvrtc IMPORTED_LOCATION)
 | 
			
		||||
    message(STATUS "    nvrtc               : ${__tmp}")
 | 
			
		||||
@ -130,6 +134,7 @@ function(caffe2_print_configuration_summary)
 | 
			
		||||
  message(STATUS "  USE_MKL               : ${CAFFE2_USE_MKL}")
 | 
			
		||||
  message(STATUS "  USE_MKLDNN            : ${USE_MKLDNN}")
 | 
			
		||||
  if(${CAFFE2_USE_MKLDNN})
 | 
			
		||||
    message(STATUS "  USE_MKLDNN_ACL        : ${USE_MKLDNN_ACL}")
 | 
			
		||||
    message(STATUS "  USE_MKLDNN_CBLAS      : ${USE_MKLDNN_CBLAS}")
 | 
			
		||||
  endif()
 | 
			
		||||
  message(STATUS "  USE_NCCL              : ${USE_NCCL}")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								cmake/public/ComputeLibrary.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								cmake/public/ComputeLibrary.cmake
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
			
		||||
# Build with Compute Library backend for the Arm architecture
 | 
			
		||||
# Note: Compute Library is available from: https://github.com/ARM-software/ComputeLibrary
 | 
			
		||||
#   and must be built separately. The location of the Compute Library build
 | 
			
		||||
#   must be set with the env var ACL_ROOT_DIR. This path will be checked later
 | 
			
		||||
#   as part of FindACL.cmake in oneDNN.
 | 
			
		||||
 | 
			
		||||
if(NOT USE_MKLDNN_ACL)
 | 
			
		||||
    RETURN()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
set(DNNL_AARCH64_USE_ACL ON CACHE BOOL "" FORCE)
 | 
			
		||||
 | 
			
		||||
# Check the Compute Library version number.
 | 
			
		||||
# Note: oneDNN / MKL-DNN v2.2 onwards will check the Compute Library version
 | 
			
		||||
#   the version check here can be removed once PyTorch transitions to v2.2.
 | 
			
		||||
set(ACL_MINIMUM_VERSION "21.02")
 | 
			
		||||
 | 
			
		||||
file(GLOB_RECURSE ACL_VERSION_FILE $ENV{ACL_ROOT_DIR}/*/arm_compute_version.embed)
 | 
			
		||||
 | 
			
		||||
if("${ACL_VERSION_FILE}" STREQUAL "")
 | 
			
		||||
  message(WARNING "Build may fail: Could not determine ACL version (minimum required is ${ACL_MINIMUM_VERSION})")
 | 
			
		||||
else()
 | 
			
		||||
  file(READ ${ACL_VERSION_FILE} ACL_VERSION_STRING)
 | 
			
		||||
  string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION ${ACL_VERSION_STRING})
 | 
			
		||||
  set(ACL_VERSION "${CMAKE_MATCH_1}")
 | 
			
		||||
 | 
			
		||||
  if(${ACL_VERSION} VERSION_EQUAL "0.0")
 | 
			
		||||
    # Unreleased ACL versions come with version string "v0.0-unreleased", and may not be compatible with oneDNN.
 | 
			
		||||
    # It is recommended to use the latest release of ACL.
 | 
			
		||||
    message(WARNING "Build may fail: Using unreleased ACL version (minimum required is ${ACL_MINIMUM_VERSION})")
 | 
			
		||||
  elseif(${ACL_VERSION} VERSION_LESS ${ACL_MINIMUM_VERSION})
 | 
			
		||||
    message(FATAL_ERROR "Detected ACL version ${ACL_VERSION}, but minimum required is ${ACL_MINIMUM_VERSION}")
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
@ -272,20 +272,66 @@ else()
 | 
			
		||||
      ${LIBNVTOOLSEXT})
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
# cudnn
 | 
			
		||||
# static linking is handled by USE_STATIC_CUDNN environment variable
 | 
			
		||||
if(CAFFE2_USE_CUDNN)
 | 
			
		||||
  add_library(caffe2::cudnn UNKNOWN IMPORTED)
 | 
			
		||||
  set_property(
 | 
			
		||||
      TARGET caffe2::cudnn PROPERTY IMPORTED_LOCATION
 | 
			
		||||
      ${CUDNN_LIBRARY_PATH})
 | 
			
		||||
  set_property(
 | 
			
		||||
      TARGET caffe2::cudnn PROPERTY INTERFACE_INCLUDE_DIRECTORIES
 | 
			
		||||
      ${CUDNN_INCLUDE_PATH})
 | 
			
		||||
  if(CUDNN_STATIC AND NOT WIN32)
 | 
			
		||||
# cublas. CUDA_CUBLAS_LIBRARIES is actually a list, so we will make an
 | 
			
		||||
# interface library similar to cudart.
 | 
			
		||||
add_library(caffe2::cublas INTERFACE IMPORTED)
 | 
			
		||||
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
 | 
			
		||||
    set_property(
 | 
			
		||||
        TARGET caffe2::cudnn PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libculibos.a" dl)
 | 
			
		||||
        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a")
 | 
			
		||||
    if(CUDA_VERSION VERSION_GREATER_EQUAL 10.1)
 | 
			
		||||
      set_property(
 | 
			
		||||
        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublasLt_static.a")
 | 
			
		||||
      # Add explicit dependency to cudart_static to fix
 | 
			
		||||
      # libcublasLt_static.a.o): undefined reference to symbol 'cudaStreamWaitEvent'
 | 
			
		||||
      # error adding symbols: DSO missing from command line
 | 
			
		||||
      set_property(
 | 
			
		||||
        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_cudart_static_LIBRARY}" rt dl)
 | 
			
		||||
    endif()
 | 
			
		||||
else()
 | 
			
		||||
    set_property(
 | 
			
		||||
        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        ${CUDA_CUBLAS_LIBRARIES})
 | 
			
		||||
endif()
 | 
			
		||||
set_property(
 | 
			
		||||
    TARGET caffe2::cublas PROPERTY INTERFACE_INCLUDE_DIRECTORIES
 | 
			
		||||
    ${CUDA_INCLUDE_DIRS})
 | 
			
		||||
 | 
			
		||||
# cudnn public and private interfaces
 | 
			
		||||
# static linking is handled by USE_STATIC_CUDNN environment variable
 | 
			
		||||
# If library is linked dynamically, than private interface is no-op
 | 
			
		||||
# If library is linked statically:
 | 
			
		||||
#  - public interface would only reference headers
 | 
			
		||||
#  - private interface will contain the actual link instructions
 | 
			
		||||
if(CAFFE2_USE_CUDNN)
 | 
			
		||||
  add_library(caffe2::cudnn-public INTERFACE IMPORTED)
 | 
			
		||||
  set_property(
 | 
			
		||||
    TARGET caffe2::cudnn-public PROPERTY INTERFACE_INCLUDE_DIRECTORIES
 | 
			
		||||
    ${CUDNN_INCLUDE_PATH})
 | 
			
		||||
  add_library(caffe2::cudnn-private INTERFACE IMPORTED)
 | 
			
		||||
  set_property(
 | 
			
		||||
    TARGET caffe2::cudnn-private PROPERTY INTERFACE_INCLUDE_DIRECTORIES
 | 
			
		||||
    ${CUDNN_INCLUDE_PATH})
 | 
			
		||||
  if(CUDNN_STATIC AND NOT WIN32)
 | 
			
		||||
    if(USE_WHOLE_CUDNN)
 | 
			
		||||
      set_property(
 | 
			
		||||
        TARGET caffe2::cudnn-private PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "-Wl,--whole-archive,\"${CUDNN_LIBRARY_PATH}\" -Wl,--no-whole-archive")
 | 
			
		||||
    else()
 | 
			
		||||
      set_property(
 | 
			
		||||
        TARGET caffe2::cudnn-private PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        ${CUDNN_LIBRARY_PATH})
 | 
			
		||||
    endif()
 | 
			
		||||
    set_property(
 | 
			
		||||
      TARGET caffe2::cudnn-private APPEND PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
      "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libculibos.a" dl)
 | 
			
		||||
    # Add explicit dependency on cublas to cudnn
 | 
			
		||||
    get_target_property(__tmp caffe2::cublas INTERFACE_LINK_LIBRARIES)
 | 
			
		||||
    set_property(
 | 
			
		||||
      TARGET caffe2::cudnn-private APPEND PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
      "${__tmp}")
 | 
			
		||||
    # Lines below use target_link_libraries because we support cmake 3.5+.
 | 
			
		||||
    # For cmake 3.13+, target_link_options to set INTERFACE_LINK_OPTIONS would be better.
 | 
			
		||||
    # https://cmake.org/cmake/help/v3.5/command/target_link_libraries.html warns
 | 
			
		||||
@ -295,8 +341,12 @@ if(CAFFE2_USE_CUDNN)
 | 
			
		||||
    #  link items that will not propagate to dependents."
 | 
			
		||||
    # Propagating to a dependent (torch_cuda) is exactly what we want here, so we are
 | 
			
		||||
    # flouting the warning, but I can't think of a better (3.5+ compatible) way.
 | 
			
		||||
    target_link_libraries(caffe2::cudnn INTERFACE
 | 
			
		||||
    target_link_libraries(caffe2::cudnn-private INTERFACE
 | 
			
		||||
        "-Wl,--exclude-libs,libcudnn_static.a")
 | 
			
		||||
  else()
 | 
			
		||||
  set_property(
 | 
			
		||||
    TARGET caffe2::cudnn-public PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
    ${CUDNN_LIBRARY_PATH})
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
@ -346,33 +396,6 @@ if(CAFFE2_USE_TENSORRT)
 | 
			
		||||
      ${TENSORRT_INCLUDE_DIR})
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
# cublas. CUDA_CUBLAS_LIBRARIES is actually a list, so we will make an
 | 
			
		||||
# interface library similar to cudart.
 | 
			
		||||
add_library(caffe2::cublas INTERFACE IMPORTED)
 | 
			
		||||
if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32)
 | 
			
		||||
    set_property(
 | 
			
		||||
        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a")
 | 
			
		||||
    if(CUDA_VERSION VERSION_GREATER_EQUAL 10.1)
 | 
			
		||||
      set_property(
 | 
			
		||||
        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublasLt_static.a")
 | 
			
		||||
      # Add explicit dependency to cudart_static to fix
 | 
			
		||||
      # libcublasLt_static.a.o): undefined reference to symbol 'cudaStreamWaitEvent'
 | 
			
		||||
      # error adding symbols: DSO missing from command line
 | 
			
		||||
      set_property(
 | 
			
		||||
        TARGET caffe2::cublas APPEND PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        "${CUDA_cudart_static_LIBRARY}" rt dl)
 | 
			
		||||
    endif()
 | 
			
		||||
else()
 | 
			
		||||
    set_property(
 | 
			
		||||
        TARGET caffe2::cublas PROPERTY INTERFACE_LINK_LIBRARIES
 | 
			
		||||
        ${CUDA_CUBLAS_LIBRARIES})
 | 
			
		||||
endif()
 | 
			
		||||
set_property(
 | 
			
		||||
    TARGET caffe2::cublas PROPERTY INTERFACE_INCLUDE_DIRECTORIES
 | 
			
		||||
    ${CUDA_INCLUDE_DIRS})
 | 
			
		||||
 | 
			
		||||
# nvrtc
 | 
			
		||||
add_library(caffe2::nvrtc UNKNOWN IMPORTED)
 | 
			
		||||
set_property(
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,9 @@
 | 
			
		||||
set(MKLDNN_USE_NATIVE_ARCH ${USE_NATIVE_ARCH})
 | 
			
		||||
 | 
			
		||||
if(CPU_AARCH64)
 | 
			
		||||
  include(${CMAKE_CURRENT_LIST_DIR}/ComputeLibrary.cmake)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
find_package(MKLDNN QUIET)
 | 
			
		||||
 | 
			
		||||
if(NOT TARGET caffe2::mkldnn)
 | 
			
		||||
 | 
			
		||||
@ -30,8 +30,6 @@ Inside an ``InferenceMode`` block, we make the following performance guarantees:
 | 
			
		||||
- Inplace operations on inference tensors are guaranteed not to do a version bump.
 | 
			
		||||
 | 
			
		||||
For more implementation details of ``InferenceMode`` please see the `RFC-0011-InferenceMode <https://github.com/pytorch/rfcs/pull/17>`_.
 | 
			
		||||
Currently this guard is only available in C++ frontend, adding python frontend support
 | 
			
		||||
is tracked in #56608.
 | 
			
		||||
 | 
			
		||||
Migration guide from ``AutoNonVariableTypeMode``
 | 
			
		||||
------------------------------------------------
 | 
			
		||||
 | 
			
		||||
@ -142,6 +142,7 @@ Ops that can autocast to ``float32``
 | 
			
		||||
``exp``,
 | 
			
		||||
``expm1``,
 | 
			
		||||
``gelu``,
 | 
			
		||||
``grid_sample``,
 | 
			
		||||
``group_norm``,
 | 
			
		||||
``hinge_embedding_loss``,
 | 
			
		||||
``kl_div``,
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,10 @@ you can use it as ``functional.jacobian(lambda x: f(x, constant, flag=flag), inp
 | 
			
		||||
Locally disabling gradient computation
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
See :ref:`locally-disable-grad-doc` for more information on the differences
 | 
			
		||||
between no-grad and inference mode as well as other related mechanisms that
 | 
			
		||||
may be confused with the two.
 | 
			
		||||
 | 
			
		||||
.. autosummary::
 | 
			
		||||
    :toctree: generated
 | 
			
		||||
    :nosignatures:
 | 
			
		||||
 | 
			
		||||
@ -166,9 +166,10 @@ if RELEASE:
 | 
			
		||||
    version_end = torch.__version__.find('a')
 | 
			
		||||
    if version_end == -1:
 | 
			
		||||
        html_title = " ".join((project, torch.__version__, "documentation"))
 | 
			
		||||
        version = torch.__version__
 | 
			
		||||
    else:
 | 
			
		||||
        html_title = " ".join((project, torch.__version__[:version_end], "documentation"))
 | 
			
		||||
    version = torch.__version__
 | 
			
		||||
        version = torch.__version__[:version_end]
 | 
			
		||||
    release = version
 | 
			
		||||
 | 
			
		||||
# The language for content autogenerated by Sphinx. Refer to documentation
 | 
			
		||||
 | 
			
		||||
@ -180,6 +180,8 @@ joined.
 | 
			
		||||
 | 
			
		||||
.. autofunction:: is_nccl_available
 | 
			
		||||
 | 
			
		||||
.. autofunction:: is_torchelastic_launched
 | 
			
		||||
 | 
			
		||||
--------------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
Currently three initialization methods are supported:
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,5 @@
 | 
			
		||||
.. _elastic_errors-api:
 | 
			
		||||
 | 
			
		||||
Error Propagation
 | 
			
		||||
==================
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,6 @@
 | 
			
		||||
.. _launcher-api:
 | 
			
		||||
 | 
			
		||||
Elastic Launch
 | 
			
		||||
============================
 | 
			
		||||
 | 
			
		||||
torch.distributed.run
 | 
			
		||||
----------------------
 | 
			
		||||
torch.distributed.run (Elastic Launch)
 | 
			
		||||
======================================
 | 
			
		||||
 | 
			
		||||
.. automodule:: torch.distributed.run
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,5 @@
 | 
			
		||||
.. _elastic_train_script:
 | 
			
		||||
 | 
			
		||||
Train script
 | 
			
		||||
-------------
 | 
			
		||||
 | 
			
		||||
@ -7,18 +9,20 @@ working with ``torch.distributed.run`` with these differences:
 | 
			
		||||
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
 | 
			
		||||
   ``MASTER_ADDR``, and ``MASTER_PORT``.
 | 
			
		||||
 | 
			
		||||
2. ``rdzv_backend`` and ``rdzv_endpoint`` must be provided. For most users
 | 
			
		||||
   this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_).
 | 
			
		||||
2. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
 | 
			
		||||
   this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
 | 
			
		||||
   ``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
 | 
			
		||||
   the master address.
 | 
			
		||||
 | 
			
		||||
3. Make sure you have a ``load_checkpoint(path)`` and
 | 
			
		||||
   ``save_checkpoint(path)`` logic in your script. When workers fail
 | 
			
		||||
   we restart all the workers with the same program arguments so you will
 | 
			
		||||
   lose progress up to the most recent checkpoint
 | 
			
		||||
   (see `elastic launch <distributed.html>`_).
 | 
			
		||||
   ``save_checkpoint(path)`` logic in your script. When any number of
 | 
			
		||||
   workers fail we restart all the workers with the same program
 | 
			
		||||
   arguments so you will lose progress up to the most recent checkpoint
 | 
			
		||||
   (see `elastic launch <run.html>`_).
 | 
			
		||||
 | 
			
		||||
4. ``use_env`` flag has been removed. If you were parsing local rank by parsing
 | 
			
		||||
   the ``--local_rank`` option, you need to get the local rank from the
 | 
			
		||||
   environment variable ``LOCAL_RANK`` (e.g. ``os.environ["LOCAL_RANK"]``).
 | 
			
		||||
   environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).
 | 
			
		||||
 | 
			
		||||
Below is an expository example of a training script that checkpoints on each
 | 
			
		||||
epoch, hence the worst-case progress lost on failure is one full epoch worth
 | 
			
		||||
@ -31,7 +35,7 @@ of training.
 | 
			
		||||
       state = load_checkpoint(args.checkpoint_path)
 | 
			
		||||
       initialize(state)
 | 
			
		||||
 | 
			
		||||
       # torch.distributed.run ensure that this will work
 | 
			
		||||
       # torch.distributed.run ensures that this will work
 | 
			
		||||
       # by exporting all the env vars needed to initialize the process group
 | 
			
		||||
       torch.distributed.init_process_group(backend=args.backend)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -89,6 +89,7 @@ Non-linear activation functions
 | 
			
		||||
    sigmoid
 | 
			
		||||
    hardsigmoid
 | 
			
		||||
    silu
 | 
			
		||||
    mish
 | 
			
		||||
    batch_norm
 | 
			
		||||
    group_norm
 | 
			
		||||
    instance_norm
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ These are the basic building blocks for graphs:
 | 
			
		||||
 | 
			
		||||
    ~parameter.Parameter
 | 
			
		||||
    ~parameter.UninitializedParameter
 | 
			
		||||
    ~parameter.UninitializedBuffer
 | 
			
		||||
 | 
			
		||||
Containers
 | 
			
		||||
----------------------------------
 | 
			
		||||
@ -145,6 +146,7 @@ Non-linear Activations (weighted sum, nonlinearity)
 | 
			
		||||
    nn.GELU
 | 
			
		||||
    nn.Sigmoid
 | 
			
		||||
    nn.SiLU
 | 
			
		||||
    nn.Mish
 | 
			
		||||
    nn.Softplus
 | 
			
		||||
    nn.Softshrink
 | 
			
		||||
    nn.Softsign
 | 
			
		||||
 | 
			
		||||
@ -8,56 +8,6 @@ operations. It's not strictly necessary to understand all this, but we recommend
 | 
			
		||||
getting familiar with it, as it will help you write more efficient, cleaner
 | 
			
		||||
programs, and can aid you in debugging.
 | 
			
		||||
 | 
			
		||||
.. _excluding-subgraphs:
 | 
			
		||||
 | 
			
		||||
Excluding subgraphs from backward
 | 
			
		||||
---------------------------------
 | 
			
		||||
 | 
			
		||||
Every Tensor has a flag: :attr:`requires_grad` that allows for fine grained
 | 
			
		||||
exclusion of subgraphs from gradient computation and can increase efficiency.
 | 
			
		||||
 | 
			
		||||
.. _excluding-requires_grad:
 | 
			
		||||
 | 
			
		||||
``requires_grad``
 | 
			
		||||
^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
If there's a single input to an operation that requires gradient, its output
 | 
			
		||||
will also require gradient. Conversely, only if all inputs don't require
 | 
			
		||||
gradient, the output also won't require it. Backward computation is never
 | 
			
		||||
performed in the subgraphs, where all Tensors didn't require gradients.
 | 
			
		||||
 | 
			
		||||
.. code::
 | 
			
		||||
 | 
			
		||||
    >>> x = torch.randn(5, 5)  # requires_grad=False by default
 | 
			
		||||
    >>> y = torch.randn(5, 5)  # requires_grad=False by default
 | 
			
		||||
    >>> z = torch.randn((5, 5), requires_grad=True)
 | 
			
		||||
    >>> a = x + y
 | 
			
		||||
    >>> a.requires_grad
 | 
			
		||||
    False
 | 
			
		||||
    >>> b = a + z
 | 
			
		||||
    >>> b.requires_grad
 | 
			
		||||
    True
 | 
			
		||||
 | 
			
		||||
This is especially useful when you want to freeze part of your model, or you
 | 
			
		||||
know in advance that you're not going to use gradients w.r.t. some parameters.
 | 
			
		||||
For example if you want to finetune a pretrained CNN, it's enough to switch the
 | 
			
		||||
:attr:`requires_grad` flags in the frozen base, and no intermediate buffers will
 | 
			
		||||
be saved, until the computation gets to the last layer, where the affine
 | 
			
		||||
transform will use weights that require gradient, and the output of the network
 | 
			
		||||
will also require them.
 | 
			
		||||
 | 
			
		||||
.. code::
 | 
			
		||||
 | 
			
		||||
    model = torchvision.models.resnet18(pretrained=True)
 | 
			
		||||
    for param in model.parameters():
 | 
			
		||||
        param.requires_grad = False
 | 
			
		||||
    # Replace the last fully-connected layer
 | 
			
		||||
    # Parameters of newly constructed modules have requires_grad=True by default
 | 
			
		||||
    model.fc = nn.Linear(512, 100)
 | 
			
		||||
 | 
			
		||||
    # Optimize only the classifier
 | 
			
		||||
    optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
 | 
			
		||||
 | 
			
		||||
.. _how-autograd-encodes-history:
 | 
			
		||||
 | 
			
		||||
How autograd encodes the history
 | 
			
		||||
@ -86,6 +36,157 @@ flow statements, that can change the overall shape and size of the graph at
 | 
			
		||||
every iteration. You don't have to encode all possible paths before you
 | 
			
		||||
launch the training - what you run is what you differentiate.
 | 
			
		||||
 | 
			
		||||
.. _locally-disable-grad-doc:
 | 
			
		||||
 | 
			
		||||
Locally disabling gradient computation
 | 
			
		||||
--------------------------------------
 | 
			
		||||
 | 
			
		||||
There are several mechanisms available from Python to locally disable gradient
 | 
			
		||||
computation:
 | 
			
		||||
 | 
			
		||||
To disable gradients across entire blocks of code, there are context managers
 | 
			
		||||
like no-grad mode and inference mode.
 | 
			
		||||
For more fine-grained exclusion of subgraphs from gradient computation,
 | 
			
		||||
there is setting the ``requires_grad`` field of a tensor.
 | 
			
		||||
 | 
			
		||||
Below, in addition to discussing the mechanisms above, we also describe
 | 
			
		||||
evaluation mode (:meth:`nn.Module.eval()`), a method that is not actually used
 | 
			
		||||
to disable gradient computation but, because of its name, is often mixed up with the three.
 | 
			
		||||
 | 
			
		||||
Setting ``requires_grad``
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
:attr:`requires_grad` is a flag that allows for fine-grained exclusion of
 | 
			
		||||
subgraphs from gradient computation. It takes effect in both the forward
 | 
			
		||||
and backward passes:
 | 
			
		||||
 | 
			
		||||
During the forward pass, an operation is only recorded in the backward graph if
 | 
			
		||||
at least one of its input tensors require grad.
 | 
			
		||||
During the backward pass (``.backward()``), only leaf tensors with
 | 
			
		||||
``requires_grad=True`` will have gradients accumulated into their ``.grad``
 | 
			
		||||
fields.
 | 
			
		||||
 | 
			
		||||
It is important to note that even though every tensor has this flag,
 | 
			
		||||
*setting* it only makes sense for leaf tensors (tensors that do not have a
 | 
			
		||||
``grad_fn``, e.g., a ``nn.Module``'s parameters).
 | 
			
		||||
Non-leaf tensors (tensors that do have ``grad_fn``) are tensors that have a
 | 
			
		||||
backward graph associated with them. Thus their gradients will be needed
 | 
			
		||||
as an intermediary result to compute the gradient for a leaf tensor that
 | 
			
		||||
requires grad. From this definition, it is clear that all non-leaf tensors
 | 
			
		||||
will automatically have ``require_grad=True``.
 | 
			
		||||
 | 
			
		||||
Setting ``requires_grad`` should be the main way you control which parts
 | 
			
		||||
of the model are part of the gradient computation, for example, if you need to
 | 
			
		||||
freeze parts of your pretrained model during model fine-tuning.
 | 
			
		||||
 | 
			
		||||
To freeze parts of your model, simply apply ``.requires_grad_(False)`` to
 | 
			
		||||
the parameters that you don't want updated. And as described above,
 | 
			
		||||
since computations that use these parameters as inputs would not be recorded in
 | 
			
		||||
the forward pass, they won't have their ``.grad`` fields updated in the backward
 | 
			
		||||
pass because they won't be part of the backward graph in the first place, as
 | 
			
		||||
desired.
 | 
			
		||||
 | 
			
		||||
Because this is such a common pattern, ``requires_grad`` can also be set at
 | 
			
		||||
the module level with :meth:`nn.Module.requires_grad_()`.
 | 
			
		||||
When applied to a module, ``.requires_grad_()`` takes effect on all
 | 
			
		||||
of the module's parameters (which have ``requires_grad=True`` by default).
 | 
			
		||||
 | 
			
		||||
Grad Modes
 | 
			
		||||
^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
Apart from setting ``requires_grad`` there are also three possible modes
 | 
			
		||||
enableable from Python that can affect how computations in PyTorch are
 | 
			
		||||
processed by autograd internally: default mode (grad mode), no-grad mode,
 | 
			
		||||
and inference mode, all of which can be togglable via context managers and
 | 
			
		||||
decorators.
 | 
			
		||||
 | 
			
		||||
Default Mode (Grad Mode)
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
The "default mode" is actually the mode we are implicitly in when no other modes like
 | 
			
		||||
no-grad and inference mode are enabled. To be contrasted with
 | 
			
		||||
"no-grad mode" the default mode is also sometimes called "grad mode".
 | 
			
		||||
 | 
			
		||||
The most important thing to know about the default mode is that it is the only
 | 
			
		||||
mode in which ``requires_grad`` takes effect. ``requires_grad`` is always overridden
 | 
			
		||||
to be ``False`` in both the two other modes.
 | 
			
		||||
 | 
			
		||||
No-grad Mode
 | 
			
		||||
^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
Computations in no-grad mode behave as if none of the inputs require grad.
 | 
			
		||||
In other words, computations in no-grad mode are never recorded in the backward graph
 | 
			
		||||
even if there are inputs that have ``require_grad=True``.
 | 
			
		||||
 | 
			
		||||
Enable no-grad mode when you need to perform operations that should not be
 | 
			
		||||
recorded by autograd, but you’d still like to use the outputs of these
 | 
			
		||||
computations in grad mode later. This context manager makes it convenient to
 | 
			
		||||
disable gradients for a block of code or function without
 | 
			
		||||
having to temporarily set tensors to have ``requires_grad=False``, and then
 | 
			
		||||
back to ``True``.
 | 
			
		||||
 | 
			
		||||
For example, no-grad mode might be useful when writing an optimizer: when
 | 
			
		||||
performing the training update you’d like to update parameters
 | 
			
		||||
in-place without the update being recorded by autograd.
 | 
			
		||||
You also intend to use the updated parameters for computations in
 | 
			
		||||
grad mode in the next forward pass.
 | 
			
		||||
 | 
			
		||||
The implementations in :ref:`nn-init-doc` also
 | 
			
		||||
rely on no-grad mode when initializing the parameters as to avoid
 | 
			
		||||
autograd tracking when updating the intialized parameters in-place.
 | 
			
		||||
 | 
			
		||||
Inference Mode
 | 
			
		||||
^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
Inference mode is the extreme version of no-grad mode. Just like in no-grad
 | 
			
		||||
mode, computations in inference mode are not recorded in the backward graph, but
 | 
			
		||||
enabling inference mode will allow PyTorch to speed up your model even more.
 | 
			
		||||
This better runtime comes with a drawback: tensors created in inference mode
 | 
			
		||||
will not be able to be used in computations to be recorded by autograd after
 | 
			
		||||
exiting inference mode.
 | 
			
		||||
 | 
			
		||||
Enable inference mode when you are performing computations that don’t need
 | 
			
		||||
to be recorded in the backward graph, AND you don’t plan on using the tensors
 | 
			
		||||
created in inference mode in any computation that is to be recorded by autograd later.
 | 
			
		||||
 | 
			
		||||
It is recommended that you try out inference mode in the parts of your code
 | 
			
		||||
that do not require autograd tracking (e.g., data processing and model evaluation).
 | 
			
		||||
If it works out of the box
 | 
			
		||||
for your use case it’s a free performance win. If you run into errors after
 | 
			
		||||
enabling inference mode, check that you are not using tensors created in
 | 
			
		||||
inference mode in computations that are recorded by autograd after exiting inference
 | 
			
		||||
mode. If you cannot avoid such use in your case, you can always switch back
 | 
			
		||||
to no-grad mode.
 | 
			
		||||
 | 
			
		||||
For details on inference mode please see
 | 
			
		||||
`Inference Mode <https://pytorch.org/cppdocs/notes/inference_mode.html>`_.
 | 
			
		||||
 | 
			
		||||
For implementation details of inference mode see
 | 
			
		||||
`RFC-0011-InferenceMode <https://github.com/pytorch/rfcs/pull/17>`_.
 | 
			
		||||
 | 
			
		||||
Evaluation Mode (``nn.Module.eval()``)
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
Evaluation mode is not actually a mechanism to locally disable gradient computation.
 | 
			
		||||
It is included here anyway because it is sometimes confused to be such a mechanism.
 | 
			
		||||
 | 
			
		||||
Functionally, ``module.eval()`` (or equivalently ``module.train()``) are completely
 | 
			
		||||
orthogonal to no-grad mode and inference mode. How ``model.eval()`` affects
 | 
			
		||||
your model depends entirely on the specific modules used in your model and
 | 
			
		||||
whether they define any training-mode specific behavior.
 | 
			
		||||
 | 
			
		||||
You are responsible for calling ``model.eval()`` and ``model.train()`` if your
 | 
			
		||||
model relies on modules such as :class:`torch.nn.Dropout` and
 | 
			
		||||
:class:`torch.nn.BatchNorm2d` that may behave
 | 
			
		||||
differently depending on training mode, for example, to avoid updating your
 | 
			
		||||
BatchNorm running statistics on validation data.
 | 
			
		||||
 | 
			
		||||
It is recommended that you always use ``model.train()`` when
 | 
			
		||||
training and ``model.eval()`` when evaluating your model (validation/testing) even
 | 
			
		||||
if you aren’t sure your model has training-mode specific behavior, because a
 | 
			
		||||
module you are using might be updated to behave differently in training and
 | 
			
		||||
eval modes.
 | 
			
		||||
 | 
			
		||||
In-place operations with autograd
 | 
			
		||||
---------------------------------
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,10 +2,787 @@
 | 
			
		||||
 | 
			
		||||
torch.package
 | 
			
		||||
=============
 | 
			
		||||
``torch.package`` adds support for creating hermetic packages containing arbitrary
 | 
			
		||||
PyTorch code. These packages can be saved, shared, used to load and execute models
 | 
			
		||||
at a later date or on a different machine, and can even be deployed to production using
 | 
			
		||||
``torch::deploy``.
 | 
			
		||||
 | 
			
		||||
This document contains tutorials, how-to guides, explanations, and an API reference that
 | 
			
		||||
will help you learn more about ``torch.package`` and how to use it.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
.. warning::
 | 
			
		||||
 | 
			
		||||
    This module is experimental and has not yet been publicly released.
 | 
			
		||||
    This module depends on the ``pickle`` module which is is not secure. Only unpackage data you trust.
 | 
			
		||||
 | 
			
		||||
    It is possible to construct malicious pickle data which will **execute arbitrary code during unpickling**.
 | 
			
		||||
    Never unpackage data that could have come from an untrusted source, or that could have been tampered with.
 | 
			
		||||
 | 
			
		||||
    For more information, review the `documentation <https://docs.python.org/3/library/pickle.html>`_ for the ``pickle`` module.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
.. contents:: :local:
 | 
			
		||||
    :depth: 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Tutorials
 | 
			
		||||
---------
 | 
			
		||||
Packaging your first model
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
A tutorial that guides you through packaging and unpackaging a simple model is available
 | 
			
		||||
`on Colab <https://colab.research.google.com/drive/1lFZkLyViGfXxB-m3jqlyTQuYToo3XLo->`_.
 | 
			
		||||
After completing this exercise, you will be familiar with the basic API for creating and using
 | 
			
		||||
Torch packages.
 | 
			
		||||
 | 
			
		||||
How do I...
 | 
			
		||||
-----------
 | 
			
		||||
See what is inside a package?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
Treat the package like a ZIP archive
 | 
			
		||||
""""""""""""""""""""""""""""""""""""
 | 
			
		||||
The container format for a ``torch.package`` is ZIP, so any tools that work with standard ZIP files should
 | 
			
		||||
work for exploring the contents. Some common ways to interact with ZIP files:
 | 
			
		||||
 | 
			
		||||
* ``unzip my_package.pt`` will unzip the ``torch.package`` archive to disk, where you can freely inspect its contents.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    $ unzip my_package.pt && tree my_package
 | 
			
		||||
    my_package
 | 
			
		||||
    ├── .data
 | 
			
		||||
    │   ├── 94304870911616.storage
 | 
			
		||||
    │   ├── 94304900784016.storage
 | 
			
		||||
    │   ├── extern_modules
 | 
			
		||||
    │   └── version
 | 
			
		||||
    ├── models
 | 
			
		||||
    │   └── model_1.pkl
 | 
			
		||||
    └── torchvision
 | 
			
		||||
        └── models
 | 
			
		||||
            ├── resnet.py
 | 
			
		||||
            └── utils.py
 | 
			
		||||
    ~ cd my_package && cat torchvision/models/resnet.py
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
* The Python ``zipfile`` module provides a standard way to read and write ZIP archive contents.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    from zipfile import ZipFile
 | 
			
		||||
    with ZipFile("my_package.pt") as myzip:
 | 
			
		||||
        file_bytes = myzip.read("torchvision/models/resnet.py")
 | 
			
		||||
        # edit file_bytes in some way
 | 
			
		||||
        myzip.writestr("torchvision/models/resnet.py", new_file_bytes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
* vim has the ability to natively read ZIP archives. You can even edit files and :``write`` them back into the archive!
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # add this to your .vimrc to treat `*.pt` files as zip files
 | 
			
		||||
    au BufReadCmd *.pt call zip#Browse(expand("<amatch>"))
 | 
			
		||||
 | 
			
		||||
    ~ vi my_package.pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Use the ``file_structure()`` API
 | 
			
		||||
""""""""""""""""""""""""""""""""
 | 
			
		||||
:class:`PackageImporter` and :class:`PackageExporter` provide a ``file_structure()`` method, which will return a printable
 | 
			
		||||
and queryable ``Folder`` object. The ``Folder`` object is a simple directory structure that you can use to explore the
 | 
			
		||||
current contents of a ``torch.package``.
 | 
			
		||||
 | 
			
		||||
The ``Folder`` object itself is directly printable and will print out a file tree representation. To filter what is returned,
 | 
			
		||||
use the glob-style ``include`` and ``exclude`` filtering arguments.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    with PackageExporter('my_package.pt', verbose=False) as pe:
 | 
			
		||||
        pe.save_pickle('models', 'model_1.pkl', mod)
 | 
			
		||||
        # can limit printed items with include/exclude args
 | 
			
		||||
        print(pe.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storages"))
 | 
			
		||||
 | 
			
		||||
    importer = PackageImporter('my_package.pt')
 | 
			
		||||
    print(importer.file_structure()) # will print out all files
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Output:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # filtered with glob pattern:
 | 
			
		||||
    #    include=["**/utils.py", "**/*.pkl"], exclude="**/*.storages"
 | 
			
		||||
    ─── my_package.pt
 | 
			
		||||
        ├── models
 | 
			
		||||
        │   └── model_1.pkl
 | 
			
		||||
        └── torchvision
 | 
			
		||||
            └── models
 | 
			
		||||
                └── utils.py
 | 
			
		||||
 | 
			
		||||
    # all files
 | 
			
		||||
    ─── my_package.pt
 | 
			
		||||
        ├── .data
 | 
			
		||||
        │   ├── 94304870911616.storage
 | 
			
		||||
        │   ├── 94304900784016.storage
 | 
			
		||||
        │   ├── extern_modules
 | 
			
		||||
        │   └── version
 | 
			
		||||
        ├── models
 | 
			
		||||
        │   └── model_1.pkl
 | 
			
		||||
        └── torchvision
 | 
			
		||||
            └── models
 | 
			
		||||
                ├── resnet.py
 | 
			
		||||
                └── utils.py
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
You can also query ``Folder`` objects with the ``has_file()`` method.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    exporter_file_structure = exporter.file_structure()
 | 
			
		||||
    found: bool = exporter_file_structure.has_file("package_a/subpackage.py")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Include arbitrary resources with my package and access them later?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
:class:`PackageExporter` exposes three methods, ``save_pickle``, ``save_text`` and ``save_binary`` that allow you to save
 | 
			
		||||
Python objects, text, and binary data to a package.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    with torch.PackageExporter("package.pt") as exporter:
 | 
			
		||||
        # Pickles the object and saves to `my_resources/tens.pkl` in the archive.
 | 
			
		||||
        exporter.save_pickle("my_resources", "tensor.pkl", torch.randn(4))
 | 
			
		||||
        exporter.save_text("config_stuff", "words.txt", "a sample string")
 | 
			
		||||
        exporter.save_binary("raw_data", "binary", my_bytes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
:class:`PackageImporter` exposes complementary methods named ``load_pickle``, ``load_text`` and ``load_binary`` that allow you to load
 | 
			
		||||
Python objects, text and binary data from a package.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    importer = torch.PackageImporter("package.pt")
 | 
			
		||||
    my_tensor = importer.load_pickle("my_resources", "tensor.pkl")
 | 
			
		||||
    text = importer.load_text("config_stuff", "words.txt")
 | 
			
		||||
    binary = importer.load_binary("raw_data", "binary")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Customize how a class is packaged?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
``torch.package`` allows for the customization of how classes are packaged. This behavior is accessed through defining the method
 | 
			
		||||
``__reduce_package__`` on a class and by defining a corresponding de-packaging function. This is similar to defining ``__reduce__`` for
 | 
			
		||||
Python’s normal pickling process.
 | 
			
		||||
 | 
			
		||||
Steps:
 | 
			
		||||
 | 
			
		||||
1. Define the method ``__reduce_package__(self, exporter: PackageExporter)`` on the target class. This method should do the work to save the class instance inside of the package, and should return a tuple of the corresponding de-packaging function with the arguments needed to invoke the de-packaging function. This method is called by the ``PackageExporter`` when it encounters an instance of the target class.
 | 
			
		||||
2. Define a de-packaging function for the class. This de-packaging function should do the work to reconstruct and return an instance of the class. The function signature’s first parameter should be a ``PackageImporter`` instance, and the rest of the parameters are user defined.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # foo.py [Example of customizing how class Foo is packaged]
 | 
			
		||||
    from torch.package import PackageExporter, PackageImporter
 | 
			
		||||
    import time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    class Foo:
 | 
			
		||||
        def __init__(self, my_string: str):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.my_string = my_string
 | 
			
		||||
            self.time_imported = 0
 | 
			
		||||
            self.time_exported = 0
 | 
			
		||||
 | 
			
		||||
        def __reduce_package__(self, exporter: PackageExporter):
 | 
			
		||||
            """
 | 
			
		||||
            Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when
 | 
			
		||||
            saving an instance of this object. This method should do the work to save this
 | 
			
		||||
            object inside of the ``torch.package`` archive.
 | 
			
		||||
 | 
			
		||||
            Returns function w/ arguments to load the object from a
 | 
			
		||||
            ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
            # use this pattern to ensure no naming conflicts with normal dependencies,
 | 
			
		||||
            # anything saved under this module name shouldn't conflict with other
 | 
			
		||||
            # items in the package
 | 
			
		||||
            generated_module_name = f"foo-generated._{exporter.get_unique_id()}"
 | 
			
		||||
            exporter.save_text(
 | 
			
		||||
                generated_module_name,
 | 
			
		||||
                "foo.txt",
 | 
			
		||||
                self.my_string + ", with exporter modification!",
 | 
			
		||||
            )
 | 
			
		||||
            time_exported = time.clock_gettime(1)
 | 
			
		||||
 | 
			
		||||
            # returns de-packaging function w/ arguments to invoke with
 | 
			
		||||
            return (unpackage_foo, (generated_module_name, time_exported,))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def unpackage_foo(
 | 
			
		||||
        importer: PackageImporter, generated_module_name: str, time_exported: float
 | 
			
		||||
    ) -> Foo:
 | 
			
		||||
        """
 | 
			
		||||
        Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function
 | 
			
		||||
        when depickling a Foo object.
 | 
			
		||||
        Performs work of loading and returning a Foo instance from a ``torch.package`` archive.
 | 
			
		||||
        """
 | 
			
		||||
        time_imported = time.clock_gettime(1)
 | 
			
		||||
        foo = Foo(importer.load_text(generated_module_name, "foo.txt"))
 | 
			
		||||
        foo.time_imported = time_imported
 | 
			
		||||
        foo.time_exported = time_exported
 | 
			
		||||
        return foo
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # example of saving instances of class Foo
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
    from torch.package import PackageImporter, PackageExporter
 | 
			
		||||
    import foo
 | 
			
		||||
 | 
			
		||||
    foo_1 = foo.Foo("foo_1 initial string")
 | 
			
		||||
    foo_2 = foo.Foo("foo_2 initial string")
 | 
			
		||||
    with PackageExporter('foo_package.pt', verbose=False) as pe:
 | 
			
		||||
        # save as normal, no extra work necessary
 | 
			
		||||
        pe.save_pickle('foo_collection', 'foo1.pkl', foo_1)
 | 
			
		||||
        pe.save_pickle('foo_collection', 'foo2.pkl', foo_2)
 | 
			
		||||
        print(pe.file_structure())
 | 
			
		||||
 | 
			
		||||
    pi = PackageImporter('foo_package.pt')
 | 
			
		||||
    imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl')
 | 
			
		||||
    print(f"foo_1 string: '{imported_foo.my_string}'")
 | 
			
		||||
    print(f"foo_1 export time: {imported_foo.time_exported}")
 | 
			
		||||
    print(f"foo_1 import time: {imported_foo.time_imported}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # output of running above script
 | 
			
		||||
    ─── foo_package
 | 
			
		||||
        ├── foo-generated
 | 
			
		||||
        │   ├── _0
 | 
			
		||||
        │   │   └── foo.txt
 | 
			
		||||
        │   └── _1
 | 
			
		||||
        │       └── foo.txt
 | 
			
		||||
        ├── foo_collection
 | 
			
		||||
        │   ├── foo1.pkl
 | 
			
		||||
        │   └── foo2.pkl
 | 
			
		||||
        └── foo.py
 | 
			
		||||
 | 
			
		||||
    foo_1 string: 'foo_1 initial string, with reduction modification!'
 | 
			
		||||
    foo_1 export time: 9857706.650140837
 | 
			
		||||
    foo_1 import time: 9857706.652698385
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Test in my source code whether or not it is executing inside a package?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
A :class:`PackageImporter` will add the attribute ``__torch_package__`` to every module that it initializes. Your code can check for the
 | 
			
		||||
presence of this attribute to determine whether it is executing in a packaged context or not.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # In foo/bar.py:
 | 
			
		||||
 | 
			
		||||
    if "__torch_package__" in dir():  # true if the code is being loaded from a package
 | 
			
		||||
        def is_in_package():
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        UserException = Exception
 | 
			
		||||
    else:
 | 
			
		||||
        def is_in_package():
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        UserException = UnpackageableException
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Now, the code will behave differently depending on whether it’s imported normally through your Python environment or imported from a
 | 
			
		||||
``torch.package``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    from foo.bar import is_in_package
 | 
			
		||||
 | 
			
		||||
    print(is_in_package())  # False
 | 
			
		||||
 | 
			
		||||
    loaded_module = PackageImporter(my_pacakge).import_module("foo.bar")
 | 
			
		||||
    loaded_module.is_in_package()  # True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
**Warning**: in general, it’s bad practice to have code that behaves differently depending on whether it’s packaged or not. This can lead to
 | 
			
		||||
hard-to-debug issues that are sensitive to how you imported your code. If your package is intended to be heavily used, consider restructuring
 | 
			
		||||
your code so that it behaves the same way no matter how it was loaded.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Patch code into a package?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
:class:`PackageExporter` offers a ``save_source_string()`` method that allows one to save arbitrary Python source code to a module of your choosing.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    with PackageExporter(f) as exporter:
 | 
			
		||||
        # Save the my_module.foo available in your current Python environment.
 | 
			
		||||
        exporter.save_module("my_module.foo")
 | 
			
		||||
 | 
			
		||||
        # This saves the provided string to my_module/foo.py in the package archive.
 | 
			
		||||
        # It will override the my_module.foo that was previously saved.
 | 
			
		||||
        exporter.save_source_string("my_module.foo", textwrap.dedent(
 | 
			
		||||
            """\
 | 
			
		||||
            def my_function():
 | 
			
		||||
                print('hello world')
 | 
			
		||||
            """
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
        # If you want to treat my_module.bar as a package
 | 
			
		||||
        # (e.g. save to `my_module/bar/__init__.py` instead of `my_module/bar.py)
 | 
			
		||||
        # pass is_package=True,
 | 
			
		||||
        exporter.save_source_string("my_module.bar",
 | 
			
		||||
                                    "def foo(): print('hello')\n",
 | 
			
		||||
                                    is_package=True)
 | 
			
		||||
 | 
			
		||||
    importer = PackageImporter(f)
 | 
			
		||||
    importer.import_module("my_module.foo").my_function()  # prints 'hello world'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Access package contents from packaged code?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
:class:`PackageImporter` implements the
 | 
			
		||||
`importlib.resources <https://docs.python.org/3/library/importlib.html#module-importlib.resources>`_
 | 
			
		||||
API for accessing resources from inside a package.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    with PackageExporter(f) as exporter:
 | 
			
		||||
        # saves text to one/a.txt in the archive
 | 
			
		||||
        exporter.save_text("my_resource", "a.txt", "hello world!")
 | 
			
		||||
        # saves the tensor to my_pickle/obj.pkl
 | 
			
		||||
        exporter.save_pickle("my_pickle", "obj.pkl", torch.ones(2, 2))
 | 
			
		||||
 | 
			
		||||
        # see below for module contents
 | 
			
		||||
        exporter.save_module("foo")
 | 
			
		||||
        exporter.save_module("bar")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
The ``importlib.resources`` API allows access to resources from within packaged code.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # foo.py:
 | 
			
		||||
    import importlib.resources
 | 
			
		||||
    import my_resource
 | 
			
		||||
 | 
			
		||||
    # returns "hello world!"
 | 
			
		||||
    def get_my_resource():
 | 
			
		||||
        return importlib.resources.read_text(my_resource, "a.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Using ``importlib.resources`` is the recommended way to access package contents from within packaged code, since it complies
 | 
			
		||||
with the Python standard. However, it is also possible to access the parent :class:`PackageImporter` instance itself from within
 | 
			
		||||
packaged code.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # bar.py:
 | 
			
		||||
    import torch_package_importer # this is the PackageImporter that imported this module.
 | 
			
		||||
 | 
			
		||||
    # Prints "hello world!", equivalient to importlib.resources.read_text
 | 
			
		||||
    def get_my_resource():
 | 
			
		||||
        return torch_package_importer.load_text("my_resource", "a.txt")
 | 
			
		||||
 | 
			
		||||
    # You also do things that the importlib.resources API does not support, like loading
 | 
			
		||||
    # a pickled object from the package.
 | 
			
		||||
    def get_my_pickle():
 | 
			
		||||
        return torch_package_importer.load_pickle("my_pickle", "obj.pkl")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Distinguish between packaged code and non-packaged code?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
To tell if an object’s code is from a ``torch.package``, use the ``torch.package.is_from_package()`` function.
 | 
			
		||||
Note: if an object is from a package but its definition is from a module marked ``extern`` or from ``stdlib``,
 | 
			
		||||
this check will return ``False``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    importer = PackageImporter(f)
 | 
			
		||||
    mod = importer.import_module('foo')
 | 
			
		||||
    obj = importer.load_pickle('model', 'model.pkl')
 | 
			
		||||
    txt = importer.load_text('text', 'my_test.txt')
 | 
			
		||||
 | 
			
		||||
    assert is_from_package(mod)
 | 
			
		||||
    assert is_from_package(obj)
 | 
			
		||||
    assert not is_from_package(txt) # str is from stdlib, so this will return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Re-export an imported object?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
To re-export an object that was previously imported by a :class:`PackageImporter`, you must make the new :class:`PackageExporter`
 | 
			
		||||
aware of the original :class:`PackageImporter` so that it can find source code for your object’s dependencies.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    importer = PackageImporter(f)
 | 
			
		||||
    obj = importer.load_pickle("model", "model.pkl")
 | 
			
		||||
 | 
			
		||||
    # re-export obj in a new package
 | 
			
		||||
    with PackageExporter(f2, importer=(importer, sys_importer)) as exporter:
 | 
			
		||||
        exporter.save_pickle("model", "model.pkl", obj)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Package a TorchScript module?
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
To package a TorchScript model, use the same ``save_pickle`` and ``load_pickle`` APIs as you would with any other object.
 | 
			
		||||
Saving TorchScript objects that are attributes or submodules is supported as well with no extra work.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    # save TorchScript just like any other object
 | 
			
		||||
    with PackageExporter(file_name, verbose=True) as e:
 | 
			
		||||
        e.save_pickle("res", "script_model.pkl", scripted_model)
 | 
			
		||||
        e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule)
 | 
			
		||||
    # load as normal
 | 
			
		||||
    importer = PackageImporter(file_name)
 | 
			
		||||
    loaded_script = importer.load_pickle("res", "script_model.pkl")
 | 
			
		||||
    loaded_mixed = importer.load_pickle("res", "mixed_model.pkl"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Explanation
 | 
			
		||||
-----------
 | 
			
		||||
``torch.package`` Format Overview
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
A ``torch.package`` file is a ZIP archive which conventionally uses the ``.pt`` extension. Inside the ZIP archive, there are two kinds of files:
 | 
			
		||||
 | 
			
		||||
* Framework files, which are placed in the ``.data/``.
 | 
			
		||||
* User files, which is everything else.
 | 
			
		||||
 | 
			
		||||
As an example, this is what a fully packaged ResNet model from ``torchvision`` looks like:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    resnet
 | 
			
		||||
    ├── .data  # All framework-specific data is stored here.
 | 
			
		||||
    │   │      # It's named to avoid conflicts with user-serialized code.
 | 
			
		||||
    │   ├── 94286146172688.storage  # tensor data
 | 
			
		||||
    │   ├── 94286146172784.storage
 | 
			
		||||
    │   ├── extern_modules  # text file with names of extern modules (e.g. 'torch')
 | 
			
		||||
    │   ├── version         # version metadata
 | 
			
		||||
    │   ├── ...
 | 
			
		||||
    ├── model  # the pickled model
 | 
			
		||||
    │   └── model.pkl
 | 
			
		||||
    └── torchvision  # all code dependencies are captured as source files
 | 
			
		||||
        └── models
 | 
			
		||||
            ├── resnet.py
 | 
			
		||||
            └── utils.py
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Framework files
 | 
			
		||||
"""""""""""""""
 | 
			
		||||
The ``.data/`` directory is owned by torch.package, and its contents are considered to be a private implementation detail.
 | 
			
		||||
The ``torch.package`` format makes no guarantees about the contents of ``.data/``, but any changes made will be backward compatible
 | 
			
		||||
(that is, newer version of PyTorch will always be able to load older ``torch.packages``).
 | 
			
		||||
 | 
			
		||||
Currently, the ``.data/`` directory contains the following items:
 | 
			
		||||
 | 
			
		||||
* ``version``: a version number for the serialized format, so that the ``torch.package`` import infrastructures knows how to load this package.
 | 
			
		||||
* ``extern_modules``: a list of modules that are considered ``extern:class:`PackageImporter`. ``extern`` modules will be imported using the loading environment’s system importer.
 | 
			
		||||
* ``*.storage``: serialized tensor data.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    .data
 | 
			
		||||
    ├── 94286146172688.storage
 | 
			
		||||
    ├── 94286146172784.storage
 | 
			
		||||
    ├── extern_modules
 | 
			
		||||
    ├── version
 | 
			
		||||
    ├── ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
User files
 | 
			
		||||
""""""""""
 | 
			
		||||
All other files in the archive were put there by a user. The layout is identical to a Python
 | 
			
		||||
`regular package <https://docs.python.org/3/reference/import.html#regular-packages>`_. For a deeper dive in how Python packaging works,
 | 
			
		||||
please consult `this essay <https://www.python.org/doc/essays/packages/>`_ (it’s slightly out of date, so double-check implementation details
 | 
			
		||||
with the `Python reference documentation <https://docs.python.org/3/library/importlib.html>`_).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    <package root>
 | 
			
		||||
    ├── model  # the pickled model
 | 
			
		||||
    │   └── model.pkl
 | 
			
		||||
    ├── another_package
 | 
			
		||||
    │   ├── __init__.py
 | 
			
		||||
    │   ├── foo.txt         # a resource file , see importlib.resources
 | 
			
		||||
    │   └── ...
 | 
			
		||||
    └── torchvision
 | 
			
		||||
        └── models
 | 
			
		||||
            ├── resnet.py   # torchvision.models.resnet
 | 
			
		||||
            └── utils.py    # torchvision.models.utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
How ``torch.package`` finds your code's dependencies
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
Analyzing an object's dependencies
 | 
			
		||||
""""""""""""""""""""""""""""""""""
 | 
			
		||||
When you issue a ``save_pickle(obj, ...)`` call, :class:`PackageExporter` will pickle the object normally. Then, it uses the
 | 
			
		||||
``pickletools`` standard library module to parse the pickle bytecode.
 | 
			
		||||
 | 
			
		||||
In a pickle, an object is saved along with a ``GLOBAL`` opcode that describes where to find the implementation of the object’s type, like:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    GLOBAL 'torchvision.models.resnet Resnet`
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
The dependency resolver will gather up all ``GLOBAL`` ops and mark them as dependencies of your pickled object.
 | 
			
		||||
For more information about pickling and the pickle format, please consult `the Python docs <https://docs.python.org/3/library/pickle.html>`_.
 | 
			
		||||
 | 
			
		||||
Analyzing a module's dependencies
 | 
			
		||||
"""""""""""""""""""""""""""""""""
 | 
			
		||||
When a Python module is identified as a dependency, ``torch.package`` walks the module’s python AST representation and looks for import statements with
 | 
			
		||||
full support for the standard forms: ``from x import y``, ``import z``, ``from w import v as u``, etc. When one of these import statements are
 | 
			
		||||
encountered, ``torch.package`` registers the imported modules as dependencies that are then themselves parsed in the same AST walking way.
 | 
			
		||||
 | 
			
		||||
**Note**: AST parsing has limited support for the ``__import__(...)`` syntax and does not support ``importlib.import_module`` calls. In general, you should
 | 
			
		||||
not expect dynamic imports to be detected by ``torch.package``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Dependency Management
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
``torch.package`` automatically finds the Python modules that your code and objects depend on. This process is called dependency resolution.
 | 
			
		||||
For each module that the dependency resolver finds, you must specify an *action* to take.
 | 
			
		||||
 | 
			
		||||
The allowed actions are:
 | 
			
		||||
 | 
			
		||||
* ``intern``: put this module into the package.
 | 
			
		||||
* ``extern``: declare this module as an external dependency of the package.
 | 
			
		||||
* ``mock``: stub out this module.
 | 
			
		||||
* ``deny``: depending on this module will raise an error during package export.
 | 
			
		||||
 | 
			
		||||
Finally, there is one more important action that is not technically part of ``torch.package``:
 | 
			
		||||
 | 
			
		||||
* Refactoring: remove or change the dependencies in your code.
 | 
			
		||||
 | 
			
		||||
Note that actions are only defined on entire Python modules. There is no way to package “just” a function or class from module and leave the rest out.
 | 
			
		||||
This is by design. Python does not offer clean boundaries between objects defined in a module. The only defined unit of dependency organization is a
 | 
			
		||||
module, so that’s what ``torch.package`` uses.
 | 
			
		||||
 | 
			
		||||
Actions are applied to modules using patterns. Patterns can either be module names (``"foo.bar"``) or globs (like ``"foo.**"``). You associate a pattern
 | 
			
		||||
with an action using methods on :class:`PackageImporter`, e.g.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    my_exporter.intern("torchvision.**")
 | 
			
		||||
    my_exporter.extern("numpy")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
If a module matches a pattern, the corresponding action is applied to it. For a given module, patterns will be checked in the order that they were defined,
 | 
			
		||||
and the first action will be taken.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
``intern``
 | 
			
		||||
""""""""""
 | 
			
		||||
If a module is ``intern``-ed, it will be placed into the package.
 | 
			
		||||
 | 
			
		||||
This action is your model code, or any related code you want to package. For example, if you are trying to package a ResNet from ``torchvision``,
 | 
			
		||||
you will need to ``intern`` the module torchvision.models.resnet.
 | 
			
		||||
 | 
			
		||||
On package import, when your packaged code tries to import an ``intern``-ed module, PackageImporter will look inside your package for that module.
 | 
			
		||||
If it can’t find that module, an error will be raised. This ensures that each :class:`PackageImporter` is isolated from the loading environment—even
 | 
			
		||||
if you have ``my_interned_module`` available in both your package and the loading environment, :class:`PackageImporter` will only use the version in your
 | 
			
		||||
package.
 | 
			
		||||
 | 
			
		||||
**Note**: Only Python source modules can be ``intern``-ed. Other kinds of modules, like C extension modules and bytecode modules, will raise an error if
 | 
			
		||||
you attempt to ``intern`` them. These kinds of modules need to be ``mock``-ed or ``extern``-ed.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
``extern``
 | 
			
		||||
""""""""""
 | 
			
		||||
If a module is ``extern``-ed, it will not be packaged. Instead, it will be added to a list of external dependencies for this package. You can find this
 | 
			
		||||
list on ``package_exporter.extern_modules``.
 | 
			
		||||
 | 
			
		||||
On package import, when time packaged code tries to import an ``extern``-ed module, :class:`PackageImporter` will use the default Python importer to find
 | 
			
		||||
that module, as if you did ``importlib.import_module("my_externed_module")``. If it can’t find that module, an error will be raised.
 | 
			
		||||
 | 
			
		||||
In this way, you can depend on third-party libraries like ``numpy`` and ``scipy`` from within your package without having to package them too.
 | 
			
		||||
 | 
			
		||||
**Warning**: If any external library changes in a backwards-incompatible way, your package may fail to load. If you need long-term reproducibility
 | 
			
		||||
for your package, try to limit your use of ``extern``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
``mock``
 | 
			
		||||
""""""""
 | 
			
		||||
If a module is ``mock``-ed, it will not be packaged. Instead a stub module will be packaged in its place. The stub module will allow you to retrieve
 | 
			
		||||
objects from it (so that ``from my_mocked_module import foo`` will not error), but any use of that object will raise a ``NotImplementedError``.
 | 
			
		||||
 | 
			
		||||
``mock`` should be used for code that you “know” will not be needed in the loaded package, but you still want to available for use in non-packaged contents.
 | 
			
		||||
For example, initialization/configuration code, or code only used for debugging/training.
 | 
			
		||||
 | 
			
		||||
**Warning**: In general, ``mock`` should be used as a last resort. It introduces behavioral differences between packaged code and non-packaged code,
 | 
			
		||||
which may lead to later confusion. Prefer instead to refactor your code to remove unwanted dependencies.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Refactoring
 | 
			
		||||
"""""""""""
 | 
			
		||||
The best way to manage dependencies is to not have dependencies at all! Often, code can be refactored to remove unnecessary dependencies. Here are some
 | 
			
		||||
guidelines for writing code with clean dependencies (which are also generally good practices!):
 | 
			
		||||
 | 
			
		||||
**Include only what you use**. Do not leave unused imports in our code. The dependency resolver is not smart enough to tell that they are indeed unused,
 | 
			
		||||
and will try to process them.
 | 
			
		||||
 | 
			
		||||
**Qualify your imports**. For example, instead of writing import foo and later using ``foo.bar.baz``, prefer to write ``from foo.bar import baz``. This more
 | 
			
		||||
precisely specifies your real dependency (``foo.bar``) and lets the dependency resolver know you don’t need all of ``foo``.
 | 
			
		||||
 | 
			
		||||
**Split up large files with unrelated functionality into smaller ones**. If your ``utils`` module contains a hodge-podge of unrelated functionality, any module
 | 
			
		||||
that depends on ``utils`` will need to pull in lots of unrelated dependencies, even if you only needed a small part of it. Prefer instead to define
 | 
			
		||||
single-purpose modules that can be packaged independently of one another.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Patterns
 | 
			
		||||
""""""""
 | 
			
		||||
Patterns allow you to specify groups of modules with a convenient syntax. The syntax and behavior of patterns follows the Bazel/Buck
 | 
			
		||||
`glob() <https://docs.bazel.build/versions/master/be/functions.html#glob>`_.
 | 
			
		||||
 | 
			
		||||
A module that we are trying to match against a pattern is called a candidate. A candidate is composed of a list of segments separated by a
 | 
			
		||||
separator string, e.g. ``foo.bar.baz``.
 | 
			
		||||
 | 
			
		||||
A pattern contains one or more segments. Segments can be:
 | 
			
		||||
 | 
			
		||||
* A literal string (e.g. ``foo``), which matches exactly.
 | 
			
		||||
* A string containing a wildcard (e.g. ``torch``, or ``foo*baz*``). The wildcard matches any string, including the empty string.
 | 
			
		||||
* A double wildcard (``**``). This matches against zero or more complete segments.
 | 
			
		||||
 | 
			
		||||
Examples:
 | 
			
		||||
 | 
			
		||||
* ``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``.
 | 
			
		||||
* ``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional`` or ``torch``
 | 
			
		||||
* ``torch*.**``: matches ``torch``, ``torchvision``, and all of their submodules
 | 
			
		||||
 | 
			
		||||
When specifying actions, you can pass multiple patterns, e.g.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    exporter.intern(["torchvision.models.**", "torchvision.utils.**"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
A module will match against this action if it matches any of the patterns.
 | 
			
		||||
 | 
			
		||||
You can also specify patterns to exlcude, e.g.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    exporter.mock("**", exclude=["torchvision.**"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
A module will not match against this action if it matches any of the exclude patterns. In this example, we are mocking all modules except
 | 
			
		||||
``torchvision`` and its submodules.
 | 
			
		||||
 | 
			
		||||
When a module could potentially match against multiple actions, the first action defined will be taken.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
``torch.package`` sharp edges
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
Avoid global state in your modules
 | 
			
		||||
""""""""""""""""""""""""""""""""""
 | 
			
		||||
Python makes it really easy to bind objects and run code at module-level scope. This is generally fine—after all, functions and classes are bound to
 | 
			
		||||
names this way. However, things become more complicated when you define an object at module scope with the intention of mutating it, introducing mutable
 | 
			
		||||
global state.
 | 
			
		||||
 | 
			
		||||
Mutable global state is quite useful—it can reduce boilerplate, allow for open registration into tables, etc. But unless employed very carefully, it can
 | 
			
		||||
cause complications when used with ``torch.package``.
 | 
			
		||||
 | 
			
		||||
Every :class:`PackageImporter` creates an independent environment for its contents. This is nice because it means we load multiple packages and ensure
 | 
			
		||||
they are isolated from each other, but when modules are written in a way that assumes shared mutable global state, this behavior can create hard-to-debug
 | 
			
		||||
errors.
 | 
			
		||||
 | 
			
		||||
Types are not shared between packages and the loading environment
 | 
			
		||||
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
 | 
			
		||||
Any class that you import from a :class:`PackageImporter` will be a version of the class specific to that importer. For example:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    from foo import MyClass
 | 
			
		||||
 | 
			
		||||
    my_class_instance = MyClass()
 | 
			
		||||
 | 
			
		||||
    with PackageExporter(f) as exporter:
 | 
			
		||||
        exporter.save_module("foo")
 | 
			
		||||
 | 
			
		||||
    importer = PackageImporter(f)
 | 
			
		||||
    imported_MyClass = importer.import_module("foo").MyClass
 | 
			
		||||
 | 
			
		||||
    assert isinstance(my_class_instance, MyClass)  # works
 | 
			
		||||
    assert isinstance(my_class_instance, imported_MyClass)  # ERROR!
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
In this example, ``MyClass`` and ``import_MyClass`` are *not the same type*. In this specific example, ``MyClass`` and ``import_MyClass`` have exactly the
 | 
			
		||||
same implementation, so you might thing it’s okay to consider them the same class. But consider the situation where ``import_MyClass`` is coming from an
 | 
			
		||||
older package with an entirely different implementation of ``MyClass`` — in that case, it’s unsafe to consider them the same class.
 | 
			
		||||
 | 
			
		||||
Under the hood, each importer has a prefix that allows it to uniquely identify classes:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
::
 | 
			
		||||
 | 
			
		||||
    print(MyClass.__name__)  # prints "foo.MyClass"
 | 
			
		||||
    print(imported_MyClass.__name__)  # prints <torch_package_0>.foo.MyClass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
That means you should not expect ``isinstance`` checks to work when one of the arguments if from a package and the other is not. If you need this
 | 
			
		||||
functionality, consider the following options:
 | 
			
		||||
 | 
			
		||||
* Doing duck typing (just using the class instead of explicitly checking that it is of a given type).
 | 
			
		||||
* Make the typing relationship an explicit part of the class contract. For example, you can add an attribute tag ``self.handler = "handle_me_this_way"`` and have client code check for the value of ``handler`` instead of checking the type directly.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
How ``torch.package`` keeps packages isolated from each other
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
Each :class:`PackageImporter` instance creates an independent, isolated environment for its modules and objects. Modules in a package can only import
 | 
			
		||||
other packaged modules, or modules marked ``extern``. If you use multiple :class:`PackageImporter` instances to load a single package, you will get
 | 
			
		||||
multiple independent environments that do not interact.
 | 
			
		||||
 | 
			
		||||
This is achieved by extending Python’s import infrastructure with a custom importer. :class:`PackageImporter` provides the same core API as the
 | 
			
		||||
``importlib`` importer; namely, it implements the ``import_module`` and ``__import__`` methods.
 | 
			
		||||
 | 
			
		||||
When you invoke :meth:`PackageImporter.import_module`, :class:`PackageImporter` will construct and return a new module, much as the system importer does.
 | 
			
		||||
However, :class:`PackageImporter` patches the returned module to use ``self`` (i.e. that :class:`PackageImporter` instance) to fulfill future import
 | 
			
		||||
requests by looking in the package rather than searching the user’s Python environment.
 | 
			
		||||
 | 
			
		||||
Mangling
 | 
			
		||||
""""""""
 | 
			
		||||
To avoid confusion (“is this ``foo.bar`` object the one from my package, or the one from my Python environment?”), :class:`PackageImporter` mangles the
 | 
			
		||||
``__name__`` and ``__file__`` of all imported modules, by adding a *mangle prefix* to them.
 | 
			
		||||
 | 
			
		||||
For ``__name__``, a name like ``torchvision.models.resnet18`` becomes ``<torch_package_0>.torchvision.models.resnet18``.
 | 
			
		||||
 | 
			
		||||
For ``__file__``, a name like ``torchvision/models/resnet18.py`` becomes ``<torch_package_0>.torchvision/modules/resnet18.py``.
 | 
			
		||||
 | 
			
		||||
Name mangling helps avoid inadvertent punning of module names between different packages, and helps you debug by making stack traces and print
 | 
			
		||||
statements more clearly show whether they are referring to packaged code or not. For developer-facing details about mangling, consult
 | 
			
		||||
``mangling.md`` in ``torch/package/``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
API Reference
 | 
			
		||||
-------------
 | 
			
		||||
 | 
			
		||||
@ -37,6 +37,7 @@ functions = [
 | 
			
		||||
    'RReLU',
 | 
			
		||||
    'SELU',
 | 
			
		||||
    'SiLU',
 | 
			
		||||
    'Mish',
 | 
			
		||||
    'CELU',
 | 
			
		||||
    'GELU',
 | 
			
		||||
    'Sigmoid',
 | 
			
		||||
 | 
			
		||||
@ -397,7 +397,7 @@ and ``values``:
 | 
			
		||||
Construction of CSR tensors
 | 
			
		||||
---------------------------
 | 
			
		||||
 | 
			
		||||
Sparse CSR matrices can be directly constructed by using the :func:`torch.sparse_csr_tensor`
 | 
			
		||||
Sparse CSR matrices can be directly constructed by using the :func:`torch._sparse_csr_tensor`
 | 
			
		||||
method. The user must supply the row and column indices and values tensors separately.
 | 
			
		||||
The ``size`` argument is optional and will be deduced from the the ``crow_indices``
 | 
			
		||||
and ``col_indices`` if it is not present.
 | 
			
		||||
@ -405,7 +405,7 @@ and ``col_indices`` if it is not present.
 | 
			
		||||
    >>> crow_indices = torch.tensor([0, 2, 4])
 | 
			
		||||
    >>> col_indices = torch.tensor([0, 1, 0, 1])
 | 
			
		||||
    >>> values = torch.tensor([1, 2, 3, 4])
 | 
			
		||||
    >>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
 | 
			
		||||
    >>> csr = torch._sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
 | 
			
		||||
    >>> csr
 | 
			
		||||
    tensor(crow_indices=tensor([0, 2, 4]),
 | 
			
		||||
          col_indices=tensor([0, 1, 0, 1]),
 | 
			
		||||
@ -419,11 +419,11 @@ CSR Tensor Operations
 | 
			
		||||
---------------------
 | 
			
		||||
 | 
			
		||||
The simplest way of constructing a sparse CSR tensor from a strided or sparse COO
 | 
			
		||||
tensor is to use :meth:`tensor.to_sparse_csr`. Any zeros in the (strided) tensor will
 | 
			
		||||
tensor is to use :meth:`tensor._to_sparse_csr`. Any zeros in the (strided) tensor will
 | 
			
		||||
be interpreted as missing values in the sparse tensor:
 | 
			
		||||
 | 
			
		||||
    >>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype = torch.float64)
 | 
			
		||||
    >>> sp = a.to_sparse_csr()
 | 
			
		||||
    >>> sp = a._to_sparse_csr()
 | 
			
		||||
    >>> sp
 | 
			
		||||
    tensor(crow_indices=tensor([0, 1, 3, 3]),
 | 
			
		||||
          col_indices=tensor([2, 0, 1]),
 | 
			
		||||
@ -496,7 +496,7 @@ The following Tensor methods are related to sparse tensors:
 | 
			
		||||
    Tensor.sparse_dim
 | 
			
		||||
    Tensor.sparse_mask
 | 
			
		||||
    Tensor.to_sparse
 | 
			
		||||
    Tensor.to_sparse_csr
 | 
			
		||||
    Tensor._to_sparse_csr
 | 
			
		||||
    Tensor.indices
 | 
			
		||||
    Tensor.values
 | 
			
		||||
 | 
			
		||||
@ -581,7 +581,7 @@ Torch functions specific to sparse Tensors
 | 
			
		||||
    :nosignatures:
 | 
			
		||||
 | 
			
		||||
    sparse_coo_tensor
 | 
			
		||||
    sparse_csr_tensor
 | 
			
		||||
    _sparse_csr_tensor
 | 
			
		||||
    sparse.sum
 | 
			
		||||
    sparse.addmm
 | 
			
		||||
    sparse.mm
 | 
			
		||||
 | 
			
		||||
@ -253,7 +253,9 @@ Examples::
 | 
			
		||||
    no_grad
 | 
			
		||||
    enable_grad
 | 
			
		||||
    set_grad_enabled
 | 
			
		||||
    is_grad_enabled
 | 
			
		||||
    inference_mode
 | 
			
		||||
    is_inference_mode_enabled
 | 
			
		||||
 | 
			
		||||
Math operations
 | 
			
		||||
---------------
 | 
			
		||||
@ -575,5 +577,4 @@ Utilities
 | 
			
		||||
    are_deterministic_algorithms_enabled
 | 
			
		||||
    set_warn_always
 | 
			
		||||
    is_warn_always_enabled
 | 
			
		||||
    vmap
 | 
			
		||||
    _assert
 | 
			
		||||
 | 
			
		||||
@ -1,14 +1,19 @@
 | 
			
		||||
#import "Benchmark.h"
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include "torch/script.h"
 | 
			
		||||
 | 
			
		||||
#include <torch/csrc/jit/mobile/function.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/import.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/interpreter.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/module.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/observer.h>
 | 
			
		||||
#include "ATen/ATen.h"
 | 
			
		||||
#include "caffe2/core/timer.h"
 | 
			
		||||
#include "caffe2/utils/string_utils.h"
 | 
			
		||||
#include "torch/csrc/autograd/grad_mode.h"
 | 
			
		||||
#include "torch/csrc/jit/serialization/import.h"
 | 
			
		||||
#include "torch/script.h"
 | 
			
		||||
 | 
			
		||||
static std::string model = "model.pt";
 | 
			
		||||
static std::string model = "model.ptl";
 | 
			
		||||
static std::string input_dims = "1,3,224,224";
 | 
			
		||||
static std::string input_type = "float";
 | 
			
		||||
static BOOL print_output = false;
 | 
			
		||||
@ -18,9 +23,9 @@ static int iter = 10;
 | 
			
		||||
@implementation Benchmark
 | 
			
		||||
 | 
			
		||||
+ (BOOL)setup:(NSDictionary*)config {
 | 
			
		||||
  NSString* modelPath = [[NSBundle mainBundle] pathForResource:@"model" ofType:@"pt"];
 | 
			
		||||
  NSString* modelPath = [[NSBundle mainBundle] pathForResource:@"model" ofType:@"ptl"];
 | 
			
		||||
  if (![[NSFileManager defaultManager] fileExistsAtPath:modelPath]) {
 | 
			
		||||
    NSLog(@"model.pt doesn't exist!");
 | 
			
		||||
    NSLog(@"model.ptl doesn't exist!");
 | 
			
		||||
    return NO;
 | 
			
		||||
  }
 | 
			
		||||
  model = std::string(modelPath.UTF8String);
 | 
			
		||||
@ -66,10 +71,9 @@ static int iter = 10;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::InferenceMode mode;
 | 
			
		||||
  torch::jit::GraphOptimizerEnabledGuard opguard(false);
 | 
			
		||||
  auto module = torch::jit::load(model);
 | 
			
		||||
  auto module = torch::jit::_load_for_mobile(model);
 | 
			
		||||
 | 
			
		||||
  module.eval();
 | 
			
		||||
//  module.eval();
 | 
			
		||||
  if (print_output) {
 | 
			
		||||
    std::cout << module.forward(inputs) << std::endl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1,13 +1,22 @@
 | 
			
		||||
#import <XCTest/XCTest.h>
 | 
			
		||||
 | 
			
		||||
#include <torch/script.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/function.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/import.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/interpreter.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/module.h>
 | 
			
		||||
#include <torch/csrc/jit/mobile/observer.h>
 | 
			
		||||
#include "ATen/ATen.h"
 | 
			
		||||
#include "caffe2/core/timer.h"
 | 
			
		||||
#include "caffe2/utils/string_utils.h"
 | 
			
		||||
#include "torch/csrc/autograd/grad_mode.h"
 | 
			
		||||
 | 
			
		||||
@interface TestAppTests : XCTestCase
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
@implementation TestAppTests {
 | 
			
		||||
  torch::jit::Module _module;
 | 
			
		||||
  torch::jit::mobile::Module _module;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (void)setUp {
 | 
			
		||||
@ -17,14 +26,14 @@
 | 
			
		||||
- (void)setUp {
 | 
			
		||||
  [super setUp];
 | 
			
		||||
  NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model"
 | 
			
		||||
                                                                         ofType:@"pt"];
 | 
			
		||||
                                                                         ofType:@"ptl"];
 | 
			
		||||
  XCTAssertTrue([NSFileManager.defaultManager fileExistsAtPath:modelPath],
 | 
			
		||||
                @"model.pt doesn't exist!");
 | 
			
		||||
  _module = torch::jit::load(modelPath.UTF8String);
 | 
			
		||||
                @"model.ptl doesn't exist!");
 | 
			
		||||
  _module = torch::jit::_load_for_mobile(modelPath.UTF8String);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testForward {
 | 
			
		||||
  _module.eval();
 | 
			
		||||
//  _module.eval();
 | 
			
		||||
  c10::InferenceMode mode;
 | 
			
		||||
  std::vector<c10::IValue> inputs;
 | 
			
		||||
  inputs.push_back(torch::ones({1, 3, 224, 224}, at::ScalarType::Float));
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ targets.each do |target|
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
puts "Installing the testing model..."
 | 
			
		||||
model_path = File.expand_path("./model.pt")
 | 
			
		||||
model_path = File.expand_path("./model.ptl")
 | 
			
		||||
if not File.exist?(model_path)
 | 
			
		||||
   raise "model.pt can't be found!"
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,10 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torchvision
 | 
			
		||||
from torch.utils.mobile_optimizer import optimize_for_mobile
 | 
			
		||||
 | 
			
		||||
model = torchvision.models.mobilenet_v2(pretrained=True)
 | 
			
		||||
model.eval()
 | 
			
		||||
example = torch.rand(1, 3, 224, 224)
 | 
			
		||||
traced_script_module = torch.jit.trace(model, example)
 | 
			
		||||
traced_script_module.save("model.pt")
 | 
			
		||||
traced_script_module = torch.jit.script(model, example)
 | 
			
		||||
optimized_scripted_module = optimize_for_mobile(traced_script_module)
 | 
			
		||||
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter("model.ptl")
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
# Python dependencies required for development
 | 
			
		||||
astunparse
 | 
			
		||||
future
 | 
			
		||||
numpy
 | 
			
		||||
numpy<1.21
 | 
			
		||||
psutil
 | 
			
		||||
pyyaml
 | 
			
		||||
requests
 | 
			
		||||
 | 
			
		||||
@ -104,12 +104,13 @@ fi
 | 
			
		||||
CMAKE_ARGS+=("-DBUILD_TEST=OFF")
 | 
			
		||||
CMAKE_ARGS+=("-DBUILD_BINARY=OFF")
 | 
			
		||||
 | 
			
		||||
# If there exists env variable and it equals to 1, build lite interpreter.
 | 
			
		||||
# cmd:  BUILD_LITE_INTERPRETER=1 ./scripts/build_android.sh
 | 
			
		||||
if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then
 | 
			
		||||
  CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
 | 
			
		||||
else
 | 
			
		||||
# If there exists env variable and it equals to 0, build full jit interpreter.
 | 
			
		||||
# Default behavior is to build lite interpreter
 | 
			
		||||
# cmd:  BUILD_LITE_INTERPRETER=0 ./scripts/build_android.sh
 | 
			
		||||
if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then
 | 
			
		||||
  CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF")
 | 
			
		||||
else
 | 
			
		||||
  CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
 | 
			
		||||
fi
 | 
			
		||||
CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK")
 | 
			
		||||
CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST")
 | 
			
		||||
 | 
			
		||||
@ -78,10 +78,10 @@ if [ -n "${IOS_ARCH:-}" ]; then
 | 
			
		||||
  CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}")
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then
 | 
			
		||||
  CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
 | 
			
		||||
else
 | 
			
		||||
if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then
 | 
			
		||||
  CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF")
 | 
			
		||||
else
 | 
			
		||||
  CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Don't build binaries or tests (only the library)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								setup.py
									
									
									
									
									
								
							@ -43,6 +43,10 @@
 | 
			
		||||
#   USE_MKLDNN=0
 | 
			
		||||
#     disables use of MKLDNN
 | 
			
		||||
#
 | 
			
		||||
#   USE_MKLDNN_ACL
 | 
			
		||||
#     enables use of Compute Library backend for MKLDNN on Arm;
 | 
			
		||||
#     USE_MKLDNN must be explicitly enabled.
 | 
			
		||||
#
 | 
			
		||||
#   MKLDNN_CPU_RUNTIME
 | 
			
		||||
#     MKL-DNN threading mode: TBB or OMP (default)
 | 
			
		||||
#
 | 
			
		||||
@ -156,6 +160,9 @@
 | 
			
		||||
#   NVTOOLSEXT_PATH (Windows only)
 | 
			
		||||
#     specify where nvtoolsext is installed
 | 
			
		||||
#
 | 
			
		||||
#   ACL_ROOT_DIR
 | 
			
		||||
#     specify where Compute Library is installed
 | 
			
		||||
#
 | 
			
		||||
#   LIBRARY_PATH
 | 
			
		||||
#   LD_LIBRARY_PATH
 | 
			
		||||
#     we will search for libraries in these paths
 | 
			
		||||
@ -460,6 +467,10 @@ class build_ext(setuptools.command.build_ext.build_ext):
 | 
			
		||||
            report('-- Not using CUDA')
 | 
			
		||||
        if cmake_cache_vars['USE_MKLDNN']:
 | 
			
		||||
            report('-- Using MKLDNN')
 | 
			
		||||
            if cmake_cache_vars['USE_MKLDNN_ACL']:
 | 
			
		||||
                report('-- Using Compute Library for the Arm architecture with MKLDNN')
 | 
			
		||||
            else:
 | 
			
		||||
                report('-- Not using Compute Library for the Arm architecture with MKLDNN')
 | 
			
		||||
            if cmake_cache_vars['USE_MKLDNN_CBLAS']:
 | 
			
		||||
                report('-- Using CBLAS in MKLDNN')
 | 
			
		||||
            else:
 | 
			
		||||
 | 
			
		||||
@ -89,6 +89,7 @@ allow_list = [
 | 
			
		||||
    ("aten::_amp_update_scale", datetime.date(2021, 6, 1)),
 | 
			
		||||
    ("aten::randperm", datetime.date(9999, 1, 1)),
 | 
			
		||||
    ("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
 | 
			
		||||
    ("aten::sparse_csr_tensor", datetime.date(9999, 1, 1)),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
def allow_listed(schema, allow_list):
 | 
			
		||||
 | 
			
		||||
@ -1749,6 +1749,15 @@ TEST_F(FunctionalTest, Softsign) {
 | 
			
		||||
  ASSERT_TRUE(torch::allclose(y, y_exp));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
TEST_F(FunctionalTest, Mish) {
 | 
			
		||||
  auto x = torch::randn(100) * 10;
 | 
			
		||||
  auto y_exp = x * x.exp().log1p().tanh();
 | 
			
		||||
  auto y = F::mish(x);
 | 
			
		||||
 | 
			
		||||
  ASSERT_TRUE(torch::allclose(y, y_exp));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
TEST_F(FunctionalTest, Tanhshrink) {
 | 
			
		||||
  auto x = torch::randn(100) * 10;
 | 
			
		||||
 | 
			
		||||
@ -2958,6 +2958,16 @@ TEST_F(ModulesTest, GELU) {
 | 
			
		||||
  ASSERT_TRUE(torch::allclose(y, y_exp));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
TEST_F(ModulesTest, Mish) {
 | 
			
		||||
  Mish model;
 | 
			
		||||
  auto x = torch::randn(100) * 10;
 | 
			
		||||
  auto y_exp = x * x.exp().log1p().tanh();
 | 
			
		||||
  auto y = model(x);
 | 
			
		||||
 | 
			
		||||
  ASSERT_TRUE(torch::allclose(y, y_exp));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 | 
			
		||||
TEST_F(ModulesTest, Sigmoid) {
 | 
			
		||||
  Sigmoid model;
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,7 @@ torch::nn::Hardshrink|Yes|No
 | 
			
		||||
torch::nn::Hardtanh|Yes|No
 | 
			
		||||
torch::nn::LeakyReLU|Yes|No
 | 
			
		||||
torch::nn::LogSigmoid|Yes|No
 | 
			
		||||
torch::nn::Mish|Yes|No
 | 
			
		||||
torch::nn::MultiheadAttention|No|No
 | 
			
		||||
torch::nn::PReLU|Yes|No
 | 
			
		||||
torch::nn::ReLU|Yes|No
 | 
			
		||||
@ -187,6 +188,7 @@ F::rrelu|Yes|No
 | 
			
		||||
F::glu|Yes|No
 | 
			
		||||
F::gelu|Yes|No
 | 
			
		||||
F::silu|Yes|No
 | 
			
		||||
F::mish|Yes|No
 | 
			
		||||
F::logsigmoid|Yes|No
 | 
			
		||||
F::hardshrink|Yes|No
 | 
			
		||||
F::tanhshrink|Yes|No
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										76
									
								
								test/distributed/launcher/bin/test_script_init_method.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										76
									
								
								test/distributed/launcher/bin/test_script_init_method.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_args():
 | 
			
		||||
    parser = argparse.ArgumentParser(description="test script")
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--init_method",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="init_method to pass to `dist.init_process_group()` (e.g. env://)",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--world_size",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=os.getenv("WORLD_SIZE", -1),
 | 
			
		||||
        help="world_size to pass to `dist.init_process_group()`",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--rank",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=os.getenv("RANK", -1),
 | 
			
		||||
        help="rank to pass to `dist.init_process_group()`",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    args = parse_args()
 | 
			
		||||
 | 
			
		||||
    dist.init_process_group(
 | 
			
		||||
        backend="gloo",
 | 
			
		||||
        init_method=args.init_method,
 | 
			
		||||
        world_size=args.world_size,
 | 
			
		||||
        rank=args.rank,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    rank = dist.get_rank()
 | 
			
		||||
    world_size = dist.get_world_size()
 | 
			
		||||
 | 
			
		||||
    # one hot (by rank) tensor of size world_size
 | 
			
		||||
    # example:
 | 
			
		||||
    # rank 0, world_size 4 => [1, 0, 0, 0]
 | 
			
		||||
    # rank 1, world_size 4 => [0, 1, 0, 0]
 | 
			
		||||
    # ...
 | 
			
		||||
    t = F.one_hot(torch.tensor(rank), num_classes=world_size)
 | 
			
		||||
 | 
			
		||||
    # after all_reduce t = tensor.ones(size=world_size)
 | 
			
		||||
    dist.all_reduce(t)
 | 
			
		||||
 | 
			
		||||
    # adding all elements in t should equal world_size
 | 
			
		||||
    derived_world_size = torch.sum(t).item()
 | 
			
		||||
    if derived_world_size != world_size:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            f"Wrong world size derived. Expected: {world_size}, Got: {derived_world_size}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    print("Done")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										42
									
								
								test/distributed/launcher/bin/test_script_is_torchelastic_launched.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										42
									
								
								test/distributed/launcher/bin/test_script_is_torchelastic_launched.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This is a test script that launches as part of the test cases in
 | 
			
		||||
run_test.py, to validate the correctness of
 | 
			
		||||
the method ``torch.distributed.is_torchelastic_launched()``. To do so,
 | 
			
		||||
we run this script with and without torchelastic and validate that the
 | 
			
		||||
boolean value written to the out_file is indeed what we expect (e.g.
 | 
			
		||||
should be False when not launched with torchelastic, True when launched with)
 | 
			
		||||
The script itself is not a test case hence no assertions are made in this script.
 | 
			
		||||
 | 
			
		||||
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
 | 
			
		||||
     - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
 | 
			
		||||
"""
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_args():
 | 
			
		||||
    parser = argparse.ArgumentParser(description="test script")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--out_file",
 | 
			
		||||
        help="file to write indicating whether this script was launched with torchelastic",
 | 
			
		||||
    )
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    args = parse_args()
 | 
			
		||||
    with open(args.out_file, "w") as out:
 | 
			
		||||
        out.write(f"{dist.is_torchelastic_launched()}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -7,8 +7,10 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
import multiprocessing as mp
 | 
			
		||||
import os
 | 
			
		||||
import runpy
 | 
			
		||||
import shutil
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
import uuid
 | 
			
		||||
@ -21,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
 | 
			
		||||
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
 | 
			
		||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
 | 
			
		||||
from torch.distributed.elastic.utils import get_socket_with_port
 | 
			
		||||
from torch.distributed.elastic.utils.distributed import get_free_port
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
    TEST_WITH_ASAN,
 | 
			
		||||
    TEST_WITH_TSAN,
 | 
			
		||||
@ -169,6 +172,35 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
			
		||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan")
 | 
			
		||||
    def test_launch_user_script_default_nproc(self):
 | 
			
		||||
        run_id = str(uuid.uuid4().int)
 | 
			
		||||
        nnodes = 1
 | 
			
		||||
        world_size = 1
 | 
			
		||||
        args = [
 | 
			
		||||
            f"--nnodes={nnodes}",
 | 
			
		||||
            "--rdzv_backend=etcd",
 | 
			
		||||
            f"--rdzv_endpoint={self._etcd_endpoint}",
 | 
			
		||||
            f"--rdzv_id={run_id}",
 | 
			
		||||
            "--monitor_interval=1",
 | 
			
		||||
            "--start_method=fork",
 | 
			
		||||
            "--no_python",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            # --no_python cannot be used with --module
 | 
			
		||||
            launch.main(args + ["--module"] + script_args)
 | 
			
		||||
 | 
			
		||||
        launch.main(args + script_args)
 | 
			
		||||
 | 
			
		||||
        # make sure all the workers ran
 | 
			
		||||
        # each worker touches a file with its global rank as the name
 | 
			
		||||
        self.assertSetEqual(
 | 
			
		||||
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan")
 | 
			
		||||
    def test_launch_with_env_vars(self):
 | 
			
		||||
        run_id = str(uuid.uuid4().int)
 | 
			
		||||
@ -447,3 +479,117 @@ class ElasticLaunchTest(unittest.TestCase):
 | 
			
		||||
            param_mock.return_value = rdzv_handler_mock
 | 
			
		||||
            launch.main(args)
 | 
			
		||||
            rdzv_handler_mock.shutdown.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
			
		||||
    def test_is_torchelastic_launched(self):
 | 
			
		||||
        # launch test script with torchelastic and validate that
 | 
			
		||||
        # torch.distributed.is_torchelastic_launched() returns True
 | 
			
		||||
 | 
			
		||||
        out_file = f"{os.path.join(self.test_dir, 'out')}"
 | 
			
		||||
 | 
			
		||||
        launch.main(
 | 
			
		||||
            [
 | 
			
		||||
                "--run_path",
 | 
			
		||||
                "--nnodes=1",
 | 
			
		||||
                "--nproc_per_node=1",
 | 
			
		||||
                "--monitor_interval=1",
 | 
			
		||||
                path("bin/test_script_is_torchelastic_launched.py"),
 | 
			
		||||
                f"--out_file={out_file}",
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with open(out_file, "r") as fp:
 | 
			
		||||
            is_torchelastic_launched = fp.readline()
 | 
			
		||||
            self.assertEqual("True", is_torchelastic_launched)
 | 
			
		||||
 | 
			
		||||
    def test_is_not_torchelastic_launched(self):
 | 
			
		||||
        # launch test script without torchelastic and validate that
 | 
			
		||||
        # torch.distributed.is_torchelastic_launched() returns False
 | 
			
		||||
 | 
			
		||||
        out_file = f"{os.path.join(self.test_dir, 'out')}"
 | 
			
		||||
 | 
			
		||||
        # need to run the script with runpy in the same interpreter
 | 
			
		||||
        # as the test because otherwise (depending on the environment)
 | 
			
		||||
        # it will not find torch as a dependency
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            sys,
 | 
			
		||||
            "argv",
 | 
			
		||||
            [
 | 
			
		||||
                path("bin/test_script_is_torchelastic_launched.py"),
 | 
			
		||||
                f"--out_file={out_file}",
 | 
			
		||||
            ],
 | 
			
		||||
        ):
 | 
			
		||||
            runpy.run_path(sys.argv[0], run_name="__main__")
 | 
			
		||||
            with open(out_file, "r") as fp:
 | 
			
		||||
                is_torchelastic_launched = fp.readline()
 | 
			
		||||
                self.assertEqual("False", is_torchelastic_launched)
 | 
			
		||||
 | 
			
		||||
    def test_init_method_tcp(self):
 | 
			
		||||
        port = get_free_port()
 | 
			
		||||
        with patch.object(
 | 
			
		||||
            sys,
 | 
			
		||||
            "argv",
 | 
			
		||||
            [
 | 
			
		||||
                path("bin/test_script_init_method.py"),
 | 
			
		||||
                f"--init_method=tcp://localhost:{port}",
 | 
			
		||||
                "--rank=0",
 | 
			
		||||
                "--world_size=1",
 | 
			
		||||
            ],
 | 
			
		||||
        ):
 | 
			
		||||
            runpy.run_path(sys.argv[0], run_name="__main__")
 | 
			
		||||
            # nothing to validate, just make sure it runs
 | 
			
		||||
 | 
			
		||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
			
		||||
    def test_init_method_tcp_with_torchelastic(self):
 | 
			
		||||
        port = get_free_port()
 | 
			
		||||
        launch.main(
 | 
			
		||||
            [
 | 
			
		||||
                "--run_path",
 | 
			
		||||
                "--nnodes=1",
 | 
			
		||||
                "--nproc_per_node=4",
 | 
			
		||||
                "--master_addr=localhost",
 | 
			
		||||
                f"--master_port={port}",
 | 
			
		||||
                "--monitor_interval=1",
 | 
			
		||||
                path("bin/test_script_init_method.py"),
 | 
			
		||||
                f"--init_method=tcp://localhost:{port}",
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        # nothing to validate, just make sure it runs
 | 
			
		||||
 | 
			
		||||
    def test_init_method_env(self):
 | 
			
		||||
        port = get_free_port()
 | 
			
		||||
        with patch.dict(
 | 
			
		||||
            os.environ,
 | 
			
		||||
            {
 | 
			
		||||
                "RANK": "0",
 | 
			
		||||
                "WORLD_SIZE": "1",
 | 
			
		||||
                "MASTER_ADDR": "localhost",
 | 
			
		||||
                "MASTER_PORT": str(port),
 | 
			
		||||
            },
 | 
			
		||||
        ), patch.object(
 | 
			
		||||
            sys,
 | 
			
		||||
            "argv",
 | 
			
		||||
            [
 | 
			
		||||
                path("bin/test_script_init_method.py"),
 | 
			
		||||
                "--init_method=env://",
 | 
			
		||||
            ],
 | 
			
		||||
        ):
 | 
			
		||||
            runpy.run_path(sys.argv[0], run_name="__main__")
 | 
			
		||||
            # nothing to validate, just make sure it runs
 | 
			
		||||
 | 
			
		||||
    @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
 | 
			
		||||
    def test_init_method_env_with_torchelastic(self):
 | 
			
		||||
        port = get_free_port()
 | 
			
		||||
        launch.main(
 | 
			
		||||
            [
 | 
			
		||||
                "--run_path",
 | 
			
		||||
                "--nnodes=1",
 | 
			
		||||
                "--nproc_per_node=4",
 | 
			
		||||
                "--master_addr=localhost",
 | 
			
		||||
                f"--master_port={port}",
 | 
			
		||||
                "--monitor_interval=1",
 | 
			
		||||
                path("bin/test_script_init_method.py"),
 | 
			
		||||
                "--init_method=env://",
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        # nothing to validate, just make sure it runs
 | 
			
		||||
 | 
			
		||||
@ -5373,6 +5373,18 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
        x = torch.randn(2, 3, 4)
 | 
			
		||||
        self.run_test(SiLUModel(), (x))
 | 
			
		||||
 | 
			
		||||
    def test_mish(self):
 | 
			
		||||
        class MishModel(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(MishModel, self).__init__()
 | 
			
		||||
                self.mish = torch.nn.Mish()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return self.mish(x)
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3, 4)
 | 
			
		||||
        self.run_test(MishModel(), (x))
 | 
			
		||||
 | 
			
		||||
    def test_remainder(self):
 | 
			
		||||
        class RemainderModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, input, other):
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@ import sys
 | 
			
		||||
from tempfile import NamedTemporaryFile
 | 
			
		||||
 | 
			
		||||
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
 | 
			
		||||
import torch.package.package_exporter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PackageTestCase(TestCase):
 | 
			
		||||
@ -29,8 +28,6 @@ class PackageTestCase(TestCase):
 | 
			
		||||
        self.package_test_dir = os.path.dirname(os.path.realpath(__file__))
 | 
			
		||||
        self.orig_sys_path = sys.path.copy()
 | 
			
		||||
        sys.path.append(self.package_test_dir)
 | 
			
		||||
        torch.package.package_exporter._gate_torchscript_serialization = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        super().tearDown()
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ from unittest import skipIf
 | 
			
		||||
 | 
			
		||||
from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
 | 
			
		||||
from torch.package.package_exporter import PackagingError
 | 
			
		||||
from torch.testing._internal.common_utils import run_tests
 | 
			
		||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from .common import PackageTestCase
 | 
			
		||||
@ -224,19 +224,27 @@ class TestDependencyAPI(PackageTestCase):
 | 
			
		||||
 | 
			
		||||
        buffer = BytesIO()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
        with self.assertRaises(PackagingError) as e:
 | 
			
		||||
            with PackageExporter(buffer, verbose=False) as he:
 | 
			
		||||
                he.save_pickle("obj", "obj.pkl", obj2)
 | 
			
		||||
        except PackagingError as e:
 | 
			
		||||
            self.assertEqual(e.unhandled, set(["package_a", "package_a.subpackage"]))
 | 
			
		||||
        else:
 | 
			
		||||
            self.fail("PackagingError should have been raised")
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            str(e.exception),
 | 
			
		||||
            dedent(
 | 
			
		||||
                """
 | 
			
		||||
                * Module did not match against any action pattern. Extern, mock, or intern it.
 | 
			
		||||
                    package_a
 | 
			
		||||
                    package_a.subpackage
 | 
			
		||||
                """
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Interning all dependencies should work
 | 
			
		||||
        with PackageExporter(buffer, verbose=False) as he:
 | 
			
		||||
            he.intern(["package_a", "package_a.subpackage"])
 | 
			
		||||
            he.save_pickle("obj", "obj.pkl", obj2)
 | 
			
		||||
 | 
			
		||||
    @skipIf(IS_WINDOWS, "extension modules have a different file extension on windows")
 | 
			
		||||
    def test_broken_dependency(self):
 | 
			
		||||
        """A unpackageable dependency should raise a PackagingError."""
 | 
			
		||||
 | 
			
		||||
@ -262,16 +270,42 @@ class TestDependencyAPI(PackageTestCase):
 | 
			
		||||
 | 
			
		||||
        buffer = BytesIO()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
        with self.assertRaises(PackagingError) as e:
 | 
			
		||||
            with PackageExporter(
 | 
			
		||||
                buffer, verbose=False, importer=BrokenImporter()
 | 
			
		||||
            ) as exporter:
 | 
			
		||||
                exporter.intern(["foo", "bar"])
 | 
			
		||||
                exporter.save_source_string("my_module", "import foo; import bar")
 | 
			
		||||
        except PackagingError as e:
 | 
			
		||||
            self.assertEqual(set(e.broken.keys()), set(["foo", "bar"]))
 | 
			
		||||
        else:
 | 
			
		||||
            self.fail("PackagingError should have been raised")
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            str(e.exception),
 | 
			
		||||
            dedent(
 | 
			
		||||
                """
 | 
			
		||||
                * Module is a C extension module. torch.package supports Python modules only.
 | 
			
		||||
                    foo
 | 
			
		||||
                    bar
 | 
			
		||||
                """
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_invalid_import(self):
 | 
			
		||||
        """An incorrectly-formed import should raise a PackagingError."""
 | 
			
		||||
        buffer = BytesIO()
 | 
			
		||||
        with self.assertRaises(PackagingError) as e:
 | 
			
		||||
            with PackageExporter(buffer, verbose=False) as exporter:
 | 
			
		||||
                # This import will fail to load.
 | 
			
		||||
                exporter.save_source_string("foo", "from ........ import lol")
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            str(e.exception),
 | 
			
		||||
            dedent(
 | 
			
		||||
                """
 | 
			
		||||
                * Dependency resolution failed.
 | 
			
		||||
                    foo
 | 
			
		||||
                      Context: attempted relative import beyond top-level package
 | 
			
		||||
                """
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
@ -3272,6 +3272,24 @@ class TestQuantizeFxOps(QuantizationTestCase):
 | 
			
		||||
        self._test_default_node_quant_handler_ops(
 | 
			
		||||
            module, functional, qconfig, is_reference, node_list)
 | 
			
		||||
 | 
			
		||||
    def test_mish_reference(self):
 | 
			
		||||
        module = torch.nn.Mish
 | 
			
		||||
        functional = torch.nn.functional.mish
 | 
			
		||||
        qconfig = float16_static_qconfig
 | 
			
		||||
        is_reference = True
 | 
			
		||||
        node_list = [
 | 
			
		||||
            ns.call_method("to"),
 | 
			
		||||
            ns.call_method("dequantize"),
 | 
			
		||||
            ns.call_module(module),
 | 
			
		||||
            ns.call_method("to"),
 | 
			
		||||
            ns.call_method('dequantize'),
 | 
			
		||||
            ns.call_function(functional),
 | 
			
		||||
            ns.call_method("to"),
 | 
			
		||||
            ns.call_method('dequantize')
 | 
			
		||||
        ]
 | 
			
		||||
        self._test_default_node_quant_handler_ops(
 | 
			
		||||
            module, functional, qconfig, is_reference, node_list)
 | 
			
		||||
 | 
			
		||||
    def test_bmm_int_reference(self):
 | 
			
		||||
        class M(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
 | 
			
		||||
@ -197,7 +197,7 @@ class TestStaticQuantizedModule(QuantizationTestCase):
 | 
			
		||||
        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)
 | 
			
		||||
 | 
			
		||||
        # Make sure `from_float` works for all linear variants
 | 
			
		||||
        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias]
 | 
			
		||||
        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
 | 
			
		||||
 | 
			
		||||
        for mut in modules_under_test:
 | 
			
		||||
            # Test from_float.
 | 
			
		||||
@ -978,7 +978,7 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
 | 
			
		||||
        # Test JIT
 | 
			
		||||
        self.checkScriptable(qlinear, [[X]], check_save_load=True)
 | 
			
		||||
 | 
			
		||||
        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear._LinearWithBias]
 | 
			
		||||
        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
 | 
			
		||||
        for mut in modules_under_test:
 | 
			
		||||
            # Test from_float
 | 
			
		||||
            float_linear = mut(in_features, out_features).float()
 | 
			
		||||
 | 
			
		||||
@ -886,7 +886,7 @@ class TestCppExtensionJIT(common.TestCase):
 | 
			
		||||
            #include <torch/torch.h>
 | 
			
		||||
 | 
			
		||||
            int fail() {{
 | 
			
		||||
                torch::crash_handler::_enable_minidump_collection("{destination}");
 | 
			
		||||
                torch::crash_handler::enable_minidumps("{destination}");
 | 
			
		||||
 | 
			
		||||
                volatile int* bad = nullptr;
 | 
			
		||||
                return *bad;
 | 
			
		||||
 | 
			
		||||
@ -14,13 +14,15 @@ from unittest import skipIf
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from torch.testing._internal.common_utils import (TestCase, run_tests)
 | 
			
		||||
from torch.utils.data import \
 | 
			
		||||
    (IterDataPipe, RandomSampler, DataLoader,
 | 
			
		||||
     argument_validation, runtime_validation_disabled, runtime_validation)
 | 
			
		||||
from torch.utils.data import (
 | 
			
		||||
    IterDataPipe, RandomSampler, DataLoader,
 | 
			
		||||
    argument_validation, runtime_validation_disabled, runtime_validation
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from typing import \
 | 
			
		||||
    (Any, Dict, Generic, Iterator, List, NamedTuple, Optional, Tuple, Type,
 | 
			
		||||
     TypeVar, Set, Union)
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any, Awaitable, Dict, Generic, Iterator, List, NamedTuple, Optional, Tuple,
 | 
			
		||||
    Type, TypeVar, Set, Union
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import torch.utils.data.datapipes as dp
 | 
			
		||||
from torch.utils.data.datapipes.utils.decoder import (
 | 
			
		||||
@ -707,7 +709,7 @@ class TestTyping(TestCase):
 | 
			
		||||
        dp2 = DP1(5)
 | 
			
		||||
        self.assertEqual(dp1.type, dp2.type)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"Can not subclass a DataPipe"):
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"is not a generic class"):
 | 
			
		||||
            class InvalidDP5(DP1[tuple]):  # type: ignore[type-arg]
 | 
			
		||||
                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
 | 
			
		||||
                    yield (0, )
 | 
			
		||||
@ -754,7 +756,8 @@ class TestTyping(TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP5, IterDataPipe))
 | 
			
		||||
        dp = DP5()  # type: ignore[assignment]
 | 
			
		||||
        self.assertTrue(dp.type.param == Any)
 | 
			
		||||
        from torch.utils.data._typing import issubtype
 | 
			
		||||
        self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param))
 | 
			
		||||
 | 
			
		||||
        class DP6(IterDataPipe[int]):
 | 
			
		||||
            r""" DataPipe with plain Iterator"""
 | 
			
		||||
@ -765,6 +768,19 @@ class TestTyping(TestCase):
 | 
			
		||||
        dp = DP6()  # type: ignore[assignment]
 | 
			
		||||
        self.assertTrue(dp.type.param == int)
 | 
			
		||||
 | 
			
		||||
        class DP7(IterDataPipe[Awaitable[T_co]]):
 | 
			
		||||
            r""" DataPipe with abstract base class"""
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP6, IterDataPipe))
 | 
			
		||||
        self.assertTrue(DP7.type.param == Awaitable[T_co])
 | 
			
		||||
 | 
			
		||||
        class DP8(DP7[str]):
 | 
			
		||||
            r""" DataPipe subclass from a DataPipe with abc type"""
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP8, IterDataPipe))
 | 
			
		||||
        self.assertTrue(DP8.type.param == Awaitable[str])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_construct_time(self):
 | 
			
		||||
        class DP0(IterDataPipe[Tuple]):
 | 
			
		||||
            @argument_validation
 | 
			
		||||
 | 
			
		||||
@ -2728,6 +2728,7 @@ class TestFunctionalTracing(JitTestCase):
 | 
			
		||||
        "rrelu": CONTROL_FLOW,
 | 
			
		||||
        "selu": CONTROL_FLOW,
 | 
			
		||||
        "silu": CONTROL_FLOW,
 | 
			
		||||
        "mish": CONTROL_FLOW,
 | 
			
		||||
        "smooth_l1_loss": CONTROL_FLOW,
 | 
			
		||||
        "soft_margin_loss": CONTROL_FLOW,
 | 
			
		||||
        "threshold": CONTROL_FLOW,
 | 
			
		||||
 | 
			
		||||
@ -1699,10 +1699,10 @@ graph(%Ra, %Rb):
 | 
			
		||||
        self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings
 | 
			
		||||
    def test_sparse_csr_tensors(self):
 | 
			
		||||
    def test__sparse_csr_tensors(self):
 | 
			
		||||
        @torch.jit.ignore
 | 
			
		||||
        def get_sparse_csr():
 | 
			
		||||
            return torch.randn(3, 3).to_sparse_csr()
 | 
			
		||||
            return torch.randn(3, 3)._to_sparse_csr()
 | 
			
		||||
 | 
			
		||||
        @torch.jit.script
 | 
			
		||||
        def test_is_sparse_csr(input):
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user