mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Speedup segmented sort with large nsort
Follow up to gh-77100 Instead of calling `at::arange`, I repurpose the existing kernels to achieve the same effect. I've also fixed the 2d case bug where the pointer was advanced by `n` which equals `nsegment * nsort` after only processing `nsort` elements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77188 Approved by: https://github.com/ngimel
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							99339fddd9
						
					
				
				
					commit
					d9351ed520
				
			| @ -130,6 +130,22 @@ class TestSortAndSelect(TestCase): | ||||
|             self.assertIsOrdered('descending', x, res2val, res2ind, | ||||
|                                  'random with NaNs') | ||||
|  | ||||
|     @onlyCUDA | ||||
|     def test_sort_large_slice(self, device): | ||||
|         # tests direct cub path | ||||
|         x = torch.randn(4, 1024000, device=device) | ||||
|         res1val, res1ind = torch.sort(x, stable=True) | ||||
|         torch.cuda.synchronize() | ||||
|         # assertIsOrdered is too slow, so just compare to cpu | ||||
|         res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True) | ||||
|         self.assertEqual(res1val, res1val_cpu.cuda()) | ||||
|         self.assertEqual(res1ind, res1ind_cpu.cuda()) | ||||
|         res1val, res1ind = torch.sort(x, descending=True, stable=True) | ||||
|         torch.cuda.synchronize() | ||||
|         res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True) | ||||
|         self.assertEqual(res1val, res1val_cpu.cuda()) | ||||
|         self.assertEqual(res1ind, res1ind_cpu.cuda()) | ||||
|  | ||||
|     # FIXME: remove torch.bool from unsupported types once support is added for cub sort | ||||
|     @dtypes(*all_types_and(torch.half, torch.bfloat16)) | ||||
|     def test_stable_sort(self, device, dtype): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user