mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			17 Commits
		
	
	
		
			mlazos/dyn
			...
			v1.10.0-rc
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 9509e8a3d6 | |||
| 1774a6a2f4 | |||
| a27906c250 | |||
| 49f52b6c07 | |||
| 5f1a434599 | |||
| ecbf5a7439 | |||
| 4e3ebebcff | |||
| 2b46c95e7c | |||
| 5f3eee1ca5 | |||
| 4731f33d02 | |||
| ecfcb8ff5a | |||
| 6aadfda9e2 | |||
| 13666d20fd | |||
| 1fa17a20fc | |||
| c05547fa6c | |||
| 0e857bf109 | |||
| ad22804b95 | 
							
								
								
									
										14
									
								
								.circleci/config.yml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										14
									
								
								.circleci/config.yml
									
									
									
										generated
									
									
									
								
							@ -632,7 +632,8 @@ jobs:
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if is_vanilla_build; then
 | 
			
		||||
            echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
 | 
			
		||||
            echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
 | 
			
		||||
            echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
 | 
			
		||||
            echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
 | 
			
		||||
          else
 | 
			
		||||
            echo "Skipping for ${BUILD_ENVIRONMENT}"
 | 
			
		||||
@ -1734,16 +1735,17 @@ jobs:
 | 
			
		||||
            # install fastlane
 | 
			
		||||
            sudo gem install bundler && bundle install
 | 
			
		||||
            # install certificates
 | 
			
		||||
            echo ${IOS_CERT_KEY} >> cert.txt
 | 
			
		||||
            echo ${IOS_CERT_KEY_2022} >> cert.txt
 | 
			
		||||
            base64 --decode cert.txt -o Certificates.p12
 | 
			
		||||
            rm cert.txt
 | 
			
		||||
            bundle exec fastlane install_cert
 | 
			
		||||
            bundle exec fastlane install_root_cert
 | 
			
		||||
            bundle exec fastlane install_dev_cert
 | 
			
		||||
            # install the provisioning profile
 | 
			
		||||
            PROFILE=PyTorch_CI_2021.mobileprovision
 | 
			
		||||
            PROFILE=PyTorch_CI_2022.mobileprovision
 | 
			
		||||
            PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
 | 
			
		||||
            mkdir -pv "${PROVISIONING_PROFILES}"
 | 
			
		||||
            cd "${PROVISIONING_PROFILES}"
 | 
			
		||||
            echo ${IOS_SIGN_KEY} >> cert.txt
 | 
			
		||||
            echo ${IOS_SIGN_KEY_2022} >> cert.txt
 | 
			
		||||
            base64 --decode cert.txt -o ${PROFILE}
 | 
			
		||||
            rm cert.txt
 | 
			
		||||
      - run:
 | 
			
		||||
@ -1802,7 +1804,7 @@ jobs:
 | 
			
		||||
          command: |
 | 
			
		||||
            set -e
 | 
			
		||||
            PROJ_ROOT=/Users/distiller/project
 | 
			
		||||
            PROFILE=PyTorch_CI_2021
 | 
			
		||||
            PROFILE=PyTorch_CI_2022
 | 
			
		||||
            # run the ruby build script
 | 
			
		||||
            if ! [ -x "$(command -v xcodebuild)" ]; then
 | 
			
		||||
              echo 'Error: xcodebuild is not installed.'
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,7 @@ git --no-pager log --max-count 1
 | 
			
		||||
popd
 | 
			
		||||
 | 
			
		||||
# Clone the Builder master repo
 | 
			
		||||
retry git clone -q https://github.com/pytorch/builder.git "$BUILDER_ROOT"
 | 
			
		||||
retry git clone -q https://github.com/pytorch/builder.git -b release/1.10 "$BUILDER_ROOT"
 | 
			
		||||
pushd "$BUILDER_ROOT"
 | 
			
		||||
echo "Using builder from "
 | 
			
		||||
git --no-pager log --max-count 1
 | 
			
		||||
 | 
			
		||||
@ -8,16 +8,17 @@ cd ${PROJ_ROOT}/ios/TestApp
 | 
			
		||||
# install fastlane
 | 
			
		||||
sudo gem install bundler && bundle install
 | 
			
		||||
# install certificates
 | 
			
		||||
echo "${IOS_CERT_KEY}" >> cert.txt
 | 
			
		||||
echo "${IOS_CERT_KEY_2022}" >> cert.txt
 | 
			
		||||
base64 --decode cert.txt -o Certificates.p12
 | 
			
		||||
rm cert.txt
 | 
			
		||||
bundle exec fastlane install_cert
 | 
			
		||||
bundle exec fastlane install_root_cert
 | 
			
		||||
bundle exec fastlane install_dev_cert
 | 
			
		||||
# install the provisioning profile
 | 
			
		||||
PROFILE=PyTorch_CI_2021.mobileprovision
 | 
			
		||||
PROFILE=PyTorch_CI_2022.mobileprovision
 | 
			
		||||
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
 | 
			
		||||
mkdir -pv "${PROVISIONING_PROFILES}"
 | 
			
		||||
cd "${PROVISIONING_PROFILES}"
 | 
			
		||||
echo "${IOS_SIGN_KEY}" >> cert.txt
 | 
			
		||||
echo "${IOS_SIGN_KEY_2022}" >> cert.txt
 | 
			
		||||
base64 --decode cert.txt -o ${PROFILE}
 | 
			
		||||
rm cert.txt
 | 
			
		||||
# run the ruby build script
 | 
			
		||||
@ -25,5 +26,5 @@ if ! [ -x "$(command -v xcodebuild)" ]; then
 | 
			
		||||
    echo 'Error: xcodebuild is not installed.'
 | 
			
		||||
    exit 1
 | 
			
		||||
fi
 | 
			
		||||
PROFILE=PyTorch_CI_2021
 | 
			
		||||
PROFILE=PyTorch_CI_2022
 | 
			
		||||
ruby ${PROJ_ROOT}/scripts/xcode_build.rb -i ${PROJ_ROOT}/build_ios/install -x ${PROJ_ROOT}/ios/TestApp/TestApp.xcodeproj -p ${IOS_PLATFORM} -c ${PROFILE} -t ${IOS_DEV_TEAM_ID} -f Accelerate,MetalPerformanceShaders,CoreML
 | 
			
		||||
 | 
			
		||||
@ -467,16 +467,17 @@
 | 
			
		||||
            # install fastlane
 | 
			
		||||
            sudo gem install bundler && bundle install
 | 
			
		||||
            # install certificates
 | 
			
		||||
            echo ${IOS_CERT_KEY} >> cert.txt
 | 
			
		||||
            echo ${IOS_CERT_KEY_2022} >> cert.txt
 | 
			
		||||
            base64 --decode cert.txt -o Certificates.p12
 | 
			
		||||
            rm cert.txt
 | 
			
		||||
            bundle exec fastlane install_cert
 | 
			
		||||
            bundle exec fastlane install_root_cert
 | 
			
		||||
            bundle exec fastlane install_dev_cert
 | 
			
		||||
            # install the provisioning profile
 | 
			
		||||
            PROFILE=PyTorch_CI_2021.mobileprovision
 | 
			
		||||
            PROFILE=PyTorch_CI_2022.mobileprovision
 | 
			
		||||
            PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
 | 
			
		||||
            mkdir -pv "${PROVISIONING_PROFILES}"
 | 
			
		||||
            cd "${PROVISIONING_PROFILES}"
 | 
			
		||||
            echo ${IOS_SIGN_KEY} >> cert.txt
 | 
			
		||||
            echo ${IOS_SIGN_KEY_2022} >> cert.txt
 | 
			
		||||
            base64 --decode cert.txt -o ${PROFILE}
 | 
			
		||||
            rm cert.txt
 | 
			
		||||
      - run:
 | 
			
		||||
@ -535,7 +536,7 @@
 | 
			
		||||
          command: |
 | 
			
		||||
            set -e
 | 
			
		||||
            PROJ_ROOT=/Users/distiller/project
 | 
			
		||||
            PROFILE=PyTorch_CI_2021
 | 
			
		||||
            PROFILE=PyTorch_CI_2022
 | 
			
		||||
            # run the ruby build script
 | 
			
		||||
            if ! [ -x "$(command -v xcodebuild)" ]; then
 | 
			
		||||
              echo 'Error: xcodebuild is not installed.'
 | 
			
		||||
 | 
			
		||||
@ -158,7 +158,8 @@ jobs:
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if is_vanilla_build; then
 | 
			
		||||
            echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
 | 
			
		||||
            echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
 | 
			
		||||
            echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
 | 
			
		||||
            echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
 | 
			
		||||
          else
 | 
			
		||||
            echo "Skipping for ${BUILD_ENVIRONMENT}"
 | 
			
		||||
 | 
			
		||||
@ -63,9 +63,9 @@ function get_pr_change_files() {
 | 
			
		||||
function file_diff_from_base() {
 | 
			
		||||
  # The fetch may fail on Docker hosts, this fetch is necessary for GHA
 | 
			
		||||
  set +e
 | 
			
		||||
  git fetch origin master --quiet
 | 
			
		||||
  git fetch origin release/1.10 --quiet
 | 
			
		||||
  set -e
 | 
			
		||||
  git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
 | 
			
		||||
  git diff --name-only "$(git merge-base origin/release/1.10 HEAD)" > "$1"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function get_bazel() {
 | 
			
		||||
@ -99,5 +99,5 @@ function checkout_install_torchvision() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function clone_pytorch_xla() {
 | 
			
		||||
  git clone --recursive https://github.com/pytorch/xla.git
 | 
			
		||||
  git clone --recursive -b r1.10 https://github.com/pytorch/xla.git
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -424,7 +424,7 @@ test_backward_compatibility() {
 | 
			
		||||
  python -m venv venv
 | 
			
		||||
  # shellcheck disable=SC1091
 | 
			
		||||
  . venv/bin/activate
 | 
			
		||||
  pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
 | 
			
		||||
  pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_nightly.html
 | 
			
		||||
  pip show torch
 | 
			
		||||
  python dump_all_function_schemas.py --filename nightly_schemas.txt
 | 
			
		||||
  deactivate
 | 
			
		||||
 | 
			
		||||
@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
 | 
			
		||||
  TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ",
 | 
			
		||||
              x1.ndimension(), " and x2 has ", x2.ndimension());
 | 
			
		||||
  TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ",
 | 
			
		||||
  x1.size(dim), " and x2 has ", x2.size(dim));
 | 
			
		||||
  auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes());
 | 
			
		||||
  auto commonDtype = at::result_type(x1, x2);
 | 
			
		||||
  TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
 | 
			
		||||
  Tensor x1_ = x1.to(commonDtype);
 | 
			
		||||
  Tensor x2_ = x2.to(commonDtype);
 | 
			
		||||
  Tensor x1_ = x1.to(commonDtype).expand(common_size);
 | 
			
		||||
  Tensor x2_ = x2.to(commonDtype).expand(common_size);
 | 
			
		||||
  // Follow scipy impl to improve numerical precision
 | 
			
		||||
  // Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
 | 
			
		||||
  Tensor w12 = at::sum(x1_ * x2_, dim);
 | 
			
		||||
 | 
			
		||||
@ -307,11 +307,11 @@ If the operator is an ATen operator (shows up in the TorchScript graph with the
 | 
			
		||||
 | 
			
		||||
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
 | 
			
		||||
  `torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
 | 
			
		||||
  Make sure the function has the same name as the ATen function, which may declared in
 | 
			
		||||
  Make sure the function has the same name as the ATen function, which may be declared in
 | 
			
		||||
  ``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
 | 
			
		||||
  build time, so will not appear in your checkout until you build PyTorch).
 | 
			
		||||
* The first arg is always the ONNX graph that is being built for export.
 | 
			
		||||
  Other arg names must EXACTLY match the names in ``_VariableFunctions.pyi``,
 | 
			
		||||
  Other arg names must EXACTLY match the names in the ``.pyi`` file,
 | 
			
		||||
  because dispatch is done with keyword arguments.
 | 
			
		||||
* In the symbolic function, if the operator is in the
 | 
			
		||||
  `ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
 | 
			
		||||
@ -365,8 +365,8 @@ See the ``symbolic_opset*.py`` files for more examples.
 | 
			
		||||
torch.autograd.Functions
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
If the operator is defined in a sub-class of :class:`torch.autograd.Function`,
 | 
			
		||||
there are two ways to export it.
 | 
			
		||||
If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
 | 
			
		||||
to export it.
 | 
			
		||||
 | 
			
		||||
Static Symbolic Method
 | 
			
		||||
~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
@ -388,11 +388,11 @@ PythonOp Symbolic
 | 
			
		||||
~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
Alternatively, you can register a custom symbolic function.
 | 
			
		||||
This gives the symoblic function access to more info through the
 | 
			
		||||
This gives the symbolic function access to more info through the
 | 
			
		||||
TorchScript ``Node`` object for the original operation, which gets passed in as the second
 | 
			
		||||
argument (after the ``Graph`` object).
 | 
			
		||||
 | 
			
		||||
All autograd ``Function``s are emitted in the TorchScript graph as ``prim::PythonOp`` nodes.
 | 
			
		||||
All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
 | 
			
		||||
In order to differentiate between different ``Function`` subclasses, the
 | 
			
		||||
symbolic function should use the ``name`` kwarg which gets set to the name of the class.
 | 
			
		||||
 | 
			
		||||
@ -400,11 +400,12 @@ symbolic function should use the ``name`` kwarg which gets set to the name of th
 | 
			
		||||
the ``prim`` namespace, so for this use case, there's a back door: register the
 | 
			
		||||
symbolic for ``"::prim_PythonOp"``.
 | 
			
		||||
 | 
			
		||||
Please also consider adding shape inference logic when you regiester a custom symbolic function
 | 
			
		||||
via setType API. This can help the exporter to obtain correct shape inference.
 | 
			
		||||
An example of setType is test_aten_embedding_2 in test_operators.py.
 | 
			
		||||
Although it is not required to add shape inference logic,
 | 
			
		||||
the exporter emits a warning message if it is not added.
 | 
			
		||||
Custom symbolic functions should add type and shape information by calling ``setType(...)``
 | 
			
		||||
on Value objects before returning them (implemented in C++ by
 | 
			
		||||
``torch::jit::Value::setType``). This is not required, but it can help the exporter's
 | 
			
		||||
shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
 | 
			
		||||
``test_aten_embedding_2`` in
 | 
			
		||||
`test_operators.py <https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py>`_.
 | 
			
		||||
 | 
			
		||||
The example below shows how you can access ``requires_grad`` via the ``Node`` object::
 | 
			
		||||
 | 
			
		||||
@ -430,13 +431,17 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
 | 
			
		||||
            print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
 | 
			
		||||
 | 
			
		||||
        name = kwargs["name"]
 | 
			
		||||
        ret = None
 | 
			
		||||
        if name == "MyClip":
 | 
			
		||||
            return g.op("Clip", args[0], min_f=args[1])
 | 
			
		||||
            ret = g.op("Clip", args[0], min_f=args[1])
 | 
			
		||||
        elif name == "MyRelu":
 | 
			
		||||
            return g.op("Relu", args[0])
 | 
			
		||||
            ret = g.op("Relu", args[0])
 | 
			
		||||
        else:
 | 
			
		||||
            # Logs a warning and returns None
 | 
			
		||||
            return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
 | 
			
		||||
        # Copy type and shape from original node.
 | 
			
		||||
        ret.setType(n.type())
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    from torch.onnx import register_custom_op_symbolic
 | 
			
		||||
    register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 1)
 | 
			
		||||
 | 
			
		||||
@ -597,5 +597,4 @@ Utilities
 | 
			
		||||
    are_deterministic_algorithms_enabled
 | 
			
		||||
    set_warn_always
 | 
			
		||||
    is_warn_always_enabled
 | 
			
		||||
    vmap
 | 
			
		||||
    _assert
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								ios/TestApp/AppleWWDRCAG3.cer
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ios/TestApp/AppleWWDRCAG3.cer
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@ -4,7 +4,14 @@ platform :ios do
 | 
			
		||||
  before_all do
 | 
			
		||||
    setup_circle_ci
 | 
			
		||||
  end
 | 
			
		||||
  lane :install_cert do
 | 
			
		||||
  lane :install_root_cert do
 | 
			
		||||
    import_certificate(
 | 
			
		||||
      certificate_path: "AppleWWDRCAG3.cer",
 | 
			
		||||
      keychain_path: "/Users/distiller/Library/Keychains/fastlane_tmp_keychain-db",
 | 
			
		||||
      keychain_password: ""
 | 
			
		||||
    )
 | 
			
		||||
  end
 | 
			
		||||
  lane :install_dev_cert do
 | 
			
		||||
    puts "Installing Certificates.p12"
 | 
			
		||||
    import_certificate(
 | 
			
		||||
      keychain_name: ENV["MATCH_KEYCHAIN_NAME"],
 | 
			
		||||
 | 
			
		||||
@ -64,7 +64,7 @@ class TestExportModes(JitTestCase):
 | 
			
		||||
            return (a, a)
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        x = torch.ones(3)
 | 
			
		||||
        torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
 | 
			
		||||
        torch.onnx._export(foo, (x,), f)
 | 
			
		||||
 | 
			
		||||
    @skipIfNoLapack
 | 
			
		||||
    def test_aten_fallback(self):
 | 
			
		||||
@ -76,9 +76,8 @@ class TestExportModes(JitTestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.rand(3, 4)
 | 
			
		||||
        y = torch.rand(3, 4)
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            ModelWithAtenNotONNXOp(), (x, y), f,
 | 
			
		||||
            ModelWithAtenNotONNXOp(), (x, y), None,
 | 
			
		||||
            add_node_names=False,
 | 
			
		||||
            do_constant_folding=False,
 | 
			
		||||
            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
 | 
			
		||||
@ -91,11 +90,10 @@ class TestExportModes(JitTestCase):
 | 
			
		||||
            def forward(self, x, y):
 | 
			
		||||
                return torch.fmod(x, y)
 | 
			
		||||
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        x = torch.randn(3, 4, dtype=torch.float32)
 | 
			
		||||
        y = torch.randn(3, 4, dtype=torch.float32)
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            ModelWithAtenFmod(), (x, y), f,
 | 
			
		||||
            ModelWithAtenFmod(), (x, y), None,
 | 
			
		||||
            add_node_names=False,
 | 
			
		||||
            do_constant_folding=False,
 | 
			
		||||
            operator_export_type=OperatorExportTypes.ONNX_ATEN)
 | 
			
		||||
 | 
			
		||||
@ -49,20 +49,17 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
 | 
			
		||||
        tm = TraceMe()
 | 
			
		||||
        tm = torch.jit.trace(tm, torch.rand(3, 4))
 | 
			
		||||
        example_outputs = (tm(torch.rand(3, 4)),)
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
 | 
			
		||||
        torch.onnx._export(tm, (torch.rand(3, 4),), f)
 | 
			
		||||
 | 
			
		||||
    def test_export_tensoroption_to(self):
 | 
			
		||||
        def foo(x):
 | 
			
		||||
            return x[0].clone().detach().cpu() + x
 | 
			
		||||
 | 
			
		||||
        traced = torch.jit.trace(foo, (torch.rand([2])))
 | 
			
		||||
        example_outputs = traced(torch.rand([2]))
 | 
			
		||||
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
 | 
			
		||||
                                            example_outputs=example_outputs)
 | 
			
		||||
        torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_module(self):
 | 
			
		||||
        class ModuleToExport(torch.jit.ScriptModule):
 | 
			
		||||
@ -75,10 +72,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return x + x
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3))
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings
 | 
			
		||||
    def test_onnx_export_func_with_warnings(self):
 | 
			
		||||
@ -93,11 +88,9 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return func_with_warning(x)
 | 
			
		||||
 | 
			
		||||
        outputs = WarningTest()(torch.randn(42))
 | 
			
		||||
        # no exception
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            WarningTest(), torch.randn(42), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            WarningTest(), torch.randn(42), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_python_fail(self):
 | 
			
		||||
        class PythonModule(torch.jit.ScriptModule):
 | 
			
		||||
@ -119,11 +112,9 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return y + y
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3))
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
 | 
			
		||||
            torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
 | 
			
		||||
                               example_outputs=outputs)
 | 
			
		||||
            torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_inline_trace(self):
 | 
			
		||||
        class ModuleToInline(torch.nn.Module):
 | 
			
		||||
@ -144,10 +135,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return y + y
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3))
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_inline_script(self):
 | 
			
		||||
        class ModuleToInline(torch.jit.ScriptModule):
 | 
			
		||||
@ -169,10 +158,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return y + y
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3))
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_module_loop(self):
 | 
			
		||||
        class ModuleToExport(torch.jit.ScriptModule):
 | 
			
		||||
@ -189,10 +176,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3))
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings
 | 
			
		||||
    def test_onnx_export_script_truediv(self):
 | 
			
		||||
@ -206,11 +191,9 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return x + z
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3))
 | 
			
		||||
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_non_alpha_add_sub(self):
 | 
			
		||||
        class ModuleToExport(torch.jit.ScriptModule):
 | 
			
		||||
@ -223,10 +206,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return bs - 1
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = torch.LongTensor([mte(torch.rand(3, 4))])
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.rand(3, 4),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.rand(3, 4),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_module_if(self):
 | 
			
		||||
        class ModuleToExport(torch.jit.ScriptModule):
 | 
			
		||||
@ -240,10 +221,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
        mte = ModuleToExport()
 | 
			
		||||
        outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
 | 
			
		||||
            example_outputs=outputs)
 | 
			
		||||
            mte, (torch.zeros(1, 2, 3),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_script_inline_params(self):
 | 
			
		||||
        class ModuleToInline(torch.jit.ScriptModule):
 | 
			
		||||
@ -272,8 +251,7 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
        reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
 | 
			
		||||
        self.assertEqual(result, reference)
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            mte, (torch.ones(2, 3),), None, verbose=False,
 | 
			
		||||
            example_outputs=result)
 | 
			
		||||
            mte, (torch.ones(2, 3),), None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_speculate(self):
 | 
			
		||||
 | 
			
		||||
@ -305,18 +283,16 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
            return x.t()
 | 
			
		||||
 | 
			
		||||
        f1 = Foo(transpose)
 | 
			
		||||
        outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
 | 
			
		||||
        f2 = Foo(linear)
 | 
			
		||||
        outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
 | 
			
		||||
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            f1,
 | 
			
		||||
            (torch.ones(1, 10, dtype=torch.float), ),
 | 
			
		||||
            None, verbose=False, example_outputs=outputs_f1)
 | 
			
		||||
            None, verbose=False)
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            f2,
 | 
			
		||||
            (torch.ones(1, 10, dtype=torch.float), ),
 | 
			
		||||
            None, verbose=False, example_outputs=outputs_f2)
 | 
			
		||||
            None, verbose=False)
 | 
			
		||||
 | 
			
		||||
    def test_onnx_export_shape_reshape(self):
 | 
			
		||||
        class Foo(torch.nn.Module):
 | 
			
		||||
@ -328,10 +304,8 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
                return reshaped
 | 
			
		||||
 | 
			
		||||
        foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
 | 
			
		||||
        outputs = foo(torch.zeros(1, 2, 3))
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
 | 
			
		||||
                                           example_outputs=outputs)
 | 
			
		||||
        torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f)
 | 
			
		||||
 | 
			
		||||
    def test_listconstruct_erasure(self):
 | 
			
		||||
        class FooMod(torch.nn.Module):
 | 
			
		||||
@ -360,11 +334,10 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
        mod = DynamicSliceExportMod()
 | 
			
		||||
 | 
			
		||||
        input = torch.rand(3, 4, 5)
 | 
			
		||||
        example_outs = mod(input)
 | 
			
		||||
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10)
 | 
			
		||||
            DynamicSliceExportMod(), (input,), f, opset_version=10)
 | 
			
		||||
 | 
			
		||||
    def test_export_dict(self):
 | 
			
		||||
        class DictModule(torch.nn.Module):
 | 
			
		||||
@ -380,4 +353,4 @@ class TestONNXExport(JitTestCase):
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
 | 
			
		||||
            torch.onnx.export_to_pretty_string(
 | 
			
		||||
                torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),))
 | 
			
		||||
                torch.jit.script(mod), (x_in,), f)
 | 
			
		||||
 | 
			
		||||
@ -1115,9 +1115,8 @@ class TestTracer(JitTestCase):
 | 
			
		||||
            def forward(self, x, w):
 | 
			
		||||
                return torch.matmul(x, w).detach()
 | 
			
		||||
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export_to_pretty_string(
 | 
			
		||||
            Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)
 | 
			
		||||
            Mod(), (torch.rand(3, 4), torch.rand(4, 5)), None)
 | 
			
		||||
 | 
			
		||||
    def test_trace_slice_full_dim(self):
 | 
			
		||||
        def foo(x):
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
 | 
			
		||||
 | 
			
		||||
            outputs = model(inputs)
 | 
			
		||||
            script_model = torch.jit.script(model)
 | 
			
		||||
            run_model_test(self, script_model, False, example_outputs=outputs,
 | 
			
		||||
            run_model_test(self, script_model, False,
 | 
			
		||||
                           input=inputs, rtol=rtol, atol=atol)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -40,14 +40,13 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
 | 
			
		||||
                    assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None,
 | 
			
		||||
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL,
 | 
			
		||||
                               input_names=None, dynamic_axes=None):
 | 
			
		||||
    for opset_version in opset_versions:
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export(module, x, f,
 | 
			
		||||
                          opset_version=opset_version,
 | 
			
		||||
                          training=training,
 | 
			
		||||
                          example_outputs=example_outputs,
 | 
			
		||||
                          input_names=input_names,
 | 
			
		||||
                          dynamic_axes=dynamic_axes)
 | 
			
		||||
        model = onnx.load(io.BytesIO(f.getvalue()))
 | 
			
		||||
@ -91,10 +90,8 @@ class TestONNXOpset(TestCase):
 | 
			
		||||
        x = torch.arange(1., 6., requires_grad=True)
 | 
			
		||||
        k = torch.tensor(3)
 | 
			
		||||
        module = MyModuleDynamic()
 | 
			
		||||
        example_output = module(x, k)
 | 
			
		||||
        check_onnx_opsets_operator(module, [x, k], ops,
 | 
			
		||||
                                   opset_versions=[10],
 | 
			
		||||
                                   example_outputs=example_output)
 | 
			
		||||
                                   opset_versions=[10])
 | 
			
		||||
 | 
			
		||||
    def test_maxpool(self):
 | 
			
		||||
        module = torch.nn.MaxPool1d(2, stride=1)
 | 
			
		||||
@ -191,7 +188,6 @@ class TestONNXOpset(TestCase):
 | 
			
		||||
 | 
			
		||||
        module = DynamicSliceModel()
 | 
			
		||||
        x = torch.rand(1, 2)
 | 
			
		||||
        example_output = module(x)
 | 
			
		||||
        ops_10 = [{"op_name" : "Shape"},
 | 
			
		||||
                  {"op_name" : "Constant"},
 | 
			
		||||
                  {"op_name" : "Gather",
 | 
			
		||||
@ -202,7 +198,7 @@ class TestONNXOpset(TestCase):
 | 
			
		||||
                  {"op_name" : "Slice",
 | 
			
		||||
                   "attributes" : []}]
 | 
			
		||||
        ops = {10 : ops_10}
 | 
			
		||||
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output,
 | 
			
		||||
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10],
 | 
			
		||||
                                   input_names=['x'], dynamic_axes={"x": [0, 1]})
 | 
			
		||||
 | 
			
		||||
        ops_10 = [{"op_name" : "Constant"},
 | 
			
		||||
@ -212,7 +208,7 @@ class TestONNXOpset(TestCase):
 | 
			
		||||
                  {"op_name" : "Slice",
 | 
			
		||||
                   "attributes" : []}]
 | 
			
		||||
        ops = {10 : ops_10}
 | 
			
		||||
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output)
 | 
			
		||||
        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
 | 
			
		||||
 | 
			
		||||
    def test_flip(self):
 | 
			
		||||
        class MyModule(Module):
 | 
			
		||||
 | 
			
		||||
@ -279,12 +279,11 @@ class TestOperators(TestCase):
 | 
			
		||||
    def test_conv_variable_length(self):
 | 
			
		||||
        x = torch.ones(5, 3, 6, 6, requires_grad=True)
 | 
			
		||||
        model = torch.nn.Conv2d(3, 2, 3)
 | 
			
		||||
        y = model(x)
 | 
			
		||||
 | 
			
		||||
        dynamic_axes = {"input_1": [0, 2, 3], "output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"}}
 | 
			
		||||
        model_proto_name = "conv2d.onnx"
 | 
			
		||||
        torch.onnx.export(model, x, model_proto_name, verbose=True, input_names=["input_1"], output_names=["output_1"],
 | 
			
		||||
                          example_outputs=y, dynamic_axes=dynamic_axes)
 | 
			
		||||
                          dynamic_axes=dynamic_axes)
 | 
			
		||||
 | 
			
		||||
        import onnx
 | 
			
		||||
        onnx_model = onnx.load(model_proto_name)
 | 
			
		||||
@ -729,22 +728,6 @@ class TestOperators(TestCase):
 | 
			
		||||
        x = torch.randn(2, 3, 4, requires_grad=True)
 | 
			
		||||
        self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
 | 
			
		||||
 | 
			
		||||
    def test_retain_param_name_disabled(self):
 | 
			
		||||
        class MyModule(Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(MyModule, self).__init__()
 | 
			
		||||
                self.fc1 = nn.Linear(4, 5, bias=False)
 | 
			
		||||
                self.fc1.weight.data.fill_(2.)
 | 
			
		||||
                self.fc2 = nn.Linear(5, 6, bias=False)
 | 
			
		||||
                self.fc2.weight.data.fill_(3.)
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return self.fc2(self.fc1(x))
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(3, 4).float()
 | 
			
		||||
        self.assertONNX(MyModule(), (x,), _retain_param_name=False,
 | 
			
		||||
                        keep_initializers_as_inputs=True)
 | 
			
		||||
 | 
			
		||||
    def test_c2_op(self):
 | 
			
		||||
        class MyModel(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
 | 
			
		||||
@ -134,7 +134,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        return cuda_model, cuda_input
 | 
			
		||||
 | 
			
		||||
    def run_debug_test(self, model, train, batch_size, state_dict=None,
 | 
			
		||||
                       input=None, use_gpu=True, example_outputs=None,
 | 
			
		||||
                       input=None, use_gpu=True,
 | 
			
		||||
                       operator_export_type=torch.onnx.OperatorExportTypes.ONNX):
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: remove this from the final release version
 | 
			
		||||
@ -153,7 +153,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
            model, input = self.convert_cuda(model, input)
 | 
			
		||||
 | 
			
		||||
        onnxir, torch_out = do_export(model, input, export_params=self.embed_params, verbose=False,
 | 
			
		||||
                                      example_outputs=example_outputs,
 | 
			
		||||
                                      do_constant_folding=False,
 | 
			
		||||
                                      opset_version=self.opset_version,
 | 
			
		||||
                                      keep_initializers_as_inputs=True,
 | 
			
		||||
@ -168,7 +167,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def run_actual_test(self, model, train, batch_size, state_dict=None,
 | 
			
		||||
                        input=None, use_gpu=True, rtol=0.001, atol=1e-7,
 | 
			
		||||
                        example_outputs=None, do_constant_folding=True,
 | 
			
		||||
                        do_constant_folding=True,
 | 
			
		||||
                        operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
 | 
			
		||||
                        input_names=None, dynamic_axes=None,
 | 
			
		||||
                        remained_onnx_input_idx=None):
 | 
			
		||||
@ -191,7 +190,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        # Verify the model runs the same in Caffe2
 | 
			
		||||
        verify.verify(model, input, c2, rtol=rtol, atol=atol,
 | 
			
		||||
                      example_outputs=example_outputs,
 | 
			
		||||
                      do_constant_folding=do_constant_folding,
 | 
			
		||||
                      opset_version=self.opset_version,
 | 
			
		||||
                      keep_initializers_as_inputs=True,
 | 
			
		||||
@ -202,7 +200,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def run_model_test(self, model, train, batch_size, state_dict=None,
 | 
			
		||||
                       input=None, use_gpu=True, rtol=0.001, atol=1e-7,
 | 
			
		||||
                       example_outputs=None, do_constant_folding=True,
 | 
			
		||||
                       do_constant_folding=True,
 | 
			
		||||
                       operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
 | 
			
		||||
                       input_names=None, dynamic_axes=None,
 | 
			
		||||
                       remained_onnx_input_idx=None):
 | 
			
		||||
@ -214,7 +212,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        if self.embed_params:
 | 
			
		||||
            self.run_actual_test(model, train, batch_size, state_dict, input,
 | 
			
		||||
                                 use_gpu=use_gpu_, rtol=rtol, atol=atol,
 | 
			
		||||
                                 example_outputs=example_outputs,
 | 
			
		||||
                                 do_constant_folding=do_constant_folding,
 | 
			
		||||
                                 operator_export_type=operator_export_type,
 | 
			
		||||
                                 input_names=input_names,
 | 
			
		||||
@ -222,8 +219,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                                 remained_onnx_input_idx=remained_onnx_input_idx)
 | 
			
		||||
        else:
 | 
			
		||||
            self.run_debug_test(model, train, batch_size, state_dict, input,
 | 
			
		||||
                                use_gpu=use_gpu_, example_outputs=example_outputs,
 | 
			
		||||
                                operator_export_type=operator_export_type)
 | 
			
		||||
                                use_gpu=use_gpu_, operator_export_type=operator_export_type)
 | 
			
		||||
 | 
			
		||||
    def test_linear(self):
 | 
			
		||||
        class MyModel(torch.nn.Module):
 | 
			
		||||
@ -289,7 +285,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        # This test checks that given this edge condition, the model can be loaded and executed
 | 
			
		||||
        # in Caffe2 backend correctly.
 | 
			
		||||
        torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE,
 | 
			
		||||
                           input_names=["input1", "fc1.bias"], _retain_param_name=False,
 | 
			
		||||
                           input_names=["input1", "fc1.bias"],
 | 
			
		||||
                           keep_initializers_as_inputs=True)
 | 
			
		||||
 | 
			
		||||
        f.seek(0)
 | 
			
		||||
@ -1401,9 +1397,8 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return x[1:x.size(0)]
 | 
			
		||||
        module = DynamicSliceModel()
 | 
			
		||||
        x = torch.rand(1, 2)
 | 
			
		||||
        example_output = module(x)
 | 
			
		||||
        self.run_model_test(DynamicSliceModel(), train=False, input=(x,),
 | 
			
		||||
                            batch_size=BATCH_SIZE, use_gpu=False, example_outputs=example_output)
 | 
			
		||||
                            batch_size=BATCH_SIZE, use_gpu=False)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(11)
 | 
			
		||||
    def test_dynamic_slice_to_the_end(self):
 | 
			
		||||
@ -1472,7 +1467,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        y = torch.ones(2, 3, 4) * 2
 | 
			
		||||
        self.run_model_test(Arithmetic(),
 | 
			
		||||
                            train=False, input=(), batch_size=BATCH_SIZE,
 | 
			
		||||
                            use_gpu=False, example_outputs=(x + 3, y * (x + 3)))
 | 
			
		||||
                            use_gpu=False)
 | 
			
		||||
 | 
			
		||||
    def test_tensor_factories(self):
 | 
			
		||||
        class TensorFactory(torch.nn.Module):
 | 
			
		||||
@ -1493,11 +1488,9 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3, 4)
 | 
			
		||||
        self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            use_gpu=False, example_outputs=(torch.ones(x.size()),),
 | 
			
		||||
                            input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
 | 
			
		||||
                            use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
 | 
			
		||||
        self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            use_gpu=False, example_outputs=(torch.ones(x.size()),),
 | 
			
		||||
                            remained_onnx_input_idx=[])
 | 
			
		||||
                            use_gpu=False, remained_onnx_input_idx=[])
 | 
			
		||||
 | 
			
		||||
    def test_tensor_like_factories_script(self):
 | 
			
		||||
        class TensorFactory(torch.jit.ScriptModule):
 | 
			
		||||
@ -1509,12 +1502,10 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3, 4)
 | 
			
		||||
        self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            use_gpu=False, example_outputs=(torch.ones(x.size()),),
 | 
			
		||||
                            input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
 | 
			
		||||
                            use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
 | 
			
		||||
        remained_onnx_input_idx = None if self.opset_version < 9 else []
 | 
			
		||||
        self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            use_gpu=False, example_outputs=(torch.ones(x.size()),),
 | 
			
		||||
                            remained_onnx_input_idx=remained_onnx_input_idx)
 | 
			
		||||
                            use_gpu=False, remained_onnx_input_idx=remained_onnx_input_idx)
 | 
			
		||||
 | 
			
		||||
    def test_full(self):
 | 
			
		||||
        class FullModel(torch.nn.Module):
 | 
			
		||||
@ -1532,8 +1523,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return torch.full((4, 5), x, dtype=torch.long)
 | 
			
		||||
 | 
			
		||||
        x = torch.tensor(12)
 | 
			
		||||
        self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            use_gpu=False, example_outputs=FullClass()(x))
 | 
			
		||||
        self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
 | 
			
		||||
 | 
			
		||||
    def test_clamp(self):
 | 
			
		||||
        class ClampModel(torch.nn.Module):
 | 
			
		||||
@ -2021,8 +2011,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return a
 | 
			
		||||
 | 
			
		||||
        x = (torch.randn(3, 4), torch.randn(4, 3))
 | 
			
		||||
        self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(x,))
 | 
			
		||||
        self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    def test_nested_tuple_input_output(self):
 | 
			
		||||
        class NestedTupleModel(torch.jit.ScriptModule):
 | 
			
		||||
@ -2032,8 +2021,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(4, 5)
 | 
			
		||||
        y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
 | 
			
		||||
        self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=x + y[0] + y[1][0] + y[1][1])
 | 
			
		||||
        self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    def test_topk(self):
 | 
			
		||||
        class TopKModel(torch.nn.Module):
 | 
			
		||||
@ -2050,7 +2038,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return torch.topk(input, 3, dim=0)
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(4, 3, requires_grad=True)
 | 
			
		||||
        self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
 | 
			
		||||
        self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    def test_floor(self):
 | 
			
		||||
        class FloorModel(torch.nn.Module):
 | 
			
		||||
@ -2086,9 +2074,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(3, 4, requires_grad=True)
 | 
			
		||||
        outputs = ArangeScript()(x)
 | 
			
		||||
        self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
        class ArangeModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, a):
 | 
			
		||||
@ -2104,9 +2090,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(3, 4, requires_grad=True)
 | 
			
		||||
        outputs = ArangeScript()(x)
 | 
			
		||||
        self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
        class ArangeModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, a):
 | 
			
		||||
@ -2122,9 +2106,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
                return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(3, 4, requires_grad=True)
 | 
			
		||||
        outputs = ArangeScript()(x)
 | 
			
		||||
        self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
        class ArangeModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, a):
 | 
			
		||||
@ -2249,9 +2231,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = WhileModel()
 | 
			
		||||
        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
 | 
			
		||||
 | 
			
		||||
    def test_while_cond(self):
 | 
			
		||||
        class WhileModel(torch.jit.ScriptModule):
 | 
			
		||||
@ -2266,9 +2246,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        model = WhileModel()
 | 
			
		||||
        x = torch.zeros(1, 2, 3, dtype=torch.long)
 | 
			
		||||
        a = torch.tensor([0], dtype=torch.long)
 | 
			
		||||
        outputs = model(x, a)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("Disabled due to onnx optimizer deprecation")
 | 
			
		||||
    def test_loop(self):
 | 
			
		||||
@ -2281,9 +2259,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = LoopModel()
 | 
			
		||||
        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("Disabled due to onnx optimizer deprecation")
 | 
			
		||||
    def test_dynamic_loop(self):
 | 
			
		||||
@ -2296,9 +2272,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = LoopModel()
 | 
			
		||||
        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("Disabled due to onnx optimizer deprecation")
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
@ -2317,9 +2291,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = NestedLoopsModel()
 | 
			
		||||
        inputs = torch.zeros(1, 2, 3, dtype=torch.long)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
 | 
			
		||||
 | 
			
		||||
    def test_select(self):
 | 
			
		||||
        class SelectModel(torch.nn.Module):
 | 
			
		||||
@ -2337,9 +2309,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = StandardDeviation()
 | 
			
		||||
        inputs = torch.randn(2, 3, 4)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    def test_std_along_dims(self):
 | 
			
		||||
        class StandardDeviationAlongDims(torch.nn.Module):
 | 
			
		||||
@ -2348,9 +2318,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = StandardDeviationAlongDims()
 | 
			
		||||
        inputs = torch.randn(2, 3, 4)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
    def test_masked_fill(self):
 | 
			
		||||
@ -2379,9 +2347,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        y = torch.zeros(4, requires_grad=True)
 | 
			
		||||
        z = torch.ones(5, requires_grad=True)
 | 
			
		||||
        model = MeshgridModel()
 | 
			
		||||
        outputs = model(x, y, z)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    def test_remainder(self):
 | 
			
		||||
        class RemainderModel(torch.nn.Module):
 | 
			
		||||
@ -2391,9 +2357,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
        x = torch.randn(4, 2, 3)
 | 
			
		||||
        y = torch.randn(1, 2, 1)
 | 
			
		||||
        model = RemainderModel()
 | 
			
		||||
        outputs = model(x, y)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    def test_remainder_scalar(self):
 | 
			
		||||
        class RemainderModel(torch.nn.Module):
 | 
			
		||||
@ -2402,9 +2366,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        inputs = torch.randint(10, (2, 3))
 | 
			
		||||
        model = RemainderModel()
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
 | 
			
		||||
 | 
			
		||||
    def test_baddbmm(self):
 | 
			
		||||
        class MyModule(torch.nn.Module):
 | 
			
		||||
@ -2423,9 +2385,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = GeluModel()
 | 
			
		||||
        inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
 | 
			
		||||
        outputs = model(inputs)
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
 | 
			
		||||
                            example_outputs=(outputs,))
 | 
			
		||||
        self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
    def test_index_fill(self):
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ class TestQuantizedOps(unittest.TestCase):
 | 
			
		||||
        output = q_model(*pt_inputs)
 | 
			
		||||
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export(q_model, pt_inputs, f, input_names=input_names, example_outputs=output,
 | 
			
		||||
        torch.onnx.export(q_model, pt_inputs, f, input_names=input_names,
 | 
			
		||||
                          operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
 | 
			
		||||
        f.seek(0)
 | 
			
		||||
        onnx_model = onnx.load(f)
 | 
			
		||||
@ -84,8 +84,6 @@ class TestQuantizedOps(unittest.TestCase):
 | 
			
		||||
        self.generic_unary_test(torch.nn.ReLU())
 | 
			
		||||
 | 
			
		||||
    def export_to_onnx(self, model, input, input_names):
 | 
			
		||||
        outputs = model(input)
 | 
			
		||||
 | 
			
		||||
        traced = torch.jit.trace(model, input)
 | 
			
		||||
        buf = io.BytesIO()
 | 
			
		||||
        torch.jit.save(traced, buf)
 | 
			
		||||
@ -93,7 +91,7 @@ class TestQuantizedOps(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        model = torch.jit.load(buf)
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
 | 
			
		||||
        torch.onnx.export(model, input, f, input_names=input_names,
 | 
			
		||||
                          operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
 | 
			
		||||
        f.seek(0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,9 +45,9 @@ def to_numpy(tensor):
 | 
			
		||||
    else:
 | 
			
		||||
        return tensor.cpu().numpy()
 | 
			
		||||
 | 
			
		||||
def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
 | 
			
		||||
                    do_constant_folding=True, keep_initializers_as_inputs=True,
 | 
			
		||||
                    dynamic_axes=None, input_names=None, output_names=None,
 | 
			
		||||
def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True,
 | 
			
		||||
                    keep_initializers_as_inputs=True, dynamic_axes=None,
 | 
			
		||||
                    input_names=None, output_names=None,
 | 
			
		||||
                    fixed_batch_size=False, training=None,
 | 
			
		||||
                    onnx_shape_inference=True):
 | 
			
		||||
    # export the model to ONNX
 | 
			
		||||
@ -55,7 +55,6 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
 | 
			
		||||
    input_copy = copy.deepcopy(input)
 | 
			
		||||
    torch.onnx._export(model, input_copy, f,
 | 
			
		||||
                       opset_version=opset_version,
 | 
			
		||||
                       example_outputs=example_outputs,
 | 
			
		||||
                       do_constant_folding=do_constant_folding,
 | 
			
		||||
                       keep_initializers_as_inputs=keep_initializers_as_inputs,
 | 
			
		||||
                       dynamic_axes=dynamic_axes,
 | 
			
		||||
@ -132,7 +131,7 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
 | 
			
		||||
            input = input + ({},)
 | 
			
		||||
 | 
			
		||||
        ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version,
 | 
			
		||||
                                   example_outputs=output, do_constant_folding=do_constant_folding,
 | 
			
		||||
                                   do_constant_folding=do_constant_folding,
 | 
			
		||||
                                   keep_initializers_as_inputs=self.keep_initializers_as_inputs,
 | 
			
		||||
                                   dynamic_axes=dynamic_axes, input_names=input_names,
 | 
			
		||||
                                   output_names=output_names, fixed_batch_size=fixed_batch_size, training=training,
 | 
			
		||||
@ -284,9 +283,9 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
        _run_test(model, tracing_remained_onnx_input_idx)
 | 
			
		||||
 | 
			
		||||
    def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
 | 
			
		||||
                                          example_outputs=None, do_constant_folding=True,
 | 
			
		||||
                                          dynamic_axes=None, input_names=None, output_names=None,
 | 
			
		||||
                                          ort_optim_on=True, training=None):
 | 
			
		||||
                                          do_constant_folding=True, dynamic_axes=None,
 | 
			
		||||
                                          input_names=None, output_names=None,
 | 
			
		||||
                                          ort_optim_on=True, training=None, use_external_data_format=None):
 | 
			
		||||
        import os
 | 
			
		||||
        import tempfile
 | 
			
		||||
 | 
			
		||||
@ -310,13 +309,12 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
                input_copy = copy.deepcopy(input)
 | 
			
		||||
                torch.onnx.export(model, input_copy, model_file_name,
 | 
			
		||||
                                  opset_version=self.opset_version,
 | 
			
		||||
                                  example_outputs=output,
 | 
			
		||||
                                  verbose=False,
 | 
			
		||||
                                  do_constant_folding=do_constant_folding,
 | 
			
		||||
                                  keep_initializers_as_inputs=self.keep_initializers_as_inputs,
 | 
			
		||||
                                  dynamic_axes=dynamic_axes,
 | 
			
		||||
                                  input_names=input_names, output_names=output_names,
 | 
			
		||||
                                  use_external_data_format=True)
 | 
			
		||||
                                  use_external_data_format=use_external_data_format)
 | 
			
		||||
                # compute onnxruntime output prediction
 | 
			
		||||
                ort_sess_opt = onnxruntime.SessionOptions()
 | 
			
		||||
                ort_sess_opt.graph_optimization_level = \
 | 
			
		||||
@ -366,19 +364,55 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
                return x + torch.ones(2, 1024)
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 1)
 | 
			
		||||
        self.run_model_test_with_external_data(LargeModel(), x)
 | 
			
		||||
        self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=None)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)  # Because external data format was released with Opset 9.
 | 
			
		||||
    @unittest.skip("Enable this once large model with subgraph is supported in ORT")
 | 
			
		||||
    def test_subgraph_with_external_data(self):
 | 
			
		||||
    def test_largemodel_without_use_external_data_format_param(self):
 | 
			
		||||
        class LargeModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                for i in range(x.size(0)):
 | 
			
		||||
                    x = x + torch.ones(2, 1024)
 | 
			
		||||
                return x
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(LargeModel, self).__init__()
 | 
			
		||||
                dim = 5
 | 
			
		||||
                n = 40 * 4 * 10 ** 6
 | 
			
		||||
                self.emb = torch.nn.Embedding(n, dim)
 | 
			
		||||
                self.lin1 = torch.nn.Linear(dim, 1)
 | 
			
		||||
                self.seq = torch.nn.Sequential(
 | 
			
		||||
                    self.emb,
 | 
			
		||||
                    self.lin1,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 1)
 | 
			
		||||
        self.run_model_test_with_external_data(torch.jit.script(LargeModel()), x)
 | 
			
		||||
            def forward(self, input):
 | 
			
		||||
                return self.seq(input)
 | 
			
		||||
 | 
			
		||||
        model = LargeModel()
 | 
			
		||||
        x = torch.tensor([2], dtype=torch.long)
 | 
			
		||||
        self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=None)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)  # Because external data format was released with Opset 9.
 | 
			
		||||
    def test_largemodel_with_use_external_data_format_False(self):
 | 
			
		||||
        class LargeModel(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(LargeModel, self).__init__()
 | 
			
		||||
                dim = 5
 | 
			
		||||
                n = 30 * 4 * 10 ** 6
 | 
			
		||||
                self.emb = torch.nn.Embedding(n, dim)
 | 
			
		||||
                self.lin1 = torch.nn.Linear(dim, 1)
 | 
			
		||||
                self.seq = torch.nn.Sequential(
 | 
			
		||||
                    self.emb,
 | 
			
		||||
                    self.lin1,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            def forward(self, input):
 | 
			
		||||
                return self.seq(input)
 | 
			
		||||
 | 
			
		||||
        model = LargeModel()
 | 
			
		||||
        x = torch.tensor([3], dtype=torch.long)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError) as cm:
 | 
			
		||||
            self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=False)
 | 
			
		||||
 | 
			
		||||
            the_exception = cm.exception
 | 
			
		||||
            self.assertEqual("RuntimeError: Exporting model exceed maximum protobuf size of 2GB. " +
 | 
			
		||||
                             "Please call torch.onnx.export without setting use_external_data_format parameter.")
 | 
			
		||||
 | 
			
		||||
    def test_fuse_conv_bn1d(self):
 | 
			
		||||
        class Fuse(torch.nn.Module):
 | 
			
		||||
@ -7784,7 +7818,6 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
        script_model = torch.jit.script(model)
 | 
			
		||||
        output = model(x)
 | 
			
		||||
        ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
 | 
			
		||||
                                   example_outputs=output,
 | 
			
		||||
                                   training=torch.onnx.TrainingMode.TRAINING)
 | 
			
		||||
        ort_outs = run_ort(ort_sess, input=(x,))
 | 
			
		||||
        assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
 | 
			
		||||
@ -7828,7 +7861,6 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
        y = model(input)
 | 
			
		||||
        output = y.cpu().numpy()
 | 
			
		||||
        ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
 | 
			
		||||
                                   example_outputs=y,
 | 
			
		||||
                                   training=torch.onnx.TrainingMode.TRAINING)
 | 
			
		||||
        ort_outs = run_ort(ort_sess, input=(x,))
 | 
			
		||||
        ort_mask = np.where(ort_outs[0] != 0, 1, 0)
 | 
			
		||||
@ -7913,11 +7945,9 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
        model = torch.jit.script(MyModule())
 | 
			
		||||
        box_regression = torch.randn([4, 4])
 | 
			
		||||
        proposal = [torch.randn(2, 4), torch.randn(2, 4)]
 | 
			
		||||
        outputs = model(box_regression, proposal)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError) as cm:
 | 
			
		||||
            convert_to_onnx(model, input=(box_regression, proposal),
 | 
			
		||||
                            example_outputs=outputs)
 | 
			
		||||
            convert_to_onnx(model, input=(box_regression, proposal))
 | 
			
		||||
 | 
			
		||||
    def test_initializer_sequence(self):
 | 
			
		||||
        class MyModule(torch.nn.Module):
 | 
			
		||||
@ -7939,7 +7969,7 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(32, 3)
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx._export(test_model, (x,), f, _retain_param_name=True, do_constant_folding=False)
 | 
			
		||||
        torch.onnx._export(test_model, (x,), f, do_constant_folding=False)
 | 
			
		||||
        loaded_model = onnx.load_from_string(f.getvalue())
 | 
			
		||||
 | 
			
		||||
        actual_list = [p.name for p in loaded_model.graph.initializer]
 | 
			
		||||
@ -7986,10 +8016,9 @@ class TestONNXRuntime(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.ones(2, 3, dtype=torch.float)
 | 
			
		||||
        y = torch.tensor(5, dtype=torch.long)
 | 
			
		||||
        example_output = (test_model(x, y),)
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
 | 
			
		||||
        torch.onnx.export(test_model, (x, y), f, example_outputs=example_output, _retain_param_name=True, do_constant_folding=False)
 | 
			
		||||
        torch.onnx.export(test_model, (x, y), f, do_constant_folding=False)
 | 
			
		||||
        loaded_model = onnx.load_from_string(f.getvalue())
 | 
			
		||||
 | 
			
		||||
        actual_list = [p.name for p in loaded_model.graph.initializer]
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,6 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
 | 
			
		||||
    def _model_to_graph(self, model, input,
 | 
			
		||||
                        do_constant_folding=True,
 | 
			
		||||
                        example_outputs=None,
 | 
			
		||||
                        training=TrainingMode.EVAL,
 | 
			
		||||
                        operator_export_type=OperatorExportTypes.ONNX,
 | 
			
		||||
                        input_names=None,
 | 
			
		||||
@ -49,7 +48,6 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
                                                              _disable_torch_constant_prop=True,
 | 
			
		||||
                                                              operator_export_type=operator_export_type,
 | 
			
		||||
                                                              training=training,
 | 
			
		||||
                                                              example_outputs=example_outputs,
 | 
			
		||||
                                                              input_names=input_names,
 | 
			
		||||
                                                              dynamic_axes=dynamic_axes)
 | 
			
		||||
        _set_onnx_shape_inference(True)
 | 
			
		||||
@ -116,11 +114,11 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
        example_output = model(input_t, n)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            torch.onnx.export(model,
 | 
			
		||||
                              (input_t, n),
 | 
			
		||||
                              "test.onnx",
 | 
			
		||||
                              opset_version=self.opset_version,
 | 
			
		||||
                              example_outputs=[example_output])
 | 
			
		||||
            torch.onnx._export(model,
 | 
			
		||||
                               (input_t, n),
 | 
			
		||||
                               "test.onnx",
 | 
			
		||||
                               opset_version=self.opset_version,
 | 
			
		||||
                               example_outputs=[example_output])
 | 
			
		||||
 | 
			
		||||
    def test_constant_fold_transpose(self):
 | 
			
		||||
        class TransposeModule(torch.nn.Module):
 | 
			
		||||
@ -558,27 +556,27 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
            assert node.kind() != "onnx::Shape"
 | 
			
		||||
        assert len(list(graph.nodes())) == 1
 | 
			
		||||
 | 
			
		||||
    def test_strip_doc_string(self):
 | 
			
		||||
    def test_verbose(self):
 | 
			
		||||
        class MyModule(torch.nn.Module):
 | 
			
		||||
            def forward(self, input):
 | 
			
		||||
                return torch.exp(input)
 | 
			
		||||
        x = torch.randn(3, 4)
 | 
			
		||||
 | 
			
		||||
        def is_model_stripped(f, strip_doc_string=None):
 | 
			
		||||
            if strip_doc_string is None:
 | 
			
		||||
        def is_model_stripped(f, verbose=None):
 | 
			
		||||
            if verbose is None:
 | 
			
		||||
                torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
 | 
			
		||||
            else:
 | 
			
		||||
                torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string,
 | 
			
		||||
                torch.onnx.export(MyModule(), x, f, verbose=verbose,
 | 
			
		||||
                                  opset_version=self.opset_version)
 | 
			
		||||
            model = onnx.load(io.BytesIO(f.getvalue()))
 | 
			
		||||
            model_strip = copy.copy(model)
 | 
			
		||||
            onnx.helper.strip_doc_string(model_strip)
 | 
			
		||||
            return model == model_strip
 | 
			
		||||
 | 
			
		||||
        # test strip_doc_string=True (default)
 | 
			
		||||
        # test verbose=False (default)
 | 
			
		||||
        self.assertTrue(is_model_stripped(io.BytesIO()))
 | 
			
		||||
        # test strip_doc_string=False
 | 
			
		||||
        self.assertFalse(is_model_stripped(io.BytesIO(), False))
 | 
			
		||||
        # test verbose=True
 | 
			
		||||
        self.assertFalse(is_model_stripped(io.BytesIO(), True))
 | 
			
		||||
 | 
			
		||||
    # NB: remove this test once DataParallel can be correctly handled
 | 
			
		||||
    def test_error_on_data_parallel(self):
 | 
			
		||||
@ -720,9 +718,8 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
        q_model = torch.quantization.convert(q_model, inplace=False)
 | 
			
		||||
 | 
			
		||||
        q_model.eval()
 | 
			
		||||
        output = q_model(*pt_inputs)
 | 
			
		||||
 | 
			
		||||
        graph, _, __ = self._model_to_graph(q_model, pt_inputs, example_outputs=output,
 | 
			
		||||
        graph, _, __ = self._model_to_graph(q_model, pt_inputs,
 | 
			
		||||
                                            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
 | 
			
		||||
                                            input_names=['pt_inputs'],
 | 
			
		||||
                                            dynamic_axes={'pt_inputs': [0, 1, 2, 3]})
 | 
			
		||||
@ -749,9 +746,8 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
 | 
			
		||||
        x = torch.tensor([2])
 | 
			
		||||
        model = PrimModule()
 | 
			
		||||
        output = model(x)
 | 
			
		||||
        model.eval()
 | 
			
		||||
        graph, _, __ = self._model_to_graph(model, (x,), example_outputs=output,
 | 
			
		||||
        graph, _, __ = self._model_to_graph(model, (x,),
 | 
			
		||||
                                            operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
 | 
			
		||||
                                            input_names=['x'], dynamic_axes={'x': [0]})
 | 
			
		||||
        iter = graph.nodes()
 | 
			
		||||
@ -814,10 +810,9 @@ class TestUtilityFuns(TestCase):
 | 
			
		||||
 | 
			
		||||
        model = torch.jit.script(MyModule())
 | 
			
		||||
        x = torch.randn(10, 3, 128, 128)
 | 
			
		||||
        example_outputs = model(x)
 | 
			
		||||
        _set_opset_version(self.opset_version)
 | 
			
		||||
        _set_operator_export_type(OperatorExportTypes.ONNX)
 | 
			
		||||
        graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
 | 
			
		||||
        graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True,
 | 
			
		||||
                                            operator_export_type=OperatorExportTypes.ONNX,
 | 
			
		||||
                                            training=torch.onnx.TrainingMode.TRAINING,
 | 
			
		||||
                                            input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]})
 | 
			
		||||
 | 
			
		||||
@ -1233,8 +1233,6 @@ class TestQuantizeONNXExport(JitTestCase):
 | 
			
		||||
        input_names = ["x"]
 | 
			
		||||
 | 
			
		||||
        def export_to_onnx(model, input, input_names):
 | 
			
		||||
            outputs = model(input)
 | 
			
		||||
 | 
			
		||||
            traced = torch.jit.trace(model, input)
 | 
			
		||||
            buf = io.BytesIO()
 | 
			
		||||
            torch.jit.save(traced, buf)
 | 
			
		||||
@ -1242,7 +1240,7 @@ class TestQuantizeONNXExport(JitTestCase):
 | 
			
		||||
 | 
			
		||||
            model = torch.jit.load(buf)
 | 
			
		||||
            f = io.BytesIO()
 | 
			
		||||
            torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
 | 
			
		||||
            torch.onnx.export(model, input, f, input_names=input_names,
 | 
			
		||||
                              operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
 | 
			
		||||
        onnx_model = export_to_onnx(model, data, input_names)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1524,6 +1524,28 @@ except RuntimeError as e:
 | 
			
		||||
        ):
 | 
			
		||||
            self.assertEqual(list(fn()), list(fn()))
 | 
			
		||||
 | 
			
		||||
        for sampler in (
 | 
			
		||||
            RandomSampler(self.dataset, num_samples=5, replacement=True),
 | 
			
		||||
            RandomSampler(self.dataset, replacement=False),
 | 
			
		||||
            WeightedRandomSampler(weights, num_samples=5, replacement=True),
 | 
			
		||||
            WeightedRandomSampler(weights, num_samples=5, replacement=False),
 | 
			
		||||
            SubsetRandomSampler(range(10)),
 | 
			
		||||
        ):
 | 
			
		||||
            torch.manual_seed(0)
 | 
			
		||||
            l1 = list(sampler) + list(sampler)
 | 
			
		||||
 | 
			
		||||
            torch.manual_seed(0)
 | 
			
		||||
            l2 = list(sampler) + list(sampler)
 | 
			
		||||
            self.assertEqual(l1, l2)
 | 
			
		||||
 | 
			
		||||
            its = (iter(sampler), iter(sampler))
 | 
			
		||||
            ls = ([], [])
 | 
			
		||||
            for idx in range(len(sampler)):
 | 
			
		||||
                for i in range(2):
 | 
			
		||||
                    if idx == 0:
 | 
			
		||||
                        torch.manual_seed(0)
 | 
			
		||||
                    ls[i].append(next(its[i]))
 | 
			
		||||
            self.assertEqual(ls[0], ls[1])
 | 
			
		||||
 | 
			
		||||
    def _test_sampler(self, **kwargs):
 | 
			
		||||
        indices = range(2, 12)  # using a regular iterable
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
import copy
 | 
			
		||||
import http.server
 | 
			
		||||
import itertools
 | 
			
		||||
import os
 | 
			
		||||
@ -414,13 +415,30 @@ class TestIterableDataPipeBasic(TestCase):
 | 
			
		||||
 | 
			
		||||
        # Test Case: Uneven DataPipes
 | 
			
		||||
        source_numbers = list(range(0, 10)) + [10, 12]
 | 
			
		||||
        numbers_dp = IDP(source_numbers)
 | 
			
		||||
        numbers_dp = dp.iter.IterableWrapper(source_numbers)
 | 
			
		||||
        n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
 | 
			
		||||
        self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
 | 
			
		||||
        self.assertEqual([1, 3, 5, 7, 9], list(n2))
 | 
			
		||||
        n = n1.mux(n2)
 | 
			
		||||
        self.assertEqual(source_numbers, list(n))
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings  # Suppress warning for lambda fn
 | 
			
		||||
    def test_map_with_col_file_handle_datapipe(self):
 | 
			
		||||
        temp_dir = self.temp_dir.name
 | 
			
		||||
        datapipe1 = dp.iter.FileLister(temp_dir, '')
 | 
			
		||||
        datapipe2 = dp.iter.FileLoader(datapipe1)
 | 
			
		||||
 | 
			
		||||
        def _helper(datapipe):
 | 
			
		||||
            dp1 = datapipe.map(lambda x: x.read(), input_col=1)
 | 
			
		||||
            dp2 = datapipe.map(lambda x: (x[0], x[1].read()))
 | 
			
		||||
            self.assertEqual(list(dp1), list(dp2))
 | 
			
		||||
 | 
			
		||||
        # tuple
 | 
			
		||||
        _helper(datapipe2)
 | 
			
		||||
        # list
 | 
			
		||||
        datapipe3 = datapipe2.map(lambda x: list(x))
 | 
			
		||||
        _helper(datapipe3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDataFramesPipes(TestCase):
 | 
			
		||||
    """
 | 
			
		||||
@ -619,25 +637,13 @@ class IDP_NoLen(IterDataPipe):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.input_dp = input_dp
 | 
			
		||||
 | 
			
		||||
    # Prevent in-place modification
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for i in self.input_dp:
 | 
			
		||||
        input_dp = self.input_dp if isinstance(self.input_dp, IterDataPipe) else copy.deepcopy(self.input_dp)
 | 
			
		||||
        for i in input_dp:
 | 
			
		||||
            yield i
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IDP(IterDataPipe):
 | 
			
		||||
    def __init__(self, input_dp):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.input_dp = input_dp
 | 
			
		||||
        self.length = len(input_dp)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for i in self.input_dp:
 | 
			
		||||
            yield i
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MDP(MapDataPipe):
 | 
			
		||||
    def __init__(self, input_dp):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -669,19 +675,19 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
    def _test_picklable(self):
 | 
			
		||||
        arr = range(10)
 | 
			
		||||
        picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
 | 
			
		||||
            (dp.iter.Mapper, IDP(arr), (), {}),
 | 
			
		||||
            (dp.iter.Mapper, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
 | 
			
		||||
            (dp.iter.Collator, IDP(arr), (), {}),
 | 
			
		||||
            (dp.iter.Collator, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
 | 
			
		||||
            (dp.iter.Filter, IDP(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
 | 
			
		||||
            (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (), {}),
 | 
			
		||||
            (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (_fake_fn, (0, ), {'test': True}), {}),
 | 
			
		||||
            (dp.iter.Collator, dp.iter.IterableWrapper(arr), (), {}),
 | 
			
		||||
            (dp.iter.Collator, dp.iter.IterableWrapper(arr), (_fake_fn, (0, ), {'test': True}), {}),
 | 
			
		||||
            (dp.iter.Filter, dp.iter.IterableWrapper(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
 | 
			
		||||
        ]
 | 
			
		||||
        for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
 | 
			
		||||
            p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs))  # type: ignore[call-arg]
 | 
			
		||||
 | 
			
		||||
        unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
 | 
			
		||||
            (dp.iter.Mapper, IDP(arr), (lambda x: x, ), {}),
 | 
			
		||||
            (dp.iter.Collator, IDP(arr), (lambda x: x, ), {}),
 | 
			
		||||
            (dp.iter.Filter, IDP(arr), (lambda x: x >= 5, ), {}),
 | 
			
		||||
            (dp.iter.Mapper, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
 | 
			
		||||
            (dp.iter.Collator, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
 | 
			
		||||
            (dp.iter.Filter, dp.iter.IterableWrapper(arr), (lambda x: x >= 5, ), {}),
 | 
			
		||||
        ]
 | 
			
		||||
        for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
 | 
			
		||||
            with warnings.catch_warnings(record=True) as wa:
 | 
			
		||||
@ -692,8 +698,8 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
                    p = pickle.dumps(datapipe)
 | 
			
		||||
 | 
			
		||||
    def test_concat_datapipe(self):
 | 
			
		||||
        input_dp1 = IDP(range(10))
 | 
			
		||||
        input_dp2 = IDP(range(5))
 | 
			
		||||
        input_dp1 = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
        input_dp2 = dp.iter.IterableWrapper(range(5))
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
 | 
			
		||||
            dp.iter.Concater()
 | 
			
		||||
@ -718,7 +724,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_fork_datapipe(self):
 | 
			
		||||
        input_dp = IDP(range(10))
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            input_dp.fork(num_instances=0)
 | 
			
		||||
@ -836,7 +842,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
        self.assertEqual(len(input_dp), len(dp3))
 | 
			
		||||
 | 
			
		||||
    def test_demux_datapipe(self):
 | 
			
		||||
        input_dp = IDP(range(10))
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
 | 
			
		||||
@ -882,8 +888,8 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
        self.assertEqual(list(range(0, 5)), output2)
 | 
			
		||||
 | 
			
		||||
        # Test Case: classifer returns a value outside of [0, num_instance - 1]
 | 
			
		||||
        dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
 | 
			
		||||
        it = iter(dp[0])
 | 
			
		||||
        dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
 | 
			
		||||
        it = iter(dp0[0])
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            next(it)
 | 
			
		||||
            next(it)
 | 
			
		||||
@ -960,7 +966,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings  # Suppress warning for lambda fn
 | 
			
		||||
    def test_map_datapipe(self):
 | 
			
		||||
        input_dp = IDP(range(10))
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
 | 
			
		||||
        def fn(item, dtype=torch.float, *, sum=False):
 | 
			
		||||
            data = torch.tensor(item, dtype=dtype)
 | 
			
		||||
@ -1005,7 +1011,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
        def _helper(ref_fn, fn, input_col=None, output_col=None):
 | 
			
		||||
            for constr in (list, tuple):
 | 
			
		||||
                datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
 | 
			
		||||
                datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
 | 
			
		||||
                res_dp = datapipe.map(fn, input_col, output_col)
 | 
			
		||||
                ref_dp = datapipe.map(ref_fn)
 | 
			
		||||
                self.assertEqual(list(res_dp), list(ref_dp))
 | 
			
		||||
@ -1072,9 +1078,11 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
            return _data
 | 
			
		||||
 | 
			
		||||
        def _helper(ref_fn, fn, input_col=None, output_col=None):
 | 
			
		||||
            datapipe = IDP([{"x": 0, "y": 1, "z": 2},
 | 
			
		||||
                            {"x": 3, "y": 4, "z": 5},
 | 
			
		||||
                            {"x": 6, "y": 7, "z": 8}])
 | 
			
		||||
            datapipe = dp.iter.IterableWrapper(
 | 
			
		||||
                [{"x": 0, "y": 1, "z": 2},
 | 
			
		||||
                 {"x": 3, "y": 4, "z": 5},
 | 
			
		||||
                 {"x": 6, "y": 7, "z": 8}]
 | 
			
		||||
            )
 | 
			
		||||
            res_dp = datapipe.map(fn, input_col, output_col)
 | 
			
		||||
            ref_dp = datapipe.map(ref_fn)
 | 
			
		||||
            self.assertEqual(list(res_dp), list(ref_dp))
 | 
			
		||||
@ -1117,7 +1125,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
    # TODO(VitalyFedyunin): If dill installed this test fails
 | 
			
		||||
    def _test_map_datapipe_nested_level(self):
 | 
			
		||||
 | 
			
		||||
        input_dp = IDP([list(range(10)) for _ in range(3)])
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper([list(range(10)) for _ in range(3)])
 | 
			
		||||
 | 
			
		||||
        def fn(item, *, dtype=torch.float):
 | 
			
		||||
            return torch.tensor(item, dtype=dtype)
 | 
			
		||||
@ -1153,7 +1161,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_collate_datapipe(self):
 | 
			
		||||
        arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
 | 
			
		||||
        input_dp = IDP(arrs)
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(arrs)
 | 
			
		||||
 | 
			
		||||
        def _collate_fn(batch):
 | 
			
		||||
            return torch.tensor(sum(batch), dtype=torch.float)
 | 
			
		||||
@ -1172,7 +1180,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_batch_datapipe(self):
 | 
			
		||||
        arrs = list(range(10))
 | 
			
		||||
        input_dp = IDP(arrs)
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(arrs)
 | 
			
		||||
        with self.assertRaises(AssertionError):
 | 
			
		||||
            input_dp.batch(batch_size=0)
 | 
			
		||||
 | 
			
		||||
@ -1200,7 +1208,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
    def test_unbatch_datapipe(self):
 | 
			
		||||
 | 
			
		||||
        target_length = 6
 | 
			
		||||
        prebatch_dp = IDP(range(target_length))
 | 
			
		||||
        prebatch_dp = dp.iter.IterableWrapper(range(target_length))
 | 
			
		||||
 | 
			
		||||
        input_dp = prebatch_dp.batch(3)
 | 
			
		||||
        unbatch_dp = input_dp.unbatch()
 | 
			
		||||
@ -1208,13 +1216,13 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
        for i, res in zip(prebatch_dp, unbatch_dp):
 | 
			
		||||
            self.assertEqual(i, res)
 | 
			
		||||
 | 
			
		||||
        input_dp = IDP([[0, 1, 2], [3, 4, 5]])
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
 | 
			
		||||
        unbatch_dp = input_dp.unbatch()
 | 
			
		||||
        self.assertEqual(len(list(unbatch_dp)), target_length)
 | 
			
		||||
        for i, res in zip(prebatch_dp, unbatch_dp):
 | 
			
		||||
            self.assertEqual(i, res)
 | 
			
		||||
 | 
			
		||||
        input_dp = IDP([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
 | 
			
		||||
 | 
			
		||||
        unbatch_dp = input_dp.unbatch()
 | 
			
		||||
        expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
 | 
			
		||||
@ -1233,7 +1241,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
        for i, res in zip(expected_dp2, unbatch_dp):
 | 
			
		||||
            self.assertEqual(i, res)
 | 
			
		||||
 | 
			
		||||
        input_dp = IDP([[0, 1, 2], [3, 4, 5]])
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            unbatch_dp = input_dp.unbatch(unbatch_level=-2)
 | 
			
		||||
            for i in unbatch_dp:
 | 
			
		||||
@ -1245,7 +1253,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
                print(i)
 | 
			
		||||
 | 
			
		||||
    def test_bucket_batch_datapipe(self):
 | 
			
		||||
        input_dp = IDP(range(20))
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(range(20))
 | 
			
		||||
        with self.assertRaises(AssertionError):
 | 
			
		||||
            dp.iter.BucketBatcher(input_dp, batch_size=0)
 | 
			
		||||
 | 
			
		||||
@ -1258,7 +1266,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
            data_len = 100
 | 
			
		||||
            arrs = list(range(data_len))
 | 
			
		||||
            random.shuffle(arrs)
 | 
			
		||||
            input_dp = IDP(arrs)
 | 
			
		||||
            input_dp = dp.iter.IterableWrapper(arrs)
 | 
			
		||||
            bucket_dp = dp.iter.BucketBatcher(input_dp, **kwargs)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(len(bucket_dp), data_len // 3 if kwargs['drop_last'] else data_len // 3 + 1)
 | 
			
		||||
@ -1291,7 +1299,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_filter_datapipe(self):
 | 
			
		||||
        input_ds = IDP(range(10))
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
 | 
			
		||||
        def _filter_fn(data, val, clip=False):
 | 
			
		||||
            if clip:
 | 
			
		||||
@ -1318,7 +1326,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_filter_datapipe_nested_list(self):
 | 
			
		||||
 | 
			
		||||
        input_ds = IDP(range(10)).batch(5)
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper(range(10)).batch(5)
 | 
			
		||||
 | 
			
		||||
        def _filter_fn(data, val):
 | 
			
		||||
            return data >= val
 | 
			
		||||
@ -1340,7 +1348,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
            filter_dp = input_ds.filter(nesting_level=5, filter_fn=_filter_fn, fn_kwargs={'val': 5})
 | 
			
		||||
            temp = list(filter_dp)
 | 
			
		||||
 | 
			
		||||
        input_ds = IDP(range(10)).batch(3)
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper(range(10)).batch(3)
 | 
			
		||||
 | 
			
		||||
        filter_dp = input_ds.filter(lambda ls: len(ls) >= 3)
 | 
			
		||||
        expected_dp3: List[List[int]] = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
 | 
			
		||||
@ -1348,21 +1356,21 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
        for data, exp in zip(filter_dp, expected_dp3):
 | 
			
		||||
            self.assertEqual(data, exp)
 | 
			
		||||
 | 
			
		||||
        input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
 | 
			
		||||
        filter_dp = input_ds.filter(lambda x: x > 3, nesting_level=-1)
 | 
			
		||||
        expected_dp4 = [[[4, 5]], [[6, 7, 8]]]
 | 
			
		||||
        self.assertEqual(len(list(filter_dp)), len(expected_dp4))
 | 
			
		||||
        for data2, exp2 in zip(filter_dp, expected_dp4):
 | 
			
		||||
            self.assertEqual(data2, exp2)
 | 
			
		||||
 | 
			
		||||
        input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
 | 
			
		||||
        filter_dp = input_ds.filter(lambda x: x > 7, nesting_level=-1)
 | 
			
		||||
        expected_dp5 = [[[8]]]
 | 
			
		||||
        self.assertEqual(len(list(filter_dp)), len(expected_dp5))
 | 
			
		||||
        for data2, exp2 in zip(filter_dp, expected_dp5):
 | 
			
		||||
            self.assertEqual(data2, exp2)
 | 
			
		||||
 | 
			
		||||
        input_ds = IDP([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]])
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]])
 | 
			
		||||
        filter_dp = input_ds.filter(lambda ls: len(ls) >= 3, nesting_level=1)
 | 
			
		||||
        expected_dp6 = [[[6, 7, 8], [1, 2, 3]]]
 | 
			
		||||
        self.assertEqual(len(list(filter_dp)), len(expected_dp6))
 | 
			
		||||
@ -1370,7 +1378,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
            self.assertEqual(data2, exp2)
 | 
			
		||||
 | 
			
		||||
    def test_sampler_datapipe(self):
 | 
			
		||||
        input_dp = IDP(range(10))
 | 
			
		||||
        input_dp = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
        # Default SequentialSampler
 | 
			
		||||
        sampled_dp = dp.iter.Sampler(input_dp)  # type: ignore[var-annotated]
 | 
			
		||||
        self.assertEqual(len(sampled_dp), 10)
 | 
			
		||||
@ -1387,7 +1395,7 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_shuffle_datapipe(self):
 | 
			
		||||
        exp = list(range(20))
 | 
			
		||||
        input_ds = IDP(exp)
 | 
			
		||||
        input_ds = dp.iter.IterableWrapper(exp)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(AssertionError):
 | 
			
		||||
            shuffle_dp = input_ds.shuffle(buffer_size=0)
 | 
			
		||||
@ -1413,15 +1421,15 @@ class TestFunctionalIterDataPipe(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_zip_datapipe(self):
 | 
			
		||||
        with self.assertRaises(TypeError):
 | 
			
		||||
            dp.iter.Zipper(IDP(range(10)), list(range(10)))  # type: ignore[arg-type]
 | 
			
		||||
            dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10)))  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
        zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP_NoLen(range(5)))  # type: ignore[var-annotated]
 | 
			
		||||
        zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5)))  # type: ignore[var-annotated]
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
 | 
			
		||||
            len(zipped_dp)
 | 
			
		||||
        exp = list((i, i) for i in range(5))
 | 
			
		||||
        self.assertEqual(list(zipped_dp), exp)
 | 
			
		||||
 | 
			
		||||
        zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP(range(5)))
 | 
			
		||||
        zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5)))
 | 
			
		||||
        self.assertEqual(len(zipped_dp), 5)
 | 
			
		||||
        self.assertEqual(list(zipped_dp), exp)
 | 
			
		||||
        # Reset
 | 
			
		||||
@ -1506,32 +1514,32 @@ class TestFunctionalMapDataPipe(TestCase):
 | 
			
		||||
    def test_mux_datapipe(self):
 | 
			
		||||
 | 
			
		||||
        # Test Case: Elements are yielded one at a time from each DataPipe, until they are all exhausted
 | 
			
		||||
        input_dp1 = IDP(range(4))
 | 
			
		||||
        input_dp2 = IDP(range(4, 8))
 | 
			
		||||
        input_dp3 = IDP(range(8, 12))
 | 
			
		||||
        input_dp1 = dp.iter.IterableWrapper(range(4))
 | 
			
		||||
        input_dp2 = dp.iter.IterableWrapper(range(4, 8))
 | 
			
		||||
        input_dp3 = dp.iter.IterableWrapper(range(8, 12))
 | 
			
		||||
        output_dp = input_dp1.mux(input_dp2, input_dp3)
 | 
			
		||||
        expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
 | 
			
		||||
        self.assertEqual(len(expected_output), len(output_dp))
 | 
			
		||||
        self.assertEqual(expected_output, list(output_dp))
 | 
			
		||||
 | 
			
		||||
        # Test Case: Uneven input Data Pipes
 | 
			
		||||
        input_dp1 = IDP([1, 2, 3, 4])
 | 
			
		||||
        input_dp2 = IDP([10])
 | 
			
		||||
        input_dp3 = IDP([100, 200, 300])
 | 
			
		||||
        input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4])
 | 
			
		||||
        input_dp2 = dp.iter.IterableWrapper([10])
 | 
			
		||||
        input_dp3 = dp.iter.IterableWrapper([100, 200, 300])
 | 
			
		||||
        output_dp = input_dp1.mux(input_dp2, input_dp3)
 | 
			
		||||
        expected_output = [1, 10, 100, 2, 200, 3, 300, 4]
 | 
			
		||||
        self.assertEqual(len(expected_output), len(output_dp))
 | 
			
		||||
        self.assertEqual(expected_output, list(output_dp))
 | 
			
		||||
 | 
			
		||||
        # Test Case: Empty Data Pipe
 | 
			
		||||
        input_dp1 = IDP([0, 1, 2, 3])
 | 
			
		||||
        input_dp2 = IDP([])
 | 
			
		||||
        input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3])
 | 
			
		||||
        input_dp2 = dp.iter.IterableWrapper([])
 | 
			
		||||
        output_dp = input_dp1.mux(input_dp2)
 | 
			
		||||
        self.assertEqual(len(input_dp1), len(output_dp))
 | 
			
		||||
        self.assertEqual(list(input_dp1), list(output_dp))
 | 
			
		||||
 | 
			
		||||
        # Test Case: raises TypeError when __len__ is called and an input doesn't have __len__
 | 
			
		||||
        input_dp1 = IDP(range(10))
 | 
			
		||||
        input_dp1 = dp.iter.IterableWrapper(range(10))
 | 
			
		||||
        input_dp_no_len = IDP_NoLen(range(10))
 | 
			
		||||
        output_dp = input_dp1.mux(input_dp_no_len)
 | 
			
		||||
        with self.assertRaises(TypeError):
 | 
			
		||||
@ -1665,8 +1673,8 @@ class TestTyping(TestCase):
 | 
			
		||||
        self.assertTrue(issubclass(DP1, IterDataPipe))
 | 
			
		||||
        dp1 = DP1(10)
 | 
			
		||||
        self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))
 | 
			
		||||
        dp2 = DP1(5)
 | 
			
		||||
        self.assertEqual(dp1.type, dp2.type)
 | 
			
		||||
        dp1_ = DP1(5)
 | 
			
		||||
        self.assertEqual(dp1.type, dp1_.type)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"is not a generic class"):
 | 
			
		||||
            class InvalidDP5(DP1[tuple]):  # type: ignore[type-arg]
 | 
			
		||||
@ -1679,10 +1687,10 @@ class TestTyping(TestCase):
 | 
			
		||||
                    yield d  # type: ignore[misc]
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP2, IterDataPipe))
 | 
			
		||||
        dp1 = DP2()  # type: ignore[assignment]
 | 
			
		||||
        self.assertTrue(DP2.type.issubtype(dp1.type) and dp1.type.issubtype(DP2.type))
 | 
			
		||||
        dp2 = DP2()  # type: ignore[assignment]
 | 
			
		||||
        self.assertEqual(dp1.type, dp2.type)
 | 
			
		||||
        dp2 = DP2()  # type: ignore[var-annotated]
 | 
			
		||||
        self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type))
 | 
			
		||||
        dp2_ = DP2()  # type: ignore[var-annotated]
 | 
			
		||||
        self.assertEqual(dp2.type, dp2_.type)
 | 
			
		||||
 | 
			
		||||
        class DP3(IterDataPipe[Tuple[T_co, str]]):
 | 
			
		||||
            r""" DataPipe without fixed type with __init__ function"""
 | 
			
		||||
@ -1695,10 +1703,10 @@ class TestTyping(TestCase):
 | 
			
		||||
                    yield d, str(d)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP3, IterDataPipe))
 | 
			
		||||
        dp1 = DP3(range(10))  # type: ignore[assignment]
 | 
			
		||||
        self.assertTrue(DP3.type.issubtype(dp1.type) and dp1.type.issubtype(DP3.type))
 | 
			
		||||
        dp2 = DP3(5)  # type: ignore[assignment]
 | 
			
		||||
        self.assertEqual(dp1.type, dp2.type)
 | 
			
		||||
        dp3 = DP3(range(10))  # type: ignore[var-annotated]
 | 
			
		||||
        self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type))
 | 
			
		||||
        dp3_ = DP3(5)  # type: ignore[var-annotated]
 | 
			
		||||
        self.assertEqual(dp3.type, dp3_.type)
 | 
			
		||||
 | 
			
		||||
        class DP4(IterDataPipe[tuple]):
 | 
			
		||||
            r""" DataPipe without __iter__ annotation"""
 | 
			
		||||
@ -1707,8 +1715,8 @@ class TestTyping(TestCase):
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP4, IterDataPipe))
 | 
			
		||||
        dp = DP4()
 | 
			
		||||
        self.assertTrue(dp.type.param == tuple)
 | 
			
		||||
        dp4 = DP4()
 | 
			
		||||
        self.assertTrue(dp4.type.param == tuple)
 | 
			
		||||
 | 
			
		||||
        class DP5(IterDataPipe):
 | 
			
		||||
            r""" DataPipe without type annotation"""
 | 
			
		||||
@ -1717,9 +1725,9 @@ class TestTyping(TestCase):
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP5, IterDataPipe))
 | 
			
		||||
        dp = DP5()  # type: ignore[assignment]
 | 
			
		||||
        dp5 = DP5()
 | 
			
		||||
        from torch.utils.data._typing import issubtype
 | 
			
		||||
        self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param))
 | 
			
		||||
        self.assertTrue(issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param))
 | 
			
		||||
 | 
			
		||||
        class DP6(IterDataPipe[int]):
 | 
			
		||||
            r""" DataPipe with plain Iterator"""
 | 
			
		||||
@ -1728,13 +1736,13 @@ class TestTyping(TestCase):
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP6, IterDataPipe))
 | 
			
		||||
        dp = DP6()  # type: ignore[assignment]
 | 
			
		||||
        self.assertTrue(dp.type.param == int)
 | 
			
		||||
        dp6 = DP6()
 | 
			
		||||
        self.assertTrue(dp6.type.param == int)
 | 
			
		||||
 | 
			
		||||
        class DP7(IterDataPipe[Awaitable[T_co]]):
 | 
			
		||||
            r""" DataPipe with abstract base class"""
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(issubclass(DP6, IterDataPipe))
 | 
			
		||||
        self.assertTrue(issubclass(DP7, IterDataPipe))
 | 
			
		||||
        self.assertTrue(DP7.type.param == Awaitable[T_co])
 | 
			
		||||
 | 
			
		||||
        class DP8(DP7[str]):
 | 
			
		||||
@ -1765,11 +1773,11 @@ class TestTyping(TestCase):
 | 
			
		||||
        # Non-DataPipe input with DataPipe hint
 | 
			
		||||
        datasource = [(1, '1'), (2, '2'), (3, '3')]
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
 | 
			
		||||
            dp = DP0(datasource)
 | 
			
		||||
            dp0 = DP0(datasource)
 | 
			
		||||
 | 
			
		||||
        dp = DP0(IDP(range(10)))
 | 
			
		||||
        dp0 = DP0(dp.iter.IterableWrapper(range(10)))
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
 | 
			
		||||
            dp = DP1(dp)
 | 
			
		||||
            dp1 = DP1(dp0)
 | 
			
		||||
 | 
			
		||||
    def test_runtime(self):
 | 
			
		||||
        class DP(IterDataPipe[Tuple[int, T_co]]):
 | 
			
		||||
@ -1784,26 +1792,26 @@ class TestTyping(TestCase):
 | 
			
		||||
        dss = ([(1, '1'), (2, '2')],
 | 
			
		||||
               [(1, 1), (2, '2')])
 | 
			
		||||
        for ds in dss:
 | 
			
		||||
            dp = DP(ds)  # type: ignore[var-annotated]
 | 
			
		||||
            self.assertEqual(list(dp), ds)
 | 
			
		||||
            dp0 = DP(ds)  # type: ignore[var-annotated]
 | 
			
		||||
            self.assertEqual(list(dp0), ds)
 | 
			
		||||
            # Reset __iter__
 | 
			
		||||
            self.assertEqual(list(dp), ds)
 | 
			
		||||
            self.assertEqual(list(dp0), ds)
 | 
			
		||||
 | 
			
		||||
        dss = ([(1, 1), ('2', 2)],  # type: ignore[assignment, list-item]
 | 
			
		||||
               [[1, '1'], [2, '2']],  # type: ignore[list-item]
 | 
			
		||||
               [1, '1', 2, '2'])
 | 
			
		||||
        for ds in dss:
 | 
			
		||||
            dp = DP(ds)
 | 
			
		||||
            dp0 = DP(ds)
 | 
			
		||||
            with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
 | 
			
		||||
                list(dp)
 | 
			
		||||
                list(dp0)
 | 
			
		||||
 | 
			
		||||
            with runtime_validation_disabled():
 | 
			
		||||
                self.assertEqual(list(dp), ds)
 | 
			
		||||
                self.assertEqual(list(dp0), ds)
 | 
			
		||||
                with runtime_validation_disabled():
 | 
			
		||||
                    self.assertEqual(list(dp), ds)
 | 
			
		||||
                    self.assertEqual(list(dp0), ds)
 | 
			
		||||
 | 
			
		||||
            with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
 | 
			
		||||
                list(dp)
 | 
			
		||||
                list(dp0)
 | 
			
		||||
 | 
			
		||||
    def test_reinforce(self):
 | 
			
		||||
        T = TypeVar('T', int, str)
 | 
			
		||||
@ -1819,26 +1827,26 @@ class TestTyping(TestCase):
 | 
			
		||||
 | 
			
		||||
        ds = list(range(10))
 | 
			
		||||
        # Valid type reinforcement
 | 
			
		||||
        dp = DP(ds).reinforce_type(int)
 | 
			
		||||
        self.assertTrue(dp.type, int)
 | 
			
		||||
        self.assertEqual(list(dp), ds)
 | 
			
		||||
        dp0 = DP(ds).reinforce_type(int)
 | 
			
		||||
        self.assertTrue(dp0.type, int)
 | 
			
		||||
        self.assertEqual(list(dp0), ds)
 | 
			
		||||
 | 
			
		||||
        # Invalid type
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"):
 | 
			
		||||
            dp = DP(ds).reinforce_type(1)
 | 
			
		||||
            dp1 = DP(ds).reinforce_type(1)
 | 
			
		||||
 | 
			
		||||
        # Type is not subtype
 | 
			
		||||
        with self.assertRaisesRegex(TypeError, r"Expected 'expected_type' as subtype of"):
 | 
			
		||||
            dp = DP(ds).reinforce_type(float)
 | 
			
		||||
            dp2 = DP(ds).reinforce_type(float)
 | 
			
		||||
 | 
			
		||||
        # Invalid data at runtime
 | 
			
		||||
        dp = DP(ds).reinforce_type(str)
 | 
			
		||||
        dp3 = DP(ds).reinforce_type(str)
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
 | 
			
		||||
            list(dp)
 | 
			
		||||
            list(dp3)
 | 
			
		||||
 | 
			
		||||
        # Context Manager to disable the runtime validation
 | 
			
		||||
        with runtime_validation_disabled():
 | 
			
		||||
            self.assertEqual(list(d for d in dp), ds)
 | 
			
		||||
            self.assertEqual(list(d for d in dp3), ds)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumbersDataset(IterDataPipe):
 | 
			
		||||
@ -1900,7 +1908,7 @@ class TestSharding(TestCase):
 | 
			
		||||
        self.assertEqual(sorted(all_items), sorted(items))
 | 
			
		||||
 | 
			
		||||
    def test_sharding_length(self):
 | 
			
		||||
        numbers_dp = IDP(range(13))
 | 
			
		||||
        numbers_dp = dp.iter.IterableWrapper(range(13))
 | 
			
		||||
        sharded_dp0 = numbers_dp.sharding_filter()
 | 
			
		||||
        torch.utils.data.sharding.apply_sharding(sharded_dp0, 3, 0)
 | 
			
		||||
        sharded_dp1 = numbers_dp.sharding_filter()
 | 
			
		||||
@ -1912,7 +1920,7 @@ class TestSharding(TestCase):
 | 
			
		||||
        self.assertEqual(4, len(sharded_dp1))
 | 
			
		||||
        self.assertEqual(4, len(sharded_dp2))
 | 
			
		||||
 | 
			
		||||
        numbers_dp = IDP(range(1))
 | 
			
		||||
        numbers_dp = dp.iter.IterableWrapper(range(1))
 | 
			
		||||
        sharded_dp0 = numbers_dp.sharding_filter()
 | 
			
		||||
        torch.utils.data.sharding.apply_sharding(sharded_dp0, 2, 0)
 | 
			
		||||
        sharded_dp1 = numbers_dp.sharding_filter()
 | 
			
		||||
@ -1922,11 +1930,11 @@ class TestSharding(TestCase):
 | 
			
		||||
 | 
			
		||||
    @skipIfNoDill
 | 
			
		||||
    def test_old_dataloader(self):
 | 
			
		||||
        dp = self._get_pipeline()
 | 
			
		||||
        expected = list(dp)
 | 
			
		||||
        dp0 = self._get_pipeline()
 | 
			
		||||
        expected = list(dp0)
 | 
			
		||||
 | 
			
		||||
        dp = self._get_pipeline().sharding_filter()
 | 
			
		||||
        dl = DataLoader(dp, batch_size=1, shuffle=False, num_workers=2,
 | 
			
		||||
        dp0 = self._get_pipeline().sharding_filter()
 | 
			
		||||
        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2,
 | 
			
		||||
                        worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn)
 | 
			
		||||
        items = []
 | 
			
		||||
        for i in dl:
 | 
			
		||||
 | 
			
		||||
@ -9704,12 +9704,6 @@ class TestNN(NNTestCase):
 | 
			
		||||
        self.assertEqual(input1.grad, torch.zeros_like(input1))
 | 
			
		||||
        self.assertEqual(input2.grad, input1 * 1e8)
 | 
			
		||||
 | 
			
		||||
        # Check error when inputs are not the same shape
 | 
			
		||||
        input1 = torch.randn(2, 2, 1)
 | 
			
		||||
        input2 = torch.randn(2, 1, 3)
 | 
			
		||||
        with self.assertRaises(RuntimeError):
 | 
			
		||||
            F.cosine_similarity(input1, input2)
 | 
			
		||||
 | 
			
		||||
        # Check type promotion, issue #61454
 | 
			
		||||
        input = torch.tensor(12.)
 | 
			
		||||
        out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,8 @@
 | 
			
		||||
from torch.testing._internal.common_utils import TestCase, run_tests
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch import Tensor, vmap
 | 
			
		||||
from torch import Tensor
 | 
			
		||||
from torch._vmap_internals import vmap
 | 
			
		||||
import functools
 | 
			
		||||
import itertools
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
@ -387,7 +387,7 @@ class CMake:
 | 
			
		||||
            # then. Until then, we use "--" to pass parameters to the
 | 
			
		||||
            # underlying build system.
 | 
			
		||||
            build_args += ['--']
 | 
			
		||||
            if IS_WINDOWS:
 | 
			
		||||
            if IS_WINDOWS and not USE_NINJA:
 | 
			
		||||
                # We are likely using msbuild here
 | 
			
		||||
                build_args += ['/p:CL_MPCount={}'.format(max_jobs)]
 | 
			
		||||
            else:
 | 
			
		||||
 | 
			
		||||
@ -108,7 +108,7 @@ def plural(n: int) -> str:
 | 
			
		||||
 | 
			
		||||
def get_base_commit(sha1: str) -> str:
 | 
			
		||||
    return subprocess.check_output(
 | 
			
		||||
        ["git", "merge-base", sha1, "origin/master"],
 | 
			
		||||
        ["git", "merge-base", sha1, "origin/release/1.10"],
 | 
			
		||||
        encoding="ascii",
 | 
			
		||||
    ).strip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,11 @@ class TestCMake(unittest.TestCase):
 | 
			
		||||
            # MAX_JOBS, USE_NINJA, IS_WINDOWS,         want
 | 
			
		||||
            ((     '8',      True,     False),          ['-j', '8']),  # noqa: E201,E241
 | 
			
		||||
            ((    None,      True,     False),                 None),  # noqa: E201,E241
 | 
			
		||||
            ((     '7',     False,     False),          ['-j', '7']),  # noqa: E201,E241
 | 
			
		||||
            ((    None,     False,     False),         ['-j', '13']),  # noqa: E201,E241
 | 
			
		||||
            ((     '6',      True,      True),          ['-j', '6']),  # noqa: E201,E241
 | 
			
		||||
            ((    None,      True,      True),                 None),  # noqa: E201,E241
 | 
			
		||||
            ((    '11',     False,      True), ['/p:CL_MPCount=11']),  # noqa: E201,E241
 | 
			
		||||
            ((    None,     False,      True), ['/p:CL_MPCount=13']),  # noqa: E201,E241
 | 
			
		||||
        ]
 | 
			
		||||
        for (max_jobs, use_ninja, is_windows), want in cases:
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,13 @@ if(NOT CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
 | 
			
		||||
  set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(BUILD_BINARY)
 | 
			
		||||
  add_library(aot_compiler SHARED
 | 
			
		||||
          ${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
 | 
			
		||||
          )
 | 
			
		||||
  install(TARGETS aot_compiler DESTINATION lib)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(NOT BUILD_PYTHON)
 | 
			
		||||
  return()
 | 
			
		||||
endif()
 | 
			
		||||
@ -430,9 +437,3 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
 | 
			
		||||
  # Pybind11 requires explicit linking of the torch_python library
 | 
			
		||||
  target_link_libraries(nnapi_backend torch torch_python)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(BUILD_BINARY)
 | 
			
		||||
  add_library(aot_compiler SHARED
 | 
			
		||||
          ${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
 | 
			
		||||
          )
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
@ -756,8 +756,6 @@ del register_after_fork
 | 
			
		||||
# torch.jit.script as a decorator, for instance):
 | 
			
		||||
from ._lobpcg import lobpcg as lobpcg
 | 
			
		||||
 | 
			
		||||
from ._vmap_internals import vmap as vmap
 | 
			
		||||
 | 
			
		||||
# These were previously defined in native_functions.yaml and appeared on the
 | 
			
		||||
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
 | 
			
		||||
# class usage. We add these lines here to preserve backward compatibility.
 | 
			
		||||
 | 
			
		||||
@ -362,19 +362,21 @@ void ConvertGraphToONNXProto(
 | 
			
		||||
    SymbolDimMap& symbol_map,
 | 
			
		||||
    int opset_version) {
 | 
			
		||||
  RawDataExportMap export_map;
 | 
			
		||||
  std::tie(model_proto, export_map, symbol_map) = export_onnx(
 | 
			
		||||
      graph,
 | 
			
		||||
      {},
 | 
			
		||||
      opset_version,
 | 
			
		||||
      {},
 | 
			
		||||
      false,
 | 
			
		||||
      onnx_torch::OperatorExportTypes::ONNX,
 | 
			
		||||
      true,
 | 
			
		||||
      true,
 | 
			
		||||
      {},
 | 
			
		||||
      true,
 | 
			
		||||
      false,
 | 
			
		||||
      std::string());
 | 
			
		||||
  bool val_use_external_data_format;
 | 
			
		||||
  std::tie(model_proto, export_map, symbol_map, val_use_external_data_format) =
 | 
			
		||||
      export_onnx(
 | 
			
		||||
          graph,
 | 
			
		||||
          {},
 | 
			
		||||
          opset_version,
 | 
			
		||||
          {},
 | 
			
		||||
          false,
 | 
			
		||||
          onnx_torch::OperatorExportTypes::ONNX,
 | 
			
		||||
          true,
 | 
			
		||||
          true,
 | 
			
		||||
          {},
 | 
			
		||||
          true,
 | 
			
		||||
          false,
 | 
			
		||||
          std::string());
 | 
			
		||||
  for (int i = 0; i < model_proto->graph().output_size(); ++i) {
 | 
			
		||||
    model_proto->mutable_graph()->mutable_output(i)->clear_type();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -263,19 +263,25 @@ void initPythonIRBindings(PyObject* module_) {
 | 
			
		||||
            std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
 | 
			
		||||
            RawDataExportMap export_map;
 | 
			
		||||
            SymbolDimMap symbol_map;
 | 
			
		||||
            std::tie(model_proto, export_map, symbol_map) = export_onnx(
 | 
			
		||||
                g,
 | 
			
		||||
                initializers,
 | 
			
		||||
                onnx_opset_version,
 | 
			
		||||
                dynamic_axes,
 | 
			
		||||
                defer_weight_export,
 | 
			
		||||
                operator_export_type,
 | 
			
		||||
                strip_doc_string,
 | 
			
		||||
                keep_initializers_as_inputs,
 | 
			
		||||
                custom_opsets,
 | 
			
		||||
                add_node_names,
 | 
			
		||||
                use_external_data_format,
 | 
			
		||||
                onnx_file_path);
 | 
			
		||||
            bool val_use_external_data_format = false;
 | 
			
		||||
            std::tie(
 | 
			
		||||
                model_proto,
 | 
			
		||||
                export_map,
 | 
			
		||||
                symbol_map,
 | 
			
		||||
                val_use_external_data_format) =
 | 
			
		||||
                export_onnx(
 | 
			
		||||
                    g,
 | 
			
		||||
                    initializers,
 | 
			
		||||
                    onnx_opset_version,
 | 
			
		||||
                    dynamic_axes,
 | 
			
		||||
                    defer_weight_export,
 | 
			
		||||
                    operator_export_type,
 | 
			
		||||
                    strip_doc_string,
 | 
			
		||||
                    keep_initializers_as_inputs,
 | 
			
		||||
                    custom_opsets,
 | 
			
		||||
                    add_node_names,
 | 
			
		||||
                    use_external_data_format,
 | 
			
		||||
                    onnx_file_path);
 | 
			
		||||
            std::unordered_map<std::string, py::bytes>
 | 
			
		||||
                python_serialized_export_map;
 | 
			
		||||
            for (auto& kv : export_map) {
 | 
			
		||||
@ -289,7 +295,9 @@ void initPythonIRBindings(PyObject* module_) {
 | 
			
		||||
            }
 | 
			
		||||
            graph = serialize_model_proto_to_string(model_proto);
 | 
			
		||||
            return std::make_tuple(
 | 
			
		||||
                py::bytes(graph), python_serialized_export_map);
 | 
			
		||||
                py::bytes(graph),
 | 
			
		||||
                python_serialized_export_map,
 | 
			
		||||
                val_use_external_data_format);
 | 
			
		||||
          },
 | 
			
		||||
          py::arg("initializers"),
 | 
			
		||||
          py::arg("onnx_opset_version") = 0,
 | 
			
		||||
 | 
			
		||||
@ -231,6 +231,12 @@ class EncoderBase {
 | 
			
		||||
      bool use_external_data_format = false,
 | 
			
		||||
      const std::string& onnx_file_path = std::string());
 | 
			
		||||
 | 
			
		||||
  unsigned long long int GetGraphProtoSize(
 | 
			
		||||
      onnx::GraphProto* graph_proto,
 | 
			
		||||
      const std::shared_ptr<Graph>& graph,
 | 
			
		||||
      const std::map<std::string, at::Tensor>& initializers =
 | 
			
		||||
          std::map<std::string, at::Tensor>());
 | 
			
		||||
 | 
			
		||||
  virtual void EncodeTensor(
 | 
			
		||||
      onnx::TensorProto* tensor_proto,
 | 
			
		||||
      const at::Tensor& tensor,
 | 
			
		||||
@ -581,7 +587,6 @@ void EncoderBase::AddInitializersIntoGraphProto(
 | 
			
		||||
    bool use_external_data_format,
 | 
			
		||||
    const std::string& onnx_file_path) {
 | 
			
		||||
  AT_ASSERT(block->inputs().size() >= initializers.size());
 | 
			
		||||
 | 
			
		||||
  for (auto input : block->inputs()) {
 | 
			
		||||
    auto name_tensor_pair = initializers.find(input->debugName());
 | 
			
		||||
    if (name_tensor_pair == initializers.end()) {
 | 
			
		||||
@ -598,6 +603,38 @@ void EncoderBase::AddInitializersIntoGraphProto(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
unsigned long long int EncoderBase::GetGraphProtoSize(
 | 
			
		||||
    onnx::GraphProto* graph_proto,
 | 
			
		||||
    const std::shared_ptr<Graph>& graph,
 | 
			
		||||
    const std::map<std::string, at::Tensor>& initializers) {
 | 
			
		||||
  unsigned long long int sizes = 0;
 | 
			
		||||
  for (auto input : graph->inputs()) {
 | 
			
		||||
    auto name_tensor_pair = initializers.find(input->debugName());
 | 
			
		||||
    if (name_tensor_pair == initializers.end()) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    onnx::GraphProto* graph_proto_copy = new onnx::GraphProto(*graph_proto);
 | 
			
		||||
    auto tensor_proto = graph_proto_copy->add_initializer();
 | 
			
		||||
    const at::Tensor tensor = name_tensor_pair->second;
 | 
			
		||||
    for (auto d : tensor.sizes()) {
 | 
			
		||||
      tensor_proto->add_dims(d);
 | 
			
		||||
    }
 | 
			
		||||
    tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
 | 
			
		||||
    at::Tensor t;
 | 
			
		||||
    if (tensor.is_quantized()) {
 | 
			
		||||
      t = tensor.contiguous();
 | 
			
		||||
    } else {
 | 
			
		||||
      t = tensor.contiguous().cpu();
 | 
			
		||||
    }
 | 
			
		||||
    tensor_proto->set_raw_data(std::string(
 | 
			
		||||
        static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
 | 
			
		||||
    sizes += tensor_proto->ByteSizeLong();
 | 
			
		||||
    delete graph_proto_copy;
 | 
			
		||||
    graph_proto_copy = nullptr;
 | 
			
		||||
  }
 | 
			
		||||
  return sizes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void EncoderBase::AddAttribute(
 | 
			
		||||
    onnx::NodeProto* node_proto,
 | 
			
		||||
    const jit::Node* node,
 | 
			
		||||
@ -726,6 +763,10 @@ class GraphEncoder : public EncoderBase {
 | 
			
		||||
    return raw_data_export_map_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool get_use_external_data_format() {
 | 
			
		||||
    return use_external_data_format_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  void EncodeTensor(
 | 
			
		||||
      onnx::TensorProto* tensor_proto,
 | 
			
		||||
@ -736,6 +777,7 @@ class GraphEncoder : public EncoderBase {
 | 
			
		||||
 | 
			
		||||
  RawDataExportMap raw_data_export_map_;
 | 
			
		||||
  bool defer_weight_export_;
 | 
			
		||||
  bool use_external_data_format_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
GraphEncoder::GraphEncoder(
 | 
			
		||||
@ -754,8 +796,21 @@ GraphEncoder::GraphEncoder(
 | 
			
		||||
    bool use_external_data_format,
 | 
			
		||||
    const std::string& onnx_file_path)
 | 
			
		||||
    : EncoderBase(operator_export_type, strip_doc),
 | 
			
		||||
      defer_weight_export_(defer_weight_export) {
 | 
			
		||||
      defer_weight_export_(defer_weight_export),
 | 
			
		||||
      use_external_data_format_(use_external_data_format) {
 | 
			
		||||
  validateGraph(graph, operator_export_type);
 | 
			
		||||
  // If graph proto size exceed maximum protobuf size of 2GB, set
 | 
			
		||||
  // use_external_data_format to true.
 | 
			
		||||
  if (!use_external_data_format && !onnx_file_path.empty() &&
 | 
			
		||||
      GetGraphProtoSize(model_proto_.mutable_graph(), graph, initializers) >
 | 
			
		||||
          INT_MAX) {
 | 
			
		||||
    GRAPH_DEBUG(
 | 
			
		||||
        "Exporting model exceed maximum protobuf size of 2GB. Storing model parameters in external data files");
 | 
			
		||||
    use_external_data_format = true;
 | 
			
		||||
    // use_external_data_format_ is one of graph_encoder private variable set
 | 
			
		||||
    // for return `use_external_data_format` value.
 | 
			
		||||
    use_external_data_format_ = use_external_data_format;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (use_external_data_format) {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
@ -895,7 +950,8 @@ std::string pretty_print_onnx(
 | 
			
		||||
std::tuple<
 | 
			
		||||
    std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
 | 
			
		||||
    RawDataExportMap,
 | 
			
		||||
    SymbolDimMap>
 | 
			
		||||
    SymbolDimMap,
 | 
			
		||||
    bool>
 | 
			
		||||
export_onnx(
 | 
			
		||||
    const std::shared_ptr<Graph>& graph,
 | 
			
		||||
    const std::map<std::string, at::Tensor>& initializers,
 | 
			
		||||
@ -929,7 +985,8 @@ export_onnx(
 | 
			
		||||
      std::make_shared<::ONNX_NAMESPACE::ModelProto>(
 | 
			
		||||
          graph_encoder.get_model_proto()),
 | 
			
		||||
      graph_encoder.get_raw_data_export_map(),
 | 
			
		||||
      graph_encoder.get_symbol_dim_param_map());
 | 
			
		||||
      graph_encoder.get_symbol_dim_param_map(),
 | 
			
		||||
      graph_encoder.get_use_external_data_format());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string serialize_model_proto_to_string(
 | 
			
		||||
@ -938,7 +995,7 @@ std::string serialize_model_proto_to_string(
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      proto_size <= INT_MAX,
 | 
			
		||||
      "Exporting model exceed maximum protobuf size of 2GB. "
 | 
			
		||||
      "Please call torch.onnx.export with use_external_data_format=True.");
 | 
			
		||||
      "Please call torch.onnx.export without setting use_external_data_format parameter.");
 | 
			
		||||
  return model_proto->SerializeAsString();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,8 @@ using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
 | 
			
		||||
TORCH_API std::tuple<
 | 
			
		||||
    std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
 | 
			
		||||
    RawDataExportMap,
 | 
			
		||||
    SymbolDimMap>
 | 
			
		||||
    SymbolDimMap,
 | 
			
		||||
    bool>
 | 
			
		||||
export_onnx(
 | 
			
		||||
    const std::shared_ptr<Graph>& graph,
 | 
			
		||||
    const std::map<std::string, at::Tensor>& initializers,
 | 
			
		||||
 | 
			
		||||
@ -1519,8 +1519,11 @@ def _object_to_tensor(obj):
 | 
			
		||||
    f = io.BytesIO()
 | 
			
		||||
    _pickler(f).dump(obj)
 | 
			
		||||
    byte_storage = torch.ByteStorage.from_buffer(f.getvalue())  # type: ignore[attr-defined]
 | 
			
		||||
    byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
 | 
			
		||||
    local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long)
 | 
			
		||||
    # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
 | 
			
		||||
    # Otherwise, it will casue 100X slowdown.
 | 
			
		||||
    # See: https://github.com/pytorch/pytorch/issues/65696
 | 
			
		||||
    byte_tensor = torch.ByteTensor(byte_storage)
 | 
			
		||||
    local_size = torch.LongTensor([byte_tensor.numel()])
 | 
			
		||||
    return byte_tensor, local_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4256,7 +4256,10 @@ cosine_similarity = _add_docstr(
 | 
			
		||||
    r"""
 | 
			
		||||
cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
 | 
			
		||||
 | 
			
		||||
Returns cosine similarity between x1 and x2, computed along dim.
 | 
			
		||||
Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
 | 
			
		||||
to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
 | 
			
		||||
squeezed (see :func:`torch.squeeze`), resulting in the
 | 
			
		||||
output tensor having 1 fewer dimension.
 | 
			
		||||
 | 
			
		||||
.. math ::
 | 
			
		||||
    \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
 | 
			
		||||
@ -4265,16 +4268,11 @@ Supports :ref:`type promotion <type-promotion-doc>`.
 | 
			
		||||
 | 
			
		||||
Args:
 | 
			
		||||
    x1 (Tensor): First input.
 | 
			
		||||
    x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`,
 | 
			
		||||
        and broadcastable with x1 at other dimensions).
 | 
			
		||||
    dim (int, optional): Dimension of vectors. Default: 1
 | 
			
		||||
    x2 (Tensor): Second input.
 | 
			
		||||
    dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
 | 
			
		||||
    eps (float, optional): Small value to avoid division by zero.
 | 
			
		||||
        Default: 1e-8
 | 
			
		||||
 | 
			
		||||
Shape:
 | 
			
		||||
    - Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
 | 
			
		||||
    - Output: :math:`(\ast_1, \ast_2)`
 | 
			
		||||
 | 
			
		||||
Example::
 | 
			
		||||
 | 
			
		||||
    >>> input1 = torch.randn(100, 128)
 | 
			
		||||
 | 
			
		||||
@ -31,10 +31,10 @@ def _export(*args, **kwargs):
 | 
			
		||||
 | 
			
		||||
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
 | 
			
		||||
           input_names=None, output_names=None, operator_export_type=None,
 | 
			
		||||
           opset_version=None, _retain_param_name=True, do_constant_folding=True,
 | 
			
		||||
           example_outputs=None, strip_doc_string=True, dynamic_axes=None,
 | 
			
		||||
           keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True,
 | 
			
		||||
           use_external_data_format=False):
 | 
			
		||||
           opset_version=None, _retain_param_name=None, do_constant_folding=True,
 | 
			
		||||
           example_outputs=None, strip_doc_string=None, dynamic_axes=None,
 | 
			
		||||
           keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None,
 | 
			
		||||
           use_external_data_format=None):
 | 
			
		||||
    r"""
 | 
			
		||||
    Exports a model into ONNX format. If ``model`` is not a
 | 
			
		||||
    :class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs
 | 
			
		||||
@ -105,7 +105,9 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
 | 
			
		||||
            In this case, the exported model will first take all of its parameters
 | 
			
		||||
            as arguments, with the ordering as specified by ``model.state_dict().values()``
 | 
			
		||||
        verbose (bool, default False): if True, prints a description of the
 | 
			
		||||
            model being exported to stdout.
 | 
			
		||||
            model being exported to stdout. In addition, the final ONNX graph will include the
 | 
			
		||||
            field ``doc_string``` from the exported model which mentions the source code locations
 | 
			
		||||
            for ``model``.
 | 
			
		||||
        training (enum, default TrainingMode.EVAL):
 | 
			
		||||
            * ``TrainingMode.EVAL``: export the model in inference mode.
 | 
			
		||||
            * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
 | 
			
		||||
@ -181,16 +183,17 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
 | 
			
		||||
        opset_version (int, default 9):
 | 
			
		||||
            Must be ``== _onnx_main_opset or in _onnx_stable_opsets``,
 | 
			
		||||
            defined in torch/onnx/symbolic_helper.py.
 | 
			
		||||
        _retain_param_name (bool, default True): [Deprecated and ignored. Will be removed in next PyTorch
 | 
			
		||||
            release]
 | 
			
		||||
        do_constant_folding (bool, default False): Apply the constant-folding optimization.
 | 
			
		||||
            Constant-folding will replace some of the ops that have all constant inputs
 | 
			
		||||
            with pre-computed constant nodes.
 | 
			
		||||
        example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None):
 | 
			
		||||
            [Deprecated and ignored. Will be removed in next PyTorch release],
 | 
			
		||||
            Must be provided when exporting a ScriptModule or ScriptFunction, ignored otherwise.
 | 
			
		||||
            Used to determine the type and shape of the outputs without tracing the execution of
 | 
			
		||||
            the model. A single object is treated as equivalent to a tuple of one element.
 | 
			
		||||
        strip_doc_string (bool, default True): do not include the field
 | 
			
		||||
            ``doc_string``` from the exported model. Otherwise the field will mention the source
 | 
			
		||||
            code locations for ``model``.
 | 
			
		||||
        strip_doc_string (bool, default True): [Deprecated and ignored. Will be removed in next PyTorch release]
 | 
			
		||||
        dynamic_axes (dict<string, dict<int, string>> or dict<string, list(int)>, default empty dict):
 | 
			
		||||
 | 
			
		||||
            By default the exported model will have the shapes of all input and output tensors
 | 
			
		||||
@ -293,10 +296,11 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
 | 
			
		||||
 | 
			
		||||
        enable_onnx_checker (bool, default True): Deprecated and ignored. Will be removed in next
 | 
			
		||||
            Pytorch release.
 | 
			
		||||
        use_external_data_format (bool, default False): If True, then some of the model
 | 
			
		||||
            parameters are stored in external data files and not in the ONNX model file itself.
 | 
			
		||||
            Models larger than 2GB cannot be exported in one file because of size limits imposed
 | 
			
		||||
            by Protocol Buffers.
 | 
			
		||||
        use_external_data_format (bool, default False): [Deprecated and ignored. Will be removed in
 | 
			
		||||
            next Pytorch release.]
 | 
			
		||||
            If True, then some of the model parameters are stored in external data files and not in
 | 
			
		||||
            the ONNX model file itself. Models larger than 2GB cannot be exported in one file because
 | 
			
		||||
            of size limits imposed by Protocol Buffers.
 | 
			
		||||
            For details see
 | 
			
		||||
            `onnx.proto <https://github.com/onnx/onnx/blob/32c7cb66/onnx/onnx.proto#L562>`_.
 | 
			
		||||
            If True,  argument ``f`` must be a string specifying the location of the model.
 | 
			
		||||
@ -316,9 +320,24 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
 | 
			
		||||
                        custom_opsets, enable_onnx_checker, use_external_data_format)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_to_pretty_string(*args, **kwargs):
 | 
			
		||||
def export_to_pretty_string(*args, **kwargs) -> str:
 | 
			
		||||
    r"""
 | 
			
		||||
    Same as :func:`export`, but returns a text representation of the exported model.
 | 
			
		||||
    Similar to :func:`export`, but returns a text representation of the ONNX
 | 
			
		||||
    model. Only differences in args listed below. All other args are the same
 | 
			
		||||
    as :func:`export`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      f:  Deprecated and ignored. Will be removed in the next release of
 | 
			
		||||
          PyTorch.
 | 
			
		||||
      add_node_names (bool, default True): Whether or not to set
 | 
			
		||||
          NodeProto.name. This makes no difference unless
 | 
			
		||||
          ``google_printer=True``.
 | 
			
		||||
      google_printer (bool, default False): If False, will return a custom,
 | 
			
		||||
          compact representation of the model. If True will return the
 | 
			
		||||
          protobuf's `Message::DebugString()`, which is more verbose.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A UTF-8 str containing a human-readable representation of the ONNX model.
 | 
			
		||||
    """
 | 
			
		||||
    from torch.onnx import utils
 | 
			
		||||
    return utils.export_to_pretty_string(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,7 @@ import torch.serialization
 | 
			
		||||
import re
 | 
			
		||||
import collections
 | 
			
		||||
import contextlib
 | 
			
		||||
import copy
 | 
			
		||||
import numbers
 | 
			
		||||
import warnings
 | 
			
		||||
from torch._six import string_classes
 | 
			
		||||
@ -75,24 +76,37 @@ def select_model_mode_for_export(model, mode):
 | 
			
		||||
 | 
			
		||||
def export(model, args, f, export_params=True, verbose=False, training=None,
 | 
			
		||||
           input_names=None, output_names=None, operator_export_type=None,
 | 
			
		||||
           opset_version=None, _retain_param_name=True, do_constant_folding=True,
 | 
			
		||||
           example_outputs=None, strip_doc_string=True, dynamic_axes=None,
 | 
			
		||||
           opset_version=None, _retain_param_name=None, do_constant_folding=True,
 | 
			
		||||
           example_outputs=None, strip_doc_string=None, dynamic_axes=None,
 | 
			
		||||
           keep_initializers_as_inputs=None, custom_opsets=None,
 | 
			
		||||
           enable_onnx_checker=None, use_external_data_format=False):
 | 
			
		||||
           enable_onnx_checker=None, use_external_data_format=None):
 | 
			
		||||
    if operator_export_type is None:
 | 
			
		||||
        if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
 | 
			
		||||
            operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
 | 
			
		||||
        else:
 | 
			
		||||
            operator_export_type = OperatorExportTypes.ONNX
 | 
			
		||||
 | 
			
		||||
    if enable_onnx_checker is not None:
 | 
			
		||||
        warnings.warn("`enable_onnx_checker' is deprecated and ignored. It will be removed in"
 | 
			
		||||
                      "the next PyTorch release. To proceed despite ONNX checker failures, you"
 | 
			
		||||
                      "can catch torch.onnx.ONNXCheckerError.")
 | 
			
		||||
        warnings.warn("'enable_onnx_checker' is deprecated and ignored. It will be removed in "
 | 
			
		||||
                      "the next PyTorch release. To proceed despite ONNX checker failures, "
 | 
			
		||||
                      "catch torch.onnx.ONNXCheckerError.")
 | 
			
		||||
    if _retain_param_name is not None:
 | 
			
		||||
        warnings.warn("'_retain_param_name' is deprecated and ignored. "
 | 
			
		||||
                      "It will be removed in the next PyTorch release.")
 | 
			
		||||
    if strip_doc_string is not None:
 | 
			
		||||
        warnings.warn("`strip_doc_string' is deprecated and ignored. Will be removed in "
 | 
			
		||||
                      "next PyTorch release. It's combined with `verbose' argument now. ")
 | 
			
		||||
    if example_outputs is not None:
 | 
			
		||||
        warnings.warn("`example_outputs' is deprecated and ignored. Will be removed in "
 | 
			
		||||
                      "next PyTorch release.")
 | 
			
		||||
    if use_external_data_format is not None:
 | 
			
		||||
        warnings.warn("`use_external_data_format' is deprecated and ignored. Will be removed in next "
 | 
			
		||||
                      "PyTorch release. The code will work as it is False if models are not larger than 2GB, "
 | 
			
		||||
                      "Otherwise set to False because of size limits imposed by Protocol Buffers.")
 | 
			
		||||
 | 
			
		||||
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
 | 
			
		||||
            operator_export_type=operator_export_type, opset_version=opset_version,
 | 
			
		||||
            _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
 | 
			
		||||
            example_outputs=example_outputs, strip_doc_string=strip_doc_string,
 | 
			
		||||
            do_constant_folding=do_constant_folding, example_outputs=example_outputs,
 | 
			
		||||
            dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
 | 
			
		||||
            custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
 | 
			
		||||
 | 
			
		||||
@ -310,7 +324,10 @@ def _decide_external_data_format(use_external_data_format, operator_export_type,
 | 
			
		||||
    # string specifying the location of the model. For large model cases, if f is not a non-empty string,
 | 
			
		||||
    # then this method returns an empty string, which is an error condition for the large model export code
 | 
			
		||||
    # path later (but not for regular model export code path).
 | 
			
		||||
    model_file_location = f if val_use_external_data_format and isinstance(f, str) else str()
 | 
			
		||||
    if (val_use_external_data_format is None or val_use_external_data_format is True) and isinstance(f, str):
 | 
			
		||||
        model_file_location = f
 | 
			
		||||
    else:
 | 
			
		||||
        model_file_location = str()
 | 
			
		||||
    return val_use_external_data_format, model_file_location
 | 
			
		||||
 | 
			
		||||
def _decide_input_format(model, args):
 | 
			
		||||
@ -389,7 +406,7 @@ def _get_param_count_list(method_graph, args_params):
 | 
			
		||||
    return param_count_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_jit_graph(model, args, _retain_param_name):
 | 
			
		||||
def _create_jit_graph(model, args):
 | 
			
		||||
    torch_out = None
 | 
			
		||||
    params: Union[List, Tuple]
 | 
			
		||||
    if isinstance(model, torch.jit.ScriptModule):
 | 
			
		||||
@ -420,13 +437,12 @@ def _create_jit_graph(model, args, _retain_param_name):
 | 
			
		||||
        graph, torch_out = _trace_and_get_graph_from_model(model, args)
 | 
			
		||||
        state_dict = _unique_state_dict(model)
 | 
			
		||||
        params = list(state_dict.values())
 | 
			
		||||
        if _retain_param_name:
 | 
			
		||||
            graph_inputs = list(graph.inputs())
 | 
			
		||||
            user_input_num = len(graph_inputs) - len(state_dict)
 | 
			
		||||
            param_names = list(state_dict.keys())
 | 
			
		||||
            for i, inp in enumerate(graph_inputs):
 | 
			
		||||
                if i >= user_input_num:
 | 
			
		||||
                    inp.setDebugName(param_names[i - user_input_num])
 | 
			
		||||
        graph_inputs = list(graph.inputs())
 | 
			
		||||
        user_input_num = len(graph_inputs) - len(state_dict)
 | 
			
		||||
        param_names = list(state_dict.keys())
 | 
			
		||||
        for i, inp in enumerate(graph_inputs):
 | 
			
		||||
            if i >= user_input_num:
 | 
			
		||||
                inp.setDebugName(param_names[i - user_input_num])
 | 
			
		||||
        torch._C._jit_pass_onnx_function_substitution(graph)
 | 
			
		||||
        return graph, params, torch_out, None
 | 
			
		||||
 | 
			
		||||
@ -437,12 +453,25 @@ def _get_named_param_dict(graph, params):
 | 
			
		||||
    _params_dict = dict(zip(param_names, params))
 | 
			
		||||
    return _params_dict
 | 
			
		||||
 | 
			
		||||
def _get_example_outputs(model, args):
 | 
			
		||||
    input_args = copy.deepcopy(args)
 | 
			
		||||
    input_kwargs = {}
 | 
			
		||||
    if input_args and isinstance(input_args[-1], dict):
 | 
			
		||||
        input_kwargs = input_args[-1]
 | 
			
		||||
        input_args = input_args[:-1]
 | 
			
		||||
 | 
			
		||||
    example_outputs = model(*input_args, **input_kwargs)
 | 
			
		||||
    if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
 | 
			
		||||
        example_outputs = (example_outputs,)
 | 
			
		||||
 | 
			
		||||
    if isinstance(example_outputs, list):
 | 
			
		||||
        example_outputs = [example_outputs]
 | 
			
		||||
    return example_outputs
 | 
			
		||||
 | 
			
		||||
def _model_to_graph(model, args, verbose=False,
 | 
			
		||||
                    input_names=None, output_names=None,
 | 
			
		||||
                    operator_export_type=OperatorExportTypes.ONNX,
 | 
			
		||||
                    example_outputs=None,
 | 
			
		||||
                    _retain_param_name=False, do_constant_folding=True,
 | 
			
		||||
                    example_outputs=None, do_constant_folding=True,
 | 
			
		||||
                    _disable_torch_constant_prop=False, fixed_batch_size=False,
 | 
			
		||||
                    training=None, dynamic_axes=None):
 | 
			
		||||
    r"""Converts model into an ONNX graph.
 | 
			
		||||
@ -461,11 +490,7 @@ def _model_to_graph(model, args, verbose=False,
 | 
			
		||||
    if isinstance(args, (torch.Tensor, int, float, bool)):
 | 
			
		||||
        args = (args, )
 | 
			
		||||
 | 
			
		||||
    if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
 | 
			
		||||
        example_outputs = (example_outputs,)
 | 
			
		||||
 | 
			
		||||
    graph, params, torch_out, module = _create_jit_graph(model, args,
 | 
			
		||||
                                                         _retain_param_name)
 | 
			
		||||
    graph, params, torch_out, module = _create_jit_graph(model, args)
 | 
			
		||||
 | 
			
		||||
    params_dict = _get_named_param_dict(graph, params)
 | 
			
		||||
 | 
			
		||||
@ -476,13 +501,22 @@ def _model_to_graph(model, args, verbose=False,
 | 
			
		||||
                            module=module)
 | 
			
		||||
    from torch.onnx.symbolic_helper import _onnx_shape_inference
 | 
			
		||||
    if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
 | 
			
		||||
        assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
 | 
			
		||||
                                            "ScriptFunction."
 | 
			
		||||
        if isinstance(example_outputs, list):
 | 
			
		||||
            example_outputs = [example_outputs]
 | 
			
		||||
        if example_outputs is None:
 | 
			
		||||
            example_outputs = _get_example_outputs(model, args)
 | 
			
		||||
        else:
 | 
			
		||||
            # example_outpus specified
 | 
			
		||||
            if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
 | 
			
		||||
                example_outputs = (example_outputs,)
 | 
			
		||||
 | 
			
		||||
            if isinstance(example_outputs, list):
 | 
			
		||||
                example_outputs = [example_outputs]
 | 
			
		||||
 | 
			
		||||
        out_vars, desc = torch.jit._flatten(tuple(example_outputs))
 | 
			
		||||
        torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, desc, _onnx_shape_inference)
 | 
			
		||||
    else:
 | 
			
		||||
        flatten_args, _ = torch._C._jit_flatten(args)
 | 
			
		||||
        # make sure that the param dict and the graph match each other
 | 
			
		||||
        assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())
 | 
			
		||||
 | 
			
		||||
    # NB: ONNX requires complete information about output types, which might be
 | 
			
		||||
    # erased by some optimizations, so we need to set it explicitly again.
 | 
			
		||||
@ -533,14 +567,18 @@ def _model_to_graph(model, args, verbose=False,
 | 
			
		||||
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
 | 
			
		||||
                            input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
 | 
			
		||||
                            export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
 | 
			
		||||
                            google_printer=False, opset_version=None, _retain_param_name=True,
 | 
			
		||||
                            google_printer=False, opset_version=None, _retain_param_name=None,
 | 
			
		||||
                            keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True,
 | 
			
		||||
                            do_constant_folding=True, dynamic_axes=None):
 | 
			
		||||
    if f is not None:
 | 
			
		||||
        warnings.warn("'f' is deprecated and ignored. It will be removed in the next PyTorch release.")
 | 
			
		||||
    if _retain_param_name is not None:
 | 
			
		||||
        warnings.warn("'_retain_param_name' is deprecated and ignored. "
 | 
			
		||||
                      "It will be removed in the next PyTorch release.")
 | 
			
		||||
    return _export_to_pretty_string(model, args, f, export_params, verbose, training,
 | 
			
		||||
                                    input_names, output_names, operator_export_type,
 | 
			
		||||
                                    export_type, example_outputs, google_printer,
 | 
			
		||||
                                    opset_version, _retain_param_name,
 | 
			
		||||
                                    do_constant_folding=do_constant_folding,
 | 
			
		||||
                                    opset_version, do_constant_folding=do_constant_folding,
 | 
			
		||||
                                    add_node_names=add_node_names,
 | 
			
		||||
                                    keep_initializers_as_inputs=keep_initializers_as_inputs,
 | 
			
		||||
                                    custom_opsets=custom_opsets, dynamic_axes=dynamic_axes)
 | 
			
		||||
@ -549,7 +587,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
 | 
			
		||||
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
 | 
			
		||||
                             input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
 | 
			
		||||
                             export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
 | 
			
		||||
                             google_printer=False, opset_version=None, _retain_param_name=False,
 | 
			
		||||
                             google_printer=False, opset_version=None,
 | 
			
		||||
                             do_constant_folding=True, keep_initializers_as_inputs=None,
 | 
			
		||||
                             fixed_batch_size=False, custom_opsets=None, add_node_names=True,
 | 
			
		||||
                             onnx_shape_inference=True, dynamic_axes=None):
 | 
			
		||||
@ -572,8 +610,8 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
 | 
			
		||||
        args = _decide_input_format(model, args)
 | 
			
		||||
        graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
 | 
			
		||||
                                                        output_names, operator_export_type,
 | 
			
		||||
                                                        example_outputs, _retain_param_name,
 | 
			
		||||
                                                        val_do_constant_folding, fixed_batch_size=fixed_batch_size,
 | 
			
		||||
                                                        example_outputs, val_do_constant_folding,
 | 
			
		||||
                                                        fixed_batch_size=fixed_batch_size,
 | 
			
		||||
                                                        training=training, dynamic_axes=dynamic_axes)
 | 
			
		||||
 | 
			
		||||
        return graph._pretty_print_onnx(params_dict, opset_version, False,
 | 
			
		||||
@ -633,10 +671,10 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini
 | 
			
		||||
def _export(model, args, f, export_params=True, verbose=False, training=None,
 | 
			
		||||
            input_names=None, output_names=None, operator_export_type=None,
 | 
			
		||||
            export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
 | 
			
		||||
            opset_version=None, _retain_param_name=False, do_constant_folding=True,
 | 
			
		||||
            strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None,
 | 
			
		||||
            opset_version=None, do_constant_folding=True,
 | 
			
		||||
            dynamic_axes=None, keep_initializers_as_inputs=None,
 | 
			
		||||
            fixed_batch_size=False, custom_opsets=None, add_node_names=True,
 | 
			
		||||
            use_external_data_format=False, onnx_shape_inference=True):
 | 
			
		||||
            use_external_data_format=None, onnx_shape_inference=True):
 | 
			
		||||
 | 
			
		||||
    if isinstance(model, torch.nn.DataParallel):
 | 
			
		||||
        raise ValueError("torch.nn.DataParallel is not supported by ONNX "
 | 
			
		||||
@ -685,8 +723,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
 | 
			
		||||
            graph, params_dict, torch_out = \
 | 
			
		||||
                _model_to_graph(model, args, verbose, input_names,
 | 
			
		||||
                                output_names, operator_export_type,
 | 
			
		||||
                                example_outputs, _retain_param_name,
 | 
			
		||||
                                val_do_constant_folding,
 | 
			
		||||
                                example_outputs, val_do_constant_folding,
 | 
			
		||||
                                fixed_batch_size=fixed_batch_size,
 | 
			
		||||
                                training=training,
 | 
			
		||||
                                dynamic_axes=dynamic_axes)
 | 
			
		||||
@ -697,14 +734,14 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
 | 
			
		||||
                custom_opsets = {}
 | 
			
		||||
 | 
			
		||||
            if export_params:
 | 
			
		||||
                proto, export_map = graph._export_onnx(
 | 
			
		||||
                proto, export_map, val_use_external_data_format = graph._export_onnx(
 | 
			
		||||
                    params_dict, opset_version, dynamic_axes, defer_weight_export,
 | 
			
		||||
                    operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
 | 
			
		||||
                    operator_export_type, not verbose, val_keep_init_as_ip, custom_opsets,
 | 
			
		||||
                    val_add_node_names, val_use_external_data_format, model_file_location)
 | 
			
		||||
            else:
 | 
			
		||||
                proto, export_map = graph._export_onnx(
 | 
			
		||||
                proto, export_map, val_use_external_data_format = graph._export_onnx(
 | 
			
		||||
                    {}, opset_version, dynamic_axes, False, operator_export_type,
 | 
			
		||||
                    strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
 | 
			
		||||
                    not verbose, val_keep_init_as_ip, custom_opsets, val_add_node_names,
 | 
			
		||||
                    val_use_external_data_format, model_file_location)
 | 
			
		||||
 | 
			
		||||
            if export_type == ExportTypes.PROTOBUF_FILE:
 | 
			
		||||
 | 
			
		||||
@ -1256,6 +1256,8 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
 | 
			
		||||
            yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
 | 
			
		||||
        # Test for Broadcasting
 | 
			
		||||
        yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
 | 
			
		||||
        yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
 | 
			
		||||
        yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
 | 
			
		||||
 | 
			
		||||
    return list(generator())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,9 @@ class MapperIterDataPipe(IterDataPipe):
 | 
			
		||||
        self.dp = dp
 | 
			
		||||
        self.fn = fn
 | 
			
		||||
```
 | 
			
		||||
Note: Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory.
 | 
			
		||||
Note:
 | 
			
		||||
- Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory.
 | 
			
		||||
- If `IterDataPipe` instance holds data in memory, please be ware of the in-place modification of data. When second iterator is created from the instance, the data may have already changed. Please take [`IterableWrapper`](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/datapipes/iter/utils.py) class as reference to `deepcopy` data for each iterator.
 | 
			
		||||
 | 
			
		||||
### Iterator
 | 
			
		||||
For `IterDataPipe`, an `__iter__` function is needed to consume data from the source `IterDataPipe` then apply operation over the data before yield.
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
import copy
 | 
			
		||||
import warnings
 | 
			
		||||
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
 | 
			
		||||
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
 | 
			
		||||
@ -99,8 +98,6 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
 | 
			
		||||
            data = list(data)
 | 
			
		||||
        else:
 | 
			
		||||
            t_flag = False
 | 
			
		||||
            # Deepcopy data to prevent the original data modified. E.g. list, dict
 | 
			
		||||
            data = copy.deepcopy(data)
 | 
			
		||||
 | 
			
		||||
        if self.output_col is None:
 | 
			
		||||
            if isinstance(self.input_col, (list, tuple)):
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,9 @@
 | 
			
		||||
import random
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
 | 
			
		||||
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
 | 
			
		||||
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
 | 
			
		||||
 | 
			
		||||
T_co = TypeVar('T_co', covariant=True)
 | 
			
		||||
@ -185,8 +185,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
 | 
			
		||||
        assert batch_size > 0, "Batch size is required to be larger than 0!"
 | 
			
		||||
        assert batch_num > 0, "Number of batches is required to be larger than 0!"
 | 
			
		||||
        assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
 | 
			
		||||
 | 
			
		||||
        warnings.warn("`BucketBatcher` is going to be removed from PyTorch Core")
 | 
			
		||||
        deprecation_warning_torchdata(type(self).__name__)
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        # TODO: Verify _datapippe is not going to be serialized twice
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ from typing import Sized, Tuple
 | 
			
		||||
from urllib.error import HTTPError, URLError
 | 
			
		||||
import urllib.request as urllib
 | 
			
		||||
from torch.utils.data import IterDataPipe
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
 | 
			
		||||
@ -19,6 +20,7 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
 | 
			
		||||
    def __init__(self, datapipe, timeout=None):
 | 
			
		||||
        self.datapipe = datapipe
 | 
			
		||||
        self.timeout = timeout
 | 
			
		||||
        deprecation_warning_torchdata(type(self).__name__)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for furl in self.datapipe:
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from torch.utils.data import IterDataPipe
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
 | 
			
		||||
@ -14,6 +15,7 @@ class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, datapipe):
 | 
			
		||||
        self.datapipe = datapipe
 | 
			
		||||
        deprecation_warning_torchdata(type(self).__name__)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for file_name, stream in self.datapipe:
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
from torch.utils.data import IterDataPipe
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple, deprecation_warning_torchdata
 | 
			
		||||
from typing import Iterable, Iterator, Tuple, Optional, IO, cast
 | 
			
		||||
from io import BufferedIOBase
 | 
			
		||||
 | 
			
		||||
@ -34,11 +34,13 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
 | 
			
		||||
        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
 | 
			
		||||
        self.mode: str = mode
 | 
			
		||||
        self.length: int = length
 | 
			
		||||
        deprecation_warning_torchdata(type(self).__name__)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
 | 
			
		||||
        for data in self.datapipe:
 | 
			
		||||
            validate_pathname_binary_tuple(data)
 | 
			
		||||
            pathname, data_stream = data
 | 
			
		||||
            folder_name = os.path.dirname(pathname)
 | 
			
		||||
            try:
 | 
			
		||||
                # typing.cast is used here to silence mypy's type checker
 | 
			
		||||
                tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
 | 
			
		||||
@ -49,14 +51,12 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
 | 
			
		||||
                    if extracted_fobj is None:
 | 
			
		||||
                        warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
 | 
			
		||||
                        raise tarfile.ExtractError
 | 
			
		||||
                    inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
 | 
			
		||||
                    inner_pathname = os.path.normpath(os.path.join(folder_name, tarinfo.name))
 | 
			
		||||
                    yield (inner_pathname, extracted_fobj)  # type: ignore[misc]
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
 | 
			
		||||
                raise e
 | 
			
		||||
            finally:
 | 
			
		||||
                data_stream.close()
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        if self.length == -1:
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,5 @@
 | 
			
		||||
import copy
 | 
			
		||||
import warnings
 | 
			
		||||
from torch.utils.data import IterDataPipe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,12 +10,34 @@ class IterableWrapperIterDataPipe(IterDataPipe):
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        iterable: Iterable object to be wrapped into an IterDataPipe
 | 
			
		||||
        deepcopy: Option to deepcopy input iterable object for each
 | 
			
		||||
            iteration.
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
      If `deepcopy` is set to False explicitly, users should ensure
 | 
			
		||||
      that data pipeline doesn't contain any in-place operations over
 | 
			
		||||
      the iterable instance, in order to prevent data inconsistency
 | 
			
		||||
      across iterations.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, iterable):
 | 
			
		||||
    def __init__(self, iterable, deepcopy=True):
 | 
			
		||||
        self.iterable = iterable
 | 
			
		||||
        self.deepcopy = deepcopy
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for data in self.iterable:
 | 
			
		||||
        source_data = self.iterable
 | 
			
		||||
        if self.deepcopy:
 | 
			
		||||
            try:
 | 
			
		||||
                source_data = copy.deepcopy(self.iterable)
 | 
			
		||||
            # For the case that data cannot be deep-copied,
 | 
			
		||||
            # all in-place operations will affect iterable variable.
 | 
			
		||||
            # When this DataPipe is iterated second time, it will
 | 
			
		||||
            # yield modified items.
 | 
			
		||||
            except TypeError:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "The input iterable can not be deepcopied, "
 | 
			
		||||
                    "please be aware of in-place modification would affect source data"
 | 
			
		||||
                )
 | 
			
		||||
        for data in source_data:
 | 
			
		||||
            yield data
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
from torch.utils.data import IterDataPipe
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
 | 
			
		||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple, deprecation_warning_torchdata
 | 
			
		||||
from typing import Iterable, Iterator, Tuple, IO, cast
 | 
			
		||||
from io import BufferedIOBase
 | 
			
		||||
 | 
			
		||||
@ -31,11 +31,13 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
 | 
			
		||||
        self.length: int = length
 | 
			
		||||
        deprecation_warning_torchdata(type(self).__name__)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
 | 
			
		||||
        for data in self.datapipe:
 | 
			
		||||
            validate_pathname_binary_tuple(data)
 | 
			
		||||
            pathname, data_stream = data
 | 
			
		||||
            folder_name = os.path.dirname(pathname)
 | 
			
		||||
            try:
 | 
			
		||||
                # typing.cast is used here to silence mypy's type checker
 | 
			
		||||
                zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
 | 
			
		||||
@ -47,7 +49,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
 | 
			
		||||
                    elif zipinfo.filename.endswith('/'):
 | 
			
		||||
                        continue
 | 
			
		||||
                    extracted_fobj = zips.open(zipinfo)
 | 
			
		||||
                    inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
 | 
			
		||||
                    inner_pathname = os.path.normpath(os.path.join(folder_name, zipinfo.filename))
 | 
			
		||||
                    yield (inner_pathname, extracted_fobj)  # type: ignore[misc]
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
 | 
			
		||||
@ -63,3 +63,9 @@ def validate_pathname_binary_tuple(data):
 | 
			
		||||
        raise TypeError("pathname binary tuple should have string type pathname, but got {}".format(type(data[0])))
 | 
			
		||||
    if not isinstance(data[1], BufferedIOBase):
 | 
			
		||||
        raise TypeError("pathname binary tuple should have BufferedIOBase based binary type, but got {}".format(type(data[1])))
 | 
			
		||||
 | 
			
		||||
# Warns user that the DataPipe has been moved to TorchData and will be removed from `torch`
 | 
			
		||||
def deprecation_warning_torchdata(name):
 | 
			
		||||
    warnings.warn(f"{name} and its functional API are deprecated and will be removed from the package `torch`. "
 | 
			
		||||
                  f"Please import those features from the new package TorchData: https://github.com/pytorch/data",
 | 
			
		||||
                  DeprecationWarning)
 | 
			
		||||
 | 
			
		||||
@ -112,15 +112,18 @@ class RandomSampler(Sampler[int]):
 | 
			
		||||
    def __iter__(self) -> Iterator[int]:
 | 
			
		||||
        n = len(self.data_source)
 | 
			
		||||
        if self.generator is None:
 | 
			
		||||
            self.generator = torch.Generator()
 | 
			
		||||
            self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
 | 
			
		||||
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
 | 
			
		||||
            generator = torch.Generator()
 | 
			
		||||
            generator.manual_seed(seed)
 | 
			
		||||
        else:
 | 
			
		||||
            generator = self.generator
 | 
			
		||||
 | 
			
		||||
        if self.replacement:
 | 
			
		||||
            for _ in range(self.num_samples // 32):
 | 
			
		||||
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
 | 
			
		||||
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
 | 
			
		||||
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
 | 
			
		||||
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
 | 
			
		||||
        else:
 | 
			
		||||
            yield from torch.randperm(n, generator=self.generator).tolist()
 | 
			
		||||
            yield from torch.randperm(n, generator=generator).tolist()
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return self.num_samples
 | 
			
		||||
@ -140,7 +143,8 @@ class SubsetRandomSampler(Sampler[int]):
 | 
			
		||||
        self.generator = generator
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[int]:
 | 
			
		||||
        return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
 | 
			
		||||
        for i in torch.randperm(len(self.indices), generator=self.generator):
 | 
			
		||||
            yield self.indices[i]
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return len(self.indices)
 | 
			
		||||
@ -183,7 +187,7 @@ class WeightedRandomSampler(Sampler[int]):
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[int]:
 | 
			
		||||
        rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
 | 
			
		||||
        return iter(rand_tensor.tolist())
 | 
			
		||||
        yield from iter(rand_tensor.tolist())
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return self.num_samples
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user