mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			11 Commits
		
	
	
		
			desertfire
			...
			v1.8.0-rc2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| c79decdbba | |||
| c307a3f336 | |||
| f071020756 | |||
| 4f436f8570 | |||
| ae11589710 | |||
| 9e5bcc1020 | |||
| fa8578241d | |||
| 1368809532 | |||
| 4073248fc2 | |||
| 75153cb730 | |||
| 5bb69b080c | 
@ -153,11 +153,11 @@ commands:
 | 
			
		||||
                git config --global user.email "circleci.ossci@gmail.com"
 | 
			
		||||
                git config --global user.name "CircleCI"
 | 
			
		||||
                git config remote.origin.url https://github.com/pytorch/pytorch.git
 | 
			
		||||
                git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
 | 
			
		||||
                git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet
 | 
			
		||||
                git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8
 | 
			
		||||
                git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet
 | 
			
		||||
                # PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
 | 
			
		||||
                if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
 | 
			
		||||
                  CIRCLE_PR_BASE_BRANCH=master
 | 
			
		||||
                  CIRCLE_PR_BASE_BRANCH=release/1.8
 | 
			
		||||
                fi
 | 
			
		||||
                export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
 | 
			
		||||
                echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
 | 
			
		||||
 | 
			
		||||
@ -111,11 +111,11 @@ commands:
 | 
			
		||||
                git config --global user.email "circleci.ossci@gmail.com"
 | 
			
		||||
                git config --global user.name "CircleCI"
 | 
			
		||||
                git config remote.origin.url https://github.com/pytorch/pytorch.git
 | 
			
		||||
                git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master
 | 
			
		||||
                git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet
 | 
			
		||||
                git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8
 | 
			
		||||
                git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet
 | 
			
		||||
                # PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base
 | 
			
		||||
                if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then
 | 
			
		||||
                  CIRCLE_PR_BASE_BRANCH=master
 | 
			
		||||
                  CIRCLE_PR_BASE_BRANCH=release/1.8
 | 
			
		||||
                fi
 | 
			
		||||
                export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH`
 | 
			
		||||
                echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET}
 | 
			
		||||
 | 
			
		||||
@ -182,7 +182,7 @@ fi
 | 
			
		||||
 | 
			
		||||
# Patch required to build xla
 | 
			
		||||
if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then
 | 
			
		||||
  git clone --recursive https://github.com/pytorch/xla.git
 | 
			
		||||
  git clone --recursive -b r1.8 https://github.com/pytorch/xla.git
 | 
			
		||||
  ./xla/scripts/apply_patches.sh
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ function file_diff_from_base() {
 | 
			
		||||
  set +e
 | 
			
		||||
  git fetch origin master --quiet
 | 
			
		||||
  set -e
 | 
			
		||||
  git diff --name-only "$(git merge-base origin/master HEAD)" > "$1"
 | 
			
		||||
  git diff --name-only "$(git merge-base origin/release/1.8 HEAD)" > "$1"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function get_bazel() {
 | 
			
		||||
 | 
			
		||||
@ -300,7 +300,7 @@ test_backward_compatibility() {
 | 
			
		||||
  pushd test/backward_compatibility
 | 
			
		||||
  python -m venv venv
 | 
			
		||||
  . venv/bin/activate
 | 
			
		||||
  pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
 | 
			
		||||
  pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_test.html
 | 
			
		||||
  pip show torch
 | 
			
		||||
  python dump_all_function_schemas.py --filename nightly_schemas.txt
 | 
			
		||||
  deactivate
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,27 @@ namespace {
 | 
			
		||||
 | 
			
		||||
using namespace vec256;
 | 
			
		||||
 | 
			
		||||
// Note: Explicit implementation of copysign for Half and BFloat16
 | 
			
		||||
// is needed to workaround g++-7/8 crash on aarch64, but also makes
 | 
			
		||||
// copysign faster for the half-precision types
 | 
			
		||||
template<typename T>
 | 
			
		||||
T copysign(T a, T b) {
 | 
			
		||||
  return std::copysign(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implement copysign for half precision floats using bit ops
 | 
			
		||||
// Sign is the most significant bit for both half and bfloat16 types
 | 
			
		||||
template<>
 | 
			
		||||
c10::Half copysign(c10::Half a, c10::Half b) {
 | 
			
		||||
  return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
 | 
			
		||||
   return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Note: Undefined behavior when performing addition is intentionally
 | 
			
		||||
// ignored.
 | 
			
		||||
void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) {
 | 
			
		||||
@ -180,7 +201,7 @@ void div_floor_kernel(TensorIterator& iter) {
 | 
			
		||||
                floordiv += scalar_t(1.0);
 | 
			
		||||
              }
 | 
			
		||||
            } else {
 | 
			
		||||
              floordiv = std::copysign(scalar_t(0), a / b);
 | 
			
		||||
              floordiv = copysign(scalar_t(0), a / b);
 | 
			
		||||
            }
 | 
			
		||||
            return floordiv;
 | 
			
		||||
          });
 | 
			
		||||
@ -889,23 +910,6 @@ void heaviside_kernel(TensorIterator& iter) {
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
T copysign(T a, T b) {
 | 
			
		||||
  return std::copysign(a, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Implement copysign for half precision floats using bit ops
 | 
			
		||||
// Sign is the most significant bit for both half and bfloat16 types
 | 
			
		||||
template<>
 | 
			
		||||
c10::Half copysign(c10::Half a, c10::Half b) {
 | 
			
		||||
  return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
 | 
			
		||||
   return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void copysign_kernel(TensorIterator& iter) {
 | 
			
		||||
  AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
 | 
			
		||||
    cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
 | 
			
		||||
 | 
			
		||||
@ -133,7 +133,9 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) {
 | 
			
		||||
    ASSERT_EQ(buffer1[i].z, buffer2[i].z);
 | 
			
		||||
    ASSERT_EQ(buffer1[i].w, buffer2[i].w);
 | 
			
		||||
  }
 | 
			
		||||
// Skipping this part until https://github.com/pytorch/pytorch/issues/51863 is resolved
 | 
			
		||||
 | 
			
		||||
#if 0
 | 
			
		||||
  // unaligned
 | 
			
		||||
  for (int i = 0; i < 16; i++) {
 | 
			
		||||
    for (int j = 0; j < 16; j++) {
 | 
			
		||||
@ -151,4 +153,5 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) {
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,7 @@ int32_t driver_version() {
 | 
			
		||||
  return driver_version;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int device_count_impl() {
 | 
			
		||||
int device_count_impl(bool fail_if_no_driver) {
 | 
			
		||||
  int count;
 | 
			
		||||
  auto err = cudaGetDeviceCount(&count);
 | 
			
		||||
  if (err == cudaSuccess) {
 | 
			
		||||
@ -34,6 +34,11 @@ int device_count_impl() {
 | 
			
		||||
    case cudaErrorInsufficientDriver: {
 | 
			
		||||
      auto version = driver_version();
 | 
			
		||||
      if (version <= 0) {
 | 
			
		||||
        if (!fail_if_no_driver) {
 | 
			
		||||
          // No CUDA driver means no devices
 | 
			
		||||
          count = 0;
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        TORCH_CHECK(
 | 
			
		||||
            false,
 | 
			
		||||
            "Found no NVIDIA driver on your system. Please check that you "
 | 
			
		||||
@ -95,9 +100,9 @@ DeviceIndex device_count() noexcept {
 | 
			
		||||
  // initialize number of devices only once
 | 
			
		||||
  static int count = []() {
 | 
			
		||||
    try {
 | 
			
		||||
      auto result = device_count_impl();
 | 
			
		||||
      auto result = device_count_impl(/*fail_if_no_driver=*/false);
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
 | 
			
		||||
      return device_count_impl();
 | 
			
		||||
      return result;
 | 
			
		||||
    } catch (const c10::Error& ex) {
 | 
			
		||||
      // We don't want to fail, but still log the warning
 | 
			
		||||
      // msg() returns the message without the stack trace
 | 
			
		||||
@ -110,7 +115,7 @@ DeviceIndex device_count() noexcept {
 | 
			
		||||
 | 
			
		||||
DeviceIndex device_count_ensure_non_zero() {
 | 
			
		||||
  // Call the implementation every time to throw the exception
 | 
			
		||||
  int count = device_count_impl();
 | 
			
		||||
  int count = device_count_impl(/*fail_if_no_driver=*/true);
 | 
			
		||||
  // Zero gpus doesn't produce a warning in `device_count` but we fail here
 | 
			
		||||
  TORCH_CHECK(count, "No CUDA GPUs are available");
 | 
			
		||||
  return static_cast<DeviceIndex>(count);
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										74
									
								
								docs/source/ddp_comm_hooks.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								docs/source/ddp_comm_hooks.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,74 @@
 | 
			
		||||
DDP Communication Hooks
 | 
			
		||||
=======================
 | 
			
		||||
 | 
			
		||||
DDP communication hook is a generic interface to control how to communicate
 | 
			
		||||
gradients across workers by overriding the vanilla allreduce in
 | 
			
		||||
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_.
 | 
			
		||||
A few built-in communication hooks are provided,
 | 
			
		||||
and users can easily apply any of these hooks to optimize communication.
 | 
			
		||||
Besides, the hook interface can also support user-defined communication
 | 
			
		||||
strategies for more advanced use cases.
 | 
			
		||||
 | 
			
		||||
.. warning ::
 | 
			
		||||
    DDP communication hook is experimental and subject to change.
 | 
			
		||||
 | 
			
		||||
.. warning ::
 | 
			
		||||
    DDP communication hooks can only support single process single device mode
 | 
			
		||||
    on NCCL backend.
 | 
			
		||||
 | 
			
		||||
How to Use a Communication Hook?
 | 
			
		||||
--------------------------------
 | 
			
		||||
 | 
			
		||||
To use a communication hook, the user just needs to let the DDP model register
 | 
			
		||||
the hook before the training loop as below.
 | 
			
		||||
 | 
			
		||||
:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`.
 | 
			
		||||
    :noindex:
 | 
			
		||||
 | 
			
		||||
Default Communication Hooks
 | 
			
		||||
---------------------------
 | 
			
		||||
 | 
			
		||||
Default communication hooks are simple **stateless** hooks, so the input state
 | 
			
		||||
in ``register_comm_hook`` is either a process group or ``None``.
 | 
			
		||||
 | 
			
		||||
.. automodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
 | 
			
		||||
    :members:
 | 
			
		||||
 | 
			
		||||
PowerSGD Communication Hook
 | 
			
		||||
---------------------------
 | 
			
		||||
 | 
			
		||||
PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_)
 | 
			
		||||
is a gradient compression algorithm, which can provide very high compression
 | 
			
		||||
rates and accelerate bandwidth-bound distributed training.
 | 
			
		||||
This algorithm needs to maintain both some hyperparameters and the internal
 | 
			
		||||
state. Therefore, PowerSGD communication hook is a **stateful** hook,
 | 
			
		||||
and the user needs to provide a state object defined as below.
 | 
			
		||||
 | 
			
		||||
PowerSGD State
 | 
			
		||||
^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
 | 
			
		||||
.. autoclass:: PowerSGDState
 | 
			
		||||
 | 
			
		||||
PowerSGD Hooks
 | 
			
		||||
^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
.. warning ::
 | 
			
		||||
    PowerSGD typically requires extra memory of the same size as the model's
 | 
			
		||||
    gradients to enable error feedback, which can compensate for biased
 | 
			
		||||
    compressed communication and improve accuracy.
 | 
			
		||||
 | 
			
		||||
.. warning ::
 | 
			
		||||
    The current implementation may cause gradient overflow for FP16 input.
 | 
			
		||||
 | 
			
		||||
.. autofunction:: powerSGD_hook
 | 
			
		||||
.. autofunction:: batched_powerSGD_hook
 | 
			
		||||
 | 
			
		||||
Acknowledgements
 | 
			
		||||
----------------
 | 
			
		||||
 | 
			
		||||
Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on
 | 
			
		||||
PowerSGD communication hook, as well as the
 | 
			
		||||
`comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_,
 | 
			
		||||
which show that the performance of PowerSGD communication hook is on par with
 | 
			
		||||
the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_.
 | 
			
		||||
@ -71,6 +71,7 @@ Features described in this documentation are classified by release status:
 | 
			
		||||
   onnx
 | 
			
		||||
   optim
 | 
			
		||||
   complex_numbers
 | 
			
		||||
   ddp_comm_hooks
 | 
			
		||||
   pipeline
 | 
			
		||||
   quantization
 | 
			
		||||
   rpc
 | 
			
		||||
 | 
			
		||||
@ -484,6 +484,7 @@ Sparse tensor functions
 | 
			
		||||
+++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
.. autofunction:: torch.sparse_coo_tensor
 | 
			
		||||
   :noindex:
 | 
			
		||||
.. autofunction:: torch.sparse.sum
 | 
			
		||||
.. autofunction:: torch.sparse.addmm
 | 
			
		||||
.. autofunction:: torch.sparse.mm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										45
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								setup.py
									
									
									
									
									
								
							@ -552,6 +552,50 @@ class build_ext(setuptools.command.build_ext.build_ext):
 | 
			
		||||
            with open('compile_commands.json', 'w') as f:
 | 
			
		||||
                f.write(new_contents)
 | 
			
		||||
 | 
			
		||||
class concat_license_files():
 | 
			
		||||
    """Merge LICENSE and LICENSES_BUNDLED.txt as a context manager
 | 
			
		||||
 | 
			
		||||
    LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated
 | 
			
		||||
    from all the licenses found in ./third_party/. We concatenate them so there
 | 
			
		||||
    is a single license file in the sdist and wheels with all of the necessary
 | 
			
		||||
    licensing info.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.f1 = 'LICENSE'
 | 
			
		||||
        self.f2 = 'third_party/LICENSES_BUNDLED.txt'
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        """Concatenate files"""
 | 
			
		||||
        with open(self.f1, 'r') as f1:
 | 
			
		||||
            self.bsd_text = f1.read()
 | 
			
		||||
 | 
			
		||||
        with open(self.f1, 'a') as f1:
 | 
			
		||||
            with open(self.f2, 'r') as f2:
 | 
			
		||||
                self.bundled_text = f2.read()
 | 
			
		||||
                f1.write('\n\n')
 | 
			
		||||
                f1.write(self.bundled_text)
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exception_type, exception_value, traceback):
 | 
			
		||||
        """Restore content of f1"""
 | 
			
		||||
        with open(self.f1, 'w') as f:
 | 
			
		||||
            f.write(self.bsd_text)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from wheel.bdist_wheel import bdist_wheel
 | 
			
		||||
except ImportError:
 | 
			
		||||
    # This is useful when wheel is not installed and bdist_wheel is not
 | 
			
		||||
    # specified on the command line. If it _is_ specified, parsing the command
 | 
			
		||||
    # line will fail before wheel_concatenate is needed
 | 
			
		||||
    wheel_concatenate = None
 | 
			
		||||
else:
 | 
			
		||||
    # Need to create the proper LICENSE.txt for the wheel
 | 
			
		||||
    class wheel_concatenate(bdist_wheel):
 | 
			
		||||
        """ check submodules on sdist to prevent incomplete tarballs """
 | 
			
		||||
        def run(self):
 | 
			
		||||
            with concat_license_files():
 | 
			
		||||
                super().run()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class install(setuptools.command.install.install):
 | 
			
		||||
    def run(self):
 | 
			
		||||
@ -724,6 +768,7 @@ def configure_extension_build():
 | 
			
		||||
        'build_ext': build_ext,
 | 
			
		||||
        'clean': clean,
 | 
			
		||||
        'install': install,
 | 
			
		||||
        'bdist_wheel': wheel_concatenate,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    entry_points = {
 | 
			
		||||
 | 
			
		||||
@ -872,7 +872,7 @@ class TestFakeQuantize(TestCase):
 | 
			
		||||
            scale, zero_point = float(scale), int(zero_point)
 | 
			
		||||
            quant_min, quant_max = obs._calculate_qmin_qmax()
 | 
			
		||||
 | 
			
		||||
            Y_test, _mask = torch.fake_quantize_per_tensor_affine_cachemask(
 | 
			
		||||
            Y_test = torch.fake_quantize_per_tensor_affine(
 | 
			
		||||
                X, scale, zero_point, quant_min, quant_max)
 | 
			
		||||
            Y_ref = _fake_quantize_per_tensor_affine_reference(
 | 
			
		||||
                X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
 | 
			
		||||
@ -899,7 +899,7 @@ class TestFakeQuantize(TestCase):
 | 
			
		||||
            quant_min, quant_max = obs._calculate_qmin_qmax()
 | 
			
		||||
 | 
			
		||||
            # forward pass
 | 
			
		||||
            Y_test, mask = torch.fake_quantize_per_tensor_affine_cachemask(
 | 
			
		||||
            Y_test = torch.fake_quantize_per_tensor_affine(
 | 
			
		||||
                X, scale, zero_point, quant_min, quant_max)
 | 
			
		||||
            Y_ref = _fake_quantize_per_tensor_affine_reference(
 | 
			
		||||
                X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
 | 
			
		||||
@ -1246,7 +1246,7 @@ class TestFakeQuantize(TestCase):
 | 
			
		||||
 | 
			
		||||
            Y = _fake_quantize_per_channel_affine_reference(
 | 
			
		||||
                X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
 | 
			
		||||
            Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
 | 
			
		||||
            Y_prime = torch.fake_quantize_per_channel_affine(
 | 
			
		||||
                X, scale, zero_point, axis, quant_min, quant_max)
 | 
			
		||||
            np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
 | 
			
		||||
 | 
			
		||||
@ -1339,7 +1339,7 @@ class TestFakeQuantize(TestCase):
 | 
			
		||||
            zero_point = zero_point.to(torch.int64)
 | 
			
		||||
            quant_min, quant_max = obs._calculate_qmin_qmax()
 | 
			
		||||
            X.requires_grad_()
 | 
			
		||||
            Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
 | 
			
		||||
            Y_prime = torch.fake_quantize_per_channel_affine(
 | 
			
		||||
                X, scale, zero_point, axis, quant_min, quant_max)
 | 
			
		||||
            dout = torch.rand(X.shape, dtype=torch.float).to(device)
 | 
			
		||||
            dX = _fake_quantize_per_channel_affine_grad_reference(
 | 
			
		||||
 | 
			
		||||
@ -108,6 +108,7 @@ TESTS = [
 | 
			
		||||
    'test_fx_experimental',
 | 
			
		||||
    'test_functional_autograd_benchmark',
 | 
			
		||||
    'test_package',
 | 
			
		||||
    'test_license',
 | 
			
		||||
    'distributed/pipeline/sync/skip/test_api',
 | 
			
		||||
    'distributed/pipeline/sync/skip/test_gpipe',
 | 
			
		||||
    'distributed/pipeline/sync/skip/test_inspect_skip_layout',
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ from math import sqrt
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from torch.multiprocessing import Process
 | 
			
		||||
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap
 | 
			
		||||
from torch.fx.node import Target
 | 
			
		||||
from torch.fx.node import Target, Argument
 | 
			
		||||
from torch.fx.passes import shape_prop
 | 
			
		||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
@ -187,7 +187,7 @@ class TestFX(JitTestCase):
 | 
			
		||||
        # Custom delegate to disallow in-place tensor operations
 | 
			
		||||
        class NoMutableCallTracer(Tracer):
 | 
			
		||||
            def create_node(self, kind : str, target : Union[str, Callable],
 | 
			
		||||
                            args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
 | 
			
		||||
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
 | 
			
		||||
                            type_expr : Optional[Any] = None) -> Node:
 | 
			
		||||
                name = target if isinstance(target, str) else torch.typename(target)
 | 
			
		||||
                if name[-1] == '_':
 | 
			
		||||
@ -539,7 +539,7 @@ class TestFX(JitTestCase):
 | 
			
		||||
    def test_node_tagging(self):
 | 
			
		||||
        class TaggingTracer(Tracer):
 | 
			
		||||
            def create_node(self, kind : str, target : Union[str, Callable],
 | 
			
		||||
                            args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
 | 
			
		||||
                            args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
 | 
			
		||||
                            type_expr : Optional[Any] = None) -> Node:
 | 
			
		||||
                n = super().create_node(kind, target, args, kwargs, name)
 | 
			
		||||
                n.tag = 'foo'
 | 
			
		||||
@ -1057,6 +1057,13 @@ class TestFX(JitTestCase):
 | 
			
		||||
        result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
 | 
			
		||||
        self.assertEqual(result, torch.ones(3, 4) * 2.0)
 | 
			
		||||
 | 
			
		||||
    @skipIfNoTorchVision
 | 
			
		||||
    def test_interpreter_noop_resnet18(self):
 | 
			
		||||
        rn18 = resnet18()
 | 
			
		||||
        transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
 | 
			
		||||
        inp = torch.randn(5, 3, 224, 224)
 | 
			
		||||
        self.assertEqual(transformed(inp), rn18(inp))
 | 
			
		||||
 | 
			
		||||
    def test_transformer_noop(self):
 | 
			
		||||
        class MyModule(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
@ -1377,6 +1384,45 @@ class TestFX(JitTestCase):
 | 
			
		||||
        x, y = torch.randn(3, 4), torch.randn(3, 4)
 | 
			
		||||
        self.checkGraphModule(foo, (x, y))
 | 
			
		||||
 | 
			
		||||
    def test_trace_dict_int_keys(self):
 | 
			
		||||
        class ModWithDictArg(torch.nn.Module):
 | 
			
		||||
            def forward(self, d : Dict[int, torch.Tensor]):
 | 
			
		||||
                return d[42]
 | 
			
		||||
 | 
			
		||||
        class CallsModWithDict(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
                self.m = ModWithDictArg()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return self.m({42: x})
 | 
			
		||||
 | 
			
		||||
        class MyTracer(torch.fx.Tracer):
 | 
			
		||||
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
 | 
			
		||||
                return isinstance(m, ModWithDictArg)
 | 
			
		||||
 | 
			
		||||
        traced_graph = MyTracer().trace(CallsModWithDict())
 | 
			
		||||
 | 
			
		||||
    def test_trace_dict_proxy_keys(self):
 | 
			
		||||
        class ModWithDictArg(torch.nn.Module):
 | 
			
		||||
            def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
 | 
			
		||||
                return d[42]
 | 
			
		||||
 | 
			
		||||
        class CallsModWithDict(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
                self.m = ModWithDictArg()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return self.m({x: x})
 | 
			
		||||
 | 
			
		||||
        class MyTracer(torch.fx.Tracer):
 | 
			
		||||
            def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
 | 
			
		||||
                return isinstance(m, ModWithDictArg)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
 | 
			
		||||
            traced_graph = MyTracer().trace(CallsModWithDict())
 | 
			
		||||
 | 
			
		||||
    def test_direct_param_use(self):
 | 
			
		||||
        class TransposeTest(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
 | 
			
		||||
@ -5,14 +5,14 @@ from typing import Callable, Dict, Union, List
 | 
			
		||||
from torch.fx.symbolic_trace import symbolic_trace
 | 
			
		||||
from torch.fx.graph_module import GraphModule
 | 
			
		||||
from torch.fx.node import Node
 | 
			
		||||
from torch.fx.experimental import graph_manipulation
 | 
			
		||||
from torch.fx.experimental.accelerator_partitioner import Partitioner
 | 
			
		||||
from torch.fx.experimental.rewriter import RewritingTracer
 | 
			
		||||
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
 | 
			
		||||
from torch.fx._experimental import graph_manipulation
 | 
			
		||||
from torch.fx._experimental.accelerator_partitioner import Partitioner
 | 
			
		||||
from torch.fx._experimental.rewriter import RewritingTracer
 | 
			
		||||
from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes
 | 
			
		||||
from torch.testing._internal.common_utils import run_tests
 | 
			
		||||
from torch.testing._internal.jit_utils import JitTestCase
 | 
			
		||||
from torch.fx.passes.split_module import split_module
 | 
			
		||||
from torch.fx.experimental.partitioner_utils import (
 | 
			
		||||
from torch.fx._experimental.partitioner_utils import (
 | 
			
		||||
    NodeLatency,
 | 
			
		||||
    get_partition_to_latency_mapping,
 | 
			
		||||
    get_latency_of_partitioned_graph,
 | 
			
		||||
@ -20,8 +20,8 @@ from torch.fx.experimental.partitioner_utils import (
 | 
			
		||||
    PartitionerConfig,
 | 
			
		||||
    PartitionMode
 | 
			
		||||
)
 | 
			
		||||
from torch.fx.experimental.fuser import fuse
 | 
			
		||||
from torch.fx.experimental import merge_matmul
 | 
			
		||||
from torch.fx._experimental.fuser import fuse
 | 
			
		||||
from torch.fx._experimental import merge_matmul
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from torchvision.models import resnet18
 | 
			
		||||
@ -849,7 +849,7 @@ terrible spacing
 | 
			
		||||
 | 
			
		||||
    def test_merge_matmuls(self):
 | 
			
		||||
        """
 | 
			
		||||
        A collection of test cases for torch.fx.experimental.merge_matmul,
 | 
			
		||||
        A collection of test cases for torch.fx._experimental.merge_matmul,
 | 
			
		||||
        a graph transformation that merges matrix multiplication operations.
 | 
			
		||||
        """
 | 
			
		||||
        # Utility function for counting matmuls for test assertions.
 | 
			
		||||
 | 
			
		||||
@ -6503,6 +6503,38 @@ a")
 | 
			
		||||
            self.checkModule(module().train(), ())
 | 
			
		||||
            self.checkModule(module().eval(), ())
 | 
			
		||||
 | 
			
		||||
    def test_ternary_static_if(self):
 | 
			
		||||
        # Test for True branch when condition variable
 | 
			
		||||
        # is annotated as Final
 | 
			
		||||
        class M1(torch.nn.Module):
 | 
			
		||||
            flag: torch.jit.Final[bool]
 | 
			
		||||
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
                self.flag = True
 | 
			
		||||
 | 
			
		||||
            def forward(self) -> torch.Tensor:
 | 
			
		||||
                return torch.ones(3) if self.flag else {}
 | 
			
		||||
 | 
			
		||||
        # Test for True branch when condition variable
 | 
			
		||||
        # is annotated as Final
 | 
			
		||||
        class M2(torch.nn.Module):
 | 
			
		||||
            flag: torch.jit.Final[bool]
 | 
			
		||||
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
                self.flag = False
 | 
			
		||||
 | 
			
		||||
            def forward(self) -> torch.Tensor:
 | 
			
		||||
                return {} if self.flag else torch.ones(3)
 | 
			
		||||
 | 
			
		||||
        model1 = M1()
 | 
			
		||||
        model2 = M2()
 | 
			
		||||
        script_model_1 = torch.jit.script(model1)
 | 
			
		||||
        script_model_2 = torch.jit.script(model2)
 | 
			
		||||
        self.assertEqual(model1.forward(), script_model_1.forward())
 | 
			
		||||
        self.assertEqual(model2.forward(), script_model_2.forward())
 | 
			
		||||
 | 
			
		||||
    def test_print(self):
 | 
			
		||||
        def func(x, y):
 | 
			
		||||
            q = (x + y).sigmoid()
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,9 @@
 | 
			
		||||
import glob
 | 
			
		||||
import io
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.testing._internal.common_utils import TestCase, run_tests
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -10,11 +13,14 @@ except ImportError:
 | 
			
		||||
    create_bundled = None
 | 
			
		||||
 | 
			
		||||
license_file = 'third_party/LICENSES_BUNDLED.txt'
 | 
			
		||||
starting_txt = 'The Pytorch repository and source distributions bundle'
 | 
			
		||||
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
 | 
			
		||||
distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info'))
 | 
			
		||||
 | 
			
		||||
class TestLicense(TestCase):
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not create_bundled, "can only be run in a source tree")
 | 
			
		||||
    def test_license_in_wheel(self):
 | 
			
		||||
    def test_license_for_wheel(self):
 | 
			
		||||
        current = io.StringIO()
 | 
			
		||||
        create_bundled('third_party', current)
 | 
			
		||||
        with open(license_file) as fid:
 | 
			
		||||
@ -25,6 +31,18 @@ class TestLicense(TestCase):
 | 
			
		||||
                'match the current state of the third_party files. Use '
 | 
			
		||||
                '"python third_party/build_bundled.py" to regenerate it')
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
 | 
			
		||||
    def test_distinfo_license(self):
 | 
			
		||||
        """If run when pytorch is installed via a wheel, the license will be in
 | 
			
		||||
        site-package/torch-*dist-info/LICENSE. Make sure it contains the third
 | 
			
		||||
        party bundle of licenses"""
 | 
			
		||||
 | 
			
		||||
        if len(distinfo) > 1:
 | 
			
		||||
            raise AssertionError('Found too many "torch-*dist-info" directories '
 | 
			
		||||
                                 f'in "{site_packages}, expected only one')
 | 
			
		||||
        with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid:
 | 
			
		||||
            txt = fid.read()
 | 
			
		||||
            self.assertTrue(starting_txt in txt)
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    run_tests()
 | 
			
		||||
 | 
			
		||||
@ -82,7 +82,8 @@ SKIP_PYTHON_BINDINGS = [
 | 
			
		||||
    'set_data',
 | 
			
		||||
    '.*_overrideable',  # overrideable functions for backend extension
 | 
			
		||||
    'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
 | 
			
		||||
    '_fw_primal'
 | 
			
		||||
    '_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
 | 
			
		||||
    'fake_quantize_per_channel_affine_cachemask',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# These function signatures are not exposed to Python. Note that this signature
 | 
			
		||||
 | 
			
		||||
@ -1258,6 +1258,15 @@ struct to_ir {
 | 
			
		||||
      const TernaryIf& expr,
 | 
			
		||||
      const TypePtr& type_hint = nullptr) {
 | 
			
		||||
    CondValue cond_value = emitCondExpr(expr.cond());
 | 
			
		||||
    // If the cond expr is a static value, then we metacompile the `if`
 | 
			
		||||
    // statemement and only emit true or false branch
 | 
			
		||||
    if (cond_value.staticIf()) {
 | 
			
		||||
        if (*cond_value.staticIf()) {
 | 
			
		||||
            return emitExpr(expr.true_expr(), type_hint);
 | 
			
		||||
        } else {
 | 
			
		||||
            return emitExpr(expr.false_expr(), type_hint);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); };
 | 
			
		||||
    auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); };
 | 
			
		||||
    return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
 | 
			
		||||
 | 
			
		||||
@ -33,39 +33,38 @@ def _orthogonalize(matrix, epsilon=1e-8):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PowerSGDState(object):
 | 
			
		||||
    """
 | 
			
		||||
    Stores both the gradient compression configs and the internal states for all the gradients during the training.
 | 
			
		||||
    Particularly, `matrix_approximation_rank` and `start_powerSGD_iter` are the main configs that need to be tuned by the user.
 | 
			
		||||
    Although `use_error_feedback` and `warm_start` can also be tuned by the user,
 | 
			
		||||
    they are typically turned on for performance.
 | 
			
		||||
    r"""
 | 
			
		||||
    Stores both the algorithm's hyperparameters and the internal state for all the gradients during the training.
 | 
			
		||||
    Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
 | 
			
		||||
    For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
 | 
			
		||||
 | 
			
		||||
    Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`]
 | 
			
		||||
    ~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
    1) To tune `matrix_approximation_rank`, the user can increase it from 1 by factors of 2,
 | 
			
		||||
    until a satisfying accuracy can be reached.
 | 
			
		||||
    The increase of `matrix_approximation_rank` can substantially increase the computation costs of the compression.
 | 
			
		||||
    However, the accuracy may not be futher improved beyond a certain `matrix_approximation_rank` value.
 | 
			
		||||
    2) To tune `start_powerSGD_iter`, the user can typically start with 10% of total training steps,
 | 
			
		||||
    and increase it until a satisfying accuracy can be reached.
 | 
			
		||||
    Deferrring PowerSGD can effectively improve the accuracy,
 | 
			
		||||
    even a relatively small `matrix_approximation_rank` is used.
 | 
			
		||||
    This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients,
 | 
			
		||||
    and compressing gradients too early may make the training quickly take a suboptimal trajectory,
 | 
			
		||||
    which can result in an irrecoverable impact on the accuracy.
 | 
			
		||||
    The minimum value allowed in DDP is 2, if error feedback or warm-up is enabled.
 | 
			
		||||
    This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
 | 
			
		||||
    and this can conflict with any tensor memorized before the rebuild process.
 | 
			
		||||
    """
 | 
			
		||||
    1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
 | 
			
		||||
 | 
			
		||||
        1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
 | 
			
		||||
 | 
			
		||||
        1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be futher improved beyond a certain ``matrix_approximation_rank`` threshold.
 | 
			
		||||
 | 
			
		||||
    To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an expoential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
 | 
			
		||||
 | 
			
		||||
    2. ``start_powerSGD_iter`` defers PowerSGD compression util step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
 | 
			
		||||
 | 
			
		||||
    To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached.
 | 
			
		||||
 | 
			
		||||
    .. warning ::
 | 
			
		||||
        If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
 | 
			
		||||
        This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
 | 
			
		||||
        and this can conflict with any tensor memorized before the rebuild process.
 | 
			
		||||
    """  # noqa
 | 
			
		||||
 | 
			
		||||
    __slots__ = [
 | 
			
		||||
        "process_group",
 | 
			
		||||
        # The two fields below are the configs that usually need to be tuned by the user.
 | 
			
		||||
        # The two fields below are the hyperparameters that should be tuned by the user.
 | 
			
		||||
        "matrix_approximation_rank",
 | 
			
		||||
        "start_powerSGD_iter",
 | 
			
		||||
        # The two fields below are the configs that usually need to be turned on for performance.
 | 
			
		||||
        # The two fields below are the binary hyperparameters recommended to be turned on for performance.
 | 
			
		||||
        "use_error_feedback",
 | 
			
		||||
        "warm_start",
 | 
			
		||||
        # The fields below are not configs.
 | 
			
		||||
        # The fields below are internal state.
 | 
			
		||||
        "rng",
 | 
			
		||||
        "error_dict",
 | 
			
		||||
        "p_memory_dict",
 | 
			
		||||
@ -93,21 +92,12 @@ class PowerSGDState(object):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.process_group = process_group
 | 
			
		||||
        # The low rank for matrix approximation controls the size of compressed low-rank tensors,
 | 
			
		||||
        # which determines the computation ratio.
 | 
			
		||||
        # Typically only a small value 1-4 is used.
 | 
			
		||||
        # For some NLP tasks (as shown in Appendix D of the original paper
 | 
			
		||||
        # https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32.
 | 
			
		||||
        # A high rank value will increase the computation costs of compression exponentially.
 | 
			
		||||
        # A good choice depends on how much extra computation can be hidden by the dominating communication costs.
 | 
			
		||||
        self.matrix_approximation_rank = matrix_approximation_rank
 | 
			
		||||
        # This defers PowerSGD compression util step 'start_powerSGD_iter',
 | 
			
		||||
        # and vanilla allreduce runs before step 'start_powerSGD_iter'.
 | 
			
		||||
        # This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages:
 | 
			
		||||
        # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
 | 
			
		||||
        # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
 | 
			
		||||
        # even if the matrix approximation rank is increased to a large value.
 | 
			
		||||
        # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
 | 
			
		||||
        # (or a more convervative compression such as FP16 compression) with PowerSGD.
 | 
			
		||||
        # (or a more conservative compression such as FP16 compression) with PowerSGD.
 | 
			
		||||
        # 2) There is an internal optimization of rebuilding buckets process in DDP,
 | 
			
		||||
        # in order to save the memory space.
 | 
			
		||||
        # This step takes place after the first iteration.
 | 
			
		||||
@ -162,38 +152,44 @@ class PowerSGDState(object):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
    """
 | 
			
		||||
    This DDP communication hook implements the original PowerSGD gradient compression
 | 
			
		||||
    algorithm described in https://arxiv.org/abs/1905.13727.
 | 
			
		||||
    r"""
 | 
			
		||||
    This DDP communication hook implements PowerSGD gradient compression
 | 
			
		||||
    algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
 | 
			
		||||
    Once gradient tensors are aggregated across all workers, this hook applies
 | 
			
		||||
    compression as follows:
 | 
			
		||||
    1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
 | 
			
		||||
    high-rank tensors and vector-like rank-1 tensors (for biases).
 | 
			
		||||
    2) Handles rank-1 tensors by allreducing them without compression:
 | 
			
		||||
        2.1) Allocate contiguous memory for those rank-1 tensors,
 | 
			
		||||
        and allreduces all the rank-1 tensors as a batch, without compression;
 | 
			
		||||
        2.2) Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
 | 
			
		||||
    3) Handles high-rank tensors by PowerSGD compression:
 | 
			
		||||
        3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
 | 
			
		||||
 | 
			
		||||
    1. Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: high-rank tensors and vector-like rank-1 tensors (for biases).
 | 
			
		||||
 | 
			
		||||
    2. Handles rank-1 tensors by allreducing them without compression:
 | 
			
		||||
 | 
			
		||||
        2.1. Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression;
 | 
			
		||||
 | 
			
		||||
        2.2. Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
 | 
			
		||||
 | 
			
		||||
    3. Handles high-rank tensors by PowerSGD compression:
 | 
			
		||||
 | 
			
		||||
        3.1. For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
 | 
			
		||||
        such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
 | 
			
		||||
        3.2) Computes each P in Ps, which is equal to MQ;
 | 
			
		||||
        3.3) Allreduces Ps as a batch;
 | 
			
		||||
        3.4) Orthogonalizes each P in Ps;
 | 
			
		||||
        3.5) Computes each Q in Qs, which is approximately equal to M^TP;
 | 
			
		||||
        3.6) Allreduces Qs as a batch;
 | 
			
		||||
        3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
 | 
			
		||||
 | 
			
		||||
    Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
 | 
			
		||||
    This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
 | 
			
		||||
    but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
 | 
			
		||||
        3.2. Computes each P in Ps, which is equal to MQ;
 | 
			
		||||
 | 
			
		||||
    TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
 | 
			
		||||
    one left multiplication and one right multiplication.
 | 
			
		||||
    For warm-start, can take one such step at a time, and alternate between them.
 | 
			
		||||
        3.3. Allreduces Ps as a batch;
 | 
			
		||||
 | 
			
		||||
        3.4. Orthogonalizes each P in Ps;
 | 
			
		||||
 | 
			
		||||
        3.5. Computes each Q in Qs, which is approximately equal to M^TP;
 | 
			
		||||
 | 
			
		||||
        3.6. Allreduces Qs as a batch;
 | 
			
		||||
 | 
			
		||||
        3.7. Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
 | 
			
		||||
 | 
			
		||||
    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
 | 
			
		||||
    This not only gives the user more control over the tradeoff between speedup and accuracy,
 | 
			
		||||
    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
 | 
			
		||||
            To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
 | 
			
		||||
            To tune the compression configs, mainly need to tune `matrix_approximation_rank`` and ``start_powerSGD_iter``.
 | 
			
		||||
        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
 | 
			
		||||
            Note that since DDP comm hook only supports single process single device mode at this time,
 | 
			
		||||
            only exactly one tensor is stored in this bucket.
 | 
			
		||||
@ -202,9 +198,9 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
        Future handler of the communication, which updates the gradients in place.
 | 
			
		||||
 | 
			
		||||
    Example::
 | 
			
		||||
        state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
 | 
			
		||||
        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
 | 
			
		||||
        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
 | 
			
		||||
    """
 | 
			
		||||
    """  # noqa
 | 
			
		||||
    process_group = state.process_group
 | 
			
		||||
    group_to_use = process_group if process_group is not None else dist.group.WORLD
 | 
			
		||||
    world_size = group_to_use.size()
 | 
			
		||||
@ -374,6 +370,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
        for tensor, p, q in zip(high_rank_tensors, ps, qs):
 | 
			
		||||
            torch.matmul(tensor.t(), p, out=q)
 | 
			
		||||
 | 
			
		||||
        # TODO: The above procedure does two matmul+allreduce steps per iteration --
 | 
			
		||||
        # one left multiplication and one right multiplication.
 | 
			
		||||
        # For warm-start, can take one such step at a time, and alternate between them.
 | 
			
		||||
 | 
			
		||||
        # Allreduce Qs.
 | 
			
		||||
        return [
 | 
			
		||||
            dist.all_reduce(
 | 
			
		||||
@ -412,40 +412,48 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
    """
 | 
			
		||||
    r"""
 | 
			
		||||
    This DDP communication hook implements a simplified PowerSGD gradient compression
 | 
			
		||||
    algorithm described in https://arxiv.org/abs/1905.13727.
 | 
			
		||||
    algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
 | 
			
		||||
    This variant does not compress the gradients layer by layer,
 | 
			
		||||
    but instead compresses the flattened input tensor that batches all the gradients.
 | 
			
		||||
    Therefore, it is **faster** than :meth:`powerSGD_hook`,
 | 
			
		||||
    but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
 | 
			
		||||
 | 
			
		||||
    .. warning ::
 | 
			
		||||
        Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
 | 
			
		||||
        because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
 | 
			
		||||
        Therefore, the user should always consider :meth:`powerSGD_hook` first,
 | 
			
		||||
        and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
 | 
			
		||||
 | 
			
		||||
    Once gradient tensors are aggregated across all workers, this hook applies
 | 
			
		||||
    compression to the flattened input tensor that batches per-parameter tensors as follows:
 | 
			
		||||
    1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
 | 
			
		||||
    2) Creates two low-rank tensors P and Q for decomposing M,
 | 
			
		||||
    such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
 | 
			
		||||
    2) Computes P, which is equal to MQ;
 | 
			
		||||
    3) Allreduces P;
 | 
			
		||||
    4) Orthogonalizes P;
 | 
			
		||||
    5) Computes Q, which is approximately equal to M^TP;
 | 
			
		||||
    6) Allreduces Q;
 | 
			
		||||
    7) Computes M, which is approximately equal to PQ^T.
 | 
			
		||||
    8) Truncates the input tensor to the original length.
 | 
			
		||||
    compression as follows:
 | 
			
		||||
 | 
			
		||||
    This variant is faster than `powerSGD_hook` that runs layer-wise gradient compression,
 | 
			
		||||
    but it usually results in a much lower accuracy, unless `matrix_approximation_rank` in the state is 1.
 | 
			
		||||
    Increasing `matrix_approximation_rank` may not necessarily increase the accuracy,
 | 
			
		||||
    because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
 | 
			
		||||
    Therefore, the user shoud always consider `powerSGD_hook` first,
 | 
			
		||||
    and only consider this variant when a satisfying accuracy can be achieved when `matrix_approximation_rank` is 1.
 | 
			
		||||
    1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
 | 
			
		||||
 | 
			
		||||
    Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
 | 
			
		||||
    This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
 | 
			
		||||
    but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
 | 
			
		||||
    2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
 | 
			
		||||
 | 
			
		||||
    TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
 | 
			
		||||
    one left multiplication and one right multiplication.
 | 
			
		||||
    For warm-start, can take one such step at a time, and alternate between them.
 | 
			
		||||
    3. Computes P, which is equal to MQ;
 | 
			
		||||
 | 
			
		||||
    4. Allreduces P;
 | 
			
		||||
 | 
			
		||||
    5. Orthogonalizes P;
 | 
			
		||||
 | 
			
		||||
    6. Computes Q, which is approximately equal to M^TP;
 | 
			
		||||
 | 
			
		||||
    7. Allreduces Q;
 | 
			
		||||
 | 
			
		||||
    8. Computes M, which is approximately equal to PQ^T.
 | 
			
		||||
 | 
			
		||||
    9. Truncates the input tensor to the original length.
 | 
			
		||||
 | 
			
		||||
    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
 | 
			
		||||
    This not only gives the user more control over the tradeoff between speedup and accuracy,
 | 
			
		||||
    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
 | 
			
		||||
            To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
 | 
			
		||||
            To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
 | 
			
		||||
        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
 | 
			
		||||
            Note that since DDP comm hook only supports single process single device mode at this time,
 | 
			
		||||
            only exactly one tensor is stored in this bucket.
 | 
			
		||||
@ -454,9 +462,9 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
        Future handler of the communication, which updates the gradients in place.
 | 
			
		||||
 | 
			
		||||
    Example::
 | 
			
		||||
        state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
 | 
			
		||||
        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
 | 
			
		||||
        >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
 | 
			
		||||
    """
 | 
			
		||||
    """  # noqa
 | 
			
		||||
    process_group = state.process_group
 | 
			
		||||
    group_to_use = process_group if process_group is not None else dist.group.WORLD
 | 
			
		||||
    world_size = group_to_use.size()
 | 
			
		||||
@ -563,6 +571,10 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
 | 
			
		||||
            out=state.q_memory_dict[bucket_index],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # TODO: The above procedure does two matmul+allreduce steps per iteration --
 | 
			
		||||
        # one left multiplication and one right multiplication.
 | 
			
		||||
        # For warm-start, can take one such step at a time, and alternate between them.
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            dist.all_reduce(
 | 
			
		||||
                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ from typing import Dict, List, Set, NamedTuple, Tuple
 | 
			
		||||
import torch
 | 
			
		||||
from torch.fx.passes.split_module import split_module
 | 
			
		||||
import operator
 | 
			
		||||
from torch.fx.experimental.partitioner_utils import Partition, \
 | 
			
		||||
from torch.fx._experimental.partitioner_utils import Partition, \
 | 
			
		||||
    Device, PartitionerConfig, get_partition_to_latency_mapping,\
 | 
			
		||||
    get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of, \
 | 
			
		||||
    PartitionMode
 | 
			
		||||
@ -2,7 +2,7 @@ from typing import Dict, List, NamedTuple, Any
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.fx.passes.shape_prop import ShapeProp
 | 
			
		||||
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
 | 
			
		||||
from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes
 | 
			
		||||
from torch.fx.graph import Graph, get_qualified_name
 | 
			
		||||
from torch.fx.graph_module import GraphModule
 | 
			
		||||
from torch.fx.node import Node, Target, map_arg
 | 
			
		||||
@ -116,7 +116,7 @@ class Interpreter:
 | 
			
		||||
 | 
			
		||||
    # Main Node running APIs
 | 
			
		||||
 | 
			
		||||
    def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``placeholder`` node. Note that this is stateful:
 | 
			
		||||
        ``Interpreter`` maintains an internal iterator over
 | 
			
		||||
@ -141,7 +141,7 @@ class Interpreter:
 | 
			
		||||
        else:
 | 
			
		||||
            return next(self.args_iter)
 | 
			
		||||
 | 
			
		||||
    def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``get_attr`` node. Will retrieve an attribute
 | 
			
		||||
        value from the ``Module`` hierarchy of ``self.module``.
 | 
			
		||||
@ -159,7 +159,7 @@ class Interpreter:
 | 
			
		||||
        assert isinstance(target, str)
 | 
			
		||||
        return self.fetch_attr(target)
 | 
			
		||||
 | 
			
		||||
    def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``call_function`` node and return the result.
 | 
			
		||||
 | 
			
		||||
@ -178,7 +178,7 @@ class Interpreter:
 | 
			
		||||
        # Execute the function and return the result
 | 
			
		||||
        return target(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``call_method`` node and return the result.
 | 
			
		||||
 | 
			
		||||
@ -199,7 +199,7 @@ class Interpreter:
 | 
			
		||||
        assert isinstance(target, str)
 | 
			
		||||
        return getattr(self_obj, target)(*args_tail, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def call_module(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``call_module`` node and return the result.
 | 
			
		||||
 | 
			
		||||
@ -221,7 +221,7 @@ class Interpreter:
 | 
			
		||||
 | 
			
		||||
        return submod(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def output(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
    def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        """
 | 
			
		||||
        Execute an ``output`` node. This really just retrieves
 | 
			
		||||
        the value referenced by the ``output`` node and returns it.
 | 
			
		||||
@ -307,12 +307,12 @@ class Transformer(Interpreter):
 | 
			
		||||
        method equivalents). We could subclass ``Transformer`` like so::
 | 
			
		||||
 | 
			
		||||
            class NegSigmSwapXformer(Transformer):
 | 
			
		||||
                def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
                def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
                    if target == torch.sigmoid:
 | 
			
		||||
                        return torch.neg(*args, **kwargs)
 | 
			
		||||
                    return super().call_function(n)
 | 
			
		||||
 | 
			
		||||
                def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
                def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
                    if target == 'neg':
 | 
			
		||||
                        call_self, *args_tail = args
 | 
			
		||||
                        return call_self.sigmoid(*args_tail, **kwargs)
 | 
			
		||||
@ -344,7 +344,7 @@ class Transformer(Interpreter):
 | 
			
		||||
        self.tracer = TransformerTracer(self.new_graph)
 | 
			
		||||
        self.tracer.root = module
 | 
			
		||||
 | 
			
		||||
    def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy:
 | 
			
		||||
    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``placeholder`` node. In ``Transformer``, this is
 | 
			
		||||
        overridden to insert a new ``placeholder`` into the output
 | 
			
		||||
@ -360,7 +360,7 @@ class Transformer(Interpreter):
 | 
			
		||||
        assert isinstance(target, str)
 | 
			
		||||
        return Proxy(self.new_graph.placeholder(target), self.tracer)
 | 
			
		||||
 | 
			
		||||
    def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy:
 | 
			
		||||
    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
 | 
			
		||||
        """
 | 
			
		||||
        Execute a ``get_attr`` node. In ``Transformer``, this is
 | 
			
		||||
        overridden to insert a new ``get_attr`` node into the output
 | 
			
		||||
@ -376,6 +376,12 @@ class Transformer(Interpreter):
 | 
			
		||||
        assert isinstance(target, str)
 | 
			
		||||
        return Proxy(self.new_graph.get_attr(target), self.tracer)
 | 
			
		||||
 | 
			
		||||
    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
 | 
			
		||||
        # Override so that the leaf module policy from `self.tracer` is respected.
 | 
			
		||||
        assert isinstance(target, str)
 | 
			
		||||
        submod = self.fetch_attr(target)
 | 
			
		||||
        return self.tracer.call_module(submod, submod.forward, args, kwargs)
 | 
			
		||||
 | 
			
		||||
    def transform(self) -> GraphModule:
 | 
			
		||||
        """
 | 
			
		||||
        Transform ``self.module`` and return the transformed
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ import operator
 | 
			
		||||
 | 
			
		||||
from .graph import magic_methods, reflectable_magic_methods, Graph
 | 
			
		||||
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator
 | 
			
		||||
from .node import Target, Node, Argument, base_types
 | 
			
		||||
from .node import Target, Node, Argument, base_types, map_aggregate
 | 
			
		||||
 | 
			
		||||
class TracerBase:
 | 
			
		||||
    graph: Graph
 | 
			
		||||
@ -61,8 +61,17 @@ class TracerBase:
 | 
			
		||||
        elif isinstance(a, dict):
 | 
			
		||||
            r = {}
 | 
			
		||||
            for k, v in a.items():
 | 
			
		||||
                if not isinstance(k, str):
 | 
			
		||||
                    raise NotImplementedError(f"dictionaries with non-string keys: {a}")
 | 
			
		||||
                # Check for invalid dict keys. We do not want a Proxy to appear
 | 
			
		||||
                # anywhere within the key. Since keys can be collection types,
 | 
			
		||||
                # we iterate through the key with map_aggregate
 | 
			
		||||
                k = self.create_arg(k)
 | 
			
		||||
 | 
			
		||||
                def no_node(arg):
 | 
			
		||||
                    if isinstance(arg, Node):
 | 
			
		||||
                        raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
 | 
			
		||||
                                           "Node. Got key: {k}")
 | 
			
		||||
                map_aggregate(k, no_node)
 | 
			
		||||
 | 
			
		||||
                r[k] = self.create_arg(v)
 | 
			
		||||
            return r
 | 
			
		||||
        elif isinstance(a, slice):
 | 
			
		||||
 | 
			
		||||
@ -1021,13 +1021,14 @@ class DistributedDataParallel(Module):
 | 
			
		||||
        parameter syncs while running Distributed DataParallel training.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            state (object): state is passed to the hook and can be used to maintain
 | 
			
		||||
                            and update any state information that users would like to
 | 
			
		||||
                            maintain as part of the training process. Examples: error
 | 
			
		||||
                            feedback in gradient compression, peers to communicate with
 | 
			
		||||
                            next in GossipGrad etc.
 | 
			
		||||
            hook (callable): is defined as:
 | 
			
		||||
                             hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future:
 | 
			
		||||
            state (object): Passed to the hook to maintain any state information during the training process.
 | 
			
		||||
                            Examples include error feedback in gradient compression,
 | 
			
		||||
                            peers to communicate with next in GossipGrad, etc.
 | 
			
		||||
 | 
			
		||||
                            It is locally stored by each worker
 | 
			
		||||
                            and shared by all the gradient tensors on the worker.
 | 
			
		||||
            hook (callable): Averages gradient tensors across workers and defined as:
 | 
			
		||||
                             ``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``:
 | 
			
		||||
 | 
			
		||||
                             This function is called once the bucket is ready. The
 | 
			
		||||
                             hook can perform whatever processing is needed and return
 | 
			
		||||
@ -1067,7 +1068,7 @@ class DistributedDataParallel(Module):
 | 
			
		||||
            DDP communication hook is experimental and subject to change.
 | 
			
		||||
 | 
			
		||||
        Example::
 | 
			
		||||
            Below is an example of a noop hook that returns back the same tensors:
 | 
			
		||||
            Below is an example of a noop hook that returns the same tensors.
 | 
			
		||||
 | 
			
		||||
            >>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
 | 
			
		||||
            >>>     fut = torch.futures.Future()
 | 
			
		||||
@ -1091,7 +1092,6 @@ class DistributedDataParallel(Module):
 | 
			
		||||
            >>>     return fut.then(decode)
 | 
			
		||||
 | 
			
		||||
            >>> ddp.register_comm_hook(state = None, hook = encode_and_decode)
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        self._check_comm_hook(hook)
 | 
			
		||||
        dist._register_comm_hook(self.reducer, state, hook)
 | 
			
		||||
 | 
			
		||||
@ -391,9 +391,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
 | 
			
		||||
        torch.exp2: lambda input, out=None: -1,
 | 
			
		||||
        torch.expm1: lambda input, out=None: -1,
 | 
			
		||||
        torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
 | 
			
		||||
        torch.fake_quantize_per_channel_affine_cachemask: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
 | 
			
		||||
        torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
 | 
			
		||||
        torch.fake_quantize_per_tensor_affine_cachemask: lambda input, scale, zero_point, quant_min, quant_max: -1,
 | 
			
		||||
        torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
 | 
			
		||||
        torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
 | 
			
		||||
        torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user