remove _is_foreach_op codegen special cases, clean up mutable return type checks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76190

Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-04-25 12:01:18 -07:00
committed by PyTorch MergeBot
parent ea5209c9fd
commit 74e93f727a
2 changed files with 38 additions and 62 deletions

View File

@ -849,59 +849,6 @@ class NativeFunctionsGroup:
)
def is_foreach_op(name: str) -> bool:
return str(name) in set(
[
"_amp_foreach_non_finite_check_and_unscale_",
"_foreach_add_.ScalarList",
"_foreach_sub_.ScalarList",
"_foreach_mul_.ScalarList",
"_foreach_div_.ScalarList",
"_foreach_add_.Scalar",
"_foreach_sub_.Scalar",
"_foreach_mul_.Scalar",
"_foreach_div_.Scalar",
"_foreach_add_.List",
"_foreach_sub_.List",
"_foreach_mul_.List",
"_foreach_div_.List",
"_foreach_exp_",
"_foreach_sqrt_",
"_foreach_abs_",
"_foreach_acos_",
"_foreach_asin_",
"_foreach_atan_",
"_foreach_ceil_",
"_foreach_cos_",
"_foreach_cosh_",
"_foreach_erf_",
"_foreach_erfc_",
"_foreach_expm1_",
"_foreach_floor_",
"_foreach_log_",
"_foreach_log10_",
"_foreach_log1p_",
"_foreach_log2_",
"_foreach_neg_",
"_foreach_tan_",
"_foreach_tanh_",
"_foreach_sin_",
"_foreach_sinh_",
"_foreach_round_",
"_foreach_lgamma_",
"_foreach_frac_",
"_foreach_reciprocal_",
"_foreach_sigmoid_",
"_foreach_trunc_",
"_foreach_addcmul_.Scalar",
"_foreach_addcdiv_.Scalar",
"_foreach_addcmul_.ScalarList",
"_foreach_addcdiv_.ScalarList",
"_foreach_zero_",
]
)
@dataclass(frozen=True)
class BackendMetadata:
# The name of the backend kernel, for a given operator
@ -1120,13 +1067,39 @@ class FunctionSchema:
"Did you forget to mark an out argument as keyword-only?"
)
if self.arguments.out:
assert (
len(self.arguments.out) == len(self.returns) or len(self.returns) == 0
), "Must return as many arguments as there are out arguments, or no return at all"
# out= ops that return their mutable inputs are only really useful for method chaining.
# And method chaining is only really useful if the thing you're returning is a plain Tensor.
# So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
# and all other types of out= op schemas should return void.
# There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
assert (
len(self.returns) == 0
), "out= ops that accept tensor lists as out arguments "
"are expected to have no return type (since you can't do method chaining on them)"
else:
assert len(self.arguments.out) == len(
self.returns
), "Must return as many arguments as there are out arguments, or no return at all"
if self.name.name.inplace:
# TODO: fixme
if not is_foreach_op(str(self.name)):
assert len(self.returns) == 1
self_a = self.arguments.self_arg
assert (
self_a
and self_a.argument.annotation
and self_a.argument.annotation.is_write
)
if self_a.argument.type == BaseType(BaseTy.Tensor):
# All inplace ops with an ordinary `Tensor self` argument should return self,
# to allow for method chaining.
assert (
len(self.returns) == 1
and self.returns[0].annotation == self_a.argument.annotation
)
else:
# You can't method chain on non-tensor self arguments though (like a List[Tensor])
# so in all other cases we expect the return type to be none.
assert len(self.returns) == 0
def is_out_fn(self) -> bool:
# Note [is_out_fn]