mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Handle size/etc accessors in FakeTensor, support accessing symbolic types from toInt/etc in IValue (#124760)
Fixes https://github.com/pytorch/pytorch/issues/122772 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124760 Approved by: https://github.com/albanD, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
9bd6e93a04
commit
0d58aeb73a
@ -532,8 +532,13 @@ struct TORCH_API IValue final {
|
||||
return Tag::Double == tag;
|
||||
}
|
||||
double toDouble() const {
|
||||
AT_ASSERT(isDouble());
|
||||
if (isDouble()) {
|
||||
return payload.u.as_double;
|
||||
} else if (isSymFloat()) {
|
||||
return toSymFloat().guard_float(__FILE__, __LINE__);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "expected double");
|
||||
}
|
||||
}
|
||||
|
||||
// ComplexDouble
|
||||
@ -639,8 +644,13 @@ struct TORCH_API IValue final {
|
||||
}
|
||||
|
||||
int64_t toInt() const {
|
||||
AT_ASSERT(isInt());
|
||||
if (isInt()) {
|
||||
return payload.u.as_int;
|
||||
} else if (isSymInt()) {
|
||||
return toSymInt().guard_int(__FILE__, __LINE__);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "expected int");
|
||||
}
|
||||
}
|
||||
|
||||
// Bool
|
||||
@ -658,8 +668,13 @@ struct TORCH_API IValue final {
|
||||
return Tag::Bool == tag;
|
||||
}
|
||||
bool toBool() const {
|
||||
AT_ASSERT(isBool());
|
||||
if (isBool()) {
|
||||
return payload.u.as_bool;
|
||||
} else if (isSymBool()) {
|
||||
return toSymBool().guard_bool(__FILE__, __LINE__);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "expected bool");
|
||||
}
|
||||
}
|
||||
|
||||
// IntList
|
||||
|
||||
@ -5404,9 +5404,6 @@ symbolic_aot_autograd_failures = {
|
||||
xfail(
|
||||
"nn.functional.embedding_bag", ""
|
||||
), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail(
|
||||
"nn.functional.fractional_max_pool2d", ""
|
||||
), # rand() received an invalid combination of arguments - g...
|
||||
xfail(
|
||||
"nn.functional.fractional_max_pool3d", ""
|
||||
), # rand() received an invalid combination of arguments - g...
|
||||
@ -5608,7 +5605,6 @@ symbolic_aot_autograd_module_failures = {
|
||||
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
|
||||
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
|
||||
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
|
||||
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
|
||||
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
|
||||
torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size())
|
||||
# RuntimeError: expected int at position 0, but got: SymInt
|
||||
|
||||
@ -5598,7 +5598,7 @@ a")
|
||||
g = parse_ir(graph_str)
|
||||
m = self.createFunctionFromGraph(g)
|
||||
self.getExportImportCopy(m)
|
||||
with self.assertRaisesRegex(RuntimeError, "isInt"):
|
||||
with self.assertRaisesRegex(RuntimeError, "expected int"):
|
||||
m()
|
||||
|
||||
|
||||
|
||||
@ -1905,9 +1905,7 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
|
||||
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
|
||||
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
|
||||
|
||||
@ -1958,29 +1956,17 @@ out_symbolic_tensor_failures = {
|
||||
xfail('angle', ''),
|
||||
xfail('argmax', ''),
|
||||
xfail('argmin', ''),
|
||||
xfail('bmm', ''),
|
||||
xfail('fft.fft2', ''),
|
||||
xfail('fft.fftn', ''),
|
||||
xfail('fft.ifft2', ''),
|
||||
xfail('fft.ifftn', ''),
|
||||
xfail('gather', ''),
|
||||
xfail('linalg.cholesky', ''),
|
||||
xfail('linalg.cholesky_ex', ''),
|
||||
xfail('linalg.det', ''),
|
||||
xfail('linalg.det', 'singular'),
|
||||
xfail('linalg.inv', ''),
|
||||
xfail('linalg.inv_ex', ''),
|
||||
xfail('linalg.pinv', ''),
|
||||
xfail('linalg.pinv', 'hermitian'),
|
||||
xfail('linalg.svdvals', ''),
|
||||
xfail('lu', ''),
|
||||
xfail('max', 'reduction_with_dim'),
|
||||
xfail('min', 'reduction_with_dim'),
|
||||
xfail('nn.functional.avg_pool2d', ''),
|
||||
xfail('scatter_add', ''),
|
||||
xfail('scatter', ''),
|
||||
xfail('take_along_dim', ''),
|
||||
xfail('topk', ''),
|
||||
xfail('triangular_solve', ''),
|
||||
xfail('view_copy', ''),
|
||||
|
||||
|
||||
@ -540,6 +540,18 @@ class FakeTensor(torch.Tensor):
|
||||
else:
|
||||
return args[0].fake_device
|
||||
|
||||
# this handler must be done inside FakeTensor subclass, not mode, because
|
||||
# we can end up dispatching here when we have a fake tensor with
|
||||
# symbolic sizes running under in_kernel_invocation_manager.
|
||||
# The subclass is asked to handle this query because size (not
|
||||
# sym_size) was called, but we are unable to serve it directly because
|
||||
# there are symbolic sizes in the class. The use of
|
||||
# in_kernel_invocation_manager means it's incorrect to activate a
|
||||
# mode to actually handle this (this caused
|
||||
# https://github.com/pytorch/pytorch/issues/122772).
|
||||
if handler := _DISPATCH_META_HANDLERS.get(func):
|
||||
return handler(args)
|
||||
|
||||
# Because fake mode can return NotImplemented (if it sees a subclass
|
||||
# it doesn't know how to deal with), this test here is important
|
||||
# because the next dispatch after a fake mode will attempt to use
|
||||
@ -1468,6 +1480,9 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
r = func(*args, **kwargs)
|
||||
except NotImplementedError as not_implemented_error:
|
||||
return maybe_run_unsafe_fallback(not_implemented_error)
|
||||
except Exception:
|
||||
log.exception("failed while attempting to run meta for %s", func)
|
||||
raise
|
||||
|
||||
return self.wrap_meta_outputs_with_default_device_logic(
|
||||
r, func, flat_args, device=kwargs.get("device")
|
||||
|
||||
Reference in New Issue
Block a user