mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Compare commits
39 Commits
v2.0.0-rc2
...
v2.0.0
| Author | SHA1 | Date | |
|---|---|---|---|
| c263bd43e8 | |||
| c9913cf66f | |||
| 2f7d8bbf17 | |||
| ca0cdf52ca | |||
| 9cfa076da8 | |||
| 8e05e41dbc | |||
| d8ffc60bc1 | |||
| 1483723037 | |||
| c4572aa1b7 | |||
| 82b078ba64 | |||
| 77f7bc5f9d | |||
| 0865964576 | |||
| f18ac1b386 | |||
| c04134cdb1 | |||
| 72d0863ab2 | |||
| 1bd334dc25 | |||
| 93e13cd429 | |||
| 4e4d4b0afe | |||
| c4fa850827 | |||
| 36ead09873 | |||
| 66d23dbad7 | |||
| e2fff58844 | |||
| 735333a7ff | |||
| 6017488801 | |||
| e51e5e721c | |||
| 91739a0279 | |||
| 531f097b6f | |||
| 00eb7b0d78 | |||
| 2180f342c4 | |||
| a90b4f09ac | |||
| 1211ceeaa4 | |||
| beaa5c5908 | |||
| 4bd5c1e4f4 | |||
| f3c97a4e43 | |||
| 30cf0e70f7 | |||
| 96f627dcde | |||
| 6f11e6d6a1 | |||
| fcec27f7d5 | |||
| cddcb1e526 |
2
.github/ci_commit_pins/triton.txt
vendored
2
.github/ci_commit_pins/triton.txt
vendored
@ -1 +1 @@
|
||||
d54c04abe2c3e67b2139c68cdbda87b59e8dd01b
|
||||
b8b470bc597c1c5bd03682c09fe3e6b7c53787fd
|
||||
|
||||
@ -1 +1 @@
|
||||
pytorch-triton-rocm>=2.0.0.dev
|
||||
pytorch-triton-rocm>=2.0.0,<2.1
|
||||
2
.github/scripts/build_triton_wheel.py
vendored
2
.github/scripts/build_triton_wheel.py
vendored
@ -38,7 +38,7 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio
|
||||
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
|
||||
if build_conda:
|
||||
with open(triton_basedir / "meta.yaml", "w") as meta:
|
||||
print(f"package:\n name: torchtriton\n version: 2.0.0+{commit_hash[:10]}\n", file=meta)
|
||||
print("package:\n name: torchtriton\n version: 2.0.0\n", file=meta)
|
||||
print("source:\n path: .\n", file=meta)
|
||||
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
||||
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)
|
||||
|
||||
@ -226,7 +226,8 @@ def generate_wheels_matrix(os: str,
|
||||
"nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"build_name":
|
||||
f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn"
|
||||
.replace(
|
||||
|
||||
2
.github/workflows/_binary-test-linux.yml
vendored
2
.github/workflows/_binary-test-linux.yml
vendored
@ -153,7 +153,7 @@ jobs:
|
||||
- name: Checkout pytorch/builder to builder dir
|
||||
uses: malfet/checkout@silent-checkout
|
||||
with:
|
||||
ref: main
|
||||
ref: release/2.0
|
||||
submodules: recursive
|
||||
repository: pytorch/builder
|
||||
path: builder
|
||||
|
||||
4
.github/workflows/build-triton-wheel.yml
vendored
4
.github/workflows/build-triton-wheel.yml
vendored
@ -137,7 +137,7 @@ jobs:
|
||||
run: |
|
||||
set -ex
|
||||
pip install -q awscli
|
||||
s3_dir="${UPLOAD_BUCKET}/whl/nightly/"
|
||||
s3_dir="${UPLOAD_BUCKET}/whl/test/"
|
||||
for pkg in "${PKG_DIR}/"*.whl; do
|
||||
aws s3 cp --no-progress --acl public-read "${pkg}" "${s3_dir}"
|
||||
done
|
||||
@ -193,7 +193,7 @@ jobs:
|
||||
if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/master' || github.event.ref == 'refs/heads/main') }}
|
||||
run: |
|
||||
container_name=$(docker container ps --format '{{.ID}}')
|
||||
docker exec -t "${container_name}" sh -c "anaconda upload /artifacts/torch*.tar.bz2 -u pytorch-nightly --label main --no-progress --force"
|
||||
docker exec -t "${container_name}" sh -c "anaconda upload /artifacts/torch*.tar.bz2 -u pytorch-test --label main --no-progress --force"
|
||||
|
||||
- name: Chown artifacts
|
||||
run: |
|
||||
|
||||
2
.github/workflows/generated-linux-binary-manywheel-master.yml
generated
vendored
2
.github/workflows/generated-linux-binary-manywheel-master.yml
generated
vendored
@ -47,7 +47,7 @@ jobs:
|
||||
DESIRED_PYTHON: "3.8"
|
||||
build_name: manywheel-py3_8-cuda11_7-with-pypi-cudnn
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
8
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
8
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
@ -169,7 +169,7 @@ jobs:
|
||||
DESIRED_PYTHON: "3.8"
|
||||
build_name: manywheel-py3_8-cuda11_7-with-pypi-cudnn
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -667,7 +667,7 @@ jobs:
|
||||
DESIRED_PYTHON: "3.9"
|
||||
build_name: manywheel-py3_9-cuda11_7-with-pypi-cudnn
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -1165,7 +1165,7 @@ jobs:
|
||||
DESIRED_PYTHON: "3.10"
|
||||
build_name: manywheel-py3_10-cuda11_7-with-pypi-cudnn
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -1663,7 +1663,7 @@ jobs:
|
||||
DESIRED_PYTHON: "3.11"
|
||||
build_name: manywheel-py3_11-cuda11_7-with-pypi-cudnn
|
||||
build_environment: linux-binary-manywheel
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.2.10.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.0.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64' | triton==2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
@ -54,8 +54,6 @@ TORCH_LIBRARY_IMPL(aten, MPS, m) {
|
||||
m.impl("embedding_renorm_", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("_fft_c2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("_fft_r2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold
|
||||
m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_mps_max_pool2d.h>
|
||||
#include <ATen/ops/adaptive_avg_pool1d_native.h>
|
||||
#include <ATen/ops/adaptive_avg_pool2d.h>
|
||||
#include <ATen/ops/adaptive_max_pool1d_native.h>
|
||||
@ -141,12 +140,6 @@ Tensor max_pool2d(
|
||||
return at::mkldnn_max_pool2d(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
#ifdef USE_MPS
|
||||
if (self.is_mps()) {
|
||||
return at::_mps_max_pool2d(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
#endif
|
||||
#if defined(C10_MOBILE)
|
||||
if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride,
|
||||
dilation, ceil_mode)) {
|
||||
|
||||
@ -1428,7 +1428,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
|
||||
}
|
||||
#ifdef USE_MPS
|
||||
if (_input.is_mps() && !bidirectional) {
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
|
||||
num_layers, dropout_p, train, bidirectional, batch_first);
|
||||
std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
|
||||
return return_values;
|
||||
|
||||
@ -138,4 +138,7 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
|
||||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
|
||||
constantValue:(double) constantValue
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
|
||||
@end
|
||||
|
||||
@ -265,7 +265,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
|
||||
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
|
||||
bool sliceViewTensor = canSliceViewTensor(src, mpsShape);
|
||||
// a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
|
||||
if ((!src.is_contiguous() || (src.is_view() && src.storage_offset() && !sliceViewTensor)) && gatherTensorData) {
|
||||
if ((!src.is_contiguous() || (src.storage_offset() && !sliceViewTensor)) && gatherTensorData) {
|
||||
Tensor emptyShell = Tensor();
|
||||
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
|
||||
_tensor = gatherViewTensor(src, emptyShell);
|
||||
@ -289,7 +289,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSS
|
||||
} else {
|
||||
if (!mpsShape) {
|
||||
mpsShape = getMPSShape(_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
|
||||
shape:mpsShape
|
||||
|
||||
@ -311,11 +311,25 @@ TORCH_IMPL_FUNC(log_softmax_mps_out) (
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
MPSGraphTensor* softmaxTensor = [mpsGraph softMaxWithTensor:inputTensor
|
||||
axis:dim
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph logarithmWithTensor:softmaxTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* maximumsTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
|
||||
axis:dim
|
||||
name:nil];
|
||||
MPSGraphTensor* inputTensorSubMax = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:maximumsTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:inputTensorSubMax
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* exponentTensorReduced = [mpsGraph reductionSumWithTensor:exponentTensor
|
||||
axis:dim
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* logSumExpTensor = [mpsGraph logarithmWithTensor:exponentTensorReduced
|
||||
name:nil];
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensorSubMax
|
||||
secondaryTensor:logSumExpTensor
|
||||
name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
@ -1208,8 +1222,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
{
|
||||
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor *gradOutputTensor_ = nil;
|
||||
MPSGraphTensor *inputTensor_ = nil;
|
||||
MPSGraphTensor *resultTensor_ = nil;
|
||||
MPSGraphTensor *selfOrResultTensor_ = nil;
|
||||
MPSGraphTensor *gradInputTensor_ = nil;
|
||||
};
|
||||
|
||||
@ -1218,7 +1231,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
|
||||
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
|
||||
to_string(alpha.to<double>()) + ":" +
|
||||
to_string(scale.to<double>()) + ":" +
|
||||
to_string(input_scale.to<double>()) + ":" +
|
||||
@ -1235,18 +1248,14 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor* inputTensor = nil;
|
||||
MPSGraphTensor* resultTensor = nil;
|
||||
|
||||
MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
|
||||
MPSGraphTensor* lessThanZeroGradTensor = nil;
|
||||
|
||||
if(is_result) {
|
||||
resultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
|
||||
MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to<double>()
|
||||
shape:@[@1]
|
||||
dataType:getMPSDataType(grad_output.scalar_type())];
|
||||
MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:resultTensor
|
||||
MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor
|
||||
secondaryTensor:alphaTensor
|
||||
name:nil];
|
||||
auto constMul = scale.to<double>() * input_scale.to<double>();
|
||||
@ -1258,11 +1267,10 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
name:nil];
|
||||
}
|
||||
else {
|
||||
inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
|
||||
MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to<double>()
|
||||
shape:@[@1]
|
||||
dataType:getMPSDataType(grad_output.scalar_type())];
|
||||
MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
|
||||
MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor
|
||||
secondaryTensor:inputScaleTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor
|
||||
@ -1282,7 +1290,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f
|
||||
shape:@[@1]
|
||||
dataType:getMPSDataType(grad_output.scalar_type())];
|
||||
MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor
|
||||
MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
@ -1294,8 +1302,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
name:nil];
|
||||
|
||||
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->resultTensor_ = resultTensor;
|
||||
newCachedGraph->selfOrResultTensor_ = selfOrResultTensor;
|
||||
newCachedGraph->gradInputTensor_ = gradInputTensor;
|
||||
}
|
||||
return newCachedGraph;
|
||||
@ -1304,28 +1311,14 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) (
|
||||
}
|
||||
|
||||
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp);
|
||||
Placeholder selfPlaceholder = Placeholder();
|
||||
Placeholder resultPlaceholder = Placeholder();
|
||||
if(is_result)
|
||||
resultPlaceholder = Placeholder(cachedGraph->resultTensor_, self_or_result, nil, executeGatherOp);
|
||||
else
|
||||
selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp);
|
||||
Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->selfOrResultTensor_, self_or_result, nil, executeGatherOp);
|
||||
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false);
|
||||
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
|
||||
|
||||
if(is_result)
|
||||
feeds = @{
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
else
|
||||
feeds = @{
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
|
||||
selfOrResultPlaceholder.getMPSGraphTensor() : selfOrResultPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
@ -1840,7 +1833,7 @@ std::tuple<Tensor, Tensor> prelu_backward_mps(const Tensor& grad_output, const T
|
||||
using namespace mps;
|
||||
|
||||
Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
|
||||
Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous);
|
||||
Tensor weight_grad = at::empty_like(self, at::MemoryFormat::Contiguous);
|
||||
if (grad_output.numel() == 0) {
|
||||
return std::tuple<Tensor, Tensor>{grad_input, weight_grad};
|
||||
}
|
||||
|
||||
@ -177,10 +177,6 @@ void div_mode_template(const Tensor& self, const Tensor& other,
|
||||
c10::optional<c10::string_view> rounding_mode,
|
||||
const Tensor& output, const string op_name)
|
||||
{
|
||||
if(rounding_mode.has_value() && *rounding_mode == "floor"){
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Long,
|
||||
"MPS: does not support floor_divide op with int64 input");
|
||||
}
|
||||
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
|
||||
|
||||
@ -12,7 +12,7 @@ Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
}
|
||||
Tensor output = self;
|
||||
bool needsCopyToOutput = false;
|
||||
if (!self.is_contiguous()) {
|
||||
if (!self.is_contiguous() || self.storage_offset()) {
|
||||
output = empty_mps(self.sizes(), self.scalar_type(), c10::nullopt, kMPS);
|
||||
needsCopyToOutput = true;
|
||||
}
|
||||
@ -89,7 +89,7 @@ bool fill_mps_tensor_(Tensor& self, uint8_t value) {
|
||||
if (self.is_contiguous()) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
auto storage_byte_offset = self.storage_offset() * self.itemsize();
|
||||
stream->fill(mps::getMTLBufferStorage(self), 0, self.nbytes(), storage_byte_offset);
|
||||
stream->fill(mps::getMTLBufferStorage(self), 0, self.storage().nbytes(), storage_byte_offset);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -56,15 +56,17 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
|
||||
descriptor_.groups = groups;
|
||||
}
|
||||
|
||||
Tensor _mps_convolution(
|
||||
Tensor _mps_convolution_impl(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
int64_t groups,
|
||||
c10::optional<IntArrayRef> input_shape) {
|
||||
TORCH_CHECK(input_t.dim() < 5, "Conv3D is not supported on MPS");
|
||||
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
|
||||
namespace native_mps = at::native::mps;
|
||||
CheckedFrom c = "mps_convolution";
|
||||
@ -83,6 +85,8 @@ Tensor _mps_convolution(
|
||||
auto memory_format = input_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
auto output_t = at::empty(
|
||||
input_shape.has_value() ?
|
||||
input_shape.value() :
|
||||
conv_output_size(input->sizes(), weight->sizes(),
|
||||
padding, stride, dilation),
|
||||
input->scalar_type(),
|
||||
@ -237,21 +241,30 @@ Tensor _mps_convolution(
|
||||
return *output;
|
||||
}
|
||||
|
||||
Tensor _mps_convolution(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const c10::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor mps_convolution_backward_input(
|
||||
IntArrayRef input_size, const Tensor& grad_output_, const Tensor& weight_,
|
||||
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
namespace native_mps = at::native::mps;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
CheckedFrom c = "mps_convolution_backward_input";
|
||||
TensorArg grad_output{ grad_output_, "grad_output", 1 },
|
||||
weight{ weight_, "weight", 2 };
|
||||
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
|
||||
weight{ weight_t, "weight", 2 };
|
||||
checkAllSameType(c, {grad_output, weight});
|
||||
checkAllSameGPU(c, {grad_output, weight});
|
||||
auto memory_format = grad_output_.suggest_memory_format();
|
||||
auto memory_format = grad_output_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
Tensor grad_output_t = grad_output_.contiguous(memory_format);
|
||||
Tensor weight_t = weight_.contiguous(memory_format);
|
||||
MPSShape* weightShape = getMPSShape(weight_);
|
||||
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);
|
||||
|
||||
// Avoid "grad_input" when this is being used as transposed convolution
|
||||
@ -327,10 +340,10 @@ Tensor mps_convolution_backward_input(
|
||||
}
|
||||
|
||||
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(weight_t.scalar_type()), weightShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
|
||||
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
|
||||
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
|
||||
if (is_channels_last) {
|
||||
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor;
|
||||
@ -359,7 +372,7 @@ Tensor mps_convolution_backward_input(
|
||||
}
|
||||
|
||||
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
|
||||
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t, weightShape);
|
||||
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
|
||||
|
||||
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
|
||||
@ -377,17 +390,15 @@ Tensor mps_convolution_backward_input(
|
||||
}
|
||||
|
||||
Tensor mps_convolution_backward_weights(
|
||||
IntArrayRef weight_size, const Tensor& grad_output_, const Tensor& input_,
|
||||
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
namespace native_mps = at::native::mps;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
CheckedFrom c = "mps_convolution_backward_weights";
|
||||
auto memory_format = input_.suggest_memory_format();
|
||||
auto memory_format = grad_output_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
|
||||
auto grad_output_t = grad_output_.to(memory_format);
|
||||
auto input_t = input_.to(memory_format);
|
||||
|
||||
MPSShape* gradOutputShape = mps::getMPSShape(grad_output_t, memory_format);
|
||||
|
||||
// For uniformity with everything else, although it seems grad_weight
|
||||
@ -475,7 +486,7 @@ Tensor mps_convolution_backward_weights(
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
|
||||
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
|
||||
if (is_channels_last) {
|
||||
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
|
||||
}
|
||||
|
||||
@ -525,12 +536,9 @@ Tensor mps_convolution_backward_weights(
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
|
||||
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
std::array<bool,3> output_mask) {
|
||||
|
||||
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
|
||||
|
||||
Tensor grad_input, grad_weight, grad_bias;
|
||||
if (input.numel() == 0) {
|
||||
if (output_mask[0]) {
|
||||
@ -576,10 +584,10 @@ Tensor _mps_convolution_transpose(
|
||||
Tensor mps_convolution_transpose_backward_input(
|
||||
const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups)
|
||||
int64_t groups, IntArrayRef input_shape)
|
||||
{
|
||||
return at::_mps_convolution(
|
||||
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups);
|
||||
return _mps_convolution_impl(
|
||||
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
|
||||
}
|
||||
|
||||
Tensor mps_convolution_transpose_backward_weight(
|
||||
@ -595,15 +603,12 @@ Tensor mps_convolution_transpose_backward_weight(
|
||||
|
||||
|
||||
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
|
||||
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
|
||||
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
std::array<bool,2> output_mask) {
|
||||
|
||||
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
|
||||
|
||||
Tensor grad_input, grad_weight;
|
||||
if (output_mask[0]) {
|
||||
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups);
|
||||
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups);
|
||||
|
||||
@ -251,8 +251,11 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
bool returnGatherOutput = dst_.is_contiguous();
|
||||
Tensor src;
|
||||
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
|
||||
const bool sameDataType = src_.dtype() == dst_.dtype();
|
||||
|
||||
if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
|
||||
if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) ||
|
||||
// the copy_cast path requires storage_offset to be applied before casting
|
||||
(src_.storage_offset() && !sameDataType)) {
|
||||
Tensor emptyShell = Tensor();
|
||||
src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell);
|
||||
|
||||
@ -282,7 +285,7 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
src._set_neg(src_.is_neg());
|
||||
|
||||
const size_t src_size = src.nbytes();
|
||||
if (src.dtype() == dst_.dtype()) {
|
||||
if (sameDataType) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
|
||||
stream->copy(sourceBuffer, destBuffer, src_size, src_byte_offset, dst_byte_offset);
|
||||
@ -297,22 +300,27 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
||||
TORCH_CHECK(dst.defined(), "dst is undefined");
|
||||
TORCH_CHECK(src.defined(), "src is undefined");
|
||||
|
||||
bool needs_broadcasting = false;
|
||||
|
||||
if (src.numel() == 0 || dst.is_same(src)) {
|
||||
return dst;
|
||||
}
|
||||
if (dst.numel() == 0) {
|
||||
dst.resize_as_(src);
|
||||
}
|
||||
if (dst.dim() > src.dim()) {
|
||||
needs_broadcasting = true;
|
||||
}
|
||||
|
||||
if (src.device().type() == at::kMPS && dst.device().type() == at::kCPU) {
|
||||
return copy_from_mps_(dst, src, non_blocking);
|
||||
return copy_from_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
|
||||
}
|
||||
if (src.device().type() == at::kCPU && dst.device().type() == at::kMPS) {
|
||||
return copy_to_mps_(dst, src, non_blocking);
|
||||
return copy_to_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
|
||||
}
|
||||
|
||||
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
|
||||
return copy_kernel_mps(dst, src, non_blocking);
|
||||
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.device().type() == DeviceType::MPS,
|
||||
|
||||
@ -886,19 +886,31 @@ Tensor embedding_dense_backward_mps(
|
||||
|
||||
MPSGraphTensor* reshapedIndicesTensor = indicesTensor;
|
||||
|
||||
MPSGraphTensor* castGradTensor = incomingGradTensor;
|
||||
MPSDataType dataType = mps::getMPSDataType(grad_.scalar_type());
|
||||
// issue 105486100, scatterNDWithUpdatesTensor produces wrong result for float16
|
||||
if (dataType == MPSDataTypeFloat16) {
|
||||
castGradTensor = [mpsGraph castTensor: incomingGradTensor
|
||||
toType: MPSDataTypeFloat32
|
||||
name: @"castGradTensor"];
|
||||
}
|
||||
if (num_indices_dims != 0) {
|
||||
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor: indicesTensor
|
||||
axes: @[@-1]
|
||||
name: nil];
|
||||
}
|
||||
|
||||
auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: incomingGradTensor
|
||||
auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: castGradTensor
|
||||
indicesTensor: reshapedIndicesTensor
|
||||
shape: native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape))
|
||||
batchDimensions: 0
|
||||
mode: MPSGraphScatterModeAdd
|
||||
name: @"edb"];
|
||||
|
||||
if (dataType == MPSDataTypeFloat16) {
|
||||
outgoingGradTensor = [mpsGraph castTensor: outgoingGradTensor
|
||||
toType: MPSDataTypeFloat16
|
||||
name: @"castGradTensor"];
|
||||
}
|
||||
newCachedGraph->incomingGradTensor_ = incomingGradTensor;
|
||||
newCachedGraph->indicesTensor_ = indicesTensor;
|
||||
newCachedGraph->outgoingGradTensor_ = outgoingGradTensor;
|
||||
|
||||
@ -83,6 +83,7 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
|
||||
|
||||
auto output_memory_format = output.suggest_memory_format();
|
||||
// the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors
|
||||
// by simply restriding them (instead of calling the costly Contiguous()).
|
||||
if (indices.suggest_memory_format() == MemoryFormat::ChannelsLast) {
|
||||
@ -94,8 +95,9 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
outputSizes.insert(outputSizes.begin(), nbatch);
|
||||
}
|
||||
output.resize_(outputSizes);
|
||||
} else if (output.suggest_memory_format() == MemoryFormat::ChannelsLast) {
|
||||
} else if (output_memory_format == MemoryFormat::ChannelsLast) {
|
||||
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
output_memory_format = MemoryFormat::Contiguous;
|
||||
}
|
||||
|
||||
if (output.numel() == 0 || (is_backward_pass && grad_output.numel() == 0)) {
|
||||
@ -196,6 +198,10 @@ static void pool2d_template(const Tensor& input, const Tensor& output,
|
||||
}
|
||||
|
||||
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
if (output_memory_format != suggested_memory_format) {
|
||||
const_cast<Tensor&>(output) = output.to(suggested_memory_format);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -302,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor _mps_max_pool2d(
|
||||
Tensor mps_max_pool2d(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
@ -356,6 +362,8 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
|
||||
const Tensor& output,
|
||||
const Tensor& indices) {
|
||||
|
||||
auto indices_memory_format = indices.suggest_memory_format();
|
||||
|
||||
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
|
||||
MPSGraph* mpsGraph = cachedGraph.graph();
|
||||
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
|
||||
@ -366,6 +374,10 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)(
|
||||
};
|
||||
mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride,
|
||||
padding, dilation, ceil_mode, false, c10::nullopt, pooling_op_block, "max_pool2d_indices");
|
||||
|
||||
if (indices_memory_format == MemoryFormat::ChannelsLast) {
|
||||
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)(
|
||||
|
||||
@ -139,6 +139,10 @@ void reduction_out_mps(
|
||||
MPSReductionType reduction_type,
|
||||
const std::string& func_name) {
|
||||
|
||||
// issue 103641234, reduction ops does not have int64 support
|
||||
if (input_t.scalar_type() == ScalarType::Long) {
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 reduction ops, casting it to int32");
|
||||
}
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
|
||||
if (opt_dim.has_value()) {
|
||||
@ -163,6 +167,9 @@ void reduction_out_mps(
|
||||
if (reduction_type == MPSReductionType::PROD) {
|
||||
output_t.fill_(1);
|
||||
}
|
||||
else if (reduction_type == MPSReductionType::SUM) {
|
||||
output_t.zero_();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@ -197,7 +204,10 @@ void reduction_out_mps(
|
||||
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt)) {
|
||||
inputCastDtype = getMPSDataType(dtype.value());
|
||||
} else if (input_type != MPSDataTypeInt32 &&
|
||||
input_type != MPSDataTypeFloat32) {
|
||||
input_type != MPSDataTypeFloat32 &&
|
||||
input_type != MPSDataTypeFloat16) {
|
||||
inputCastDtype = MPSDataTypeFloat32;
|
||||
} else if (!is_macos_13_or_newer() && input_type == MPSDataTypeFloat16) {
|
||||
inputCastDtype = MPSDataTypeFloat32;
|
||||
}
|
||||
|
||||
@ -241,7 +251,7 @@ void reduction_out_mps(
|
||||
axes:wrappedAxes
|
||||
name:nil];
|
||||
} else if (reduction_type == MPSReductionType::TRACE) {
|
||||
MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:inputTensor
|
||||
MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:castInputTensor
|
||||
numLower:0
|
||||
numUpper:0
|
||||
name:nil];
|
||||
@ -1257,7 +1267,9 @@ Tensor min_max_mps
|
||||
(const Tensor& input_t,
|
||||
MPSReductionType reduction_type,
|
||||
const std::string& func_name) {
|
||||
TORCH_WARN_ONCE(input_t.scalar_type() != ScalarType::Long, "MPS: no support for int64 min/max ops, casting it to int32");
|
||||
if (input_t.scalar_type() == ScalarType::Long) {
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
|
||||
}
|
||||
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
|
||||
@ -233,7 +233,7 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
|
||||
if (repeat.scalar_type() == kLong) {
|
||||
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
|
||||
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
|
||||
TORCH_WARN_ONCE(false, "MPS: no support for int64 repeats mask, casting it to int32");
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32");
|
||||
repeat = repeat.to(kInt);
|
||||
}
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
@ -243,4 +243,4 @@ Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> outpu
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
||||
|
||||
@ -23,17 +23,31 @@ std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
|
||||
return output_dimensions;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
using namespace mps;
|
||||
|
||||
//Projections are not currently supported, raise an error if needed
|
||||
bool has_projections = (hx[0].size(2) != hx[1].size(2));
|
||||
if(has_projections) {
|
||||
AT_ERROR("LSTM with projections is not currently supported with MPS.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && num_layers > 1), "Multi-layer LSTM support in MPS available only on MacOS 13 onwards");
|
||||
|
||||
std::vector<Tensor> kernel_weights;
|
||||
std::vector<Tensor> recurrent_kernel_weights;
|
||||
std::vector<Tensor> biases;
|
||||
std::vector<Tensor> recurrent_biases;
|
||||
for (size_t i = 0; i < num_layers; i+=1) {
|
||||
kernel_weights.push_back(params[i*4]);
|
||||
recurrent_kernel_weights.push_back(params[i*4+1]);
|
||||
biases.push_back(params[i*4+2]);
|
||||
recurrent_biases.push_back(params[i*4+3]);
|
||||
if (has_biases) {
|
||||
kernel_weights.push_back(params[i*4]);
|
||||
recurrent_kernel_weights.push_back(params[i*4+1]);
|
||||
biases.push_back(params[i*4+2]);
|
||||
recurrent_biases.push_back(params[i*4+3]);
|
||||
} else {
|
||||
kernel_weights.push_back(params[i*2]);
|
||||
recurrent_kernel_weights.push_back(params[i*2+1]);
|
||||
}
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
@ -44,8 +58,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
|
||||
std::vector<MPSGraphTensor*> outputCellStateFwdVector_;
|
||||
std::vector<MPSGraphTensor*> outputZStateVector_;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@ -67,12 +79,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
|
||||
for (size_t i = 0; i < num_layers; i += 1) {
|
||||
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
|
||||
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
|
||||
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
|
||||
if(has_biases) {
|
||||
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphLSTMDescriptor * opDesc = [MPSGraphLSTMDescriptor descriptor];
|
||||
@ -93,25 +108,28 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
}
|
||||
|
||||
MPSGraphTensor* inputTensor_ = inputTensor;
|
||||
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:0
|
||||
length:1
|
||||
name:nil];
|
||||
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:0
|
||||
length:1
|
||||
name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputs = nil;
|
||||
NSMutableArray<MPSGraphTensor*>* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
for(int i = 0; i < num_layers; i++) {
|
||||
MPSGraphTensor* biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
|
||||
secondaryTensor:recurrentBiasList[i]
|
||||
name:nil];
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if(has_biases) {
|
||||
biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
|
||||
secondaryTensor:recurrentBiasList[i]
|
||||
name:nil];
|
||||
}
|
||||
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
|
||||
recurrentWeight:recurrentKernelWeightsList[i]
|
||||
inputWeight:kernelWeightsList[i]
|
||||
@ -121,18 +139,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
descriptor:opDesc
|
||||
name:nil];
|
||||
|
||||
|
||||
stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
inputTensor_ = [outputs objectAtIndex:0];
|
||||
// no need to keep a final layer output copy as it is
|
||||
// returned anyway and not used in backprop
|
||||
if(i != num_layers - 1) {
|
||||
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
|
||||
axis:0
|
||||
name:nil]];
|
||||
}
|
||||
if(dropout_p>0.0 && train && (i!=num_layers-1)) {
|
||||
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
|
||||
rate:dropout_p
|
||||
@ -150,7 +164,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
name:nil]];
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor = [outputs objectAtIndex:0];
|
||||
MPSGraphTensor* outputTensor = inputTensor_;
|
||||
if (batch_first) {
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor
|
||||
dimension:0
|
||||
@ -169,8 +183,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* layersOutputs = (num_layers > 1)
|
||||
? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil]
|
||||
: nil;
|
||||
|
||||
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd};
|
||||
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs};
|
||||
newCachedGraph->inputTensors_ = inputTensors;
|
||||
newCachedGraph->outputTensors_ = outputTensors;
|
||||
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
|
||||
@ -188,20 +205,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
|
||||
|
||||
Placeholder kernelWeight;
|
||||
Placeholder recurrentKernelWeight;
|
||||
Placeholder bias;
|
||||
Placeholder recurrentBias;
|
||||
Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias;
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
for (size_t i = 0; i < num_layers; i+=1) {
|
||||
kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
|
||||
recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
|
||||
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
|
||||
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
|
||||
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
|
||||
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
|
||||
if(has_biases) {
|
||||
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
|
||||
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
|
||||
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
}
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensors_[0], input);
|
||||
@ -218,6 +235,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
Tensor cy = at::empty_like(hx[1], input.options());
|
||||
Tensor zState = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[3])), input.options());
|
||||
Tensor cellStateFwd = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[4])), input.options());
|
||||
Tensor layerOutputs = (num_layers > 1)
|
||||
? at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[5])), input.options())
|
||||
: at::empty({ 1 }, input.options()); // not used if num_layers == 1
|
||||
|
||||
Placeholder outputPlaceholder0 = Placeholder(cachedGraph->outputTensors_[0], output);
|
||||
Placeholder outputPlaceholder1 = Placeholder(cachedGraph->outputTensors_[1], hy);
|
||||
@ -225,20 +245,25 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
||||
Placeholder outputPlaceholder3 = Placeholder(cachedGraph->outputTensors_[3], zState);
|
||||
Placeholder outputPlaceholder4 = Placeholder(cachedGraph->outputTensors_[4], cellStateFwd);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [@{
|
||||
outputPlaceholder0.getMPSGraphTensor() : outputPlaceholder0.getMPSGraphTensorData(),
|
||||
outputPlaceholder1.getMPSGraphTensor() : outputPlaceholder1.getMPSGraphTensorData(),
|
||||
outputPlaceholder2.getMPSGraphTensor() : outputPlaceholder2.getMPSGraphTensorData(),
|
||||
outputPlaceholder3.getMPSGraphTensor() : outputPlaceholder3.getMPSGraphTensorData(),
|
||||
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData()
|
||||
};
|
||||
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData(),
|
||||
} mutableCopy];
|
||||
|
||||
if (num_layers > 1) {
|
||||
Placeholder outputPlaceholder5 = Placeholder(cachedGraph->outputTensors_[5], layerOutputs);
|
||||
[results setObject:outputPlaceholder5.getMPSGraphTensorData() forKey: outputPlaceholder5.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
return std::make_tuple(output, hy, cy, zState, cellStateFwd);
|
||||
return std::make_tuple(output, hy, cy, zState, cellStateFwd, layerOutputs);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, const Tensor& layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
using namespace mps;
|
||||
const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] {return Tensor();});
|
||||
const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
|
||||
@ -250,10 +275,15 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
std::vector<Tensor> biases;
|
||||
std::vector<Tensor> recurrent_biases;
|
||||
for (size_t i = 0; i < num_layers; i+=1) {
|
||||
kernel_weights.push_back(params[i*4]);
|
||||
recurrent_kernel_weights.push_back(params[i*4+1]);
|
||||
biases.push_back(params[i*4+2]);
|
||||
recurrent_biases.push_back(params[i*4+3]);
|
||||
if(has_biases) {
|
||||
kernel_weights.push_back(params[i*4]);
|
||||
recurrent_kernel_weights.push_back(params[i*4+1]);
|
||||
biases.push_back(params[i*4+2]);
|
||||
recurrent_biases.push_back(params[i*4+3]);
|
||||
} else {
|
||||
kernel_weights.push_back(params[i*2]);
|
||||
recurrent_kernel_weights.push_back(params[i*2+1]);
|
||||
}
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
@ -264,12 +294,12 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradOutput_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradRecWeights_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradWeights_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradBias_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradState_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradCellState_ = nil;
|
||||
MPSGraphTensor* gradOutput_ = nil;
|
||||
MPSGraphTensor* gradState_ = nil;
|
||||
MPSGraphTensor* gradCellState_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@ -296,8 +326,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
for (size_t i = 0; i < num_layers; i += 1) {
|
||||
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
|
||||
[recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
|
||||
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
|
||||
if(has_biases) {
|
||||
[kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
|
||||
[recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
|
||||
}
|
||||
}
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(input));
|
||||
@ -308,8 +340,22 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
MPSGraphTensor* gradientCyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy.scalar_type()), getMPSShape(grad_cy));
|
||||
MPSGraphTensor* gradientHyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy.scalar_type()), getMPSShape(grad_hy));
|
||||
MPSGraphTensor* cellStateFwdTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd.scalar_type()), getMPSShape(cell_state_fwd));
|
||||
MPSGraphTensor* layersOutputsTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs.scalar_type()), getMPSShape(layersOutputs));
|
||||
|
||||
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor, layersOutputsTensor};
|
||||
|
||||
if (batch_first) {
|
||||
inputTensor = [mpsGraph transposeTensor: inputTensor
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name: nil];
|
||||
|
||||
gradientTensor = [mpsGraph transposeTensor: gradientTensor
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name: nil];
|
||||
}
|
||||
|
||||
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor};
|
||||
newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList;
|
||||
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
|
||||
newCachedGraph->biasList_ = kernelBiasList;
|
||||
@ -325,7 +371,6 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
|
||||
NSArray<MPSGraphTensor*>* outputs = nil;
|
||||
|
||||
NSMutableArray<MPSGraphTensor*>* gradOutputArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* gradRecWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* gradWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* gradBiasArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
@ -349,9 +394,15 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd
|
||||
axis:0
|
||||
name:nil];
|
||||
MPSGraphTensor* biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
|
||||
secondaryTensor:recurrentBiasList[i]
|
||||
name:nil];
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if(has_biases) {
|
||||
biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
|
||||
secondaryTensor:recurrentBiasList[i]
|
||||
name:nil];
|
||||
} else {
|
||||
biasTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
}
|
||||
|
||||
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
@ -375,7 +426,23 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
length:1
|
||||
name:nil];
|
||||
|
||||
outputs = [mpsGraph LSTMGradientsWithSourceTensor: inputTensor
|
||||
MPSGraphTensor* iterationInputTensor_ = nil;
|
||||
if (i == 0) {
|
||||
iterationInputTensor_ = inputTensor;
|
||||
} else {
|
||||
iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor
|
||||
dimension: 0
|
||||
// last element in layersOutputsTensor contains
|
||||
// **inputs** for the last layer
|
||||
start: i - num_layers
|
||||
length: 1
|
||||
name: nil];
|
||||
iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_
|
||||
axis:0
|
||||
name: nil];
|
||||
}
|
||||
|
||||
outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_
|
||||
recurrentWeight: recurrentKernelWeightsList[i]
|
||||
sourceGradient: gradientTensor_
|
||||
zState: zState
|
||||
@ -391,24 +458,31 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
descriptor: opDesc
|
||||
name: nil];
|
||||
|
||||
|
||||
gradientTensor_ = [outputs objectAtIndex:0];
|
||||
[gradOutputArray addObject:[outputs objectAtIndex:0]];
|
||||
[gradRecWeightsArray addObject:[outputs objectAtIndex:1]];
|
||||
[gradWeightsArray addObject:[outputs objectAtIndex:2]];
|
||||
[gradBiasArray addObject:[outputs objectAtIndex:3]];
|
||||
[gradStateArray addObject:[outputs objectAtIndex:4]];
|
||||
[gradCellStateArray addObject:[outputs objectAtIndex:5]];
|
||||
[gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0];
|
||||
[gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0];
|
||||
[gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0];
|
||||
[gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil] atIndex:0];
|
||||
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0];
|
||||
}
|
||||
std::vector<MPSGraphTensor*> outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]};
|
||||
|
||||
if (batch_first) {
|
||||
MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name:nil];
|
||||
newCachedGraph->gradOutput_ = gradientTensorTransposed;
|
||||
} else {
|
||||
newCachedGraph->gradOutput_ = gradientTensor_;
|
||||
}
|
||||
|
||||
newCachedGraph->outputTensors_ = outputTensors;
|
||||
newCachedGraph->gradOutput_ = gradOutputArray;
|
||||
newCachedGraph->gradRecWeights_ = gradRecWeightsArray;
|
||||
newCachedGraph->gradWeights_ = gradWeightsArray;
|
||||
newCachedGraph->gradBias_ = gradBiasArray;
|
||||
newCachedGraph->gradState_ = gradStateArray;
|
||||
newCachedGraph->gradCellState_ = gradCellStateArray;
|
||||
|
||||
newCachedGraph->gradState_ = [mpsGraph concatTensors:gradStateArray dimension: 0 name: nil];
|
||||
newCachedGraph->gradCellState_ = [mpsGraph concatTensors:gradCellStateArray dimension: 0 name: nil];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
@ -423,6 +497,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
Placeholder cellStateFwdPlaceholder = Placeholder(cachedGraph->inputTensors_[5], cell_state_fwd);
|
||||
Placeholder gradientHyPlaceholder = Placeholder(cachedGraph->inputTensors_[6], grad_hy);
|
||||
Placeholder gradientCyPlaceholder = Placeholder(cachedGraph->inputTensors_[7], grad_cy);
|
||||
Placeholder layersOutputsPlaceholder = Placeholder(cachedGraph->inputTensors_[8], layersOutputs);
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
[feeds setObject:gradientPlaceholder.getMPSGraphTensorData() forKey:gradientPlaceholder.getMPSGraphTensor()];
|
||||
@ -433,6 +508,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
[feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() forKey:cellStateFwdPlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() forKey:layersOutputsPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
NSMutableArray<MPSGraphTensor*> *kernelWeightsList = cachedGraph->kernelWeightsList_;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
|
||||
@ -445,68 +521,65 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
for (size_t i = 0; i < num_layers; i+=1) {
|
||||
kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
|
||||
recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
|
||||
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
|
||||
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
|
||||
[feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
|
||||
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
|
||||
if(has_biases) {
|
||||
bias = Placeholder([biasList objectAtIndex:i], biases[i]);
|
||||
recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
|
||||
[feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
|
||||
[feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
|
||||
}
|
||||
}
|
||||
|
||||
Tensor output = at::empty_like(input);
|
||||
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[0]);
|
||||
Tensor grad_weights = at::empty_like(kernel_weights[0]);
|
||||
Tensor grad_bias = at::empty_like(biases[0]);
|
||||
Tensor grad_state = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state = at::empty_like(hx[1]);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensors_[0], output);
|
||||
Placeholder gradRecWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[1], grad_rec_weights);
|
||||
Placeholder gradWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[2], grad_weights);
|
||||
Placeholder gradBiasPlaceholder = Placeholder(cachedGraph->outputTensors_[3], grad_bias);
|
||||
Placeholder gradStatePlaceholder = Placeholder(cachedGraph->outputTensors_[4], grad_state);
|
||||
Placeholder gradCellStatePlaceholder = Placeholder(cachedGraph->outputTensors_[5], grad_cell_state);
|
||||
Tensor output_out = at::empty_like(input);
|
||||
Tensor grad_state_out = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state_out = at::empty_like(hx[1]);
|
||||
|
||||
std::vector<Tensor> grad_hx = {grad_state, grad_cell_state};
|
||||
|
||||
std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out};
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
NSMutableArray<MPSGraphTensor*> *gradOutputArray = cachedGraph->gradOutput_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradRecWeightsArray = cachedGraph->gradRecWeights_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradWeightsArray = cachedGraph->gradWeights_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradBiasArray = cachedGraph->gradBias_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradStateArray = cachedGraph->gradState_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradCellStateArray = cachedGraph->gradCellState_;
|
||||
Placeholder gradOutPlaceholder;
|
||||
MPSGraphTensor* gradOutput = cachedGraph->gradOutput_;
|
||||
MPSGraphTensor* gradState = cachedGraph->gradState_;
|
||||
MPSGraphTensor* gradCellState = cachedGraph->gradCellState_;
|
||||
|
||||
Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out);
|
||||
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out);
|
||||
Placeholder outputPlaceholder = Placeholder(gradOutput, output_out);
|
||||
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
|
||||
|
||||
std::vector<Tensor> weights;
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
Tensor output = at::empty_like(input);
|
||||
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]);
|
||||
Tensor grad_weights = at::empty_like(kernel_weights[i]);
|
||||
Tensor grad_bias = at::empty_like(biases[i]);
|
||||
Tensor grad_state = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state = at::empty_like(hx[1]);
|
||||
Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
|
||||
weights.push_back(grad_weights);
|
||||
weights.push_back(grad_rec_weights);
|
||||
weights.push_back(grad_bias);
|
||||
weights.push_back(grad_bias);
|
||||
gradOutPlaceholder = Placeholder([gradOutputArray objectAtIndex:i], output);
|
||||
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:i], grad_rec_weights);
|
||||
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:i], grad_weights);
|
||||
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:i], grad_bias);
|
||||
gradStatePlaceholder = Placeholder([gradStateArray objectAtIndex:i], grad_state);
|
||||
gradCellStatePlaceholder = Placeholder([gradCellStateArray objectAtIndex:i], grad_cell_state);
|
||||
|
||||
[results setObject:gradOutPlaceholder.getMPSGraphTensorData() forKey:gradOutPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
|
||||
if(has_biases) {
|
||||
weights.push_back(grad_bias);
|
||||
weights.push_back(grad_bias);
|
||||
}
|
||||
|
||||
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights);
|
||||
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights);
|
||||
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
|
||||
|
||||
[results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output, grad_hx, weights);
|
||||
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output_out, grad_hx, weights);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,7 +35,9 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
indices.copy_(cpu_indices);
|
||||
return;
|
||||
}
|
||||
TORCH_WARN_ONCE(self.scalar_type() != ScalarType::Long, "MPS: no support for int64 min/max ops, casting it to int32");
|
||||
if (self.scalar_type() == ScalarType::Long) {
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
|
||||
}
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
|
||||
@ -75,15 +75,20 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
|
||||
return inputTensor;
|
||||
}
|
||||
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
return [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
if(!is_macos_13_or_newer()) {
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
|
||||
dataType:inputTensor.dataType];
|
||||
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:zeroTensor
|
||||
name:nil];
|
||||
return [mpsGraph selectWithPredicateTensor:predicateTensor
|
||||
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
|
||||
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
|
||||
name:nil];
|
||||
} else {
|
||||
return [mpsGraph truncateWithTensor:inputTensor
|
||||
name:nil];
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mps
|
||||
|
||||
@ -26,6 +26,11 @@ void upsample_out_template(const Tensor& input,
|
||||
} else {
|
||||
native::upsample_2d_common_check(input.sizes(), output_size);
|
||||
}
|
||||
Tensor out;
|
||||
if (!output.is_contiguous()) {
|
||||
out = at::empty_like(output, MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
bool centerResults = false;
|
||||
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
|
||||
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
|
||||
@ -199,7 +204,7 @@ void upsample_out_template(const Tensor& input,
|
||||
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
|
||||
@ -209,6 +214,10 @@ void upsample_out_template(const Tensor& input,
|
||||
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
|
||||
};
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
if (out.has_storage()) {
|
||||
output.copy_(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -424,22 +424,54 @@ MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTen
|
||||
}
|
||||
|
||||
static
|
||||
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape) {
|
||||
std::vector<int64_t> getViewShape(const Tensor& src, MPSShape *mpsShape, const bool squeeze) {
|
||||
bool hasMPSShape = (mpsShape != nil);
|
||||
std::vector<int64_t> src_view_shape;
|
||||
if (hasMPSShape) {
|
||||
int src_ndim_view = [mpsShape count];
|
||||
src_view_shape.resize(src_ndim_view);
|
||||
for (const auto i : c10::irange(src_ndim_view)) {
|
||||
src_view_shape[i] = [mpsShape[i] intValue];
|
||||
if (squeeze) {
|
||||
for (const auto i : c10::irange(src_ndim_view)) {
|
||||
if ([mpsShape[i] intValue] == 1)
|
||||
continue;
|
||||
src_view_shape.emplace_back([mpsShape[i] intValue]);
|
||||
}
|
||||
} else {
|
||||
src_view_shape.resize(src_ndim_view);
|
||||
for (const auto i : c10::irange(src_ndim_view)) {
|
||||
src_view_shape[i] = [mpsShape[i] intValue];
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
src_view_shape = src.sizes().vec();
|
||||
if (squeeze) {
|
||||
IntArrayRef src_shape = src.sizes();
|
||||
size_t src_ndim_view = src_shape.size();
|
||||
for (const auto i : c10::irange(src_ndim_view)) {
|
||||
if (src_shape[i] == 1)
|
||||
continue;
|
||||
src_view_shape.emplace_back(src_shape[i]);
|
||||
}
|
||||
} else {
|
||||
src_view_shape = src.sizes().vec();
|
||||
}
|
||||
}
|
||||
|
||||
return src_view_shape;
|
||||
}
|
||||
|
||||
|
||||
std::vector<int64_t> getSqueezedBaseShape(const Tensor& src, IntArrayRef shape) {
|
||||
std::vector<int64_t> src_base_shape;
|
||||
for (const auto i : c10::irange(shape.size())) {
|
||||
if (shape[i] == 1)
|
||||
continue;
|
||||
src_base_shape.emplace_back(shape[i]);
|
||||
}
|
||||
|
||||
return src_base_shape;
|
||||
}
|
||||
|
||||
|
||||
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
|
||||
if (!src.is_contiguous()) {
|
||||
return false;
|
||||
@ -447,57 +479,79 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
|
||||
|
||||
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
|
||||
size_t src_ndim_base = src_base_shape.size();
|
||||
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
|
||||
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
|
||||
size_t src_ndim_view = src_view_shape.size();
|
||||
|
||||
if (src_ndim_base != src_ndim_view) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto i: c10::irange(src_ndim_base)) {
|
||||
if (src_view_shape[i] > src_base_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (src_view_shape[i] > src_base_shape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) {
|
||||
IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data());
|
||||
int src_ndim_base = src_base_shape.size();
|
||||
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape);
|
||||
int src_ndim_view = src_view_shape.size();
|
||||
|
||||
TORCH_CHECK(src_ndim_base == src_ndim_view);
|
||||
size_t src_ndim_base = src_base_shape.size();
|
||||
std::vector<int64_t> src_view_shape = getViewShape(src, mpsShape, false);
|
||||
size_t src_ndim_view = src_view_shape.size();
|
||||
|
||||
MPSNDArray *srcTensorNDArrayView = nil;
|
||||
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
|
||||
MPSNDArray *srcTensorNDArray = nil;
|
||||
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
|
||||
int64_t base_idx = 0;
|
||||
|
||||
std::vector<int64_t> src_base_shape_vec;
|
||||
|
||||
if (src_ndim_view != src_ndim_base) {
|
||||
src_base_shape_vec.reserve(src_ndim_view);
|
||||
for (const auto i : c10::irange(src_ndim_view)) {
|
||||
if (src_view_shape[i] == 1 && src_base_shape[base_idx] != 1) {
|
||||
src_base_shape_vec.emplace_back(1);
|
||||
} else {
|
||||
src_base_shape_vec.emplace_back(src_base_shape[base_idx]);
|
||||
if (base_idx < src_ndim_base - 1)
|
||||
base_idx += 1;
|
||||
}
|
||||
}
|
||||
src_base_shape = IntArrayRef(src_base_shape_vec);
|
||||
src_ndim_base = src_base_shape.size();
|
||||
}
|
||||
|
||||
srcTensorNDArray = ndArrayFromTensor(src, getMPSShape(src_base_shape), mpsDataType);
|
||||
srcTensorNDArrayDesc = srcTensorNDArray.descriptor;
|
||||
|
||||
int firstDimToSlice = 0;
|
||||
size_t firstDimToSlice = 0;
|
||||
while (src_base_shape[firstDimToSlice] == src_view_shape[firstDimToSlice]) {
|
||||
firstDimToSlice++;
|
||||
}
|
||||
|
||||
int view_numel = 1;
|
||||
int64_t view_numel = 1;
|
||||
for (const auto i : c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
|
||||
view_numel *= src_base_shape[i];
|
||||
}
|
||||
|
||||
int sliceOffset = src.storage_offset() / view_numel;
|
||||
// There are cases where both dimensions of a view can shrink
|
||||
// E.g: x = torch.randn((3,6))[1, 1:3]
|
||||
int nextSliceOffset = src.storage_offset() % view_numel;
|
||||
int64_t sliceOffset = src.storage_offset() / view_numel;
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
|
||||
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
|
||||
if (nextSliceOffset) {
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 2 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(nextSliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice+1])}];
|
||||
// Slice any remaining dimensions
|
||||
for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
|
||||
if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
|
||||
if (crtSliceOffset == src_base_shape.size() - 1) {
|
||||
sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
|
||||
} else {
|
||||
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
|
||||
}
|
||||
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
|
||||
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
|
||||
}
|
||||
}
|
||||
|
||||
srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer
|
||||
descriptor:srcTensorNDArrayDesc
|
||||
aliasing:MPSAliasingStrategyShallAlias];
|
||||
@ -696,7 +750,7 @@ const std::string& getGatherScatterScalarType(const Tensor& t) {
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Char, "char"},
|
||||
{c10::ScalarType::Byte, "char"},
|
||||
{c10::ScalarType::Byte, "uchar"},
|
||||
{c10::ScalarType::Bool, "bool"},
|
||||
};
|
||||
|
||||
|
||||
@ -3567,19 +3567,14 @@
|
||||
- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
|
||||
- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
|
||||
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
|
||||
# native_functions.yaml
|
||||
# https://github.com/pytorch/pytorch/issues/77394
|
||||
- func: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
dispatch:
|
||||
MPS: _mps_max_pool2d
|
||||
autogen: _mps_max_pool2d.out
|
||||
CompositeImplicitAutograd: max_pool2d
|
||||
MPS: mps_max_pool2d
|
||||
|
||||
- func: mps_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
dispatch:
|
||||
MPS: mps_max_pool2d_backward
|
||||
autogen: mps_max_pool2d_backward.out
|
||||
autogen: max_pool2d_backward.out
|
||||
|
||||
- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
dispatch:
|
||||
@ -7188,12 +7183,12 @@
|
||||
|
||||
# MPS LSTM implementation
|
||||
|
||||
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
MPS: _lstm_mps
|
||||
autogen: _lstm_mps.out
|
||||
|
||||
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
dispatch:
|
||||
MPS: lstm_mps_backward
|
||||
autogen: lstm_mps_backward.out
|
||||
|
||||
@ -630,6 +630,7 @@ macro(cuda_unset_include_and_libraries)
|
||||
unset(CUDA_cublas_LIBRARY CACHE)
|
||||
unset(CUDA_cublas_device_LIBRARY CACHE)
|
||||
unset(CUDA_cublasemu_LIBRARY CACHE)
|
||||
unset(CUDA_cublasLt_LIBRARY CACHE)
|
||||
unset(CUDA_cufft_LIBRARY CACHE)
|
||||
unset(CUDA_cufftemu_LIBRARY CACHE)
|
||||
unset(CUDA_cupti_LIBRARY CACHE)
|
||||
@ -963,6 +964,7 @@ endif()
|
||||
|
||||
find_cuda_helper_libs(cufft)
|
||||
find_cuda_helper_libs(cublas)
|
||||
find_cuda_helper_libs(cublasLt)
|
||||
# cusparse showed up in version 3.2
|
||||
find_cuda_helper_libs(cusparse)
|
||||
find_cuda_helper_libs(curand)
|
||||
@ -993,7 +995,7 @@ if (CUDA_BUILD_EMULATION)
|
||||
set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublasemu_LIBRARY})
|
||||
else()
|
||||
set(CUDA_CUFFT_LIBRARIES ${CUDA_cufft_LIBRARY})
|
||||
set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY})
|
||||
set(CUDA_CUBLAS_LIBRARIES ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
|
||||
endif()
|
||||
|
||||
########################
|
||||
@ -1962,7 +1964,7 @@ macro(CUDA_ADD_CUBLAS_TO_TARGET target)
|
||||
if (CUDA_BUILD_EMULATION)
|
||||
target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublasemu_LIBRARY})
|
||||
else()
|
||||
target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY})
|
||||
target_link_libraries(${target} ${CUDA_LINK_LIBRARIES_KEYWORD} ${CUDA_cublas_LIBRARY} ${CUDA_cublas_device_LIBRARY} ${CUDA_cublasLt_LIBRARY})
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
|
||||
@ -351,7 +351,7 @@ master_doc = 'index'
|
||||
|
||||
# General information about the project.
|
||||
project = 'PyTorch'
|
||||
copyright = '2022, PyTorch Contributors'
|
||||
copyright = '2023, PyTorch Contributors'
|
||||
author = 'PyTorch Contributors'
|
||||
torch_version = str(torch.__version__)
|
||||
|
||||
|
||||
@ -6,13 +6,12 @@ significant speedups the newer your GPU is.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from torch._dynamo import optimize
|
||||
import torch
|
||||
def fn(x, y):
|
||||
a = torch.cos(x).cuda()
|
||||
b = torch.sin(y).cuda()
|
||||
return a + b
|
||||
new_fn = optimize("inductor")(fn)
|
||||
new_fn = torch.compile(fn, backend="inductor")
|
||||
input_tensor = torch.randn(10000).to(device="cuda:0")
|
||||
a = new_fn(input_tensor, input_tensor)
|
||||
|
||||
@ -54,7 +53,7 @@ with the actual generated kernel being
|
||||
tmp2 = tl.sin(tmp1)
|
||||
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
|
||||
|
||||
And you can verify that fusing the two ``sins`` did actually occur
|
||||
And you can verify that fusing the two ``sin`` did actually occur
|
||||
because the two ``sin`` operations occur within a single Triton kernel
|
||||
and the temporary variables are held in registers with very fast access.
|
||||
|
||||
@ -69,13 +68,12 @@ hub.
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
|
||||
opt_model = dynamo.optimize("inductor")(model)
|
||||
opt_model = torch.compile(model, backend="inductor")
|
||||
model(torch.randn(1,3,64,64))
|
||||
|
||||
And that is not the only available backend, you can run in a REPL
|
||||
``dynamo.list_backends()`` to see all the available backends. Try out the
|
||||
``torch._dynamo.list_backends()`` to see all the available backends. Try out the
|
||||
``cudagraphs`` or ``nvfuser`` next as inspiration.
|
||||
|
||||
Let’s do something a bit more interesting now, our community frequently
|
||||
@ -92,11 +90,10 @@ HuggingFace hub and optimize it:
|
||||
|
||||
import torch
|
||||
from transformers import BertTokenizer, BertModel
|
||||
import torch._dynamo as dynamo
|
||||
# Copy pasted from here https://huggingface.co/bert-base-uncased
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
|
||||
model = dynamo.optimize("inductor")(model) # This is the only line of code that we changed
|
||||
model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
|
||||
output = model(**encoded_input)
|
||||
@ -116,7 +113,7 @@ Similarly let’s try out a TIMM example
|
||||
import torch._dynamo as dynamo
|
||||
import torch
|
||||
model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2)
|
||||
opt_model = dynamo.optimize("inductor")(model)
|
||||
opt_model = torch.compile(model, backend="inductor")
|
||||
opt_model(torch.randn(64,3,7,7))
|
||||
|
||||
Our goal with Dynamo and inductor is to build the highest coverage ML compiler
|
||||
@ -132,16 +129,16 @@ or ``torch._dynamo.list_backends()`` each of which with its optional dependencie
|
||||
Some of the most commonly used backends include:
|
||||
|
||||
**Training & inference backends**:
|
||||
* ``dynamo.optimize("inductor")`` - Uses ``TorchInductor`` backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
|
||||
* ``dynamo.optimize("aot_ts_nvfuser")`` - nvFuser with AotAutograd/TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
* ``dynamo.optimize("nvprims_nvfuser")`` - nvFuser with PrimTorch. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
* ``dynamo.optimize("cudagraphs")`` - cudagraphs with AotAutograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
|
||||
* ``torch.compile(m, backend="inductor")`` - Uses ``TorchInductor`` backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
|
||||
* ``torch.compile(m, backend="aot_ts_nvfuser")`` - nvFuser with AotAutograd/TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
* ``torch.compile(m, backend=""nvprims_nvfuser")`` - nvFuser with PrimTorch. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
* ``torch.compile(m, backend="cudagraphs")`` - cudagraphs with AotAutograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
|
||||
|
||||
**Inference-only backends**:
|
||||
* ``dynamo.optimize("onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__
|
||||
* ``dynamo.optimize("tensorrt")`` - Uses ONNXRT to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
|
||||
* ``dynamo.optimize("ipex")`` - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
|
||||
* ``dynamo.optimize("tvm")`` - Uses Apach TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__
|
||||
* ``torch.compile(m, backend="onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more <https://onnxruntime.ai/>`__
|
||||
* ``torch.compile(m, backend="tensorrt")`` - Uses ONNXRT to run TensorRT for inference optimizations. `Read more <https://github.com/onnx/onnx-tensorrt>`__
|
||||
* ``torch.compile(m, backend="ipex")`` - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__
|
||||
* ``torch.compile(m, backend="tvm")`` - Uses Apach TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__
|
||||
|
||||
Why do you need another way of optimizing PyTorch code?
|
||||
-------------------------------------------------------
|
||||
|
||||
@ -15,7 +15,7 @@ Where a complete example looks like this:
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torchdynamo
|
||||
from torch import _dynamo as torchdynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
|
||||
@ -14,7 +14,7 @@ worlds — usability and performance.
|
||||
|
||||
TorchDynamo makes it easy to experiment with different compiler
|
||||
backends to make PyTorch code faster with a single line decorator
|
||||
``torch._dynamo.optimize()``
|
||||
``torch._dynamo.optimize()`` which is wrapped for convenience by ``torch.compile()``
|
||||
|
||||
.. image:: ../_static/img/dynamo/TorchDynamo.png
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ TorchDynamo dependencies (for CUDA 11.7):
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
pip3 install numpy --pre torch --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
|
||||
CPU requirements
|
||||
~~~~~~~~~~~~~~~~
|
||||
@ -41,16 +41,6 @@ To install, run the following command:
|
||||
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
|
||||
|
||||
Install from Local Source
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Alternatively, you can build PyTorch from `source
|
||||
<https://github.com/pytorch/pytorch#from-source>`__, which has TorchDynamo
|
||||
included.
|
||||
|
||||
To install GPU TorchDynamo dependencies, run ``make triton`` in the
|
||||
PyTorch repo root directory.
|
||||
|
||||
Verify Installation
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@ -129,6 +129,49 @@ Algorithms
|
||||
Rprop
|
||||
SGD
|
||||
|
||||
Many of our algorithms have various implementations optimized for performance,
|
||||
readability and/or generality, so we attempt to default to the generally fastest
|
||||
implementation for the current device if no particular implementation has been
|
||||
specified by the user.
|
||||
|
||||
We have 3 major categories of implementations: for-loop, foreach (multi-tensor), and
|
||||
fused. The most straightforward implementations are for-loops over the parameters with
|
||||
big chunks of computation. For-looping is usually slower than our foreach
|
||||
implementations, which combine parameters into a multi-tensor and run the big chunks
|
||||
of computation all at once, thereby saving many sequential kernel calls. A few of our
|
||||
optimizers have even faster fused implementations, which fuse the big chunks of
|
||||
computation into one kernel. We can think of foreach implementations as fusing
|
||||
horizontally and fused implementations as fusing vertically on top of that.
|
||||
|
||||
In general, the performance ordering of the 3 implementations is fused > foreach > for-loop.
|
||||
So when applicable, we default to foreach over for-loop. Applicable means the foreach
|
||||
implementation is available, the user has not specified any implementation-specific kwargs
|
||||
(e.g., fused, foreach, differentiable), and all tensors are native and on CUDA. Note that
|
||||
while fused should be even faster than foreach, the implementations are newer and we would
|
||||
like to give them more bake-in time before flipping the switch everywhere. You are welcome
|
||||
to try them out though!
|
||||
|
||||
Below is a table showing the available and default implementations of each algorithm:
|
||||
|
||||
.. csv-table::
|
||||
:header: "Algorithm", "Default", "Has foreach?", "Has fused?"
|
||||
:widths: 25, 25, 25, 25
|
||||
:delim: ;
|
||||
|
||||
:class:`Adadelta`;foreach;yes;no
|
||||
:class:`Adagrad`;foreach;yes;no
|
||||
:class:`Adam`;foreach;yes;yes
|
||||
:class:`AdamW`;foreach;yes;yes
|
||||
:class:`SparseAdam`;for-loop;no;no
|
||||
:class:`Adamax`;foreach;yes;no
|
||||
:class:`ASGD`;foreach;yes;no
|
||||
:class:`LBFGS`;for-loop;no;no
|
||||
:class:`NAdam`;foreach;yes;no
|
||||
:class:`RAdam`;foreach;yes;no
|
||||
:class:`RMSprop`;foreach;yes;no
|
||||
:class:`Rprop`;foreach;yes;no
|
||||
:class:`SGD`;foreach;yes;no
|
||||
|
||||
How to adjust learning rate
|
||||
---------------------------
|
||||
|
||||
|
||||
7
setup.py
7
setup.py
@ -1024,17 +1024,12 @@ def main():
|
||||
'typing-extensions',
|
||||
'sympy',
|
||||
'networkx',
|
||||
'jinja2',
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
'opt-einsum': ['opt-einsum>=3.3']
|
||||
}
|
||||
if platform.system() == 'Linux':
|
||||
triton_pin_file = os.path.join(cwd, ".github", "ci_commit_pins", "triton.txt")
|
||||
if os.path.exists(triton_pin_file):
|
||||
with open(triton_pin_file) as f:
|
||||
triton_pin = f.read().strip()
|
||||
extras_require['dynamo'] = ['pytorch-triton==2.0.0+' + triton_pin[:10], 'jinja2']
|
||||
|
||||
# Parse the command line and check the arguments before we proceed with
|
||||
# building deps and setup. We need to set values so `--help` works.
|
||||
|
||||
@ -504,7 +504,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
deterministic=True,
|
||||
)
|
||||
optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
|
||||
optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR)
|
||||
fsdp_kwargs["use_orig_params"] = True
|
||||
fsdp_model_orig_params = TransformerWithSharedParams.init(
|
||||
self.process_group,
|
||||
@ -513,7 +513,9 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
deterministic=True,
|
||||
)
|
||||
optim_orig_params = torch.optim.Adam(fsdp_model_orig_params.parameters(), lr=LR)
|
||||
optim_orig_params = torch.optim.Adam(
|
||||
fsdp_model_orig_params.parameters(), foreach=False, lr=LR
|
||||
)
|
||||
return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params
|
||||
|
||||
def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None:
|
||||
|
||||
@ -60,6 +60,11 @@ unittest.expectedFailure(
|
||||
# Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
)
|
||||
|
||||
unittest.expectedFailure(
|
||||
DynamicShapesMiscTests.test_parsing_sdpa_dynamic_shapes
|
||||
# Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
)
|
||||
|
||||
|
||||
# DynamicShapesSubGraphTests
|
||||
unittest.expectedFailure(
|
||||
|
||||
@ -3145,6 +3145,53 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(compiled.device.index, 0)
|
||||
self.assertEqual(compiled.dtype, torch.float16)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
|
||||
"Can't run fused SDPA on this platform",
|
||||
)
|
||||
def test_parsing_sdpa(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, query, key, value):
|
||||
out = F.scaled_dot_product_attention(query, key, value, None, 0, True)
|
||||
out = F.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
)
|
||||
out = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=True,
|
||||
)
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value, None, dropout_p=0, is_causal=True
|
||||
)
|
||||
return out
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
seq_len_q = 1
|
||||
seq_len_k = 1
|
||||
head_dim = 8
|
||||
query = torch.ones(
|
||||
1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
key = torch.ones(
|
||||
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
value = torch.ones(
|
||||
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
module = MyModule()
|
||||
opt_mod = torch._dynamo.optimize("inductor")(module)
|
||||
opt_mod(query, key, value)
|
||||
|
||||
def test_autocast_cpu(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
@ -377,8 +377,6 @@ aten::_mps_convolution
|
||||
aten::_mps_convolution.out
|
||||
aten::_mps_convolution_transpose
|
||||
aten::_mps_convolution_transpose.out
|
||||
aten::_mps_max_pool2d
|
||||
aten::_mps_max_pool2d.out
|
||||
aten::_native_batch_norm_legit.no_stats_out
|
||||
aten::_native_batch_norm_legit.out
|
||||
aten::_native_decoder_only_multi_head_attention
|
||||
@ -857,6 +855,8 @@ aten::max
|
||||
aten::max.dim
|
||||
aten::max.dim_max
|
||||
aten::max.unary_out
|
||||
aten::max_pool2d_backward
|
||||
aten::max_pool2d_backward.out
|
||||
aten::max_pool2d_with_indices
|
||||
aten::max_pool2d_with_indices.out
|
||||
aten::max_pool2d_with_indices_backward
|
||||
@ -930,8 +930,6 @@ aten::mps_convolution_backward
|
||||
aten::mps_convolution_backward.out
|
||||
aten::mps_convolution_transpose_backward
|
||||
aten::mps_convolution_transpose_backward.out
|
||||
aten::mps_max_pool2d_backward
|
||||
aten::mps_max_pool2d_backward.out
|
||||
aten::multi_margin_loss
|
||||
aten::multi_margin_loss.out
|
||||
aten::multi_margin_loss_backward
|
||||
|
||||
@ -150,6 +150,10 @@ ALLOW_LIST = [
|
||||
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)),
|
||||
("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
|
||||
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
@ -269,7 +273,10 @@ ALLOW_LIST = [
|
||||
("aten::dsplit.int", datetime.date(2022, 9, 1)),
|
||||
("aten::hsplit.array", datetime.date(2022, 9, 1)),
|
||||
("aten::hsplit.int", datetime.date(2022, 9, 1)),
|
||||
("aten::lstm_mps_backward.out", datetime.date(2022, 9, 1)),
|
||||
("aten::lstm_mps_backward.out", datetime.date(2023, 9, 1)),
|
||||
("aten::lstm_mps_backward", datetime.date(2023, 9, 1)),
|
||||
("aten::_lstm_mps.out", datetime.date(2023, 9, 1)),
|
||||
("aten::_lstm_mps", datetime.date(2023, 9, 1)),
|
||||
("aten::miopen_rnn_backward.out", datetime.date(2022, 9, 1)),
|
||||
("aten::quantize_per_tensor.tensors_out", datetime.date(2022, 9, 1)),
|
||||
("aten::split", datetime.date(2022, 9, 1)),
|
||||
|
||||
@ -989,6 +989,34 @@ def forward(self, primals_1, primals_2):
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True)
|
||||
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_mem_leak_from_save_for_bw(self):
|
||||
# See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990
|
||||
# Note [Detaching saved tensors in AOTAutograd]
|
||||
# This program creates a ref-cycle. Long term, we should fix this ref cycle
|
||||
# (since it can arise, naturally albeit rarely, from uses of autograd.Function).
|
||||
# But AOTAutograd makes it more likely to show up from tracing user programs,
|
||||
# so we deal with it by manually detaching the tensors that we save for backward.
|
||||
# This is completely wrong and would give wrong results if we were to do double backward.
|
||||
# Fortunately today, double backward is explicitly banned in AOTAutograd.
|
||||
def f(a, b):
|
||||
add = a + a
|
||||
split = torch.functional.split(add, [4, 4], dim=1)
|
||||
getitem_2 = split[1]
|
||||
unsqueeze = getitem_2.unsqueeze(-1)
|
||||
mul = unsqueeze * b
|
||||
return (getitem_2, mul)
|
||||
|
||||
f_compiled = aot_function(f, nop)
|
||||
inps = [
|
||||
torch.ones(8, 8, device='cuda', requires_grad=True),
|
||||
torch.ones(1, 4, 1, device='cuda', requires_grad=True),
|
||||
]
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
f_compiled(*inps)
|
||||
mem_after = torch.cuda.memory_allocated()
|
||||
self.assertTrue(mem_after == mem_before)
|
||||
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
def test_output_aliases_multiple_inputs_get_correct_one(self):
|
||||
# a and b are aliased, but have different shapes
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import contextlib
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import functorch
|
||||
@ -10,6 +11,7 @@ from torch._dynamo.backends.registry import register_backend
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.compile_fx import compile_fx, count_bytes_inner
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
@ -23,9 +25,17 @@ def count_bytes_inductor(gm, example_inputs):
|
||||
return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner)
|
||||
|
||||
|
||||
@torch._dynamo.optimize("count_bytes_inductor")
|
||||
def f(x):
|
||||
return torch.cat([x, x.cos()])
|
||||
# TODO remove version check once dynamo supports 3.11
|
||||
if sys.version_info < (3, 11) and not IS_WINDOWS:
|
||||
|
||||
@torch._dynamo.optimize("count_bytes_inductor")
|
||||
def f(x):
|
||||
return torch.cat([x, x.cos()])
|
||||
|
||||
else:
|
||||
|
||||
def f(x):
|
||||
return torch.cat([x, x.cos()])
|
||||
|
||||
|
||||
def count_numel(f, *args):
|
||||
|
||||
@ -62,11 +62,30 @@ class TestSelectAlgorithm(TestCase):
|
||||
def foo(input, weight, bias):
|
||||
return torch.addmm(bias, input, weight)
|
||||
|
||||
foo(
|
||||
inps = (
|
||||
torch.randn(20, 33, device="cuda"),
|
||||
torch.randn(33, 16, device="cuda"),
|
||||
torch.randn(20, 16, device="cuda"),
|
||||
)
|
||||
|
||||
foo(*inps)
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
|
||||
@patches
|
||||
def test_addmm_fp16(self):
|
||||
@torch.compile
|
||||
def foo(input, weight, bias):
|
||||
return torch.addmm(bias, input, weight)
|
||||
|
||||
inps = (
|
||||
torch.randn(2, 320, device="cuda", dtype=torch.half),
|
||||
torch.randn(320, 320, device="cuda", dtype=torch.half).t(),
|
||||
torch.empty(320, device="cuda", dtype=torch.half),
|
||||
)
|
||||
|
||||
foo(*inps)
|
||||
# Autotuning checks correctness of each version
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
|
||||
@ -4127,18 +4127,38 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
|
||||
|
||||
@requires_decomp(aten.nll_loss_forward)
|
||||
def test_nll_loss_forward(self):
|
||||
def fn(a, b):
|
||||
return aten.nll_loss_forward(a, b, None, 1, -100)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
torch.randn([5, 5]),
|
||||
torch.zeros([5], dtype=torch.int64),
|
||||
),
|
||||
labels = (
|
||||
torch.zeros([5], dtype=torch.int64),
|
||||
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
|
||||
)
|
||||
inps = (torch.randn(5, 5), torch.randn(5, 5))
|
||||
for a, b in zip(inps, labels):
|
||||
self.common(
|
||||
fn,
|
||||
(a, b),
|
||||
)
|
||||
|
||||
def test_nll_loss_backward(self):
|
||||
def fn(a, b, c):
|
||||
return aten.nll_loss_backward(
|
||||
a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device)
|
||||
)
|
||||
|
||||
labels = (
|
||||
torch.zeros([5], dtype=torch.int64),
|
||||
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
|
||||
)
|
||||
inps = (torch.randn(5, 5), torch.randn(5, 5))
|
||||
grad_outs = (torch.randn(()), torch.randn(()))
|
||||
for a, b, c in zip(grad_outs, inps, labels):
|
||||
self.common(
|
||||
fn,
|
||||
(a, b, c),
|
||||
)
|
||||
|
||||
def test_isinf(self):
|
||||
def fn(x):
|
||||
@ -5613,6 +5633,22 @@ class CommonTemplate:
|
||||
eager_out = eager_mod(*eager_args)
|
||||
self.assertEqual(inductor_out, eager_out)
|
||||
|
||||
def test_where_with_logical_op(self):
|
||||
def fn_and(x, y):
|
||||
return torch.where(torch.logical_and(x, y), 1.0, 0.0)
|
||||
|
||||
def fn_or(x, y):
|
||||
return torch.where(torch.logical_or(x, y), 1.0, 0.0)
|
||||
|
||||
self.common(
|
||||
fn_and,
|
||||
(torch.randn(32), torch.randn(32)),
|
||||
)
|
||||
self.common(
|
||||
fn_or,
|
||||
(torch.randn(32), torch.randn(32)),
|
||||
)
|
||||
|
||||
|
||||
test_skips = {
|
||||
"test_alexnet_prefix_dynamic_shapes": ("cuda",),
|
||||
@ -5956,6 +5992,8 @@ if HAS_CPU:
|
||||
"randn",
|
||||
"isnan",
|
||||
"rand",
|
||||
"logical_and",
|
||||
"logical_or",
|
||||
]
|
||||
union = {*cpp_vec_op_list, *diff}
|
||||
self.assertTrue(set(cpp_op_list).issubset(union))
|
||||
|
||||
@ -448,6 +448,7 @@ inductor_all_samples = {
|
||||
"mT",
|
||||
"mH",
|
||||
"rsub",
|
||||
"triu",
|
||||
}
|
||||
|
||||
|
||||
|
||||
491
test/test_mps.py
491
test/test_mps.py
@ -435,6 +435,27 @@ class TestMPS(TestCaseMPS):
|
||||
helper(0, [1024])
|
||||
helper(0.2, [2, 3])
|
||||
|
||||
def test_fill_storage_offset(self):
|
||||
shape = [2, 10]
|
||||
val = 0.2
|
||||
tensor = torch.ones(shape, device="mps")
|
||||
tensor_mps = tensor[:][1].fill_(val)
|
||||
tensor_0 = torch.ones(shape, device="cpu")
|
||||
tensor_cpu = tensor_0[:][1].fill_(val)
|
||||
|
||||
self.assertEqual(tensor_mps, tensor_cpu)
|
||||
|
||||
shape = [1, 10]
|
||||
val = 0.0
|
||||
tensor = torch.ones(shape, device="mps")
|
||||
val_tensor_mps = torch.tensor(val, device="mps")
|
||||
tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
|
||||
tensor_0 = torch.ones(shape, device="cpu")
|
||||
val_tensor_cpu = torch.tensor(val, device="cpu")
|
||||
tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
|
||||
|
||||
self.assertEqual(tensor_mps, tensor_cpu)
|
||||
|
||||
def test_cdist_large(self, device="mps"):
|
||||
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
|
||||
x = torch.randn(100, 10, device=device)
|
||||
@ -1786,6 +1807,87 @@ class TestMPS(TestCaseMPS):
|
||||
x_cpu = x_cpu + 2
|
||||
self.assertEqual(x, x_cpu)
|
||||
|
||||
def test_reshape_storage_offset(self):
|
||||
# https://github.com/pytorch/pytorch/issues/95883
|
||||
B = 4
|
||||
T = 1
|
||||
|
||||
lin_cpu = nn.Linear(10, 256)
|
||||
lin_mps = nn.Linear(10, 256, device="mps")
|
||||
|
||||
# Use the same weights and bias as the ones from the cpu
|
||||
lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
|
||||
lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
|
||||
|
||||
x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
|
||||
x_cpu = x_mps.detach().clone().cpu().requires_grad_()
|
||||
x_mps = lin_mps(x_mps)
|
||||
x_cpu = lin_cpu(x_cpu)
|
||||
|
||||
self.assertEqual(x_mps.shape, (B, T, 256))
|
||||
self.assertEqual(x_cpu.shape, (B, T, 256))
|
||||
|
||||
cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
|
||||
cls_token_cpu = cls_token_mps.detach().clone().cpu()
|
||||
x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
|
||||
x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
|
||||
|
||||
x_mps = x_mps.transpose(0, 1)
|
||||
x_cpu = x_cpu.transpose(0, 1)
|
||||
|
||||
target_mps = torch.rand_like(x_mps)
|
||||
target_cpu = target_mps.detach().clone().cpu()
|
||||
loss_mps = F.mse_loss(x_mps, target_mps)
|
||||
loss_cpu = F.mse_loss(x_cpu, target_cpu)
|
||||
self.assertEqual(loss_mps, loss_cpu)
|
||||
|
||||
loss_mps.backward()
|
||||
loss_cpu.backward()
|
||||
self.assertEqual(x_mps.grad, x_cpu.grad)
|
||||
|
||||
def test_stack(self):
|
||||
# https://github.com/pytorch/pytorch/issues/87856
|
||||
x_cpu = torch.tensor([[1, 2]])
|
||||
x_mps = x_cpu.detach().clone().to("mps")
|
||||
|
||||
y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
|
||||
y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
|
||||
|
||||
self.assertEqual(y_cpu, y_mps)
|
||||
|
||||
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
|
||||
t_cpu = t_mps.detach().cpu().detach()
|
||||
|
||||
x_mps = t_mps[2:]
|
||||
y_mps = t_mps[:2]
|
||||
|
||||
x_cpu = t_cpu[2:]
|
||||
y_cpu = t_cpu[:2]
|
||||
|
||||
res_mps = torch.stack((y_mps, x_mps), dim=-1)
|
||||
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
|
||||
|
||||
self.assertEqual(res_mps, res_cpu)
|
||||
|
||||
def test_unsafe_chunk(self):
|
||||
# https://github.com/pytorch/pytorch/issues/91065
|
||||
a = torch.rand(5, dtype=torch.float32, device="cpu")
|
||||
ret = a.unsafe_chunk(4, 0)
|
||||
y = ret[0] * ret[2]
|
||||
a_mps = a.to("mps")
|
||||
ret_mps = a_mps.unsafe_chunk(4, 0)
|
||||
y_mps = ret_mps[0] * ret_mps[2]
|
||||
self.assertEqual(y, y_mps)
|
||||
|
||||
def test_slice_casting(self):
|
||||
# generate random binary numbers
|
||||
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
|
||||
mps_in = cpu_in.detach().clone().to("mps")
|
||||
# check copy_cast(unit8 -> bool) on tensors with storage offset
|
||||
cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
|
||||
mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
|
||||
self.assertEqual(cpu_out, mps_out)
|
||||
|
||||
def test_slice_reshape_contg_view(self):
|
||||
import torch
|
||||
|
||||
@ -1797,6 +1899,72 @@ class TestMPS(TestCaseMPS):
|
||||
|
||||
self.assertEqual(r_mps, r_cpu)
|
||||
|
||||
def test_contiguous_slice_2d(self):
|
||||
def helper(shape):
|
||||
for i in range(0, shape[0]):
|
||||
for j in range(0, shape[1]):
|
||||
t_mps = torch.randn(shape, device="mps")
|
||||
t_cpu = t_mps.detach().clone().cpu()
|
||||
|
||||
y_mps = t_mps[i:, :j]
|
||||
y_cpu = t_cpu[i:, :j]
|
||||
self.assertEqual(y_mps + 1, y_cpu + 1)
|
||||
|
||||
y_mps = t_mps[i:, j]
|
||||
y_cpu = t_cpu[i:, j]
|
||||
self.assertEqual(y_mps + 1, y_cpu + 1)
|
||||
|
||||
y_mps = t_mps[i, :j]
|
||||
y_cpu = t_cpu[i, :j]
|
||||
self.assertEqual(y_mps + 1, y_cpu + 1)
|
||||
|
||||
y_mps = t_mps[:i, :j]
|
||||
y_cpu = t_cpu[:i, :j]
|
||||
self.assertEqual(y_mps + 1, y_cpu + 1)
|
||||
|
||||
y_mps = t_mps[:i, j]
|
||||
y_cpu = t_cpu[:i, j]
|
||||
self.assertEqual(y_mps + 1, y_cpu + 1)
|
||||
|
||||
y_mps = t_mps[:i, j:]
|
||||
y_cpu = t_cpu[:i, j:]
|
||||
self.assertEqual(y_mps + 1, y_cpu + 1)
|
||||
|
||||
l = []
|
||||
for N in range(1, 3):
|
||||
l.append(N)
|
||||
for C in range(1, 3):
|
||||
l.append(C)
|
||||
helper(l)
|
||||
for D in range(1, 3):
|
||||
l.append(D)
|
||||
helper(l)
|
||||
for H in range(1, 3):
|
||||
l.append(H)
|
||||
helper(l)
|
||||
for W in range(1, 3):
|
||||
l.append(W)
|
||||
helper(l)
|
||||
l.pop()
|
||||
l.pop()
|
||||
l.pop()
|
||||
l.pop()
|
||||
l.pop()
|
||||
|
||||
helper([9, 15, 4])
|
||||
helper([9, 3, 2])
|
||||
helper([3, 4, 18, 22])
|
||||
helper([3, 4, 18, 22, 150])
|
||||
|
||||
def test_contiguous_slice_3d(self):
|
||||
x = torch.randn(2, 3, 3, device="mps")
|
||||
x_cpu = x.detach().clone().cpu()
|
||||
x = x[:1]
|
||||
x_cpu = x_cpu[:1]
|
||||
out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
|
||||
out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
|
||||
self.assertEqual(out, out_cpu)
|
||||
|
||||
def test_view_slice(self):
|
||||
# https://github.com/pytorch/pytorch/issues/83995
|
||||
NUM_SAMPLES = 60
|
||||
@ -1890,25 +2058,28 @@ class TestMPS(TestCaseMPS):
|
||||
if operator == "<=":
|
||||
res_mps = x_mps <= y_mps
|
||||
res_cpu = x_cpu <= y_cpu
|
||||
if operator == "<":
|
||||
elif operator == "<":
|
||||
res_mps = x_mps < y_mps
|
||||
res_cpu = x_cpu < y_cpu
|
||||
if operator == ">=":
|
||||
elif operator == ">=":
|
||||
res_mps = x_mps >= y_mps
|
||||
res_cpu = x_cpu >= y_cpu
|
||||
if operator == ">":
|
||||
elif operator == ">":
|
||||
res_mps = x_mps >= y_mps
|
||||
res_cpu = x_cpu >= y_cpu
|
||||
if operator == "==":
|
||||
elif operator == "==":
|
||||
res_mps = x_mps == y_mps
|
||||
res_cpu = x_cpu == y_cpu
|
||||
if operator == "!=":
|
||||
elif operator == "!=":
|
||||
res_mps = x_mps != y_mps
|
||||
res_cpu = x_cpu != y_cpu
|
||||
elif operator == "stack":
|
||||
res_mps = torch.stack((y_mps, x_mps), dim=-1)
|
||||
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
|
||||
|
||||
self.assertEqual(res_mps, res_cpu)
|
||||
|
||||
for op in ["<=", "<", ">=", ">", "==", "!="]:
|
||||
for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
|
||||
helper(op)
|
||||
|
||||
def test_slice_of_slice(self):
|
||||
@ -2245,10 +2416,9 @@ class TestMPS(TestCaseMPS):
|
||||
# See https://github.com/pytorch/pytorch/issues/84995
|
||||
def test_div_bugs(self):
|
||||
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
|
||||
if dtype != torch.int64:
|
||||
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
|
||||
y = torch.div(x, 101, rounding_mode=mode)
|
||||
self.assertEqual(y.sum(), 0)
|
||||
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
|
||||
y = torch.div(x, 101, rounding_mode=mode)
|
||||
self.assertEqual(y.sum(), 0)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/82663
|
||||
def test_bool_expand(self):
|
||||
@ -3358,6 +3528,26 @@ class TestNLLLoss(TestCaseMPS):
|
||||
|
||||
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
|
||||
|
||||
def test_log_softmax_large_numbers(self):
|
||||
values = [
|
||||
[10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
|
||||
[-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
|
||||
]
|
||||
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
|
||||
mps_x = torch.tensor(values, device='mps', requires_grad=True)
|
||||
|
||||
cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
|
||||
mps_log_softmax = F.log_softmax(mps_x, dim=-1)
|
||||
self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
|
||||
|
||||
cpu_grad = torch.ones_like(cpu_log_softmax)
|
||||
mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
|
||||
|
||||
cpu_log_softmax.backward(gradient=cpu_grad)
|
||||
mps_log_softmax.backward(gradient=mps_grad)
|
||||
|
||||
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
|
||||
|
||||
def test_eq(self):
|
||||
values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
|
||||
values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
|
||||
@ -4545,9 +4735,9 @@ class TestNLLLoss(TestCaseMPS):
|
||||
)
|
||||
|
||||
def test_upsample_nearest2d(self):
|
||||
def helper(N, C, H, W):
|
||||
def helper(N, C, H, W, memory_format):
|
||||
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
|
||||
requires_grad=True).reshape(N, C, H, W)
|
||||
requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
|
||||
inputCPU.retain_grad()
|
||||
inputMPS = inputCPU.detach().to('mps').requires_grad_()
|
||||
|
||||
@ -4573,8 +4763,9 @@ class TestNLLLoss(TestCaseMPS):
|
||||
|
||||
self.assertEqual(inputCPU.grad, inputMPS.grad)
|
||||
|
||||
helper(1, 1, 4, 4)
|
||||
helper(7, 5, 3, 2)
|
||||
for memory_format in [torch.channels_last, torch.contiguous_format]:
|
||||
helper(1, 1, 4, 4, memory_format=memory_format)
|
||||
helper(7, 5, 3, 2, memory_format=memory_format)
|
||||
|
||||
def test_upsample_bilinear2d(self):
|
||||
def helper(N, C, H, W):
|
||||
@ -7716,7 +7907,8 @@ class TestConvolutionMPS(TestCaseMPS):
|
||||
def test_conv_backward_1d_channels_last(self):
|
||||
def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
|
||||
# https://github.com/pytorch/pytorch/issues/84511
|
||||
conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups)
|
||||
conv_cpu = torch.nn.Conv1d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
|
||||
conv_mps = torch.nn.Conv1d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
|
||||
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
|
||||
@ -7756,15 +7948,89 @@ class TestConvolutionMPS(TestCaseMPS):
|
||||
|
||||
def test_conv2d_all_strides_paddings(self):
|
||||
# https://github.com/pytorch/pytorch/issues/83180
|
||||
y_cpu = torch.randn(2, 2, 3, 6)
|
||||
y_gpu = y_cpu.to(device='mps')
|
||||
for strideX in range(1, 4):
|
||||
for strideY in range(1, 4):
|
||||
conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=(strideX, strideY))
|
||||
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
|
||||
x_cpu = conv_cpu(y_cpu)
|
||||
x_gpu = conv_gpu(y_gpu)
|
||||
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
|
||||
def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
|
||||
x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
|
||||
x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
|
||||
|
||||
if permute_data:
|
||||
x_cpu.permute(0, 2, 3, 1)
|
||||
x_mps.permute(0, 2, 3, 1)
|
||||
|
||||
for strideX in range(1, 4):
|
||||
for strideY in range(1, 4):
|
||||
conv_cpu = torch.nn.Conv2d(
|
||||
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
|
||||
conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
|
||||
|
||||
conv_mps = torch.nn.Conv2d(
|
||||
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
|
||||
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
|
||||
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
|
||||
|
||||
res_cpu = conv_cpu(x_cpu)
|
||||
res_mps = conv_mps(x_mps)
|
||||
self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
|
||||
|
||||
res_cpu = res_cpu.sum().backward()
|
||||
res_mps = res_mps.sum().backward()
|
||||
self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
|
||||
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
|
||||
self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
|
||||
self.assertEqual(x_cpu.grad, x_mps.grad)
|
||||
|
||||
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
|
||||
for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
|
||||
for permute_data in [True, False]:
|
||||
helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
|
||||
helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
|
||||
helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
|
||||
|
||||
def test_conv_transpose_2d_strided(self):
|
||||
def helper(m_cpu, memory_format):
|
||||
m_mps = copy.deepcopy(m_cpu).requires_grad_()
|
||||
m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
|
||||
m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
|
||||
|
||||
input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
|
||||
input_mps = input_cpu.detach().clone().to("mps")
|
||||
|
||||
output_cpu = m_cpu(input_cpu)
|
||||
output_mps = m_mps(input_mps)
|
||||
self.assertEqual(output_cpu, output_mps)
|
||||
|
||||
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
|
||||
# With square kernels and equal stride
|
||||
helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
|
||||
|
||||
# non-square kernels and unequal stride and with padding
|
||||
helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
|
||||
|
||||
def test_conv_transpose_2d_specified_output(self):
|
||||
input_cpu = torch.randn(1, 16, 12, 12)
|
||||
input_mps = input_cpu.detach().clone().to("mps")
|
||||
|
||||
downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
|
||||
downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
|
||||
downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
|
||||
downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
|
||||
|
||||
upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
|
||||
upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
|
||||
upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
|
||||
upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
|
||||
|
||||
h_cpu = downsample_cpu(input_cpu)
|
||||
h_mps = downsample_mps(input_mps)
|
||||
self.assertEqual(h_cpu, h_mps)
|
||||
|
||||
size_cpu = h_cpu.size()
|
||||
size_mps = h_mps.size()
|
||||
self.assertEqual(size_cpu, size_mps)
|
||||
|
||||
output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
|
||||
output_mps = upsample_mps(h_mps, output_size=input_mps.size())
|
||||
self.assertEqual(output_cpu, output_mps)
|
||||
self.assertEqual(output_cpu.size(), output_mps.size())
|
||||
|
||||
def test_conv2d_single_stride(self):
|
||||
y_cpu = torch.randn(2, 2, 3, 6)
|
||||
@ -8822,64 +9088,148 @@ class TestAdvancedIndexing(TestCaseMPS):
|
||||
|
||||
class TestRNNMPS(TestCaseMPS):
|
||||
def test_lstm_1(self, device="mps", dtype=torch.float32):
|
||||
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
|
||||
torch.random.manual_seed(42)
|
||||
rnn = nn.LSTM(7, 4, layers, device="cpu")
|
||||
input = torch.randn(2, 3, 7, device="cpu")
|
||||
hx = torch.randn(layers, 3, 4, device="cpu")
|
||||
cx = torch.randn(layers, 3, 4, device="cpu")
|
||||
|
||||
rnn = nn.LSTM(1, 4, 2, device="cpu")
|
||||
input = torch.randn(2, 3, 1, device="cpu")
|
||||
hx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
# test batch_first
|
||||
rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True)
|
||||
input = torch.randn(3, 2, 7, device="cpu")
|
||||
hx = torch.randn(layers, 3, 4, device="cpu")
|
||||
cx = torch.randn(layers, 3, 4, device="cpu")
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
|
||||
# test batch_first
|
||||
rnn = nn.LSTM(1, 4, 2, device="cpu", batch_first=True)
|
||||
input = torch.randn(3, 2, 1, device="cpu")
|
||||
hx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
def test_lstm_backward(self, device="mps", dtype=torch.float32):
|
||||
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
|
||||
lstm = nn.LSTM(2, 4, layers) # initialized globally for consistent parameters init
|
||||
lstm.train()
|
||||
|
||||
@unittest.skipIf(True, "Backward of lstm returns wrong result")
|
||||
def test_lstm_2(self, device="mps", dtype=torch.float32):
|
||||
def get_results(device):
|
||||
rnn = nn.LSTM(1, 4, 1, device=device)
|
||||
inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
|
||||
hx = torch.zeros(1, 3, 4, device=device)
|
||||
cx = torch.zeros(1, 3, 4, device=device)
|
||||
def get_results(device, inp, hx, cx):
|
||||
rnn = lstm.to(device)
|
||||
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
|
||||
|
||||
output, _ = rnn(inp, (hx, cx))
|
||||
output.sum().backward()
|
||||
output, _ = rnn(inp, (hx, cx))
|
||||
f = output.sum()
|
||||
|
||||
weight_grad = rnn.weight_ih_l0.grad.clone()
|
||||
input_grad = inp.grad.clone()
|
||||
param_names, params = zip(*rnn.named_parameters())
|
||||
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
|
||||
|
||||
return output, weight_grad, input_grad
|
||||
input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
|
||||
return output, param_grads, input_grad, hx_grad, cx_grad
|
||||
|
||||
inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device)
|
||||
hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
|
||||
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
|
||||
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
|
||||
|
||||
self.assertEqual(cpu_hx_grad, mps_hx_grad)
|
||||
self.assertEqual(cpu_cx_grad, mps_cx_grad)
|
||||
self.assertEqual(cpu_output, mps_output)
|
||||
self.assertEqual(cpu_input_grad, mps_input_grad)
|
||||
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
|
||||
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
|
||||
|
||||
# test batch_first backward
|
||||
lstm = nn.LSTM(2, 4, layers, batch_first=True)
|
||||
lstm.train()
|
||||
|
||||
hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
|
||||
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
|
||||
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
|
||||
|
||||
self.assertEqual(cpu_hx_grad, mps_hx_grad)
|
||||
self.assertEqual(cpu_cx_grad, mps_cx_grad)
|
||||
self.assertEqual(cpu_output, mps_output)
|
||||
self.assertEqual(cpu_input_grad, mps_input_grad)
|
||||
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
|
||||
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
|
||||
|
||||
|
||||
cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
|
||||
mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
|
||||
def test_RNN_cell_no_broadcasting(self):
|
||||
def test(cell_module, input, hx, input_size, hidden_size):
|
||||
cell = cell_module(input_size, hidden_size, device='mps')
|
||||
self.assertRaises(RuntimeError, lambda: cell(input, hx))
|
||||
|
||||
def test_all(hidden_size, bad_hx, good_hx, input_size, input):
|
||||
test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
|
||||
test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
|
||||
test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
|
||||
test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
|
||||
|
||||
hidden_size = 20
|
||||
input_size = 10
|
||||
input = torch.randn(3, input_size, device='mps')
|
||||
bad_hx = torch.randn(1, hidden_size, device='mps')
|
||||
good_hx = torch.randn(3, hidden_size, device='mps')
|
||||
|
||||
# Test hidden/input batch size broadcasting
|
||||
test_all(hidden_size, bad_hx, good_hx, input_size, input)
|
||||
|
||||
# Test hx's hidden_size vs module's hidden_size broadcasting
|
||||
bad_hx = torch.randn(3, 1)
|
||||
test_all(hidden_size, bad_hx, good_hx, input_size, input)
|
||||
|
||||
# Test input's input_size vs module's input_size broadcasting
|
||||
bad_input = torch.randn(3, 1)
|
||||
test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
|
||||
|
||||
def test_LSTM_cell(self):
|
||||
# this is just a smoke test; these modules are implemented through
|
||||
# autograd so no Jacobian test is needed
|
||||
for bias in (True, False):
|
||||
input = torch.randn(3, 10, device='mps')
|
||||
hx = torch.randn(3, 20, device='mps')
|
||||
cx = torch.randn(3, 20, device='mps')
|
||||
lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
|
||||
for _ in range(6):
|
||||
hx, cx = lstm(input, (hx, cx))
|
||||
|
||||
(hx + cx).sum().backward()
|
||||
|
||||
def test_LSTM_cell_forward_input_size(self):
|
||||
input = torch.randn(3, 11, device='mps')
|
||||
hx = torch.randn(3, 20, device='mps')
|
||||
cx = torch.randn(3, 20, device='mps')
|
||||
lstm = nn.LSTMCell(10, 20, device='mps')
|
||||
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
|
||||
|
||||
def test_LSTM_cell_forward_hidden_size(self):
|
||||
input = torch.randn(3, 10, device='mps')
|
||||
hx = torch.randn(3, 21, device='mps')
|
||||
cx = torch.randn(3, 20, device='mps')
|
||||
lstm = nn.LSTMCell(10, 20, device='mps')
|
||||
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
|
||||
self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
|
||||
|
||||
self.assertEqual(cpu_output, mps_output)
|
||||
self.assertEqual(cpu_input_grad, mps_input_grad)
|
||||
self.assertEqual(cpu_weight_grad, mps_weight_grad)
|
||||
|
||||
class TestFallbackWarning(TestCase):
|
||||
# TODO: Remove once test_testing.py is running on MPS devices
|
||||
@ -9152,6 +9502,7 @@ class TestConsistency(TestCaseMPS):
|
||||
'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'linalg.matrix_norm': ['f16'],
|
||||
'linalg.matrix_power': ['f32'],
|
||||
'linalg.svd': ['f32'],
|
||||
'linalg.vector_norm': ['f16', 'f32'],
|
||||
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
@ -9315,6 +9666,7 @@ class TestConsistency(TestCaseMPS):
|
||||
'nn.functional.bilinear': ['f32'],
|
||||
'linalg.solve_triangular': ['f32'],
|
||||
'triangular_solve': ['f32'],
|
||||
'trace': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'_native_batch_norm_legit': ['f32'],
|
||||
'native_batch_norm': ['f32'],
|
||||
'minreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
@ -9511,6 +9863,8 @@ class TestConsistency(TestCaseMPS):
|
||||
'native_batch_norm': ['f32'],
|
||||
'native_layer_norm': ['f32'],
|
||||
'nn.functional.gelu': ['f32'],
|
||||
'nn.functional.bilinear': ['f32'],
|
||||
'nn.functional.prelu': ['f32'],
|
||||
}
|
||||
|
||||
# These ops that are problematic. So never run them even when
|
||||
@ -9530,7 +9884,6 @@ class TestConsistency(TestCaseMPS):
|
||||
'stft': [torch.float32], 'var': [torch.float16],
|
||||
# + forward when requires_grad=True or running backward
|
||||
'nn.functional.embedding': [torch.float32, torch.float16],
|
||||
'__rpow__': [torch.int64],
|
||||
|
||||
'as_strided_scatter': [torch.uint8],
|
||||
'atan2': [torch.int64],
|
||||
|
||||
@ -46,9 +46,10 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfRocm,
|
||||
skipIfTorchDynamo
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
|
||||
from typing import Dict, Any, Tuple
|
||||
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
||||
from unittest.mock import patch
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -252,21 +253,26 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
# Make sure that optimizers that support maximize can load older models
|
||||
state_dict = optimizer.state_dict()
|
||||
if "maximize" in state_dict["param_groups"][0]:
|
||||
for group in state_dict["param_groups"]:
|
||||
old_state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_no_maximize = deepcopy(optimizer.state_dict())
|
||||
if "maximize" in state_dict_no_maximize["param_groups"][0]:
|
||||
for group in state_dict_no_maximize["param_groups"]:
|
||||
del group["maximize"]
|
||||
optimizer.load_state_dict(state_dict)
|
||||
optimizer.load_state_dict(state_dict_no_maximize)
|
||||
# Make sure we can still step
|
||||
optimizer.step()
|
||||
# Undo these changes before proceeding!
|
||||
optimizer.load_state_dict(old_state_dict)
|
||||
# Make sure that optimizers that support foreach can load older models
|
||||
state_dict = optimizer.state_dict()
|
||||
if "foreach" in state_dict["param_groups"][0]:
|
||||
for group in state_dict["param_groups"]:
|
||||
state_dict_no_foreach = deepcopy(optimizer.state_dict())
|
||||
if "foreach" in state_dict_no_foreach["param_groups"][0]:
|
||||
for group in state_dict_no_foreach["param_groups"]:
|
||||
del group["foreach"]
|
||||
optimizer.load_state_dict(state_dict)
|
||||
optimizer.load_state_dict(state_dict_no_foreach)
|
||||
# Make sure we can still step
|
||||
optimizer.step()
|
||||
# Undo these changes before proceeding!
|
||||
optimizer.load_state_dict(old_state_dict)
|
||||
|
||||
# Make sure that loading optimizers with step not wrapped in tensor can work
|
||||
state_dict = optimizer.state_dict()
|
||||
@ -4535,5 +4541,39 @@ class TestDifferentiableOptimizer(TestCase):
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
|
||||
def test_defaults_changed_to_foreach(self):
|
||||
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
|
||||
asgd, adamax, adadelta, adagrad)
|
||||
multi_optims = ((optim.Adam, adam, "_multi_tensor_adam"),
|
||||
(optim.AdamW, adamw, "_multi_tensor_adamw"),
|
||||
(optim.NAdam, nadam, "_multi_tensor_nadam"),
|
||||
(optim.SGD, sgd, "_multi_tensor_sgd"),
|
||||
(optim.RAdam, radam, "_multi_tensor_radam"),
|
||||
(optim.RMSprop, rmsprop, "_multi_tensor_rmsprop"),
|
||||
(optim.Rprop, rprop, "_multi_tensor_rprop"),
|
||||
(optim.ASGD, asgd, "_multi_tensor_asgd"),
|
||||
(optim.Adamax, adamax, "_multi_tensor_adamax"),
|
||||
(optim.Adadelta, adadelta, "_multi_tensor_adadelta"),
|
||||
(optim.Adagrad, adagrad, "_multi_tensor_adagrad"),)
|
||||
|
||||
model = torch.nn.Linear(5, 5)
|
||||
model.to(dtype=torch.float64, device="cuda")
|
||||
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
|
||||
|
||||
for opt, mod, func in multi_optims:
|
||||
defaults = {}
|
||||
if opt == optim.SGD:
|
||||
defaults["lr"] = 1e-2
|
||||
optimizer = opt(model.parameters(), **defaults)
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
with patch.object(mod, func) as mocked_foreach_impl:
|
||||
optimizer.step()
|
||||
self.assertTrue(mocked_foreach_impl.called)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -736,6 +736,15 @@ class SerializationMixin:
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a.storage(), s_bytes], f)
|
||||
|
||||
def test_safe_load_basic_types(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
data = {"int": 123, "str": "world", "float": 3.14, "bool": False}
|
||||
torch.save(data, f)
|
||||
f.seek(0)
|
||||
loaded_data = torch.load(f, weights_only=True)
|
||||
self.assertEqual(data, loaded_data)
|
||||
|
||||
|
||||
class serialization_method:
|
||||
def __init__(self, use_zip):
|
||||
self.use_zip = use_zip
|
||||
|
||||
@ -2170,8 +2170,8 @@
|
||||
input, weight, bias: linear_backward(input, grad, weight, grad_input_mask)
|
||||
|
||||
#mps
|
||||
- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
@ -2564,11 +2564,11 @@
|
||||
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
#LSTM MPS
|
||||
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
output_differentiability: [True, True, True, False, False]
|
||||
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
|
||||
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
output_differentiability: [True, True, True, False, False, False]
|
||||
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
|
||||
|
||||
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
|
||||
|
||||
|
||||
|
||||
@ -215,6 +215,10 @@ def main():
|
||||
f"ROCM version: {rocm_ver}\n"
|
||||
)
|
||||
for args in _SANITY_CHECK_ARGS:
|
||||
# TODO remove check when 3.11 is supported
|
||||
if sys.version_info >= (3, 11):
|
||||
warnings.warn("Dynamo not yet supported in Python 3.11. Skipping check.")
|
||||
continue
|
||||
check_dynamo(*args)
|
||||
print("All required checks passed")
|
||||
|
||||
|
||||
@ -308,6 +308,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
|
||||
aten.trace,
|
||||
aten.transpose.int,
|
||||
aten.tril.default,
|
||||
aten.triu.default,
|
||||
aten.unfold,
|
||||
aten.unfold_backward,
|
||||
aten.upsample_bilinear2d,
|
||||
|
||||
@ -405,8 +405,9 @@ def _nll_loss_backward(
|
||||
grad_output = grad_output / total_weight
|
||||
|
||||
target = target.unsqueeze(channel_dim)
|
||||
safe_target = torch.where(target != ignore_index, target, 0)
|
||||
grad_input = torch.zeros_like(self)
|
||||
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
|
||||
grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
|
||||
|
||||
if grad_input.dim() > grad_output.dim() > 0:
|
||||
grad_output = grad_output.unsqueeze(channel_dim)
|
||||
@ -417,9 +418,7 @@ def _nll_loss_backward(
|
||||
weight = weight.reshape(new_shape)
|
||||
grad_output = grad_output * weight
|
||||
|
||||
has_ignore_index = ignore_index >= 0
|
||||
if has_ignore_index:
|
||||
grad_output = torch.where(target != ignore_index, grad_output, 0)
|
||||
grad_output = torch.where(target != ignore_index, grad_output, 0)
|
||||
|
||||
return grad_input * grad_output
|
||||
|
||||
@ -2798,14 +2797,13 @@ def nll_loss_forward(
|
||||
if weight is not None:
|
||||
w = weight.unsqueeze(0) if n_dims > 1 else weight
|
||||
self = self * w
|
||||
|
||||
target_ = target.unsqueeze(channel_dim)
|
||||
safe_target = torch.where(target != ignore_index, target, 0)
|
||||
safe_target_ = safe_target.unsqueeze(channel_dim)
|
||||
# target can be [N, 1] or [1]
|
||||
|
||||
result = -torch.gather(self, channel_dim, target_).squeeze(channel_dim)
|
||||
result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
|
||||
if ignore_index >= 0:
|
||||
result = torch.where(target != ignore_index, result, 0)
|
||||
result = torch.where(target != ignore_index, result, 0)
|
||||
|
||||
if reduction == Reduction.NONE.value and n_dims > 1:
|
||||
total_weight = self.new_full((), 0.0)
|
||||
@ -2813,22 +2811,16 @@ def nll_loss_forward(
|
||||
|
||||
if weight is not None:
|
||||
w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
|
||||
wsum = torch.gather(w, channel_dim, target_).squeeze(channel_dim)
|
||||
if ignore_index >= 0:
|
||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||
total_weight = wsum.sum()
|
||||
elif ignore_index >= 0:
|
||||
total_weight = (target != ignore_index).sum().to(self)
|
||||
else:
|
||||
total_weight = self.new_full((), 1.0 * result.numel())
|
||||
total_weight = (target != ignore_index).sum().to(self)
|
||||
|
||||
if reduction == Reduction.SUM.value:
|
||||
result = result.sum()
|
||||
elif reduction == Reduction.MEAN.value:
|
||||
if weight is None:
|
||||
result = result.sum() / total_weight if ignore_index >= 0 else result.mean()
|
||||
else:
|
||||
result = result.sum() / total_weight
|
||||
result = result.sum() / total_weight
|
||||
|
||||
return result, total_weight
|
||||
|
||||
|
||||
@ -116,8 +116,3 @@ def onnxrt(gm, example_inputs, *, filename=None, provider=None):
|
||||
return outputs
|
||||
|
||||
return _call
|
||||
|
||||
|
||||
@register_backend
|
||||
def tensorrt(gm, example_inputs):
|
||||
return onnxrt(gm, example_inputs, provider="TensorrtExecutionProvider")
|
||||
|
||||
12
torch/_dynamo/backends/tensorrt.py
Normal file
12
torch/_dynamo/backends/tensorrt.py
Normal file
@ -0,0 +1,12 @@
|
||||
# import torch # type: ignore[import]
|
||||
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
|
||||
# from .registry import register_backend # type: ignore[import]
|
||||
|
||||
"""
|
||||
Placeholder for TensorRT backend for dynamo via torch-tensorrt
|
||||
"""
|
||||
|
||||
# @register_backend
|
||||
# def tensorrt(gm, example_inputs):
|
||||
# import torch_tensorrt # type: ignore[import]
|
||||
# pass
|
||||
@ -370,6 +370,13 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
||||
return fn
|
||||
|
||||
|
||||
def check_if_dynamo_supported():
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("Windows not yet supported for torch.compile")
|
||||
if sys.version_info >= (3, 11):
|
||||
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
|
||||
|
||||
|
||||
def optimize(
|
||||
backend="inductor",
|
||||
*,
|
||||
@ -403,6 +410,7 @@ def optimize(
|
||||
def toy_example(a, b):
|
||||
...
|
||||
"""
|
||||
check_if_dynamo_supported()
|
||||
# Note: The hooks object could be global instead of passed around, *however* that would make
|
||||
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
|
||||
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
|
||||
@ -412,14 +420,6 @@ def optimize(
|
||||
torch._C._log_api_usage_once("torch._dynamo.optimize")
|
||||
if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
|
||||
return _NullDecorator()
|
||||
if sys.platform == "win32":
|
||||
warnings.warn(
|
||||
"Windows is not currently supported, torch.compile() will do nothing"
|
||||
)
|
||||
return _NullDecorator()
|
||||
if sys.version_info >= (3, 11):
|
||||
warnings.warn("Python 3.11+ not yet supported, torch.compile() will do nothing")
|
||||
return _NullDecorator()
|
||||
|
||||
backend = get_compiler_fn(backend)
|
||||
|
||||
@ -521,6 +521,7 @@ def explain(f, *args, **kwargs):
|
||||
def export(
|
||||
f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
|
||||
):
|
||||
check_if_dynamo_supported()
|
||||
torch._C._log_api_usage_once("torch._dynamo.export")
|
||||
if decomposition_table is not None or tracing_mode != "real":
|
||||
assert (
|
||||
|
||||
@ -481,9 +481,34 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
if self.value == torch._C._nn.scaled_dot_product_attention:
|
||||
# See:[Note] SDPA_flash's meta function returns incorrect Philox seed and offset
|
||||
# in pytorch/torch/_meta_registrations.py
|
||||
fake_query = args[0].as_proxy().node.meta["example_value"]
|
||||
fake_key = args[1].as_proxy().node.meta["example_value"]
|
||||
fake_value = args[2].as_proxy().node.meta["example_value"]
|
||||
all_kwargs = kwargs.copy()
|
||||
all_kwargs.update(
|
||||
dict(
|
||||
zip(
|
||||
(
|
||||
"query",
|
||||
"key",
|
||||
"value",
|
||||
"attn_mask",
|
||||
"dropout_p",
|
||||
"is_causal",
|
||||
),
|
||||
args,
|
||||
)
|
||||
)
|
||||
)
|
||||
fake_query = all_kwargs["query"].as_proxy().node.meta["example_value"]
|
||||
fake_key = all_kwargs["key"].as_proxy().node.meta["example_value"]
|
||||
fake_value = all_kwargs["value"].as_proxy().node.meta["example_value"]
|
||||
fake_mask = all_kwargs.get("attn_mask")
|
||||
if isinstance(fake_mask, TensorVariable):
|
||||
fake_mask = fake_mask.as_proxy().node.meta["example_value"]
|
||||
else:
|
||||
fake_mask = None
|
||||
dropout_p = kwargs.get("dropout_p")
|
||||
dropout_p = dropout_p.value if dropout_p is not None else 0.0
|
||||
is_causal = kwargs.get("is_causal")
|
||||
is_causal = is_causal.value if is_causal is not None else False
|
||||
# We look through the stack to find a cuda autocast context
|
||||
# If we do we will convert the fake tensors to torch.float16
|
||||
is_cuda_autocast_context = False
|
||||
@ -502,15 +527,10 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
fake_value = fake_value.clone().to(amp_dtype)
|
||||
|
||||
backend_choice = torch._fused_sdp_choice(
|
||||
fake_query, fake_key, fake_value
|
||||
fake_query, fake_key, fake_value, fake_mask, dropout_p, is_causal
|
||||
)
|
||||
if backend_choice == torch.backends.cuda.SDPBackend.FLASH_ATTENTION:
|
||||
dropout_p = kwargs.get("dropout_p")
|
||||
# Lets see if they passed it in as not an arg
|
||||
if len(args) >= 5:
|
||||
dropout_p = args[4]
|
||||
|
||||
if dropout_p is not None and dropout_p.value != 0.0:
|
||||
if dropout_p is not None and dropout_p != 0.0:
|
||||
unimplemented(
|
||||
"FlashAttention with dropout is not supported in cuda graphs"
|
||||
)
|
||||
|
||||
@ -1870,6 +1870,9 @@ def create_runtime_wrapper(
|
||||
trace_joint: bool,
|
||||
keep_input_mutations: bool,
|
||||
):
|
||||
if not hasattr(compiled_fn, "_boxed_call"):
|
||||
compiled_fn = make_boxed_func(compiled_fn)
|
||||
|
||||
def runtime_wrapper(*args):
|
||||
# Step 2: remove aliased inputs that are mutated, replace with synthetic bases
|
||||
# Only happens if our graph mutates an input that aliases another input.
|
||||
@ -2180,7 +2183,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
|
||||
assert all(
|
||||
[isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]
|
||||
)
|
||||
ctx.save_for_backward(*tensors_saved_for_backwards)
|
||||
# See Note [Detaching saved tensors in AOTAutograd]
|
||||
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
|
||||
symint_outs = fw_outs[-num_symints_saved_for_bw:]
|
||||
assert all(
|
||||
[
|
||||
@ -2190,7 +2194,9 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
|
||||
)
|
||||
ctx.symints = symint_outs
|
||||
else:
|
||||
ctx.save_for_backward(*fw_outs[num_forward_returns:])
|
||||
tensors_saved_for_backwards = fw_outs[num_forward_returns:]
|
||||
# See Note [Detaching saved tensors in AOTAutograd]
|
||||
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
|
||||
ctx.symints = []
|
||||
|
||||
raw_returns = fw_outs[0:num_forward_returns]
|
||||
@ -2299,6 +2305,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
|
||||
contiguous_args = [
|
||||
t.contiguous() if torch.is_tensor(t) else t for t in flat_bw_args
|
||||
]
|
||||
|
||||
all_args = (
|
||||
list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args)
|
||||
)
|
||||
|
||||
@ -361,6 +361,8 @@ class CppVecOverrides(OpOverrides):
|
||||
def lgamma(x):
|
||||
return f"{x}.lgamma()"
|
||||
|
||||
"""
|
||||
#TODO: support logical_and and logical_or vectorization
|
||||
@staticmethod
|
||||
def logical_and(a, b):
|
||||
return f"{a} && {b}"
|
||||
@ -368,6 +370,7 @@ class CppVecOverrides(OpOverrides):
|
||||
@staticmethod
|
||||
def logical_or(a, b):
|
||||
return f"{a} || {b}"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def tan(a):
|
||||
|
||||
@ -294,15 +294,6 @@ class WrapperCodeGen(CodeGen):
|
||||
"""
|
||||
)
|
||||
|
||||
if config.triton.convolution != "aten":
|
||||
self.header.splice(
|
||||
"""
|
||||
from torch._inductor.triton_ops.conv_perf_model import early_config_prune
|
||||
from torch._inductor.triton_ops.conv_perf_model import estimate_conv_time
|
||||
from torch._inductor.triton_ops.autotune import conv_heuristics
|
||||
"""
|
||||
)
|
||||
|
||||
self.write_prefix()
|
||||
|
||||
for name, value in V.graph.constants.items():
|
||||
|
||||
@ -153,9 +153,6 @@ class triton:
|
||||
# Synchronize after every kernel launch, to help pinpoint bugs
|
||||
debug_sync_kernel = False
|
||||
|
||||
# choose conv backend, "aten" or "triton"
|
||||
convolution = "aten"
|
||||
|
||||
# Always load full blocks (rather than broadcasting inside the block)
|
||||
dense_indexing = False
|
||||
|
||||
|
||||
@ -2184,20 +2184,10 @@ class ComputedBuffer(Buffer):
|
||||
for reads_name in body.reads_name2expr.keys()
|
||||
]
|
||||
priority_idx = []
|
||||
if config.triton.convolution == "aten":
|
||||
memory_addrs = [
|
||||
*body.reads_name2expr.values(),
|
||||
*body.writes_name2expr.values(),
|
||||
]
|
||||
else:
|
||||
# prioritize reads layout/loop_ordering over writes
|
||||
if len(body.reads_name2expr.values()) > 0:
|
||||
memory_addrs = [*body.reads_name2expr.values()]
|
||||
else:
|
||||
memory_addrs = [*body.writes_name2expr.values()]
|
||||
for i, reads_buf in enumerate(reads_bufs):
|
||||
if isinstance(reads_buf, Convolution):
|
||||
priority_idx.append(i)
|
||||
memory_addrs = [
|
||||
*body.reads_name2expr.values(),
|
||||
*body.writes_name2expr.values(),
|
||||
]
|
||||
index_vars = []
|
||||
reduce_vars = []
|
||||
index_size = []
|
||||
@ -3140,12 +3130,8 @@ class Convolution(ExternKernelAlloc):
|
||||
)
|
||||
req_stride_order = get_stride_order(output.stride())
|
||||
|
||||
if config.triton.convolution == "aten":
|
||||
weight = cls.require_stride_order(weight, req_stride_order)
|
||||
x = cls.require_stride_order(x, req_stride_order)
|
||||
else:
|
||||
x = cls.require_stride1(cls.realize_input(x))
|
||||
weight = cls.require_stride1(cls.realize_input(weight))
|
||||
weight = cls.require_stride_order(weight, req_stride_order)
|
||||
x = cls.require_stride_order(x, req_stride_order)
|
||||
|
||||
stride = tuple(stride_)
|
||||
padding = tuple(padding_)
|
||||
@ -3163,7 +3149,7 @@ class Convolution(ExternKernelAlloc):
|
||||
_, _, *kernel_size = weight_shape
|
||||
|
||||
# choose runtime kernel
|
||||
config_conv = config.triton.convolution
|
||||
config_conv = "aten"
|
||||
if (
|
||||
config_conv == "aten"
|
||||
or len(kernel_size) != 2 # triton conv only supports conv2d
|
||||
@ -3196,7 +3182,7 @@ class Convolution(ExternKernelAlloc):
|
||||
)
|
||||
|
||||
# for conv2d or conv3d, prefer channels last format
|
||||
transform_x_layout = config.triton.convolution != "aten"
|
||||
transform_x_layout = False
|
||||
if kernel == "triton_ops.conv":
|
||||
output_layout_str = "torch.channels_last"
|
||||
else:
|
||||
|
||||
@ -45,7 +45,7 @@ def mm_configs():
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=8
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4
|
||||
|
||||
@ -1505,30 +1505,6 @@ def iota(
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.triu)
|
||||
def triu(x, diagonal=0):
|
||||
x_loader = x.make_loader()
|
||||
dtype = x.get_dtype()
|
||||
|
||||
def inner_fn(index):
|
||||
*_, i, j = index
|
||||
return ops.where(
|
||||
ops.ge(
|
||||
ops.index_expr(j - i - diagonal, torch.int32),
|
||||
ops.constant(0, torch.int32),
|
||||
),
|
||||
x_loader(index),
|
||||
ops.constant(0, dtype),
|
||||
)
|
||||
|
||||
return Pointwise.create(
|
||||
device=x.get_device(),
|
||||
dtype=dtype,
|
||||
inner_fn=inner_fn,
|
||||
ranges=list(x.get_size()),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.select_scatter, type_promotion_kind=None)
|
||||
def select_scatter(x, src, dim: int, index: int):
|
||||
assert x.get_dtype() == src.get_dtype()
|
||||
|
||||
@ -22,11 +22,14 @@ def dl_open_guard():
|
||||
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
|
||||
shared library to load custom operators.
|
||||
"""
|
||||
if _SET_GLOBAL_FLAGS:
|
||||
old_flags = sys.getdlopenflags()
|
||||
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
|
||||
yield
|
||||
if _SET_GLOBAL_FLAGS:
|
||||
if not _SET_GLOBAL_FLAGS:
|
||||
yield
|
||||
return
|
||||
old_flags = sys.getdlopenflags()
|
||||
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.setdlopenflags(old_flags)
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from collections import OrderedDict
|
||||
from pickle import (
|
||||
APPEND,
|
||||
APPENDS,
|
||||
BINFLOAT,
|
||||
BINGET,
|
||||
BININT,
|
||||
BININT1,
|
||||
@ -226,6 +227,8 @@ class Unpickler:
|
||||
self.append(self.read(1)[0])
|
||||
elif key[0] == BININT2[0]:
|
||||
self.append(unpack("<H", read(2))[0])
|
||||
elif key[0] == BINFLOAT[0]:
|
||||
self.append(unpack(">d", self.read(8))[0])
|
||||
elif key[0] == BINUNICODE[0]:
|
||||
strlen = unpack("<I", read(4))[0]
|
||||
if strlen > maxsize:
|
||||
|
||||
@ -320,7 +320,9 @@ void Logger::set_runtime_stats_and_log() {
|
||||
"Cuda time stats are not collected for multi-device modules.");
|
||||
return;
|
||||
}
|
||||
if (!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu()) {
|
||||
|
||||
if (!reducer_->timer_ &&
|
||||
(!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu())) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Time stats are currently only collected for CPU and CUDA devices. "
|
||||
"Please refer to CpuTimer or CudaTimer for how to register timer "
|
||||
|
||||
@ -193,8 +193,7 @@ def adadelta(
|
||||
|
||||
# We still respect when the user inputs False for foreach.
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, acc_deltas],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -210,8 +210,7 @@ def adagrad(
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, state_sums, state_steps],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
|
||||
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
|
||||
_differentiable_doc, _foreach_doc, _maximize_doc)
|
||||
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
__all__ = ['Adam', 'adam']
|
||||
@ -218,28 +218,14 @@ Adam.__doc__ = r"""Implements Adam algorithm.
|
||||
{maximize}
|
||||
{capturable}
|
||||
{differentiable}
|
||||
fused (bool, optional): whether the fused implementation (CUDA only) is used.
|
||||
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
|
||||
are supported. Since the fused implementation is usually significantly faster than
|
||||
the for-loop implementation, we try to use it whenever possible (all parameters
|
||||
are on CUDA and are of a supported type). Else, we attempt to use the foreach
|
||||
implementation and lastly fall back to the for-loop implementation. (default: None)
|
||||
|
||||
.. note:: The foreach and fused implementations are typically faster than the for-loop,
|
||||
single-tensor implementation, so we will try to default to them IF the user has
|
||||
not specified either flag (i.e., when foreach = fused = None). For example, if
|
||||
the user specifies True for foreach but nothing for fused, we will run the foreach
|
||||
implementation. If the user specifies False for fused but nothing for foreach, we will
|
||||
run the for-loop implementation. If the user specifies True for both foreach and
|
||||
fused, we will prioritize fused over foreach. We attempt to use the fastest, so the
|
||||
hierarchy goes fused -> foreach -> for-loop.
|
||||
{fused}
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
|
||||
""".format(foreach=_foreach_doc, maximize=_maximize_doc, capturable=_capturable_doc,
|
||||
differentiable=_differentiable_doc)
|
||||
differentiable=_differentiable_doc, fused=_fused_doc)
|
||||
|
||||
|
||||
def adam(params: List[Tensor],
|
||||
@ -268,10 +254,12 @@ def adam(params: List[Tensor],
|
||||
See :class:`~torch.optim.Adam` for details.
|
||||
"""
|
||||
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
fused, foreach = _default_to_fused_or_foreach(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
|
||||
differentiable, has_fused=True)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
|
||||
@ -206,8 +206,7 @@ def adamax(
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_infs, state_steps],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -2,7 +2,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
|
||||
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
|
||||
_maximize_doc, _default_to_fused_or_foreach)
|
||||
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
|
||||
from typing import List, Optional
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
@ -248,13 +248,7 @@ AdamW.__doc__ = r"""Implements AdamW algorithm.
|
||||
{foreach}
|
||||
{capturable}
|
||||
{differentiable}
|
||||
fused (bool, optional): whether the fused implementation (CUDA only) is used.
|
||||
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
|
||||
are supported. Since the fused implementation is usually significantly faster than
|
||||
the for-loop implementation, we try to use it whenever possible (all parameters
|
||||
are on CUDA and are of a supported type). Else, we continue with the for-loop
|
||||
implementation. (default: None)
|
||||
|
||||
{fused}
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
@ -262,6 +256,7 @@ AdamW.__doc__ = r"""Implements AdamW algorithm.
|
||||
|
||||
""".format(maximize=_maximize_doc,
|
||||
foreach=_foreach_doc,
|
||||
fused=_fused_doc,
|
||||
capturable=_capturable_doc,
|
||||
differentiable=_differentiable_doc)
|
||||
|
||||
@ -300,11 +295,12 @@ def adamw(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
# Respect when the user inputs False/True for foreach.
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
fused, foreach = _default_to_fused_or_foreach(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
|
||||
@ -185,8 +185,7 @@ def asgd(
|
||||
"""
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, axs, mus, etas, state_steps],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -187,8 +187,7 @@ def nadam(params: List[Tensor],
|
||||
raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
||||
|
||||
@ -15,6 +15,7 @@ from torch._utils import is_compiling
|
||||
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
|
||||
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
|
||||
|
||||
class _RequiredParameter:
|
||||
"""Singleton class representing a required parameter for an Optimizer."""
|
||||
@ -55,24 +56,21 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
|
||||
return math.sqrt(x)
|
||||
|
||||
# For any optimizer with a faster implementation, we attempt to default to the
|
||||
# fastest whenever possible. For foreach, the requirements are to have native
|
||||
# tensors all on CUDA. For fused, there's currently the additional requirement
|
||||
# fastest + stablest whenever possible. For foreach, the requirements are to have
|
||||
# native params all on CUDA. For fused, there's currently the additional requirement
|
||||
# that the tensors' dtypes must be floating point. Neither alternative supports
|
||||
# torch.jit.script nor differentiable, so we fall back to the single tensor
|
||||
# implementation in those cases.
|
||||
def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
|
||||
def _default_to_fused_or_foreach(params: List[torch.Tensor],
|
||||
differentiable: bool,
|
||||
has_fused: bool = False) -> Tuple[bool, bool]:
|
||||
use_fused: bool = False) -> Tuple[bool, bool]:
|
||||
if torch.jit.is_scripting() or differentiable:
|
||||
return False, False
|
||||
all_tensors = []
|
||||
for tensorlist in tensorlists:
|
||||
all_tensors.extend(tensorlist)
|
||||
fused = has_fused and all(
|
||||
p is None or (type(p) == torch.Tensor and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
|
||||
fused = use_fused and all(
|
||||
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
|
||||
)
|
||||
foreach = not fused and all(
|
||||
p is None or (type(p) == torch.Tensor and p.is_cuda) for p in all_tensors
|
||||
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
|
||||
)
|
||||
return fused, foreach
|
||||
|
||||
@ -83,6 +81,23 @@ _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of o
|
||||
foreach over the for-loop implementation on CUDA, since it is usually
|
||||
significantly more performant. (default: None)"""
|
||||
|
||||
_fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used.
|
||||
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
|
||||
are supported. (default: None)
|
||||
|
||||
.. note:: The foreach and fused implementations are typically faster than the for-loop,
|
||||
single-tensor implementation. Thus, if the user has not specified BOTH flags
|
||||
(i.e., when foreach = fused = None), we will attempt defaulting to the foreach
|
||||
implementation when the tensors are all on CUDA. For example, if the user specifies
|
||||
True for fused but nothing for foreach, we will run the fused implementation. If
|
||||
the user specifies False for foreach but nothing for fused (or False for fused but
|
||||
nothing for foreach), we will run the for-loop implementation. If the user specifies
|
||||
True for both foreach and fused, we will prioritize fused over foreach, as it is
|
||||
typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
|
||||
foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
|
||||
we want to give it sufficient bake-in time, so we default to foreach and NOT
|
||||
fused when the user has not specified either flag."""
|
||||
|
||||
_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
|
||||
capture in a CUDA graph. Passing True can impair ungraphed performance,
|
||||
so if you don't intend to graph capture this instance, leave it False
|
||||
|
||||
@ -209,8 +209,7 @@ def radam(
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, state_steps],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -220,8 +220,7 @@ def rmsprop(
|
||||
"""
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, grad_avgs, momentum_buffer_list],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -192,8 +192,7 @@ def rprop(
|
||||
"""
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach([params, grads, prevs, step_sizes],
|
||||
differentiable, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
@ -207,8 +207,7 @@ def sgd(params: List[Tensor],
|
||||
# why must we be explicit about an if statement for torch.jit.is_scripting here?
|
||||
# because JIT can't handle Optionals nor fancy conditionals when scripting
|
||||
if not torch.jit.is_scripting():
|
||||
_, foreach = _default_to_fused_or_foreach([params, d_p_list, momentum_buffer_list],
|
||||
differentiable=False, has_fused=False)
|
||||
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
|
||||
else:
|
||||
foreach = False
|
||||
|
||||
|
||||
@ -2130,7 +2130,8 @@ class TestCase(expecttest.TestCase):
|
||||
errors_before = 0 if result is None else len(result.errors)
|
||||
skipped_before = 0 if result is None else len(result.skipped)
|
||||
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
# TODO remove version check once dynamo supports 3.11
|
||||
if TEST_WITH_TORCHDYNAMO and sys.version_info < (3, 11):
|
||||
# TorchDynamo optimize annotation
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
super_run = torch._dynamo.optimize("inductor")(super().run)
|
||||
|
||||
@ -380,7 +380,7 @@ def make_histogram(values, bins, max_bins=None):
|
||||
limits = new_limits
|
||||
|
||||
# Find the first and the last bin defining the support of the histogram:
|
||||
cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
|
||||
cum_counts = np.cumsum(np.greater(counts, 0))
|
||||
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
|
||||
start = int(start)
|
||||
end = int(end) + 1
|
||||
|
||||
@ -1109,6 +1109,7 @@ SUPPORTED_RETURN_TYPES = {
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
|
||||
|
||||
@ -638,6 +638,7 @@ class NativeFunction:
|
||||
raw_dispatch = e.pop("dispatch", None)
|
||||
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
||||
dispatch: Dict[DispatchKey, BackendMetadata] = {}
|
||||
num_dispatch_keys: int = 0
|
||||
if raw_dispatch is not None:
|
||||
assert not manual_kernel_registration, (
|
||||
"cannot specify both manual_kernel_registration and dispatch; with "
|
||||
@ -650,6 +651,8 @@ class NativeFunction:
|
||||
assert isinstance(ks, str), e
|
||||
for k in ks.split(","):
|
||||
dispatch_key = DispatchKey.parse(k.strip())
|
||||
num_dispatch_keys += 1
|
||||
|
||||
if ignore_keys and dispatch_key in ignore_keys:
|
||||
continue
|
||||
assert dispatch_key in dispatch_keys, (
|
||||
@ -677,7 +680,12 @@ class NativeFunction:
|
||||
):
|
||||
redundant_composite_implicit_autograd = True
|
||||
|
||||
assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), (
|
||||
# We count the number of dispatch keys which have not been ignored to prevent a dispatch table
|
||||
# in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
|
||||
# from being treated as redundant.
|
||||
assert not (
|
||||
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
|
||||
), (
|
||||
"unnecessary dispatch table for this function; just delete the dispatch "
|
||||
"key entirely"
|
||||
)
|
||||
@ -687,6 +695,7 @@ class NativeFunction:
|
||||
structured_delegate
|
||||
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||||
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
|
||||
or num_dispatch_keys != 1
|
||||
), (
|
||||
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
|
||||
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "
|
||||
|
||||
Reference in New Issue
Block a user