Revert "Add sym_size/stride/numel/storage_offset to native_function.yaml (#91919)"

This reverts commit 0388400f3f8a8ecae2f809ba40ca3ddd5a8b9028.

Reverted https://github.com/pytorch/pytorch/pull/91919 on behalf of https://github.com/atalman due to Break internal build
This commit is contained in:
PyTorch MergeBot
2023-01-17 21:03:18 +00:00
parent 88942a3199
commit befe815466
9 changed files with 25 additions and 64 deletions

View File

@ -19,9 +19,6 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
}
FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
auto alwaysCloneWithRealTypes = [&](const Argument& a) {
return a.cloneWithType(a.real_type());
};
auto cloneWithRealTypes = [&](const Argument& a) {
if (with_symint) {
return a.cloneWithType(a.real_type());
@ -42,8 +39,7 @@ FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
};
std::vector<Argument> new_arguments, new_returns;
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
// NB: SymInt returns are always SymInt
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), alwaysCloneWithRealTypes);
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
return FunctionSchema(
name(),
overload_name(),

View File

@ -49,22 +49,6 @@ int64_t stride(const Tensor& self, int64_t dim) {
return self.stride(dim);
}
c10::SymInt sym_size(const Tensor& self, int64_t dim) {
return self.sym_size(dim);
}
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
return self.sym_stride(dim);
}
c10::SymInt sym_numel(const Tensor& self) {
return self.sym_numel();
}
c10::SymInt sym_storage_offset(const Tensor& self) {
return self.sym_storage_offset();
}
int64_t size(const Tensor& self, Dimname dim) {
size_t pos_dim = dimname_to_position(self, dim);
return self.sizes()[pos_dim];

View File

@ -5020,27 +5020,6 @@
device_check: NoCheck
device_guard: False
- func: sym_size.int(Tensor self, int dim) -> SymInt
variants: function
device_check: NoCheck
device_guard: False
tags: canonical
manual_cpp_binding: True
- func: sym_numel(Tensor self) -> SymInt
variants: function
device_check: NoCheck
device_guard: False
tags: canonical
manual_cpp_binding: True
- func: sym_storage_offset(Tensor self) -> SymInt
variants: function
device_check: NoCheck
device_guard: False
tags: canonical
manual_cpp_binding: True
- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
variants: function, method
device_check: NoCheck
@ -5315,13 +5294,6 @@
device_check: NoCheck
device_guard: False
- func: sym_stride.int(Tensor self, int dim) -> SymInt
variants: function
device_check: NoCheck
device_guard: False
tags: canonical
manual_cpp_binding: True
- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method

View File

@ -294,10 +294,6 @@ xfail_not_implemented = {
"aten::subtract_.Scalar",
"aten::subtract_.Tensor",
"aten::svd.U",
"aten::sym_size.int",
"aten::sym_stride.int",
"aten::sym_numel",
"aten::sym_storage_offset",
"aten::tensor_split.indices",
"aten::tensor_split.sections",
"aten::tensor_split.tensor_indices_or_sections",

View File

@ -88,10 +88,6 @@ _SKIP_PYTHON_BINDINGS = [
"is_sparse_csr",
"size",
"stride",
"sym_size",
"sym_stride",
"sym_storage_offset",
"sym_numel",
".*_backward",
".*_backward_(out|input|weight|bias)",
".*_forward",

View File

@ -415,6 +415,16 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
TORCH_SELECTIVE_SCHEMA("aten::sym_size(Tensor self) -> SymInt[]"),
sym_size,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::sym_size.int(Tensor self, int dim) -> SymInt"),
sym_size_int,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::sym_stride.int(Tensor self, int dim) -> SymInt"),
sym_stride_int,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::stride(Tensor self) -> int[]"),
[](Stack& stack) {
@ -422,6 +432,15 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, arg.strides());
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::sym_numel(Tensor self) -> SymInt"),
sym_numel,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::sym_storage_offset(Tensor self) -> SymInt"),
sym_storage_offset,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::sym_stride(Tensor self) -> SymInt[]"),
sym_stride,

View File

@ -226,9 +226,7 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> Named
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
# placeholder is ignored
# NB: symint is ALWAYS respected for return types. So symint argument
# here is IGNORED
r = valuetype_type(t, binds="__placeholder__", symint=True)
r = valuetype_type(t, binds="__placeholder__", symint=symint)
if r is not None:
return r.type
@ -251,7 +249,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
elem = returntype_type(t.elem, mutable=False, symint=symint)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)

View File

@ -35,8 +35,6 @@ class CppSignature:
# Is this a symint C++ signature. For BC reasons, functions that take
# SymInts still present as int64_t in C++, and the SymInt variant is
# offered at a different overload name
#
# NB: If a function RETURNS a SymInt, this is ALWAYS false
symint: bool
# The set of C++ arguments which should not have defaults applied to them

View File

@ -1628,7 +1628,9 @@ class FunctionSchema:
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
def has_symint(self) -> bool:
return self.arguments.has_symint_arg()
return self.arguments.has_symint_arg() or any(
r.type.is_symint_like() for r in self.returns
)
def __str__(self) -> str:
all_arguments_str = str(self.arguments)