mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
removed zero dim cpu logic from fake_tensor.py (#147501)
Fixes #144748
In #144748, the inconsistency between the eager mode and the inductor mode is reported as a bug.
The root cause is fake_tenosr.py's find-common-device method, 0b0da81021/torch/_subclasses/fake_tensor.py (L833)
, takes zero dim cpu tensor into account but the device check in adaption.h doesn't.
This fix is to add a list for some ops to bypass zero-dim-cpu-tensor check to align with the eager mode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147501
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e17932c22
commit
9e0473b566
@ -211,6 +211,22 @@ class FakeTensorTest(TestCase):
|
||||
self.assertEqual(out.device, y.device)
|
||||
self.assertTrue(isinstance(out, FakeTensor))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_op_with_zero_dim_bypassed(self):
|
||||
if torch._functorch.config.fake_tensor_propagate_real_tensors:
|
||||
return
|
||||
shape_env = ShapeEnv()
|
||||
mode = FakeTensorMode(shape_env=shape_env)
|
||||
x = torch.tensor(1.0, device="cuda")
|
||||
y = torch.tensor(2.0)
|
||||
fake_x = mode.from_tensor(x)
|
||||
fake_y = mode.from_tensor(y)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Unhandled FakeTensor Device Propagation for.*"
|
||||
) as exc:
|
||||
torch.nextafter(fake_x, fake_y)
|
||||
|
||||
def test_nan_to_num(self):
|
||||
with FakeTensorMode():
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
|
@ -889,6 +889,11 @@ class FakeTensor(Tensor):
|
||||
aten._foreach_copy.default,
|
||||
)
|
||||
|
||||
# list of ops not using zero dim cpu tensor logic to align with the eager mode.
|
||||
bypass_zero_dim_cpu_tensor_check_ops = ordered_set(
|
||||
aten.nextafter.default,
|
||||
)
|
||||
|
||||
def check_cpu_device(device: torch.device) -> bool:
|
||||
return device.type == "cpu"
|
||||
|
||||
@ -912,13 +917,17 @@ class FakeTensor(Tensor):
|
||||
is_cpu_zero_dim = t_is_cpu_zero_dim
|
||||
return
|
||||
|
||||
is_bypass_zero_dim_cpu_tensor_check_op = (
|
||||
func in bypass_zero_dim_cpu_tensor_check_ops
|
||||
)
|
||||
|
||||
# mismatching devices !
|
||||
# if current tensor is cpu 0 dim, defer to existing device
|
||||
if t_is_cpu_zero_dim:
|
||||
if t_is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op:
|
||||
return
|
||||
|
||||
# current device is from cpu 0 dim tensor, overwrite
|
||||
if is_cpu_zero_dim:
|
||||
if is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op:
|
||||
common_device = t.device
|
||||
is_cpu_zero_dim = t_is_cpu_zero_dim
|
||||
return
|
||||
|
Reference in New Issue
Block a user