mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-27 17:54:55 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/76320 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			383 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			383 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from torchgen.model import (
 | |
|     Argument,
 | |
|     DispatchKey,
 | |
|     FunctionSchema,
 | |
|     BaseType,
 | |
|     BaseTy,
 | |
|     Return,
 | |
|     Annotation,
 | |
|     NativeFunction,
 | |
|     OperatorName,
 | |
|     BackendIndex,
 | |
|     BackendMetadata,
 | |
|     DeviceCheckType,
 | |
|     SchemaKind,
 | |
|     Variant,
 | |
| )
 | |
| from torchgen.utils import (
 | |
|     concatMap,
 | |
| )
 | |
| 
 | |
| 
 | |
| from typing import List, Tuple, Sequence, Dict
 | |
| from collections import defaultdict
 | |
| 
 | |
| # See Note: [Out ops with functional variants that don't get grouped properly]
 | |
| OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
 | |
|     # This has a functional variant, but it's currently marked private.
 | |
|     # This function should be marked private as well (*_backward ops aren't exposed to python anyway).
 | |
|     "adaptive_avg_pool3d_backward.grad_input",
 | |
|     # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
 | |
|     # Maybe we can kill this operator in favor of convolution_backward?
 | |
|     "_slow_conv2d_backward.grad_input",
 | |
| ]
 | |
| 
 | |
| 
 | |
| # See Note: [Mutable ops that cannot get an out variant]
 | |
| MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
 | |
|     # should be out=?
 | |
|     "_cummax_helper",
 | |
|     # should be out=?
 | |
|     "_cummin_helper",
 | |
| ]
 | |
| 
 | |
| INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
 | |
|     # polygamma and polygamma.out both exist, but have a
 | |
|     # pre-self arg (while polygamma_ does not)
 | |
|     # We should either fix this schema so it can be grouped properly,
 | |
|     # or allow the codegen to generate new functional/out= NativeFunctions for this op
 | |
|     # (which would require changing its overload name to prevent overload ambiguity).
 | |
|     "polygamma_"
 | |
| ]
 | |
| 
 | |
| # Groups "similar" NativeFunctions together
 | |
| # example add.Tensor, add_.Tensor, add.out
 | |
| # "similar" NativeFunctions are all expected to have an identical `signature()`,
 | |
| # But have differing SchemaKinds.
 | |
| def pre_group_native_functions(
 | |
|     native_functions: Sequence[NativeFunction],
 | |
| ) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
 | |
|     pre_grouped_native_functions: Dict[
 | |
|         FunctionSchema, Dict[SchemaKind, NativeFunction]
 | |
|     ] = defaultdict(dict)
 | |
|     for f in native_functions:
 | |
|         d = pre_grouped_native_functions[f.func.signature()]
 | |
|         assert f.func.kind() not in d
 | |
|         d[f.func.kind()] = f
 | |
|     return pre_grouped_native_functions
 | |
| 
 | |
| 
 | |
| # Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
 | |
| # Example before:
 | |
| #   _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
 | |
| # Example after:
 | |
| #   _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
 | |
| def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
 | |
|     # Generating an out= schema from an inplace schema.
 | |
|     assert func.kind() == SchemaKind.inplace
 | |
|     assert func.arguments.self_arg is not None
 | |
|     # The new out= schema has:
 | |
|     # - a new out argument with the same type as "func" (but with a mutable annotation)
 | |
|     # - The returns (if any) now alias the out= argument instead of "func"
 | |
|     # - an "out" overload name
 | |
|     return FunctionSchema(
 | |
|         name=func.name.remove_inplace().with_overload(
 | |
|             "out" if not func.name.overload_name else f"{func.name.overload_name}_out"
 | |
|         ),
 | |
|         arguments=func.arguments.remove_self_annotation().with_out_args(
 | |
|             [
 | |
|                 Argument(
 | |
|                     name="out",
 | |
|                     type=func.arguments.self_arg.argument.type,
 | |
|                     default=None,
 | |
|                     annotation=func.arguments.self_arg.argument.annotation,
 | |
|                 )
 | |
|             ]
 | |
|         ),
 | |
|         returns=func.returns,
 | |
|     )
 | |
| 
 | |
| 
 | |
| # Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
 | |
| # Example before:
 | |
| #   _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
 | |
| # Example after:
 | |
| #   _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))  # noqa: B950
 | |
| def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
 | |
|     # Generating an out= schema from a mutable schema.
 | |
|     assert func.kind() == SchemaKind.mutable
 | |
|     # The new out= schema has:
 | |
|     # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
 | |
|     #   (if the argument is a tensor then we also return it for method chaining,
 | |
|     #   otherwise we return nothing)
 | |
|     # - an "out" overload name
 | |
|     #
 | |
|     # Note that:
 | |
|     # (1) This also means that we can *only* generate an out= variant from a mutable schema
 | |
|     #     if the mutable schema has at least one tensor-like non-aliasing return.
 | |
|     # (2) The generated out= variant still has mutable positional arguments,
 | |
|     #     but if necessary we could probably add another out= variant that also
 | |
|     #     functionalizes the mutable arguments (a functional_out variant)
 | |
| 
 | |
|     # More of a sanity check - our existing restrictions on schemas should enforce that
 | |
|     # mutable schema kinds never return their mutable arguments.
 | |
|     assert not any(
 | |
|         r.annotation is not None and r.annotation.is_write for r in func.returns
 | |
|     )
 | |
| 
 | |
|     tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
 | |
|     assert len(tensorlike_rets) > 0
 | |
| 
 | |
|     used_annotations = concatMap(
 | |
|         lambda a: [] if a.annotation is None else a.annotation.alias_set,
 | |
|         func.arguments.flat_all,
 | |
|     )
 | |
|     valid_annotations = [
 | |
|         x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
 | |
|     ]
 | |
| 
 | |
|     all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
 | |
| 
 | |
|     new_out_args: List[Argument] = []
 | |
|     # The end result of new_returns is that:
 | |
|     # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
 | |
|     # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
 | |
|     new_returns: List[Return] = []
 | |
|     for (i, r) in enumerate(func.returns):
 | |
|         if r.type.is_tensor_like():
 | |
|             new_out = Argument(
 | |
|                 name=f"out{i}",
 | |
|                 type=r.type,
 | |
|                 default=None,
 | |
|                 annotation=Annotation.parse(f"{valid_annotations[i]}!"),
 | |
|             )
 | |
|             new_out_args.append(new_out)
 | |
|             if all_rets_are_tensors:
 | |
|                 # The convention for out= schemas is that they only return their out arguments
 | |
|                 # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
 | |
|                 new_ret = Return(
 | |
|                     name=None, type=new_out.type, annotation=new_out.annotation
 | |
|                 )
 | |
|                 new_returns.append(new_ret)
 | |
|         else:
 | |
|             new_returns.append(r)
 | |
| 
 | |
|     return FunctionSchema(
 | |
|         name=func.name.remove_inplace().with_overload(
 | |
|             "out" if not func.name.overload_name else f"{func.name.overload_name}_out"
 | |
|         ),
 | |
|         arguments=func.arguments.with_out_args(new_out_args),
 | |
|         returns=tuple(new_returns),
 | |
|     )
 | |
| 
 | |
| 
 | |
| # This function, given function of one SchemaKind, as well as a target SchemaKind,
 | |
| # generates a new NativeFunction with the same properties, but using the target SchemaKind.
 | |
| # We only actually generate functions for either functional or out= SchemaKinds.
 | |
| # This function returns a tuple, with:
 | |
| # - The generated NativeFunction
 | |
| # - a dictionary of `BackendIndex` objects, describing which dispatch keys
 | |
| #   we will generate kernels for, for the new NativeFunction.
 | |
| #   Details are in the function, but we only generate composite kernels (in some cases) today.
 | |
| def generate_function(
 | |
|     f: NativeFunction, k: SchemaKind
 | |
| ) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
 | |
|     from torchgen.api import cpp
 | |
| 
 | |
|     if k == SchemaKind.functional:
 | |
|         assert f.func.kind() != SchemaKind.functional
 | |
|         gets_composite_kernel = True
 | |
|         # The new "functional" NativeFunction has:
 | |
|         # - any mutable arguments have been converted into (immutable) returns.
 | |
|         #   (if a mutable argument was not also a return, it gets converted to one)
 | |
|         # - a "functional" overload name.
 | |
|         # The default grouping logic in signature() actually already does this,
 | |
|         # so we can piggy-back off it (but we still want return names)
 | |
|         func = f.func.signature(keep_return_names=True).with_name(
 | |
|             f.func.name.remove_inplace().with_overload(
 | |
|                 "functional"
 | |
|                 if not f.func.name.overload_name
 | |
|                 else f"{f.func.name.overload_name}_functional"
 | |
|             )
 | |
|         )
 | |
|     elif k == SchemaKind.out:
 | |
|         # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
 | |
|         # but at least today, there is no good reason to actually use them.
 | |
|         # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
 | |
|         gets_composite_kernel = False
 | |
|         if f.func.kind() == SchemaKind.inplace:
 | |
|             func = self_to_out_signature(f.func)
 | |
|         elif f.func.kind() == SchemaKind.mutable:
 | |
|             func = mutable_to_out_signature(f.func)
 | |
|         else:
 | |
|             raise AssertionError(
 | |
|                 "We only bother generating out= functions from either inplace or mutable variants"
 | |
|             )
 | |
|     else:
 | |
|         raise AssertionError(
 | |
|             "We currently only generate either functional or out= NativeFunctions"
 | |
|         )
 | |
| 
 | |
|     if gets_composite_kernel:
 | |
|         backend_metadata = {
 | |
|             DispatchKey.CompositeExplicitAutograd: {
 | |
|                 func.name: BackendMetadata(cpp.name(func), structured=False)
 | |
|             }
 | |
|         }
 | |
|     else:
 | |
|         backend_metadata = {}
 | |
| 
 | |
|     return (
 | |
|         NativeFunction(
 | |
|             func=func,
 | |
|             use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
 | |
|             # These generated fn's aren't meant to be user friendly- don't generate methods.
 | |
|             variants=set([Variant.function]),
 | |
|             structured=False,
 | |
|             structured_delegate=None,
 | |
|             structured_inherits=None,
 | |
|             precomputed=None,
 | |
|             autogen=[],
 | |
|             ufunc_inner_loop={},
 | |
|             manual_kernel_registration=False,
 | |
|             manual_cpp_binding=False,
 | |
|             python_module=None,
 | |
|             category_override=None,
 | |
|             device_guard=False,
 | |
|             device_check=DeviceCheckType.NoCheck,
 | |
|             loc=f.loc,
 | |
|             cpp_no_default_args=set(),
 | |
|             is_abstract=f.is_abstract,
 | |
|             has_composite_implicit_autograd_kernel=False,
 | |
|             has_composite_explicit_autograd_kernel=gets_composite_kernel,
 | |
|             # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
 | |
|             # which NativeFunction objects did not come directly from native_functions.yaml.
 | |
|             tags=set(["generated"]),
 | |
|         ),
 | |
|         backend_metadata,
 | |
|     )
 | |
| 
 | |
| 
 | |
| # This function is responsible for adding generated NativeFunctions which don't appear
 | |
| # explicitly in the codegen.
 | |
| # You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
 | |
| # torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
 | |
| # (Maybe we should make a friendly API for this)
 | |
| #
 | |
| # Note: this function *mutates* its two inputs,
 | |
| # adding the new NativeFunctions / BackendMetadata to them
 | |
| def add_generated_native_functions(
 | |
|     rs: List[NativeFunction],
 | |
|     indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
 | |
| ) -> None:
 | |
|     # The main code for gnerating new NativeFunctions
 | |
|     # First we group of NaitveFunctions by schema kind,
 | |
|     # then we detect which ones are missing and generate them.
 | |
|     pre_grouped_native_functions = pre_group_native_functions(rs)
 | |
|     for k, d in pre_grouped_native_functions.items():
 | |
|         has_functional = SchemaKind.functional in d
 | |
|         has_inplace = SchemaKind.inplace in d
 | |
|         has_mutable = SchemaKind.mutable in d
 | |
|         has_out = SchemaKind.out in d
 | |
| 
 | |
|         # We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
 | |
|         # (1) If an operator has an inplace/out= variant but no functional variant, we can generate
 | |
|         #     a simple functional variant that the functionalization pass can consume.
 | |
|         # (2) If an operator has an inplace and functional but no out= variant, we generate an out=
 | |
|         #     variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
 | |
|         #     while maintaining the constraint that the out= variant is "required".
 | |
|         #
 | |
|         # For now, we don't bother generated NativeFunctions for existing operators
 | |
|         # that only have a functional variant.
 | |
|         if has_mutable or has_inplace or has_out:
 | |
| 
 | |
|             # Don't bother generating functions trio's for native functions that bypass the dispatcher.
 | |
|             are_manual = all(f.manual_cpp_binding for f in d.values())
 | |
|             # Don't bother generating functional + out= variants for view operators
 | |
|             has_view_ops = (
 | |
|                 has_inplace and "inplace_view" in d[SchemaKind.inplace].tags
 | |
|             ) or any(f.is_view_op for f in d.values())
 | |
|             # Don't generate the other variants for CompositeImplicitAutograd operators.
 | |
|             # We could probably do this, but the main benefit of generating the function triplets
 | |
|             # is for transforms that need them, and transforms don't need to act directly
 | |
|             # on CompositeImplicitAutograd operators (since we let them decompose).
 | |
|             are_composite_implicit = all(
 | |
|                 f.has_composite_implicit_autograd_kernel for f in d.values()
 | |
|             )
 | |
|             if are_manual or has_view_ops or are_composite_implicit:
 | |
|                 continue
 | |
|             if has_out and len(d.values()) == 1:
 | |
|                 # Note: [Out ops with functional variants that don't get grouped properly]
 | |
|                 # In theory we could validly have an out= operator in native_functions.yaml
 | |
|                 # that has no other variants.
 | |
|                 # But today, all of the operators where that's the case actually do have
 | |
|                 # functional variants, that we are just unable to pair up properly.
 | |
|                 # I think banning this all together is probably safer
 | |
|                 # (you can always add a functional variant yourself if you want to add a new out= operator).
 | |
|                 #
 | |
|                 # We should probably fix the existing cases; this check is to prevent us from adding more over time.
 | |
|                 if (
 | |
|                     str(d[SchemaKind.out].func.name)
 | |
|                     not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
 | |
|                 ):
 | |
|                     raise AssertionError(
 | |
|                         f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
 | |
|                     )
 | |
|                 continue
 | |
| 
 | |
|             # Some inplace ops that have problematic schemas (that we should fix), which prevent us
 | |
|             # from generating out= and functional variants
 | |
|             if (
 | |
|                 has_inplace
 | |
|                 and str(d[SchemaKind.inplace].func.name)
 | |
|                 in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
 | |
|             ):
 | |
|                 continue
 | |
| 
 | |
|             base_fn = (
 | |
|                 d[SchemaKind.inplace]
 | |
|                 if has_inplace
 | |
|                 else d[SchemaKind.mutable]
 | |
|                 if has_mutable
 | |
|                 else d[SchemaKind.out]
 | |
|             )
 | |
| 
 | |
|             # Note: [Mutable ops that cannot get an out variant]
 | |
|             # We can only generate an out= variant if either:
 | |
|             # - the original function has tensor-like returns (since we can convert them to out kwargs)
 | |
|             # - or it's inplace (since we can convert `self` to an out kwarg)
 | |
|             # There are only two functions that don't fit this criteria today though,
 | |
|             # and they both look like they should be fixed to be out= variants,
 | |
|             # so if feels safer to ban this schema all-together
 | |
|             gets_out_variant = not has_out and (
 | |
|                 base_fn.func.kind() == SchemaKind.inplace
 | |
|                 or any(r.type.is_tensor_like() for r in base_fn.func.returns)
 | |
|             )
 | |
|             if not has_out and not gets_out_variant:
 | |
|                 if (
 | |
|                     str(base_fn.func.name)
 | |
|                     not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
 | |
|                 ):
 | |
|                     raise AssertionError(
 | |
|                         f"""Found a mutable operator that we could not generate an out= variant for: {str(base_fn.func)}.
 | |
| These operators are problematic, because we can't easily auto-generate functionalization code for them. If you really need
 | |
| the operator have the schema mentioned, that add the name of the operator to the allow-list. Otherwise if possible,
 | |
| please convert it to an inplace operator"""
 | |
|                     )
 | |
| 
 | |
|             # Generate an out= variant
 | |
|             if gets_out_variant:
 | |
|                 fn, metadata = generate_function(base_fn, SchemaKind.out)
 | |
|                 d[SchemaKind.out] = fn
 | |
|                 BackendIndex.grow_index(indices, metadata)
 | |
|                 rs.append(fn)
 | |
| 
 | |
|             # Generate a functional variant, but only do it if the operator got an out= variant
 | |
|             # (Functional variants are only useful if we can group up the variants,
 | |
|             # which we can only do if they have an out= variant)
 | |
|             if not has_functional and (has_out or gets_out_variant):
 | |
|                 fn, metadata = generate_function(base_fn, SchemaKind.functional)
 | |
|                 d[SchemaKind.functional] = fn
 | |
|                 BackendIndex.grow_index(indices, metadata)
 | |
|                 rs.append(fn)
 |