mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
remove _is_foreach_op codegen special cases, clean up mutable return type checks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76190 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea5209c9fd
commit
74e93f727a
@ -849,59 +849,6 @@ class NativeFunctionsGroup:
|
||||
)
|
||||
|
||||
|
||||
def is_foreach_op(name: str) -> bool:
|
||||
return str(name) in set(
|
||||
[
|
||||
"_amp_foreach_non_finite_check_and_unscale_",
|
||||
"_foreach_add_.ScalarList",
|
||||
"_foreach_sub_.ScalarList",
|
||||
"_foreach_mul_.ScalarList",
|
||||
"_foreach_div_.ScalarList",
|
||||
"_foreach_add_.Scalar",
|
||||
"_foreach_sub_.Scalar",
|
||||
"_foreach_mul_.Scalar",
|
||||
"_foreach_div_.Scalar",
|
||||
"_foreach_add_.List",
|
||||
"_foreach_sub_.List",
|
||||
"_foreach_mul_.List",
|
||||
"_foreach_div_.List",
|
||||
"_foreach_exp_",
|
||||
"_foreach_sqrt_",
|
||||
"_foreach_abs_",
|
||||
"_foreach_acos_",
|
||||
"_foreach_asin_",
|
||||
"_foreach_atan_",
|
||||
"_foreach_ceil_",
|
||||
"_foreach_cos_",
|
||||
"_foreach_cosh_",
|
||||
"_foreach_erf_",
|
||||
"_foreach_erfc_",
|
||||
"_foreach_expm1_",
|
||||
"_foreach_floor_",
|
||||
"_foreach_log_",
|
||||
"_foreach_log10_",
|
||||
"_foreach_log1p_",
|
||||
"_foreach_log2_",
|
||||
"_foreach_neg_",
|
||||
"_foreach_tan_",
|
||||
"_foreach_tanh_",
|
||||
"_foreach_sin_",
|
||||
"_foreach_sinh_",
|
||||
"_foreach_round_",
|
||||
"_foreach_lgamma_",
|
||||
"_foreach_frac_",
|
||||
"_foreach_reciprocal_",
|
||||
"_foreach_sigmoid_",
|
||||
"_foreach_trunc_",
|
||||
"_foreach_addcmul_.Scalar",
|
||||
"_foreach_addcdiv_.Scalar",
|
||||
"_foreach_addcmul_.ScalarList",
|
||||
"_foreach_addcdiv_.ScalarList",
|
||||
"_foreach_zero_",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BackendMetadata:
|
||||
# The name of the backend kernel, for a given operator
|
||||
@ -1120,13 +1067,39 @@ class FunctionSchema:
|
||||
"Did you forget to mark an out argument as keyword-only?"
|
||||
)
|
||||
if self.arguments.out:
|
||||
assert (
|
||||
len(self.arguments.out) == len(self.returns) or len(self.returns) == 0
|
||||
), "Must return as many arguments as there are out arguments, or no return at all"
|
||||
# out= ops that return their mutable inputs are only really useful for method chaining.
|
||||
# And method chaining is only really useful if the thing you're returning is a plain Tensor.
|
||||
# So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
|
||||
# and all other types of out= op schemas should return void.
|
||||
# There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
|
||||
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
|
||||
assert (
|
||||
len(self.returns) == 0
|
||||
), "out= ops that accept tensor lists as out arguments "
|
||||
"are expected to have no return type (since you can't do method chaining on them)"
|
||||
else:
|
||||
assert len(self.arguments.out) == len(
|
||||
self.returns
|
||||
), "Must return as many arguments as there are out arguments, or no return at all"
|
||||
|
||||
if self.name.name.inplace:
|
||||
# TODO: fixme
|
||||
if not is_foreach_op(str(self.name)):
|
||||
assert len(self.returns) == 1
|
||||
self_a = self.arguments.self_arg
|
||||
assert (
|
||||
self_a
|
||||
and self_a.argument.annotation
|
||||
and self_a.argument.annotation.is_write
|
||||
)
|
||||
if self_a.argument.type == BaseType(BaseTy.Tensor):
|
||||
# All inplace ops with an ordinary `Tensor self` argument should return self,
|
||||
# to allow for method chaining.
|
||||
assert (
|
||||
len(self.returns) == 1
|
||||
and self.returns[0].annotation == self_a.argument.annotation
|
||||
)
|
||||
else:
|
||||
# You can't method chain on non-tensor self arguments though (like a List[Tensor])
|
||||
# so in all other cases we expect the return type to be none.
|
||||
assert len(self.returns) == 0
|
||||
|
||||
def is_out_fn(self) -> bool:
|
||||
# Note [is_out_fn]
|
||||
|
||||
Reference in New Issue
Block a user