mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	This is a big bang PR, merge conflicts are probably expected and will be addressed at merge. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81643 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			22 lines
		
	
	
		
			815 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			22 lines
		
	
	
		
			815 B
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
This global flag controls whether to assign new tensors to the parameters
 | 
						|
instead of changing the existing parameters in-place when converting an `nn.Module`
 | 
						|
using the following methods:
 | 
						|
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
 | 
						|
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
 | 
						|
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
 | 
						|
4. `module._apply(fn)` (for generic functions applied to `module`)
 | 
						|
 | 
						|
Default: False
 | 
						|
"""
 | 
						|
_overwrite_module_params_on_conversion = False
 | 
						|
 | 
						|
 | 
						|
def set_overwrite_module_params_on_conversion(value):
 | 
						|
    global _overwrite_module_params_on_conversion
 | 
						|
    _overwrite_module_params_on_conversion = value
 | 
						|
 | 
						|
 | 
						|
def get_overwrite_module_params_on_conversion():
 | 
						|
    return _overwrite_module_params_on_conversion
 |