mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Handle size/etc accessors in FakeTensor, support accessing symbolic types from toInt/etc in IValue (#124760)
Fixes https://github.com/pytorch/pytorch/issues/122772 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124760 Approved by: https://github.com/albanD, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
9bd6e93a04
commit
0d58aeb73a
@ -532,8 +532,13 @@ struct TORCH_API IValue final {
|
|||||||
return Tag::Double == tag;
|
return Tag::Double == tag;
|
||||||
}
|
}
|
||||||
double toDouble() const {
|
double toDouble() const {
|
||||||
AT_ASSERT(isDouble());
|
if (isDouble()) {
|
||||||
return payload.u.as_double;
|
return payload.u.as_double;
|
||||||
|
} else if (isSymFloat()) {
|
||||||
|
return toSymFloat().guard_float(__FILE__, __LINE__);
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(0, "expected double");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ComplexDouble
|
// ComplexDouble
|
||||||
@ -639,8 +644,13 @@ struct TORCH_API IValue final {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t toInt() const {
|
int64_t toInt() const {
|
||||||
AT_ASSERT(isInt());
|
if (isInt()) {
|
||||||
return payload.u.as_int;
|
return payload.u.as_int;
|
||||||
|
} else if (isSymInt()) {
|
||||||
|
return toSymInt().guard_int(__FILE__, __LINE__);
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(0, "expected int");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bool
|
// Bool
|
||||||
@ -658,8 +668,13 @@ struct TORCH_API IValue final {
|
|||||||
return Tag::Bool == tag;
|
return Tag::Bool == tag;
|
||||||
}
|
}
|
||||||
bool toBool() const {
|
bool toBool() const {
|
||||||
AT_ASSERT(isBool());
|
if (isBool()) {
|
||||||
return payload.u.as_bool;
|
return payload.u.as_bool;
|
||||||
|
} else if (isSymBool()) {
|
||||||
|
return toSymBool().guard_bool(__FILE__, __LINE__);
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(0, "expected bool");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IntList
|
// IntList
|
||||||
|
@ -5404,9 +5404,6 @@ symbolic_aot_autograd_failures = {
|
|||||||
xfail(
|
xfail(
|
||||||
"nn.functional.embedding_bag", ""
|
"nn.functional.embedding_bag", ""
|
||||||
), # Cannot call sizes() on tensor with symbolic sizes/strides
|
), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||||
xfail(
|
|
||||||
"nn.functional.fractional_max_pool2d", ""
|
|
||||||
), # rand() received an invalid combination of arguments - g...
|
|
||||||
xfail(
|
xfail(
|
||||||
"nn.functional.fractional_max_pool3d", ""
|
"nn.functional.fractional_max_pool3d", ""
|
||||||
), # rand() received an invalid combination of arguments - g...
|
), # rand() received an invalid combination of arguments - g...
|
||||||
@ -5608,7 +5605,6 @@ symbolic_aot_autograd_module_failures = {
|
|||||||
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
|
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
|
||||||
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
|
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
|
||||||
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
|
# TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int'
|
||||||
torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
|
|
||||||
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
|
torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat'
|
||||||
torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size())
|
torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size())
|
||||||
# RuntimeError: expected int at position 0, but got: SymInt
|
# RuntimeError: expected int at position 0, but got: SymInt
|
||||||
|
@ -5598,7 +5598,7 @@ a")
|
|||||||
g = parse_ir(graph_str)
|
g = parse_ir(graph_str)
|
||||||
m = self.createFunctionFromGraph(g)
|
m = self.createFunctionFromGraph(g)
|
||||||
self.getExportImportCopy(m)
|
self.getExportImportCopy(m)
|
||||||
with self.assertRaisesRegex(RuntimeError, "isInt"):
|
with self.assertRaisesRegex(RuntimeError, "expected int"):
|
||||||
m()
|
m()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1905,9 +1905,7 @@ symbolic_tensor_failures = {
|
|||||||
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
|
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
|
||||||
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
|
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
|
||||||
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
|
||||||
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
|
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
|
||||||
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
|
|
||||||
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
|
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
|
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
|
||||||
|
|
||||||
@ -1958,29 +1956,17 @@ out_symbolic_tensor_failures = {
|
|||||||
xfail('angle', ''),
|
xfail('angle', ''),
|
||||||
xfail('argmax', ''),
|
xfail('argmax', ''),
|
||||||
xfail('argmin', ''),
|
xfail('argmin', ''),
|
||||||
xfail('bmm', ''),
|
|
||||||
xfail('fft.fft2', ''),
|
xfail('fft.fft2', ''),
|
||||||
xfail('fft.fftn', ''),
|
xfail('fft.fftn', ''),
|
||||||
xfail('fft.ifft2', ''),
|
xfail('fft.ifft2', ''),
|
||||||
xfail('fft.ifftn', ''),
|
xfail('fft.ifftn', ''),
|
||||||
xfail('gather', ''),
|
xfail('gather', ''),
|
||||||
xfail('linalg.cholesky', ''),
|
|
||||||
xfail('linalg.cholesky_ex', ''),
|
|
||||||
xfail('linalg.det', ''),
|
|
||||||
xfail('linalg.det', 'singular'),
|
|
||||||
xfail('linalg.inv', ''),
|
|
||||||
xfail('linalg.inv_ex', ''),
|
|
||||||
xfail('linalg.pinv', ''),
|
xfail('linalg.pinv', ''),
|
||||||
xfail('linalg.pinv', 'hermitian'),
|
xfail('linalg.pinv', 'hermitian'),
|
||||||
xfail('linalg.svdvals', ''),
|
|
||||||
xfail('lu', ''),
|
xfail('lu', ''),
|
||||||
xfail('max', 'reduction_with_dim'),
|
|
||||||
xfail('min', 'reduction_with_dim'),
|
|
||||||
xfail('nn.functional.avg_pool2d', ''),
|
|
||||||
xfail('scatter_add', ''),
|
xfail('scatter_add', ''),
|
||||||
xfail('scatter', ''),
|
xfail('scatter', ''),
|
||||||
xfail('take_along_dim', ''),
|
xfail('take_along_dim', ''),
|
||||||
xfail('topk', ''),
|
|
||||||
xfail('triangular_solve', ''),
|
xfail('triangular_solve', ''),
|
||||||
xfail('view_copy', ''),
|
xfail('view_copy', ''),
|
||||||
|
|
||||||
|
@ -540,6 +540,18 @@ class FakeTensor(torch.Tensor):
|
|||||||
else:
|
else:
|
||||||
return args[0].fake_device
|
return args[0].fake_device
|
||||||
|
|
||||||
|
# this handler must be done inside FakeTensor subclass, not mode, because
|
||||||
|
# we can end up dispatching here when we have a fake tensor with
|
||||||
|
# symbolic sizes running under in_kernel_invocation_manager.
|
||||||
|
# The subclass is asked to handle this query because size (not
|
||||||
|
# sym_size) was called, but we are unable to serve it directly because
|
||||||
|
# there are symbolic sizes in the class. The use of
|
||||||
|
# in_kernel_invocation_manager means it's incorrect to activate a
|
||||||
|
# mode to actually handle this (this caused
|
||||||
|
# https://github.com/pytorch/pytorch/issues/122772).
|
||||||
|
if handler := _DISPATCH_META_HANDLERS.get(func):
|
||||||
|
return handler(args)
|
||||||
|
|
||||||
# Because fake mode can return NotImplemented (if it sees a subclass
|
# Because fake mode can return NotImplemented (if it sees a subclass
|
||||||
# it doesn't know how to deal with), this test here is important
|
# it doesn't know how to deal with), this test here is important
|
||||||
# because the next dispatch after a fake mode will attempt to use
|
# because the next dispatch after a fake mode will attempt to use
|
||||||
@ -1468,6 +1480,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
r = func(*args, **kwargs)
|
r = func(*args, **kwargs)
|
||||||
except NotImplementedError as not_implemented_error:
|
except NotImplementedError as not_implemented_error:
|
||||||
return maybe_run_unsafe_fallback(not_implemented_error)
|
return maybe_run_unsafe_fallback(not_implemented_error)
|
||||||
|
except Exception:
|
||||||
|
log.exception("failed while attempting to run meta for %s", func)
|
||||||
|
raise
|
||||||
|
|
||||||
return self.wrap_meta_outputs_with_default_device_logic(
|
return self.wrap_meta_outputs_with_default_device_logic(
|
||||||
r, func, flat_args, device=kwargs.get("device")
|
r, func, flat_args, device=kwargs.get("device")
|
||||||
|
Reference in New Issue
Block a user