Handle size/etc accessors in FakeTensor, support accessing symbolic types from toInt/etc in IValue (#124760)

Fixes https://github.com/pytorch/pytorch/issues/122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124760
Approved by: https://github.com/albanD, https://github.com/eellison
This commit is contained in:
Edward Z. Yang
2024-04-24 16:29:37 -04:00
committed by PyTorch MergeBot
parent 9bd6e93a04
commit 0d58aeb73a
10 changed files with 37 additions and 25 deletions

View File

@ -532,8 +532,13 @@ struct TORCH_API IValue final {
return Tag::Double == tag;
}
double toDouble() const {
AT_ASSERT(isDouble());
if (isDouble()) {
return payload.u.as_double;
} else if (isSymFloat()) {
return toSymFloat().guard_float(__FILE__, __LINE__);
} else {
TORCH_INTERNAL_ASSERT(0, "expected double");
}
}
// ComplexDouble
@ -639,8 +644,13 @@ struct TORCH_API IValue final {
}
int64_t toInt() const {
AT_ASSERT(isInt());
if (isInt()) {
return payload.u.as_int;
} else if (isSymInt()) {
return toSymInt().guard_int(__FILE__, __LINE__);
} else {
TORCH_INTERNAL_ASSERT(0, "expected int");
}
}
// Bool
@ -658,8 +668,13 @@ struct TORCH_API IValue final {
return Tag::Bool == tag;
}
bool toBool() const {
AT_ASSERT(isBool());
if (isBool()) {
return payload.u.as_bool;
} else if (isSymBool()) {
return toSymBool().guard_bool(__FILE__, __LINE__);
} else {
TORCH_INTERNAL_ASSERT(0, "expected bool");
}
}
// IntList

View File

@ -5404,9 +5404,6 @@ symbolic_aot_autograd_failures = {
xfail(
"nn.functional.embedding_bag", ""
), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail(
"nn.functional.fractional_max_pool2d", ""
), # rand() received an invalid combination of arguments - g...
xfail(
"nn.functional.fractional_max_pool3d", ""
), # rand() received an invalid combination of arguments - g...
@ -5608,7 +5605,6 @@ symbolic_aot_autograd_module_failures = {
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size())
# RuntimeError: expected int at position 0, but got: SymInt

View File

@ -5598,7 +5598,7 @@ a")
g = parse_ir(graph_str)
m = self.createFunctionFromGraph(g)
self.getExportImportCopy(m)
with self.assertRaisesRegex(RuntimeError, "isInt"):
with self.assertRaisesRegex(RuntimeError, "expected int"):
m()

View File

@ -1905,9 +1905,7 @@ symbolic_tensor_failures = {
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
@ -1958,29 +1956,17 @@ out_symbolic_tensor_failures = {
xfail('angle', ''),
xfail('argmax', ''),
xfail('argmin', ''),
xfail('bmm', ''),
xfail('fft.fft2', ''),
xfail('fft.fftn', ''),
xfail('fft.ifft2', ''),
xfail('fft.ifftn', ''),
xfail('gather', ''),
xfail('linalg.cholesky', ''),
xfail('linalg.cholesky_ex', ''),
xfail('linalg.det', ''),
xfail('linalg.det', 'singular'),
xfail('linalg.inv', ''),
xfail('linalg.inv_ex', ''),
xfail('linalg.pinv', ''),
xfail('linalg.pinv', 'hermitian'),
xfail('linalg.svdvals', ''),
xfail('lu', ''),
xfail('max', 'reduction_with_dim'),
xfail('min', 'reduction_with_dim'),
xfail('nn.functional.avg_pool2d', ''),
xfail('scatter_add', ''),
xfail('scatter', ''),
xfail('take_along_dim', ''),
xfail('topk', ''),
xfail('triangular_solve', ''),
xfail('view_copy', ''),

View File

@ -540,6 +540,18 @@ class FakeTensor(torch.Tensor):
else:
return args[0].fake_device
# this handler must be done inside FakeTensor subclass, not mode, because
# we can end up dispatching here when we have a fake tensor with
# symbolic sizes running under in_kernel_invocation_manager.
# The subclass is asked to handle this query because size (not
# sym_size) was called, but we are unable to serve it directly because
# there are symbolic sizes in the class. The use of
# in_kernel_invocation_manager means it's incorrect to activate a
# mode to actually handle this (this caused
# https://github.com/pytorch/pytorch/issues/122772).
if handler := _DISPATCH_META_HANDLERS.get(func):
return handler(args)
# Because fake mode can return NotImplemented (if it sees a subclass
# it doesn't know how to deal with), this test here is important
# because the next dispatch after a fake mode will attempt to use
@ -1468,6 +1480,9 @@ class FakeTensorMode(TorchDispatchMode):
r = func(*args, **kwargs)
except NotImplementedError as not_implemented_error:
return maybe_run_unsafe_fallback(not_implemented_error)
except Exception:
log.exception("failed while attempting to run meta for %s", func)
raise
return self.wrap_meta_outputs_with_default_device_logic(
r, func, flat_args, device=kwargs.get("device")