Compare commits

...

17 Commits

Author SHA1 Message Date
9509e8a3d6 Fix cosine similarity dim checks (#66214)
* fix cosine similarity dimensionality check

* fix shapes in the doc
2021-10-08 07:22:40 -07:00
1774a6a2f4 [ONNX] Deprecate various args (#65962)
* [ONNX] Remove argument _retain_param_name from torch.onnx.export() function. (#61702) (#64370)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64370

As of now, the "_retain_param_name" parameter has no description in PyTorch docs website. According to code, this argument determines if we keep the original parameter names of PyTorch model in the final ONNX graph. If this is False, those original parameter names will be replaced with a series of integers starting from 1.

Since setting numbers as parameter names make no sense to users, we remove this argument from the torch.onnx.export() function to increase user experience of calling this function.

This PR will still keep it in torch.onnx.export() function for backward support while all backend logic has been changed to work as _retain_param_name is set to True.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30905270

Pulled By: malfet

fbshipit-source-id: ca60757ca17daaff937e9f08da42596086795f4a

Co-authored-by: fatcat-z <zhang-ji@outlook.com>

* [ONNX] Remove strip_doc_string param from torch.onnx.export() function. (#61712) (#64371)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64371

As of now, the "strip_doc_string" parameter was described as below:

strip_doc_string (bool, default True): do not include the field
doc_string``` from the exported model. Otherwise the field will mention the source code locations for model``.

This is usually useless to users who want to transform a PyTorch model to ONNX one. Only when the user wants to debug the export process, these source code locations could provide benefits.

To make the export() function more friendly by providing less parameters, we combined "strip_doc_string" into "verbose" parameter. If a user set verbose to True, it means the users need some log information for debugging the export process and this is similar with the purpose of strip_doc_string parameter.

But the usage of these 2 arguments are opposite: setting verbose to True means we want to print log information to help debug, which means strip_doc_string should be False. And this is how we replace strip_doc_string with verbose argument in this PR.

This PR will still keep it in torch.onnx.export() function for backward support while the usage of it has been combined with verbose argument.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30905268

Pulled By: malfet

fbshipit-source-id: 2f06eb805c01fe15ff7a1b4f6595c937ba716d60

Co-authored-by: fatcat-z <zhang-ji@outlook.com>

* [ONNX] minor doc improvements and cleanup (#62514) (#64373)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64373

* Fix some bad formatting and clarify things in onnx.rst.
* In `export_to_pretty_string`:
    * Add documentation for previously undocumented args.
    * Document that `f` arg is ignored and mark it deprecated.
    * Update tests to stop setting `f`.
    * Warn if `_retain_param_name` is set.
* Use double quotes for string literals in test_operators.py.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30905271

Pulled By: malfet

fbshipit-source-id: 3627eeabf40b9516c4a83cfab424ce537b36e4b3

* [ONNX] Deprecated the example_outputs param from torch.onnx.export() function. (#62815) (#64380)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64380

* `example_outputs` used to determine the type and shape of the outputs without tracing the execution of the model. And it must be provided when exporting a ScriptModule or ScriptFunction when using export() function.

* Since we can work out `example_outputs` in internal function instead of being provided by user, so we deprecated this argument in the export() function to increase user experience of calling this function.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30905266

Pulled By: malfet

fbshipit-source-id: d00b00d7d02b365d165028288ad915678caa51f2

Co-authored-by: hwangdeyu <dejack953@outlook.com>

* [ONNX] Deprecate use_external_data_format param from torch.onnx.export() function. (#62257) (#64382)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64382

* This `use_external_data_format` parameter is used for large models cannot be exported because of the 2GB protobuf limit.

* When `use_external_data_format` set to True, the model is exported in ONNX external data format, in which case some of the model parameters are stored in external binary files and not in the ONNX model file itself.

* This PR will set this paramter to DEPRECATED and check the model proto sizes by code instead of by user, if the sizes lager than 2GB, then `use_external_data_format = True` automatically.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30905265

Pulled By: malfet

fbshipit-source-id: 82b4e17bfa6a8de2bfd700a5282c12f6835603cb

Co-authored-by: hwangdeyu <dejack953@outlook.com>

* fix clang-tidy error introduced by #64382 (#65977)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65977

Reviewed By: ngimel

Differential Revision: D31423174

Pulled By: malfet

fbshipit-source-id: 0ea560b9a6ddd6431f70bd3ac10ace68e26ab352

Co-authored-by: BowenBao <bowbao@microsoft.com>
Co-authored-by: fatcat-z <zhang-ji@outlook.com>
Co-authored-by: hwangdeyu <dejack953@outlook.com>
2021-10-08 07:21:29 -07:00
a27906c250 Convert Sampler back to lazily construction (#63646) (#65926)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63646

Fixes #63609

Test Plan: Imported from OSS

Reviewed By: NivekT

Differential Revision: D30451774

Pulled By: ejguan

fbshipit-source-id: 550d77494326446d1a42b5da0559e0d384c47413
2021-10-08 07:20:03 -07:00
49f52b6c07 Revert "Added option to update parameters using state_dict in AveragedModel (#65495) (#65755)" (#66308)
This reverts commit 5f1a434599b46afd99607839d15892e09269a1c4.
2021-10-08 07:17:47 -07:00
5f1a434599 Added option to update parameters using state_dict in AveragedModel (#65495) (#65755)
* Added option to update parameters using state_dict in AveragedModel (#65495)

Summary:
While implementing [EMA](https://github.com/pytorch/vision/pull/4381)(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](https://github.com/pytorch/vision/pull/4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: https://github.com/pytorch/vision/pull/4406#pullrequestreview-753734102

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b1fd543304e3be7bd223cac451cd093e16)

* Added validation of mode parameter in AveragedModel (#65921)

Summary:
Discussion: https://github.com/pytorch/pytorch/pull/65495#issuecomment-930460469

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
(cherry picked from commit c7748fc172553da66368fd0b7fea3fe5661e2dc1)
2021-10-06 11:13:31 -07:00
ecbf5a7439 Tweak file_diff_from_base for release/1.10 branch (#66202) 2021-10-06 08:34:46 -07:00
4e3ebebcff [DataPipe] DataPipe Fix and Deprecation Warnings for Release 1.10 (#65932)
* Unify the output pathname of archive reader and extractor (#65424)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65424

This PR is re-implementation for https://github.com/facebookexternal/torchdata/pull/93
Same PR has landed into torchdata https://github.com/facebookexternal/torchdata/pull/157

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D31090447

Pulled By: ejguan

fbshipit-source-id: 45af1ad9b24310bebfd6e010f41cff398946ba65

* [DatePipe] add deprecation warnings for DataPipes that will solely exist in TorchData (#65827)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65827

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D31272794

Pulled By: NivekT

fbshipit-source-id: 8da8266184b4df050422904cbc5fca6d7c3d2e02

* [DataPipe] Fixes an issue where TarArchiveReader closes stream when read into a buffer (#65877)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65877

Fixes #65808

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D31296041

Pulled By: NivekT

fbshipit-source-id: cdcad3a333ae9781d6063678a122a128955b0ff4

Co-authored-by: Erjia Guan <erjia@fb.com>
2021-10-05 20:54:40 -07:00
2b46c95e7c [iOS][CI] Update dev certs (#66004) (#66188)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/65988

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66004

Reviewed By: xta0

Differential Revision: D31340893

Pulled By: malfet

fbshipit-source-id: 3bf0be266e9686a73d62e86c5cf0bebeb0416260

Co-authored-by: Tao Xu <taox@fb.com>
2021-10-05 20:12:40 -07:00
5f3eee1ca5 Fix backward compatibility tests (#66186)
Compare operator list against RC1 build rather than against nightly
2021-10-05 20:12:13 -07:00
4731f33d02 Fix Windows ninja builds when MAX_JOBS is specified (#65444) (#66155)
Summary:
Reported by cloudhan in https://github.com/pytorch/pytorch/pull/64733#issuecomment-924545463

Fixes regression introduced by 047e68235f

cc malfet seemethere

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65444

Reviewed By: dagitses, seemethere

Differential Revision: D31103260

Pulled By: malfet

fbshipit-source-id: 9d5454a64cb8a0b96264119cf16582cc5afed284
2021-10-05 12:03:27 -07:00
ecfcb8ff5a Binary building wthout python fix (#66031) (#66117)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/66030

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66031

Reviewed By: VitalyFedyunin

Differential Revision: D31356243

Pulled By: malfet

fbshipit-source-id: d1537bc65bbba5d6497ecb8db7160a397eca81fd
2021-10-05 12:02:51 -07:00
6aadfda9e2 [ci] try installing libgnutls to fix cert error (#65934) (#65979)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65934

see: https://github.com/pytorch/pytorch/issues/65931, this was a
suggested remediation on the linked issue

Test Plan: Imported from OSS

Reviewed By: malfet, zhouzhuojie

Differential Revision: D31313040

Pulled By: suo

fbshipit-source-id: a9e2b82a1e879962af768ed3049c73ab77394738

Co-authored-by: Michael Suo <suo@fb.com>
2021-09-30 18:55:44 -07:00
13666d20fd [DataPipe] Fix deepcopy filehandle for Mapper and in-place modification for IterableWrapper (#65220) (#65924)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65220

Fixes #65221

- Remove deepcopy from Mapper to support file handles
- Convert `IterableWrapper` to deepcopy iterable instance within each iterator to prevent in-place modification (different data per epoch)
- Convert `IDP` to `IterableWrapper` in test_datapipe.py
- Refine the variable names (prevent using `dp` that is module reference)

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D31021886

Pulled By: ejguan

fbshipit-source-id: 72a9eee66c758e2717d591cd0942892bddedc223
2021-09-30 18:36:49 -07:00
1fa17a20fc Fix the slowdown of _object_to_tensor since 1.9 (#65721) (#65835)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65721

#Closes: https://github.com/pytorch/pytorch/issues/65696

The bug is introduced in https://github.com/pytorch/pytorch/pull/55861, and it causes 100X slowdown since 1.9.
ghstack-source-id: 139128267

Test Plan:
Performance test:
```
import time

from torch.distributed.distributed_c10d import _object_to_tensor

start = time.time()
_object_to_tensor("x" * 50_000_000)
print("Time:", time.time() - start)
```

Reviewed By: rohan-varma

Differential Revision: D31219794

fbshipit-source-id: 1abec38f9d51361c1eab6ad5efd87b589322e208

Co-authored-by: Yi Wang <wayi@fb.com>
2021-09-29 14:38:54 -07:00
c05547fa6c Fix test reporting git merge-base (#65787) 2021-09-28 15:48:32 -07:00
0e857bf109 [1.10] Remove torch.vmap (#65496)
torch.vmap is a prototype feature and should not be in the stable
binary. This PR:
- Removes the torch.vmap API
- Removes the documentation entry for torch.vmap
- Changes the vmap tests to use an internal API instead of torch.vmap.

Test Plan:
- Tested locally (test_torch, test_autograd, test_type_hints, test_vmap),
but also wait for CI.
2021-09-24 10:29:08 -07:00
ad22804b95 [release/1.10] Pin builder and xla repo (#65433)
Pin builder to https://github.com/pytorch/builder/commits/release/1.10
Pin xla to https://github.com/pytorch/xla/tree/r1.10
2021-09-21 16:16:22 -07:00
51 changed files with 624 additions and 489 deletions

14
.circleci/config.yml generated
View File

@ -632,7 +632,8 @@ jobs:
}
if is_vanilla_build; then
echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
else
echo "Skipping for ${BUILD_ENVIRONMENT}"
@ -1734,16 +1735,17 @@ jobs:
# install fastlane
sudo gem install bundler && bundle install
# install certificates
echo ${IOS_CERT_KEY} >> cert.txt
echo ${IOS_CERT_KEY_2022} >> cert.txt
base64 --decode cert.txt -o Certificates.p12
rm cert.txt
bundle exec fastlane install_cert
bundle exec fastlane install_root_cert
bundle exec fastlane install_dev_cert
# install the provisioning profile
PROFILE=PyTorch_CI_2021.mobileprovision
PROFILE=PyTorch_CI_2022.mobileprovision
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
mkdir -pv "${PROVISIONING_PROFILES}"
cd "${PROVISIONING_PROFILES}"
echo ${IOS_SIGN_KEY} >> cert.txt
echo ${IOS_SIGN_KEY_2022} >> cert.txt
base64 --decode cert.txt -o ${PROFILE}
rm cert.txt
- run:
@ -1802,7 +1804,7 @@ jobs:
command: |
set -e
PROJ_ROOT=/Users/distiller/project
PROFILE=PyTorch_CI_2021
PROFILE=PyTorch_CI_2022
# run the ruby build script
if ! [ -x "$(command -v xcodebuild)" ]; then
echo 'Error: xcodebuild is not installed.'

View File

@ -61,7 +61,7 @@ git --no-pager log --max-count 1
popd
# Clone the Builder master repo
retry git clone -q https://github.com/pytorch/builder.git "$BUILDER_ROOT"
retry git clone -q https://github.com/pytorch/builder.git -b release/1.10 "$BUILDER_ROOT"
pushd "$BUILDER_ROOT"
echo "Using builder from "
git --no-pager log --max-count 1

View File

@ -8,16 +8,17 @@ cd ${PROJ_ROOT}/ios/TestApp
# install fastlane
sudo gem install bundler && bundle install
# install certificates
echo "${IOS_CERT_KEY}" >> cert.txt
echo "${IOS_CERT_KEY_2022}" >> cert.txt
base64 --decode cert.txt -o Certificates.p12
rm cert.txt
bundle exec fastlane install_cert
bundle exec fastlane install_root_cert
bundle exec fastlane install_dev_cert
# install the provisioning profile
PROFILE=PyTorch_CI_2021.mobileprovision
PROFILE=PyTorch_CI_2022.mobileprovision
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
mkdir -pv "${PROVISIONING_PROFILES}"
cd "${PROVISIONING_PROFILES}"
echo "${IOS_SIGN_KEY}" >> cert.txt
echo "${IOS_SIGN_KEY_2022}" >> cert.txt
base64 --decode cert.txt -o ${PROFILE}
rm cert.txt
# run the ruby build script
@ -25,5 +26,5 @@ if ! [ -x "$(command -v xcodebuild)" ]; then
echo 'Error: xcodebuild is not installed.'
exit 1
fi
PROFILE=PyTorch_CI_2021
PROFILE=PyTorch_CI_2022
ruby ${PROJ_ROOT}/scripts/xcode_build.rb -i ${PROJ_ROOT}/build_ios/install -x ${PROJ_ROOT}/ios/TestApp/TestApp.xcodeproj -p ${IOS_PLATFORM} -c ${PROFILE} -t ${IOS_DEV_TEAM_ID} -f Accelerate,MetalPerformanceShaders,CoreML

View File

@ -467,16 +467,17 @@
# install fastlane
sudo gem install bundler && bundle install
# install certificates
echo ${IOS_CERT_KEY} >> cert.txt
echo ${IOS_CERT_KEY_2022} >> cert.txt
base64 --decode cert.txt -o Certificates.p12
rm cert.txt
bundle exec fastlane install_cert
bundle exec fastlane install_root_cert
bundle exec fastlane install_dev_cert
# install the provisioning profile
PROFILE=PyTorch_CI_2021.mobileprovision
PROFILE=PyTorch_CI_2022.mobileprovision
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
mkdir -pv "${PROVISIONING_PROFILES}"
cd "${PROVISIONING_PROFILES}"
echo ${IOS_SIGN_KEY} >> cert.txt
echo ${IOS_SIGN_KEY_2022} >> cert.txt
base64 --decode cert.txt -o ${PROFILE}
rm cert.txt
- run:
@ -535,7 +536,7 @@
command: |
set -e
PROJ_ROOT=/Users/distiller/project
PROFILE=PyTorch_CI_2021
PROFILE=PyTorch_CI_2022
# run the ruby build script
if ! [ -x "$(command -v xcodebuild)" ]; then
echo 'Error: xcodebuild is not installed.'

View File

@ -158,7 +158,8 @@ jobs:
}
if is_vanilla_build; then
echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
else
echo "Skipping for ${BUILD_ENVIRONMENT}"

View File

@ -63,9 +63,9 @@ function get_pr_change_files() {
function file_diff_from_base() {
# The fetch may fail on Docker hosts, this fetch is necessary for GHA
set +e
git fetch origin master --quiet
git fetch origin release/1.10 --quiet
set -e
git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
git diff --name-only "$(git merge-base origin/release/1.10 HEAD)" > "$1"
}
function get_bazel() {
@ -99,5 +99,5 @@ function checkout_install_torchvision() {
}
function clone_pytorch_xla() {
git clone --recursive https://github.com/pytorch/xla.git
git clone --recursive -b r1.10 https://github.com/pytorch/xla.git
}

View File

@ -424,7 +424,7 @@ test_backward_compatibility() {
python -m venv venv
# shellcheck disable=SC1091
. venv/bin/activate
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_nightly.html
pip show torch
python dump_all_function_schemas.py --filename nightly_schemas.txt
deactivate

View File

@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
}
Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ",
x1.ndimension(), " and x2 has ", x2.ndimension());
TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ",
x1.size(dim), " and x2 has ", x2.size(dim));
auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes());
auto commonDtype = at::result_type(x1, x2);
TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
Tensor x1_ = x1.to(commonDtype);
Tensor x2_ = x2.to(commonDtype);
Tensor x1_ = x1.to(commonDtype).expand(common_size);
Tensor x2_ = x2.to(commonDtype).expand(common_size);
// Follow scipy impl to improve numerical precision
// Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
Tensor w12 = at::sum(x1_ * x2_, dim);

View File

@ -307,11 +307,11 @@ If the operator is an ATen operator (shows up in the TorchScript graph with the
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
Make sure the function has the same name as the ATen function, which may declared in
Make sure the function has the same name as the ATen function, which may be declared in
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
build time, so will not appear in your checkout until you build PyTorch).
* The first arg is always the ONNX graph that is being built for export.
Other arg names must EXACTLY match the names in ``_VariableFunctions.pyi``,
Other arg names must EXACTLY match the names in the ``.pyi`` file,
because dispatch is done with keyword arguments.
* In the symbolic function, if the operator is in the
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
@ -365,8 +365,8 @@ See the ``symbolic_opset*.py`` files for more examples.
torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^
If the operator is defined in a sub-class of :class:`torch.autograd.Function`,
there are two ways to export it.
If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
to export it.
Static Symbolic Method
~~~~~~~~~~~~~~~~~~~~~~
@ -388,11 +388,11 @@ PythonOp Symbolic
~~~~~~~~~~~~~~~~~
Alternatively, you can register a custom symbolic function.
This gives the symoblic function access to more info through the
This gives the symbolic function access to more info through the
TorchScript ``Node`` object for the original operation, which gets passed in as the second
argument (after the ``Graph`` object).
All autograd ``Function``s are emitted in the TorchScript graph as ``prim::PythonOp`` nodes.
All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
In order to differentiate between different ``Function`` subclasses, the
symbolic function should use the ``name`` kwarg which gets set to the name of the class.
@ -400,11 +400,12 @@ symbolic function should use the ``name`` kwarg which gets set to the name of th
the ``prim`` namespace, so for this use case, there's a back door: register the
symbolic for ``"::prim_PythonOp"``.
Please also consider adding shape inference logic when you regiester a custom symbolic function
via setType API. This can help the exporter to obtain correct shape inference.
An example of setType is test_aten_embedding_2 in test_operators.py.
Although it is not required to add shape inference logic,
the exporter emits a warning message if it is not added.
Custom symbolic functions should add type and shape information by calling ``setType(...)``
on Value objects before returning them (implemented in C++ by
``torch::jit::Value::setType``). This is not required, but it can help the exporter's
shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
``test_aten_embedding_2`` in
`test_operators.py <https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py>`_.
The example below shows how you can access ``requires_grad`` via the ``Node`` object::
@ -430,13 +431,17 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
name = kwargs["name"]
ret = None
if name == "MyClip":
return g.op("Clip", args[0], min_f=args[1])
ret = g.op("Clip", args[0], min_f=args[1])
elif name == "MyRelu":
return g.op("Relu", args[0])
ret = g.op("Relu", args[0])
else:
# Logs a warning and returns None
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
# Copy type and shape from original node.
ret.setType(n.type())
return ret
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 1)

View File

@ -597,5 +597,4 @@ Utilities
are_deterministic_algorithms_enabled
set_warn_always
is_warn_always_enabled
vmap
_assert

Binary file not shown.

View File

@ -4,7 +4,14 @@ platform :ios do
before_all do
setup_circle_ci
end
lane :install_cert do
lane :install_root_cert do
import_certificate(
certificate_path: "AppleWWDRCAG3.cer",
keychain_path: "/Users/distiller/Library/Keychains/fastlane_tmp_keychain-db",
keychain_password: ""
)
end
lane :install_dev_cert do
puts "Installing Certificates.p12"
import_certificate(
keychain_name: ENV["MATCH_KEYCHAIN_NAME"],

View File

@ -64,7 +64,7 @@ class TestExportModes(JitTestCase):
return (a, a)
f = io.BytesIO()
x = torch.ones(3)
torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
torch.onnx._export(foo, (x,), f)
@skipIfNoLapack
def test_aten_fallback(self):
@ -76,9 +76,8 @@ class TestExportModes(JitTestCase):
x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(), (x, y), f,
ModelWithAtenNotONNXOp(), (x, y), None,
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
@ -91,11 +90,10 @@ class TestExportModes(JitTestCase):
def forward(self, x, y):
return torch.fmod(x, y)
f = io.BytesIO()
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
torch.onnx.export_to_pretty_string(
ModelWithAtenFmod(), (x, y), f,
ModelWithAtenFmod(), (x, y), None,
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN)

View File

@ -49,20 +49,17 @@ class TestONNXExport(JitTestCase):
tm = TraceMe()
tm = torch.jit.trace(tm, torch.rand(3, 4))
example_outputs = (tm(torch.rand(3, 4)),)
f = io.BytesIO()
torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
torch.onnx._export(tm, (torch.rand(3, 4),), f)
def test_export_tensoroption_to(self):
def foo(x):
return x[0].clone().detach().cpu() + x
traced = torch.jit.trace(foo, (torch.rand([2])))
example_outputs = traced(torch.rand([2]))
f = io.BytesIO()
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
example_outputs=example_outputs)
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f)
def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
@ -75,10 +72,8 @@ class TestONNXExport(JitTestCase):
return x + x
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
@suppress_warnings
def test_onnx_export_func_with_warnings(self):
@ -93,11 +88,9 @@ class TestONNXExport(JitTestCase):
def forward(self, x):
return func_with_warning(x)
outputs = WarningTest()(torch.randn(42))
# no exception
torch.onnx.export_to_pretty_string(
WarningTest(), torch.randn(42), None, verbose=False,
example_outputs=outputs)
WarningTest(), torch.randn(42), None, verbose=False)
def test_onnx_export_script_python_fail(self):
class PythonModule(torch.jit.ScriptModule):
@ -119,11 +112,9 @@ class TestONNXExport(JitTestCase):
return y + y
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
example_outputs=outputs)
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.nn.Module):
@ -144,10 +135,8 @@ class TestONNXExport(JitTestCase):
return y + y
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
@ -169,10 +158,8 @@ class TestONNXExport(JitTestCase):
return y + y
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
@ -189,10 +176,8 @@ class TestONNXExport(JitTestCase):
return x
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
@suppress_warnings
def test_onnx_export_script_truediv(self):
@ -206,11 +191,9 @@ class TestONNXExport(JitTestCase):
return x + z
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False)
def test_onnx_export_script_non_alpha_add_sub(self):
class ModuleToExport(torch.jit.ScriptModule):
@ -223,10 +206,8 @@ class TestONNXExport(JitTestCase):
return bs - 1
mte = ModuleToExport()
outputs = torch.LongTensor([mte(torch.rand(3, 4))])
torch.onnx.export_to_pretty_string(
mte, (torch.rand(3, 4),), None, verbose=False,
example_outputs=outputs)
mte, (torch.rand(3, 4),), None, verbose=False)
def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
@ -240,10 +221,8 @@ class TestONNXExport(JitTestCase):
return x
mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
@ -272,8 +251,7 @@ class TestONNXExport(JitTestCase):
reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
self.assertEqual(result, reference)
torch.onnx.export_to_pretty_string(
mte, (torch.ones(2, 3),), None, verbose=False,
example_outputs=result)
mte, (torch.ones(2, 3),), None, verbose=False)
def test_onnx_export_speculate(self):
@ -305,18 +283,16 @@ class TestONNXExport(JitTestCase):
return x.t()
f1 = Foo(transpose)
outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
f2 = Foo(linear)
outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
torch.onnx.export_to_pretty_string(
f1,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f1)
None, verbose=False)
torch.onnx.export_to_pretty_string(
f2,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f2)
None, verbose=False)
def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
@ -328,10 +304,8 @@ class TestONNXExport(JitTestCase):
return reshaped
foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
outputs = foo(torch.zeros(1, 2, 3))
f = io.BytesIO()
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
example_outputs=outputs)
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f)
def test_listconstruct_erasure(self):
class FooMod(torch.nn.Module):
@ -360,11 +334,10 @@ class TestONNXExport(JitTestCase):
mod = DynamicSliceExportMod()
input = torch.rand(3, 4, 5)
example_outs = mod(input)
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10)
DynamicSliceExportMod(), (input,), f, opset_version=10)
def test_export_dict(self):
class DictModule(torch.nn.Module):
@ -380,4 +353,4 @@ class TestONNXExport(JitTestCase):
with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
torch.onnx.export_to_pretty_string(
torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),))
torch.jit.script(mod), (x_in,), f)

View File

@ -1115,9 +1115,8 @@ class TestTracer(JitTestCase):
def forward(self, x, w):
return torch.matmul(x, w).detach()
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), None)
def test_trace_slice_full_dim(self):
def foo(x):

View File

@ -19,7 +19,7 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
outputs = model(inputs)
script_model = torch.jit.script(model)
run_model_test(self, script_model, False, example_outputs=outputs,
run_model_test(self, script_model, False,
input=inputs, rtol=rtol, atol=atol)

View File

@ -40,14 +40,13 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None,
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL,
input_names=None, dynamic_axes=None):
for opset_version in opset_versions:
f = io.BytesIO()
torch.onnx.export(module, x, f,
opset_version=opset_version,
training=training,
example_outputs=example_outputs,
input_names=input_names,
dynamic_axes=dynamic_axes)
model = onnx.load(io.BytesIO(f.getvalue()))
@ -91,10 +90,8 @@ class TestONNXOpset(TestCase):
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
module = MyModuleDynamic()
example_output = module(x, k)
check_onnx_opsets_operator(module, [x, k], ops,
opset_versions=[10],
example_outputs=example_output)
opset_versions=[10])
def test_maxpool(self):
module = torch.nn.MaxPool1d(2, stride=1)
@ -191,7 +188,6 @@ class TestONNXOpset(TestCase):
module = DynamicSliceModel()
x = torch.rand(1, 2)
example_output = module(x)
ops_10 = [{"op_name" : "Shape"},
{"op_name" : "Constant"},
{"op_name" : "Gather",
@ -202,7 +198,7 @@ class TestONNXOpset(TestCase):
{"op_name" : "Slice",
"attributes" : []}]
ops = {10 : ops_10}
check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output,
check_onnx_opsets_operator(module, x, ops, opset_versions=[10],
input_names=['x'], dynamic_axes={"x": [0, 1]})
ops_10 = [{"op_name" : "Constant"},
@ -212,7 +208,7 @@ class TestONNXOpset(TestCase):
{"op_name" : "Slice",
"attributes" : []}]
ops = {10 : ops_10}
check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output)
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
def test_flip(self):
class MyModule(Module):

View File

@ -279,12 +279,11 @@ class TestOperators(TestCase):
def test_conv_variable_length(self):
x = torch.ones(5, 3, 6, 6, requires_grad=True)
model = torch.nn.Conv2d(3, 2, 3)
y = model(x)
dynamic_axes = {"input_1": [0, 2, 3], "output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"}}
model_proto_name = "conv2d.onnx"
torch.onnx.export(model, x, model_proto_name, verbose=True, input_names=["input_1"], output_names=["output_1"],
example_outputs=y, dynamic_axes=dynamic_axes)
dynamic_axes=dynamic_axes)
import onnx
onnx_model = onnx.load(model_proto_name)
@ -729,22 +728,6 @@ class TestOperators(TestCase):
x = torch.randn(2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
def test_retain_param_name_disabled(self):
class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.fc1 = nn.Linear(4, 5, bias=False)
self.fc1.weight.data.fill_(2.)
self.fc2 = nn.Linear(5, 6, bias=False)
self.fc2.weight.data.fill_(3.)
def forward(self, x):
return self.fc2(self.fc1(x))
x = torch.randn(3, 4).float()
self.assertONNX(MyModule(), (x,), _retain_param_name=False,
keep_initializers_as_inputs=True)
def test_c2_op(self):
class MyModel(torch.nn.Module):
def __init__(self):

View File

@ -134,7 +134,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return cuda_model, cuda_input
def run_debug_test(self, model, train, batch_size, state_dict=None,
input=None, use_gpu=True, example_outputs=None,
input=None, use_gpu=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX):
"""
# TODO: remove this from the final release version
@ -153,7 +153,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model, input = self.convert_cuda(model, input)
onnxir, torch_out = do_export(model, input, export_params=self.embed_params, verbose=False,
example_outputs=example_outputs,
do_constant_folding=False,
opset_version=self.opset_version,
keep_initializers_as_inputs=True,
@ -168,7 +167,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
def run_actual_test(self, model, train, batch_size, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
input_names=None, dynamic_axes=None,
remained_onnx_input_idx=None):
@ -191,7 +190,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
# Verify the model runs the same in Caffe2
verify.verify(model, input, c2, rtol=rtol, atol=atol,
example_outputs=example_outputs,
do_constant_folding=do_constant_folding,
opset_version=self.opset_version,
keep_initializers_as_inputs=True,
@ -202,7 +200,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
def run_model_test(self, model, train, batch_size, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
do_constant_folding=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
input_names=None, dynamic_axes=None,
remained_onnx_input_idx=None):
@ -214,7 +212,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
if self.embed_params:
self.run_actual_test(model, train, batch_size, state_dict, input,
use_gpu=use_gpu_, rtol=rtol, atol=atol,
example_outputs=example_outputs,
do_constant_folding=do_constant_folding,
operator_export_type=operator_export_type,
input_names=input_names,
@ -222,8 +219,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
remained_onnx_input_idx=remained_onnx_input_idx)
else:
self.run_debug_test(model, train, batch_size, state_dict, input,
use_gpu=use_gpu_, example_outputs=example_outputs,
operator_export_type=operator_export_type)
use_gpu=use_gpu_, operator_export_type=operator_export_type)
def test_linear(self):
class MyModel(torch.nn.Module):
@ -289,7 +285,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
# This test checks that given this edge condition, the model can be loaded and executed
# in Caffe2 backend correctly.
torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE,
input_names=["input1", "fc1.bias"], _retain_param_name=False,
input_names=["input1", "fc1.bias"],
keep_initializers_as_inputs=True)
f.seek(0)
@ -1401,9 +1397,8 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return x[1:x.size(0)]
module = DynamicSliceModel()
x = torch.rand(1, 2)
example_output = module(x)
self.run_model_test(DynamicSliceModel(), train=False, input=(x,),
batch_size=BATCH_SIZE, use_gpu=False, example_outputs=example_output)
batch_size=BATCH_SIZE, use_gpu=False)
@skipIfUnsupportedMinOpsetVersion(11)
def test_dynamic_slice_to_the_end(self):
@ -1472,7 +1467,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
y = torch.ones(2, 3, 4) * 2
self.run_model_test(Arithmetic(),
train=False, input=(), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=(x + 3, y * (x + 3)))
use_gpu=False)
def test_tensor_factories(self):
class TensorFactory(torch.nn.Module):
@ -1493,11 +1488,9 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(2, 3, 4)
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=(torch.ones(x.size()),),
input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=(torch.ones(x.size()),),
remained_onnx_input_idx=[])
use_gpu=False, remained_onnx_input_idx=[])
def test_tensor_like_factories_script(self):
class TensorFactory(torch.jit.ScriptModule):
@ -1509,12 +1502,10 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(2, 3, 4)
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=(torch.ones(x.size()),),
input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
remained_onnx_input_idx = None if self.opset_version < 9 else []
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=(torch.ones(x.size()),),
remained_onnx_input_idx=remained_onnx_input_idx)
use_gpu=False, remained_onnx_input_idx=remained_onnx_input_idx)
def test_full(self):
class FullModel(torch.nn.Module):
@ -1532,8 +1523,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return torch.full((4, 5), x, dtype=torch.long)
x = torch.tensor(12)
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
use_gpu=False, example_outputs=FullClass()(x))
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
def test_clamp(self):
class ClampModel(torch.nn.Module):
@ -2021,8 +2011,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return a
x = (torch.randn(3, 4), torch.randn(4, 3))
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(x,))
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
def test_nested_tuple_input_output(self):
class NestedTupleModel(torch.jit.ScriptModule):
@ -2032,8 +2021,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(4, 5)
y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
example_outputs=x + y[0] + y[1][0] + y[1][1])
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
def test_topk(self):
class TopKModel(torch.nn.Module):
@ -2050,7 +2038,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return torch.topk(input, 3, dim=0)
x = torch.randn(4, 3, requires_grad=True)
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
def test_floor(self):
class FloorModel(torch.nn.Module):
@ -2086,9 +2074,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
class ArangeModel(torch.nn.Module):
def forward(self, a):
@ -2104,9 +2090,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
class ArangeModel(torch.nn.Module):
def forward(self, a):
@ -2122,9 +2106,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
x = torch.randn(3, 4, requires_grad=True)
outputs = ArangeScript()(x)
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
class ArangeModel(torch.nn.Module):
def forward(self, a):
@ -2249,9 +2231,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = WhileModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
def test_while_cond(self):
class WhileModel(torch.jit.ScriptModule):
@ -2266,9 +2246,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = WhileModel()
x = torch.zeros(1, 2, 3, dtype=torch.long)
a = torch.tensor([0], dtype=torch.long)
outputs = model(x, a)
self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE)
@unittest.skip("Disabled due to onnx optimizer deprecation")
def test_loop(self):
@ -2281,9 +2259,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = LoopModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
@unittest.skip("Disabled due to onnx optimizer deprecation")
def test_dynamic_loop(self):
@ -2296,9 +2272,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = LoopModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
@unittest.skip("Disabled due to onnx optimizer deprecation")
@skipIfUnsupportedMinOpsetVersion(9)
@ -2317,9 +2291,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = NestedLoopsModel()
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
def test_select(self):
class SelectModel(torch.nn.Module):
@ -2337,9 +2309,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = StandardDeviation()
inputs = torch.randn(2, 3, 4)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
def test_std_along_dims(self):
class StandardDeviationAlongDims(torch.nn.Module):
@ -2348,9 +2318,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = StandardDeviationAlongDims()
inputs = torch.randn(2, 3, 4)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_masked_fill(self):
@ -2379,9 +2347,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
y = torch.zeros(4, requires_grad=True)
z = torch.ones(5, requires_grad=True)
model = MeshgridModel()
outputs = model(x, y, z)
self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE)
def test_remainder(self):
class RemainderModel(torch.nn.Module):
@ -2391,9 +2357,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(4, 2, 3)
y = torch.randn(1, 2, 1)
model = RemainderModel()
outputs = model(x, y)
self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE)
def test_remainder_scalar(self):
class RemainderModel(torch.nn.Module):
@ -2402,9 +2366,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
inputs = torch.randint(10, (2, 3))
model = RemainderModel()
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
def test_baddbmm(self):
class MyModule(torch.nn.Module):
@ -2423,9 +2385,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
model = GeluModel()
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
outputs = model(inputs)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
example_outputs=(outputs,))
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_index_fill(self):

View File

@ -28,7 +28,7 @@ class TestQuantizedOps(unittest.TestCase):
output = q_model(*pt_inputs)
f = io.BytesIO()
torch.onnx.export(q_model, pt_inputs, f, input_names=input_names, example_outputs=output,
torch.onnx.export(q_model, pt_inputs, f, input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
f.seek(0)
onnx_model = onnx.load(f)
@ -84,8 +84,6 @@ class TestQuantizedOps(unittest.TestCase):
self.generic_unary_test(torch.nn.ReLU())
def export_to_onnx(self, model, input, input_names):
outputs = model(input)
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
@ -93,7 +91,7 @@ class TestQuantizedOps(unittest.TestCase):
model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
torch.onnx.export(model, input, f, input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
f.seek(0)

View File

@ -45,9 +45,9 @@ def to_numpy(tensor):
else:
return tensor.cpu().numpy()
def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
do_constant_folding=True, keep_initializers_as_inputs=True,
dynamic_axes=None, input_names=None, output_names=None,
def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True,
keep_initializers_as_inputs=True, dynamic_axes=None,
input_names=None, output_names=None,
fixed_batch_size=False, training=None,
onnx_shape_inference=True):
# export the model to ONNX
@ -55,7 +55,6 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
input_copy = copy.deepcopy(input)
torch.onnx._export(model, input_copy, f,
opset_version=opset_version,
example_outputs=example_outputs,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
@ -132,7 +131,7 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
input = input + ({},)
ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version,
example_outputs=output, do_constant_folding=do_constant_folding,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes, input_names=input_names,
output_names=output_names, fixed_batch_size=fixed_batch_size, training=training,
@ -284,9 +283,9 @@ class TestONNXRuntime(unittest.TestCase):
_run_test(model, tracing_remained_onnx_input_idx)
def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, input_names=None, output_names=None,
ort_optim_on=True, training=None):
do_constant_folding=True, dynamic_axes=None,
input_names=None, output_names=None,
ort_optim_on=True, training=None, use_external_data_format=None):
import os
import tempfile
@ -310,13 +309,12 @@ class TestONNXRuntime(unittest.TestCase):
input_copy = copy.deepcopy(input)
torch.onnx.export(model, input_copy, model_file_name,
opset_version=self.opset_version,
example_outputs=output,
verbose=False,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
use_external_data_format=True)
use_external_data_format=use_external_data_format)
# compute onnxruntime output prediction
ort_sess_opt = onnxruntime.SessionOptions()
ort_sess_opt.graph_optimization_level = \
@ -366,19 +364,55 @@ class TestONNXRuntime(unittest.TestCase):
return x + torch.ones(2, 1024)
x = torch.randn(2, 1)
self.run_model_test_with_external_data(LargeModel(), x)
self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=None)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
@unittest.skip("Enable this once large model with subgraph is supported in ORT")
def test_subgraph_with_external_data(self):
def test_largemodel_without_use_external_data_format_param(self):
class LargeModel(torch.nn.Module):
def forward(self, x):
for i in range(x.size(0)):
x = x + torch.ones(2, 1024)
return x
def __init__(self):
super(LargeModel, self).__init__()
dim = 5
n = 40 * 4 * 10 ** 6
self.emb = torch.nn.Embedding(n, dim)
self.lin1 = torch.nn.Linear(dim, 1)
self.seq = torch.nn.Sequential(
self.emb,
self.lin1,
)
x = torch.randn(2, 1)
self.run_model_test_with_external_data(torch.jit.script(LargeModel()), x)
def forward(self, input):
return self.seq(input)
model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=None)
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_largemodel_with_use_external_data_format_False(self):
class LargeModel(torch.nn.Module):
def __init__(self):
super(LargeModel, self).__init__()
dim = 5
n = 30 * 4 * 10 ** 6
self.emb = torch.nn.Embedding(n, dim)
self.lin1 = torch.nn.Linear(dim, 1)
self.seq = torch.nn.Sequential(
self.emb,
self.lin1,
)
def forward(self, input):
return self.seq(input)
model = LargeModel()
x = torch.tensor([3], dtype=torch.long)
with self.assertRaises(RuntimeError) as cm:
self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=False)
the_exception = cm.exception
self.assertEqual("RuntimeError: Exporting model exceed maximum protobuf size of 2GB. " +
"Please call torch.onnx.export without setting use_external_data_format parameter.")
def test_fuse_conv_bn1d(self):
class Fuse(torch.nn.Module):
@ -7784,7 +7818,6 @@ class TestONNXRuntime(unittest.TestCase):
script_model = torch.jit.script(model)
output = model(x)
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
example_outputs=output,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
@ -7828,7 +7861,6 @@ class TestONNXRuntime(unittest.TestCase):
y = model(input)
output = y.cpu().numpy()
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
example_outputs=y,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
@ -7913,11 +7945,9 @@ class TestONNXRuntime(unittest.TestCase):
model = torch.jit.script(MyModule())
box_regression = torch.randn([4, 4])
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
outputs = model(box_regression, proposal)
with self.assertRaises(RuntimeError) as cm:
convert_to_onnx(model, input=(box_regression, proposal),
example_outputs=outputs)
convert_to_onnx(model, input=(box_regression, proposal))
def test_initializer_sequence(self):
class MyModule(torch.nn.Module):
@ -7939,7 +7969,7 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(32, 3)
f = io.BytesIO()
torch.onnx._export(test_model, (x,), f, _retain_param_name=True, do_constant_folding=False)
torch.onnx._export(test_model, (x,), f, do_constant_folding=False)
loaded_model = onnx.load_from_string(f.getvalue())
actual_list = [p.name for p in loaded_model.graph.initializer]
@ -7986,10 +8016,9 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.ones(2, 3, dtype=torch.float)
y = torch.tensor(5, dtype=torch.long)
example_output = (test_model(x, y),)
f = io.BytesIO()
torch.onnx.export(test_model, (x, y), f, example_outputs=example_output, _retain_param_name=True, do_constant_folding=False)
torch.onnx.export(test_model, (x, y), f, do_constant_folding=False)
loaded_model = onnx.load_from_string(f.getvalue())
actual_list = [p.name for p in loaded_model.graph.initializer]

View File

@ -32,7 +32,6 @@ class TestUtilityFuns(TestCase):
def _model_to_graph(self, model, input,
do_constant_folding=True,
example_outputs=None,
training=TrainingMode.EVAL,
operator_export_type=OperatorExportTypes.ONNX,
input_names=None,
@ -49,7 +48,6 @@ class TestUtilityFuns(TestCase):
_disable_torch_constant_prop=True,
operator_export_type=operator_export_type,
training=training,
example_outputs=example_outputs,
input_names=input_names,
dynamic_axes=dynamic_axes)
_set_onnx_shape_inference(True)
@ -116,11 +114,11 @@ class TestUtilityFuns(TestCase):
example_output = model(input_t, n)
with self.assertRaises(RuntimeError):
torch.onnx.export(model,
(input_t, n),
"test.onnx",
opset_version=self.opset_version,
example_outputs=[example_output])
torch.onnx._export(model,
(input_t, n),
"test.onnx",
opset_version=self.opset_version,
example_outputs=[example_output])
def test_constant_fold_transpose(self):
class TransposeModule(torch.nn.Module):
@ -558,27 +556,27 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Shape"
assert len(list(graph.nodes())) == 1
def test_strip_doc_string(self):
def test_verbose(self):
class MyModule(torch.nn.Module):
def forward(self, input):
return torch.exp(input)
x = torch.randn(3, 4)
def is_model_stripped(f, strip_doc_string=None):
if strip_doc_string is None:
def is_model_stripped(f, verbose=None):
if verbose is None:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
else:
torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string,
torch.onnx.export(MyModule(), x, f, verbose=verbose,
opset_version=self.opset_version)
model = onnx.load(io.BytesIO(f.getvalue()))
model_strip = copy.copy(model)
onnx.helper.strip_doc_string(model_strip)
return model == model_strip
# test strip_doc_string=True (default)
# test verbose=False (default)
self.assertTrue(is_model_stripped(io.BytesIO()))
# test strip_doc_string=False
self.assertFalse(is_model_stripped(io.BytesIO(), False))
# test verbose=True
self.assertFalse(is_model_stripped(io.BytesIO(), True))
# NB: remove this test once DataParallel can be correctly handled
def test_error_on_data_parallel(self):
@ -720,9 +718,8 @@ class TestUtilityFuns(TestCase):
q_model = torch.quantization.convert(q_model, inplace=False)
q_model.eval()
output = q_model(*pt_inputs)
graph, _, __ = self._model_to_graph(q_model, pt_inputs, example_outputs=output,
graph, _, __ = self._model_to_graph(q_model, pt_inputs,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=['pt_inputs'],
dynamic_axes={'pt_inputs': [0, 1, 2, 3]})
@ -749,9 +746,8 @@ class TestUtilityFuns(TestCase):
x = torch.tensor([2])
model = PrimModule()
output = model(x)
model.eval()
graph, _, __ = self._model_to_graph(model, (x,), example_outputs=output,
graph, _, __ = self._model_to_graph(model, (x,),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=['x'], dynamic_axes={'x': [0]})
iter = graph.nodes()
@ -814,10 +810,9 @@ class TestUtilityFuns(TestCase):
model = torch.jit.script(MyModule())
x = torch.randn(10, 3, 128, 128)
example_outputs = model(x)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX,
training=torch.onnx.TrainingMode.TRAINING,
input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]})

View File

@ -1233,8 +1233,6 @@ class TestQuantizeONNXExport(JitTestCase):
input_names = ["x"]
def export_to_onnx(model, input, input_names):
outputs = model(input)
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
@ -1242,7 +1240,7 @@ class TestQuantizeONNXExport(JitTestCase):
model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
torch.onnx.export(model, input, f, input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
onnx_model = export_to_onnx(model, data, input_names)

View File

@ -1524,6 +1524,28 @@ except RuntimeError as e:
):
self.assertEqual(list(fn()), list(fn()))
for sampler in (
RandomSampler(self.dataset, num_samples=5, replacement=True),
RandomSampler(self.dataset, replacement=False),
WeightedRandomSampler(weights, num_samples=5, replacement=True),
WeightedRandomSampler(weights, num_samples=5, replacement=False),
SubsetRandomSampler(range(10)),
):
torch.manual_seed(0)
l1 = list(sampler) + list(sampler)
torch.manual_seed(0)
l2 = list(sampler) + list(sampler)
self.assertEqual(l1, l2)
its = (iter(sampler), iter(sampler))
ls = ([], [])
for idx in range(len(sampler)):
for i in range(2):
if idx == 0:
torch.manual_seed(0)
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])
def _test_sampler(self, **kwargs):
indices = range(2, 12) # using a regular iterable

View File

@ -1,3 +1,4 @@
import copy
import http.server
import itertools
import os
@ -414,13 +415,30 @@ class TestIterableDataPipeBasic(TestCase):
# Test Case: Uneven DataPipes
source_numbers = list(range(0, 10)) + [10, 12]
numbers_dp = IDP(source_numbers)
numbers_dp = dp.iter.IterableWrapper(source_numbers)
n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
self.assertEqual([1, 3, 5, 7, 9], list(n2))
n = n1.mux(n2)
self.assertEqual(source_numbers, list(n))
@suppress_warnings # Suppress warning for lambda fn
def test_map_with_col_file_handle_datapipe(self):
temp_dir = self.temp_dir.name
datapipe1 = dp.iter.FileLister(temp_dir, '')
datapipe2 = dp.iter.FileLoader(datapipe1)
def _helper(datapipe):
dp1 = datapipe.map(lambda x: x.read(), input_col=1)
dp2 = datapipe.map(lambda x: (x[0], x[1].read()))
self.assertEqual(list(dp1), list(dp2))
# tuple
_helper(datapipe2)
# list
datapipe3 = datapipe2.map(lambda x: list(x))
_helper(datapipe3)
class TestDataFramesPipes(TestCase):
"""
@ -619,25 +637,13 @@ class IDP_NoLen(IterDataPipe):
super().__init__()
self.input_dp = input_dp
# Prevent in-place modification
def __iter__(self):
for i in self.input_dp:
input_dp = self.input_dp if isinstance(self.input_dp, IterDataPipe) else copy.deepcopy(self.input_dp)
for i in input_dp:
yield i
class IDP(IterDataPipe):
def __init__(self, input_dp):
super().__init__()
self.input_dp = input_dp
self.length = len(input_dp)
def __iter__(self):
for i in self.input_dp:
yield i
def __len__(self):
return self.length
class MDP(MapDataPipe):
def __init__(self, input_dp):
super().__init__()
@ -669,19 +675,19 @@ class TestFunctionalIterDataPipe(TestCase):
def _test_picklable(self):
arr = range(10)
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
(dp.iter.Mapper, IDP(arr), (), {}),
(dp.iter.Mapper, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
(dp.iter.Collator, IDP(arr), (), {}),
(dp.iter.Collator, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
(dp.iter.Filter, IDP(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (), {}),
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (_fake_fn, (0, ), {'test': True}), {}),
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (), {}),
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (_fake_fn, (0, ), {'test': True}), {}),
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
(dp.iter.Mapper, IDP(arr), (lambda x: x, ), {}),
(dp.iter.Collator, IDP(arr), (lambda x: x, ), {}),
(dp.iter.Filter, IDP(arr), (lambda x: x >= 5, ), {}),
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (lambda x: x >= 5, ), {}),
]
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
@ -692,8 +698,8 @@ class TestFunctionalIterDataPipe(TestCase):
p = pickle.dumps(datapipe)
def test_concat_datapipe(self):
input_dp1 = IDP(range(10))
input_dp2 = IDP(range(5))
input_dp1 = dp.iter.IterableWrapper(range(10))
input_dp2 = dp.iter.IterableWrapper(range(5))
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
dp.iter.Concater()
@ -718,7 +724,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_fork_datapipe(self):
input_dp = IDP(range(10))
input_dp = dp.iter.IterableWrapper(range(10))
with self.assertRaises(ValueError):
input_dp.fork(num_instances=0)
@ -836,7 +842,7 @@ class TestFunctionalIterDataPipe(TestCase):
self.assertEqual(len(input_dp), len(dp3))
def test_demux_datapipe(self):
input_dp = IDP(range(10))
input_dp = dp.iter.IterableWrapper(range(10))
with self.assertRaises(ValueError):
input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
@ -882,8 +888,8 @@ class TestFunctionalIterDataPipe(TestCase):
self.assertEqual(list(range(0, 5)), output2)
# Test Case: classifer returns a value outside of [0, num_instance - 1]
dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
it = iter(dp[0])
dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
it = iter(dp0[0])
with self.assertRaises(ValueError):
next(it)
next(it)
@ -960,7 +966,7 @@ class TestFunctionalIterDataPipe(TestCase):
@suppress_warnings # Suppress warning for lambda fn
def test_map_datapipe(self):
input_dp = IDP(range(10))
input_dp = dp.iter.IterableWrapper(range(10))
def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
@ -1005,7 +1011,7 @@ class TestFunctionalIterDataPipe(TestCase):
def _helper(ref_fn, fn, input_col=None, output_col=None):
for constr in (list, tuple):
datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
@ -1072,9 +1078,11 @@ class TestFunctionalIterDataPipe(TestCase):
return _data
def _helper(ref_fn, fn, input_col=None, output_col=None):
datapipe = IDP([{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}])
datapipe = dp.iter.IterableWrapper(
[{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}]
)
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
@ -1117,7 +1125,7 @@ class TestFunctionalIterDataPipe(TestCase):
# TODO(VitalyFedyunin): If dill installed this test fails
def _test_map_datapipe_nested_level(self):
input_dp = IDP([list(range(10)) for _ in range(3)])
input_dp = dp.iter.IterableWrapper([list(range(10)) for _ in range(3)])
def fn(item, *, dtype=torch.float):
return torch.tensor(item, dtype=dtype)
@ -1153,7 +1161,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_collate_datapipe(self):
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
input_dp = IDP(arrs)
input_dp = dp.iter.IterableWrapper(arrs)
def _collate_fn(batch):
return torch.tensor(sum(batch), dtype=torch.float)
@ -1172,7 +1180,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_batch_datapipe(self):
arrs = list(range(10))
input_dp = IDP(arrs)
input_dp = dp.iter.IterableWrapper(arrs)
with self.assertRaises(AssertionError):
input_dp.batch(batch_size=0)
@ -1200,7 +1208,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_unbatch_datapipe(self):
target_length = 6
prebatch_dp = IDP(range(target_length))
prebatch_dp = dp.iter.IterableWrapper(range(target_length))
input_dp = prebatch_dp.batch(3)
unbatch_dp = input_dp.unbatch()
@ -1208,13 +1216,13 @@ class TestFunctionalIterDataPipe(TestCase):
for i, res in zip(prebatch_dp, unbatch_dp):
self.assertEqual(i, res)
input_dp = IDP([[0, 1, 2], [3, 4, 5]])
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
unbatch_dp = input_dp.unbatch()
self.assertEqual(len(list(unbatch_dp)), target_length)
for i, res in zip(prebatch_dp, unbatch_dp):
self.assertEqual(i, res)
input_dp = IDP([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
unbatch_dp = input_dp.unbatch()
expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
@ -1233,7 +1241,7 @@ class TestFunctionalIterDataPipe(TestCase):
for i, res in zip(expected_dp2, unbatch_dp):
self.assertEqual(i, res)
input_dp = IDP([[0, 1, 2], [3, 4, 5]])
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
with self.assertRaises(ValueError):
unbatch_dp = input_dp.unbatch(unbatch_level=-2)
for i in unbatch_dp:
@ -1245,7 +1253,7 @@ class TestFunctionalIterDataPipe(TestCase):
print(i)
def test_bucket_batch_datapipe(self):
input_dp = IDP(range(20))
input_dp = dp.iter.IterableWrapper(range(20))
with self.assertRaises(AssertionError):
dp.iter.BucketBatcher(input_dp, batch_size=0)
@ -1258,7 +1266,7 @@ class TestFunctionalIterDataPipe(TestCase):
data_len = 100
arrs = list(range(data_len))
random.shuffle(arrs)
input_dp = IDP(arrs)
input_dp = dp.iter.IterableWrapper(arrs)
bucket_dp = dp.iter.BucketBatcher(input_dp, **kwargs)
self.assertEqual(len(bucket_dp), data_len // 3 if kwargs['drop_last'] else data_len // 3 + 1)
@ -1291,7 +1299,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_filter_datapipe(self):
input_ds = IDP(range(10))
input_ds = dp.iter.IterableWrapper(range(10))
def _filter_fn(data, val, clip=False):
if clip:
@ -1318,7 +1326,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_filter_datapipe_nested_list(self):
input_ds = IDP(range(10)).batch(5)
input_ds = dp.iter.IterableWrapper(range(10)).batch(5)
def _filter_fn(data, val):
return data >= val
@ -1340,7 +1348,7 @@ class TestFunctionalIterDataPipe(TestCase):
filter_dp = input_ds.filter(nesting_level=5, filter_fn=_filter_fn, fn_kwargs={'val': 5})
temp = list(filter_dp)
input_ds = IDP(range(10)).batch(3)
input_ds = dp.iter.IterableWrapper(range(10)).batch(3)
filter_dp = input_ds.filter(lambda ls: len(ls) >= 3)
expected_dp3: List[List[int]] = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
@ -1348,21 +1356,21 @@ class TestFunctionalIterDataPipe(TestCase):
for data, exp in zip(filter_dp, expected_dp3):
self.assertEqual(data, exp)
input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
input_ds = dp.iter.IterableWrapper([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
filter_dp = input_ds.filter(lambda x: x > 3, nesting_level=-1)
expected_dp4 = [[[4, 5]], [[6, 7, 8]]]
self.assertEqual(len(list(filter_dp)), len(expected_dp4))
for data2, exp2 in zip(filter_dp, expected_dp4):
self.assertEqual(data2, exp2)
input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
input_ds = dp.iter.IterableWrapper([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
filter_dp = input_ds.filter(lambda x: x > 7, nesting_level=-1)
expected_dp5 = [[[8]]]
self.assertEqual(len(list(filter_dp)), len(expected_dp5))
for data2, exp2 in zip(filter_dp, expected_dp5):
self.assertEqual(data2, exp2)
input_ds = IDP([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]])
input_ds = dp.iter.IterableWrapper([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]])
filter_dp = input_ds.filter(lambda ls: len(ls) >= 3, nesting_level=1)
expected_dp6 = [[[6, 7, 8], [1, 2, 3]]]
self.assertEqual(len(list(filter_dp)), len(expected_dp6))
@ -1370,7 +1378,7 @@ class TestFunctionalIterDataPipe(TestCase):
self.assertEqual(data2, exp2)
def test_sampler_datapipe(self):
input_dp = IDP(range(10))
input_dp = dp.iter.IterableWrapper(range(10))
# Default SequentialSampler
sampled_dp = dp.iter.Sampler(input_dp) # type: ignore[var-annotated]
self.assertEqual(len(sampled_dp), 10)
@ -1387,7 +1395,7 @@ class TestFunctionalIterDataPipe(TestCase):
def test_shuffle_datapipe(self):
exp = list(range(20))
input_ds = IDP(exp)
input_ds = dp.iter.IterableWrapper(exp)
with self.assertRaises(AssertionError):
shuffle_dp = input_ds.shuffle(buffer_size=0)
@ -1413,15 +1421,15 @@ class TestFunctionalIterDataPipe(TestCase):
def test_zip_datapipe(self):
with self.assertRaises(TypeError):
dp.iter.Zipper(IDP(range(10)), list(range(10))) # type: ignore[arg-type]
dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10))) # type: ignore[arg-type]
zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated]
zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated]
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(zipped_dp)
exp = list((i, i) for i in range(5))
self.assertEqual(list(zipped_dp), exp)
zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP(range(5)))
zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5)))
self.assertEqual(len(zipped_dp), 5)
self.assertEqual(list(zipped_dp), exp)
# Reset
@ -1506,32 +1514,32 @@ class TestFunctionalMapDataPipe(TestCase):
def test_mux_datapipe(self):
# Test Case: Elements are yielded one at a time from each DataPipe, until they are all exhausted
input_dp1 = IDP(range(4))
input_dp2 = IDP(range(4, 8))
input_dp3 = IDP(range(8, 12))
input_dp1 = dp.iter.IterableWrapper(range(4))
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
output_dp = input_dp1.mux(input_dp2, input_dp3)
expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
self.assertEqual(len(expected_output), len(output_dp))
self.assertEqual(expected_output, list(output_dp))
# Test Case: Uneven input Data Pipes
input_dp1 = IDP([1, 2, 3, 4])
input_dp2 = IDP([10])
input_dp3 = IDP([100, 200, 300])
input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4])
input_dp2 = dp.iter.IterableWrapper([10])
input_dp3 = dp.iter.IterableWrapper([100, 200, 300])
output_dp = input_dp1.mux(input_dp2, input_dp3)
expected_output = [1, 10, 100, 2, 200, 3, 300, 4]
self.assertEqual(len(expected_output), len(output_dp))
self.assertEqual(expected_output, list(output_dp))
# Test Case: Empty Data Pipe
input_dp1 = IDP([0, 1, 2, 3])
input_dp2 = IDP([])
input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3])
input_dp2 = dp.iter.IterableWrapper([])
output_dp = input_dp1.mux(input_dp2)
self.assertEqual(len(input_dp1), len(output_dp))
self.assertEqual(list(input_dp1), list(output_dp))
# Test Case: raises TypeError when __len__ is called and an input doesn't have __len__
input_dp1 = IDP(range(10))
input_dp1 = dp.iter.IterableWrapper(range(10))
input_dp_no_len = IDP_NoLen(range(10))
output_dp = input_dp1.mux(input_dp_no_len)
with self.assertRaises(TypeError):
@ -1665,8 +1673,8 @@ class TestTyping(TestCase):
self.assertTrue(issubclass(DP1, IterDataPipe))
dp1 = DP1(10)
self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))
dp2 = DP1(5)
self.assertEqual(dp1.type, dp2.type)
dp1_ = DP1(5)
self.assertEqual(dp1.type, dp1_.type)
with self.assertRaisesRegex(TypeError, r"is not a generic class"):
class InvalidDP5(DP1[tuple]): # type: ignore[type-arg]
@ -1679,10 +1687,10 @@ class TestTyping(TestCase):
yield d # type: ignore[misc]
self.assertTrue(issubclass(DP2, IterDataPipe))
dp1 = DP2() # type: ignore[assignment]
self.assertTrue(DP2.type.issubtype(dp1.type) and dp1.type.issubtype(DP2.type))
dp2 = DP2() # type: ignore[assignment]
self.assertEqual(dp1.type, dp2.type)
dp2 = DP2() # type: ignore[var-annotated]
self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type))
dp2_ = DP2() # type: ignore[var-annotated]
self.assertEqual(dp2.type, dp2_.type)
class DP3(IterDataPipe[Tuple[T_co, str]]):
r""" DataPipe without fixed type with __init__ function"""
@ -1695,10 +1703,10 @@ class TestTyping(TestCase):
yield d, str(d)
self.assertTrue(issubclass(DP3, IterDataPipe))
dp1 = DP3(range(10)) # type: ignore[assignment]
self.assertTrue(DP3.type.issubtype(dp1.type) and dp1.type.issubtype(DP3.type))
dp2 = DP3(5) # type: ignore[assignment]
self.assertEqual(dp1.type, dp2.type)
dp3 = DP3(range(10)) # type: ignore[var-annotated]
self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type))
dp3_ = DP3(5) # type: ignore[var-annotated]
self.assertEqual(dp3.type, dp3_.type)
class DP4(IterDataPipe[tuple]):
r""" DataPipe without __iter__ annotation"""
@ -1707,8 +1715,8 @@ class TestTyping(TestCase):
raise NotImplementedError
self.assertTrue(issubclass(DP4, IterDataPipe))
dp = DP4()
self.assertTrue(dp.type.param == tuple)
dp4 = DP4()
self.assertTrue(dp4.type.param == tuple)
class DP5(IterDataPipe):
r""" DataPipe without type annotation"""
@ -1717,9 +1725,9 @@ class TestTyping(TestCase):
raise NotImplementedError
self.assertTrue(issubclass(DP5, IterDataPipe))
dp = DP5() # type: ignore[assignment]
dp5 = DP5()
from torch.utils.data._typing import issubtype
self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param))
self.assertTrue(issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param))
class DP6(IterDataPipe[int]):
r""" DataPipe with plain Iterator"""
@ -1728,13 +1736,13 @@ class TestTyping(TestCase):
raise NotImplementedError
self.assertTrue(issubclass(DP6, IterDataPipe))
dp = DP6() # type: ignore[assignment]
self.assertTrue(dp.type.param == int)
dp6 = DP6()
self.assertTrue(dp6.type.param == int)
class DP7(IterDataPipe[Awaitable[T_co]]):
r""" DataPipe with abstract base class"""
self.assertTrue(issubclass(DP6, IterDataPipe))
self.assertTrue(issubclass(DP7, IterDataPipe))
self.assertTrue(DP7.type.param == Awaitable[T_co])
class DP8(DP7[str]):
@ -1765,11 +1773,11 @@ class TestTyping(TestCase):
# Non-DataPipe input with DataPipe hint
datasource = [(1, '1'), (2, '2'), (3, '3')]
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
dp = DP0(datasource)
dp0 = DP0(datasource)
dp = DP0(IDP(range(10)))
dp0 = DP0(dp.iter.IterableWrapper(range(10)))
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
dp = DP1(dp)
dp1 = DP1(dp0)
def test_runtime(self):
class DP(IterDataPipe[Tuple[int, T_co]]):
@ -1784,26 +1792,26 @@ class TestTyping(TestCase):
dss = ([(1, '1'), (2, '2')],
[(1, 1), (2, '2')])
for ds in dss:
dp = DP(ds) # type: ignore[var-annotated]
self.assertEqual(list(dp), ds)
dp0 = DP(ds) # type: ignore[var-annotated]
self.assertEqual(list(dp0), ds)
# Reset __iter__
self.assertEqual(list(dp), ds)
self.assertEqual(list(dp0), ds)
dss = ([(1, 1), ('2', 2)], # type: ignore[assignment, list-item]
[[1, '1'], [2, '2']], # type: ignore[list-item]
[1, '1', 2, '2'])
for ds in dss:
dp = DP(ds)
dp0 = DP(ds)
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
list(dp)
list(dp0)
with runtime_validation_disabled():
self.assertEqual(list(dp), ds)
self.assertEqual(list(dp0), ds)
with runtime_validation_disabled():
self.assertEqual(list(dp), ds)
self.assertEqual(list(dp0), ds)
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
list(dp)
list(dp0)
def test_reinforce(self):
T = TypeVar('T', int, str)
@ -1819,26 +1827,26 @@ class TestTyping(TestCase):
ds = list(range(10))
# Valid type reinforcement
dp = DP(ds).reinforce_type(int)
self.assertTrue(dp.type, int)
self.assertEqual(list(dp), ds)
dp0 = DP(ds).reinforce_type(int)
self.assertTrue(dp0.type, int)
self.assertEqual(list(dp0), ds)
# Invalid type
with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"):
dp = DP(ds).reinforce_type(1)
dp1 = DP(ds).reinforce_type(1)
# Type is not subtype
with self.assertRaisesRegex(TypeError, r"Expected 'expected_type' as subtype of"):
dp = DP(ds).reinforce_type(float)
dp2 = DP(ds).reinforce_type(float)
# Invalid data at runtime
dp = DP(ds).reinforce_type(str)
dp3 = DP(ds).reinforce_type(str)
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
list(dp)
list(dp3)
# Context Manager to disable the runtime validation
with runtime_validation_disabled():
self.assertEqual(list(d for d in dp), ds)
self.assertEqual(list(d for d in dp3), ds)
class NumbersDataset(IterDataPipe):
@ -1900,7 +1908,7 @@ class TestSharding(TestCase):
self.assertEqual(sorted(all_items), sorted(items))
def test_sharding_length(self):
numbers_dp = IDP(range(13))
numbers_dp = dp.iter.IterableWrapper(range(13))
sharded_dp0 = numbers_dp.sharding_filter()
torch.utils.data.sharding.apply_sharding(sharded_dp0, 3, 0)
sharded_dp1 = numbers_dp.sharding_filter()
@ -1912,7 +1920,7 @@ class TestSharding(TestCase):
self.assertEqual(4, len(sharded_dp1))
self.assertEqual(4, len(sharded_dp2))
numbers_dp = IDP(range(1))
numbers_dp = dp.iter.IterableWrapper(range(1))
sharded_dp0 = numbers_dp.sharding_filter()
torch.utils.data.sharding.apply_sharding(sharded_dp0, 2, 0)
sharded_dp1 = numbers_dp.sharding_filter()
@ -1922,11 +1930,11 @@ class TestSharding(TestCase):
@skipIfNoDill
def test_old_dataloader(self):
dp = self._get_pipeline()
expected = list(dp)
dp0 = self._get_pipeline()
expected = list(dp0)
dp = self._get_pipeline().sharding_filter()
dl = DataLoader(dp, batch_size=1, shuffle=False, num_workers=2,
dp0 = self._get_pipeline().sharding_filter()
dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2,
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn)
items = []
for i in dl:

View File

@ -9704,12 +9704,6 @@ class TestNN(NNTestCase):
self.assertEqual(input1.grad, torch.zeros_like(input1))
self.assertEqual(input2.grad, input1 * 1e8)
# Check error when inputs are not the same shape
input1 = torch.randn(2, 2, 1)
input2 = torch.randn(2, 1, 3)
with self.assertRaises(RuntimeError):
F.cosine_similarity(input1, input2)
# Check type promotion, issue #61454
input = torch.tensor(12.)
out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)

View File

@ -1,7 +1,8 @@
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn.functional as F
from torch import Tensor, vmap
from torch import Tensor
from torch._vmap_internals import vmap
import functools
import itertools
import warnings

View File

@ -387,7 +387,7 @@ class CMake:
# then. Until then, we use "--" to pass parameters to the
# underlying build system.
build_args += ['--']
if IS_WINDOWS:
if IS_WINDOWS and not USE_NINJA:
# We are likely using msbuild here
build_args += ['/p:CL_MPCount={}'.format(max_jobs)]
else:

View File

@ -108,7 +108,7 @@ def plural(n: int) -> str:
def get_base_commit(sha1: str) -> str:
return subprocess.check_output(
["git", "merge-base", sha1, "origin/master"],
["git", "merge-base", sha1, "origin/release/1.10"],
encoding="ascii",
).strip()

View File

@ -22,7 +22,11 @@ class TestCMake(unittest.TestCase):
# MAX_JOBS, USE_NINJA, IS_WINDOWS, want
(( '8', True, False), ['-j', '8']), # noqa: E201,E241
(( None, True, False), None), # noqa: E201,E241
(( '7', False, False), ['-j', '7']), # noqa: E201,E241
(( None, False, False), ['-j', '13']), # noqa: E201,E241
(( '6', True, True), ['-j', '6']), # noqa: E201,E241
(( None, True, True), None), # noqa: E201,E241
(( '11', False, True), ['/p:CL_MPCount=11']), # noqa: E201,E241
(( None, False, True), ['/p:CL_MPCount=13']), # noqa: E201,E241
]
for (max_jobs, use_ninja, is_windows), want in cases:

View File

@ -9,6 +9,13 @@ if(NOT CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
endif()
if(BUILD_BINARY)
add_library(aot_compiler SHARED
${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
)
install(TARGETS aot_compiler DESTINATION lib)
endif()
if(NOT BUILD_PYTHON)
return()
endif()
@ -430,9 +437,3 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
# Pybind11 requires explicit linking of the torch_python library
target_link_libraries(nnapi_backend torch torch_python)
endif()
if(BUILD_BINARY)
add_library(aot_compiler SHARED
${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
)
endif()

View File

@ -756,8 +756,6 @@ del register_after_fork
# torch.jit.script as a decorator, for instance):
from ._lobpcg import lobpcg as lobpcg
from ._vmap_internals import vmap as vmap
# These were previously defined in native_functions.yaml and appeared on the
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
# class usage. We add these lines here to preserve backward compatibility.

View File

@ -362,19 +362,21 @@ void ConvertGraphToONNXProto(
SymbolDimMap& symbol_map,
int opset_version) {
RawDataExportMap export_map;
std::tie(model_proto, export_map, symbol_map) = export_onnx(
graph,
{},
opset_version,
{},
false,
onnx_torch::OperatorExportTypes::ONNX,
true,
true,
{},
true,
false,
std::string());
bool val_use_external_data_format;
std::tie(model_proto, export_map, symbol_map, val_use_external_data_format) =
export_onnx(
graph,
{},
opset_version,
{},
false,
onnx_torch::OperatorExportTypes::ONNX,
true,
true,
{},
true,
false,
std::string());
for (int i = 0; i < model_proto->graph().output_size(); ++i) {
model_proto->mutable_graph()->mutable_output(i)->clear_type();
}

View File

@ -263,19 +263,25 @@ void initPythonIRBindings(PyObject* module_) {
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
RawDataExportMap export_map;
SymbolDimMap symbol_map;
std::tie(model_proto, export_map, symbol_map) = export_onnx(
g,
initializers,
onnx_opset_version,
dynamic_axes,
defer_weight_export,
operator_export_type,
strip_doc_string,
keep_initializers_as_inputs,
custom_opsets,
add_node_names,
use_external_data_format,
onnx_file_path);
bool val_use_external_data_format = false;
std::tie(
model_proto,
export_map,
symbol_map,
val_use_external_data_format) =
export_onnx(
g,
initializers,
onnx_opset_version,
dynamic_axes,
defer_weight_export,
operator_export_type,
strip_doc_string,
keep_initializers_as_inputs,
custom_opsets,
add_node_names,
use_external_data_format,
onnx_file_path);
std::unordered_map<std::string, py::bytes>
python_serialized_export_map;
for (auto& kv : export_map) {
@ -289,7 +295,9 @@ void initPythonIRBindings(PyObject* module_) {
}
graph = serialize_model_proto_to_string(model_proto);
return std::make_tuple(
py::bytes(graph), python_serialized_export_map);
py::bytes(graph),
python_serialized_export_map,
val_use_external_data_format);
},
py::arg("initializers"),
py::arg("onnx_opset_version") = 0,

View File

@ -231,6 +231,12 @@ class EncoderBase {
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());
unsigned long long int GetGraphProtoSize(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers =
std::map<std::string, at::Tensor>());
virtual void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
@ -581,7 +587,6 @@ void EncoderBase::AddInitializersIntoGraphProto(
bool use_external_data_format,
const std::string& onnx_file_path) {
AT_ASSERT(block->inputs().size() >= initializers.size());
for (auto input : block->inputs()) {
auto name_tensor_pair = initializers.find(input->debugName());
if (name_tensor_pair == initializers.end()) {
@ -598,6 +603,38 @@ void EncoderBase::AddInitializersIntoGraphProto(
}
}
unsigned long long int EncoderBase::GetGraphProtoSize(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers) {
unsigned long long int sizes = 0;
for (auto input : graph->inputs()) {
auto name_tensor_pair = initializers.find(input->debugName());
if (name_tensor_pair == initializers.end()) {
continue;
}
onnx::GraphProto* graph_proto_copy = new onnx::GraphProto(*graph_proto);
auto tensor_proto = graph_proto_copy->add_initializer();
const at::Tensor tensor = name_tensor_pair->second;
for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
at::Tensor t;
if (tensor.is_quantized()) {
t = tensor.contiguous();
} else {
t = tensor.contiguous().cpu();
}
tensor_proto->set_raw_data(std::string(
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
sizes += tensor_proto->ByteSizeLong();
delete graph_proto_copy;
graph_proto_copy = nullptr;
}
return sizes;
}
void EncoderBase::AddAttribute(
onnx::NodeProto* node_proto,
const jit::Node* node,
@ -726,6 +763,10 @@ class GraphEncoder : public EncoderBase {
return raw_data_export_map_;
}
bool get_use_external_data_format() {
return use_external_data_format_;
}
private:
void EncodeTensor(
onnx::TensorProto* tensor_proto,
@ -736,6 +777,7 @@ class GraphEncoder : public EncoderBase {
RawDataExportMap raw_data_export_map_;
bool defer_weight_export_;
bool use_external_data_format_;
};
GraphEncoder::GraphEncoder(
@ -754,8 +796,21 @@ GraphEncoder::GraphEncoder(
bool use_external_data_format,
const std::string& onnx_file_path)
: EncoderBase(operator_export_type, strip_doc),
defer_weight_export_(defer_weight_export) {
defer_weight_export_(defer_weight_export),
use_external_data_format_(use_external_data_format) {
validateGraph(graph, operator_export_type);
// If graph proto size exceed maximum protobuf size of 2GB, set
// use_external_data_format to true.
if (!use_external_data_format && !onnx_file_path.empty() &&
GetGraphProtoSize(model_proto_.mutable_graph(), graph, initializers) >
INT_MAX) {
GRAPH_DEBUG(
"Exporting model exceed maximum protobuf size of 2GB. Storing model parameters in external data files");
use_external_data_format = true;
// use_external_data_format_ is one of graph_encoder private variable set
// for return `use_external_data_format` value.
use_external_data_format_ = use_external_data_format;
}
if (use_external_data_format) {
TORCH_CHECK(
@ -895,7 +950,8 @@ std::string pretty_print_onnx(
std::tuple<
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
RawDataExportMap,
SymbolDimMap>
SymbolDimMap,
bool>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
@ -929,7 +985,8 @@ export_onnx(
std::make_shared<::ONNX_NAMESPACE::ModelProto>(
graph_encoder.get_model_proto()),
graph_encoder.get_raw_data_export_map(),
graph_encoder.get_symbol_dim_param_map());
graph_encoder.get_symbol_dim_param_map(),
graph_encoder.get_use_external_data_format());
}
std::string serialize_model_proto_to_string(
@ -938,7 +995,7 @@ std::string serialize_model_proto_to_string(
TORCH_CHECK(
proto_size <= INT_MAX,
"Exporting model exceed maximum protobuf size of 2GB. "
"Please call torch.onnx.export with use_external_data_format=True.");
"Please call torch.onnx.export without setting use_external_data_format parameter.");
return model_proto->SerializeAsString();
}

View File

@ -33,7 +33,8 @@ using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
TORCH_API std::tuple<
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
RawDataExportMap,
SymbolDimMap>
SymbolDimMap,
bool>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,

View File

@ -1519,8 +1519,11 @@ def _object_to_tensor(obj):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long)
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size

View File

@ -4256,7 +4256,10 @@ cosine_similarity = _add_docstr(
r"""
cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
Returns cosine similarity between x1 and x2, computed along dim.
Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 fewer dimension.
.. math ::
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
@ -4265,16 +4268,11 @@ Supports :ref:`type promotion <type-promotion-doc>`.
Args:
x1 (Tensor): First input.
x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`,
and broadcastable with x1 at other dimensions).
dim (int, optional): Dimension of vectors. Default: 1
x2 (Tensor): Second input.
dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
eps (float, optional): Small value to avoid division by zero.
Default: 1e-8
Shape:
- Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
- Output: :math:`(\ast_1, \ast_2)`
Example::
>>> input1 = torch.randn(100, 128)

View File

@ -31,10 +31,10 @@ def _export(*args, **kwargs):
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, operator_export_type=None,
opset_version=None, _retain_param_name=True, do_constant_folding=True,
example_outputs=None, strip_doc_string=True, dynamic_axes=None,
keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True,
use_external_data_format=False):
opset_version=None, _retain_param_name=None, do_constant_folding=True,
example_outputs=None, strip_doc_string=None, dynamic_axes=None,
keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None,
use_external_data_format=None):
r"""
Exports a model into ONNX format. If ``model`` is not a
:class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs
@ -105,7 +105,9 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
In this case, the exported model will first take all of its parameters
as arguments, with the ordering as specified by ``model.state_dict().values()``
verbose (bool, default False): if True, prints a description of the
model being exported to stdout.
model being exported to stdout. In addition, the final ONNX graph will include the
field ``doc_string``` from the exported model which mentions the source code locations
for ``model``.
training (enum, default TrainingMode.EVAL):
* ``TrainingMode.EVAL``: export the model in inference mode.
* ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
@ -181,16 +183,17 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
opset_version (int, default 9):
Must be ``== _onnx_main_opset or in _onnx_stable_opsets``,
defined in torch/onnx/symbolic_helper.py.
_retain_param_name (bool, default True): [Deprecated and ignored. Will be removed in next PyTorch
release]
do_constant_folding (bool, default False): Apply the constant-folding optimization.
Constant-folding will replace some of the ops that have all constant inputs
with pre-computed constant nodes.
example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None):
[Deprecated and ignored. Will be removed in next PyTorch release],
Must be provided when exporting a ScriptModule or ScriptFunction, ignored otherwise.
Used to determine the type and shape of the outputs without tracing the execution of
the model. A single object is treated as equivalent to a tuple of one element.
strip_doc_string (bool, default True): do not include the field
``doc_string``` from the exported model. Otherwise the field will mention the source
code locations for ``model``.
strip_doc_string (bool, default True): [Deprecated and ignored. Will be removed in next PyTorch release]
dynamic_axes (dict<string, dict<int, string>> or dict<string, list(int)>, default empty dict):
By default the exported model will have the shapes of all input and output tensors
@ -293,10 +296,11 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
enable_onnx_checker (bool, default True): Deprecated and ignored. Will be removed in next
Pytorch release.
use_external_data_format (bool, default False): If True, then some of the model
parameters are stored in external data files and not in the ONNX model file itself.
Models larger than 2GB cannot be exported in one file because of size limits imposed
by Protocol Buffers.
use_external_data_format (bool, default False): [Deprecated and ignored. Will be removed in
next Pytorch release.]
If True, then some of the model parameters are stored in external data files and not in
the ONNX model file itself. Models larger than 2GB cannot be exported in one file because
of size limits imposed by Protocol Buffers.
For details see
`onnx.proto <https://github.com/onnx/onnx/blob/32c7cb66/onnx/onnx.proto#L562>`_.
If True, argument ``f`` must be a string specifying the location of the model.
@ -316,9 +320,24 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
custom_opsets, enable_onnx_checker, use_external_data_format)
def export_to_pretty_string(*args, **kwargs):
def export_to_pretty_string(*args, **kwargs) -> str:
r"""
Same as :func:`export`, but returns a text representation of the exported model.
Similar to :func:`export`, but returns a text representation of the ONNX
model. Only differences in args listed below. All other args are the same
as :func:`export`.
Args:
f: Deprecated and ignored. Will be removed in the next release of
PyTorch.
add_node_names (bool, default True): Whether or not to set
NodeProto.name. This makes no difference unless
``google_printer=True``.
google_printer (bool, default False): If False, will return a custom,
compact representation of the model. If True will return the
protobuf's `Message::DebugString()`, which is more verbose.
Returns:
A UTF-8 str containing a human-readable representation of the ONNX model.
"""
from torch.onnx import utils
return utils.export_to_pretty_string(*args, **kwargs)

View File

@ -12,6 +12,7 @@ import torch.serialization
import re
import collections
import contextlib
import copy
import numbers
import warnings
from torch._six import string_classes
@ -75,24 +76,37 @@ def select_model_mode_for_export(model, mode):
def export(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=None,
opset_version=None, _retain_param_name=True, do_constant_folding=True,
example_outputs=None, strip_doc_string=True, dynamic_axes=None,
opset_version=None, _retain_param_name=None, do_constant_folding=True,
example_outputs=None, strip_doc_string=None, dynamic_axes=None,
keep_initializers_as_inputs=None, custom_opsets=None,
enable_onnx_checker=None, use_external_data_format=False):
enable_onnx_checker=None, use_external_data_format=None):
if operator_export_type is None:
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = OperatorExportTypes.ONNX
if enable_onnx_checker is not None:
warnings.warn("`enable_onnx_checker' is deprecated and ignored. It will be removed in"
"the next PyTorch release. To proceed despite ONNX checker failures, you"
"can catch torch.onnx.ONNXCheckerError.")
warnings.warn("'enable_onnx_checker' is deprecated and ignored. It will be removed in "
"the next PyTorch release. To proceed despite ONNX checker failures, "
"catch torch.onnx.ONNXCheckerError.")
if _retain_param_name is not None:
warnings.warn("'_retain_param_name' is deprecated and ignored. "
"It will be removed in the next PyTorch release.")
if strip_doc_string is not None:
warnings.warn("`strip_doc_string' is deprecated and ignored. Will be removed in "
"next PyTorch release. It's combined with `verbose' argument now. ")
if example_outputs is not None:
warnings.warn("`example_outputs' is deprecated and ignored. Will be removed in "
"next PyTorch release.")
if use_external_data_format is not None:
warnings.warn("`use_external_data_format' is deprecated and ignored. Will be removed in next "
"PyTorch release. The code will work as it is False if models are not larger than 2GB, "
"Otherwise set to False because of size limits imposed by Protocol Buffers.")
_export(model, args, f, export_params, verbose, training, input_names, output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
_retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
example_outputs=example_outputs, strip_doc_string=strip_doc_string,
do_constant_folding=do_constant_folding, example_outputs=example_outputs,
dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
@ -310,7 +324,10 @@ def _decide_external_data_format(use_external_data_format, operator_export_type,
# string specifying the location of the model. For large model cases, if f is not a non-empty string,
# then this method returns an empty string, which is an error condition for the large model export code
# path later (but not for regular model export code path).
model_file_location = f if val_use_external_data_format and isinstance(f, str) else str()
if (val_use_external_data_format is None or val_use_external_data_format is True) and isinstance(f, str):
model_file_location = f
else:
model_file_location = str()
return val_use_external_data_format, model_file_location
def _decide_input_format(model, args):
@ -389,7 +406,7 @@ def _get_param_count_list(method_graph, args_params):
return param_count_list
def _create_jit_graph(model, args, _retain_param_name):
def _create_jit_graph(model, args):
torch_out = None
params: Union[List, Tuple]
if isinstance(model, torch.jit.ScriptModule):
@ -420,13 +437,12 @@ def _create_jit_graph(model, args, _retain_param_name):
graph, torch_out = _trace_and_get_graph_from_model(model, args)
state_dict = _unique_state_dict(model)
params = list(state_dict.values())
if _retain_param_name:
graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i, inp in enumerate(graph_inputs):
if i >= user_input_num:
inp.setDebugName(param_names[i - user_input_num])
graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i, inp in enumerate(graph_inputs):
if i >= user_input_num:
inp.setDebugName(param_names[i - user_input_num])
torch._C._jit_pass_onnx_function_substitution(graph)
return graph, params, torch_out, None
@ -437,12 +453,25 @@ def _get_named_param_dict(graph, params):
_params_dict = dict(zip(param_names, params))
return _params_dict
def _get_example_outputs(model, args):
input_args = copy.deepcopy(args)
input_kwargs = {}
if input_args and isinstance(input_args[-1], dict):
input_kwargs = input_args[-1]
input_args = input_args[:-1]
example_outputs = model(*input_args, **input_kwargs)
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
example_outputs = (example_outputs,)
if isinstance(example_outputs, list):
example_outputs = [example_outputs]
return example_outputs
def _model_to_graph(model, args, verbose=False,
input_names=None, output_names=None,
operator_export_type=OperatorExportTypes.ONNX,
example_outputs=None,
_retain_param_name=False, do_constant_folding=True,
example_outputs=None, do_constant_folding=True,
_disable_torch_constant_prop=False, fixed_batch_size=False,
training=None, dynamic_axes=None):
r"""Converts model into an ONNX graph.
@ -461,11 +490,7 @@ def _model_to_graph(model, args, verbose=False,
if isinstance(args, (torch.Tensor, int, float, bool)):
args = (args, )
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
example_outputs = (example_outputs,)
graph, params, torch_out, module = _create_jit_graph(model, args,
_retain_param_name)
graph, params, torch_out, module = _create_jit_graph(model, args)
params_dict = _get_named_param_dict(graph, params)
@ -476,13 +501,22 @@ def _model_to_graph(model, args, verbose=False,
module=module)
from torch.onnx.symbolic_helper import _onnx_shape_inference
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
"ScriptFunction."
if isinstance(example_outputs, list):
example_outputs = [example_outputs]
if example_outputs is None:
example_outputs = _get_example_outputs(model, args)
else:
# example_outpus specified
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
example_outputs = (example_outputs,)
if isinstance(example_outputs, list):
example_outputs = [example_outputs]
out_vars, desc = torch.jit._flatten(tuple(example_outputs))
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, desc, _onnx_shape_inference)
else:
flatten_args, _ = torch._C._jit_flatten(args)
# make sure that the param dict and the graph match each other
assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())
# NB: ONNX requires complete information about output types, which might be
# erased by some optimizations, so we need to set it explicitly again.
@ -533,14 +567,18 @@ def _model_to_graph(model, args, verbose=False,
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
google_printer=False, opset_version=None, _retain_param_name=True,
google_printer=False, opset_version=None, _retain_param_name=None,
keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True,
do_constant_folding=True, dynamic_axes=None):
if f is not None:
warnings.warn("'f' is deprecated and ignored. It will be removed in the next PyTorch release.")
if _retain_param_name is not None:
warnings.warn("'_retain_param_name' is deprecated and ignored. "
"It will be removed in the next PyTorch release.")
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
input_names, output_names, operator_export_type,
export_type, example_outputs, google_printer,
opset_version, _retain_param_name,
do_constant_folding=do_constant_folding,
opset_version, do_constant_folding=do_constant_folding,
add_node_names=add_node_names,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets, dynamic_axes=dynamic_axes)
@ -549,7 +587,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
google_printer=False, opset_version=None, _retain_param_name=False,
google_printer=False, opset_version=None,
do_constant_folding=True, keep_initializers_as_inputs=None,
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
onnx_shape_inference=True, dynamic_axes=None):
@ -572,8 +610,8 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
args = _decide_input_format(model, args)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
example_outputs, _retain_param_name,
val_do_constant_folding, fixed_batch_size=fixed_batch_size,
example_outputs, val_do_constant_folding,
fixed_batch_size=fixed_batch_size,
training=training, dynamic_axes=dynamic_axes)
return graph._pretty_print_onnx(params_dict, opset_version, False,
@ -633,10 +671,10 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini
def _export(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=None,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
opset_version=None, _retain_param_name=False, do_constant_folding=True,
strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None,
opset_version=None, do_constant_folding=True,
dynamic_axes=None, keep_initializers_as_inputs=None,
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
use_external_data_format=False, onnx_shape_inference=True):
use_external_data_format=None, onnx_shape_inference=True):
if isinstance(model, torch.nn.DataParallel):
raise ValueError("torch.nn.DataParallel is not supported by ONNX "
@ -685,8 +723,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
graph, params_dict, torch_out = \
_model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
example_outputs, _retain_param_name,
val_do_constant_folding,
example_outputs, val_do_constant_folding,
fixed_batch_size=fixed_batch_size,
training=training,
dynamic_axes=dynamic_axes)
@ -697,14 +734,14 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
custom_opsets = {}
if export_params:
proto, export_map = graph._export_onnx(
proto, export_map, val_use_external_data_format = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
operator_export_type, not verbose, val_keep_init_as_ip, custom_opsets,
val_add_node_names, val_use_external_data_format, model_file_location)
else:
proto, export_map = graph._export_onnx(
proto, export_map, val_use_external_data_format = graph._export_onnx(
{}, opset_version, dynamic_axes, False, operator_export_type,
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
not verbose, val_keep_init_as_ip, custom_opsets, val_add_node_names,
val_use_external_data_format, model_file_location)
if export_type == ExportTypes.PROTOBUF_FILE:

View File

@ -1256,6 +1256,8 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
# Test for Broadcasting
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
return list(generator())

View File

@ -21,7 +21,9 @@ class MapperIterDataPipe(IterDataPipe):
self.dp = dp
self.fn = fn
```
Note: Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory.
Note:
- Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory.
- If `IterDataPipe` instance holds data in memory, please be ware of the in-place modification of data. When second iterator is created from the instance, the data may have already changed. Please take [`IterableWrapper`](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/datapipes/iter/utils.py) class as reference to `deepcopy` data for each iterator.
### Iterator
For `IterDataPipe`, an `__iter__` function is needed to consume data from the source `IterDataPipe` then apply operation over the data before yield.

View File

@ -1,4 +1,3 @@
import copy
import warnings
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
@ -99,8 +98,6 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
data = list(data)
else:
t_flag = False
# Deepcopy data to prevent the original data modified. E.g. list, dict
data = copy.deepcopy(data)
if self.output_col is None:
if isinstance(self.input_col, (list, tuple)):

View File

@ -1,9 +1,9 @@
import random
import warnings
from collections import defaultdict
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
T_co = TypeVar('T_co', covariant=True)
@ -185,8 +185,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert batch_num > 0, "Number of batches is required to be larger than 0!"
assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
warnings.warn("`BucketBatcher` is going to be removed from PyTorch Core")
deprecation_warning_torchdata(type(self).__name__)
super().__init__()
# TODO: Verify _datapippe is not going to be serialized twice

View File

@ -3,6 +3,7 @@ from typing import Sized, Tuple
from urllib.error import HTTPError, URLError
import urllib.request as urllib
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
@ -19,6 +20,7 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
def __init__(self, datapipe, timeout=None):
self.datapipe = datapipe
self.timeout = timeout
deprecation_warning_torchdata(type(self).__name__)
def __iter__(self):
for furl in self.datapipe:

View File

@ -1,5 +1,6 @@
from typing import Tuple
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
@ -14,6 +15,7 @@ class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
def __init__(self, datapipe):
self.datapipe = datapipe
deprecation_warning_torchdata(type(self).__name__)
def __iter__(self):
for file_name, stream in self.datapipe:

View File

@ -1,5 +1,5 @@
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple, deprecation_warning_torchdata
from typing import Iterable, Iterator, Tuple, Optional, IO, cast
from io import BufferedIOBase
@ -34,11 +34,13 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
self.mode: str = mode
self.length: int = length
deprecation_warning_torchdata(type(self).__name__)
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
for data in self.datapipe:
validate_pathname_binary_tuple(data)
pathname, data_stream = data
folder_name = os.path.dirname(pathname)
try:
# typing.cast is used here to silence mypy's type checker
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
@ -49,14 +51,12 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
if extracted_fobj is None:
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
raise tarfile.ExtractError
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
inner_pathname = os.path.normpath(os.path.join(folder_name, tarinfo.name))
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
except Exception as e:
warnings.warn(
"Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
raise e
finally:
data_stream.close()
def __len__(self):
if self.length == -1:

View File

@ -1,3 +1,5 @@
import copy
import warnings
from torch.utils.data import IterDataPipe
@ -8,12 +10,34 @@ class IterableWrapperIterDataPipe(IterDataPipe):
Args:
iterable: Iterable object to be wrapped into an IterDataPipe
deepcopy: Option to deepcopy input iterable object for each
iteration.
.. note::
If `deepcopy` is set to False explicitly, users should ensure
that data pipeline doesn't contain any in-place operations over
the iterable instance, in order to prevent data inconsistency
across iterations.
"""
def __init__(self, iterable):
def __init__(self, iterable, deepcopy=True):
self.iterable = iterable
self.deepcopy = deepcopy
def __iter__(self):
for data in self.iterable:
source_data = self.iterable
if self.deepcopy:
try:
source_data = copy.deepcopy(self.iterable)
# For the case that data cannot be deep-copied,
# all in-place operations will affect iterable variable.
# When this DataPipe is iterated second time, it will
# yield modified items.
except TypeError:
warnings.warn(
"The input iterable can not be deepcopied, "
"please be aware of in-place modification would affect source data"
)
for data in source_data:
yield data
def __len__(self):

View File

@ -1,5 +1,5 @@
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple, deprecation_warning_torchdata
from typing import Iterable, Iterator, Tuple, IO, cast
from io import BufferedIOBase
@ -31,11 +31,13 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
super().__init__()
self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
self.length: int = length
deprecation_warning_torchdata(type(self).__name__)
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
for data in self.datapipe:
validate_pathname_binary_tuple(data)
pathname, data_stream = data
folder_name = os.path.dirname(pathname)
try:
# typing.cast is used here to silence mypy's type checker
zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
@ -47,7 +49,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
elif zipinfo.filename.endswith('/'):
continue
extracted_fobj = zips.open(zipinfo)
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
inner_pathname = os.path.normpath(os.path.join(folder_name, zipinfo.filename))
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
except Exception as e:
warnings.warn(

View File

@ -63,3 +63,9 @@ def validate_pathname_binary_tuple(data):
raise TypeError("pathname binary tuple should have string type pathname, but got {}".format(type(data[0])))
if not isinstance(data[1], BufferedIOBase):
raise TypeError("pathname binary tuple should have BufferedIOBase based binary type, but got {}".format(type(data[1])))
# Warns user that the DataPipe has been moved to TorchData and will be removed from `torch`
def deprecation_warning_torchdata(name):
warnings.warn(f"{name} and its functional API are deprecated and will be removed from the package `torch`. "
f"Please import those features from the new package TorchData: https://github.com/pytorch/data",
DeprecationWarning)

View File

@ -112,15 +112,18 @@ class RandomSampler(Sampler[int]):
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
self.generator = torch.Generator()
self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=self.generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()
def __len__(self) -> int:
return self.num_samples
@ -140,7 +143,8 @@ class SubsetRandomSampler(Sampler[int]):
self.generator = generator
def __iter__(self) -> Iterator[int]:
return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
def __len__(self) -> int:
return len(self.indices)
@ -183,7 +187,7 @@ class WeightedRandomSampler(Sampler[int]):
def __iter__(self) -> Iterator[int]:
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
return iter(rand_tensor.tolist())
yield from iter(rand_tensor.tolist())
def __len__(self) -> int:
return self.num_samples