mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Compare commits
11 Commits
mlazos/use
...
v1.8.0-rc2
| Author | SHA1 | Date | |
|---|---|---|---|
| c79decdbba | |||
| c307a3f336 | |||
| f071020756 | |||
| 4f436f8570 | |||
| ae11589710 | |||
| 9e5bcc1020 | |||
| fa8578241d | |||
| 1368809532 | |||
| 4073248fc2 | |||
| 75153cb730 | |||
| 5bb69b080c |
@ -153,11 +153,11 @@ commands:
|
|||||||
git config --global user.email "circleci.ossci@gmail.com"
|
git config --global user.email "circleci.ossci@gmail.com"
|
||||||
git config --global user.name "CircleCI"
|
git config --global user.name "CircleCI"
|
||||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8
|
||||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet
|
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet
|
||||||
# PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
|
# PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
|
||||||
if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
|
if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
|
||||||
CIRCLE_PR_BASE_BRANCH=master
|
CIRCLE_PR_BASE_BRANCH=release/1.8
|
||||||
fi
|
fi
|
||||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
|
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
|
||||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||||
|
|||||||
@ -111,11 +111,11 @@ commands:
|
|||||||
git config --global user.email "circleci.ossci@gmail.com"
|
git config --global user.email "circleci.ossci@gmail.com"
|
||||||
git config --global user.name "CircleCI"
|
git config --global user.name "CircleCI"
|
||||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8
|
||||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet
|
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet
|
||||||
# PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
|
# PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
|
||||||
if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
|
if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
|
||||||
CIRCLE_PR_BASE_BRANCH=master
|
CIRCLE_PR_BASE_BRANCH=release/1.8
|
||||||
fi
|
fi
|
||||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
|
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
|
||||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||||
|
|||||||
@ -182,7 +182,7 @@ fi
|
|||||||
|
|
||||||
# Patch required to build xla
|
# Patch required to build xla
|
||||||
if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
|
if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
|
||||||
git clone --recursive https://github.com/pytorch/xla.git
|
git clone --recursive -b r1.8 https://github.com/pytorch/xla.git
|
||||||
./xla/scripts/apply_patches.sh
|
./xla/scripts/apply_patches.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -54,7 +54,7 @@ function file_diff_from_base() {
|
|||||||
set +e
|
set +e
|
||||||
git fetch origin master --quiet
|
git fetch origin master --quiet
|
||||||
set -e
|
set -e
|
||||||
git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
|
git diff --name-only "$(git merge-base origin/release/1.8 HEAD)" > "$1"
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_bazel() {
|
function get_bazel() {
|
||||||
|
|||||||
@ -300,7 +300,7 @@ test_backward_compatibility() {
|
|||||||
pushd test/backward_compatibility
|
pushd test/backward_compatibility
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
. venv/bin/activate
|
. venv/bin/activate
|
||||||
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_test.html
|
||||||
pip show torch
|
pip show torch
|
||||||
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
||||||
deactivate
|
deactivate
|
||||||
|
|||||||
@ -19,6 +19,27 @@ namespace {
|
|||||||
|
|
||||||
using namespace vec256;
|
using namespace vec256;
|
||||||
|
|
||||||
|
// Note: Explicit implementation of copysign for Half and BFloat16
|
||||||
|
// is needed to workaround g++-7/8 crash on aarch64, but also makes
|
||||||
|
// copysign faster for the half-precision types
|
||||||
|
template<typename T>
|
||||||
|
T copysign(T a, T b) {
|
||||||
|
return std::copysign(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement copysign for half precision floats using bit ops
|
||||||
|
// Sign is the most significant bit for both half and bfloat16 types
|
||||||
|
template<>
|
||||||
|
c10::Half copysign(c10::Half a, c10::Half b) {
|
||||||
|
return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
|
||||||
|
return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Note: Undefined behavior when performing addition is intentionally
|
// Note: Undefined behavior when performing addition is intentionally
|
||||||
// ignored.
|
// ignored.
|
||||||
void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) {
|
void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) {
|
||||||
@ -180,7 +201,7 @@ void div_floor_kernel(TensorIterator& iter) {
|
|||||||
floordiv += scalar_t(1.0);
|
floordiv += scalar_t(1.0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
floordiv = std::copysign(scalar_t(0), a / b);
|
floordiv = copysign(scalar_t(0), a / b);
|
||||||
}
|
}
|
||||||
return floordiv;
|
return floordiv;
|
||||||
});
|
});
|
||||||
@ -889,23 +910,6 @@ void heaviside_kernel(TensorIterator& iter) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
T copysign(T a, T b) {
|
|
||||||
return std::copysign(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement copysign for half precision floats using bit ops
|
|
||||||
// Sign is the most significant bit for both half and bfloat16 types
|
|
||||||
template<>
|
|
||||||
c10::Half copysign(c10::Half a, c10::Half b) {
|
|
||||||
return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
|
|
||||||
return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
|
|
||||||
}
|
|
||||||
|
|
||||||
void copysign_kernel(TensorIterator& iter) {
|
void copysign_kernel(TensorIterator& iter) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
|
||||||
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
|
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
|
||||||
|
|||||||
@ -133,7 +133,9 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) {
|
|||||||
ASSERT_EQ(buffer1[i].z, buffer2[i].z);
|
ASSERT_EQ(buffer1[i].z, buffer2[i].z);
|
||||||
ASSERT_EQ(buffer1[i].w, buffer2[i].w);
|
ASSERT_EQ(buffer1[i].w, buffer2[i].w);
|
||||||
}
|
}
|
||||||
|
// Skipping this part until https://github.com/pytorch/pytorch/issues/51863 is resolved
|
||||||
|
|
||||||
|
#if 0
|
||||||
// unaligned
|
// unaligned
|
||||||
for (int i = 0; i < 16; i++) {
|
for (int i = 0; i < 16; i++) {
|
||||||
for (int j = 0; j < 16; j++) {
|
for (int j = 0; j < 16; j++) {
|
||||||
@ -151,4 +153,5 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,7 @@ int32_t driver_version() {
|
|||||||
return driver_version;
|
return driver_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
int device_count_impl() {
|
int device_count_impl(bool fail_if_no_driver) {
|
||||||
int count;
|
int count;
|
||||||
auto err = cudaGetDeviceCount(&count);
|
auto err = cudaGetDeviceCount(&count);
|
||||||
if (err == cudaSuccess) {
|
if (err == cudaSuccess) {
|
||||||
@ -34,6 +34,11 @@ int device_count_impl() {
|
|||||||
case cudaErrorInsufficientDriver: {
|
case cudaErrorInsufficientDriver: {
|
||||||
auto version = driver_version();
|
auto version = driver_version();
|
||||||
if (version <= 0) {
|
if (version <= 0) {
|
||||||
|
if (!fail_if_no_driver) {
|
||||||
|
// No CUDA driver means no devices
|
||||||
|
count = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Found no NVIDIA driver on your system. Please check that you "
|
"Found no NVIDIA driver on your system. Please check that you "
|
||||||
@ -95,9 +100,9 @@ DeviceIndex device_count() noexcept {
|
|||||||
// initialize number of devices only once
|
// initialize number of devices only once
|
||||||
static int count = []() {
|
static int count = []() {
|
||||||
try {
|
try {
|
||||||
auto result = device_count_impl();
|
auto result = device_count_impl(/*fail_if_no_driver=*/false);
|
||||||
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
|
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
|
||||||
return device_count_impl();
|
return result;
|
||||||
} catch (const c10::Error& ex) {
|
} catch (const c10::Error& ex) {
|
||||||
// We don't want to fail, but still log the warning
|
// We don't want to fail, but still log the warning
|
||||||
// msg() returns the message without the stack trace
|
// msg() returns the message without the stack trace
|
||||||
@ -110,7 +115,7 @@ DeviceIndex device_count() noexcept {
|
|||||||
|
|
||||||
DeviceIndex device_count_ensure_non_zero() {
|
DeviceIndex device_count_ensure_non_zero() {
|
||||||
// Call the implementation every time to throw the exception
|
// Call the implementation every time to throw the exception
|
||||||
int count = device_count_impl();
|
int count = device_count_impl(/*fail_if_no_driver=*/true);
|
||||||
// Zero gpus doesn't produce a warning in `device_count` but we fail here
|
// Zero gpus doesn't produce a warning in `device_count` but we fail here
|
||||||
TORCH_CHECK(count, "No CUDA GPUs are available");
|
TORCH_CHECK(count, "No CUDA GPUs are available");
|
||||||
return static_cast<DeviceIndex>(count);
|
return static_cast<DeviceIndex>(count);
|
||||||
|
|||||||
74
docs/source/ddp_comm_hooks.rst
Normal file
74
docs/source/ddp_comm_hooks.rst
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
DDP Communication Hooks
|
||||||
|
=======================
|
||||||
|
|
||||||
|
DDP communication hook is a generic interface to control how to communicate
|
||||||
|
gradients across workers by overriding the vanilla allreduce in
|
||||||
|
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_.
|
||||||
|
A few built-in communication hooks are provided,
|
||||||
|
and users can easily apply any of these hooks to optimize communication.
|
||||||
|
Besides, the hook interface can also support user-defined communication
|
||||||
|
strategies for more advanced use cases.
|
||||||
|
|
||||||
|
.. warning ::
|
||||||
|
DDP communication hook is experimental and subject to change.
|
||||||
|
|
||||||
|
.. warning ::
|
||||||
|
DDP communication hooks can only support single process single device mode
|
||||||
|
on NCCL backend.
|
||||||
|
|
||||||
|
How to Use a Communication Hook?
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
To use a communication hook, the user just needs to let the DDP model register
|
||||||
|
the hook before the training loop as below.
|
||||||
|
|
||||||
|
:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`.
|
||||||
|
:noindex:
|
||||||
|
|
||||||
|
Default Communication Hooks
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
Default communication hooks are simple **stateless** hooks, so the input state
|
||||||
|
in ``register_comm_hook`` is either a process group or ``None``.
|
||||||
|
|
||||||
|
.. automodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
|
||||||
|
:members:
|
||||||
|
|
||||||
|
PowerSGD Communication Hook
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_)
|
||||||
|
is a gradient compression algorithm, which can provide very high compression
|
||||||
|
rates and accelerate bandwidth-bound distributed training.
|
||||||
|
This algorithm needs to maintain both some hyperparameters and the internal
|
||||||
|
state. Therefore, PowerSGD communication hook is a **stateful** hook,
|
||||||
|
and the user needs to provide a state object defined as below.
|
||||||
|
|
||||||
|
PowerSGD State
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
|
||||||
|
.. autoclass:: PowerSGDState
|
||||||
|
|
||||||
|
PowerSGD Hooks
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. warning ::
|
||||||
|
PowerSGD typically requires extra memory of the same size as the model's
|
||||||
|
gradients to enable error feedback, which can compensate for biased
|
||||||
|
compressed communication and improve accuracy.
|
||||||
|
|
||||||
|
.. warning ::
|
||||||
|
The current implementation may cause gradient overflow for FP16 input.
|
||||||
|
|
||||||
|
.. autofunction:: powerSGD_hook
|
||||||
|
.. autofunction:: batched_powerSGD_hook
|
||||||
|
|
||||||
|
Acknowledgements
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on
|
||||||
|
PowerSGD communication hook, as well as the
|
||||||
|
`comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_,
|
||||||
|
which show that the performance of PowerSGD communication hook is on par with
|
||||||
|
the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_.
|
||||||
@ -71,6 +71,7 @@ Features described in this documentation are classified by release status:
|
|||||||
onnx
|
onnx
|
||||||
optim
|
optim
|
||||||
complex_numbers
|
complex_numbers
|
||||||
|
ddp_comm_hooks
|
||||||
pipeline
|
pipeline
|
||||||
quantization
|
quantization
|
||||||
rpc
|
rpc
|
||||||
|
|||||||
@ -484,6 +484,7 @@ Sparse tensor functions
|
|||||||
+++++++++++++++++++++++
|
+++++++++++++++++++++++
|
||||||
|
|
||||||
.. autofunction:: torch.sparse_coo_tensor
|
.. autofunction:: torch.sparse_coo_tensor
|
||||||
|
:noindex:
|
||||||
.. autofunction:: torch.sparse.sum
|
.. autofunction:: torch.sparse.sum
|
||||||
.. autofunction:: torch.sparse.addmm
|
.. autofunction:: torch.sparse.addmm
|
||||||
.. autofunction:: torch.sparse.mm
|
.. autofunction:: torch.sparse.mm
|
||||||
|
|||||||
45
setup.py
45
setup.py
@ -552,6 +552,50 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
with open('compile_commands.json', 'w') as f:
|
with open('compile_commands.json', 'w') as f:
|
||||||
f.write(new_contents)
|
f.write(new_contents)
|
||||||
|
|
||||||
|
class concat_license_files():
|
||||||
|
"""Merge LICENSE and LICENSES_BUNDLED.txt as a context manager
|
||||||
|
|
||||||
|
LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated
|
||||||
|
from all the licenses found in ./third_party/. We concatenate them so there
|
||||||
|
is a single license file in the sdist and wheels with all of the necessary
|
||||||
|
licensing info.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.f1 = 'LICENSE'
|
||||||
|
self.f2 = 'third_party/LICENSES_BUNDLED.txt'
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Concatenate files"""
|
||||||
|
with open(self.f1, 'r') as f1:
|
||||||
|
self.bsd_text = f1.read()
|
||||||
|
|
||||||
|
with open(self.f1, 'a') as f1:
|
||||||
|
with open(self.f2, 'r') as f2:
|
||||||
|
self.bundled_text = f2.read()
|
||||||
|
f1.write('\n\n')
|
||||||
|
f1.write(self.bundled_text)
|
||||||
|
|
||||||
|
def __exit__(self, exception_type, exception_value, traceback):
|
||||||
|
"""Restore content of f1"""
|
||||||
|
with open(self.f1, 'w') as f:
|
||||||
|
f.write(self.bsd_text)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from wheel.bdist_wheel import bdist_wheel
|
||||||
|
except ImportError:
|
||||||
|
# This is useful when wheel is not installed and bdist_wheel is not
|
||||||
|
# specified on the command line. If it _is_ specified, parsing the command
|
||||||
|
# line will fail before wheel_concatenate is needed
|
||||||
|
wheel_concatenate = None
|
||||||
|
else:
|
||||||
|
# Need to create the proper LICENSE.txt for the wheel
|
||||||
|
class wheel_concatenate(bdist_wheel):
|
||||||
|
""" check submodules on sdist to prevent incomplete tarballs """
|
||||||
|
def run(self):
|
||||||
|
with concat_license_files():
|
||||||
|
super().run()
|
||||||
|
|
||||||
|
|
||||||
class install(setuptools.command.install.install):
|
class install(setuptools.command.install.install):
|
||||||
def run(self):
|
def run(self):
|
||||||
@ -724,6 +768,7 @@ def configure_extension_build():
|
|||||||
'build_ext': build_ext,
|
'build_ext': build_ext,
|
||||||
'clean': clean,
|
'clean': clean,
|
||||||
'install': install,
|
'install': install,
|
||||||
|
'bdist_wheel': wheel_concatenate,
|
||||||
}
|
}
|
||||||
|
|
||||||
entry_points = {
|
entry_points = {
|
||||||
|
|||||||
@ -872,7 +872,7 @@ class TestFakeQuantize(TestCase):
|
|||||||
scale, zero_point = float(scale), int(zero_point)
|
scale, zero_point = float(scale), int(zero_point)
|
||||||
quant_min, quant_max = obs._calculate_qmin_qmax()
|
quant_min, quant_max = obs._calculate_qmin_qmax()
|
||||||
|
|
||||||
Y_test, _mask = torch.fake_quantize_per_tensor_affine_cachemask(
|
Y_test = torch.fake_quantize_per_tensor_affine(
|
||||||
X, scale, zero_point, quant_min, quant_max)
|
X, scale, zero_point, quant_min, quant_max)
|
||||||
Y_ref = _fake_quantize_per_tensor_affine_reference(
|
Y_ref = _fake_quantize_per_tensor_affine_reference(
|
||||||
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
|
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
|
||||||
@ -899,7 +899,7 @@ class TestFakeQuantize(TestCase):
|
|||||||
quant_min, quant_max = obs._calculate_qmin_qmax()
|
quant_min, quant_max = obs._calculate_qmin_qmax()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
Y_test, mask = torch.fake_quantize_per_tensor_affine_cachemask(
|
Y_test = torch.fake_quantize_per_tensor_affine(
|
||||||
X, scale, zero_point, quant_min, quant_max)
|
X, scale, zero_point, quant_min, quant_max)
|
||||||
Y_ref = _fake_quantize_per_tensor_affine_reference(
|
Y_ref = _fake_quantize_per_tensor_affine_reference(
|
||||||
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
|
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
|
||||||
@ -1246,7 +1246,7 @@ class TestFakeQuantize(TestCase):
|
|||||||
|
|
||||||
Y = _fake_quantize_per_channel_affine_reference(
|
Y = _fake_quantize_per_channel_affine_reference(
|
||||||
X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
|
X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
|
||||||
Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
|
Y_prime = torch.fake_quantize_per_channel_affine(
|
||||||
X, scale, zero_point, axis, quant_min, quant_max)
|
X, scale, zero_point, axis, quant_min, quant_max)
|
||||||
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||||
|
|
||||||
@ -1339,7 +1339,7 @@ class TestFakeQuantize(TestCase):
|
|||||||
zero_point = zero_point.to(torch.int64)
|
zero_point = zero_point.to(torch.int64)
|
||||||
quant_min, quant_max = obs._calculate_qmin_qmax()
|
quant_min, quant_max = obs._calculate_qmin_qmax()
|
||||||
X.requires_grad_()
|
X.requires_grad_()
|
||||||
Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
|
Y_prime = torch.fake_quantize_per_channel_affine(
|
||||||
X, scale, zero_point, axis, quant_min, quant_max)
|
X, scale, zero_point, axis, quant_min, quant_max)
|
||||||
dout = torch.rand(X.shape, dtype=torch.float).to(device)
|
dout = torch.rand(X.shape, dtype=torch.float).to(device)
|
||||||
dX = _fake_quantize_per_channel_affine_grad_reference(
|
dX = _fake_quantize_per_channel_affine_grad_reference(
|
||||||
|
|||||||
@ -108,6 +108,7 @@ TESTS = [
|
|||||||
'test_fx_experimental',
|
'test_fx_experimental',
|
||||||
'test_functional_autograd_benchmark',
|
'test_functional_autograd_benchmark',
|
||||||
'test_package',
|
'test_package',
|
||||||
|
'test_license',
|
||||||
'distributed/pipeline/sync/skip/test_api',
|
'distributed/pipeline/sync/skip/test_api',
|
||||||
'distributed/pipeline/sync/skip/test_gpipe',
|
'distributed/pipeline/sync/skip/test_gpipe',
|
||||||
'distributed/pipeline/sync/skip/test_inspect_skip_layout',
|
'distributed/pipeline/sync/skip/test_inspect_skip_layout',
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from math import sqrt
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch.multiprocessing import Process
|
from torch.multiprocessing import Process
|
||||||
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap
|
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap
|
||||||
from torch.fx.node import Target
|
from torch.fx.node import Target, Argument
|
||||||
from torch.fx.passes import shape_prop
|
from torch.fx.passes import shape_prop
|
||||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -187,7 +187,7 @@ class TestFX(JitTestCase):
|
|||||||
# Custom delegate to disallow in-place tensor operations
|
# Custom delegate to disallow in-place tensor operations
|
||||||
class NoMutableCallTracer(Tracer):
|
class NoMutableCallTracer(Tracer):
|
||||||
def create_node(self, kind : str, target : Union[str, Callable],
|
def create_node(self, kind : str, target : Union[str, Callable],
|
||||||
args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
|
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
||||||
type_expr : Optional[Any] = None) -> Node:
|
type_expr : Optional[Any] = None) -> Node:
|
||||||
name = target if isinstance(target, str) else torch.typename(target)
|
name = target if isinstance(target, str) else torch.typename(target)
|
||||||
if name[-1] == '_':
|
if name[-1] == '_':
|
||||||
@ -539,7 +539,7 @@ class TestFX(JitTestCase):
|
|||||||
def test_node_tagging(self):
|
def test_node_tagging(self):
|
||||||
class TaggingTracer(Tracer):
|
class TaggingTracer(Tracer):
|
||||||
def create_node(self, kind : str, target : Union[str, Callable],
|
def create_node(self, kind : str, target : Union[str, Callable],
|
||||||
args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
|
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
||||||
type_expr : Optional[Any] = None) -> Node:
|
type_expr : Optional[Any] = None) -> Node:
|
||||||
n = super().create_node(kind, target, args, kwargs, name)
|
n = super().create_node(kind, target, args, kwargs, name)
|
||||||
n.tag = 'foo'
|
n.tag = 'foo'
|
||||||
@ -1057,6 +1057,13 @@ class TestFX(JitTestCase):
|
|||||||
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
|
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
|
||||||
self.assertEqual(result, torch.ones(3, 4) * 2.0)
|
self.assertEqual(result, torch.ones(3, 4) * 2.0)
|
||||||
|
|
||||||
|
@skipIfNoTorchVision
|
||||||
|
def test_interpreter_noop_resnet18(self):
|
||||||
|
rn18 = resnet18()
|
||||||
|
transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
|
||||||
|
inp = torch.randn(5, 3, 224, 224)
|
||||||
|
self.assertEqual(transformed(inp), rn18(inp))
|
||||||
|
|
||||||
def test_transformer_noop(self):
|
def test_transformer_noop(self):
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -1377,6 +1384,45 @@ class TestFX(JitTestCase):
|
|||||||
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
||||||
self.checkGraphModule(foo, (x, y))
|
self.checkGraphModule(foo, (x, y))
|
||||||
|
|
||||||
|
def test_trace_dict_int_keys(self):
|
||||||
|
class ModWithDictArg(torch.nn.Module):
|
||||||
|
def forward(self, d : Dict[int, torch.Tensor]):
|
||||||
|
return d[42]
|
||||||
|
|
||||||
|
class CallsModWithDict(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.m = ModWithDictArg()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.m({42: x})
|
||||||
|
|
||||||
|
class MyTracer(torch.fx.Tracer):
|
||||||
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||||
|
return isinstance(m, ModWithDictArg)
|
||||||
|
|
||||||
|
traced_graph = MyTracer().trace(CallsModWithDict())
|
||||||
|
|
||||||
|
def test_trace_dict_proxy_keys(self):
|
||||||
|
class ModWithDictArg(torch.nn.Module):
|
||||||
|
def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
|
||||||
|
return d[42]
|
||||||
|
|
||||||
|
class CallsModWithDict(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.m = ModWithDictArg()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.m({x: x})
|
||||||
|
|
||||||
|
class MyTracer(torch.fx.Tracer):
|
||||||
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||||
|
return isinstance(m, ModWithDictArg)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
|
||||||
|
traced_graph = MyTracer().trace(CallsModWithDict())
|
||||||
|
|
||||||
def test_direct_param_use(self):
|
def test_direct_param_use(self):
|
||||||
class TransposeTest(torch.nn.Module):
|
class TransposeTest(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@ -5,14 +5,14 @@ from typing import Callable, Dict, Union, List
|
|||||||
from torch.fx.symbolic_trace import symbolic_trace
|
from torch.fx.symbolic_trace import symbolic_trace
|
||||||
from torch.fx.graph_module import GraphModule
|
from torch.fx.graph_module import GraphModule
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from torch.fx.experimental import graph_manipulation
|
from torch.fx._experimental import graph_manipulation
|
||||||
from torch.fx.experimental.accelerator_partitioner import Partitioner
|
from torch.fx._experimental.accelerator_partitioner import Partitioner
|
||||||
from torch.fx.experimental.rewriter import RewritingTracer
|
from torch.fx._experimental.rewriter import RewritingTracer
|
||||||
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
|
from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
from torch.fx.passes.split_module import split_module
|
from torch.fx.passes.split_module import split_module
|
||||||
from torch.fx.experimental.partitioner_utils import (
|
from torch.fx._experimental.partitioner_utils import (
|
||||||
NodeLatency,
|
NodeLatency,
|
||||||
get_partition_to_latency_mapping,
|
get_partition_to_latency_mapping,
|
||||||
get_latency_of_partitioned_graph,
|
get_latency_of_partitioned_graph,
|
||||||
@ -20,8 +20,8 @@ from torch.fx.experimental.partitioner_utils import (
|
|||||||
PartitionerConfig,
|
PartitionerConfig,
|
||||||
PartitionMode
|
PartitionMode
|
||||||
)
|
)
|
||||||
from torch.fx.experimental.fuser import fuse
|
from torch.fx._experimental.fuser import fuse
|
||||||
from torch.fx.experimental import merge_matmul
|
from torch.fx._experimental import merge_matmul
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
@ -849,7 +849,7 @@ terrible spacing
|
|||||||
|
|
||||||
def test_merge_matmuls(self):
|
def test_merge_matmuls(self):
|
||||||
"""
|
"""
|
||||||
A collection of test cases for torch.fx.experimental.merge_matmul,
|
A collection of test cases for torch.fx._experimental.merge_matmul,
|
||||||
a graph transformation that merges matrix multiplication operations.
|
a graph transformation that merges matrix multiplication operations.
|
||||||
"""
|
"""
|
||||||
# Utility function for counting matmuls for test assertions.
|
# Utility function for counting matmuls for test assertions.
|
||||||
|
|||||||
@ -6503,6 +6503,38 @@ a")
|
|||||||
self.checkModule(module().train(), ())
|
self.checkModule(module().train(), ())
|
||||||
self.checkModule(module().eval(), ())
|
self.checkModule(module().eval(), ())
|
||||||
|
|
||||||
|
def test_ternary_static_if(self):
|
||||||
|
# Test for True branch when condition variable
|
||||||
|
# is annotated as Final
|
||||||
|
class M1(torch.nn.Module):
|
||||||
|
flag: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.flag = True
|
||||||
|
|
||||||
|
def forward(self) -> torch.Tensor:
|
||||||
|
return torch.ones(3) if self.flag else {}
|
||||||
|
|
||||||
|
# Test for True branch when condition variable
|
||||||
|
# is annotated as Final
|
||||||
|
class M2(torch.nn.Module):
|
||||||
|
flag: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.flag = False
|
||||||
|
|
||||||
|
def forward(self) -> torch.Tensor:
|
||||||
|
return {} if self.flag else torch.ones(3)
|
||||||
|
|
||||||
|
model1 = M1()
|
||||||
|
model2 = M2()
|
||||||
|
script_model_1 = torch.jit.script(model1)
|
||||||
|
script_model_2 = torch.jit.script(model2)
|
||||||
|
self.assertEqual(model1.forward(), script_model_1.forward())
|
||||||
|
self.assertEqual(model2.forward(), script_model_2.forward())
|
||||||
|
|
||||||
def test_print(self):
|
def test_print(self):
|
||||||
def func(x, y):
|
def func(x, y):
|
||||||
q = (x + y).sigmoid()
|
q = (x + y).sigmoid()
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
|
import glob
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
|
|
||||||
|
|
||||||
@ -10,11 +13,14 @@ except ImportError:
|
|||||||
create_bundled = None
|
create_bundled = None
|
||||||
|
|
||||||
license_file = 'third_party/LICENSES_BUNDLED.txt'
|
license_file = 'third_party/LICENSES_BUNDLED.txt'
|
||||||
|
starting_txt = 'The Pytorch repository and source distributions bundle'
|
||||||
|
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
|
||||||
|
distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info'))
|
||||||
|
|
||||||
class TestLicense(TestCase):
|
class TestLicense(TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not create_bundled, "can only be run in a source tree")
|
@unittest.skipIf(not create_bundled, "can only be run in a source tree")
|
||||||
def test_license_in_wheel(self):
|
def test_license_for_wheel(self):
|
||||||
current = io.StringIO()
|
current = io.StringIO()
|
||||||
create_bundled('third_party', current)
|
create_bundled('third_party', current)
|
||||||
with open(license_file) as fid:
|
with open(license_file) as fid:
|
||||||
@ -25,6 +31,18 @@ class TestLicense(TestCase):
|
|||||||
'match the current state of the third_party files. Use '
|
'match the current state of the third_party files. Use '
|
||||||
'"python third_party/build_bundled.py" to regenerate it')
|
'"python third_party/build_bundled.py" to regenerate it')
|
||||||
|
|
||||||
|
@unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
|
||||||
|
def test_distinfo_license(self):
|
||||||
|
"""If run when pytorch is installed via a wheel, the license will be in
|
||||||
|
site-package/torch-*dist-info/LICENSE. Make sure it contains the third
|
||||||
|
party bundle of licenses"""
|
||||||
|
|
||||||
|
if len(distinfo) > 1:
|
||||||
|
raise AssertionError('Found too many "torch-*dist-info" directories '
|
||||||
|
f'in "{site_packages}, expected only one')
|
||||||
|
with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid:
|
||||||
|
txt = fid.read()
|
||||||
|
self.assertTrue(starting_txt in txt)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|||||||
@ -82,7 +82,8 @@ SKIP_PYTHON_BINDINGS = [
|
|||||||
'set_data',
|
'set_data',
|
||||||
'.*_overrideable', # overrideable functions for backend extension
|
'.*_overrideable', # overrideable functions for backend extension
|
||||||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
|
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
|
||||||
'_fw_primal'
|
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
|
||||||
|
'fake_quantize_per_channel_affine_cachemask',
|
||||||
]
|
]
|
||||||
|
|
||||||
# These function signatures are not exposed to Python. Note that this signature
|
# These function signatures are not exposed to Python. Note that this signature
|
||||||
|
|||||||
@ -1258,6 +1258,15 @@ struct to_ir {
|
|||||||
const TernaryIf& expr,
|
const TernaryIf& expr,
|
||||||
const TypePtr& type_hint = nullptr) {
|
const TypePtr& type_hint = nullptr) {
|
||||||
CondValue cond_value = emitCondExpr(expr.cond());
|
CondValue cond_value = emitCondExpr(expr.cond());
|
||||||
|
// If the cond expr is a static value, then we metacompile the `if`
|
||||||
|
// statemement and only emit true or false branch
|
||||||
|
if (cond_value.staticIf()) {
|
||||||
|
if (*cond_value.staticIf()) {
|
||||||
|
return emitExpr(expr.true_expr(), type_hint);
|
||||||
|
} else {
|
||||||
|
return emitExpr(expr.false_expr(), type_hint);
|
||||||
|
}
|
||||||
|
}
|
||||||
auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); };
|
auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); };
|
||||||
auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); };
|
auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); };
|
||||||
return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
|
return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
|
||||||
|
|||||||
@ -33,39 +33,38 @@ def _orthogonalize(matrix, epsilon=1e-8):
|
|||||||
|
|
||||||
|
|
||||||
class PowerSGDState(object):
|
class PowerSGDState(object):
|
||||||
"""
|
r"""
|
||||||
Stores both the gradient compression configs and the internal states for all the gradients during the training.
|
Stores both the algorithm's hyperparameters and the internal state for all the gradients during the training.
|
||||||
Particularly, `matrix_approximation_rank` and `start_powerSGD_iter` are the main configs that need to be tuned by the user.
|
Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
|
||||||
Although `use_error_feedback` and `warm_start` can also be tuned by the user,
|
For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
|
||||||
they are typically turned on for performance.
|
|
||||||
|
|
||||||
Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`]
|
1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
1) To tune `matrix_approximation_rank`, the user can increase it from 1 by factors of 2,
|
1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
|
||||||
until a satisfying accuracy can be reached.
|
|
||||||
The increase of `matrix_approximation_rank` can substantially increase the computation costs of the compression.
|
1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be futher improved beyond a certain ``matrix_approximation_rank`` threshold.
|
||||||
However, the accuracy may not be futher improved beyond a certain `matrix_approximation_rank` value.
|
|
||||||
2) To tune `start_powerSGD_iter`, the user can typically start with 10% of total training steps,
|
To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an expoential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
|
||||||
and increase it until a satisfying accuracy can be reached.
|
|
||||||
Deferrring PowerSGD can effectively improve the accuracy,
|
2. ``start_powerSGD_iter`` defers PowerSGD compression util step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
|
||||||
even a relatively small `matrix_approximation_rank` is used.
|
|
||||||
This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients,
|
To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached.
|
||||||
and compressing gradients too early may make the training quickly take a suboptimal trajectory,
|
|
||||||
which can result in an irrecoverable impact on the accuracy.
|
.. warning ::
|
||||||
The minimum value allowed in DDP is 2, if error feedback or warm-up is enabled.
|
If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
|
||||||
This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
|
This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
|
||||||
and this can conflict with any tensor memorized before the rebuild process.
|
and this can conflict with any tensor memorized before the rebuild process.
|
||||||
"""
|
""" # noqa
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"process_group",
|
"process_group",
|
||||||
# The two fields below are the configs that usually need to be tuned by the user.
|
# The two fields below are the hyperparameters that should be tuned by the user.
|
||||||
"matrix_approximation_rank",
|
"matrix_approximation_rank",
|
||||||
"start_powerSGD_iter",
|
"start_powerSGD_iter",
|
||||||
# The two fields below are the configs that usually need to be turned on for performance.
|
# The two fields below are the binary hyperparameters recommended to be turned on for performance.
|
||||||
"use_error_feedback",
|
"use_error_feedback",
|
||||||
"warm_start",
|
"warm_start",
|
||||||
# The fields below are not configs.
|
# The fields below are internal state.
|
||||||
"rng",
|
"rng",
|
||||||
"error_dict",
|
"error_dict",
|
||||||
"p_memory_dict",
|
"p_memory_dict",
|
||||||
@ -93,21 +92,12 @@ class PowerSGDState(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
# The low rank for matrix approximation controls the size of compressed low-rank tensors,
|
|
||||||
# which determines the computation ratio.
|
|
||||||
# Typically only a small value 1-4 is used.
|
|
||||||
# For some NLP tasks (as shown in Appendix D of the original paper
|
|
||||||
# https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32.
|
|
||||||
# A high rank value will increase the computation costs of compression exponentially.
|
|
||||||
# A good choice depends on how much extra computation can be hidden by the dominating communication costs.
|
|
||||||
self.matrix_approximation_rank = matrix_approximation_rank
|
self.matrix_approximation_rank = matrix_approximation_rank
|
||||||
# This defers PowerSGD compression util step 'start_powerSGD_iter',
|
# Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
|
||||||
# and vanilla allreduce runs before step 'start_powerSGD_iter'.
|
|
||||||
# This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages:
|
|
||||||
# 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
|
# 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
|
||||||
# even if the matrix approximation rank is increased to a large value.
|
# even if the matrix approximation rank is increased to a large value.
|
||||||
# To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
|
# To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
|
||||||
# (or a more convervative compression such as FP16 compression) with PowerSGD.
|
# (or a more conservative compression such as FP16 compression) with PowerSGD.
|
||||||
# 2) There is an internal optimization of rebuilding buckets process in DDP,
|
# 2) There is an internal optimization of rebuilding buckets process in DDP,
|
||||||
# in order to save the memory space.
|
# in order to save the memory space.
|
||||||
# This step takes place after the first iteration.
|
# This step takes place after the first iteration.
|
||||||
@ -162,38 +152,44 @@ class PowerSGDState(object):
|
|||||||
|
|
||||||
|
|
||||||
def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||||
"""
|
r"""
|
||||||
This DDP communication hook implements the original PowerSGD gradient compression
|
This DDP communication hook implements PowerSGD gradient compression
|
||||||
algorithm described in https://arxiv.org/abs/1905.13727.
|
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
|
||||||
Once gradient tensors are aggregated across all workers, this hook applies
|
Once gradient tensors are aggregated across all workers, this hook applies
|
||||||
compression as follows:
|
compression as follows:
|
||||||
1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
|
|
||||||
high-rank tensors and vector-like rank-1 tensors (for biases).
|
1. Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: high-rank tensors and vector-like rank-1 tensors (for biases).
|
||||||
2) Handles rank-1 tensors by allreducing them without compression:
|
|
||||||
2.1) Allocate contiguous memory for those rank-1 tensors,
|
2. Handles rank-1 tensors by allreducing them without compression:
|
||||||
and allreduces all the rank-1 tensors as a batch, without compression;
|
|
||||||
2.2) Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
|
2.1. Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression;
|
||||||
3) Handles high-rank tensors by PowerSGD compression:
|
|
||||||
3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
|
2.2. Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
|
||||||
|
|
||||||
|
3. Handles high-rank tensors by PowerSGD compression:
|
||||||
|
|
||||||
|
3.1. For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
|
||||||
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||||
3.2) Computes each P in Ps, which is equal to MQ;
|
|
||||||
3.3) Allreduces Ps as a batch;
|
|
||||||
3.4) Orthogonalizes each P in Ps;
|
|
||||||
3.5) Computes each Q in Qs, which is approximately equal to M^TP;
|
|
||||||
3.6) Allreduces Qs as a batch;
|
|
||||||
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
|
|
||||||
|
|
||||||
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
|
3.2. Computes each P in Ps, which is equal to MQ;
|
||||||
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
|
|
||||||
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
|
||||||
|
|
||||||
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
|
3.3. Allreduces Ps as a batch;
|
||||||
one left multiplication and one right multiplication.
|
|
||||||
For warm-start, can take one such step at a time, and alternate between them.
|
3.4. Orthogonalizes each P in Ps;
|
||||||
|
|
||||||
|
3.5. Computes each Q in Qs, which is approximately equal to M^TP;
|
||||||
|
|
||||||
|
3.6. Allreduces Qs as a batch;
|
||||||
|
|
||||||
|
3.7. Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
|
||||||
|
|
||||||
|
Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
|
||||||
|
This not only gives the user more control over the tradeoff between speedup and accuracy,
|
||||||
|
but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
||||||
To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
|
To tune the compression configs, mainly need to tune `matrix_approximation_rank`` and ``start_powerSGD_iter``.
|
||||||
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
||||||
Note that since DDP comm hook only supports single process single device mode at this time,
|
Note that since DDP comm hook only supports single process single device mode at this time,
|
||||||
only exactly one tensor is stored in this bucket.
|
only exactly one tensor is stored in this bucket.
|
||||||
@ -202,9 +198,9 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
|||||||
Future handler of the communication, which updates the gradients in place.
|
Future handler of the communication, which updates the gradients in place.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
|
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
|
||||||
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
||||||
"""
|
""" # noqa
|
||||||
process_group = state.process_group
|
process_group = state.process_group
|
||||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||||
world_size = group_to_use.size()
|
world_size = group_to_use.size()
|
||||||
@ -374,6 +370,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
|||||||
for tensor, p, q in zip(high_rank_tensors, ps, qs):
|
for tensor, p, q in zip(high_rank_tensors, ps, qs):
|
||||||
torch.matmul(tensor.t(), p, out=q)
|
torch.matmul(tensor.t(), p, out=q)
|
||||||
|
|
||||||
|
# TODO: The above procedure does two matmul+allreduce steps per iteration --
|
||||||
|
# one left multiplication and one right multiplication.
|
||||||
|
# For warm-start, can take one such step at a time, and alternate between them.
|
||||||
|
|
||||||
# Allreduce Qs.
|
# Allreduce Qs.
|
||||||
return [
|
return [
|
||||||
dist.all_reduce(
|
dist.all_reduce(
|
||||||
@ -412,40 +412,48 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
|||||||
|
|
||||||
|
|
||||||
def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||||
"""
|
r"""
|
||||||
This DDP communication hook implements a simplified PowerSGD gradient compression
|
This DDP communication hook implements a simplified PowerSGD gradient compression
|
||||||
algorithm described in https://arxiv.org/abs/1905.13727.
|
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
|
||||||
|
This variant does not compress the gradients layer by layer,
|
||||||
|
but instead compresses the flattened input tensor that batches all the gradients.
|
||||||
|
Therefore, it is **faster** than :meth:`powerSGD_hook`,
|
||||||
|
but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
|
||||||
|
|
||||||
|
.. warning ::
|
||||||
|
Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
|
||||||
|
because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
|
||||||
|
Therefore, the user should always consider :meth:`powerSGD_hook` first,
|
||||||
|
and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
|
||||||
|
|
||||||
Once gradient tensors are aggregated across all workers, this hook applies
|
Once gradient tensors are aggregated across all workers, this hook applies
|
||||||
compression to the flattened input tensor that batches per-parameter tensors as follows:
|
compression as follows:
|
||||||
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
|
|
||||||
2) Creates two low-rank tensors P and Q for decomposing M,
|
|
||||||
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
|
||||||
2) Computes P, which is equal to MQ;
|
|
||||||
3) Allreduces P;
|
|
||||||
4) Orthogonalizes P;
|
|
||||||
5) Computes Q, which is approximately equal to M^TP;
|
|
||||||
6) Allreduces Q;
|
|
||||||
7) Computes M, which is approximately equal to PQ^T.
|
|
||||||
8) Truncates the input tensor to the original length.
|
|
||||||
|
|
||||||
This variant is faster than `powerSGD_hook` that runs layer-wise gradient compression,
|
1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
|
||||||
but it usually results in a much lower accuracy, unless `matrix_approximation_rank` in the state is 1.
|
|
||||||
Increasing `matrix_approximation_rank` may not necessarily increase the accuracy,
|
|
||||||
because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
|
|
||||||
Therefore, the user shoud always consider `powerSGD_hook` first,
|
|
||||||
and only consider this variant when a satisfying accuracy can be achieved when `matrix_approximation_rank` is 1.
|
|
||||||
|
|
||||||
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
|
2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||||
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
|
|
||||||
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
|
||||||
|
|
||||||
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
|
3. Computes P, which is equal to MQ;
|
||||||
one left multiplication and one right multiplication.
|
|
||||||
For warm-start, can take one such step at a time, and alternate between them.
|
4. Allreduces P;
|
||||||
|
|
||||||
|
5. Orthogonalizes P;
|
||||||
|
|
||||||
|
6. Computes Q, which is approximately equal to M^TP;
|
||||||
|
|
||||||
|
7. Allreduces Q;
|
||||||
|
|
||||||
|
8. Computes M, which is approximately equal to PQ^T.
|
||||||
|
|
||||||
|
9. Truncates the input tensor to the original length.
|
||||||
|
|
||||||
|
Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
|
||||||
|
This not only gives the user more control over the tradeoff between speedup and accuracy,
|
||||||
|
but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
||||||
To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
|
To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
|
||||||
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
||||||
Note that since DDP comm hook only supports single process single device mode at this time,
|
Note that since DDP comm hook only supports single process single device mode at this time,
|
||||||
only exactly one tensor is stored in this bucket.
|
only exactly one tensor is stored in this bucket.
|
||||||
@ -454,9 +462,9 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
|||||||
Future handler of the communication, which updates the gradients in place.
|
Future handler of the communication, which updates the gradients in place.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
|
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
|
||||||
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
|
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
|
||||||
"""
|
""" # noqa
|
||||||
process_group = state.process_group
|
process_group = state.process_group
|
||||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||||
world_size = group_to_use.size()
|
world_size = group_to_use.size()
|
||||||
@ -563,6 +571,10 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
|||||||
out=state.q_memory_dict[bucket_index],
|
out=state.q_memory_dict[bucket_index],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: The above procedure does two matmul+allreduce steps per iteration --
|
||||||
|
# one left multiplication and one right multiplication.
|
||||||
|
# For warm-start, can take one such step at a time, and alternate between them.
|
||||||
|
|
||||||
return [
|
return [
|
||||||
dist.all_reduce(
|
dist.all_reduce(
|
||||||
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
|
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Dict, List, Set, NamedTuple, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx.passes.split_module import split_module
|
from torch.fx.passes.split_module import split_module
|
||||||
import operator
|
import operator
|
||||||
from torch.fx.experimental.partitioner_utils import Partition, \
|
from torch.fx._experimental.partitioner_utils import Partition, \
|
||||||
Device, PartitionerConfig, get_partition_to_latency_mapping,\
|
Device, PartitionerConfig, get_partition_to_latency_mapping,\
|
||||||
get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of, \
|
get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of, \
|
||||||
PartitionMode
|
PartitionMode
|
||||||
@ -2,7 +2,7 @@ from typing import Dict, List, NamedTuple, Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.passes.shape_prop import ShapeProp
|
from torch.fx.passes.shape_prop import ShapeProp
|
||||||
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
|
from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes
|
||||||
from torch.fx.graph import Graph, get_qualified_name
|
from torch.fx.graph import Graph, get_qualified_name
|
||||||
from torch.fx.graph_module import GraphModule
|
from torch.fx.graph_module import GraphModule
|
||||||
from torch.fx.node import Node, Target, map_arg
|
from torch.fx.node import Node, Target, map_arg
|
||||||
@ -116,7 +116,7 @@ class Interpreter:
|
|||||||
|
|
||||||
# Main Node running APIs
|
# Main Node running APIs
|
||||||
|
|
||||||
def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``placeholder`` node. Note that this is stateful:
|
Execute a ``placeholder`` node. Note that this is stateful:
|
||||||
``Interpreter`` maintains an internal iterator over
|
``Interpreter`` maintains an internal iterator over
|
||||||
@ -141,7 +141,7 @@ class Interpreter:
|
|||||||
else:
|
else:
|
||||||
return next(self.args_iter)
|
return next(self.args_iter)
|
||||||
|
|
||||||
def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||||
value from the ``Module`` hierarchy of ``self.module``.
|
value from the ``Module`` hierarchy of ``self.module``.
|
||||||
@ -159,7 +159,7 @@ class Interpreter:
|
|||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
return self.fetch_attr(target)
|
return self.fetch_attr(target)
|
||||||
|
|
||||||
def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``call_function`` node and return the result.
|
Execute a ``call_function`` node and return the result.
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class Interpreter:
|
|||||||
# Execute the function and return the result
|
# Execute the function and return the result
|
||||||
return target(*args, **kwargs)
|
return target(*args, **kwargs)
|
||||||
|
|
||||||
def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``call_method`` node and return the result.
|
Execute a ``call_method`` node and return the result.
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ class Interpreter:
|
|||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
return getattr(self_obj, target)(*args_tail, **kwargs)
|
return getattr(self_obj, target)(*args_tail, **kwargs)
|
||||||
|
|
||||||
def call_module(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``call_module`` node and return the result.
|
Execute a ``call_module`` node and return the result.
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class Interpreter:
|
|||||||
|
|
||||||
return submod(*args, **kwargs)
|
return submod(*args, **kwargs)
|
||||||
|
|
||||||
def output(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute an ``output`` node. This really just retrieves
|
Execute an ``output`` node. This really just retrieves
|
||||||
the value referenced by the ``output`` node and returns it.
|
the value referenced by the ``output`` node and returns it.
|
||||||
@ -307,12 +307,12 @@ class Transformer(Interpreter):
|
|||||||
method equivalents). We could subclass ``Transformer`` like so::
|
method equivalents). We could subclass ``Transformer`` like so::
|
||||||
|
|
||||||
class NegSigmSwapXformer(Transformer):
|
class NegSigmSwapXformer(Transformer):
|
||||||
def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
if target == torch.sigmoid:
|
if target == torch.sigmoid:
|
||||||
return torch.neg(*args, **kwargs)
|
return torch.neg(*args, **kwargs)
|
||||||
return super().call_function(n)
|
return super().call_function(n)
|
||||||
|
|
||||||
def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
if target == 'neg':
|
if target == 'neg':
|
||||||
call_self, *args_tail = args
|
call_self, *args_tail = args
|
||||||
return call_self.sigmoid(*args_tail, **kwargs)
|
return call_self.sigmoid(*args_tail, **kwargs)
|
||||||
@ -344,7 +344,7 @@ class Transformer(Interpreter):
|
|||||||
self.tracer = TransformerTracer(self.new_graph)
|
self.tracer = TransformerTracer(self.new_graph)
|
||||||
self.tracer.root = module
|
self.tracer.root = module
|
||||||
|
|
||||||
def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy:
|
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
|
||||||
"""
|
"""
|
||||||
Execute a ``placeholder`` node. In ``Transformer``, this is
|
Execute a ``placeholder`` node. In ``Transformer``, this is
|
||||||
overridden to insert a new ``placeholder`` into the output
|
overridden to insert a new ``placeholder`` into the output
|
||||||
@ -360,7 +360,7 @@ class Transformer(Interpreter):
|
|||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
return Proxy(self.new_graph.placeholder(target), self.tracer)
|
return Proxy(self.new_graph.placeholder(target), self.tracer)
|
||||||
|
|
||||||
def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy:
|
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
|
||||||
"""
|
"""
|
||||||
Execute a ``get_attr`` node. In ``Transformer``, this is
|
Execute a ``get_attr`` node. In ``Transformer``, this is
|
||||||
overridden to insert a new ``get_attr`` node into the output
|
overridden to insert a new ``get_attr`` node into the output
|
||||||
@ -376,6 +376,12 @@ class Transformer(Interpreter):
|
|||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
return Proxy(self.new_graph.get_attr(target), self.tracer)
|
return Proxy(self.new_graph.get_attr(target), self.tracer)
|
||||||
|
|
||||||
|
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||||
|
# Override so that the leaf module policy from `self.tracer` is respected.
|
||||||
|
assert isinstance(target, str)
|
||||||
|
submod = self.fetch_attr(target)
|
||||||
|
return self.tracer.call_module(submod, submod.forward, args, kwargs)
|
||||||
|
|
||||||
def transform(self) -> GraphModule:
|
def transform(self) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
Transform ``self.module`` and return the transformed
|
Transform ``self.module`` and return the transformed
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import operator
|
|||||||
|
|
||||||
from .graph import magic_methods, reflectable_magic_methods, Graph
|
from .graph import magic_methods, reflectable_magic_methods, Graph
|
||||||
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator
|
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator
|
||||||
from .node import Target, Node, Argument, base_types
|
from .node import Target, Node, Argument, base_types, map_aggregate
|
||||||
|
|
||||||
class TracerBase:
|
class TracerBase:
|
||||||
graph: Graph
|
graph: Graph
|
||||||
@ -61,8 +61,17 @@ class TracerBase:
|
|||||||
elif isinstance(a, dict):
|
elif isinstance(a, dict):
|
||||||
r = {}
|
r = {}
|
||||||
for k, v in a.items():
|
for k, v in a.items():
|
||||||
if not isinstance(k, str):
|
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||||
raise NotImplementedError(f"dictionaries with non-string keys: {a}")
|
# anywhere within the key. Since keys can be collection types,
|
||||||
|
# we iterate through the key with map_aggregate
|
||||||
|
k = self.create_arg(k)
|
||||||
|
|
||||||
|
def no_node(arg):
|
||||||
|
if isinstance(arg, Node):
|
||||||
|
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
|
||||||
|
"Node. Got key: {k}")
|
||||||
|
map_aggregate(k, no_node)
|
||||||
|
|
||||||
r[k] = self.create_arg(v)
|
r[k] = self.create_arg(v)
|
||||||
return r
|
return r
|
||||||
elif isinstance(a, slice):
|
elif isinstance(a, slice):
|
||||||
|
|||||||
@ -1021,13 +1021,14 @@ class DistributedDataParallel(Module):
|
|||||||
parameter syncs while running Distributed DataParallel training.
|
parameter syncs while running Distributed DataParallel training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state (object): state is passed to the hook and can be used to maintain
|
state (object): Passed to the hook to maintain any state information during the training process.
|
||||||
and update any state information that users would like to
|
Examples include error feedback in gradient compression,
|
||||||
maintain as part of the training process. Examples: error
|
peers to communicate with next in GossipGrad, etc.
|
||||||
feedback in gradient compression, peers to communicate with
|
|
||||||
next in GossipGrad etc.
|
It is locally stored by each worker
|
||||||
hook (callable): is defined as:
|
and shared by all the gradient tensors on the worker.
|
||||||
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future:
|
hook (callable): Averages gradient tensors across workers and defined as:
|
||||||
|
``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``:
|
||||||
|
|
||||||
This function is called once the bucket is ready. The
|
This function is called once the bucket is ready. The
|
||||||
hook can perform whatever processing is needed and return
|
hook can perform whatever processing is needed and return
|
||||||
@ -1067,7 +1068,7 @@ class DistributedDataParallel(Module):
|
|||||||
DDP communication hook is experimental and subject to change.
|
DDP communication hook is experimental and subject to change.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
Below is an example of a noop hook that returns back the same tensors:
|
Below is an example of a noop hook that returns the same tensors.
|
||||||
|
|
||||||
>>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
|
>>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
|
||||||
>>> fut = torch.futures.Future()
|
>>> fut = torch.futures.Future()
|
||||||
@ -1091,7 +1092,6 @@ class DistributedDataParallel(Module):
|
|||||||
>>> return fut.then(decode)
|
>>> return fut.then(decode)
|
||||||
|
|
||||||
>>> ddp.register_comm_hook(state = None, hook = encode_and_decode)
|
>>> ddp.register_comm_hook(state = None, hook = encode_and_decode)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._check_comm_hook(hook)
|
self._check_comm_hook(hook)
|
||||||
dist._register_comm_hook(self.reducer, state, hook)
|
dist._register_comm_hook(self.reducer, state, hook)
|
||||||
|
|||||||
@ -391,9 +391,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||||||
torch.exp2: lambda input, out=None: -1,
|
torch.exp2: lambda input, out=None: -1,
|
||||||
torch.expm1: lambda input, out=None: -1,
|
torch.expm1: lambda input, out=None: -1,
|
||||||
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
||||||
torch.fake_quantize_per_channel_affine_cachemask: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
|
||||||
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
||||||
torch.fake_quantize_per_tensor_affine_cachemask: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
|
||||||
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
||||||
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
||||||
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
||||||
|
|||||||
Reference in New Issue
Block a user