mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Migrate param_group testing to OptimizerInfo (#117675)
Today, our param_group testing does the equivalent of pitting weight and bias with different optimizer hyperparams and then check that the overall result is going the right direction based on maximize. This PR introduces two tests to encompass coverage: 1. For every optimizer input (no differentiable), always force bias to have 0 weight_decay, and then check that the direction is expected. This is basically a replica to today's tests, but is more methodical as the test is a real use case. 2. To ensure that the different groups have distinct behavior, I added another test where lr is basically 0 in default group, and ensure that the param in the default group doesn't move while loss does. Together, these tests do a better job of testing param groups than today's tests, **though we do lose some flavors**. For example, RMSProp also pits centered=True vs False across the param_groups, Adadelta has a variation on rho, and ASGD has a variation for t0. I don't think this is really a loss, as the previous test was just testing for direction and our new tests test stronger guarantees. The leftover param group configs are used in conjunction with LRSchedulers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117675 Approved by: https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							d280b6ae58
						
					
				
				
					commit
					c6be5d55a5
				
			| @ -325,9 +325,6 @@ class TestOptim(TestCase): | ||||
|     def _build_params_dict(self, weight, bias, **kwargs): | ||||
|         return [{"params": [weight]}, dict(params=[bias], **kwargs)] | ||||
|  | ||||
|     def _build_params_dict_single(self, weight, bias, **kwargs): | ||||
|         return [dict(params=bias, **kwargs)] | ||||
|  | ||||
|     def test_sgd(self): | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: SGD( | ||||
| @ -336,35 +333,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: SGD( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                 lr=1e-3, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: SGD( | ||||
|                 self._build_params_dict_single(weight, bias, lr=1e-2), | ||||
|                 lr=1e-3, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: SGD( | ||||
|                 self._build_params_dict_single(weight, bias, lr=1e-2), | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: SGD( | ||||
|                 [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach | ||||
| @ -530,16 +498,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adam( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                 lr=1e-3, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adam( | ||||
|                 [weight, bias], | ||||
| @ -562,17 +520,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adam( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                 lr=1e-3, | ||||
|                 amsgrad=True, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adam( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
| @ -694,16 +641,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: AdamW( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                 lr=1e-3, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: AdamW( | ||||
|                 [weight, bias], | ||||
| @ -774,15 +711,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adadelta( | ||||
|                 self._build_params_dict(weight, bias, rho=0.95), | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adadelta( | ||||
|                 self._build_params_dict(weight, bias, rho=0.95), | ||||
| @ -815,12 +743,6 @@ class TestOptim(TestCase): | ||||
|             ) | ||||
|  | ||||
|     def test_nadam(self): | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, foreach: NAdam( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach | ||||
|             ), | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, foreach: NAdam( | ||||
|                 [weight, bias], lr=1e-3, foreach=foreach | ||||
| @ -915,16 +837,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adagrad( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                 lr=1e-1, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adagrad( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
| @ -989,16 +901,6 @@ class TestOptim(TestCase): | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adamax( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                 lr=1e-1, | ||||
|                 maximize=maximize, | ||||
|                 foreach=foreach, | ||||
|             ), | ||||
|             constructor_accepts_maximize=True, | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, maximize, foreach: Adamax( | ||||
|                 [weight, bias], | ||||
| @ -1021,12 +923,6 @@ class TestOptim(TestCase): | ||||
|             ), | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, foreach: RAdam( | ||||
|                 self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach | ||||
|             ), | ||||
|             constructor_accepts_foreach=True, | ||||
|         ) | ||||
|         self._test_basic_cases( | ||||
|             lambda weight, bias, foreach: RAdam( | ||||
|                 [weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach | ||||
| @ -1094,62 +990,6 @@ class TestOptim(TestCase): | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: RMSprop( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-3), | ||||
|                     lr=1e-2, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: RMSprop( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-3), | ||||
|                     lr=1e-2, | ||||
|                     centered=True, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: RMSprop( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-3), | ||||
|                     lr=1e-2, | ||||
|                     centered=True, | ||||
|                     momentum=0.1, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: RMSprop( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-3), | ||||
|                     lr=1e-2, | ||||
|                     momentum=0.1, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: RMSprop( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-3), | ||||
|                     lr=1e-2, | ||||
|                     momentum=0.1, | ||||
|                     weight_decay=1, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach)) | ||||
|             self._test_complex_2d( | ||||
|                 lambda param: RMSprop(param, centered=True, foreach=foreach) | ||||
| @ -1183,28 +1023,6 @@ class TestOptim(TestCase): | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: ASGD( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                     lr=1e-3, | ||||
|                     t0=100, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: ASGD( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                     lr=1e-3, | ||||
|                     weight_decay=1, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             # Ref: https://github.com/pytorch/pytorch/issues/84560 | ||||
|             # self._test_complex_2d(optimizer) | ||||
|             self._test_complex_optimizer( | ||||
| @ -1239,18 +1057,6 @@ class TestOptim(TestCase): | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|             ) | ||||
|             self._test_basic_cases( | ||||
|                 lambda weight, bias, maximize, foreach: Rprop( | ||||
|                     self._build_params_dict(weight, bias, lr=1e-2), | ||||
|                     lr=2e-4, | ||||
|                     maximize=maximize, | ||||
|                     foreach=foreach, | ||||
|                 ), | ||||
|                 constructor_accepts_maximize=True, | ||||
|                 constructor_accepts_foreach=True, | ||||
|                 atol=4e-5 if is_cuda_sm86 else None, | ||||
|                 rtol=3e-5 if is_cuda_sm86 else None, | ||||
|             ) | ||||
|             self._test_complex_2d(lambda param: Rprop(param, foreach=foreach)) | ||||
|             self._test_complex_optimizer( | ||||
|                 lambda param: Rprop([param], lr=0.001, foreach=foreach) | ||||
|  | ||||
| @ -359,6 +359,90 @@ class TestOptimRenewed(TestCase): | ||||
|             optimizer_cuda.step() | ||||
|  | ||||
|  | ||||
|     @optims(optim_db, dtypes=[torch.float32]) | ||||
|     def test_param_groups_weight_decay(self, device, dtype, optim_info): | ||||
|         optim_cls = optim_info.optim_cls | ||||
|         # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 | ||||
|         all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",)) | ||||
|         for optim_input in all_optim_inputs: | ||||
|             weight_kwargs = optim_input.kwargs | ||||
|             bias_kwargs = deepcopy(optim_input.kwargs) | ||||
|             bias_kwargs["weight_decay"] = 0.0 | ||||
|  | ||||
|             weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) | ||||
|             bias = Parameter(torch.randn((10), device=device, dtype=dtype)) | ||||
|             input = torch.randn(5, device=device, dtype=dtype) | ||||
|  | ||||
|             optimizer = optim_cls([dict(params=[weight], **weight_kwargs), dict(params=[bias], **bias_kwargs)]) | ||||
|  | ||||
|             loss = (weight.mv(input) + bias).pow(2).sum() | ||||
|             initial_value = loss.item() | ||||
|             for _ in range(20): | ||||
|                 optimizer.zero_grad() | ||||
|                 loss = (weight.mv(input) + bias).pow(2).sum() | ||||
|                 loss.backward() | ||||
|                 if optim_cls.__name__ == "SparseAdam": | ||||
|                     # SparseAdam requires sparse gradients. For this test, we convert the Tensor layout, | ||||
|                     # which we know does NOT represent the expected use case! | ||||
|                     weight.grad = weight.grad.to_sparse() | ||||
|                     bias.grad = bias.grad.to_sparse() | ||||
|                 optimizer.step() | ||||
|  | ||||
|             # Test that the direction of loss moved appropriately | ||||
|             if optim_input.kwargs.get("maximize", False): | ||||
|                 self.assertGreater(loss.item(), initial_value) | ||||
|             else: | ||||
|                 self.assertLess(loss.item(), initial_value) | ||||
|  | ||||
|  | ||||
|     @optims(optim_db, dtypes=[torch.float32]) | ||||
|     def test_param_groups_lr(self, device, dtype, optim_info): | ||||
|         optim_cls = optim_info.optim_cls | ||||
|         # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 | ||||
|         all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",)) | ||||
|         for optim_input in all_optim_inputs: | ||||
|             # optim_input.kwargs will be the param group kwargs, which should have >0 lr | ||||
|             if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0: | ||||
|                 optim_input.kwargs["lr"] = 1e-3 | ||||
|             outer_kwargs = {"lr": 1e-28} | ||||
|             if optim_cls.__name__ == "Rprop": | ||||
|                 # Allow min step size to be 0 | ||||
|                 outer_kwargs["step_sizes"] = (0, 50) | ||||
|  | ||||
|             weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) | ||||
|             bias = Parameter(torch.randn((10), device=device, dtype=dtype)) | ||||
|             irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype)) | ||||
|             irrelevant_clone = irrelevant.clone() | ||||
|             input = torch.randn(5, device=device, dtype=dtype) | ||||
|             optimizer = optim_cls( | ||||
|                 [dict(params=[weight, bias], **optim_input.kwargs), dict(params=[irrelevant])], | ||||
|                 **outer_kwargs) | ||||
|  | ||||
|             loss = (weight.mv(input) + bias).pow(2).sum() | ||||
|             initial_value = loss.item() | ||||
|             for _ in range(20): | ||||
|                 optimizer.zero_grad() | ||||
|                 loss = (weight.mv(input) + bias).pow(2).sum() | ||||
|                 loss.backward() | ||||
|                 irrelevant.grad = torch.rand_like(irrelevant) | ||||
|                 if optim_cls.__name__ == "SparseAdam": | ||||
|                     # SparseAdam requires sparse gradients. For this test, we convert the Tensor layout, | ||||
|                     # which we know does NOT represent the expected use case! | ||||
|                     weight.grad = weight.grad.to_sparse() | ||||
|                     bias.grad = bias.grad.to_sparse() | ||||
|                     irrelevant.grad = irrelevant.grad.to_sparse() | ||||
|                 optimizer.step() | ||||
|  | ||||
|             # Test that the direction of loss moved appropriately | ||||
|             if optim_input.kwargs.get("maximize", False): | ||||
|                 self.assertGreater(loss.item(), initial_value) | ||||
|             else: | ||||
|                 self.assertLess(loss.item(), initial_value) | ||||
|  | ||||
|             # Test that irrelevant parameters were not updated since lr was almost 0 | ||||
|             self.assertEqual(irrelevant, irrelevant_clone) | ||||
|  | ||||
|  | ||||
|     @optims(optim_db, dtypes=[torch.float32]) | ||||
|     def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info): | ||||
|         optim_cls = optim_info.optim_cls | ||||
|  | ||||
| @ -1191,6 +1191,16 @@ optim_db: List[OptimizerInfo] = [ | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_deepcopy_copies_all_public_attrs", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 unittest.skip("Does not support param groups"), | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_param_groups_lr", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 unittest.skip("Does not support param groups"), | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_param_groups_weight_decay", | ||||
|             ), | ||||
|         ), | ||||
|     ), | ||||
|     OptimizerInfo( | ||||
| @ -1436,6 +1446,22 @@ optim_db: List[OptimizerInfo] = [ | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_state_dict_deterministic", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 skipIfTorchDynamo( | ||||
|                     "Errors with list out of range, see https://github.com/pytorch/pytorch/issues/116061" | ||||
|                 ), | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_param_groups_weight_decay", | ||||
|                 device_type="cpu", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 skipIfTorchDynamo( | ||||
|                     "Errors with list out of range, see https://github.com/pytorch/pytorch/issues/116061" | ||||
|                 ), | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_param_groups_lr", | ||||
|                 device_type="cpu", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 skipIfTorchDynamo( | ||||
|                     "Errors with list out of range, see https://github.com/pytorch/pytorch/issues/116061" | ||||
| @ -1478,6 +1504,11 @@ optim_db: List[OptimizerInfo] = [ | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_state_dict_deterministic", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), | ||||
|                 "TestOptimRenewed", | ||||
|                 "test_param_groups_lr", | ||||
|             ), | ||||
|             DecorateInfo( | ||||
|                 unittest.skip( | ||||
|                     "SparseAdam does not support dense gradients, see #116507" | ||||
|  | ||||
		Reference in New Issue
	
	Block a user