mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] Add sym_size/stride/numel/storage_offset to native_function.… (#100749)
…yaml (#91… (#91919) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/91919 Approved by: https://github.com/ezyang Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/92402 Reviewed By: ezyang Differential Revision: D42565586 Pulled By: SherlockNoMad fbshipit-source-id: 1c2986e45307e076d239836a1b45441a9fa3c9d9 ghstack-source-id: 969f4928486e04c57aaf98e20e3c3ca946c51613 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/100749 Approved by: https://github.com/zhxchen17, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
4dbab17edb
commit
bb454891ed
@ -19,6 +19,9 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
|
||||
}
|
||||
|
||||
FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
|
||||
auto alwaysCloneWithRealTypes = [&](const Argument& a) {
|
||||
return a.cloneWithType(a.real_type());
|
||||
};
|
||||
auto cloneWithRealTypes = [&](const Argument& a) {
|
||||
if (with_symint) {
|
||||
return a.cloneWithType(a.real_type());
|
||||
@ -39,7 +42,8 @@ FunctionSchema FunctionSchema::cloneWithRealTypes(bool with_symint) const {
|
||||
};
|
||||
std::vector<Argument> new_arguments, new_returns;
|
||||
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
|
||||
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
|
||||
// NB: SymInt returns are always SymInt
|
||||
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), alwaysCloneWithRealTypes);
|
||||
return FunctionSchema(
|
||||
name(),
|
||||
overload_name(),
|
||||
|
@ -49,6 +49,22 @@ int64_t stride(const Tensor& self, int64_t dim) {
|
||||
return self.stride(dim);
|
||||
}
|
||||
|
||||
c10::SymInt sym_size(const Tensor& self, int64_t dim) {
|
||||
return self.sym_size(dim);
|
||||
}
|
||||
|
||||
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
|
||||
return self.sym_stride(dim);
|
||||
}
|
||||
|
||||
c10::SymInt sym_numel(const Tensor& self) {
|
||||
return self.sym_numel();
|
||||
}
|
||||
|
||||
c10::SymInt sym_storage_offset(const Tensor& self) {
|
||||
return self.sym_storage_offset();
|
||||
}
|
||||
|
||||
int64_t size(const Tensor& self, Dimname dim) {
|
||||
size_t pos_dim = dimname_to_position(self, dim);
|
||||
return self.sizes()[pos_dim];
|
||||
|
@ -5103,6 +5103,27 @@
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
- func: sym_size.int(Tensor self, int dim) -> SymInt
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sym_numel(Tensor self) -> SymInt
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sym_storage_offset(Tensor self) -> SymInt
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
@ -5378,6 +5399,13 @@
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
- func: sym_stride.int(Tensor self, int dim) -> SymInt
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
@ -262,6 +262,10 @@ xfail_not_implemented = {
|
||||
"aten::subtract_.Scalar",
|
||||
"aten::subtract_.Tensor",
|
||||
"aten::svd.U",
|
||||
"aten::sym_size.int",
|
||||
"aten::sym_stride.int",
|
||||
"aten::sym_numel",
|
||||
"aten::sym_storage_offset",
|
||||
"aten::tensor_split.indices",
|
||||
"aten::tensor_split.sections",
|
||||
"aten::tensor_split.tensor_indices_or_sections",
|
||||
|
@ -89,6 +89,10 @@ _SKIP_PYTHON_BINDINGS = [
|
||||
"is_sparse_csr",
|
||||
"size",
|
||||
"stride",
|
||||
"sym_size",
|
||||
"sym_stride",
|
||||
"sym_storage_offset",
|
||||
"sym_numel",
|
||||
".*_backward",
|
||||
".*_backward_(out|input|weight|bias)",
|
||||
".*_forward",
|
||||
|
@ -924,10 +924,21 @@ def export(
|
||||
):
|
||||
super().__init__(m)
|
||||
arg_len = len(flat_args)
|
||||
self.new_args = [
|
||||
super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {})
|
||||
for i in range(0, arg_len)
|
||||
]
|
||||
self.new_args = []
|
||||
for i in range(0, arg_len):
|
||||
arg = super(ChangeInputOutputSignature, self).placeholder(
|
||||
f"arg{i}", (), {}
|
||||
)
|
||||
# Fill node.mata["val"] with faketensolintrunner from the input,
|
||||
# if it's not found in matched_input_elements_positions
|
||||
if (
|
||||
i not in matched_input_elements_positions
|
||||
and fake_mode is not None
|
||||
and isinstance(flat_args[i], torch.Tensor)
|
||||
):
|
||||
arg.node.meta["val"] = fake_mode.from_tensor(flat_args[i])
|
||||
self.new_args.append(arg)
|
||||
|
||||
self.old_args_gen = (
|
||||
self.new_args[i] for i in matched_input_elements_positions
|
||||
)
|
||||
|
@ -415,16 +415,6 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
TORCH_SELECTIVE_SCHEMA("aten::sym_size(Tensor self) -> SymInt[]"),
|
||||
sym_size,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::sym_size.int(Tensor self, int dim) -> SymInt"),
|
||||
sym_size_int,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::sym_stride.int(Tensor self, int dim) -> SymInt"),
|
||||
sym_stride_int,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::stride(Tensor self) -> int[]"),
|
||||
[](Stack& stack) {
|
||||
@ -432,15 +422,6 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
push(stack, arg.strides());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::sym_numel(Tensor self) -> SymInt"),
|
||||
sym_numel,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::sym_storage_offset(Tensor self) -> SymInt"),
|
||||
sym_storage_offset,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::sym_stride(Tensor self) -> SymInt[]"),
|
||||
sym_stride,
|
||||
|
@ -226,7 +226,9 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> Named
|
||||
# and a function with a return type of 'std::tuple' has >1 return name.
|
||||
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
||||
# placeholder is ignored
|
||||
r = valuetype_type(t, binds="__placeholder__", symint=symint)
|
||||
# NB: symint is ALWAYS respected for return types. So symint argument
|
||||
# here is IGNORED
|
||||
r = valuetype_type(t, binds="__placeholder__", symint=True)
|
||||
if r is not None:
|
||||
return r.type
|
||||
|
||||
@ -249,7 +251,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
|
||||
assert (
|
||||
not mutable
|
||||
), "Native functions should never return a mutable tensor list. They should return void."
|
||||
elem = returntype_type(t.elem, mutable=False, symint=symint)
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
||||
|
@ -35,6 +35,8 @@ class CppSignature:
|
||||
# Is this a symint C++ signature. For BC reasons, functions that take
|
||||
# SymInts still present as int64_t in C++, and the SymInt variant is
|
||||
# offered at a different overload name
|
||||
#
|
||||
# NB: If a function RETURNS a SymInt, this is ALWAYS false
|
||||
symint: bool
|
||||
|
||||
# The set of C++ arguments which should not have defaults applied to them
|
||||
|
@ -1649,9 +1649,7 @@ class FunctionSchema:
|
||||
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
|
||||
|
||||
def has_symint(self) -> bool:
|
||||
return self.arguments.has_symint_arg() or any(
|
||||
r.type.is_symint_like() for r in self.returns
|
||||
)
|
||||
return self.arguments.has_symint_arg()
|
||||
|
||||
def __str__(self) -> str:
|
||||
all_arguments_str = str(self.arguments)
|
||||
|
Reference in New Issue
Block a user