Handle size/etc accessors in FakeTensor, support accessing symbolic types from toInt/etc in IValue (#124760)

Fixes https://github.com/pytorch/pytorch/issues/122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124760
Approved by: https://github.com/albanD, https://github.com/eellison
This commit is contained in:
Edward Z. Yang
2024-04-24 16:29:37 -04:00
committed by PyTorch MergeBot
parent 9bd6e93a04
commit 0d58aeb73a
10 changed files with 37 additions and 25 deletions

View File

@ -532,8 +532,13 @@ struct TORCH_API IValue final {
return Tag::Double == tag; return Tag::Double == tag;
} }
double toDouble() const { double toDouble() const {
AT_ASSERT(isDouble()); if (isDouble()) {
return payload.u.as_double; return payload.u.as_double;
} else if (isSymFloat()) {
return toSymFloat().guard_float(__FILE__, __LINE__);
} else {
TORCH_INTERNAL_ASSERT(0, "expected double");
}
} }
// ComplexDouble // ComplexDouble
@ -639,8 +644,13 @@ struct TORCH_API IValue final {
} }
int64_t toInt() const { int64_t toInt() const {
AT_ASSERT(isInt()); if (isInt()) {
return payload.u.as_int; return payload.u.as_int;
} else if (isSymInt()) {
return toSymInt().guard_int(__FILE__, __LINE__);
} else {
TORCH_INTERNAL_ASSERT(0, "expected int");
}
} }
// Bool // Bool
@ -658,8 +668,13 @@ struct TORCH_API IValue final {
return Tag::Bool == tag; return Tag::Bool == tag;
} }
bool toBool() const { bool toBool() const {
AT_ASSERT(isBool()); if (isBool()) {
return payload.u.as_bool; return payload.u.as_bool;
} else if (isSymBool()) {
return toSymBool().guard_bool(__FILE__, __LINE__);
} else {
TORCH_INTERNAL_ASSERT(0, "expected bool");
}
} }
// IntList // IntList

View File

@ -5404,9 +5404,6 @@ symbolic_aot_autograd_failures = {
xfail( xfail(
"nn.functional.embedding_bag", "" "nn.functional.embedding_bag", ""
), # Cannot call sizes() on tensor with symbolic sizes/strides ), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail(
"nn.functional.fractional_max_pool2d", ""
), # rand() received an invalid combination of arguments - g...
xfail( xfail(
"nn.functional.fractional_max_pool3d", "" "nn.functional.fractional_max_pool3d", ""
), # rand() received an invalid combination of arguments - g... ), # rand() received an invalid combination of arguments - g...
@ -5608,7 +5605,6 @@ symbolic_aot_autograd_module_failures = {
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group) torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int' # TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat' torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size()) torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size())
# RuntimeError: expected int at position 0, but got: SymInt # RuntimeError: expected int at position 0, but got: SymInt

View File

@ -5598,7 +5598,7 @@ a")
g = parse_ir(graph_str) g = parse_ir(graph_str)
m = self.createFunctionFromGraph(g) m = self.createFunctionFromGraph(g)
self.getExportImportCopy(m) self.getExportImportCopy(m)
with self.assertRaisesRegex(RuntimeError, "isInt"): with self.assertRaisesRegex(RuntimeError, "expected int"):
m() m()

View File

@ -1905,9 +1905,7 @@ symbolic_tensor_failures = {
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
@ -1958,29 +1956,17 @@ out_symbolic_tensor_failures = {
xfail('angle', ''), xfail('angle', ''),
xfail('argmax', ''), xfail('argmax', ''),
xfail('argmin', ''), xfail('argmin', ''),
xfail('bmm', ''),
xfail('fft.fft2', ''), xfail('fft.fft2', ''),
xfail('fft.fftn', ''), xfail('fft.fftn', ''),
xfail('fft.ifft2', ''), xfail('fft.ifft2', ''),
xfail('fft.ifftn', ''), xfail('fft.ifftn', ''),
xfail('gather', ''), xfail('gather', ''),
xfail('linalg.cholesky', ''),
xfail('linalg.cholesky_ex', ''),
xfail('linalg.det', ''),
xfail('linalg.det', 'singular'),
xfail('linalg.inv', ''),
xfail('linalg.inv_ex', ''),
xfail('linalg.pinv', ''), xfail('linalg.pinv', ''),
xfail('linalg.pinv', 'hermitian'), xfail('linalg.pinv', 'hermitian'),
xfail('linalg.svdvals', ''),
xfail('lu', ''), xfail('lu', ''),
xfail('max', 'reduction_with_dim'),
xfail('min', 'reduction_with_dim'),
xfail('nn.functional.avg_pool2d', ''),
xfail('scatter_add', ''), xfail('scatter_add', ''),
xfail('scatter', ''), xfail('scatter', ''),
xfail('take_along_dim', ''), xfail('take_along_dim', ''),
xfail('topk', ''),
xfail('triangular_solve', ''), xfail('triangular_solve', ''),
xfail('view_copy', ''), xfail('view_copy', ''),

View File

@ -540,6 +540,18 @@ class FakeTensor(torch.Tensor):
else: else:
return args[0].fake_device return args[0].fake_device
# this handler must be done inside FakeTensor subclass, not mode, because
# we can end up dispatching here when we have a fake tensor with
# symbolic sizes running under in_kernel_invocation_manager.
# The subclass is asked to handle this query because size (not
# sym_size) was called, but we are unable to serve it directly because
# there are symbolic sizes in the class. The use of
# in_kernel_invocation_manager means it's incorrect to activate a
# mode to actually handle this (this caused
# https://github.com/pytorch/pytorch/issues/122772).
if handler := _DISPATCH_META_HANDLERS.get(func):
return handler(args)
# Because fake mode can return NotImplemented (if it sees a subclass # Because fake mode can return NotImplemented (if it sees a subclass
# it doesn't know how to deal with), this test here is important # it doesn't know how to deal with), this test here is important
# because the next dispatch after a fake mode will attempt to use # because the next dispatch after a fake mode will attempt to use
@ -1468,6 +1480,9 @@ class FakeTensorMode(TorchDispatchMode):
r = func(*args, **kwargs) r = func(*args, **kwargs)
except NotImplementedError as not_implemented_error: except NotImplementedError as not_implemented_error:
return maybe_run_unsafe_fallback(not_implemented_error) return maybe_run_unsafe_fallback(not_implemented_error)
except Exception:
log.exception("failed while attempting to run meta for %s", func)
raise
return self.wrap_meta_outputs_with_default_device_logic( return self.wrap_meta_outputs_with_default_device_logic(
r, func, flat_args, device=kwargs.get("device") r, func, flat_args, device=kwargs.get("device")