mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Compare commits
11 Commits
copilot/co
...
v1.8.0-rc2
Author | SHA1 | Date | |
---|---|---|---|
c79decdbba | |||
c307a3f336 | |||
f071020756 | |||
4f436f8570 | |||
ae11589710 | |||
9e5bcc1020 | |||
fa8578241d | |||
1368809532 | |||
4073248fc2 | |||
75153cb730 | |||
5bb69b080c |
@ -153,11 +153,11 @@ commands:
|
||||
git config --global user.email "circleci.ossci@gmail.com"
|
||||
git config --global user.name "CircleCI"
|
||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet
|
||||
git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet
|
||||
# PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
|
||||
if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
|
||||
CIRCLE_PR_BASE_BRANCH=master
|
||||
CIRCLE_PR_BASE_BRANCH=release/1.8
|
||||
fi
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
|
||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||
|
@ -111,11 +111,11 @@ commands:
|
||||
git config --global user.email "circleci.ossci@gmail.com"
|
||||
git config --global user.name "CircleCI"
|
||||
git config remote.origin.url https://github.com/pytorch/pytorch.git
|
||||
git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet
|
||||
git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8
|
||||
git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet
|
||||
# PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
|
||||
if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
|
||||
CIRCLE_PR_BASE_BRANCH=master
|
||||
CIRCLE_PR_BASE_BRANCH=release/1.8
|
||||
fi
|
||||
export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
|
||||
echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
|
||||
|
@ -182,7 +182,7 @@ fi
|
||||
|
||||
# Patch required to build xla
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
|
||||
git clone --recursive https://github.com/pytorch/xla.git
|
||||
git clone --recursive -b r1.8 https://github.com/pytorch/xla.git
|
||||
./xla/scripts/apply_patches.sh
|
||||
fi
|
||||
|
||||
|
@ -54,7 +54,7 @@ function file_diff_from_base() {
|
||||
set +e
|
||||
git fetch origin master --quiet
|
||||
set -e
|
||||
git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
|
||||
git diff --name-only "$(git merge-base origin/release/1.8 HEAD)" > "$1"
|
||||
}
|
||||
|
||||
function get_bazel() {
|
||||
|
@ -300,7 +300,7 @@ test_backward_compatibility() {
|
||||
pushd test/backward_compatibility
|
||||
python -m venv venv
|
||||
. venv/bin/activate
|
||||
pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_test.html
|
||||
pip show torch
|
||||
python dump_all_function_schemas.py --filename nightly_schemas.txt
|
||||
deactivate
|
||||
|
@ -19,6 +19,27 @@ namespace {
|
||||
|
||||
using namespace vec256;
|
||||
|
||||
// Note: Explicit implementation of copysign for Half and BFloat16
|
||||
// is needed to workaround g++-7/8 crash on aarch64, but also makes
|
||||
// copysign faster for the half-precision types
|
||||
template<typename T>
|
||||
T copysign(T a, T b) {
|
||||
return std::copysign(a, b);
|
||||
}
|
||||
|
||||
// Implement copysign for half precision floats using bit ops
|
||||
// Sign is the most significant bit for both half and bfloat16 types
|
||||
template<>
|
||||
c10::Half copysign(c10::Half a, c10::Half b) {
|
||||
return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
|
||||
}
|
||||
|
||||
template<>
|
||||
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
|
||||
return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
|
||||
}
|
||||
|
||||
|
||||
// Note: Undefined behavior when performing addition is intentionally
|
||||
// ignored.
|
||||
void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) {
|
||||
@ -180,7 +201,7 @@ void div_floor_kernel(TensorIterator& iter) {
|
||||
floordiv += scalar_t(1.0);
|
||||
}
|
||||
} else {
|
||||
floordiv = std::copysign(scalar_t(0), a / b);
|
||||
floordiv = copysign(scalar_t(0), a / b);
|
||||
}
|
||||
return floordiv;
|
||||
});
|
||||
@ -889,23 +910,6 @@ void heaviside_kernel(TensorIterator& iter) {
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T copysign(T a, T b) {
|
||||
return std::copysign(a, b);
|
||||
}
|
||||
|
||||
// Implement copysign for half precision floats using bit ops
|
||||
// Sign is the most significant bit for both half and bfloat16 types
|
||||
template<>
|
||||
c10::Half copysign(c10::Half a, c10::Half b) {
|
||||
return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
|
||||
}
|
||||
|
||||
template<>
|
||||
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
|
||||
return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
|
||||
}
|
||||
|
||||
void copysign_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
|
||||
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
|
||||
|
@ -133,7 +133,9 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) {
|
||||
ASSERT_EQ(buffer1[i].z, buffer2[i].z);
|
||||
ASSERT_EQ(buffer1[i].w, buffer2[i].w);
|
||||
}
|
||||
// Skipping this part until https://github.com/pytorch/pytorch/issues/51863 is resolved
|
||||
|
||||
#if 0
|
||||
// unaligned
|
||||
for (int i = 0; i < 16; i++) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
@ -151,4 +153,5 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ int32_t driver_version() {
|
||||
return driver_version;
|
||||
}
|
||||
|
||||
int device_count_impl() {
|
||||
int device_count_impl(bool fail_if_no_driver) {
|
||||
int count;
|
||||
auto err = cudaGetDeviceCount(&count);
|
||||
if (err == cudaSuccess) {
|
||||
@ -34,6 +34,11 @@ int device_count_impl() {
|
||||
case cudaErrorInsufficientDriver: {
|
||||
auto version = driver_version();
|
||||
if (version <= 0) {
|
||||
if (!fail_if_no_driver) {
|
||||
// No CUDA driver means no devices
|
||||
count = 0;
|
||||
break;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Found no NVIDIA driver on your system. Please check that you "
|
||||
@ -95,9 +100,9 @@ DeviceIndex device_count() noexcept {
|
||||
// initialize number of devices only once
|
||||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl();
|
||||
auto result = device_count_impl(/*fail_if_no_driver=*/false);
|
||||
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
|
||||
return device_count_impl();
|
||||
return result;
|
||||
} catch (const c10::Error& ex) {
|
||||
// We don't want to fail, but still log the warning
|
||||
// msg() returns the message without the stack trace
|
||||
@ -110,7 +115,7 @@ DeviceIndex device_count() noexcept {
|
||||
|
||||
DeviceIndex device_count_ensure_non_zero() {
|
||||
// Call the implementation every time to throw the exception
|
||||
int count = device_count_impl();
|
||||
int count = device_count_impl(/*fail_if_no_driver=*/true);
|
||||
// Zero gpus doesn't produce a warning in `device_count` but we fail here
|
||||
TORCH_CHECK(count, "No CUDA GPUs are available");
|
||||
return static_cast<DeviceIndex>(count);
|
||||
|
74
docs/source/ddp_comm_hooks.rst
Normal file
74
docs/source/ddp_comm_hooks.rst
Normal file
@ -0,0 +1,74 @@
|
||||
DDP Communication Hooks
|
||||
=======================
|
||||
|
||||
DDP communication hook is a generic interface to control how to communicate
|
||||
gradients across workers by overriding the vanilla allreduce in
|
||||
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_.
|
||||
A few built-in communication hooks are provided,
|
||||
and users can easily apply any of these hooks to optimize communication.
|
||||
Besides, the hook interface can also support user-defined communication
|
||||
strategies for more advanced use cases.
|
||||
|
||||
.. warning ::
|
||||
DDP communication hook is experimental and subject to change.
|
||||
|
||||
.. warning ::
|
||||
DDP communication hooks can only support single process single device mode
|
||||
on NCCL backend.
|
||||
|
||||
How to Use a Communication Hook?
|
||||
--------------------------------
|
||||
|
||||
To use a communication hook, the user just needs to let the DDP model register
|
||||
the hook before the training loop as below.
|
||||
|
||||
:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`.
|
||||
:noindex:
|
||||
|
||||
Default Communication Hooks
|
||||
---------------------------
|
||||
|
||||
Default communication hooks are simple **stateless** hooks, so the input state
|
||||
in ``register_comm_hook`` is either a process group or ``None``.
|
||||
|
||||
.. automodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
|
||||
:members:
|
||||
|
||||
PowerSGD Communication Hook
|
||||
---------------------------
|
||||
|
||||
PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_)
|
||||
is a gradient compression algorithm, which can provide very high compression
|
||||
rates and accelerate bandwidth-bound distributed training.
|
||||
This algorithm needs to maintain both some hyperparameters and the internal
|
||||
state. Therefore, PowerSGD communication hook is a **stateful** hook,
|
||||
and the user needs to provide a state object defined as below.
|
||||
|
||||
PowerSGD State
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
|
||||
.. autoclass:: PowerSGDState
|
||||
|
||||
PowerSGD Hooks
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. warning ::
|
||||
PowerSGD typically requires extra memory of the same size as the model's
|
||||
gradients to enable error feedback, which can compensate for biased
|
||||
compressed communication and improve accuracy.
|
||||
|
||||
.. warning ::
|
||||
The current implementation may cause gradient overflow for FP16 input.
|
||||
|
||||
.. autofunction:: powerSGD_hook
|
||||
.. autofunction:: batched_powerSGD_hook
|
||||
|
||||
Acknowledgements
|
||||
----------------
|
||||
|
||||
Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on
|
||||
PowerSGD communication hook, as well as the
|
||||
`comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_,
|
||||
which show that the performance of PowerSGD communication hook is on par with
|
||||
the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_.
|
@ -71,6 +71,7 @@ Features described in this documentation are classified by release status:
|
||||
onnx
|
||||
optim
|
||||
complex_numbers
|
||||
ddp_comm_hooks
|
||||
pipeline
|
||||
quantization
|
||||
rpc
|
||||
|
@ -484,6 +484,7 @@ Sparse tensor functions
|
||||
+++++++++++++++++++++++
|
||||
|
||||
.. autofunction:: torch.sparse_coo_tensor
|
||||
:noindex:
|
||||
.. autofunction:: torch.sparse.sum
|
||||
.. autofunction:: torch.sparse.addmm
|
||||
.. autofunction:: torch.sparse.mm
|
||||
|
45
setup.py
45
setup.py
@ -552,6 +552,50 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
with open('compile_commands.json', 'w') as f:
|
||||
f.write(new_contents)
|
||||
|
||||
class concat_license_files():
|
||||
"""Merge LICENSE and LICENSES_BUNDLED.txt as a context manager
|
||||
|
||||
LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated
|
||||
from all the licenses found in ./third_party/. We concatenate them so there
|
||||
is a single license file in the sdist and wheels with all of the necessary
|
||||
licensing info.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.f1 = 'LICENSE'
|
||||
self.f2 = 'third_party/LICENSES_BUNDLED.txt'
|
||||
|
||||
def __enter__(self):
|
||||
"""Concatenate files"""
|
||||
with open(self.f1, 'r') as f1:
|
||||
self.bsd_text = f1.read()
|
||||
|
||||
with open(self.f1, 'a') as f1:
|
||||
with open(self.f2, 'r') as f2:
|
||||
self.bundled_text = f2.read()
|
||||
f1.write('\n\n')
|
||||
f1.write(self.bundled_text)
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
"""Restore content of f1"""
|
||||
with open(self.f1, 'w') as f:
|
||||
f.write(self.bsd_text)
|
||||
|
||||
|
||||
try:
|
||||
from wheel.bdist_wheel import bdist_wheel
|
||||
except ImportError:
|
||||
# This is useful when wheel is not installed and bdist_wheel is not
|
||||
# specified on the command line. If it _is_ specified, parsing the command
|
||||
# line will fail before wheel_concatenate is needed
|
||||
wheel_concatenate = None
|
||||
else:
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class wheel_concatenate(bdist_wheel):
|
||||
""" check submodules on sdist to prevent incomplete tarballs """
|
||||
def run(self):
|
||||
with concat_license_files():
|
||||
super().run()
|
||||
|
||||
|
||||
class install(setuptools.command.install.install):
|
||||
def run(self):
|
||||
@ -724,6 +768,7 @@ def configure_extension_build():
|
||||
'build_ext': build_ext,
|
||||
'clean': clean,
|
||||
'install': install,
|
||||
'bdist_wheel': wheel_concatenate,
|
||||
}
|
||||
|
||||
entry_points = {
|
||||
|
@ -872,7 +872,7 @@ class TestFakeQuantize(TestCase):
|
||||
scale, zero_point = float(scale), int(zero_point)
|
||||
quant_min, quant_max = obs._calculate_qmin_qmax()
|
||||
|
||||
Y_test, _mask = torch.fake_quantize_per_tensor_affine_cachemask(
|
||||
Y_test = torch.fake_quantize_per_tensor_affine(
|
||||
X, scale, zero_point, quant_min, quant_max)
|
||||
Y_ref = _fake_quantize_per_tensor_affine_reference(
|
||||
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
|
||||
@ -899,7 +899,7 @@ class TestFakeQuantize(TestCase):
|
||||
quant_min, quant_max = obs._calculate_qmin_qmax()
|
||||
|
||||
# forward pass
|
||||
Y_test, mask = torch.fake_quantize_per_tensor_affine_cachemask(
|
||||
Y_test = torch.fake_quantize_per_tensor_affine(
|
||||
X, scale, zero_point, quant_min, quant_max)
|
||||
Y_ref = _fake_quantize_per_tensor_affine_reference(
|
||||
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
|
||||
@ -1246,7 +1246,7 @@ class TestFakeQuantize(TestCase):
|
||||
|
||||
Y = _fake_quantize_per_channel_affine_reference(
|
||||
X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
|
||||
Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
|
||||
Y_prime = torch.fake_quantize_per_channel_affine(
|
||||
X, scale, zero_point, axis, quant_min, quant_max)
|
||||
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||
|
||||
@ -1339,7 +1339,7 @@ class TestFakeQuantize(TestCase):
|
||||
zero_point = zero_point.to(torch.int64)
|
||||
quant_min, quant_max = obs._calculate_qmin_qmax()
|
||||
X.requires_grad_()
|
||||
Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
|
||||
Y_prime = torch.fake_quantize_per_channel_affine(
|
||||
X, scale, zero_point, axis, quant_min, quant_max)
|
||||
dout = torch.rand(X.shape, dtype=torch.float).to(device)
|
||||
dX = _fake_quantize_per_channel_affine_grad_reference(
|
||||
|
@ -108,6 +108,7 @@ TESTS = [
|
||||
'test_fx_experimental',
|
||||
'test_functional_autograd_benchmark',
|
||||
'test_package',
|
||||
'test_license',
|
||||
'distributed/pipeline/sync/skip/test_api',
|
||||
'distributed/pipeline/sync/skip/test_gpipe',
|
||||
'distributed/pipeline/sync/skip/test_inspect_skip_layout',
|
||||
|
@ -14,7 +14,7 @@ from math import sqrt
|
||||
from pathlib import Path
|
||||
from torch.multiprocessing import Process
|
||||
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap
|
||||
from torch.fx.node import Target
|
||||
from torch.fx.node import Target, Argument
|
||||
from torch.fx.passes import shape_prop
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from copy import deepcopy
|
||||
@ -187,7 +187,7 @@ class TestFX(JitTestCase):
|
||||
# Custom delegate to disallow in-place tensor operations
|
||||
class NoMutableCallTracer(Tracer):
|
||||
def create_node(self, kind : str, target : Union[str, Callable],
|
||||
args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
|
||||
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
||||
type_expr : Optional[Any] = None) -> Node:
|
||||
name = target if isinstance(target, str) else torch.typename(target)
|
||||
if name[-1] == '_':
|
||||
@ -539,7 +539,7 @@ class TestFX(JitTestCase):
|
||||
def test_node_tagging(self):
|
||||
class TaggingTracer(Tracer):
|
||||
def create_node(self, kind : str, target : Union[str, Callable],
|
||||
args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
|
||||
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
||||
type_expr : Optional[Any] = None) -> Node:
|
||||
n = super().create_node(kind, target, args, kwargs, name)
|
||||
n.tag = 'foo'
|
||||
@ -1057,6 +1057,13 @@ class TestFX(JitTestCase):
|
||||
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
|
||||
self.assertEqual(result, torch.ones(3, 4) * 2.0)
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_interpreter_noop_resnet18(self):
|
||||
rn18 = resnet18()
|
||||
transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
|
||||
inp = torch.randn(5, 3, 224, 224)
|
||||
self.assertEqual(transformed(inp), rn18(inp))
|
||||
|
||||
def test_transformer_noop(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -1377,6 +1384,45 @@ class TestFX(JitTestCase):
|
||||
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
||||
self.checkGraphModule(foo, (x, y))
|
||||
|
||||
def test_trace_dict_int_keys(self):
|
||||
class ModWithDictArg(torch.nn.Module):
|
||||
def forward(self, d : Dict[int, torch.Tensor]):
|
||||
return d[42]
|
||||
|
||||
class CallsModWithDict(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m = ModWithDictArg()
|
||||
|
||||
def forward(self, x):
|
||||
return self.m({42: x})
|
||||
|
||||
class MyTracer(torch.fx.Tracer):
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
return isinstance(m, ModWithDictArg)
|
||||
|
||||
traced_graph = MyTracer().trace(CallsModWithDict())
|
||||
|
||||
def test_trace_dict_proxy_keys(self):
|
||||
class ModWithDictArg(torch.nn.Module):
|
||||
def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
|
||||
return d[42]
|
||||
|
||||
class CallsModWithDict(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m = ModWithDictArg()
|
||||
|
||||
def forward(self, x):
|
||||
return self.m({x: x})
|
||||
|
||||
class MyTracer(torch.fx.Tracer):
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
return isinstance(m, ModWithDictArg)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
|
||||
traced_graph = MyTracer().trace(CallsModWithDict())
|
||||
|
||||
def test_direct_param_use(self):
|
||||
class TransposeTest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -5,14 +5,14 @@ from typing import Callable, Dict, Union, List
|
||||
from torch.fx.symbolic_trace import symbolic_trace
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.experimental import graph_manipulation
|
||||
from torch.fx.experimental.accelerator_partitioner import Partitioner
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
|
||||
from torch.fx._experimental import graph_manipulation
|
||||
from torch.fx._experimental.accelerator_partitioner import Partitioner
|
||||
from torch.fx._experimental.rewriter import RewritingTracer
|
||||
from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.fx.passes.split_module import split_module
|
||||
from torch.fx.experimental.partitioner_utils import (
|
||||
from torch.fx._experimental.partitioner_utils import (
|
||||
NodeLatency,
|
||||
get_partition_to_latency_mapping,
|
||||
get_latency_of_partitioned_graph,
|
||||
@ -20,8 +20,8 @@ from torch.fx.experimental.partitioner_utils import (
|
||||
PartitionerConfig,
|
||||
PartitionMode
|
||||
)
|
||||
from torch.fx.experimental.fuser import fuse
|
||||
from torch.fx.experimental import merge_matmul
|
||||
from torch.fx._experimental.fuser import fuse
|
||||
from torch.fx._experimental import merge_matmul
|
||||
|
||||
try:
|
||||
from torchvision.models import resnet18
|
||||
@ -849,7 +849,7 @@ terrible spacing
|
||||
|
||||
def test_merge_matmuls(self):
|
||||
"""
|
||||
A collection of test cases for torch.fx.experimental.merge_matmul,
|
||||
A collection of test cases for torch.fx._experimental.merge_matmul,
|
||||
a graph transformation that merges matrix multiplication operations.
|
||||
"""
|
||||
# Utility function for counting matmuls for test assertions.
|
||||
|
@ -6503,6 +6503,38 @@ a")
|
||||
self.checkModule(module().train(), ())
|
||||
self.checkModule(module().eval(), ())
|
||||
|
||||
def test_ternary_static_if(self):
|
||||
# Test for True branch when condition variable
|
||||
# is annotated as Final
|
||||
class M1(torch.nn.Module):
|
||||
flag: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flag = True
|
||||
|
||||
def forward(self) -> torch.Tensor:
|
||||
return torch.ones(3) if self.flag else {}
|
||||
|
||||
# Test for True branch when condition variable
|
||||
# is annotated as Final
|
||||
class M2(torch.nn.Module):
|
||||
flag: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flag = False
|
||||
|
||||
def forward(self) -> torch.Tensor:
|
||||
return {} if self.flag else torch.ones(3)
|
||||
|
||||
model1 = M1()
|
||||
model2 = M2()
|
||||
script_model_1 = torch.jit.script(model1)
|
||||
script_model_2 = torch.jit.script(model2)
|
||||
self.assertEqual(model1.forward(), script_model_1.forward())
|
||||
self.assertEqual(model2.forward(), script_model_2.forward())
|
||||
|
||||
def test_print(self):
|
||||
def func(x, y):
|
||||
q = (x + y).sigmoid()
|
||||
|
@ -1,6 +1,9 @@
|
||||
import glob
|
||||
import io
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
@ -10,11 +13,14 @@ except ImportError:
|
||||
create_bundled = None
|
||||
|
||||
license_file = 'third_party/LICENSES_BUNDLED.txt'
|
||||
starting_txt = 'The Pytorch repository and source distributions bundle'
|
||||
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
|
||||
distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info'))
|
||||
|
||||
class TestLicense(TestCase):
|
||||
|
||||
@unittest.skipIf(not create_bundled, "can only be run in a source tree")
|
||||
def test_license_in_wheel(self):
|
||||
def test_license_for_wheel(self):
|
||||
current = io.StringIO()
|
||||
create_bundled('third_party', current)
|
||||
with open(license_file) as fid:
|
||||
@ -25,6 +31,18 @@ class TestLicense(TestCase):
|
||||
'match the current state of the third_party files. Use '
|
||||
'"python third_party/build_bundled.py" to regenerate it')
|
||||
|
||||
@unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
|
||||
def test_distinfo_license(self):
|
||||
"""If run when pytorch is installed via a wheel, the license will be in
|
||||
site-package/torch-*dist-info/LICENSE. Make sure it contains the third
|
||||
party bundle of licenses"""
|
||||
|
||||
if len(distinfo) > 1:
|
||||
raise AssertionError('Found too many "torch-*dist-info" directories '
|
||||
f'in "{site_packages}, expected only one')
|
||||
with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid:
|
||||
txt = fid.read()
|
||||
self.assertTrue(starting_txt in txt)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -82,7 +82,8 @@ SKIP_PYTHON_BINDINGS = [
|
||||
'set_data',
|
||||
'.*_overrideable', # overrideable functions for backend extension
|
||||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
|
||||
'_fw_primal'
|
||||
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
|
||||
'fake_quantize_per_channel_affine_cachemask',
|
||||
]
|
||||
|
||||
# These function signatures are not exposed to Python. Note that this signature
|
||||
|
@ -1258,6 +1258,15 @@ struct to_ir {
|
||||
const TernaryIf& expr,
|
||||
const TypePtr& type_hint = nullptr) {
|
||||
CondValue cond_value = emitCondExpr(expr.cond());
|
||||
// If the cond expr is a static value, then we metacompile the `if`
|
||||
// statemement and only emit true or false branch
|
||||
if (cond_value.staticIf()) {
|
||||
if (*cond_value.staticIf()) {
|
||||
return emitExpr(expr.true_expr(), type_hint);
|
||||
} else {
|
||||
return emitExpr(expr.false_expr(), type_hint);
|
||||
}
|
||||
}
|
||||
auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); };
|
||||
auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); };
|
||||
return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
|
||||
|
@ -33,39 +33,38 @@ def _orthogonalize(matrix, epsilon=1e-8):
|
||||
|
||||
|
||||
class PowerSGDState(object):
|
||||
"""
|
||||
Stores both the gradient compression configs and the internal states for all the gradients during the training.
|
||||
Particularly, `matrix_approximation_rank` and `start_powerSGD_iter` are the main configs that need to be tuned by the user.
|
||||
Although `use_error_feedback` and `warm_start` can also be tuned by the user,
|
||||
they are typically turned on for performance.
|
||||
r"""
|
||||
Stores both the algorithm's hyperparameters and the internal state for all the gradients during the training.
|
||||
Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
|
||||
For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
|
||||
|
||||
Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
1) To tune `matrix_approximation_rank`, the user can increase it from 1 by factors of 2,
|
||||
until a satisfying accuracy can be reached.
|
||||
The increase of `matrix_approximation_rank` can substantially increase the computation costs of the compression.
|
||||
However, the accuracy may not be futher improved beyond a certain `matrix_approximation_rank` value.
|
||||
2) To tune `start_powerSGD_iter`, the user can typically start with 10% of total training steps,
|
||||
and increase it until a satisfying accuracy can be reached.
|
||||
Deferrring PowerSGD can effectively improve the accuracy,
|
||||
even a relatively small `matrix_approximation_rank` is used.
|
||||
This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients,
|
||||
and compressing gradients too early may make the training quickly take a suboptimal trajectory,
|
||||
which can result in an irrecoverable impact on the accuracy.
|
||||
The minimum value allowed in DDP is 2, if error feedback or warm-up is enabled.
|
||||
This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
|
||||
and this can conflict with any tensor memorized before the rebuild process.
|
||||
"""
|
||||
1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
|
||||
|
||||
1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
|
||||
|
||||
1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be futher improved beyond a certain ``matrix_approximation_rank`` threshold.
|
||||
|
||||
To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an expoential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
|
||||
|
||||
2. ``start_powerSGD_iter`` defers PowerSGD compression util step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
|
||||
|
||||
To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached.
|
||||
|
||||
.. warning ::
|
||||
If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
|
||||
This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
|
||||
and this can conflict with any tensor memorized before the rebuild process.
|
||||
""" # noqa
|
||||
|
||||
__slots__ = [
|
||||
"process_group",
|
||||
# The two fields below are the configs that usually need to be tuned by the user.
|
||||
# The two fields below are the hyperparameters that should be tuned by the user.
|
||||
"matrix_approximation_rank",
|
||||
"start_powerSGD_iter",
|
||||
# The two fields below are the configs that usually need to be turned on for performance.
|
||||
# The two fields below are the binary hyperparameters recommended to be turned on for performance.
|
||||
"use_error_feedback",
|
||||
"warm_start",
|
||||
# The fields below are not configs.
|
||||
# The fields below are internal state.
|
||||
"rng",
|
||||
"error_dict",
|
||||
"p_memory_dict",
|
||||
@ -93,21 +92,12 @@ class PowerSGDState(object):
|
||||
)
|
||||
|
||||
self.process_group = process_group
|
||||
# The low rank for matrix approximation controls the size of compressed low-rank tensors,
|
||||
# which determines the computation ratio.
|
||||
# Typically only a small value 1-4 is used.
|
||||
# For some NLP tasks (as shown in Appendix D of the original paper
|
||||
# https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32.
|
||||
# A high rank value will increase the computation costs of compression exponentially.
|
||||
# A good choice depends on how much extra computation can be hidden by the dominating communication costs.
|
||||
self.matrix_approximation_rank = matrix_approximation_rank
|
||||
# This defers PowerSGD compression util step 'start_powerSGD_iter',
|
||||
# and vanilla allreduce runs before step 'start_powerSGD_iter'.
|
||||
# This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages:
|
||||
# Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
|
||||
# 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
|
||||
# even if the matrix approximation rank is increased to a large value.
|
||||
# To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
|
||||
# (or a more convervative compression such as FP16 compression) with PowerSGD.
|
||||
# (or a more conservative compression such as FP16 compression) with PowerSGD.
|
||||
# 2) There is an internal optimization of rebuilding buckets process in DDP,
|
||||
# in order to save the memory space.
|
||||
# This step takes place after the first iteration.
|
||||
@ -162,38 +152,44 @@ class PowerSGDState(object):
|
||||
|
||||
|
||||
def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
"""
|
||||
This DDP communication hook implements the original PowerSGD gradient compression
|
||||
algorithm described in https://arxiv.org/abs/1905.13727.
|
||||
r"""
|
||||
This DDP communication hook implements PowerSGD gradient compression
|
||||
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
|
||||
Once gradient tensors are aggregated across all workers, this hook applies
|
||||
compression as follows:
|
||||
1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
|
||||
high-rank tensors and vector-like rank-1 tensors (for biases).
|
||||
2) Handles rank-1 tensors by allreducing them without compression:
|
||||
2.1) Allocate contiguous memory for those rank-1 tensors,
|
||||
and allreduces all the rank-1 tensors as a batch, without compression;
|
||||
2.2) Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
|
||||
3) Handles high-rank tensors by PowerSGD compression:
|
||||
3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
|
||||
|
||||
1. Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: high-rank tensors and vector-like rank-1 tensors (for biases).
|
||||
|
||||
2. Handles rank-1 tensors by allreducing them without compression:
|
||||
|
||||
2.1. Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression;
|
||||
|
||||
2.2. Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
|
||||
|
||||
3. Handles high-rank tensors by PowerSGD compression:
|
||||
|
||||
3.1. For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
|
||||
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||
3.2) Computes each P in Ps, which is equal to MQ;
|
||||
3.3) Allreduces Ps as a batch;
|
||||
3.4) Orthogonalizes each P in Ps;
|
||||
3.5) Computes each Q in Qs, which is approximately equal to M^TP;
|
||||
3.6) Allreduces Qs as a batch;
|
||||
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
|
||||
|
||||
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
|
||||
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
|
||||
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
||||
3.2. Computes each P in Ps, which is equal to MQ;
|
||||
|
||||
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
|
||||
one left multiplication and one right multiplication.
|
||||
For warm-start, can take one such step at a time, and alternate between them.
|
||||
3.3. Allreduces Ps as a batch;
|
||||
|
||||
3.4. Orthogonalizes each P in Ps;
|
||||
|
||||
3.5. Computes each Q in Qs, which is approximately equal to M^TP;
|
||||
|
||||
3.6. Allreduces Qs as a batch;
|
||||
|
||||
3.7. Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
|
||||
|
||||
Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
|
||||
This not only gives the user more control over the tradeoff between speedup and accuracy,
|
||||
but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
||||
|
||||
Args:
|
||||
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
||||
To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
|
||||
To tune the compression configs, mainly need to tune `matrix_approximation_rank`` and ``start_powerSGD_iter``.
|
||||
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
||||
Note that since DDP comm hook only supports single process single device mode at this time,
|
||||
only exactly one tensor is stored in this bucket.
|
||||
@ -202,9 +198,9 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
Future handler of the communication, which updates the gradients in place.
|
||||
|
||||
Example::
|
||||
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
|
||||
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
|
||||
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
||||
"""
|
||||
""" # noqa
|
||||
process_group = state.process_group
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = group_to_use.size()
|
||||
@ -374,6 +370,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
for tensor, p, q in zip(high_rank_tensors, ps, qs):
|
||||
torch.matmul(tensor.t(), p, out=q)
|
||||
|
||||
# TODO: The above procedure does two matmul+allreduce steps per iteration --
|
||||
# one left multiplication and one right multiplication.
|
||||
# For warm-start, can take one such step at a time, and alternate between them.
|
||||
|
||||
# Allreduce Qs.
|
||||
return [
|
||||
dist.all_reduce(
|
||||
@ -412,40 +412,48 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
|
||||
|
||||
def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
"""
|
||||
r"""
|
||||
This DDP communication hook implements a simplified PowerSGD gradient compression
|
||||
algorithm described in https://arxiv.org/abs/1905.13727.
|
||||
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
|
||||
This variant does not compress the gradients layer by layer,
|
||||
but instead compresses the flattened input tensor that batches all the gradients.
|
||||
Therefore, it is **faster** than :meth:`powerSGD_hook`,
|
||||
but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
|
||||
|
||||
.. warning ::
|
||||
Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
|
||||
because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
|
||||
Therefore, the user should always consider :meth:`powerSGD_hook` first,
|
||||
and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
|
||||
|
||||
Once gradient tensors are aggregated across all workers, this hook applies
|
||||
compression to the flattened input tensor that batches per-parameter tensors as follows:
|
||||
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
|
||||
2) Creates two low-rank tensors P and Q for decomposing M,
|
||||
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||
2) Computes P, which is equal to MQ;
|
||||
3) Allreduces P;
|
||||
4) Orthogonalizes P;
|
||||
5) Computes Q, which is approximately equal to M^TP;
|
||||
6) Allreduces Q;
|
||||
7) Computes M, which is approximately equal to PQ^T.
|
||||
8) Truncates the input tensor to the original length.
|
||||
compression as follows:
|
||||
|
||||
This variant is faster than `powerSGD_hook` that runs layer-wise gradient compression,
|
||||
but it usually results in a much lower accuracy, unless `matrix_approximation_rank` in the state is 1.
|
||||
Increasing `matrix_approximation_rank` may not necessarily increase the accuracy,
|
||||
because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
|
||||
Therefore, the user shoud always consider `powerSGD_hook` first,
|
||||
and only consider this variant when a satisfying accuracy can be achieved when `matrix_approximation_rank` is 1.
|
||||
1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
|
||||
|
||||
Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
|
||||
This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
|
||||
but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
||||
2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||
|
||||
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
|
||||
one left multiplication and one right multiplication.
|
||||
For warm-start, can take one such step at a time, and alternate between them.
|
||||
3. Computes P, which is equal to MQ;
|
||||
|
||||
4. Allreduces P;
|
||||
|
||||
5. Orthogonalizes P;
|
||||
|
||||
6. Computes Q, which is approximately equal to M^TP;
|
||||
|
||||
7. Allreduces Q;
|
||||
|
||||
8. Computes M, which is approximately equal to PQ^T.
|
||||
|
||||
9. Truncates the input tensor to the original length.
|
||||
|
||||
Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
|
||||
This not only gives the user more control over the tradeoff between speedup and accuracy,
|
||||
but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
|
||||
|
||||
Args:
|
||||
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
||||
To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
|
||||
To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
|
||||
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
||||
Note that since DDP comm hook only supports single process single device mode at this time,
|
||||
only exactly one tensor is stored in this bucket.
|
||||
@ -454,9 +462,9 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
Future handler of the communication, which updates the gradients in place.
|
||||
|
||||
Example::
|
||||
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
|
||||
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
|
||||
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
|
||||
"""
|
||||
""" # noqa
|
||||
process_group = state.process_group
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = group_to_use.size()
|
||||
@ -563,6 +571,10 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
|
||||
out=state.q_memory_dict[bucket_index],
|
||||
)
|
||||
|
||||
# TODO: The above procedure does two matmul+allreduce steps per iteration --
|
||||
# one left multiplication and one right multiplication.
|
||||
# For warm-start, can take one such step at a time, and alternate between them.
|
||||
|
||||
return [
|
||||
dist.all_reduce(
|
||||
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
|
||||
|
@ -4,7 +4,7 @@ from typing import Dict, List, Set, NamedTuple, Tuple
|
||||
import torch
|
||||
from torch.fx.passes.split_module import split_module
|
||||
import operator
|
||||
from torch.fx.experimental.partitioner_utils import Partition, \
|
||||
from torch.fx._experimental.partitioner_utils import Partition, \
|
||||
Device, PartitionerConfig, get_partition_to_latency_mapping,\
|
||||
get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of, \
|
||||
PartitionMode
|
@ -2,7 +2,7 @@ from typing import Dict, List, NamedTuple, Any
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
|
||||
from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes
|
||||
from torch.fx.graph import Graph, get_qualified_name
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node, Target, map_arg
|
@ -116,7 +116,7 @@ class Interpreter:
|
||||
|
||||
# Main Node running APIs
|
||||
|
||||
def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
@ -141,7 +141,7 @@ class Interpreter:
|
||||
else:
|
||||
return next(self.args_iter)
|
||||
|
||||
def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
@ -159,7 +159,7 @@ class Interpreter:
|
||||
assert isinstance(target, str)
|
||||
return self.fetch_attr(target)
|
||||
|
||||
def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
|
||||
@ -178,7 +178,7 @@ class Interpreter:
|
||||
# Execute the function and return the result
|
||||
return target(*args, **kwargs)
|
||||
|
||||
def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the result.
|
||||
|
||||
@ -199,7 +199,7 @@ class Interpreter:
|
||||
assert isinstance(target, str)
|
||||
return getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
|
||||
def call_module(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the result.
|
||||
|
||||
@ -221,7 +221,7 @@ class Interpreter:
|
||||
|
||||
return submod(*args, **kwargs)
|
||||
|
||||
def output(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
@ -307,12 +307,12 @@ class Transformer(Interpreter):
|
||||
method equivalents). We could subclass ``Transformer`` like so::
|
||||
|
||||
class NegSigmSwapXformer(Transformer):
|
||||
def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(n)
|
||||
|
||||
def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
if target == 'neg':
|
||||
call_self, *args_tail = args
|
||||
return call_self.sigmoid(*args_tail, **kwargs)
|
||||
@ -344,7 +344,7 @@ class Transformer(Interpreter):
|
||||
self.tracer = TransformerTracer(self.new_graph)
|
||||
self.tracer.root = module
|
||||
|
||||
def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy:
|
||||
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
|
||||
"""
|
||||
Execute a ``placeholder`` node. In ``Transformer``, this is
|
||||
overridden to insert a new ``placeholder`` into the output
|
||||
@ -360,7 +360,7 @@ class Transformer(Interpreter):
|
||||
assert isinstance(target, str)
|
||||
return Proxy(self.new_graph.placeholder(target), self.tracer)
|
||||
|
||||
def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy:
|
||||
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
|
||||
"""
|
||||
Execute a ``get_attr`` node. In ``Transformer``, this is
|
||||
overridden to insert a new ``get_attr`` node into the output
|
||||
@ -376,6 +376,12 @@ class Transformer(Interpreter):
|
||||
assert isinstance(target, str)
|
||||
return Proxy(self.new_graph.get_attr(target), self.tracer)
|
||||
|
||||
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
# Override so that the leaf module policy from `self.tracer` is respected.
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
return self.tracer.call_module(submod, submod.forward, args, kwargs)
|
||||
|
||||
def transform(self) -> GraphModule:
|
||||
"""
|
||||
Transform ``self.module`` and return the transformed
|
||||
|
@ -5,7 +5,7 @@ import operator
|
||||
|
||||
from .graph import magic_methods, reflectable_magic_methods, Graph
|
||||
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator
|
||||
from .node import Target, Node, Argument, base_types
|
||||
from .node import Target, Node, Argument, base_types, map_aggregate
|
||||
|
||||
class TracerBase:
|
||||
graph: Graph
|
||||
@ -61,8 +61,17 @@ class TracerBase:
|
||||
elif isinstance(a, dict):
|
||||
r = {}
|
||||
for k, v in a.items():
|
||||
if not isinstance(k, str):
|
||||
raise NotImplementedError(f"dictionaries with non-string keys: {a}")
|
||||
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||
# anywhere within the key. Since keys can be collection types,
|
||||
# we iterate through the key with map_aggregate
|
||||
k = self.create_arg(k)
|
||||
|
||||
def no_node(arg):
|
||||
if isinstance(arg, Node):
|
||||
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
|
||||
"Node. Got key: {k}")
|
||||
map_aggregate(k, no_node)
|
||||
|
||||
r[k] = self.create_arg(v)
|
||||
return r
|
||||
elif isinstance(a, slice):
|
||||
|
@ -1021,13 +1021,14 @@ class DistributedDataParallel(Module):
|
||||
parameter syncs while running Distributed DataParallel training.
|
||||
|
||||
Args:
|
||||
state (object): state is passed to the hook and can be used to maintain
|
||||
and update any state information that users would like to
|
||||
maintain as part of the training process. Examples: error
|
||||
feedback in gradient compression, peers to communicate with
|
||||
next in GossipGrad etc.
|
||||
hook (callable): is defined as:
|
||||
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future:
|
||||
state (object): Passed to the hook to maintain any state information during the training process.
|
||||
Examples include error feedback in gradient compression,
|
||||
peers to communicate with next in GossipGrad, etc.
|
||||
|
||||
It is locally stored by each worker
|
||||
and shared by all the gradient tensors on the worker.
|
||||
hook (callable): Averages gradient tensors across workers and defined as:
|
||||
``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``:
|
||||
|
||||
This function is called once the bucket is ready. The
|
||||
hook can perform whatever processing is needed and return
|
||||
@ -1067,7 +1068,7 @@ class DistributedDataParallel(Module):
|
||||
DDP communication hook is experimental and subject to change.
|
||||
|
||||
Example::
|
||||
Below is an example of a noop hook that returns back the same tensors:
|
||||
Below is an example of a noop hook that returns the same tensors.
|
||||
|
||||
>>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
|
||||
>>> fut = torch.futures.Future()
|
||||
@ -1091,7 +1092,6 @@ class DistributedDataParallel(Module):
|
||||
>>> return fut.then(decode)
|
||||
|
||||
>>> ddp.register_comm_hook(state = None, hook = encode_and_decode)
|
||||
|
||||
"""
|
||||
self._check_comm_hook(hook)
|
||||
dist._register_comm_hook(self.reducer, state, hook)
|
||||
|
@ -391,9 +391,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.exp2: lambda input, out=None: -1,
|
||||
torch.expm1: lambda input, out=None: -1,
|
||||
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
||||
torch.fake_quantize_per_channel_affine_cachemask: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
|
||||
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
||||
torch.fake_quantize_per_tensor_affine_cachemask: lambda input, scale, zero_point, quant_min, quant_max: -1,
|
||||
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
||||
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
||||
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
||||
|
Reference in New Issue
Block a user