mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	Compare commits
	
		
			11 Commits
		
	
	
		
			fixflashgi
			...
			v1.8.0-rc2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| c79decdbba | |||
| c307a3f336 | |||
| f071020756 | |||
| 4f436f8570 | |||
| ae11589710 | |||
| 9e5bcc1020 | |||
| fa8578241d | |||
| 1368809532 | |||
| 4073248fc2 | |||
| 75153cb730 | |||
| 5bb69b080c | 
| @ -153,11 +153,11 @@ commands: | ||||
|                 git config --global user.email "circleci.ossci@gmail.com" | ||||
|                 git config --global user.name "CircleCI" | ||||
|                 git config remote.origin.url https://github.com/pytorch/pytorch.git | ||||
|                 git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master | ||||
|                 git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet | ||||
|                 git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8 | ||||
|                 git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet | ||||
|                 # PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base | ||||
|                 if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then | ||||
|                   CIRCLE_PR_BASE_BRANCH=master | ||||
|                   CIRCLE_PR_BASE_BRANCH=release/1.8 | ||||
|                 fi | ||||
|                 export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH` | ||||
|                 echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET} | ||||
|  | ||||
| @ -111,11 +111,11 @@ commands: | ||||
|                 git config --global user.email "circleci.ossci@gmail.com" | ||||
|                 git config --global user.name "CircleCI" | ||||
|                 git config remote.origin.url https://github.com/pytorch/pytorch.git | ||||
|                 git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master | ||||
|                 git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=100 --quiet | ||||
|                 git config --add remote.origin.fetch +refs/heads/release/1.8:refs/remotes/origin/release/1.8 | ||||
|                 git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/release/1.8:refs/remotes/origin/release/1.8 --depth=100 --quiet | ||||
|                 # PRs generated from ghstack has format CIRCLE_PR_BASE_BRANCH=gh/xxx/1234/base | ||||
|                 if [[ "${CIRCLE_PR_BASE_BRANCH}" == "gh/"* ]]; then | ||||
|                   CIRCLE_PR_BASE_BRANCH=master | ||||
|                   CIRCLE_PR_BASE_BRANCH=release/1.8 | ||||
|                 fi | ||||
|                 export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/$CIRCLE_PR_BASE_BRANCH` | ||||
|                 echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET} | ||||
|  | ||||
| @ -182,7 +182,7 @@ fi | ||||
|  | ||||
| # Patch required to build xla | ||||
| if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then | ||||
|   git clone --recursive https://github.com/pytorch/xla.git | ||||
|   git clone --recursive -b r1.8 https://github.com/pytorch/xla.git | ||||
|   ./xla/scripts/apply_patches.sh | ||||
| fi | ||||
|  | ||||
|  | ||||
| @ -54,7 +54,7 @@ function file_diff_from_base() { | ||||
|   set +e | ||||
|   git fetch origin master --quiet | ||||
|   set -e | ||||
|   git diff --name-only "$(git merge-base origin/master HEAD)" > "$1" | ||||
|   git diff --name-only "$(git merge-base origin/release/1.8 HEAD)" > "$1" | ||||
| } | ||||
|  | ||||
| function get_bazel() { | ||||
|  | ||||
| @ -300,7 +300,7 @@ test_backward_compatibility() { | ||||
|   pushd test/backward_compatibility | ||||
|   python -m venv venv | ||||
|   . venv/bin/activate | ||||
|   pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html | ||||
|   pip_install --pre torch -f https://download.pytorch.org/whl/test/cpu/torch_test.html | ||||
|   pip show torch | ||||
|   python dump_all_function_schemas.py --filename nightly_schemas.txt | ||||
|   deactivate | ||||
|  | ||||
| @ -19,6 +19,27 @@ namespace { | ||||
|  | ||||
| using namespace vec256; | ||||
|  | ||||
| // Note: Explicit implementation of copysign for Half and BFloat16 | ||||
| // is needed to workaround g++-7/8 crash on aarch64, but also makes | ||||
| // copysign faster for the half-precision types | ||||
| template<typename T> | ||||
| T copysign(T a, T b) { | ||||
|   return std::copysign(a, b); | ||||
| } | ||||
|  | ||||
| // Implement copysign for half precision floats using bit ops | ||||
| // Sign is the most significant bit for both half and bfloat16 types | ||||
| template<> | ||||
| c10::Half copysign(c10::Half a, c10::Half b) { | ||||
|   return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits()); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) { | ||||
|    return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits()); | ||||
| } | ||||
|  | ||||
|  | ||||
| // Note: Undefined behavior when performing addition is intentionally | ||||
| // ignored. | ||||
| void add_kernel(TensorIteratorBase& iter, Scalar alpha_scalar) { | ||||
| @ -180,7 +201,7 @@ void div_floor_kernel(TensorIterator& iter) { | ||||
|                 floordiv += scalar_t(1.0); | ||||
|               } | ||||
|             } else { | ||||
|               floordiv = std::copysign(scalar_t(0), a / b); | ||||
|               floordiv = copysign(scalar_t(0), a / b); | ||||
|             } | ||||
|             return floordiv; | ||||
|           }); | ||||
| @ -889,23 +910,6 @@ void heaviside_kernel(TensorIterator& iter) { | ||||
|   }); | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| T copysign(T a, T b) { | ||||
|   return std::copysign(a, b); | ||||
| } | ||||
|  | ||||
| // Implement copysign for half precision floats using bit ops | ||||
| // Sign is the most significant bit for both half and bfloat16 types | ||||
| template<> | ||||
| c10::Half copysign(c10::Half a, c10::Half b) { | ||||
|   return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits()); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) { | ||||
|    return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits()); | ||||
| } | ||||
|  | ||||
| void copysign_kernel(TensorIterator& iter) { | ||||
|   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() { | ||||
|     cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { | ||||
|  | ||||
| @ -133,7 +133,9 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) { | ||||
|     ASSERT_EQ(buffer1[i].z, buffer2[i].z); | ||||
|     ASSERT_EQ(buffer1[i].w, buffer2[i].w); | ||||
|   } | ||||
| // Skipping this part until https://github.com/pytorch/pytorch/issues/51863 is resolved | ||||
|  | ||||
| #if 0 | ||||
|   // unaligned | ||||
|   for (int i = 0; i < 16; i++) { | ||||
|     for (int j = 0; j < 16; j++) { | ||||
| @ -151,4 +153,5 @@ TEST(TestVectorizedMemoryAccess, CopyKernel) { | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| @ -16,7 +16,7 @@ int32_t driver_version() { | ||||
|   return driver_version; | ||||
| } | ||||
|  | ||||
| int device_count_impl() { | ||||
| int device_count_impl(bool fail_if_no_driver) { | ||||
|   int count; | ||||
|   auto err = cudaGetDeviceCount(&count); | ||||
|   if (err == cudaSuccess) { | ||||
| @ -34,6 +34,11 @@ int device_count_impl() { | ||||
|     case cudaErrorInsufficientDriver: { | ||||
|       auto version = driver_version(); | ||||
|       if (version <= 0) { | ||||
|         if (!fail_if_no_driver) { | ||||
|           // No CUDA driver means no devices | ||||
|           count = 0; | ||||
|           break; | ||||
|         } | ||||
|         TORCH_CHECK( | ||||
|             false, | ||||
|             "Found no NVIDIA driver on your system. Please check that you " | ||||
| @ -95,9 +100,9 @@ DeviceIndex device_count() noexcept { | ||||
|   // initialize number of devices only once | ||||
|   static int count = []() { | ||||
|     try { | ||||
|       auto result = device_count_impl(); | ||||
|       auto result = device_count_impl(/*fail_if_no_driver=*/false); | ||||
|       TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed"); | ||||
|       return device_count_impl(); | ||||
|       return result; | ||||
|     } catch (const c10::Error& ex) { | ||||
|       // We don't want to fail, but still log the warning | ||||
|       // msg() returns the message without the stack trace | ||||
| @ -110,7 +115,7 @@ DeviceIndex device_count() noexcept { | ||||
|  | ||||
| DeviceIndex device_count_ensure_non_zero() { | ||||
|   // Call the implementation every time to throw the exception | ||||
|   int count = device_count_impl(); | ||||
|   int count = device_count_impl(/*fail_if_no_driver=*/true); | ||||
|   // Zero gpus doesn't produce a warning in `device_count` but we fail here | ||||
|   TORCH_CHECK(count, "No CUDA GPUs are available"); | ||||
|   return static_cast<DeviceIndex>(count); | ||||
|  | ||||
							
								
								
									
										74
									
								
								docs/source/ddp_comm_hooks.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								docs/source/ddp_comm_hooks.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,74 @@ | ||||
| DDP Communication Hooks | ||||
| ======================= | ||||
|  | ||||
| DDP communication hook is a generic interface to control how to communicate | ||||
| gradients across workers by overriding the vanilla allreduce in | ||||
| `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_. | ||||
| A few built-in communication hooks are provided, | ||||
| and users can easily apply any of these hooks to optimize communication. | ||||
| Besides, the hook interface can also support user-defined communication | ||||
| strategies for more advanced use cases. | ||||
|  | ||||
| .. warning :: | ||||
|     DDP communication hook is experimental and subject to change. | ||||
|  | ||||
| .. warning :: | ||||
|     DDP communication hooks can only support single process single device mode | ||||
|     on NCCL backend. | ||||
|  | ||||
| How to Use a Communication Hook? | ||||
| -------------------------------- | ||||
|  | ||||
| To use a communication hook, the user just needs to let the DDP model register | ||||
| the hook before the training loop as below. | ||||
|  | ||||
| :func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`. | ||||
|     :noindex: | ||||
|  | ||||
| Default Communication Hooks | ||||
| --------------------------- | ||||
|  | ||||
| Default communication hooks are simple **stateless** hooks, so the input state | ||||
| in ``register_comm_hook`` is either a process group or ``None``. | ||||
|  | ||||
| .. automodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks | ||||
|     :members: | ||||
|  | ||||
| PowerSGD Communication Hook | ||||
| --------------------------- | ||||
|  | ||||
| PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_) | ||||
| is a gradient compression algorithm, which can provide very high compression | ||||
| rates and accelerate bandwidth-bound distributed training. | ||||
| This algorithm needs to maintain both some hyperparameters and the internal | ||||
| state. Therefore, PowerSGD communication hook is a **stateful** hook, | ||||
| and the user needs to provide a state object defined as below. | ||||
|  | ||||
| PowerSGD State | ||||
| ^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook | ||||
| .. autoclass:: PowerSGDState | ||||
|  | ||||
| PowerSGD Hooks | ||||
| ^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. warning :: | ||||
|     PowerSGD typically requires extra memory of the same size as the model's | ||||
|     gradients to enable error feedback, which can compensate for biased | ||||
|     compressed communication and improve accuracy. | ||||
|  | ||||
| .. warning :: | ||||
|     The current implementation may cause gradient overflow for FP16 input. | ||||
|  | ||||
| .. autofunction:: powerSGD_hook | ||||
| .. autofunction:: batched_powerSGD_hook | ||||
|  | ||||
| Acknowledgements | ||||
| ---------------- | ||||
|  | ||||
| Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on | ||||
| PowerSGD communication hook, as well as the | ||||
| `comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_, | ||||
| which show that the performance of PowerSGD communication hook is on par with | ||||
| the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_. | ||||
| @ -71,6 +71,7 @@ Features described in this documentation are classified by release status: | ||||
|    onnx | ||||
|    optim | ||||
|    complex_numbers | ||||
|    ddp_comm_hooks | ||||
|    pipeline | ||||
|    quantization | ||||
|    rpc | ||||
|  | ||||
| @ -484,6 +484,7 @@ Sparse tensor functions | ||||
| +++++++++++++++++++++++ | ||||
|  | ||||
| .. autofunction:: torch.sparse_coo_tensor | ||||
|    :noindex: | ||||
| .. autofunction:: torch.sparse.sum | ||||
| .. autofunction:: torch.sparse.addmm | ||||
| .. autofunction:: torch.sparse.mm | ||||
|  | ||||
							
								
								
									
										45
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								setup.py
									
									
									
									
									
								
							| @ -552,6 +552,50 @@ class build_ext(setuptools.command.build_ext.build_ext): | ||||
|             with open('compile_commands.json', 'w') as f: | ||||
|                 f.write(new_contents) | ||||
|  | ||||
| class concat_license_files(): | ||||
|     """Merge LICENSE and LICENSES_BUNDLED.txt as a context manager | ||||
|  | ||||
|     LICENSE is the main PyTorch license, LICENSES_BUNDLED.txt is auto-generated | ||||
|     from all the licenses found in ./third_party/. We concatenate them so there | ||||
|     is a single license file in the sdist and wheels with all of the necessary | ||||
|     licensing info. | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         self.f1 = 'LICENSE' | ||||
|         self.f2 = 'third_party/LICENSES_BUNDLED.txt' | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """Concatenate files""" | ||||
|         with open(self.f1, 'r') as f1: | ||||
|             self.bsd_text = f1.read() | ||||
|  | ||||
|         with open(self.f1, 'a') as f1: | ||||
|             with open(self.f2, 'r') as f2: | ||||
|                 self.bundled_text = f2.read() | ||||
|                 f1.write('\n\n') | ||||
|                 f1.write(self.bundled_text) | ||||
|  | ||||
|     def __exit__(self, exception_type, exception_value, traceback): | ||||
|         """Restore content of f1""" | ||||
|         with open(self.f1, 'w') as f: | ||||
|             f.write(self.bsd_text) | ||||
|  | ||||
|  | ||||
| try: | ||||
|     from wheel.bdist_wheel import bdist_wheel | ||||
| except ImportError: | ||||
|     # This is useful when wheel is not installed and bdist_wheel is not | ||||
|     # specified on the command line. If it _is_ specified, parsing the command | ||||
|     # line will fail before wheel_concatenate is needed | ||||
|     wheel_concatenate = None | ||||
| else: | ||||
|     # Need to create the proper LICENSE.txt for the wheel | ||||
|     class wheel_concatenate(bdist_wheel): | ||||
|         """ check submodules on sdist to prevent incomplete tarballs """ | ||||
|         def run(self): | ||||
|             with concat_license_files(): | ||||
|                 super().run() | ||||
|  | ||||
|  | ||||
| class install(setuptools.command.install.install): | ||||
|     def run(self): | ||||
| @ -724,6 +768,7 @@ def configure_extension_build(): | ||||
|         'build_ext': build_ext, | ||||
|         'clean': clean, | ||||
|         'install': install, | ||||
|         'bdist_wheel': wheel_concatenate, | ||||
|     } | ||||
|  | ||||
|     entry_points = { | ||||
|  | ||||
| @ -872,7 +872,7 @@ class TestFakeQuantize(TestCase): | ||||
|             scale, zero_point = float(scale), int(zero_point) | ||||
|             quant_min, quant_max = obs._calculate_qmin_qmax() | ||||
|  | ||||
|             Y_test, _mask = torch.fake_quantize_per_tensor_affine_cachemask( | ||||
|             Y_test = torch.fake_quantize_per_tensor_affine( | ||||
|                 X, scale, zero_point, quant_min, quant_max) | ||||
|             Y_ref = _fake_quantize_per_tensor_affine_reference( | ||||
|                 X.cpu(), scale, zero_point, quant_min, quant_max).to(device) | ||||
| @ -899,7 +899,7 @@ class TestFakeQuantize(TestCase): | ||||
|             quant_min, quant_max = obs._calculate_qmin_qmax() | ||||
|  | ||||
|             # forward pass | ||||
|             Y_test, mask = torch.fake_quantize_per_tensor_affine_cachemask( | ||||
|             Y_test = torch.fake_quantize_per_tensor_affine( | ||||
|                 X, scale, zero_point, quant_min, quant_max) | ||||
|             Y_ref = _fake_quantize_per_tensor_affine_reference( | ||||
|                 X.cpu(), scale, zero_point, quant_min, quant_max).to(device) | ||||
| @ -1246,7 +1246,7 @@ class TestFakeQuantize(TestCase): | ||||
|  | ||||
|             Y = _fake_quantize_per_channel_affine_reference( | ||||
|                 X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max) | ||||
|             Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask( | ||||
|             Y_prime = torch.fake_quantize_per_channel_affine( | ||||
|                 X, scale, zero_point, axis, quant_min, quant_max) | ||||
|             np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance) | ||||
|  | ||||
| @ -1339,7 +1339,7 @@ class TestFakeQuantize(TestCase): | ||||
|             zero_point = zero_point.to(torch.int64) | ||||
|             quant_min, quant_max = obs._calculate_qmin_qmax() | ||||
|             X.requires_grad_() | ||||
|             Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask( | ||||
|             Y_prime = torch.fake_quantize_per_channel_affine( | ||||
|                 X, scale, zero_point, axis, quant_min, quant_max) | ||||
|             dout = torch.rand(X.shape, dtype=torch.float).to(device) | ||||
|             dX = _fake_quantize_per_channel_affine_grad_reference( | ||||
|  | ||||
| @ -108,6 +108,7 @@ TESTS = [ | ||||
|     'test_fx_experimental', | ||||
|     'test_functional_autograd_benchmark', | ||||
|     'test_package', | ||||
|     'test_license', | ||||
|     'distributed/pipeline/sync/skip/test_api', | ||||
|     'distributed/pipeline/sync/skip/test_gpipe', | ||||
|     'distributed/pipeline/sync/skip/test_inspect_skip_layout', | ||||
|  | ||||
| @ -14,7 +14,7 @@ from math import sqrt | ||||
| from pathlib import Path | ||||
| from torch.multiprocessing import Process | ||||
| from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap | ||||
| from torch.fx.node import Target | ||||
| from torch.fx.node import Target, Argument | ||||
| from torch.fx.passes import shape_prop | ||||
| from torch.fx.immutable_collections import immutable_dict, immutable_list | ||||
| from copy import deepcopy | ||||
| @ -187,7 +187,7 @@ class TestFX(JitTestCase): | ||||
|         # Custom delegate to disallow in-place tensor operations | ||||
|         class NoMutableCallTracer(Tracer): | ||||
|             def create_node(self, kind : str, target : Union[str, Callable], | ||||
|                             args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, | ||||
|                             args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, | ||||
|                             type_expr : Optional[Any] = None) -> Node: | ||||
|                 name = target if isinstance(target, str) else torch.typename(target) | ||||
|                 if name[-1] == '_': | ||||
| @ -539,7 +539,7 @@ class TestFX(JitTestCase): | ||||
|     def test_node_tagging(self): | ||||
|         class TaggingTracer(Tracer): | ||||
|             def create_node(self, kind : str, target : Union[str, Callable], | ||||
|                             args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, | ||||
|                             args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, | ||||
|                             type_expr : Optional[Any] = None) -> Node: | ||||
|                 n = super().create_node(kind, target, args, kwargs, name) | ||||
|                 n.tag = 'foo' | ||||
| @ -1057,6 +1057,13 @@ class TestFX(JitTestCase): | ||||
|         result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) | ||||
|         self.assertEqual(result, torch.ones(3, 4) * 2.0) | ||||
|  | ||||
|     @skipIfNoTorchVision | ||||
|     def test_interpreter_noop_resnet18(self): | ||||
|         rn18 = resnet18() | ||||
|         transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform() | ||||
|         inp = torch.randn(5, 3, 224, 224) | ||||
|         self.assertEqual(transformed(inp), rn18(inp)) | ||||
|  | ||||
|     def test_transformer_noop(self): | ||||
|         class MyModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
| @ -1377,6 +1384,45 @@ class TestFX(JitTestCase): | ||||
|         x, y = torch.randn(3, 4), torch.randn(3, 4) | ||||
|         self.checkGraphModule(foo, (x, y)) | ||||
|  | ||||
|     def test_trace_dict_int_keys(self): | ||||
|         class ModWithDictArg(torch.nn.Module): | ||||
|             def forward(self, d : Dict[int, torch.Tensor]): | ||||
|                 return d[42] | ||||
|  | ||||
|         class CallsModWithDict(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.m = ModWithDictArg() | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 return self.m({42: x}) | ||||
|  | ||||
|         class MyTracer(torch.fx.Tracer): | ||||
|             def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: | ||||
|                 return isinstance(m, ModWithDictArg) | ||||
|  | ||||
|         traced_graph = MyTracer().trace(CallsModWithDict()) | ||||
|  | ||||
|     def test_trace_dict_proxy_keys(self): | ||||
|         class ModWithDictArg(torch.nn.Module): | ||||
|             def forward(self, d : Dict[torch.Tensor, torch.Tensor]): | ||||
|                 return d[42] | ||||
|  | ||||
|         class CallsModWithDict(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.m = ModWithDictArg() | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 return self.m({x: x}) | ||||
|  | ||||
|         class MyTracer(torch.fx.Tracer): | ||||
|             def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: | ||||
|                 return isinstance(m, ModWithDictArg) | ||||
|  | ||||
|         with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): | ||||
|             traced_graph = MyTracer().trace(CallsModWithDict()) | ||||
|  | ||||
|     def test_direct_param_use(self): | ||||
|         class TransposeTest(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|  | ||||
| @ -5,14 +5,14 @@ from typing import Callable, Dict, Union, List | ||||
| from torch.fx.symbolic_trace import symbolic_trace | ||||
| from torch.fx.graph_module import GraphModule | ||||
| from torch.fx.node import Node | ||||
| from torch.fx.experimental import graph_manipulation | ||||
| from torch.fx.experimental.accelerator_partitioner import Partitioner | ||||
| from torch.fx.experimental.rewriter import RewritingTracer | ||||
| from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes | ||||
| from torch.fx._experimental import graph_manipulation | ||||
| from torch.fx._experimental.accelerator_partitioner import Partitioner | ||||
| from torch.fx._experimental.rewriter import RewritingTracer | ||||
| from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
| from torch.testing._internal.jit_utils import JitTestCase | ||||
| from torch.fx.passes.split_module import split_module | ||||
| from torch.fx.experimental.partitioner_utils import ( | ||||
| from torch.fx._experimental.partitioner_utils import ( | ||||
|     NodeLatency, | ||||
|     get_partition_to_latency_mapping, | ||||
|     get_latency_of_partitioned_graph, | ||||
| @ -20,8 +20,8 @@ from torch.fx.experimental.partitioner_utils import ( | ||||
|     PartitionerConfig, | ||||
|     PartitionMode | ||||
| ) | ||||
| from torch.fx.experimental.fuser import fuse | ||||
| from torch.fx.experimental import merge_matmul | ||||
| from torch.fx._experimental.fuser import fuse | ||||
| from torch.fx._experimental import merge_matmul | ||||
|  | ||||
| try: | ||||
|     from torchvision.models import resnet18 | ||||
| @ -849,7 +849,7 @@ terrible spacing | ||||
|  | ||||
|     def test_merge_matmuls(self): | ||||
|         """ | ||||
|         A collection of test cases for torch.fx.experimental.merge_matmul, | ||||
|         A collection of test cases for torch.fx._experimental.merge_matmul, | ||||
|         a graph transformation that merges matrix multiplication operations. | ||||
|         """ | ||||
|         # Utility function for counting matmuls for test assertions. | ||||
|  | ||||
| @ -6503,6 +6503,38 @@ a") | ||||
|             self.checkModule(module().train(), ()) | ||||
|             self.checkModule(module().eval(), ()) | ||||
|  | ||||
|     def test_ternary_static_if(self): | ||||
|         # Test for True branch when condition variable | ||||
|         # is annotated as Final | ||||
|         class M1(torch.nn.Module): | ||||
|             flag: torch.jit.Final[bool] | ||||
|  | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.flag = True | ||||
|  | ||||
|             def forward(self) -> torch.Tensor: | ||||
|                 return torch.ones(3) if self.flag else {} | ||||
|  | ||||
|         # Test for True branch when condition variable | ||||
|         # is annotated as Final | ||||
|         class M2(torch.nn.Module): | ||||
|             flag: torch.jit.Final[bool] | ||||
|  | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.flag = False | ||||
|  | ||||
|             def forward(self) -> torch.Tensor: | ||||
|                 return {} if self.flag else torch.ones(3) | ||||
|  | ||||
|         model1 = M1() | ||||
|         model2 = M2() | ||||
|         script_model_1 = torch.jit.script(model1) | ||||
|         script_model_2 = torch.jit.script(model2) | ||||
|         self.assertEqual(model1.forward(), script_model_1.forward()) | ||||
|         self.assertEqual(model2.forward(), script_model_2.forward()) | ||||
|  | ||||
|     def test_print(self): | ||||
|         def func(x, y): | ||||
|             q = (x + y).sigmoid() | ||||
|  | ||||
| @ -1,6 +1,9 @@ | ||||
| import glob | ||||
| import io | ||||
| import os | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
| from torch.testing._internal.common_utils import TestCase, run_tests | ||||
|  | ||||
|  | ||||
| @ -10,11 +13,14 @@ except ImportError: | ||||
|     create_bundled = None | ||||
|  | ||||
| license_file = 'third_party/LICENSES_BUNDLED.txt' | ||||
| starting_txt = 'The Pytorch repository and source distributions bundle' | ||||
| site_packages = os.path.dirname(os.path.dirname(torch.__file__)) | ||||
| distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info')) | ||||
|  | ||||
| class TestLicense(TestCase): | ||||
|  | ||||
|     @unittest.skipIf(not create_bundled, "can only be run in a source tree") | ||||
|     def test_license_in_wheel(self): | ||||
|     def test_license_for_wheel(self): | ||||
|         current = io.StringIO() | ||||
|         create_bundled('third_party', current) | ||||
|         with open(license_file) as fid: | ||||
| @ -25,6 +31,18 @@ class TestLicense(TestCase): | ||||
|                 'match the current state of the third_party files. Use ' | ||||
|                 '"python third_party/build_bundled.py" to regenerate it') | ||||
|  | ||||
|     @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test") | ||||
|     def test_distinfo_license(self): | ||||
|         """If run when pytorch is installed via a wheel, the license will be in | ||||
|         site-package/torch-*dist-info/LICENSE. Make sure it contains the third | ||||
|         party bundle of licenses""" | ||||
|  | ||||
|         if len(distinfo) > 1: | ||||
|             raise AssertionError('Found too many "torch-*dist-info" directories ' | ||||
|                                  f'in "{site_packages}, expected only one') | ||||
|         with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid: | ||||
|             txt = fid.read() | ||||
|             self.assertTrue(starting_txt in txt) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run_tests() | ||||
|  | ||||
| @ -82,7 +82,8 @@ SKIP_PYTHON_BINDINGS = [ | ||||
|     'set_data', | ||||
|     '.*_overrideable',  # overrideable functions for backend extension | ||||
|     'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_', | ||||
|     '_fw_primal' | ||||
|     '_fw_primal', 'fake_quantize_per_tensor_affine_cachemask', | ||||
|     'fake_quantize_per_channel_affine_cachemask', | ||||
| ] | ||||
|  | ||||
| # These function signatures are not exposed to Python. Note that this signature | ||||
|  | ||||
| @ -1258,6 +1258,15 @@ struct to_ir { | ||||
|       const TernaryIf& expr, | ||||
|       const TypePtr& type_hint = nullptr) { | ||||
|     CondValue cond_value = emitCondExpr(expr.cond()); | ||||
|     // If the cond expr is a static value, then we metacompile the `if` | ||||
|     // statemement and only emit true or false branch | ||||
|     if (cond_value.staticIf()) { | ||||
|         if (*cond_value.staticIf()) { | ||||
|             return emitExpr(expr.true_expr(), type_hint); | ||||
|         } else { | ||||
|             return emitExpr(expr.false_expr(), type_hint); | ||||
|         } | ||||
|     } | ||||
|     auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); }; | ||||
|     auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); }; | ||||
|     return emitIfExpr(expr.range(), cond_value, true_expr, false_expr); | ||||
|  | ||||
| @ -33,39 +33,38 @@ def _orthogonalize(matrix, epsilon=1e-8): | ||||
|  | ||||
|  | ||||
| class PowerSGDState(object): | ||||
|     """ | ||||
|     Stores both the gradient compression configs and the internal states for all the gradients during the training. | ||||
|     Particularly, `matrix_approximation_rank` and `start_powerSGD_iter` are the main configs that need to be tuned by the user. | ||||
|     Although `use_error_feedback` and `warm_start` can also be tuned by the user, | ||||
|     they are typically turned on for performance. | ||||
|     r""" | ||||
|     Stores both the algorithm's hyperparameters and the internal state for all the gradients during the training. | ||||
|     Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user. | ||||
|     For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on. | ||||
|  | ||||
|     Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`] | ||||
|     ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|     1) To tune `matrix_approximation_rank`, the user can increase it from 1 by factors of 2, | ||||
|     until a satisfying accuracy can be reached. | ||||
|     The increase of `matrix_approximation_rank` can substantially increase the computation costs of the compression. | ||||
|     However, the accuracy may not be futher improved beyond a certain `matrix_approximation_rank` value. | ||||
|     2) To tune `start_powerSGD_iter`, the user can typically start with 10% of total training steps, | ||||
|     and increase it until a satisfying accuracy can be reached. | ||||
|     Deferrring PowerSGD can effectively improve the accuracy, | ||||
|     even a relatively small `matrix_approximation_rank` is used. | ||||
|     This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, | ||||
|     and compressing gradients too early may make the training quickly take a suboptimal trajectory, | ||||
|     which can result in an irrecoverable impact on the accuracy. | ||||
|     The minimum value allowed in DDP is 2, if error feedback or warm-up is enabled. | ||||
|     This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, | ||||
|     and this can conflict with any tensor memorized before the rebuild process. | ||||
|     """ | ||||
|     1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression. | ||||
|  | ||||
|         1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy. | ||||
|  | ||||
|         1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be futher improved beyond a certain ``matrix_approximation_rank`` threshold. | ||||
|  | ||||
|     To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an expoential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32. | ||||
|  | ||||
|     2. ``start_powerSGD_iter`` defers PowerSGD compression util step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy. | ||||
|  | ||||
|     To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. | ||||
|  | ||||
|     .. warning :: | ||||
|         If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2. | ||||
|         This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, | ||||
|         and this can conflict with any tensor memorized before the rebuild process. | ||||
|     """  # noqa | ||||
|  | ||||
|     __slots__ = [ | ||||
|         "process_group", | ||||
|         # The two fields below are the configs that usually need to be tuned by the user. | ||||
|         # The two fields below are the hyperparameters that should be tuned by the user. | ||||
|         "matrix_approximation_rank", | ||||
|         "start_powerSGD_iter", | ||||
|         # The two fields below are the configs that usually need to be turned on for performance. | ||||
|         # The two fields below are the binary hyperparameters recommended to be turned on for performance. | ||||
|         "use_error_feedback", | ||||
|         "warm_start", | ||||
|         # The fields below are not configs. | ||||
|         # The fields below are internal state. | ||||
|         "rng", | ||||
|         "error_dict", | ||||
|         "p_memory_dict", | ||||
| @ -93,21 +92,12 @@ class PowerSGDState(object): | ||||
|         ) | ||||
|  | ||||
|         self.process_group = process_group | ||||
|         # The low rank for matrix approximation controls the size of compressed low-rank tensors, | ||||
|         # which determines the computation ratio. | ||||
|         # Typically only a small value 1-4 is used. | ||||
|         # For some NLP tasks (as shown in Appendix D of the original paper | ||||
|         # https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32. | ||||
|         # A high rank value will increase the computation costs of compression exponentially. | ||||
|         # A good choice depends on how much extra computation can be hidden by the dominating communication costs. | ||||
|         self.matrix_approximation_rank = matrix_approximation_rank | ||||
|         # This defers PowerSGD compression util step 'start_powerSGD_iter', | ||||
|         # and vanilla allreduce runs before step 'start_powerSGD_iter'. | ||||
|         # This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages: | ||||
|         # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages: | ||||
|         # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss, | ||||
|         # even if the matrix approximation rank is increased to a large value. | ||||
|         # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce | ||||
|         # (or a more convervative compression such as FP16 compression) with PowerSGD. | ||||
|         # (or a more conservative compression such as FP16 compression) with PowerSGD. | ||||
|         # 2) There is an internal optimization of rebuilding buckets process in DDP, | ||||
|         # in order to save the memory space. | ||||
|         # This step takes place after the first iteration. | ||||
| @ -162,38 +152,44 @@ class PowerSGDState(object): | ||||
|  | ||||
|  | ||||
| def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|     """ | ||||
|     This DDP communication hook implements the original PowerSGD gradient compression | ||||
|     algorithm described in https://arxiv.org/abs/1905.13727. | ||||
|     r""" | ||||
|     This DDP communication hook implements PowerSGD gradient compression | ||||
|     algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_. | ||||
|     Once gradient tensors are aggregated across all workers, this hook applies | ||||
|     compression as follows: | ||||
|     1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: | ||||
|     high-rank tensors and vector-like rank-1 tensors (for biases). | ||||
|     2) Handles rank-1 tensors by allreducing them without compression: | ||||
|         2.1) Allocate contiguous memory for those rank-1 tensors, | ||||
|         and allreduces all the rank-1 tensors as a batch, without compression; | ||||
|         2.2) Copies the individual rank-1 tensors from the contiguous memory back to the input tensor. | ||||
|     3) Handles high-rank tensors by PowerSGD compression: | ||||
|         3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M, | ||||
|  | ||||
|     1. Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: high-rank tensors and vector-like rank-1 tensors (for biases). | ||||
|  | ||||
|     2. Handles rank-1 tensors by allreducing them without compression: | ||||
|  | ||||
|         2.1. Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression; | ||||
|  | ||||
|         2.2. Copies the individual rank-1 tensors from the contiguous memory back to the input tensor. | ||||
|  | ||||
|     3. Handles high-rank tensors by PowerSGD compression: | ||||
|  | ||||
|         3.1. For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M, | ||||
|         such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; | ||||
|         3.2) Computes each P in Ps, which is equal to MQ; | ||||
|         3.3) Allreduces Ps as a batch; | ||||
|         3.4) Orthogonalizes each P in Ps; | ||||
|         3.5) Computes each Q in Qs, which is approximately equal to M^TP; | ||||
|         3.6) Allreduces Qs as a batch; | ||||
|         3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T. | ||||
|  | ||||
|     Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations. | ||||
|     This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy, | ||||
|     but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers. | ||||
|         3.2. Computes each P in Ps, which is equal to MQ; | ||||
|  | ||||
|     TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- | ||||
|     one left multiplication and one right multiplication. | ||||
|     For warm-start, can take one such step at a time, and alternate between them. | ||||
|         3.3. Allreduces Ps as a batch; | ||||
|  | ||||
|         3.4. Orthogonalizes each P in Ps; | ||||
|  | ||||
|         3.5. Computes each Q in Qs, which is approximately equal to M^TP; | ||||
|  | ||||
|         3.6. Allreduces Qs as a batch; | ||||
|  | ||||
|         3.7. Computes each M among all the high-rank tensors, which is approximately equal to PQ^T. | ||||
|  | ||||
|     Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. | ||||
|     This not only gives the user more control over the tradeoff between speedup and accuracy, | ||||
|     but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. | ||||
|  | ||||
|     Args: | ||||
|         state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. | ||||
|             To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`]. | ||||
|             To tune the compression configs, mainly need to tune `matrix_approximation_rank`` and ``start_powerSGD_iter``. | ||||
|         bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. | ||||
|             Note that since DDP comm hook only supports single process single device mode at this time, | ||||
|             only exactly one tensor is stored in this bucket. | ||||
| @ -202,9 +198,9 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|         Future handler of the communication, which updates the gradients in place. | ||||
|  | ||||
|     Example:: | ||||
|         state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) | ||||
|         >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) | ||||
|         >>> ddp_model.register_comm_hook(state, powerSGD_hook) | ||||
|     """ | ||||
|     """  # noqa | ||||
|     process_group = state.process_group | ||||
|     group_to_use = process_group if process_group is not None else dist.group.WORLD | ||||
|     world_size = group_to_use.size() | ||||
| @ -374,6 +370,10 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|         for tensor, p, q in zip(high_rank_tensors, ps, qs): | ||||
|             torch.matmul(tensor.t(), p, out=q) | ||||
|  | ||||
|         # TODO: The above procedure does two matmul+allreduce steps per iteration -- | ||||
|         # one left multiplication and one right multiplication. | ||||
|         # For warm-start, can take one such step at a time, and alternate between them. | ||||
|  | ||||
|         # Allreduce Qs. | ||||
|         return [ | ||||
|             dist.all_reduce( | ||||
| @ -412,40 +412,48 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|  | ||||
|  | ||||
| def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|     """ | ||||
|     r""" | ||||
|     This DDP communication hook implements a simplified PowerSGD gradient compression | ||||
|     algorithm described in https://arxiv.org/abs/1905.13727. | ||||
|     algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_. | ||||
|     This variant does not compress the gradients layer by layer, | ||||
|     but instead compresses the flattened input tensor that batches all the gradients. | ||||
|     Therefore, it is **faster** than :meth:`powerSGD_hook`, | ||||
|     but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1. | ||||
|  | ||||
|     .. warning :: | ||||
|         Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy, | ||||
|         because batching per-parameter tensors without column/row alignment can destroy low-rank structure. | ||||
|         Therefore, the user should always consider :meth:`powerSGD_hook` first, | ||||
|         and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1. | ||||
|  | ||||
|     Once gradient tensors are aggregated across all workers, this hook applies | ||||
|     compression to the flattened input tensor that batches per-parameter tensors as follows: | ||||
|     1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; | ||||
|     2) Creates two low-rank tensors P and Q for decomposing M, | ||||
|     such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; | ||||
|     2) Computes P, which is equal to MQ; | ||||
|     3) Allreduces P; | ||||
|     4) Orthogonalizes P; | ||||
|     5) Computes Q, which is approximately equal to M^TP; | ||||
|     6) Allreduces Q; | ||||
|     7) Computes M, which is approximately equal to PQ^T. | ||||
|     8) Truncates the input tensor to the original length. | ||||
|     compression as follows: | ||||
|  | ||||
|     This variant is faster than `powerSGD_hook` that runs layer-wise gradient compression, | ||||
|     but it usually results in a much lower accuracy, unless `matrix_approximation_rank` in the state is 1. | ||||
|     Increasing `matrix_approximation_rank` may not necessarily increase the accuracy, | ||||
|     because batching per-parameter tensors without column/row alignment can destroy low-rank structure. | ||||
|     Therefore, the user shoud always consider `powerSGD_hook` first, | ||||
|     and only consider this variant when a satisfying accuracy can be achieved when `matrix_approximation_rank` is 1. | ||||
|     1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; | ||||
|  | ||||
|     Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations. | ||||
|     This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy, | ||||
|     but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers. | ||||
|     2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; | ||||
|  | ||||
|     TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- | ||||
|     one left multiplication and one right multiplication. | ||||
|     For warm-start, can take one such step at a time, and alternate between them. | ||||
|     3. Computes P, which is equal to MQ; | ||||
|  | ||||
|     4. Allreduces P; | ||||
|  | ||||
|     5. Orthogonalizes P; | ||||
|  | ||||
|     6. Computes Q, which is approximately equal to M^TP; | ||||
|  | ||||
|     7. Allreduces Q; | ||||
|  | ||||
|     8. Computes M, which is approximately equal to PQ^T. | ||||
|  | ||||
|     9. Truncates the input tensor to the original length. | ||||
|  | ||||
|     Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations. | ||||
|     This not only gives the user more control over the tradeoff between speedup and accuracy, | ||||
|     but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers. | ||||
|  | ||||
|     Args: | ||||
|         state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. | ||||
|             To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`]. | ||||
|             To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``. | ||||
|         bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. | ||||
|             Note that since DDP comm hook only supports single process single device mode at this time, | ||||
|             only exactly one tensor is stored in this bucket. | ||||
| @ -454,9 +462,9 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|         Future handler of the communication, which updates the gradients in place. | ||||
|  | ||||
|     Example:: | ||||
|         state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) | ||||
|         >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) | ||||
|         >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) | ||||
|     """ | ||||
|     """  # noqa | ||||
|     process_group = state.process_group | ||||
|     group_to_use = process_group if process_group is not None else dist.group.WORLD | ||||
|     world_size = group_to_use.size() | ||||
| @ -563,6 +571,10 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: | ||||
|             out=state.q_memory_dict[bucket_index], | ||||
|         ) | ||||
|  | ||||
|         # TODO: The above procedure does two matmul+allreduce steps per iteration -- | ||||
|         # one left multiplication and one right multiplication. | ||||
|         # For warm-start, can take one such step at a time, and alternate between them. | ||||
|  | ||||
|         return [ | ||||
|             dist.all_reduce( | ||||
|                 state.q_memory_dict[bucket_index], group=group_to_use, async_op=True | ||||
|  | ||||
| @ -4,7 +4,7 @@ from typing import Dict, List, Set, NamedTuple, Tuple | ||||
| import torch | ||||
| from torch.fx.passes.split_module import split_module | ||||
| import operator | ||||
| from torch.fx.experimental.partitioner_utils import Partition, \ | ||||
| from torch.fx._experimental.partitioner_utils import Partition, \ | ||||
|     Device, PartitionerConfig, get_partition_to_latency_mapping,\ | ||||
|     get_latency_of_partitioned_graph, NodeLatency, get_extra_size_of, \ | ||||
|     PartitionMode | ||||
| @ -2,7 +2,7 @@ from typing import Dict, List, NamedTuple, Any | ||||
| 
 | ||||
| import torch | ||||
| from torch.fx.passes.shape_prop import ShapeProp | ||||
| from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes | ||||
| from torch.fx._experimental.param_fetch import lift_lowering_attrs_to_nodes | ||||
| from torch.fx.graph import Graph, get_qualified_name | ||||
| from torch.fx.graph_module import GraphModule | ||||
| from torch.fx.node import Node, Target, map_arg | ||||
| @ -116,7 +116,7 @@ class Interpreter: | ||||
|  | ||||
|     # Main Node running APIs | ||||
|  | ||||
|     def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|     def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         """ | ||||
|         Execute a ``placeholder`` node. Note that this is stateful: | ||||
|         ``Interpreter`` maintains an internal iterator over | ||||
| @ -141,7 +141,7 @@ class Interpreter: | ||||
|         else: | ||||
|             return next(self.args_iter) | ||||
|  | ||||
|     def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|     def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         """ | ||||
|         Execute a ``get_attr`` node. Will retrieve an attribute | ||||
|         value from the ``Module`` hierarchy of ``self.module``. | ||||
| @ -159,7 +159,7 @@ class Interpreter: | ||||
|         assert isinstance(target, str) | ||||
|         return self.fetch_attr(target) | ||||
|  | ||||
|     def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|     def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         """ | ||||
|         Execute a ``call_function`` node and return the result. | ||||
|  | ||||
| @ -178,7 +178,7 @@ class Interpreter: | ||||
|         # Execute the function and return the result | ||||
|         return target(*args, **kwargs) | ||||
|  | ||||
|     def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|     def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         """ | ||||
|         Execute a ``call_method`` node and return the result. | ||||
|  | ||||
| @ -199,7 +199,7 @@ class Interpreter: | ||||
|         assert isinstance(target, str) | ||||
|         return getattr(self_obj, target)(*args_tail, **kwargs) | ||||
|  | ||||
|     def call_module(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|     def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         """ | ||||
|         Execute a ``call_module`` node and return the result. | ||||
|  | ||||
| @ -221,7 +221,7 @@ class Interpreter: | ||||
|  | ||||
|         return submod(*args, **kwargs) | ||||
|  | ||||
|     def output(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|     def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         """ | ||||
|         Execute an ``output`` node. This really just retrieves | ||||
|         the value referenced by the ``output`` node and returns it. | ||||
| @ -307,12 +307,12 @@ class Transformer(Interpreter): | ||||
|         method equivalents). We could subclass ``Transformer`` like so:: | ||||
|  | ||||
|             class NegSigmSwapXformer(Transformer): | ||||
|                 def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|                 def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|                     if target == torch.sigmoid: | ||||
|                         return torch.neg(*args, **kwargs) | ||||
|                     return super().call_function(n) | ||||
|  | ||||
|                 def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: | ||||
|                 def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|                     if target == 'neg': | ||||
|                         call_self, *args_tail = args | ||||
|                         return call_self.sigmoid(*args_tail, **kwargs) | ||||
| @ -344,7 +344,7 @@ class Transformer(Interpreter): | ||||
|         self.tracer = TransformerTracer(self.new_graph) | ||||
|         self.tracer.root = module | ||||
|  | ||||
|     def placeholder(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy: | ||||
|     def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: | ||||
|         """ | ||||
|         Execute a ``placeholder`` node. In ``Transformer``, this is | ||||
|         overridden to insert a new ``placeholder`` into the output | ||||
| @ -360,7 +360,7 @@ class Transformer(Interpreter): | ||||
|         assert isinstance(target, str) | ||||
|         return Proxy(self.new_graph.placeholder(target), self.tracer) | ||||
|  | ||||
|     def get_attr(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Proxy: | ||||
|     def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: | ||||
|         """ | ||||
|         Execute a ``get_attr`` node. In ``Transformer``, this is | ||||
|         overridden to insert a new ``get_attr`` node into the output | ||||
| @ -376,6 +376,12 @@ class Transformer(Interpreter): | ||||
|         assert isinstance(target, str) | ||||
|         return Proxy(self.new_graph.get_attr(target), self.tracer) | ||||
|  | ||||
|     def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | ||||
|         # Override so that the leaf module policy from `self.tracer` is respected. | ||||
|         assert isinstance(target, str) | ||||
|         submod = self.fetch_attr(target) | ||||
|         return self.tracer.call_module(submod, submod.forward, args, kwargs) | ||||
|  | ||||
|     def transform(self) -> GraphModule: | ||||
|         """ | ||||
|         Transform ``self.module`` and return the transformed | ||||
|  | ||||
| @ -5,7 +5,7 @@ import operator | ||||
|  | ||||
| from .graph import magic_methods, reflectable_magic_methods, Graph | ||||
| from typing import Tuple, Dict, Optional, Iterable, Any, Iterator | ||||
| from .node import Target, Node, Argument, base_types | ||||
| from .node import Target, Node, Argument, base_types, map_aggregate | ||||
|  | ||||
| class TracerBase: | ||||
|     graph: Graph | ||||
| @ -61,8 +61,17 @@ class TracerBase: | ||||
|         elif isinstance(a, dict): | ||||
|             r = {} | ||||
|             for k, v in a.items(): | ||||
|                 if not isinstance(k, str): | ||||
|                     raise NotImplementedError(f"dictionaries with non-string keys: {a}") | ||||
|                 # Check for invalid dict keys. We do not want a Proxy to appear | ||||
|                 # anywhere within the key. Since keys can be collection types, | ||||
|                 # we iterate through the key with map_aggregate | ||||
|                 k = self.create_arg(k) | ||||
|  | ||||
|                 def no_node(arg): | ||||
|                     if isinstance(arg, Node): | ||||
|                         raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " | ||||
|                                            "Node. Got key: {k}") | ||||
|                 map_aggregate(k, no_node) | ||||
|  | ||||
|                 r[k] = self.create_arg(v) | ||||
|             return r | ||||
|         elif isinstance(a, slice): | ||||
|  | ||||
| @ -1021,13 +1021,14 @@ class DistributedDataParallel(Module): | ||||
|         parameter syncs while running Distributed DataParallel training. | ||||
|  | ||||
|         Args: | ||||
|             state (object): state is passed to the hook and can be used to maintain | ||||
|                             and update any state information that users would like to | ||||
|                             maintain as part of the training process. Examples: error | ||||
|                             feedback in gradient compression, peers to communicate with | ||||
|                             next in GossipGrad etc. | ||||
|             hook (callable): is defined as: | ||||
|                              hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future: | ||||
|             state (object): Passed to the hook to maintain any state information during the training process. | ||||
|                             Examples include error feedback in gradient compression, | ||||
|                             peers to communicate with next in GossipGrad, etc. | ||||
|  | ||||
|                             It is locally stored by each worker | ||||
|                             and shared by all the gradient tensors on the worker. | ||||
|             hook (callable): Averages gradient tensors across workers and defined as: | ||||
|                              ``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``: | ||||
|  | ||||
|                              This function is called once the bucket is ready. The | ||||
|                              hook can perform whatever processing is needed and return | ||||
| @ -1067,7 +1068,7 @@ class DistributedDataParallel(Module): | ||||
|             DDP communication hook is experimental and subject to change. | ||||
|  | ||||
|         Example:: | ||||
|             Below is an example of a noop hook that returns back the same tensors: | ||||
|             Below is an example of a noop hook that returns the same tensors. | ||||
|  | ||||
|             >>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future | ||||
|             >>>     fut = torch.futures.Future() | ||||
| @ -1091,7 +1092,6 @@ class DistributedDataParallel(Module): | ||||
|             >>>     return fut.then(decode) | ||||
|  | ||||
|             >>> ddp.register_comm_hook(state = None, hook = encode_and_decode) | ||||
|  | ||||
|         """ | ||||
|         self._check_comm_hook(hook) | ||||
|         dist._register_comm_hook(self.reducer, state, hook) | ||||
|  | ||||
| @ -391,9 +391,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: | ||||
|         torch.exp2: lambda input, out=None: -1, | ||||
|         torch.expm1: lambda input, out=None: -1, | ||||
|         torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, | ||||
|         torch.fake_quantize_per_channel_affine_cachemask: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, | ||||
|         torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, | ||||
|         torch.fake_quantize_per_tensor_affine_cachemask: lambda input, scale, zero_point, quant_min, quant_max: -1, | ||||
|         torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, | ||||
|         torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, | ||||
|         torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	