mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Simplify BFLOAT16_AVAILABLE (#163445)
Simplify `BFLOAT16_AVAILABLE` by using `torch.cuda.is_bf16_supported()` and `torch.xpu.is_bf16_supported()`. Outdated comments are also removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163445 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							edafc902d7
						
					
				
				
					commit
					96a3afb8ec
				
			| @ -34,11 +34,7 @@ device_type = ( | ||||
|     acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" | ||||
| ) | ||||
|  | ||||
| # bfloat16 is only supported by CUDA 11+ or XPU | ||||
| BFLOAT16_AVAILABLE = ( | ||||
|     torch.cuda.is_available() | ||||
|     and (torch.version.cuda is not None or torch.version.hip is not None) | ||||
| ) or torch.xpu.is_available() | ||||
| BFLOAT16_AVAILABLE = torch.cuda.is_bf16_supported() or torch.xpu.is_bf16_supported() | ||||
|  | ||||
|  | ||||
| class Net(nn.Module): | ||||
|  | ||||
| @ -83,7 +83,6 @@ if TEST_WITH_DEV_DBG_ASAN: | ||||
|     ) | ||||
|     sys.exit(0) | ||||
|  | ||||
| # bfloat16 is only supported by CUDA 11+ | ||||
| BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( | ||||
|     torch.version.cuda is not None or torch.version.hip is not None | ||||
| ) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user