mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Sync changes from pytorch/torchdynamo, enable tests (#86950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
78ef40973c
commit
8f71e8de7e
@ -198,12 +198,22 @@ class OperatorInputsMode(TorchDispatchMode):
|
||||
|
||||
|
||||
def map_to_device(e, device):
|
||||
return e.to(device) if isinstance(e, torch.Tensor) else e
|
||||
if isinstance(e, torch.Tensor):
|
||||
return e.to(device)
|
||||
elif isinstance(e, torch.device):
|
||||
return device
|
||||
elif isinstance(e, str):
|
||||
if e == "cuda" or e == "cpu":
|
||||
return device.type
|
||||
else:
|
||||
return e
|
||||
|
||||
|
||||
def map_to_dtype(e, dtype):
|
||||
if isinstance(e, torch.Tensor) and e.is_floating_point():
|
||||
return e.to(dtype)
|
||||
elif isinstance(e, torch.dtype):
|
||||
return dtype
|
||||
else:
|
||||
return e
|
||||
|
||||
|
Reference in New Issue
Block a user