mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
Compare commits
17 Commits
revert-fai
...
v1.10.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
| 9509e8a3d6 | |||
| 1774a6a2f4 | |||
| a27906c250 | |||
| 49f52b6c07 | |||
| 5f1a434599 | |||
| ecbf5a7439 | |||
| 4e3ebebcff | |||
| 2b46c95e7c | |||
| 5f3eee1ca5 | |||
| 4731f33d02 | |||
| ecfcb8ff5a | |||
| 6aadfda9e2 | |||
| 13666d20fd | |||
| 1fa17a20fc | |||
| c05547fa6c | |||
| 0e857bf109 | |||
| ad22804b95 |
14
.circleci/config.yml
generated
14
.circleci/config.yml
generated
@ -632,7 +632,8 @@ jobs:
|
||||
}
|
||||
|
||||
if is_vanilla_build; then
|
||||
echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
|
||||
echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
|
||||
echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
|
||||
echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
|
||||
else
|
||||
echo "Skipping for ${BUILD_ENVIRONMENT}"
|
||||
@ -1734,16 +1735,17 @@ jobs:
|
||||
# install fastlane
|
||||
sudo gem install bundler && bundle install
|
||||
# install certificates
|
||||
echo ${IOS_CERT_KEY} >> cert.txt
|
||||
echo ${IOS_CERT_KEY_2022} >> cert.txt
|
||||
base64 --decode cert.txt -o Certificates.p12
|
||||
rm cert.txt
|
||||
bundle exec fastlane install_cert
|
||||
bundle exec fastlane install_root_cert
|
||||
bundle exec fastlane install_dev_cert
|
||||
# install the provisioning profile
|
||||
PROFILE=PyTorch_CI_2021.mobileprovision
|
||||
PROFILE=PyTorch_CI_2022.mobileprovision
|
||||
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
|
||||
mkdir -pv "${PROVISIONING_PROFILES}"
|
||||
cd "${PROVISIONING_PROFILES}"
|
||||
echo ${IOS_SIGN_KEY} >> cert.txt
|
||||
echo ${IOS_SIGN_KEY_2022} >> cert.txt
|
||||
base64 --decode cert.txt -o ${PROFILE}
|
||||
rm cert.txt
|
||||
- run:
|
||||
@ -1802,7 +1804,7 @@ jobs:
|
||||
command: |
|
||||
set -e
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
PROFILE=PyTorch_CI_2021
|
||||
PROFILE=PyTorch_CI_2022
|
||||
# run the ruby build script
|
||||
if ! [ -x "$(command -v xcodebuild)" ]; then
|
||||
echo 'Error: xcodebuild is not installed.'
|
||||
|
||||
@ -61,7 +61,7 @@ git --no-pager log --max-count 1
|
||||
popd
|
||||
|
||||
# Clone the Builder master repo
|
||||
retry git clone -q https://github.com/pytorch/builder.git "$BUILDER_ROOT"
|
||||
retry git clone -q https://github.com/pytorch/builder.git -b release/1.10 "$BUILDER_ROOT"
|
||||
pushd "$BUILDER_ROOT"
|
||||
echo "Using builder from "
|
||||
git --no-pager log --max-count 1
|
||||
|
||||
@ -8,16 +8,17 @@ cd ${PROJ_ROOT}/ios/TestApp
|
||||
# install fastlane
|
||||
sudo gem install bundler && bundle install
|
||||
# install certificates
|
||||
echo "${IOS_CERT_KEY}" >> cert.txt
|
||||
echo "${IOS_CERT_KEY_2022}" >> cert.txt
|
||||
base64 --decode cert.txt -o Certificates.p12
|
||||
rm cert.txt
|
||||
bundle exec fastlane install_cert
|
||||
bundle exec fastlane install_root_cert
|
||||
bundle exec fastlane install_dev_cert
|
||||
# install the provisioning profile
|
||||
PROFILE=PyTorch_CI_2021.mobileprovision
|
||||
PROFILE=PyTorch_CI_2022.mobileprovision
|
||||
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
|
||||
mkdir -pv "${PROVISIONING_PROFILES}"
|
||||
cd "${PROVISIONING_PROFILES}"
|
||||
echo "${IOS_SIGN_KEY}" >> cert.txt
|
||||
echo "${IOS_SIGN_KEY_2022}" >> cert.txt
|
||||
base64 --decode cert.txt -o ${PROFILE}
|
||||
rm cert.txt
|
||||
# run the ruby build script
|
||||
@ -25,5 +26,5 @@ if ! [ -x "$(command -v xcodebuild)" ]; then
|
||||
echo 'Error: xcodebuild is not installed.'
|
||||
exit 1
|
||||
fi
|
||||
PROFILE=PyTorch_CI_2021
|
||||
PROFILE=PyTorch_CI_2022
|
||||
ruby ${PROJ_ROOT}/scripts/xcode_build.rb -i ${PROJ_ROOT}/build_ios/install -x ${PROJ_ROOT}/ios/TestApp/TestApp.xcodeproj -p ${IOS_PLATFORM} -c ${PROFILE} -t ${IOS_DEV_TEAM_ID} -f Accelerate,MetalPerformanceShaders,CoreML
|
||||
|
||||
@ -467,16 +467,17 @@
|
||||
# install fastlane
|
||||
sudo gem install bundler && bundle install
|
||||
# install certificates
|
||||
echo ${IOS_CERT_KEY} >> cert.txt
|
||||
echo ${IOS_CERT_KEY_2022} >> cert.txt
|
||||
base64 --decode cert.txt -o Certificates.p12
|
||||
rm cert.txt
|
||||
bundle exec fastlane install_cert
|
||||
bundle exec fastlane install_root_cert
|
||||
bundle exec fastlane install_dev_cert
|
||||
# install the provisioning profile
|
||||
PROFILE=PyTorch_CI_2021.mobileprovision
|
||||
PROFILE=PyTorch_CI_2022.mobileprovision
|
||||
PROVISIONING_PROFILES=~/Library/MobileDevice/Provisioning\ Profiles
|
||||
mkdir -pv "${PROVISIONING_PROFILES}"
|
||||
cd "${PROVISIONING_PROFILES}"
|
||||
echo ${IOS_SIGN_KEY} >> cert.txt
|
||||
echo ${IOS_SIGN_KEY_2022} >> cert.txt
|
||||
base64 --decode cert.txt -o ${PROFILE}
|
||||
rm cert.txt
|
||||
- run:
|
||||
@ -535,7 +536,7 @@
|
||||
command: |
|
||||
set -e
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
PROFILE=PyTorch_CI_2021
|
||||
PROFILE=PyTorch_CI_2022
|
||||
# run the ruby build script
|
||||
if ! [ -x "$(command -v xcodebuild)" ]; then
|
||||
echo 'Error: xcodebuild is not installed.'
|
||||
|
||||
@ -158,7 +158,8 @@ jobs:
|
||||
}
|
||||
|
||||
if is_vanilla_build; then
|
||||
echo "apt-get update && apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
|
||||
echo "apt-get update || apt-get install libgnutls30" | docker exec -u root -i "$id" bash
|
||||
echo "apt-get install -y qemu-user gdb" | docker exec -u root -i "$id" bash
|
||||
echo "cd workspace/build; qemu-x86_64 -g 2345 -cpu Broadwell -E ATEN_CPU_CAPABILITY=default ./bin/basic --gtest_filter=BasicTest.BasicTestCPU & gdb ./bin/basic -ex 'set pagination off' -ex 'target remote :2345' -ex 'continue' -ex 'bt' -ex='set confirm off' -ex 'quit \$_isvoid(\$_exitcode)'" | docker exec -u jenkins -i "$id" bash
|
||||
else
|
||||
echo "Skipping for ${BUILD_ENVIRONMENT}"
|
||||
|
||||
@ -63,9 +63,9 @@ function get_pr_change_files() {
|
||||
function file_diff_from_base() {
|
||||
# The fetch may fail on Docker hosts, this fetch is necessary for GHA
|
||||
set +e
|
||||
git fetch origin master --quiet
|
||||
git fetch origin release/1.10 --quiet
|
||||
set -e
|
||||
git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
|
||||
git diff --name-only "$(git merge-base origin/release/1.10 HEAD)" > "$1"
|
||||
}
|
||||
|
||||
function get_bazel() {
|
||||
@ -99,5 +99,5 @@ function checkout_install_torchvision() {
|
||||
}
|
||||
|
||||
function clone_pytorch_xla() {
|
||||
git clone --recursive https://github.com/pytorch/xla.git
|
||||
git clone --recursive -b r1.10 https://github.com/pytorch/xla.git
|
||||
}
|
||||
|
||||
@ -424,7 +424,7 @@ test_backward_compatibility() {
|
||||
python -m venv venv
|
||||
# shellcheck disable=SC1091
|
||||
. venv/bin/activate
|
||||
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_nightly.html
|
||||
pip show torch
|
||||
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
||||
deactivate
|
||||
|
||||
@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
|
||||
}
|
||||
|
||||
Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
|
||||
TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ",
|
||||
x1.ndimension(), " and x2 has ", x2.ndimension());
|
||||
TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ",
|
||||
x1.size(dim), " and x2 has ", x2.size(dim));
|
||||
auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes());
|
||||
auto commonDtype = at::result_type(x1, x2);
|
||||
TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
|
||||
Tensor x1_ = x1.to(commonDtype);
|
||||
Tensor x2_ = x2.to(commonDtype);
|
||||
Tensor x1_ = x1.to(commonDtype).expand(common_size);
|
||||
Tensor x2_ = x2.to(commonDtype).expand(common_size);
|
||||
// Follow scipy impl to improve numerical precision
|
||||
// Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
|
||||
Tensor w12 = at::sum(x1_ * x2_, dim);
|
||||
|
||||
@ -307,11 +307,11 @@ If the operator is an ATen operator (shows up in the TorchScript graph with the
|
||||
|
||||
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
|
||||
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
|
||||
Make sure the function has the same name as the ATen function, which may declared in
|
||||
Make sure the function has the same name as the ATen function, which may be declared in
|
||||
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
|
||||
build time, so will not appear in your checkout until you build PyTorch).
|
||||
* The first arg is always the ONNX graph that is being built for export.
|
||||
Other arg names must EXACTLY match the names in ``_VariableFunctions.pyi``,
|
||||
Other arg names must EXACTLY match the names in the ``.pyi`` file,
|
||||
because dispatch is done with keyword arguments.
|
||||
* In the symbolic function, if the operator is in the
|
||||
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
|
||||
@ -365,8 +365,8 @@ See the ``symbolic_opset*.py`` files for more examples.
|
||||
torch.autograd.Functions
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If the operator is defined in a sub-class of :class:`torch.autograd.Function`,
|
||||
there are two ways to export it.
|
||||
If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
|
||||
to export it.
|
||||
|
||||
Static Symbolic Method
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -388,11 +388,11 @@ PythonOp Symbolic
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Alternatively, you can register a custom symbolic function.
|
||||
This gives the symoblic function access to more info through the
|
||||
This gives the symbolic function access to more info through the
|
||||
TorchScript ``Node`` object for the original operation, which gets passed in as the second
|
||||
argument (after the ``Graph`` object).
|
||||
|
||||
All autograd ``Function``s are emitted in the TorchScript graph as ``prim::PythonOp`` nodes.
|
||||
All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
|
||||
In order to differentiate between different ``Function`` subclasses, the
|
||||
symbolic function should use the ``name`` kwarg which gets set to the name of the class.
|
||||
|
||||
@ -400,11 +400,12 @@ symbolic function should use the ``name`` kwarg which gets set to the name of th
|
||||
the ``prim`` namespace, so for this use case, there's a back door: register the
|
||||
symbolic for ``"::prim_PythonOp"``.
|
||||
|
||||
Please also consider adding shape inference logic when you regiester a custom symbolic function
|
||||
via setType API. This can help the exporter to obtain correct shape inference.
|
||||
An example of setType is test_aten_embedding_2 in test_operators.py.
|
||||
Although it is not required to add shape inference logic,
|
||||
the exporter emits a warning message if it is not added.
|
||||
Custom symbolic functions should add type and shape information by calling ``setType(...)``
|
||||
on Value objects before returning them (implemented in C++ by
|
||||
``torch::jit::Value::setType``). This is not required, but it can help the exporter's
|
||||
shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
|
||||
``test_aten_embedding_2`` in
|
||||
`test_operators.py <https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py>`_.
|
||||
|
||||
The example below shows how you can access ``requires_grad`` via the ``Node`` object::
|
||||
|
||||
@ -430,13 +431,17 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
|
||||
print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))
|
||||
|
||||
name = kwargs["name"]
|
||||
ret = None
|
||||
if name == "MyClip":
|
||||
return g.op("Clip", args[0], min_f=args[1])
|
||||
ret = g.op("Clip", args[0], min_f=args[1])
|
||||
elif name == "MyRelu":
|
||||
return g.op("Relu", args[0])
|
||||
ret = g.op("Relu", args[0])
|
||||
else:
|
||||
# Logs a warning and returns None
|
||||
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
|
||||
# Copy type and shape from original node.
|
||||
ret.setType(n.type())
|
||||
return ret
|
||||
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 1)
|
||||
|
||||
@ -597,5 +597,4 @@ Utilities
|
||||
are_deterministic_algorithms_enabled
|
||||
set_warn_always
|
||||
is_warn_always_enabled
|
||||
vmap
|
||||
_assert
|
||||
|
||||
BIN
ios/TestApp/AppleWWDRCAG3.cer
Normal file
BIN
ios/TestApp/AppleWWDRCAG3.cer
Normal file
Binary file not shown.
@ -4,7 +4,14 @@ platform :ios do
|
||||
before_all do
|
||||
setup_circle_ci
|
||||
end
|
||||
lane :install_cert do
|
||||
lane :install_root_cert do
|
||||
import_certificate(
|
||||
certificate_path: "AppleWWDRCAG3.cer",
|
||||
keychain_path: "/Users/distiller/Library/Keychains/fastlane_tmp_keychain-db",
|
||||
keychain_password: ""
|
||||
)
|
||||
end
|
||||
lane :install_dev_cert do
|
||||
puts "Installing Certificates.p12"
|
||||
import_certificate(
|
||||
keychain_name: ENV["MATCH_KEYCHAIN_NAME"],
|
||||
|
||||
@ -64,7 +64,7 @@ class TestExportModes(JitTestCase):
|
||||
return (a, a)
|
||||
f = io.BytesIO()
|
||||
x = torch.ones(3)
|
||||
torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
|
||||
torch.onnx._export(foo, (x,), f)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_aten_fallback(self):
|
||||
@ -76,9 +76,8 @@ class TestExportModes(JitTestCase):
|
||||
|
||||
x = torch.rand(3, 4)
|
||||
y = torch.rand(3, 4)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export_to_pretty_string(
|
||||
ModelWithAtenNotONNXOp(), (x, y), f,
|
||||
ModelWithAtenNotONNXOp(), (x, y), None,
|
||||
add_node_names=False,
|
||||
do_constant_folding=False,
|
||||
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
@ -91,11 +90,10 @@ class TestExportModes(JitTestCase):
|
||||
def forward(self, x, y):
|
||||
return torch.fmod(x, y)
|
||||
|
||||
f = io.BytesIO()
|
||||
x = torch.randn(3, 4, dtype=torch.float32)
|
||||
y = torch.randn(3, 4, dtype=torch.float32)
|
||||
torch.onnx.export_to_pretty_string(
|
||||
ModelWithAtenFmod(), (x, y), f,
|
||||
ModelWithAtenFmod(), (x, y), None,
|
||||
add_node_names=False,
|
||||
do_constant_folding=False,
|
||||
operator_export_type=OperatorExportTypes.ONNX_ATEN)
|
||||
|
||||
@ -49,20 +49,17 @@ class TestONNXExport(JitTestCase):
|
||||
|
||||
tm = TraceMe()
|
||||
tm = torch.jit.trace(tm, torch.rand(3, 4))
|
||||
example_outputs = (tm(torch.rand(3, 4)),)
|
||||
f = io.BytesIO()
|
||||
torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
|
||||
torch.onnx._export(tm, (torch.rand(3, 4),), f)
|
||||
|
||||
def test_export_tensoroption_to(self):
|
||||
def foo(x):
|
||||
return x[0].clone().detach().cpu() + x
|
||||
|
||||
traced = torch.jit.trace(foo, (torch.rand([2])))
|
||||
example_outputs = traced(torch.rand([2]))
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
|
||||
example_outputs=example_outputs)
|
||||
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f)
|
||||
|
||||
def test_onnx_export_script_module(self):
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
@ -75,10 +72,8 @@ class TestONNXExport(JitTestCase):
|
||||
return x + x
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3))
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
|
||||
|
||||
@suppress_warnings
|
||||
def test_onnx_export_func_with_warnings(self):
|
||||
@ -93,11 +88,9 @@ class TestONNXExport(JitTestCase):
|
||||
def forward(self, x):
|
||||
return func_with_warning(x)
|
||||
|
||||
outputs = WarningTest()(torch.randn(42))
|
||||
# no exception
|
||||
torch.onnx.export_to_pretty_string(
|
||||
WarningTest(), torch.randn(42), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
WarningTest(), torch.randn(42), None, verbose=False)
|
||||
|
||||
def test_onnx_export_script_python_fail(self):
|
||||
class PythonModule(torch.jit.ScriptModule):
|
||||
@ -119,11 +112,9 @@ class TestONNXExport(JitTestCase):
|
||||
return y + y
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3))
|
||||
f = io.BytesIO()
|
||||
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
|
||||
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
|
||||
example_outputs=outputs)
|
||||
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
|
||||
|
||||
def test_onnx_export_script_inline_trace(self):
|
||||
class ModuleToInline(torch.nn.Module):
|
||||
@ -144,10 +135,8 @@ class TestONNXExport(JitTestCase):
|
||||
return y + y
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3))
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
|
||||
|
||||
def test_onnx_export_script_inline_script(self):
|
||||
class ModuleToInline(torch.jit.ScriptModule):
|
||||
@ -169,10 +158,8 @@ class TestONNXExport(JitTestCase):
|
||||
return y + y
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3))
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
|
||||
|
||||
def test_onnx_export_script_module_loop(self):
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
@ -189,10 +176,8 @@ class TestONNXExport(JitTestCase):
|
||||
return x
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3))
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
|
||||
|
||||
@suppress_warnings
|
||||
def test_onnx_export_script_truediv(self):
|
||||
@ -206,11 +191,9 @@ class TestONNXExport(JitTestCase):
|
||||
return x + z
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3))
|
||||
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False)
|
||||
|
||||
def test_onnx_export_script_non_alpha_add_sub(self):
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
@ -223,10 +206,8 @@ class TestONNXExport(JitTestCase):
|
||||
return bs - 1
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = torch.LongTensor([mte(torch.rand(3, 4))])
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.rand(3, 4),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.rand(3, 4),), None, verbose=False)
|
||||
|
||||
def test_onnx_export_script_module_if(self):
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
@ -240,10 +221,8 @@ class TestONNXExport(JitTestCase):
|
||||
return x
|
||||
|
||||
mte = ModuleToExport()
|
||||
outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
||||
example_outputs=outputs)
|
||||
mte, (torch.zeros(1, 2, 3),), None, verbose=False)
|
||||
|
||||
def test_onnx_export_script_inline_params(self):
|
||||
class ModuleToInline(torch.jit.ScriptModule):
|
||||
@ -272,8 +251,7 @@ class TestONNXExport(JitTestCase):
|
||||
reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
|
||||
self.assertEqual(result, reference)
|
||||
torch.onnx.export_to_pretty_string(
|
||||
mte, (torch.ones(2, 3),), None, verbose=False,
|
||||
example_outputs=result)
|
||||
mte, (torch.ones(2, 3),), None, verbose=False)
|
||||
|
||||
def test_onnx_export_speculate(self):
|
||||
|
||||
@ -305,18 +283,16 @@ class TestONNXExport(JitTestCase):
|
||||
return x.t()
|
||||
|
||||
f1 = Foo(transpose)
|
||||
outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
|
||||
f2 = Foo(linear)
|
||||
outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
|
||||
|
||||
torch.onnx.export_to_pretty_string(
|
||||
f1,
|
||||
(torch.ones(1, 10, dtype=torch.float), ),
|
||||
None, verbose=False, example_outputs=outputs_f1)
|
||||
None, verbose=False)
|
||||
torch.onnx.export_to_pretty_string(
|
||||
f2,
|
||||
(torch.ones(1, 10, dtype=torch.float), ),
|
||||
None, verbose=False, example_outputs=outputs_f2)
|
||||
None, verbose=False)
|
||||
|
||||
def test_onnx_export_shape_reshape(self):
|
||||
class Foo(torch.nn.Module):
|
||||
@ -328,10 +304,8 @@ class TestONNXExport(JitTestCase):
|
||||
return reshaped
|
||||
|
||||
foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
|
||||
outputs = foo(torch.zeros(1, 2, 3))
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
|
||||
example_outputs=outputs)
|
||||
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f)
|
||||
|
||||
def test_listconstruct_erasure(self):
|
||||
class FooMod(torch.nn.Module):
|
||||
@ -360,11 +334,10 @@ class TestONNXExport(JitTestCase):
|
||||
mod = DynamicSliceExportMod()
|
||||
|
||||
input = torch.rand(3, 4, 5)
|
||||
example_outs = mod(input)
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export_to_pretty_string(
|
||||
DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10)
|
||||
DynamicSliceExportMod(), (input,), f, opset_version=10)
|
||||
|
||||
def test_export_dict(self):
|
||||
class DictModule(torch.nn.Module):
|
||||
@ -380,4 +353,4 @@ class TestONNXExport(JitTestCase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
|
||||
torch.onnx.export_to_pretty_string(
|
||||
torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),))
|
||||
torch.jit.script(mod), (x_in,), f)
|
||||
|
||||
@ -1115,9 +1115,8 @@ class TestTracer(JitTestCase):
|
||||
def forward(self, x, w):
|
||||
return torch.matmul(x, w).detach()
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export_to_pretty_string(
|
||||
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)
|
||||
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), None)
|
||||
|
||||
def test_trace_slice_full_dim(self):
|
||||
def foo(x):
|
||||
|
||||
@ -19,7 +19,7 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
|
||||
|
||||
outputs = model(inputs)
|
||||
script_model = torch.jit.script(model)
|
||||
run_model_test(self, script_model, False, example_outputs=outputs,
|
||||
run_model_test(self, script_model, False,
|
||||
input=inputs, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
|
||||
@ -40,14 +40,13 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
|
||||
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
|
||||
|
||||
|
||||
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None,
|
||||
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL,
|
||||
input_names=None, dynamic_axes=None):
|
||||
for opset_version in opset_versions:
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(module, x, f,
|
||||
opset_version=opset_version,
|
||||
training=training,
|
||||
example_outputs=example_outputs,
|
||||
input_names=input_names,
|
||||
dynamic_axes=dynamic_axes)
|
||||
model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
@ -91,10 +90,8 @@ class TestONNXOpset(TestCase):
|
||||
x = torch.arange(1., 6., requires_grad=True)
|
||||
k = torch.tensor(3)
|
||||
module = MyModuleDynamic()
|
||||
example_output = module(x, k)
|
||||
check_onnx_opsets_operator(module, [x, k], ops,
|
||||
opset_versions=[10],
|
||||
example_outputs=example_output)
|
||||
opset_versions=[10])
|
||||
|
||||
def test_maxpool(self):
|
||||
module = torch.nn.MaxPool1d(2, stride=1)
|
||||
@ -191,7 +188,6 @@ class TestONNXOpset(TestCase):
|
||||
|
||||
module = DynamicSliceModel()
|
||||
x = torch.rand(1, 2)
|
||||
example_output = module(x)
|
||||
ops_10 = [{"op_name" : "Shape"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Gather",
|
||||
@ -202,7 +198,7 @@ class TestONNXOpset(TestCase):
|
||||
{"op_name" : "Slice",
|
||||
"attributes" : []}]
|
||||
ops = {10 : ops_10}
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output,
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[10],
|
||||
input_names=['x'], dynamic_axes={"x": [0, 1]})
|
||||
|
||||
ops_10 = [{"op_name" : "Constant"},
|
||||
@ -212,7 +208,7 @@ class TestONNXOpset(TestCase):
|
||||
{"op_name" : "Slice",
|
||||
"attributes" : []}]
|
||||
ops = {10 : ops_10}
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output)
|
||||
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
|
||||
|
||||
def test_flip(self):
|
||||
class MyModule(Module):
|
||||
|
||||
@ -279,12 +279,11 @@ class TestOperators(TestCase):
|
||||
def test_conv_variable_length(self):
|
||||
x = torch.ones(5, 3, 6, 6, requires_grad=True)
|
||||
model = torch.nn.Conv2d(3, 2, 3)
|
||||
y = model(x)
|
||||
|
||||
dynamic_axes = {"input_1": [0, 2, 3], "output_1": {0: "output_1_variable_dim_0", 1: "output_1_variable_dim_1"}}
|
||||
model_proto_name = "conv2d.onnx"
|
||||
torch.onnx.export(model, x, model_proto_name, verbose=True, input_names=["input_1"], output_names=["output_1"],
|
||||
example_outputs=y, dynamic_axes=dynamic_axes)
|
||||
dynamic_axes=dynamic_axes)
|
||||
|
||||
import onnx
|
||||
onnx_model = onnx.load(model_proto_name)
|
||||
@ -729,22 +728,6 @@ class TestOperators(TestCase):
|
||||
x = torch.randn(2, 3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
|
||||
|
||||
def test_retain_param_name_disabled(self):
|
||||
class MyModule(Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
self.fc1 = nn.Linear(4, 5, bias=False)
|
||||
self.fc1.weight.data.fill_(2.)
|
||||
self.fc2 = nn.Linear(5, 6, bias=False)
|
||||
self.fc2.weight.data.fill_(3.)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
x = torch.randn(3, 4).float()
|
||||
self.assertONNX(MyModule(), (x,), _retain_param_name=False,
|
||||
keep_initializers_as_inputs=True)
|
||||
|
||||
def test_c2_op(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
@ -134,7 +134,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return cuda_model, cuda_input
|
||||
|
||||
def run_debug_test(self, model, train, batch_size, state_dict=None,
|
||||
input=None, use_gpu=True, example_outputs=None,
|
||||
input=None, use_gpu=True,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX):
|
||||
"""
|
||||
# TODO: remove this from the final release version
|
||||
@ -153,7 +153,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
model, input = self.convert_cuda(model, input)
|
||||
|
||||
onnxir, torch_out = do_export(model, input, export_params=self.embed_params, verbose=False,
|
||||
example_outputs=example_outputs,
|
||||
do_constant_folding=False,
|
||||
opset_version=self.opset_version,
|
||||
keep_initializers_as_inputs=True,
|
||||
@ -168,7 +167,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
def run_actual_test(self, model, train, batch_size, state_dict=None,
|
||||
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
||||
example_outputs=None, do_constant_folding=True,
|
||||
do_constant_folding=True,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
|
||||
input_names=None, dynamic_axes=None,
|
||||
remained_onnx_input_idx=None):
|
||||
@ -191,7 +190,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
# Verify the model runs the same in Caffe2
|
||||
verify.verify(model, input, c2, rtol=rtol, atol=atol,
|
||||
example_outputs=example_outputs,
|
||||
do_constant_folding=do_constant_folding,
|
||||
opset_version=self.opset_version,
|
||||
keep_initializers_as_inputs=True,
|
||||
@ -202,7 +200,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
def run_model_test(self, model, train, batch_size, state_dict=None,
|
||||
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
|
||||
example_outputs=None, do_constant_folding=True,
|
||||
do_constant_folding=True,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
|
||||
input_names=None, dynamic_axes=None,
|
||||
remained_onnx_input_idx=None):
|
||||
@ -214,7 +212,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
if self.embed_params:
|
||||
self.run_actual_test(model, train, batch_size, state_dict, input,
|
||||
use_gpu=use_gpu_, rtol=rtol, atol=atol,
|
||||
example_outputs=example_outputs,
|
||||
do_constant_folding=do_constant_folding,
|
||||
operator_export_type=operator_export_type,
|
||||
input_names=input_names,
|
||||
@ -222,8 +219,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
remained_onnx_input_idx=remained_onnx_input_idx)
|
||||
else:
|
||||
self.run_debug_test(model, train, batch_size, state_dict, input,
|
||||
use_gpu=use_gpu_, example_outputs=example_outputs,
|
||||
operator_export_type=operator_export_type)
|
||||
use_gpu=use_gpu_, operator_export_type=operator_export_type)
|
||||
|
||||
def test_linear(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
@ -289,7 +285,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
# This test checks that given this edge condition, the model can be loaded and executed
|
||||
# in Caffe2 backend correctly.
|
||||
torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE,
|
||||
input_names=["input1", "fc1.bias"], _retain_param_name=False,
|
||||
input_names=["input1", "fc1.bias"],
|
||||
keep_initializers_as_inputs=True)
|
||||
|
||||
f.seek(0)
|
||||
@ -1401,9 +1397,8 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return x[1:x.size(0)]
|
||||
module = DynamicSliceModel()
|
||||
x = torch.rand(1, 2)
|
||||
example_output = module(x)
|
||||
self.run_model_test(DynamicSliceModel(), train=False, input=(x,),
|
||||
batch_size=BATCH_SIZE, use_gpu=False, example_outputs=example_output)
|
||||
batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_dynamic_slice_to_the_end(self):
|
||||
@ -1472,7 +1467,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
y = torch.ones(2, 3, 4) * 2
|
||||
self.run_model_test(Arithmetic(),
|
||||
train=False, input=(), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=(x + 3, y * (x + 3)))
|
||||
use_gpu=False)
|
||||
|
||||
def test_tensor_factories(self):
|
||||
class TensorFactory(torch.nn.Module):
|
||||
@ -1493,11 +1488,9 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=(torch.ones(x.size()),),
|
||||
input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
|
||||
use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
|
||||
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=(torch.ones(x.size()),),
|
||||
remained_onnx_input_idx=[])
|
||||
use_gpu=False, remained_onnx_input_idx=[])
|
||||
|
||||
def test_tensor_like_factories_script(self):
|
||||
class TensorFactory(torch.jit.ScriptModule):
|
||||
@ -1509,12 +1502,10 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=(torch.ones(x.size()),),
|
||||
input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
|
||||
use_gpu=False, input_names=['x'], dynamic_axes={'x': [0, 1, 2]})
|
||||
remained_onnx_input_idx = None if self.opset_version < 9 else []
|
||||
self.run_model_test(TensorFactory(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=(torch.ones(x.size()),),
|
||||
remained_onnx_input_idx=remained_onnx_input_idx)
|
||||
use_gpu=False, remained_onnx_input_idx=remained_onnx_input_idx)
|
||||
|
||||
def test_full(self):
|
||||
class FullModel(torch.nn.Module):
|
||||
@ -1532,8 +1523,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return torch.full((4, 5), x, dtype=torch.long)
|
||||
|
||||
x = torch.tensor(12)
|
||||
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
use_gpu=False, example_outputs=FullClass()(x))
|
||||
self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
def test_clamp(self):
|
||||
class ClampModel(torch.nn.Module):
|
||||
@ -2021,8 +2011,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return a
|
||||
|
||||
x = (torch.randn(3, 4), torch.randn(4, 3))
|
||||
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(x,))
|
||||
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_nested_tuple_input_output(self):
|
||||
class NestedTupleModel(torch.jit.ScriptModule):
|
||||
@ -2032,8 +2021,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
x = torch.randn(4, 5)
|
||||
y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
|
||||
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
|
||||
example_outputs=x + y[0] + y[1][0] + y[1][1])
|
||||
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_topk(self):
|
||||
class TopKModel(torch.nn.Module):
|
||||
@ -2050,7 +2038,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return torch.topk(input, 3, dim=0)
|
||||
|
||||
x = torch.randn(4, 3, requires_grad=True)
|
||||
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE, example_outputs=torch.topk(x, 3, dim=0))
|
||||
self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_floor(self):
|
||||
class FloorModel(torch.nn.Module):
|
||||
@ -2086,9 +2074,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
@ -2104,9 +2090,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
@ -2122,9 +2106,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
return torch.arange(2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float).view(-1, 1) + a
|
||||
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
outputs = ArangeScript()(x)
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE)
|
||||
|
||||
class ArangeModel(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
@ -2249,9 +2231,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = WhileModel()
|
||||
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
|
||||
|
||||
def test_while_cond(self):
|
||||
class WhileModel(torch.jit.ScriptModule):
|
||||
@ -2266,9 +2246,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
model = WhileModel()
|
||||
x = torch.zeros(1, 2, 3, dtype=torch.long)
|
||||
a = torch.tensor([0], dtype=torch.long)
|
||||
outputs = model(x, a)
|
||||
self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE)
|
||||
|
||||
@unittest.skip("Disabled due to onnx optimizer deprecation")
|
||||
def test_loop(self):
|
||||
@ -2281,9 +2259,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = LoopModel()
|
||||
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
|
||||
|
||||
@unittest.skip("Disabled due to onnx optimizer deprecation")
|
||||
def test_dynamic_loop(self):
|
||||
@ -2296,9 +2272,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = LoopModel()
|
||||
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
|
||||
|
||||
@unittest.skip("Disabled due to onnx optimizer deprecation")
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
@ -2317,9 +2291,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = NestedLoopsModel()
|
||||
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
|
||||
|
||||
def test_select(self):
|
||||
class SelectModel(torch.nn.Module):
|
||||
@ -2337,9 +2309,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = StandardDeviation()
|
||||
inputs = torch.randn(2, 3, 4)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_std_along_dims(self):
|
||||
class StandardDeviationAlongDims(torch.nn.Module):
|
||||
@ -2348,9 +2318,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = StandardDeviationAlongDims()
|
||||
inputs = torch.randn(2, 3, 4)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_masked_fill(self):
|
||||
@ -2379,9 +2347,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
y = torch.zeros(4, requires_grad=True)
|
||||
z = torch.ones(5, requires_grad=True)
|
||||
model = MeshgridModel()
|
||||
outputs = model(x, y, z)
|
||||
self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_remainder(self):
|
||||
class RemainderModel(torch.nn.Module):
|
||||
@ -2391,9 +2357,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.randn(4, 2, 3)
|
||||
y = torch.randn(1, 2, 1)
|
||||
model = RemainderModel()
|
||||
outputs = model(x, y)
|
||||
self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE)
|
||||
|
||||
def test_remainder_scalar(self):
|
||||
class RemainderModel(torch.nn.Module):
|
||||
@ -2402,9 +2366,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
inputs = torch.randint(10, (2, 3))
|
||||
model = RemainderModel()
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,)
|
||||
|
||||
def test_baddbmm(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
@ -2423,9 +2385,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
|
||||
model = GeluModel()
|
||||
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
|
||||
outputs = model(inputs)
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(outputs,))
|
||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_index_fill(self):
|
||||
|
||||
@ -28,7 +28,7 @@ class TestQuantizedOps(unittest.TestCase):
|
||||
output = q_model(*pt_inputs)
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(q_model, pt_inputs, f, input_names=input_names, example_outputs=output,
|
||||
torch.onnx.export(q_model, pt_inputs, f, input_names=input_names,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
f.seek(0)
|
||||
onnx_model = onnx.load(f)
|
||||
@ -84,8 +84,6 @@ class TestQuantizedOps(unittest.TestCase):
|
||||
self.generic_unary_test(torch.nn.ReLU())
|
||||
|
||||
def export_to_onnx(self, model, input, input_names):
|
||||
outputs = model(input)
|
||||
|
||||
traced = torch.jit.trace(model, input)
|
||||
buf = io.BytesIO()
|
||||
torch.jit.save(traced, buf)
|
||||
@ -93,7 +91,7 @@ class TestQuantizedOps(unittest.TestCase):
|
||||
|
||||
model = torch.jit.load(buf)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
|
||||
torch.onnx.export(model, input, f, input_names=input_names,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
f.seek(0)
|
||||
|
||||
|
||||
@ -45,9 +45,9 @@ def to_numpy(tensor):
|
||||
else:
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
|
||||
do_constant_folding=True, keep_initializers_as_inputs=True,
|
||||
dynamic_axes=None, input_names=None, output_names=None,
|
||||
def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True,
|
||||
keep_initializers_as_inputs=True, dynamic_axes=None,
|
||||
input_names=None, output_names=None,
|
||||
fixed_batch_size=False, training=None,
|
||||
onnx_shape_inference=True):
|
||||
# export the model to ONNX
|
||||
@ -55,7 +55,6 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
|
||||
input_copy = copy.deepcopy(input)
|
||||
torch.onnx._export(model, input_copy, f,
|
||||
opset_version=opset_version,
|
||||
example_outputs=example_outputs,
|
||||
do_constant_folding=do_constant_folding,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
dynamic_axes=dynamic_axes,
|
||||
@ -132,7 +131,7 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
|
||||
input = input + ({},)
|
||||
|
||||
ort_sess = convert_to_onnx(model, input=input, opset_version=self.opset_version,
|
||||
example_outputs=output, do_constant_folding=do_constant_folding,
|
||||
do_constant_folding=do_constant_folding,
|
||||
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
||||
dynamic_axes=dynamic_axes, input_names=input_names,
|
||||
output_names=output_names, fixed_batch_size=fixed_batch_size, training=training,
|
||||
@ -284,9 +283,9 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
_run_test(model, tracing_remained_onnx_input_idx)
|
||||
|
||||
def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
|
||||
example_outputs=None, do_constant_folding=True,
|
||||
dynamic_axes=None, input_names=None, output_names=None,
|
||||
ort_optim_on=True, training=None):
|
||||
do_constant_folding=True, dynamic_axes=None,
|
||||
input_names=None, output_names=None,
|
||||
ort_optim_on=True, training=None, use_external_data_format=None):
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -310,13 +309,12 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
input_copy = copy.deepcopy(input)
|
||||
torch.onnx.export(model, input_copy, model_file_name,
|
||||
opset_version=self.opset_version,
|
||||
example_outputs=output,
|
||||
verbose=False,
|
||||
do_constant_folding=do_constant_folding,
|
||||
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
|
||||
dynamic_axes=dynamic_axes,
|
||||
input_names=input_names, output_names=output_names,
|
||||
use_external_data_format=True)
|
||||
use_external_data_format=use_external_data_format)
|
||||
# compute onnxruntime output prediction
|
||||
ort_sess_opt = onnxruntime.SessionOptions()
|
||||
ort_sess_opt.graph_optimization_level = \
|
||||
@ -366,19 +364,55 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
return x + torch.ones(2, 1024)
|
||||
|
||||
x = torch.randn(2, 1)
|
||||
self.run_model_test_with_external_data(LargeModel(), x)
|
||||
self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=None)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
||||
@unittest.skip("Enable this once large model with subgraph is supported in ORT")
|
||||
def test_subgraph_with_external_data(self):
|
||||
def test_largemodel_without_use_external_data_format_param(self):
|
||||
class LargeModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
for i in range(x.size(0)):
|
||||
x = x + torch.ones(2, 1024)
|
||||
return x
|
||||
def __init__(self):
|
||||
super(LargeModel, self).__init__()
|
||||
dim = 5
|
||||
n = 40 * 4 * 10 ** 6
|
||||
self.emb = torch.nn.Embedding(n, dim)
|
||||
self.lin1 = torch.nn.Linear(dim, 1)
|
||||
self.seq = torch.nn.Sequential(
|
||||
self.emb,
|
||||
self.lin1,
|
||||
)
|
||||
|
||||
x = torch.randn(2, 1)
|
||||
self.run_model_test_with_external_data(torch.jit.script(LargeModel()), x)
|
||||
def forward(self, input):
|
||||
return self.seq(input)
|
||||
|
||||
model = LargeModel()
|
||||
x = torch.tensor([2], dtype=torch.long)
|
||||
self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=None)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
|
||||
def test_largemodel_with_use_external_data_format_False(self):
|
||||
class LargeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(LargeModel, self).__init__()
|
||||
dim = 5
|
||||
n = 30 * 4 * 10 ** 6
|
||||
self.emb = torch.nn.Embedding(n, dim)
|
||||
self.lin1 = torch.nn.Linear(dim, 1)
|
||||
self.seq = torch.nn.Sequential(
|
||||
self.emb,
|
||||
self.lin1,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.seq(input)
|
||||
|
||||
model = LargeModel()
|
||||
x = torch.tensor([3], dtype=torch.long)
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
self.run_model_test_with_external_data(LargeModel(), x, use_external_data_format=False)
|
||||
|
||||
the_exception = cm.exception
|
||||
self.assertEqual("RuntimeError: Exporting model exceed maximum protobuf size of 2GB. " +
|
||||
"Please call torch.onnx.export without setting use_external_data_format parameter.")
|
||||
|
||||
def test_fuse_conv_bn1d(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
@ -7784,7 +7818,6 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
script_model = torch.jit.script(model)
|
||||
output = model(x)
|
||||
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
|
||||
example_outputs=output,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
ort_outs = run_ort(ort_sess, input=(x,))
|
||||
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))
|
||||
@ -7828,7 +7861,6 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
y = model(input)
|
||||
output = y.cpu().numpy()
|
||||
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
|
||||
example_outputs=y,
|
||||
training=torch.onnx.TrainingMode.TRAINING)
|
||||
ort_outs = run_ort(ort_sess, input=(x,))
|
||||
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
|
||||
@ -7913,11 +7945,9 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
model = torch.jit.script(MyModule())
|
||||
box_regression = torch.randn([4, 4])
|
||||
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
|
||||
outputs = model(box_regression, proposal)
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
convert_to_onnx(model, input=(box_regression, proposal),
|
||||
example_outputs=outputs)
|
||||
convert_to_onnx(model, input=(box_regression, proposal))
|
||||
|
||||
def test_initializer_sequence(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
@ -7939,7 +7969,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
|
||||
x = torch.randn(32, 3)
|
||||
f = io.BytesIO()
|
||||
torch.onnx._export(test_model, (x,), f, _retain_param_name=True, do_constant_folding=False)
|
||||
torch.onnx._export(test_model, (x,), f, do_constant_folding=False)
|
||||
loaded_model = onnx.load_from_string(f.getvalue())
|
||||
|
||||
actual_list = [p.name for p in loaded_model.graph.initializer]
|
||||
@ -7986,10 +8016,9 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
|
||||
x = torch.ones(2, 3, dtype=torch.float)
|
||||
y = torch.tensor(5, dtype=torch.long)
|
||||
example_output = (test_model(x, y),)
|
||||
f = io.BytesIO()
|
||||
|
||||
torch.onnx.export(test_model, (x, y), f, example_outputs=example_output, _retain_param_name=True, do_constant_folding=False)
|
||||
torch.onnx.export(test_model, (x, y), f, do_constant_folding=False)
|
||||
loaded_model = onnx.load_from_string(f.getvalue())
|
||||
|
||||
actual_list = [p.name for p in loaded_model.graph.initializer]
|
||||
|
||||
@ -32,7 +32,6 @@ class TestUtilityFuns(TestCase):
|
||||
|
||||
def _model_to_graph(self, model, input,
|
||||
do_constant_folding=True,
|
||||
example_outputs=None,
|
||||
training=TrainingMode.EVAL,
|
||||
operator_export_type=OperatorExportTypes.ONNX,
|
||||
input_names=None,
|
||||
@ -49,7 +48,6 @@ class TestUtilityFuns(TestCase):
|
||||
_disable_torch_constant_prop=True,
|
||||
operator_export_type=operator_export_type,
|
||||
training=training,
|
||||
example_outputs=example_outputs,
|
||||
input_names=input_names,
|
||||
dynamic_axes=dynamic_axes)
|
||||
_set_onnx_shape_inference(True)
|
||||
@ -116,11 +114,11 @@ class TestUtilityFuns(TestCase):
|
||||
example_output = model(input_t, n)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.onnx.export(model,
|
||||
(input_t, n),
|
||||
"test.onnx",
|
||||
opset_version=self.opset_version,
|
||||
example_outputs=[example_output])
|
||||
torch.onnx._export(model,
|
||||
(input_t, n),
|
||||
"test.onnx",
|
||||
opset_version=self.opset_version,
|
||||
example_outputs=[example_output])
|
||||
|
||||
def test_constant_fold_transpose(self):
|
||||
class TransposeModule(torch.nn.Module):
|
||||
@ -558,27 +556,27 @@ class TestUtilityFuns(TestCase):
|
||||
assert node.kind() != "onnx::Shape"
|
||||
assert len(list(graph.nodes())) == 1
|
||||
|
||||
def test_strip_doc_string(self):
|
||||
def test_verbose(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.exp(input)
|
||||
x = torch.randn(3, 4)
|
||||
|
||||
def is_model_stripped(f, strip_doc_string=None):
|
||||
if strip_doc_string is None:
|
||||
def is_model_stripped(f, verbose=None):
|
||||
if verbose is None:
|
||||
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
|
||||
else:
|
||||
torch.onnx.export(MyModule(), x, f, strip_doc_string=strip_doc_string,
|
||||
torch.onnx.export(MyModule(), x, f, verbose=verbose,
|
||||
opset_version=self.opset_version)
|
||||
model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
model_strip = copy.copy(model)
|
||||
onnx.helper.strip_doc_string(model_strip)
|
||||
return model == model_strip
|
||||
|
||||
# test strip_doc_string=True (default)
|
||||
# test verbose=False (default)
|
||||
self.assertTrue(is_model_stripped(io.BytesIO()))
|
||||
# test strip_doc_string=False
|
||||
self.assertFalse(is_model_stripped(io.BytesIO(), False))
|
||||
# test verbose=True
|
||||
self.assertFalse(is_model_stripped(io.BytesIO(), True))
|
||||
|
||||
# NB: remove this test once DataParallel can be correctly handled
|
||||
def test_error_on_data_parallel(self):
|
||||
@ -720,9 +718,8 @@ class TestUtilityFuns(TestCase):
|
||||
q_model = torch.quantization.convert(q_model, inplace=False)
|
||||
|
||||
q_model.eval()
|
||||
output = q_model(*pt_inputs)
|
||||
|
||||
graph, _, __ = self._model_to_graph(q_model, pt_inputs, example_outputs=output,
|
||||
graph, _, __ = self._model_to_graph(q_model, pt_inputs,
|
||||
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
input_names=['pt_inputs'],
|
||||
dynamic_axes={'pt_inputs': [0, 1, 2, 3]})
|
||||
@ -749,9 +746,8 @@ class TestUtilityFuns(TestCase):
|
||||
|
||||
x = torch.tensor([2])
|
||||
model = PrimModule()
|
||||
output = model(x)
|
||||
model.eval()
|
||||
graph, _, __ = self._model_to_graph(model, (x,), example_outputs=output,
|
||||
graph, _, __ = self._model_to_graph(model, (x,),
|
||||
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
input_names=['x'], dynamic_axes={'x': [0]})
|
||||
iter = graph.nodes()
|
||||
@ -814,10 +810,9 @@ class TestUtilityFuns(TestCase):
|
||||
|
||||
model = torch.jit.script(MyModule())
|
||||
x = torch.randn(10, 3, 128, 128)
|
||||
example_outputs = model(x)
|
||||
_set_opset_version(self.opset_version)
|
||||
_set_operator_export_type(OperatorExportTypes.ONNX)
|
||||
graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
|
||||
graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True,
|
||||
operator_export_type=OperatorExportTypes.ONNX,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]})
|
||||
|
||||
@ -1233,8 +1233,6 @@ class TestQuantizeONNXExport(JitTestCase):
|
||||
input_names = ["x"]
|
||||
|
||||
def export_to_onnx(model, input, input_names):
|
||||
outputs = model(input)
|
||||
|
||||
traced = torch.jit.trace(model, input)
|
||||
buf = io.BytesIO()
|
||||
torch.jit.save(traced, buf)
|
||||
@ -1242,7 +1240,7 @@ class TestQuantizeONNXExport(JitTestCase):
|
||||
|
||||
model = torch.jit.load(buf)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
|
||||
torch.onnx.export(model, input, f, input_names=input_names,
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
||||
onnx_model = export_to_onnx(model, data, input_names)
|
||||
|
||||
|
||||
@ -1524,6 +1524,28 @@ except RuntimeError as e:
|
||||
):
|
||||
self.assertEqual(list(fn()), list(fn()))
|
||||
|
||||
for sampler in (
|
||||
RandomSampler(self.dataset, num_samples=5, replacement=True),
|
||||
RandomSampler(self.dataset, replacement=False),
|
||||
WeightedRandomSampler(weights, num_samples=5, replacement=True),
|
||||
WeightedRandomSampler(weights, num_samples=5, replacement=False),
|
||||
SubsetRandomSampler(range(10)),
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
l1 = list(sampler) + list(sampler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
l2 = list(sampler) + list(sampler)
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
its = (iter(sampler), iter(sampler))
|
||||
ls = ([], [])
|
||||
for idx in range(len(sampler)):
|
||||
for i in range(2):
|
||||
if idx == 0:
|
||||
torch.manual_seed(0)
|
||||
ls[i].append(next(its[i]))
|
||||
self.assertEqual(ls[0], ls[1])
|
||||
|
||||
def _test_sampler(self, **kwargs):
|
||||
indices = range(2, 12) # using a regular iterable
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import http.server
|
||||
import itertools
|
||||
import os
|
||||
@ -414,13 +415,30 @@ class TestIterableDataPipeBasic(TestCase):
|
||||
|
||||
# Test Case: Uneven DataPipes
|
||||
source_numbers = list(range(0, 10)) + [10, 12]
|
||||
numbers_dp = IDP(source_numbers)
|
||||
numbers_dp = dp.iter.IterableWrapper(source_numbers)
|
||||
n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
|
||||
self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
|
||||
self.assertEqual([1, 3, 5, 7, 9], list(n2))
|
||||
n = n1.mux(n2)
|
||||
self.assertEqual(source_numbers, list(n))
|
||||
|
||||
@suppress_warnings # Suppress warning for lambda fn
|
||||
def test_map_with_col_file_handle_datapipe(self):
|
||||
temp_dir = self.temp_dir.name
|
||||
datapipe1 = dp.iter.FileLister(temp_dir, '')
|
||||
datapipe2 = dp.iter.FileLoader(datapipe1)
|
||||
|
||||
def _helper(datapipe):
|
||||
dp1 = datapipe.map(lambda x: x.read(), input_col=1)
|
||||
dp2 = datapipe.map(lambda x: (x[0], x[1].read()))
|
||||
self.assertEqual(list(dp1), list(dp2))
|
||||
|
||||
# tuple
|
||||
_helper(datapipe2)
|
||||
# list
|
||||
datapipe3 = datapipe2.map(lambda x: list(x))
|
||||
_helper(datapipe3)
|
||||
|
||||
|
||||
class TestDataFramesPipes(TestCase):
|
||||
"""
|
||||
@ -619,25 +637,13 @@ class IDP_NoLen(IterDataPipe):
|
||||
super().__init__()
|
||||
self.input_dp = input_dp
|
||||
|
||||
# Prevent in-place modification
|
||||
def __iter__(self):
|
||||
for i in self.input_dp:
|
||||
input_dp = self.input_dp if isinstance(self.input_dp, IterDataPipe) else copy.deepcopy(self.input_dp)
|
||||
for i in input_dp:
|
||||
yield i
|
||||
|
||||
|
||||
class IDP(IterDataPipe):
|
||||
def __init__(self, input_dp):
|
||||
super().__init__()
|
||||
self.input_dp = input_dp
|
||||
self.length = len(input_dp)
|
||||
|
||||
def __iter__(self):
|
||||
for i in self.input_dp:
|
||||
yield i
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class MDP(MapDataPipe):
|
||||
def __init__(self, input_dp):
|
||||
super().__init__()
|
||||
@ -669,19 +675,19 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
def _test_picklable(self):
|
||||
arr = range(10)
|
||||
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Mapper, IDP(arr), (), {}),
|
||||
(dp.iter.Mapper, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
|
||||
(dp.iter.Collator, IDP(arr), (), {}),
|
||||
(dp.iter.Collator, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
|
||||
(dp.iter.Filter, IDP(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (), {}),
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (_fake_fn, (0, ), {'test': True}), {}),
|
||||
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (), {}),
|
||||
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (_fake_fn, (0, ), {'test': True}), {}),
|
||||
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
||||
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
|
||||
|
||||
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
|
||||
(dp.iter.Mapper, IDP(arr), (lambda x: x, ), {}),
|
||||
(dp.iter.Collator, IDP(arr), (lambda x: x, ), {}),
|
||||
(dp.iter.Filter, IDP(arr), (lambda x: x >= 5, ), {}),
|
||||
(dp.iter.Mapper, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
|
||||
(dp.iter.Collator, dp.iter.IterableWrapper(arr), (lambda x: x, ), {}),
|
||||
(dp.iter.Filter, dp.iter.IterableWrapper(arr), (lambda x: x >= 5, ), {}),
|
||||
]
|
||||
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
@ -692,8 +698,8 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
p = pickle.dumps(datapipe)
|
||||
|
||||
def test_concat_datapipe(self):
|
||||
input_dp1 = IDP(range(10))
|
||||
input_dp2 = IDP(range(5))
|
||||
input_dp1 = dp.iter.IterableWrapper(range(10))
|
||||
input_dp2 = dp.iter.IterableWrapper(range(5))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
|
||||
dp.iter.Concater()
|
||||
@ -718,7 +724,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
|
||||
def test_fork_datapipe(self):
|
||||
input_dp = IDP(range(10))
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
input_dp.fork(num_instances=0)
|
||||
@ -836,7 +842,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
self.assertEqual(len(input_dp), len(dp3))
|
||||
|
||||
def test_demux_datapipe(self):
|
||||
input_dp = IDP(range(10))
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
|
||||
@ -882,8 +888,8 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
self.assertEqual(list(range(0, 5)), output2)
|
||||
|
||||
# Test Case: classifer returns a value outside of [0, num_instance - 1]
|
||||
dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
|
||||
it = iter(dp[0])
|
||||
dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
|
||||
it = iter(dp0[0])
|
||||
with self.assertRaises(ValueError):
|
||||
next(it)
|
||||
next(it)
|
||||
@ -960,7 +966,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
@suppress_warnings # Suppress warning for lambda fn
|
||||
def test_map_datapipe(self):
|
||||
input_dp = IDP(range(10))
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
|
||||
def fn(item, dtype=torch.float, *, sum=False):
|
||||
data = torch.tensor(item, dtype=dtype)
|
||||
@ -1005,7 +1011,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None):
|
||||
for constr in (list, tuple):
|
||||
datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
|
||||
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn)
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
@ -1072,9 +1078,11 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
return _data
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None):
|
||||
datapipe = IDP([{"x": 0, "y": 1, "z": 2},
|
||||
{"x": 3, "y": 4, "z": 5},
|
||||
{"x": 6, "y": 7, "z": 8}])
|
||||
datapipe = dp.iter.IterableWrapper(
|
||||
[{"x": 0, "y": 1, "z": 2},
|
||||
{"x": 3, "y": 4, "z": 5},
|
||||
{"x": 6, "y": 7, "z": 8}]
|
||||
)
|
||||
res_dp = datapipe.map(fn, input_col, output_col)
|
||||
ref_dp = datapipe.map(ref_fn)
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
@ -1117,7 +1125,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
# TODO(VitalyFedyunin): If dill installed this test fails
|
||||
def _test_map_datapipe_nested_level(self):
|
||||
|
||||
input_dp = IDP([list(range(10)) for _ in range(3)])
|
||||
input_dp = dp.iter.IterableWrapper([list(range(10)) for _ in range(3)])
|
||||
|
||||
def fn(item, *, dtype=torch.float):
|
||||
return torch.tensor(item, dtype=dtype)
|
||||
@ -1153,7 +1161,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def test_collate_datapipe(self):
|
||||
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
input_dp = IDP(arrs)
|
||||
input_dp = dp.iter.IterableWrapper(arrs)
|
||||
|
||||
def _collate_fn(batch):
|
||||
return torch.tensor(sum(batch), dtype=torch.float)
|
||||
@ -1172,7 +1180,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def test_batch_datapipe(self):
|
||||
arrs = list(range(10))
|
||||
input_dp = IDP(arrs)
|
||||
input_dp = dp.iter.IterableWrapper(arrs)
|
||||
with self.assertRaises(AssertionError):
|
||||
input_dp.batch(batch_size=0)
|
||||
|
||||
@ -1200,7 +1208,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
def test_unbatch_datapipe(self):
|
||||
|
||||
target_length = 6
|
||||
prebatch_dp = IDP(range(target_length))
|
||||
prebatch_dp = dp.iter.IterableWrapper(range(target_length))
|
||||
|
||||
input_dp = prebatch_dp.batch(3)
|
||||
unbatch_dp = input_dp.unbatch()
|
||||
@ -1208,13 +1216,13 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
for i, res in zip(prebatch_dp, unbatch_dp):
|
||||
self.assertEqual(i, res)
|
||||
|
||||
input_dp = IDP([[0, 1, 2], [3, 4, 5]])
|
||||
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
|
||||
unbatch_dp = input_dp.unbatch()
|
||||
self.assertEqual(len(list(unbatch_dp)), target_length)
|
||||
for i, res in zip(prebatch_dp, unbatch_dp):
|
||||
self.assertEqual(i, res)
|
||||
|
||||
input_dp = IDP([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
|
||||
input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
|
||||
|
||||
unbatch_dp = input_dp.unbatch()
|
||||
expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
|
||||
@ -1233,7 +1241,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
for i, res in zip(expected_dp2, unbatch_dp):
|
||||
self.assertEqual(i, res)
|
||||
|
||||
input_dp = IDP([[0, 1, 2], [3, 4, 5]])
|
||||
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
|
||||
with self.assertRaises(ValueError):
|
||||
unbatch_dp = input_dp.unbatch(unbatch_level=-2)
|
||||
for i in unbatch_dp:
|
||||
@ -1245,7 +1253,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
print(i)
|
||||
|
||||
def test_bucket_batch_datapipe(self):
|
||||
input_dp = IDP(range(20))
|
||||
input_dp = dp.iter.IterableWrapper(range(20))
|
||||
with self.assertRaises(AssertionError):
|
||||
dp.iter.BucketBatcher(input_dp, batch_size=0)
|
||||
|
||||
@ -1258,7 +1266,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
data_len = 100
|
||||
arrs = list(range(data_len))
|
||||
random.shuffle(arrs)
|
||||
input_dp = IDP(arrs)
|
||||
input_dp = dp.iter.IterableWrapper(arrs)
|
||||
bucket_dp = dp.iter.BucketBatcher(input_dp, **kwargs)
|
||||
|
||||
self.assertEqual(len(bucket_dp), data_len // 3 if kwargs['drop_last'] else data_len // 3 + 1)
|
||||
@ -1291,7 +1299,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
|
||||
def test_filter_datapipe(self):
|
||||
input_ds = IDP(range(10))
|
||||
input_ds = dp.iter.IterableWrapper(range(10))
|
||||
|
||||
def _filter_fn(data, val, clip=False):
|
||||
if clip:
|
||||
@ -1318,7 +1326,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def test_filter_datapipe_nested_list(self):
|
||||
|
||||
input_ds = IDP(range(10)).batch(5)
|
||||
input_ds = dp.iter.IterableWrapper(range(10)).batch(5)
|
||||
|
||||
def _filter_fn(data, val):
|
||||
return data >= val
|
||||
@ -1340,7 +1348,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
filter_dp = input_ds.filter(nesting_level=5, filter_fn=_filter_fn, fn_kwargs={'val': 5})
|
||||
temp = list(filter_dp)
|
||||
|
||||
input_ds = IDP(range(10)).batch(3)
|
||||
input_ds = dp.iter.IterableWrapper(range(10)).batch(3)
|
||||
|
||||
filter_dp = input_ds.filter(lambda ls: len(ls) >= 3)
|
||||
expected_dp3: List[List[int]] = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||
@ -1348,21 +1356,21 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
for data, exp in zip(filter_dp, expected_dp3):
|
||||
self.assertEqual(data, exp)
|
||||
|
||||
input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
|
||||
input_ds = dp.iter.IterableWrapper([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
|
||||
filter_dp = input_ds.filter(lambda x: x > 3, nesting_level=-1)
|
||||
expected_dp4 = [[[4, 5]], [[6, 7, 8]]]
|
||||
self.assertEqual(len(list(filter_dp)), len(expected_dp4))
|
||||
for data2, exp2 in zip(filter_dp, expected_dp4):
|
||||
self.assertEqual(data2, exp2)
|
||||
|
||||
input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
|
||||
input_ds = dp.iter.IterableWrapper([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]])
|
||||
filter_dp = input_ds.filter(lambda x: x > 7, nesting_level=-1)
|
||||
expected_dp5 = [[[8]]]
|
||||
self.assertEqual(len(list(filter_dp)), len(expected_dp5))
|
||||
for data2, exp2 in zip(filter_dp, expected_dp5):
|
||||
self.assertEqual(data2, exp2)
|
||||
|
||||
input_ds = IDP([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]])
|
||||
input_ds = dp.iter.IterableWrapper([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]])
|
||||
filter_dp = input_ds.filter(lambda ls: len(ls) >= 3, nesting_level=1)
|
||||
expected_dp6 = [[[6, 7, 8], [1, 2, 3]]]
|
||||
self.assertEqual(len(list(filter_dp)), len(expected_dp6))
|
||||
@ -1370,7 +1378,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
self.assertEqual(data2, exp2)
|
||||
|
||||
def test_sampler_datapipe(self):
|
||||
input_dp = IDP(range(10))
|
||||
input_dp = dp.iter.IterableWrapper(range(10))
|
||||
# Default SequentialSampler
|
||||
sampled_dp = dp.iter.Sampler(input_dp) # type: ignore[var-annotated]
|
||||
self.assertEqual(len(sampled_dp), 10)
|
||||
@ -1387,7 +1395,7 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def test_shuffle_datapipe(self):
|
||||
exp = list(range(20))
|
||||
input_ds = IDP(exp)
|
||||
input_ds = dp.iter.IterableWrapper(exp)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
shuffle_dp = input_ds.shuffle(buffer_size=0)
|
||||
@ -1413,15 +1421,15 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
|
||||
def test_zip_datapipe(self):
|
||||
with self.assertRaises(TypeError):
|
||||
dp.iter.Zipper(IDP(range(10)), list(range(10))) # type: ignore[arg-type]
|
||||
dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10))) # type: ignore[arg-type]
|
||||
|
||||
zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated]
|
||||
zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated]
|
||||
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
|
||||
len(zipped_dp)
|
||||
exp = list((i, i) for i in range(5))
|
||||
self.assertEqual(list(zipped_dp), exp)
|
||||
|
||||
zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP(range(5)))
|
||||
zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5)))
|
||||
self.assertEqual(len(zipped_dp), 5)
|
||||
self.assertEqual(list(zipped_dp), exp)
|
||||
# Reset
|
||||
@ -1506,32 +1514,32 @@ class TestFunctionalMapDataPipe(TestCase):
|
||||
def test_mux_datapipe(self):
|
||||
|
||||
# Test Case: Elements are yielded one at a time from each DataPipe, until they are all exhausted
|
||||
input_dp1 = IDP(range(4))
|
||||
input_dp2 = IDP(range(4, 8))
|
||||
input_dp3 = IDP(range(8, 12))
|
||||
input_dp1 = dp.iter.IterableWrapper(range(4))
|
||||
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
|
||||
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
|
||||
output_dp = input_dp1.mux(input_dp2, input_dp3)
|
||||
expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
|
||||
self.assertEqual(len(expected_output), len(output_dp))
|
||||
self.assertEqual(expected_output, list(output_dp))
|
||||
|
||||
# Test Case: Uneven input Data Pipes
|
||||
input_dp1 = IDP([1, 2, 3, 4])
|
||||
input_dp2 = IDP([10])
|
||||
input_dp3 = IDP([100, 200, 300])
|
||||
input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4])
|
||||
input_dp2 = dp.iter.IterableWrapper([10])
|
||||
input_dp3 = dp.iter.IterableWrapper([100, 200, 300])
|
||||
output_dp = input_dp1.mux(input_dp2, input_dp3)
|
||||
expected_output = [1, 10, 100, 2, 200, 3, 300, 4]
|
||||
self.assertEqual(len(expected_output), len(output_dp))
|
||||
self.assertEqual(expected_output, list(output_dp))
|
||||
|
||||
# Test Case: Empty Data Pipe
|
||||
input_dp1 = IDP([0, 1, 2, 3])
|
||||
input_dp2 = IDP([])
|
||||
input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3])
|
||||
input_dp2 = dp.iter.IterableWrapper([])
|
||||
output_dp = input_dp1.mux(input_dp2)
|
||||
self.assertEqual(len(input_dp1), len(output_dp))
|
||||
self.assertEqual(list(input_dp1), list(output_dp))
|
||||
|
||||
# Test Case: raises TypeError when __len__ is called and an input doesn't have __len__
|
||||
input_dp1 = IDP(range(10))
|
||||
input_dp1 = dp.iter.IterableWrapper(range(10))
|
||||
input_dp_no_len = IDP_NoLen(range(10))
|
||||
output_dp = input_dp1.mux(input_dp_no_len)
|
||||
with self.assertRaises(TypeError):
|
||||
@ -1665,8 +1673,8 @@ class TestTyping(TestCase):
|
||||
self.assertTrue(issubclass(DP1, IterDataPipe))
|
||||
dp1 = DP1(10)
|
||||
self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))
|
||||
dp2 = DP1(5)
|
||||
self.assertEqual(dp1.type, dp2.type)
|
||||
dp1_ = DP1(5)
|
||||
self.assertEqual(dp1.type, dp1_.type)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"is not a generic class"):
|
||||
class InvalidDP5(DP1[tuple]): # type: ignore[type-arg]
|
||||
@ -1679,10 +1687,10 @@ class TestTyping(TestCase):
|
||||
yield d # type: ignore[misc]
|
||||
|
||||
self.assertTrue(issubclass(DP2, IterDataPipe))
|
||||
dp1 = DP2() # type: ignore[assignment]
|
||||
self.assertTrue(DP2.type.issubtype(dp1.type) and dp1.type.issubtype(DP2.type))
|
||||
dp2 = DP2() # type: ignore[assignment]
|
||||
self.assertEqual(dp1.type, dp2.type)
|
||||
dp2 = DP2() # type: ignore[var-annotated]
|
||||
self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type))
|
||||
dp2_ = DP2() # type: ignore[var-annotated]
|
||||
self.assertEqual(dp2.type, dp2_.type)
|
||||
|
||||
class DP3(IterDataPipe[Tuple[T_co, str]]):
|
||||
r""" DataPipe without fixed type with __init__ function"""
|
||||
@ -1695,10 +1703,10 @@ class TestTyping(TestCase):
|
||||
yield d, str(d)
|
||||
|
||||
self.assertTrue(issubclass(DP3, IterDataPipe))
|
||||
dp1 = DP3(range(10)) # type: ignore[assignment]
|
||||
self.assertTrue(DP3.type.issubtype(dp1.type) and dp1.type.issubtype(DP3.type))
|
||||
dp2 = DP3(5) # type: ignore[assignment]
|
||||
self.assertEqual(dp1.type, dp2.type)
|
||||
dp3 = DP3(range(10)) # type: ignore[var-annotated]
|
||||
self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type))
|
||||
dp3_ = DP3(5) # type: ignore[var-annotated]
|
||||
self.assertEqual(dp3.type, dp3_.type)
|
||||
|
||||
class DP4(IterDataPipe[tuple]):
|
||||
r""" DataPipe without __iter__ annotation"""
|
||||
@ -1707,8 +1715,8 @@ class TestTyping(TestCase):
|
||||
raise NotImplementedError
|
||||
|
||||
self.assertTrue(issubclass(DP4, IterDataPipe))
|
||||
dp = DP4()
|
||||
self.assertTrue(dp.type.param == tuple)
|
||||
dp4 = DP4()
|
||||
self.assertTrue(dp4.type.param == tuple)
|
||||
|
||||
class DP5(IterDataPipe):
|
||||
r""" DataPipe without type annotation"""
|
||||
@ -1717,9 +1725,9 @@ class TestTyping(TestCase):
|
||||
raise NotImplementedError
|
||||
|
||||
self.assertTrue(issubclass(DP5, IterDataPipe))
|
||||
dp = DP5() # type: ignore[assignment]
|
||||
dp5 = DP5()
|
||||
from torch.utils.data._typing import issubtype
|
||||
self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param))
|
||||
self.assertTrue(issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param))
|
||||
|
||||
class DP6(IterDataPipe[int]):
|
||||
r""" DataPipe with plain Iterator"""
|
||||
@ -1728,13 +1736,13 @@ class TestTyping(TestCase):
|
||||
raise NotImplementedError
|
||||
|
||||
self.assertTrue(issubclass(DP6, IterDataPipe))
|
||||
dp = DP6() # type: ignore[assignment]
|
||||
self.assertTrue(dp.type.param == int)
|
||||
dp6 = DP6()
|
||||
self.assertTrue(dp6.type.param == int)
|
||||
|
||||
class DP7(IterDataPipe[Awaitable[T_co]]):
|
||||
r""" DataPipe with abstract base class"""
|
||||
|
||||
self.assertTrue(issubclass(DP6, IterDataPipe))
|
||||
self.assertTrue(issubclass(DP7, IterDataPipe))
|
||||
self.assertTrue(DP7.type.param == Awaitable[T_co])
|
||||
|
||||
class DP8(DP7[str]):
|
||||
@ -1765,11 +1773,11 @@ class TestTyping(TestCase):
|
||||
# Non-DataPipe input with DataPipe hint
|
||||
datasource = [(1, '1'), (2, '2'), (3, '3')]
|
||||
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
|
||||
dp = DP0(datasource)
|
||||
dp0 = DP0(datasource)
|
||||
|
||||
dp = DP0(IDP(range(10)))
|
||||
dp0 = DP0(dp.iter.IterableWrapper(range(10)))
|
||||
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
|
||||
dp = DP1(dp)
|
||||
dp1 = DP1(dp0)
|
||||
|
||||
def test_runtime(self):
|
||||
class DP(IterDataPipe[Tuple[int, T_co]]):
|
||||
@ -1784,26 +1792,26 @@ class TestTyping(TestCase):
|
||||
dss = ([(1, '1'), (2, '2')],
|
||||
[(1, 1), (2, '2')])
|
||||
for ds in dss:
|
||||
dp = DP(ds) # type: ignore[var-annotated]
|
||||
self.assertEqual(list(dp), ds)
|
||||
dp0 = DP(ds) # type: ignore[var-annotated]
|
||||
self.assertEqual(list(dp0), ds)
|
||||
# Reset __iter__
|
||||
self.assertEqual(list(dp), ds)
|
||||
self.assertEqual(list(dp0), ds)
|
||||
|
||||
dss = ([(1, 1), ('2', 2)], # type: ignore[assignment, list-item]
|
||||
[[1, '1'], [2, '2']], # type: ignore[list-item]
|
||||
[1, '1', 2, '2'])
|
||||
for ds in dss:
|
||||
dp = DP(ds)
|
||||
dp0 = DP(ds)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
|
||||
list(dp)
|
||||
list(dp0)
|
||||
|
||||
with runtime_validation_disabled():
|
||||
self.assertEqual(list(dp), ds)
|
||||
self.assertEqual(list(dp0), ds)
|
||||
with runtime_validation_disabled():
|
||||
self.assertEqual(list(dp), ds)
|
||||
self.assertEqual(list(dp0), ds)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
|
||||
list(dp)
|
||||
list(dp0)
|
||||
|
||||
def test_reinforce(self):
|
||||
T = TypeVar('T', int, str)
|
||||
@ -1819,26 +1827,26 @@ class TestTyping(TestCase):
|
||||
|
||||
ds = list(range(10))
|
||||
# Valid type reinforcement
|
||||
dp = DP(ds).reinforce_type(int)
|
||||
self.assertTrue(dp.type, int)
|
||||
self.assertEqual(list(dp), ds)
|
||||
dp0 = DP(ds).reinforce_type(int)
|
||||
self.assertTrue(dp0.type, int)
|
||||
self.assertEqual(list(dp0), ds)
|
||||
|
||||
# Invalid type
|
||||
with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"):
|
||||
dp = DP(ds).reinforce_type(1)
|
||||
dp1 = DP(ds).reinforce_type(1)
|
||||
|
||||
# Type is not subtype
|
||||
with self.assertRaisesRegex(TypeError, r"Expected 'expected_type' as subtype of"):
|
||||
dp = DP(ds).reinforce_type(float)
|
||||
dp2 = DP(ds).reinforce_type(float)
|
||||
|
||||
# Invalid data at runtime
|
||||
dp = DP(ds).reinforce_type(str)
|
||||
dp3 = DP(ds).reinforce_type(str)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
|
||||
list(dp)
|
||||
list(dp3)
|
||||
|
||||
# Context Manager to disable the runtime validation
|
||||
with runtime_validation_disabled():
|
||||
self.assertEqual(list(d for d in dp), ds)
|
||||
self.assertEqual(list(d for d in dp3), ds)
|
||||
|
||||
|
||||
class NumbersDataset(IterDataPipe):
|
||||
@ -1900,7 +1908,7 @@ class TestSharding(TestCase):
|
||||
self.assertEqual(sorted(all_items), sorted(items))
|
||||
|
||||
def test_sharding_length(self):
|
||||
numbers_dp = IDP(range(13))
|
||||
numbers_dp = dp.iter.IterableWrapper(range(13))
|
||||
sharded_dp0 = numbers_dp.sharding_filter()
|
||||
torch.utils.data.sharding.apply_sharding(sharded_dp0, 3, 0)
|
||||
sharded_dp1 = numbers_dp.sharding_filter()
|
||||
@ -1912,7 +1920,7 @@ class TestSharding(TestCase):
|
||||
self.assertEqual(4, len(sharded_dp1))
|
||||
self.assertEqual(4, len(sharded_dp2))
|
||||
|
||||
numbers_dp = IDP(range(1))
|
||||
numbers_dp = dp.iter.IterableWrapper(range(1))
|
||||
sharded_dp0 = numbers_dp.sharding_filter()
|
||||
torch.utils.data.sharding.apply_sharding(sharded_dp0, 2, 0)
|
||||
sharded_dp1 = numbers_dp.sharding_filter()
|
||||
@ -1922,11 +1930,11 @@ class TestSharding(TestCase):
|
||||
|
||||
@skipIfNoDill
|
||||
def test_old_dataloader(self):
|
||||
dp = self._get_pipeline()
|
||||
expected = list(dp)
|
||||
dp0 = self._get_pipeline()
|
||||
expected = list(dp0)
|
||||
|
||||
dp = self._get_pipeline().sharding_filter()
|
||||
dl = DataLoader(dp, batch_size=1, shuffle=False, num_workers=2,
|
||||
dp0 = self._get_pipeline().sharding_filter()
|
||||
dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2,
|
||||
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn)
|
||||
items = []
|
||||
for i in dl:
|
||||
|
||||
@ -9704,12 +9704,6 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(input1.grad, torch.zeros_like(input1))
|
||||
self.assertEqual(input2.grad, input1 * 1e8)
|
||||
|
||||
# Check error when inputs are not the same shape
|
||||
input1 = torch.randn(2, 2, 1)
|
||||
input2 = torch.randn(2, 1, 3)
|
||||
with self.assertRaises(RuntimeError):
|
||||
F.cosine_similarity(input1, input2)
|
||||
|
||||
# Check type promotion, issue #61454
|
||||
input = torch.tensor(12.)
|
||||
out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, vmap
|
||||
from torch import Tensor
|
||||
from torch._vmap_internals import vmap
|
||||
import functools
|
||||
import itertools
|
||||
import warnings
|
||||
|
||||
@ -387,7 +387,7 @@ class CMake:
|
||||
# then. Until then, we use "--" to pass parameters to the
|
||||
# underlying build system.
|
||||
build_args += ['--']
|
||||
if IS_WINDOWS:
|
||||
if IS_WINDOWS and not USE_NINJA:
|
||||
# We are likely using msbuild here
|
||||
build_args += ['/p:CL_MPCount={}'.format(max_jobs)]
|
||||
else:
|
||||
|
||||
@ -108,7 +108,7 @@ def plural(n: int) -> str:
|
||||
|
||||
def get_base_commit(sha1: str) -> str:
|
||||
return subprocess.check_output(
|
||||
["git", "merge-base", sha1, "origin/master"],
|
||||
["git", "merge-base", sha1, "origin/release/1.10"],
|
||||
encoding="ascii",
|
||||
).strip()
|
||||
|
||||
|
||||
@ -22,7 +22,11 @@ class TestCMake(unittest.TestCase):
|
||||
# MAX_JOBS, USE_NINJA, IS_WINDOWS, want
|
||||
(( '8', True, False), ['-j', '8']), # noqa: E201,E241
|
||||
(( None, True, False), None), # noqa: E201,E241
|
||||
(( '7', False, False), ['-j', '7']), # noqa: E201,E241
|
||||
(( None, False, False), ['-j', '13']), # noqa: E201,E241
|
||||
(( '6', True, True), ['-j', '6']), # noqa: E201,E241
|
||||
(( None, True, True), None), # noqa: E201,E241
|
||||
(( '11', False, True), ['/p:CL_MPCount=11']), # noqa: E201,E241
|
||||
(( None, False, True), ['/p:CL_MPCount=13']), # noqa: E201,E241
|
||||
]
|
||||
for (max_jobs, use_ninja, is_windows), want in cases:
|
||||
|
||||
@ -9,6 +9,13 @@ if(NOT CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
endif()
|
||||
|
||||
if(BUILD_BINARY)
|
||||
add_library(aot_compiler SHARED
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
|
||||
)
|
||||
install(TARGETS aot_compiler DESTINATION lib)
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_PYTHON)
|
||||
return()
|
||||
endif()
|
||||
@ -430,9 +437,3 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
# Pybind11 requires explicit linking of the torch_python library
|
||||
target_link_libraries(nnapi_backend torch torch_python)
|
||||
endif()
|
||||
|
||||
if(BUILD_BINARY)
|
||||
add_library(aot_compiler SHARED
|
||||
${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -756,8 +756,6 @@ del register_after_fork
|
||||
# torch.jit.script as a decorator, for instance):
|
||||
from ._lobpcg import lobpcg as lobpcg
|
||||
|
||||
from ._vmap_internals import vmap as vmap
|
||||
|
||||
# These were previously defined in native_functions.yaml and appeared on the
|
||||
# `torch` namespace, but we moved them to c10 dispatch to facilitate custom
|
||||
# class usage. We add these lines here to preserve backward compatibility.
|
||||
|
||||
@ -362,19 +362,21 @@ void ConvertGraphToONNXProto(
|
||||
SymbolDimMap& symbol_map,
|
||||
int opset_version) {
|
||||
RawDataExportMap export_map;
|
||||
std::tie(model_proto, export_map, symbol_map) = export_onnx(
|
||||
graph,
|
||||
{},
|
||||
opset_version,
|
||||
{},
|
||||
false,
|
||||
onnx_torch::OperatorExportTypes::ONNX,
|
||||
true,
|
||||
true,
|
||||
{},
|
||||
true,
|
||||
false,
|
||||
std::string());
|
||||
bool val_use_external_data_format;
|
||||
std::tie(model_proto, export_map, symbol_map, val_use_external_data_format) =
|
||||
export_onnx(
|
||||
graph,
|
||||
{},
|
||||
opset_version,
|
||||
{},
|
||||
false,
|
||||
onnx_torch::OperatorExportTypes::ONNX,
|
||||
true,
|
||||
true,
|
||||
{},
|
||||
true,
|
||||
false,
|
||||
std::string());
|
||||
for (int i = 0; i < model_proto->graph().output_size(); ++i) {
|
||||
model_proto->mutable_graph()->mutable_output(i)->clear_type();
|
||||
}
|
||||
|
||||
@ -263,19 +263,25 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
|
||||
RawDataExportMap export_map;
|
||||
SymbolDimMap symbol_map;
|
||||
std::tie(model_proto, export_map, symbol_map) = export_onnx(
|
||||
g,
|
||||
initializers,
|
||||
onnx_opset_version,
|
||||
dynamic_axes,
|
||||
defer_weight_export,
|
||||
operator_export_type,
|
||||
strip_doc_string,
|
||||
keep_initializers_as_inputs,
|
||||
custom_opsets,
|
||||
add_node_names,
|
||||
use_external_data_format,
|
||||
onnx_file_path);
|
||||
bool val_use_external_data_format = false;
|
||||
std::tie(
|
||||
model_proto,
|
||||
export_map,
|
||||
symbol_map,
|
||||
val_use_external_data_format) =
|
||||
export_onnx(
|
||||
g,
|
||||
initializers,
|
||||
onnx_opset_version,
|
||||
dynamic_axes,
|
||||
defer_weight_export,
|
||||
operator_export_type,
|
||||
strip_doc_string,
|
||||
keep_initializers_as_inputs,
|
||||
custom_opsets,
|
||||
add_node_names,
|
||||
use_external_data_format,
|
||||
onnx_file_path);
|
||||
std::unordered_map<std::string, py::bytes>
|
||||
python_serialized_export_map;
|
||||
for (auto& kv : export_map) {
|
||||
@ -289,7 +295,9 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
}
|
||||
graph = serialize_model_proto_to_string(model_proto);
|
||||
return std::make_tuple(
|
||||
py::bytes(graph), python_serialized_export_map);
|
||||
py::bytes(graph),
|
||||
python_serialized_export_map,
|
||||
val_use_external_data_format);
|
||||
},
|
||||
py::arg("initializers"),
|
||||
py::arg("onnx_opset_version") = 0,
|
||||
|
||||
@ -231,6 +231,12 @@ class EncoderBase {
|
||||
bool use_external_data_format = false,
|
||||
const std::string& onnx_file_path = std::string());
|
||||
|
||||
unsigned long long int GetGraphProtoSize(
|
||||
onnx::GraphProto* graph_proto,
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers =
|
||||
std::map<std::string, at::Tensor>());
|
||||
|
||||
virtual void EncodeTensor(
|
||||
onnx::TensorProto* tensor_proto,
|
||||
const at::Tensor& tensor,
|
||||
@ -581,7 +587,6 @@ void EncoderBase::AddInitializersIntoGraphProto(
|
||||
bool use_external_data_format,
|
||||
const std::string& onnx_file_path) {
|
||||
AT_ASSERT(block->inputs().size() >= initializers.size());
|
||||
|
||||
for (auto input : block->inputs()) {
|
||||
auto name_tensor_pair = initializers.find(input->debugName());
|
||||
if (name_tensor_pair == initializers.end()) {
|
||||
@ -598,6 +603,38 @@ void EncoderBase::AddInitializersIntoGraphProto(
|
||||
}
|
||||
}
|
||||
|
||||
unsigned long long int EncoderBase::GetGraphProtoSize(
|
||||
onnx::GraphProto* graph_proto,
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers) {
|
||||
unsigned long long int sizes = 0;
|
||||
for (auto input : graph->inputs()) {
|
||||
auto name_tensor_pair = initializers.find(input->debugName());
|
||||
if (name_tensor_pair == initializers.end()) {
|
||||
continue;
|
||||
}
|
||||
onnx::GraphProto* graph_proto_copy = new onnx::GraphProto(*graph_proto);
|
||||
auto tensor_proto = graph_proto_copy->add_initializer();
|
||||
const at::Tensor tensor = name_tensor_pair->second;
|
||||
for (auto d : tensor.sizes()) {
|
||||
tensor_proto->add_dims(d);
|
||||
}
|
||||
tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
|
||||
at::Tensor t;
|
||||
if (tensor.is_quantized()) {
|
||||
t = tensor.contiguous();
|
||||
} else {
|
||||
t = tensor.contiguous().cpu();
|
||||
}
|
||||
tensor_proto->set_raw_data(std::string(
|
||||
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
|
||||
sizes += tensor_proto->ByteSizeLong();
|
||||
delete graph_proto_copy;
|
||||
graph_proto_copy = nullptr;
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
void EncoderBase::AddAttribute(
|
||||
onnx::NodeProto* node_proto,
|
||||
const jit::Node* node,
|
||||
@ -726,6 +763,10 @@ class GraphEncoder : public EncoderBase {
|
||||
return raw_data_export_map_;
|
||||
}
|
||||
|
||||
bool get_use_external_data_format() {
|
||||
return use_external_data_format_;
|
||||
}
|
||||
|
||||
private:
|
||||
void EncodeTensor(
|
||||
onnx::TensorProto* tensor_proto,
|
||||
@ -736,6 +777,7 @@ class GraphEncoder : public EncoderBase {
|
||||
|
||||
RawDataExportMap raw_data_export_map_;
|
||||
bool defer_weight_export_;
|
||||
bool use_external_data_format_;
|
||||
};
|
||||
|
||||
GraphEncoder::GraphEncoder(
|
||||
@ -754,8 +796,21 @@ GraphEncoder::GraphEncoder(
|
||||
bool use_external_data_format,
|
||||
const std::string& onnx_file_path)
|
||||
: EncoderBase(operator_export_type, strip_doc),
|
||||
defer_weight_export_(defer_weight_export) {
|
||||
defer_weight_export_(defer_weight_export),
|
||||
use_external_data_format_(use_external_data_format) {
|
||||
validateGraph(graph, operator_export_type);
|
||||
// If graph proto size exceed maximum protobuf size of 2GB, set
|
||||
// use_external_data_format to true.
|
||||
if (!use_external_data_format && !onnx_file_path.empty() &&
|
||||
GetGraphProtoSize(model_proto_.mutable_graph(), graph, initializers) >
|
||||
INT_MAX) {
|
||||
GRAPH_DEBUG(
|
||||
"Exporting model exceed maximum protobuf size of 2GB. Storing model parameters in external data files");
|
||||
use_external_data_format = true;
|
||||
// use_external_data_format_ is one of graph_encoder private variable set
|
||||
// for return `use_external_data_format` value.
|
||||
use_external_data_format_ = use_external_data_format;
|
||||
}
|
||||
|
||||
if (use_external_data_format) {
|
||||
TORCH_CHECK(
|
||||
@ -895,7 +950,8 @@ std::string pretty_print_onnx(
|
||||
std::tuple<
|
||||
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
|
||||
RawDataExportMap,
|
||||
SymbolDimMap>
|
||||
SymbolDimMap,
|
||||
bool>
|
||||
export_onnx(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers,
|
||||
@ -929,7 +985,8 @@ export_onnx(
|
||||
std::make_shared<::ONNX_NAMESPACE::ModelProto>(
|
||||
graph_encoder.get_model_proto()),
|
||||
graph_encoder.get_raw_data_export_map(),
|
||||
graph_encoder.get_symbol_dim_param_map());
|
||||
graph_encoder.get_symbol_dim_param_map(),
|
||||
graph_encoder.get_use_external_data_format());
|
||||
}
|
||||
|
||||
std::string serialize_model_proto_to_string(
|
||||
@ -938,7 +995,7 @@ std::string serialize_model_proto_to_string(
|
||||
TORCH_CHECK(
|
||||
proto_size <= INT_MAX,
|
||||
"Exporting model exceed maximum protobuf size of 2GB. "
|
||||
"Please call torch.onnx.export with use_external_data_format=True.");
|
||||
"Please call torch.onnx.export without setting use_external_data_format parameter.");
|
||||
return model_proto->SerializeAsString();
|
||||
}
|
||||
|
||||
|
||||
@ -33,7 +33,8 @@ using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
|
||||
TORCH_API std::tuple<
|
||||
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
|
||||
RawDataExportMap,
|
||||
SymbolDimMap>
|
||||
SymbolDimMap,
|
||||
bool>
|
||||
export_onnx(
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers,
|
||||
|
||||
@ -1519,8 +1519,11 @@ def _object_to_tensor(obj):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
|
||||
local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long)
|
||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||
# Otherwise, it will casue 100X slowdown.
|
||||
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||
byte_tensor = torch.ByteTensor(byte_storage)
|
||||
local_size = torch.LongTensor([byte_tensor.numel()])
|
||||
return byte_tensor, local_size
|
||||
|
||||
|
||||
|
||||
@ -4256,7 +4256,10 @@ cosine_similarity = _add_docstr(
|
||||
r"""
|
||||
cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
|
||||
|
||||
Returns cosine similarity between x1 and x2, computed along dim.
|
||||
Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
|
||||
to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
|
||||
squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 fewer dimension.
|
||||
|
||||
.. math ::
|
||||
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
|
||||
@ -4265,16 +4268,11 @@ Supports :ref:`type promotion <type-promotion-doc>`.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): First input.
|
||||
x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`,
|
||||
and broadcastable with x1 at other dimensions).
|
||||
dim (int, optional): Dimension of vectors. Default: 1
|
||||
x2 (Tensor): Second input.
|
||||
dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
|
||||
eps (float, optional): Small value to avoid division by zero.
|
||||
Default: 1e-8
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
|
||||
- Output: :math:`(\ast_1, \ast_2)`
|
||||
|
||||
Example::
|
||||
|
||||
>>> input1 = torch.randn(100, 128)
|
||||
|
||||
@ -31,10 +31,10 @@ def _export(*args, **kwargs):
|
||||
|
||||
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
|
||||
input_names=None, output_names=None, operator_export_type=None,
|
||||
opset_version=None, _retain_param_name=True, do_constant_folding=True,
|
||||
example_outputs=None, strip_doc_string=True, dynamic_axes=None,
|
||||
keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True,
|
||||
use_external_data_format=False):
|
||||
opset_version=None, _retain_param_name=None, do_constant_folding=True,
|
||||
example_outputs=None, strip_doc_string=None, dynamic_axes=None,
|
||||
keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None,
|
||||
use_external_data_format=None):
|
||||
r"""
|
||||
Exports a model into ONNX format. If ``model`` is not a
|
||||
:class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs
|
||||
@ -105,7 +105,9 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
||||
In this case, the exported model will first take all of its parameters
|
||||
as arguments, with the ordering as specified by ``model.state_dict().values()``
|
||||
verbose (bool, default False): if True, prints a description of the
|
||||
model being exported to stdout.
|
||||
model being exported to stdout. In addition, the final ONNX graph will include the
|
||||
field ``doc_string``` from the exported model which mentions the source code locations
|
||||
for ``model``.
|
||||
training (enum, default TrainingMode.EVAL):
|
||||
* ``TrainingMode.EVAL``: export the model in inference mode.
|
||||
* ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
|
||||
@ -181,16 +183,17 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
||||
opset_version (int, default 9):
|
||||
Must be ``== _onnx_main_opset or in _onnx_stable_opsets``,
|
||||
defined in torch/onnx/symbolic_helper.py.
|
||||
_retain_param_name (bool, default True): [Deprecated and ignored. Will be removed in next PyTorch
|
||||
release]
|
||||
do_constant_folding (bool, default False): Apply the constant-folding optimization.
|
||||
Constant-folding will replace some of the ops that have all constant inputs
|
||||
with pre-computed constant nodes.
|
||||
example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None):
|
||||
[Deprecated and ignored. Will be removed in next PyTorch release],
|
||||
Must be provided when exporting a ScriptModule or ScriptFunction, ignored otherwise.
|
||||
Used to determine the type and shape of the outputs without tracing the execution of
|
||||
the model. A single object is treated as equivalent to a tuple of one element.
|
||||
strip_doc_string (bool, default True): do not include the field
|
||||
``doc_string``` from the exported model. Otherwise the field will mention the source
|
||||
code locations for ``model``.
|
||||
strip_doc_string (bool, default True): [Deprecated and ignored. Will be removed in next PyTorch release]
|
||||
dynamic_axes (dict<string, dict<int, string>> or dict<string, list(int)>, default empty dict):
|
||||
|
||||
By default the exported model will have the shapes of all input and output tensors
|
||||
@ -293,10 +296,11 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
||||
|
||||
enable_onnx_checker (bool, default True): Deprecated and ignored. Will be removed in next
|
||||
Pytorch release.
|
||||
use_external_data_format (bool, default False): If True, then some of the model
|
||||
parameters are stored in external data files and not in the ONNX model file itself.
|
||||
Models larger than 2GB cannot be exported in one file because of size limits imposed
|
||||
by Protocol Buffers.
|
||||
use_external_data_format (bool, default False): [Deprecated and ignored. Will be removed in
|
||||
next Pytorch release.]
|
||||
If True, then some of the model parameters are stored in external data files and not in
|
||||
the ONNX model file itself. Models larger than 2GB cannot be exported in one file because
|
||||
of size limits imposed by Protocol Buffers.
|
||||
For details see
|
||||
`onnx.proto <https://github.com/onnx/onnx/blob/32c7cb66/onnx/onnx.proto#L562>`_.
|
||||
If True, argument ``f`` must be a string specifying the location of the model.
|
||||
@ -316,9 +320,24 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
|
||||
custom_opsets, enable_onnx_checker, use_external_data_format)
|
||||
|
||||
|
||||
def export_to_pretty_string(*args, **kwargs):
|
||||
def export_to_pretty_string(*args, **kwargs) -> str:
|
||||
r"""
|
||||
Same as :func:`export`, but returns a text representation of the exported model.
|
||||
Similar to :func:`export`, but returns a text representation of the ONNX
|
||||
model. Only differences in args listed below. All other args are the same
|
||||
as :func:`export`.
|
||||
|
||||
Args:
|
||||
f: Deprecated and ignored. Will be removed in the next release of
|
||||
PyTorch.
|
||||
add_node_names (bool, default True): Whether or not to set
|
||||
NodeProto.name. This makes no difference unless
|
||||
``google_printer=True``.
|
||||
google_printer (bool, default False): If False, will return a custom,
|
||||
compact representation of the model. If True will return the
|
||||
protobuf's `Message::DebugString()`, which is more verbose.
|
||||
|
||||
Returns:
|
||||
A UTF-8 str containing a human-readable representation of the ONNX model.
|
||||
"""
|
||||
from torch.onnx import utils
|
||||
return utils.export_to_pretty_string(*args, **kwargs)
|
||||
|
||||
@ -12,6 +12,7 @@ import torch.serialization
|
||||
import re
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import numbers
|
||||
import warnings
|
||||
from torch._six import string_classes
|
||||
@ -75,24 +76,37 @@ def select_model_mode_for_export(model, mode):
|
||||
|
||||
def export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
input_names=None, output_names=None, operator_export_type=None,
|
||||
opset_version=None, _retain_param_name=True, do_constant_folding=True,
|
||||
example_outputs=None, strip_doc_string=True, dynamic_axes=None,
|
||||
opset_version=None, _retain_param_name=None, do_constant_folding=True,
|
||||
example_outputs=None, strip_doc_string=None, dynamic_axes=None,
|
||||
keep_initializers_as_inputs=None, custom_opsets=None,
|
||||
enable_onnx_checker=None, use_external_data_format=False):
|
||||
enable_onnx_checker=None, use_external_data_format=None):
|
||||
if operator_export_type is None:
|
||||
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
|
||||
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
else:
|
||||
operator_export_type = OperatorExportTypes.ONNX
|
||||
|
||||
if enable_onnx_checker is not None:
|
||||
warnings.warn("`enable_onnx_checker' is deprecated and ignored. It will be removed in"
|
||||
"the next PyTorch release. To proceed despite ONNX checker failures, you"
|
||||
"can catch torch.onnx.ONNXCheckerError.")
|
||||
warnings.warn("'enable_onnx_checker' is deprecated and ignored. It will be removed in "
|
||||
"the next PyTorch release. To proceed despite ONNX checker failures, "
|
||||
"catch torch.onnx.ONNXCheckerError.")
|
||||
if _retain_param_name is not None:
|
||||
warnings.warn("'_retain_param_name' is deprecated and ignored. "
|
||||
"It will be removed in the next PyTorch release.")
|
||||
if strip_doc_string is not None:
|
||||
warnings.warn("`strip_doc_string' is deprecated and ignored. Will be removed in "
|
||||
"next PyTorch release. It's combined with `verbose' argument now. ")
|
||||
if example_outputs is not None:
|
||||
warnings.warn("`example_outputs' is deprecated and ignored. Will be removed in "
|
||||
"next PyTorch release.")
|
||||
if use_external_data_format is not None:
|
||||
warnings.warn("`use_external_data_format' is deprecated and ignored. Will be removed in next "
|
||||
"PyTorch release. The code will work as it is False if models are not larger than 2GB, "
|
||||
"Otherwise set to False because of size limits imposed by Protocol Buffers.")
|
||||
|
||||
_export(model, args, f, export_params, verbose, training, input_names, output_names,
|
||||
operator_export_type=operator_export_type, opset_version=opset_version,
|
||||
_retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
|
||||
example_outputs=example_outputs, strip_doc_string=strip_doc_string,
|
||||
do_constant_folding=do_constant_folding, example_outputs=example_outputs,
|
||||
dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
|
||||
|
||||
@ -310,7 +324,10 @@ def _decide_external_data_format(use_external_data_format, operator_export_type,
|
||||
# string specifying the location of the model. For large model cases, if f is not a non-empty string,
|
||||
# then this method returns an empty string, which is an error condition for the large model export code
|
||||
# path later (but not for regular model export code path).
|
||||
model_file_location = f if val_use_external_data_format and isinstance(f, str) else str()
|
||||
if (val_use_external_data_format is None or val_use_external_data_format is True) and isinstance(f, str):
|
||||
model_file_location = f
|
||||
else:
|
||||
model_file_location = str()
|
||||
return val_use_external_data_format, model_file_location
|
||||
|
||||
def _decide_input_format(model, args):
|
||||
@ -389,7 +406,7 @@ def _get_param_count_list(method_graph, args_params):
|
||||
return param_count_list
|
||||
|
||||
|
||||
def _create_jit_graph(model, args, _retain_param_name):
|
||||
def _create_jit_graph(model, args):
|
||||
torch_out = None
|
||||
params: Union[List, Tuple]
|
||||
if isinstance(model, torch.jit.ScriptModule):
|
||||
@ -420,13 +437,12 @@ def _create_jit_graph(model, args, _retain_param_name):
|
||||
graph, torch_out = _trace_and_get_graph_from_model(model, args)
|
||||
state_dict = _unique_state_dict(model)
|
||||
params = list(state_dict.values())
|
||||
if _retain_param_name:
|
||||
graph_inputs = list(graph.inputs())
|
||||
user_input_num = len(graph_inputs) - len(state_dict)
|
||||
param_names = list(state_dict.keys())
|
||||
for i, inp in enumerate(graph_inputs):
|
||||
if i >= user_input_num:
|
||||
inp.setDebugName(param_names[i - user_input_num])
|
||||
graph_inputs = list(graph.inputs())
|
||||
user_input_num = len(graph_inputs) - len(state_dict)
|
||||
param_names = list(state_dict.keys())
|
||||
for i, inp in enumerate(graph_inputs):
|
||||
if i >= user_input_num:
|
||||
inp.setDebugName(param_names[i - user_input_num])
|
||||
torch._C._jit_pass_onnx_function_substitution(graph)
|
||||
return graph, params, torch_out, None
|
||||
|
||||
@ -437,12 +453,25 @@ def _get_named_param_dict(graph, params):
|
||||
_params_dict = dict(zip(param_names, params))
|
||||
return _params_dict
|
||||
|
||||
def _get_example_outputs(model, args):
|
||||
input_args = copy.deepcopy(args)
|
||||
input_kwargs = {}
|
||||
if input_args and isinstance(input_args[-1], dict):
|
||||
input_kwargs = input_args[-1]
|
||||
input_args = input_args[:-1]
|
||||
|
||||
example_outputs = model(*input_args, **input_kwargs)
|
||||
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
|
||||
example_outputs = (example_outputs,)
|
||||
|
||||
if isinstance(example_outputs, list):
|
||||
example_outputs = [example_outputs]
|
||||
return example_outputs
|
||||
|
||||
def _model_to_graph(model, args, verbose=False,
|
||||
input_names=None, output_names=None,
|
||||
operator_export_type=OperatorExportTypes.ONNX,
|
||||
example_outputs=None,
|
||||
_retain_param_name=False, do_constant_folding=True,
|
||||
example_outputs=None, do_constant_folding=True,
|
||||
_disable_torch_constant_prop=False, fixed_batch_size=False,
|
||||
training=None, dynamic_axes=None):
|
||||
r"""Converts model into an ONNX graph.
|
||||
@ -461,11 +490,7 @@ def _model_to_graph(model, args, verbose=False,
|
||||
if isinstance(args, (torch.Tensor, int, float, bool)):
|
||||
args = (args, )
|
||||
|
||||
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
|
||||
example_outputs = (example_outputs,)
|
||||
|
||||
graph, params, torch_out, module = _create_jit_graph(model, args,
|
||||
_retain_param_name)
|
||||
graph, params, torch_out, module = _create_jit_graph(model, args)
|
||||
|
||||
params_dict = _get_named_param_dict(graph, params)
|
||||
|
||||
@ -476,13 +501,22 @@ def _model_to_graph(model, args, verbose=False,
|
||||
module=module)
|
||||
from torch.onnx.symbolic_helper import _onnx_shape_inference
|
||||
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
|
||||
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
|
||||
"ScriptFunction."
|
||||
if isinstance(example_outputs, list):
|
||||
example_outputs = [example_outputs]
|
||||
if example_outputs is None:
|
||||
example_outputs = _get_example_outputs(model, args)
|
||||
else:
|
||||
# example_outpus specified
|
||||
if isinstance(example_outputs, (torch.Tensor, int, float, bool)):
|
||||
example_outputs = (example_outputs,)
|
||||
|
||||
if isinstance(example_outputs, list):
|
||||
example_outputs = [example_outputs]
|
||||
|
||||
out_vars, desc = torch.jit._flatten(tuple(example_outputs))
|
||||
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, desc, _onnx_shape_inference)
|
||||
else:
|
||||
flatten_args, _ = torch._C._jit_flatten(args)
|
||||
# make sure that the param dict and the graph match each other
|
||||
assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs())
|
||||
|
||||
# NB: ONNX requires complete information about output types, which might be
|
||||
# erased by some optimizations, so we need to set it explicitly again.
|
||||
@ -533,14 +567,18 @@ def _model_to_graph(model, args, verbose=False,
|
||||
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
|
||||
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
|
||||
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
|
||||
google_printer=False, opset_version=None, _retain_param_name=True,
|
||||
google_printer=False, opset_version=None, _retain_param_name=None,
|
||||
keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True,
|
||||
do_constant_folding=True, dynamic_axes=None):
|
||||
if f is not None:
|
||||
warnings.warn("'f' is deprecated and ignored. It will be removed in the next PyTorch release.")
|
||||
if _retain_param_name is not None:
|
||||
warnings.warn("'_retain_param_name' is deprecated and ignored. "
|
||||
"It will be removed in the next PyTorch release.")
|
||||
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
|
||||
input_names, output_names, operator_export_type,
|
||||
export_type, example_outputs, google_printer,
|
||||
opset_version, _retain_param_name,
|
||||
do_constant_folding=do_constant_folding,
|
||||
opset_version, do_constant_folding=do_constant_folding,
|
||||
add_node_names=add_node_names,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
custom_opsets=custom_opsets, dynamic_axes=dynamic_axes)
|
||||
@ -549,7 +587,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
|
||||
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
|
||||
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
|
||||
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
|
||||
google_printer=False, opset_version=None, _retain_param_name=False,
|
||||
google_printer=False, opset_version=None,
|
||||
do_constant_folding=True, keep_initializers_as_inputs=None,
|
||||
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
|
||||
onnx_shape_inference=True, dynamic_axes=None):
|
||||
@ -572,8 +610,8 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
|
||||
args = _decide_input_format(model, args)
|
||||
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
|
||||
output_names, operator_export_type,
|
||||
example_outputs, _retain_param_name,
|
||||
val_do_constant_folding, fixed_batch_size=fixed_batch_size,
|
||||
example_outputs, val_do_constant_folding,
|
||||
fixed_batch_size=fixed_batch_size,
|
||||
training=training, dynamic_axes=dynamic_axes)
|
||||
|
||||
return graph._pretty_print_onnx(params_dict, opset_version, False,
|
||||
@ -633,10 +671,10 @@ def _find_missing_ops_onnx_export(model, args, f, verbose=False, training=Traini
|
||||
def _export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
input_names=None, output_names=None, operator_export_type=None,
|
||||
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
|
||||
opset_version=None, _retain_param_name=False, do_constant_folding=True,
|
||||
strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None,
|
||||
opset_version=None, do_constant_folding=True,
|
||||
dynamic_axes=None, keep_initializers_as_inputs=None,
|
||||
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
|
||||
use_external_data_format=False, onnx_shape_inference=True):
|
||||
use_external_data_format=None, onnx_shape_inference=True):
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
raise ValueError("torch.nn.DataParallel is not supported by ONNX "
|
||||
@ -685,8 +723,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
graph, params_dict, torch_out = \
|
||||
_model_to_graph(model, args, verbose, input_names,
|
||||
output_names, operator_export_type,
|
||||
example_outputs, _retain_param_name,
|
||||
val_do_constant_folding,
|
||||
example_outputs, val_do_constant_folding,
|
||||
fixed_batch_size=fixed_batch_size,
|
||||
training=training,
|
||||
dynamic_axes=dynamic_axes)
|
||||
@ -697,14 +734,14 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
|
||||
custom_opsets = {}
|
||||
|
||||
if export_params:
|
||||
proto, export_map = graph._export_onnx(
|
||||
proto, export_map, val_use_external_data_format = graph._export_onnx(
|
||||
params_dict, opset_version, dynamic_axes, defer_weight_export,
|
||||
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
|
||||
operator_export_type, not verbose, val_keep_init_as_ip, custom_opsets,
|
||||
val_add_node_names, val_use_external_data_format, model_file_location)
|
||||
else:
|
||||
proto, export_map = graph._export_onnx(
|
||||
proto, export_map, val_use_external_data_format = graph._export_onnx(
|
||||
{}, opset_version, dynamic_axes, False, operator_export_type,
|
||||
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
|
||||
not verbose, val_keep_init_as_ip, custom_opsets, val_add_node_names,
|
||||
val_use_external_data_format, model_file_location)
|
||||
|
||||
if export_type == ExportTypes.PROTOBUF_FILE:
|
||||
|
||||
@ -1256,6 +1256,8 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
|
||||
yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
|
||||
# Test for Broadcasting
|
||||
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
|
||||
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
|
||||
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
||||
@ -21,7 +21,9 @@ class MapperIterDataPipe(IterDataPipe):
|
||||
self.dp = dp
|
||||
self.fn = fn
|
||||
```
|
||||
Note: Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory.
|
||||
Note:
|
||||
- Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory.
|
||||
- If `IterDataPipe` instance holds data in memory, please be ware of the in-place modification of data. When second iterator is created from the instance, the data may have already changed. Please take [`IterableWrapper`](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/datapipes/iter/utils.py) class as reference to `deepcopy` data for each iterator.
|
||||
|
||||
### Iterator
|
||||
For `IterDataPipe`, an `__iter__` function is needed to consume data from the source `IterDataPipe` then apply operation over the data before yield.
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import warnings
|
||||
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
|
||||
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
|
||||
@ -99,8 +98,6 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
|
||||
data = list(data)
|
||||
else:
|
||||
t_flag = False
|
||||
# Deepcopy data to prevent the original data modified. E.g. list, dict
|
||||
data = copy.deepcopy(data)
|
||||
|
||||
if self.output_col is None:
|
||||
if isinstance(self.input_col, (list, tuple)):
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import random
|
||||
import warnings
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
|
||||
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
|
||||
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
@ -185,8 +185,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
|
||||
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
||||
assert batch_num > 0, "Number of batches is required to be larger than 0!"
|
||||
assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
|
||||
|
||||
warnings.warn("`BucketBatcher` is going to be removed from PyTorch Core")
|
||||
deprecation_warning_torchdata(type(self).__name__)
|
||||
super().__init__()
|
||||
|
||||
# TODO: Verify _datapippe is not going to be serialized twice
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Sized, Tuple
|
||||
from urllib.error import HTTPError, URLError
|
||||
import urllib.request as urllib
|
||||
from torch.utils.data import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
|
||||
|
||||
|
||||
class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
|
||||
@ -19,6 +20,7 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
|
||||
def __init__(self, datapipe, timeout=None):
|
||||
self.datapipe = datapipe
|
||||
self.timeout = timeout
|
||||
deprecation_warning_torchdata(type(self).__name__)
|
||||
|
||||
def __iter__(self):
|
||||
for furl in self.datapipe:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Tuple
|
||||
from torch.utils.data import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import deprecation_warning_torchdata
|
||||
|
||||
|
||||
class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
|
||||
@ -14,6 +15,7 @@ class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
|
||||
|
||||
def __init__(self, datapipe):
|
||||
self.datapipe = datapipe
|
||||
deprecation_warning_torchdata(type(self).__name__)
|
||||
|
||||
def __iter__(self):
|
||||
for file_name, stream in self.datapipe:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from torch.utils.data import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
|
||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple, deprecation_warning_torchdata
|
||||
from typing import Iterable, Iterator, Tuple, Optional, IO, cast
|
||||
from io import BufferedIOBase
|
||||
|
||||
@ -34,11 +34,13 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
|
||||
self.mode: str = mode
|
||||
self.length: int = length
|
||||
deprecation_warning_torchdata(type(self).__name__)
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
|
||||
for data in self.datapipe:
|
||||
validate_pathname_binary_tuple(data)
|
||||
pathname, data_stream = data
|
||||
folder_name = os.path.dirname(pathname)
|
||||
try:
|
||||
# typing.cast is used here to silence mypy's type checker
|
||||
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
|
||||
@ -49,14 +51,12 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
if extracted_fobj is None:
|
||||
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
|
||||
raise tarfile.ExtractError
|
||||
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
|
||||
inner_pathname = os.path.normpath(os.path.join(folder_name, tarinfo.name))
|
||||
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
|
||||
raise e
|
||||
finally:
|
||||
data_stream.close()
|
||||
|
||||
def __len__(self):
|
||||
if self.length == -1:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import copy
|
||||
import warnings
|
||||
from torch.utils.data import IterDataPipe
|
||||
|
||||
|
||||
@ -8,12 +10,34 @@ class IterableWrapperIterDataPipe(IterDataPipe):
|
||||
|
||||
Args:
|
||||
iterable: Iterable object to be wrapped into an IterDataPipe
|
||||
deepcopy: Option to deepcopy input iterable object for each
|
||||
iteration.
|
||||
|
||||
.. note::
|
||||
If `deepcopy` is set to False explicitly, users should ensure
|
||||
that data pipeline doesn't contain any in-place operations over
|
||||
the iterable instance, in order to prevent data inconsistency
|
||||
across iterations.
|
||||
"""
|
||||
def __init__(self, iterable):
|
||||
def __init__(self, iterable, deepcopy=True):
|
||||
self.iterable = iterable
|
||||
self.deepcopy = deepcopy
|
||||
|
||||
def __iter__(self):
|
||||
for data in self.iterable:
|
||||
source_data = self.iterable
|
||||
if self.deepcopy:
|
||||
try:
|
||||
source_data = copy.deepcopy(self.iterable)
|
||||
# For the case that data cannot be deep-copied,
|
||||
# all in-place operations will affect iterable variable.
|
||||
# When this DataPipe is iterated second time, it will
|
||||
# yield modified items.
|
||||
except TypeError:
|
||||
warnings.warn(
|
||||
"The input iterable can not be deepcopied, "
|
||||
"please be aware of in-place modification would affect source data"
|
||||
)
|
||||
for data in source_data:
|
||||
yield data
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from torch.utils.data import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
|
||||
from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple, deprecation_warning_torchdata
|
||||
from typing import Iterable, Iterator, Tuple, IO, cast
|
||||
from io import BufferedIOBase
|
||||
|
||||
@ -31,11 +31,13 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
super().__init__()
|
||||
self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
|
||||
self.length: int = length
|
||||
deprecation_warning_torchdata(type(self).__name__)
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
|
||||
for data in self.datapipe:
|
||||
validate_pathname_binary_tuple(data)
|
||||
pathname, data_stream = data
|
||||
folder_name = os.path.dirname(pathname)
|
||||
try:
|
||||
# typing.cast is used here to silence mypy's type checker
|
||||
zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
|
||||
@ -47,7 +49,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
elif zipinfo.filename.endswith('/'):
|
||||
continue
|
||||
extracted_fobj = zips.open(zipinfo)
|
||||
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
|
||||
inner_pathname = os.path.normpath(os.path.join(folder_name, zipinfo.filename))
|
||||
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
|
||||
@ -63,3 +63,9 @@ def validate_pathname_binary_tuple(data):
|
||||
raise TypeError("pathname binary tuple should have string type pathname, but got {}".format(type(data[0])))
|
||||
if not isinstance(data[1], BufferedIOBase):
|
||||
raise TypeError("pathname binary tuple should have BufferedIOBase based binary type, but got {}".format(type(data[1])))
|
||||
|
||||
# Warns user that the DataPipe has been moved to TorchData and will be removed from `torch`
|
||||
def deprecation_warning_torchdata(name):
|
||||
warnings.warn(f"{name} and its functional API are deprecated and will be removed from the package `torch`. "
|
||||
f"Please import those features from the new package TorchData: https://github.com/pytorch/data",
|
||||
DeprecationWarning)
|
||||
|
||||
@ -112,15 +112,18 @@ class RandomSampler(Sampler[int]):
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
n = len(self.data_source)
|
||||
if self.generator is None:
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
generator = self.generator
|
||||
|
||||
if self.replacement:
|
||||
for _ in range(self.num_samples // 32):
|
||||
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
|
||||
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
|
||||
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
||||
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
||||
else:
|
||||
yield from torch.randperm(n, generator=self.generator).tolist()
|
||||
yield from torch.randperm(n, generator=generator).tolist()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
@ -140,7 +143,8 @@ class SubsetRandomSampler(Sampler[int]):
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
@ -183,7 +187,7 @@ class WeightedRandomSampler(Sampler[int]):
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
|
||||
return iter(rand_tensor.tolist())
|
||||
yield from iter(rand_tensor.tolist())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
|
||||
Reference in New Issue
Block a user