mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][CI] bump ruff
to 0.9.0: string quote styles (#144569)
Reference: https://docs.astral.sh/ruff/formatter/#f-string-formatting - Change the outer quotes to double quotes for nested f-strings ```diff - f'{", ".join(args)}' + f"{', '.join(args)}" ``` - Change the inner quotes to double quotes for triple f-strings ```diff string = """ - {', '.join(args)} + {", ".join(args)} """ ``` - Join implicitly concatenated strings ```diff - string = "short string " "short string " f"{var}" + string = f"short string short string {var}" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/144569 Approved by: https://github.com/Skylion007 ghstack dependencies: #146509
This commit is contained in:
committed by
PyTorch MergeBot
parent
52f6d4aa30
commit
754fb834db
@ -405,7 +405,7 @@ class PythonSignature:
|
||||
if len(schema_formals) > positional_argc:
|
||||
schema_formals.insert(positional_argc, "*")
|
||||
|
||||
return f'{self.name}({", ".join(schema_formals)})'
|
||||
return f"{self.name}({', '.join(schema_formals)})"
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
@ -421,7 +421,7 @@ class PythonSignature:
|
||||
# pyi also includes self (with no typing/defaults) for methods
|
||||
if self.method:
|
||||
schema_formals.insert(0, "self")
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..."
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||
# only pyi uses vararg signatures
|
||||
@ -457,7 +457,7 @@ class PythonSignature:
|
||||
# pyi also includes self (with no typing/defaults) for methods
|
||||
if self.method:
|
||||
schema_formals.insert(0, "self")
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..."
|
||||
|
||||
|
||||
# The deprecated python signature involves some special logic, so create a
|
||||
@ -498,7 +498,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
schema_formals.insert(positional_argc, "*")
|
||||
|
||||
returns_str = returns_str_pyi(self)
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..."
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||
# the codegen doesn't include vararg variants for deprecated signatures
|
||||
@ -1474,11 +1474,11 @@ def dispatch_lambda_exprs(
|
||||
inits.append(
|
||||
f"""\
|
||||
const auto options = TensorOptions()
|
||||
.dtype({arg_parser_outputs['dtype'].expr})
|
||||
.device({arg_parser_outputs['device'].expr})
|
||||
.layout({arg_parser_outputs['layout'].expr})
|
||||
.requires_grad({arg_parser_outputs['requires_grad'].expr})
|
||||
.pinned_memory({arg_parser_outputs['pin_memory'].expr});
|
||||
.dtype({arg_parser_outputs["dtype"].expr})
|
||||
.device({arg_parser_outputs["device"].expr})
|
||||
.layout({arg_parser_outputs["layout"].expr})
|
||||
.requires_grad({arg_parser_outputs["requires_grad"].expr})
|
||||
.pinned_memory({arg_parser_outputs["pin_memory"].expr});
|
||||
torch::utils::maybe_initialize_device(options);
|
||||
"""
|
||||
)
|
||||
@ -1500,9 +1500,9 @@ torch::utils::maybe_initialize_device(options);
|
||||
|
||||
inits.append(
|
||||
f"""\
|
||||
check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
|
||||
{arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
|
||||
{arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
|
||||
check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
|
||||
{arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
|
||||
{arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
|
||||
"""
|
||||
)
|
||||
# we'll set requires_grad on outgoing tensor
|
||||
|
@ -366,9 +366,9 @@ class FunctionalizationLambda:
|
||||
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
|
||||
]
|
||||
if not self.is_reverse and maybe_index is not None:
|
||||
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
|
||||
return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
|
||||
else:
|
||||
return f'{inner_call_name}({", ".join(call_exprs)});'
|
||||
return f"{inner_call_name}({', '.join(call_exprs)});"
|
||||
|
||||
@staticmethod
|
||||
def from_func(
|
||||
|
@ -131,7 +131,7 @@ class TupleCType(CType):
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
|
||||
return f"::std::tuple<{','.join([e.cpp_type() for e in self.elems])}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
return TupleCType([e.remove_const_ref() for e in self.elems])
|
||||
|
@ -543,7 +543,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
||||
aten_name += "_symint"
|
||||
shape_str = f"""\
|
||||
{meta_conversion_str}
|
||||
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
|
||||
auto out_meta = at::{dispatch_ns}::{aten_name}({", ".join(meta_call_args)});
|
||||
{meta_out}"""
|
||||
else:
|
||||
shape_sig = ComputeShapeSignature(
|
||||
@ -559,7 +559,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
||||
func_schema_str = "aten::" + str(func.func)
|
||||
shape_str += f"""
|
||||
if(torch::lazy::symbolicShapeEnabled()){{
|
||||
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
|
||||
std::vector<torch::jit::IValue> inputs = {{ {", ".join(str(a.name) for a in all_args)} }};
|
||||
const char* schema_str = "{func_schema_str}";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}}
|
||||
|
@ -53,7 +53,7 @@ def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list
|
||||
return [
|
||||
f"""\
|
||||
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
|
||||
void impl({', '.join(a.decl() for a in out_args)});
|
||||
void impl({", ".join(a.decl() for a in out_args)});
|
||||
}};
|
||||
"""
|
||||
]
|
||||
|
@ -332,7 +332,7 @@ class RegisterDispatchKey:
|
||||
f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
|
||||
for i, ret_name in enumerate(return_names)
|
||||
)
|
||||
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
|
||||
returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})"
|
||||
elif len(return_names) == 1:
|
||||
ret_name = return_names[0]
|
||||
updates = f"{copy_op}({func_res}, {ret_name});"
|
||||
@ -448,7 +448,7 @@ class RegisterDispatchKey:
|
||||
def generate_defn(cpp_sig: CppSignature) -> str:
|
||||
return f"""
|
||||
{cpp_sig.defn()} {{
|
||||
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||
return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||
}}
|
||||
"""
|
||||
|
||||
@ -802,7 +802,7 @@ resize_out(out, sizes, strides, options);
|
||||
def generate_defn(cpp_sig: CppSignature) -> str:
|
||||
return f"""
|
||||
{cpp_sig.defn()} {{
|
||||
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||
return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
|
||||
}}
|
||||
"""
|
||||
|
||||
@ -986,12 +986,15 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||
# For an overview of what this template code looks like, see
|
||||
# https://github.com/pytorch/rfcs/pull/9
|
||||
return f"""\
|
||||
{self.gen_class(
|
||||
f, k,
|
||||
class_name=class_name,
|
||||
parent_class=parent_class,
|
||||
generate_super=self.g.out.structured_inherits is not None
|
||||
)}
|
||||
{
|
||||
self.gen_class(
|
||||
f,
|
||||
k,
|
||||
class_name=class_name,
|
||||
parent_class=parent_class,
|
||||
generate_super=self.g.out.structured_inherits is not None,
|
||||
)
|
||||
}
|
||||
|
||||
{sig.defn()} {{
|
||||
{sig_body_str}
|
||||
|
@ -477,15 +477,15 @@ def compute_ufunc_cpu_dtype_body(
|
||||
return f"""
|
||||
{body_str}
|
||||
cpu_kernel_vec(iter,
|
||||
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
|
||||
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
|
||||
[=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
|
||||
[=]({", ".join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
|
||||
);
|
||||
"""
|
||||
else:
|
||||
return f"""
|
||||
{body_str}
|
||||
cpu_kernel(iter,
|
||||
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
|
||||
[=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
|
||||
);
|
||||
"""
|
||||
|
||||
|
@ -499,7 +499,7 @@ def generate_static_dispatch_fallback_call(
|
||||
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
|
||||
else:
|
||||
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
|
||||
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
|
||||
{", ".join([str(index.dispatch_key) for index in backend_indices])} ");"""
|
||||
|
||||
|
||||
def static_dispatch(
|
||||
@ -552,7 +552,7 @@ def static_dispatch(
|
||||
)
|
||||
if tensor_args != "":
|
||||
subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
|
||||
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
|
||||
stmts.append(f"""DispatchKeySet _dk_set = {" | ".join(subexprs)};""")
|
||||
stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
|
||||
|
||||
dispatch_code = []
|
||||
@ -1016,7 +1016,7 @@ C10_ALWAYS_INLINE
|
||||
{sig.defn(name)} {{
|
||||
{compute_dk}
|
||||
return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
|
||||
_dk, {', '.join(a.expr for a in dispatcher_exprs)});
|
||||
_dk, {", ".join(a.expr for a in dispatcher_exprs)});
|
||||
}}
|
||||
"""
|
||||
elif self.target is Target.REGISTRATION:
|
||||
|
@ -299,7 +299,7 @@ def gen_declaration_and_definition(
|
||||
{declaration} {{
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
|
||||
{tmp_result}{backend_call}(
|
||||
{textwrap.indent(', '.join(callsite_exprs), " ")}
|
||||
{textwrap.indent(", ".join(callsite_exprs), " ")}
|
||||
);{textwrap.indent(ret_assignments_str, " ")}
|
||||
}});
|
||||
}}
|
||||
|
@ -119,10 +119,10 @@ def parse_backend_yaml(
|
||||
# ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
|
||||
yaml_values.pop("ir_gen", {})
|
||||
|
||||
assert (
|
||||
len(yaml_values.keys()) == 0
|
||||
), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
|
||||
Only the following keys are supported: {", ".join(valid_keys)}'
|
||||
assert len(yaml_values.keys()) == 0, (
|
||||
f"{backend_yaml_path} contains unexpected keys: {', '.join(yaml_values.keys())}. "
|
||||
f"Only the following keys are supported: {', '.join(valid_keys)}"
|
||||
)
|
||||
|
||||
def create_backend_index(
|
||||
backend_ops: list[str],
|
||||
|
@ -280,7 +280,7 @@ class ComputeCodegenUnboxedKernels:
|
||||
[
|
||||
f"""
|
||||
Kernel(
|
||||
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''}
|
||||
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""}
|
||||
[]({contextArg.defn()}, EValue** stack) {{
|
||||
{code_connector.join(code_list)}
|
||||
|
||||
|
@ -407,7 +407,7 @@ def emit_view_functionalization_body(
|
||||
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
|
||||
{unwrap_tensor_args_str}
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
|
||||
return at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
|
||||
}}
|
||||
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
||||
auto inverse_return_mode = (
|
||||
@ -436,7 +436,7 @@ def emit_view_functionalization_body(
|
||||
{meta_conversion_str}
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
|
||||
reference_tensor_output = at::_ops::{noop_api_name}::call({", ".join(meta_call_args)});
|
||||
}}
|
||||
// This function adds the above view meta to the current tensor and replays them off the base,
|
||||
// mutating the size/stride info of the current FunctionalTensorWrapper.
|
||||
@ -462,7 +462,7 @@ def emit_view_functionalization_body(
|
||||
if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
|
||||
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
|
||||
return at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
|
||||
}}
|
||||
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
||||
auto inverse_return_mode = (
|
||||
@ -477,15 +477,15 @@ def emit_view_functionalization_body(
|
||||
{meta_conversion_str}
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
|
||||
reference_tensor_output = at::_ops::{noop_api_name}::call({", ".join(meta_call_args)});
|
||||
}}
|
||||
{return_type} tmp_output;
|
||||
{{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
if (reapply_views) {{
|
||||
tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
|
||||
tmp_output = at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
|
||||
}} else {{
|
||||
tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)});
|
||||
tmp_output = at::_ops::{api_name}::call({", ".join(view_redispatch_args)});
|
||||
}}
|
||||
}}
|
||||
{symbolic_inputs_check}
|
||||
@ -502,7 +502,7 @@ def emit_view_functionalization_body(
|
||||
}},
|
||||
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
|
||||
/*is_multi_output=*/{str(is_multi_output_view).lower()},
|
||||
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
|
||||
/*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
|
||||
);
|
||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
@ -686,7 +686,7 @@ def emit_inplace_functionalization_body(
|
||||
[
|
||||
f"""
|
||||
at::functionalization::impl::replace_(
|
||||
{a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'});
|
||||
{a.name}, {"std::get<" + str(i) + ">(tmp_output)" if len(f.func.returns) > 1 else "tmp_output"});
|
||||
at::functionalization::impl::commit_update({a.name});"""
|
||||
for (i, a) in enumerate(f.func.arguments.out)
|
||||
if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
|
||||
@ -722,7 +722,7 @@ def emit_inplace_functionalization_body(
|
||||
{meta_conversion_str}
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
|
||||
at::_ops::{f.func.name.unambiguous_name()}::call({", ".join(a.name for a in meta_call_ctx)});
|
||||
}}
|
||||
{unwrap_tensor_args_str}
|
||||
if (!({check_all_mutated_args_are_functional})) {{
|
||||
@ -736,16 +736,16 @@ def emit_inplace_functionalization_body(
|
||||
}} else {{
|
||||
// case 2: arguments are not functional tensors, so we no-op and redispatch.
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
{maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
|
||||
{return_from_mutable_noop_redispatch(f, 'tmp_output')}
|
||||
{maybe_create_output(f, "tmp_output")}at::_ops::{f.func.name.unambiguous_name()}::call({", ".join(inplace_exprs)});
|
||||
{return_from_mutable_noop_redispatch(f, "tmp_output")}
|
||||
}}
|
||||
}} else {{
|
||||
{return_type} tmp_output;
|
||||
{{
|
||||
at::AutoDispatchSkipFunctionalize guard;
|
||||
tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
|
||||
tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({", ".join(functional_exprs)});
|
||||
}}
|
||||
{wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')}
|
||||
{wrap_propagate_mutations_and_return(f, g.functional, "tmp_output")}
|
||||
}}
|
||||
}}"""
|
||||
|
||||
|
@ -97,7 +97,7 @@ def gen_case_where_all_bdims_are_none(
|
||||
e.expr for e in translate(outer_sig.arguments(), sig.arguments())
|
||||
)
|
||||
return f"""\
|
||||
if ({' && '.join(conditions)}) {{
|
||||
if ({" && ".join(conditions)}) {{
|
||||
return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
|
||||
}}"""
|
||||
|
||||
@ -124,7 +124,7 @@ def gen_returns(
|
||||
if len(wrapped_returns) == 1:
|
||||
result = f"return {wrapped_returns[0]};"
|
||||
else:
|
||||
result = f'return std::make_tuple({", ".join(wrapped_returns)});'
|
||||
result = f"return std::make_tuple({', '.join(wrapped_returns)});"
|
||||
return result
|
||||
|
||||
|
||||
@ -168,14 +168,14 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
|
||||
|
||||
return f"""\
|
||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
|
||||
{sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
|
||||
int64_t {cur_level_var} = maybe_layer->layerId();
|
||||
{textwrap.indent(bdims_all_none_case, " ")}
|
||||
{textwrap.indent(unwraps, " ")}
|
||||
batch_rule({', '.join(unwrapped_arg_list)});
|
||||
batch_rule({", ".join(unwrapped_arg_list)});
|
||||
return {schema.arguments.flat_all[0].name};
|
||||
}}"""
|
||||
|
||||
@ -190,14 +190,14 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
|
||||
|
||||
return f"""\
|
||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
|
||||
{sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
|
||||
int64_t {cur_level_var} = maybe_layer->layerId();
|
||||
{textwrap.indent(bdims_all_none_case, " ")}
|
||||
{textwrap.indent(unwraps, " ")}
|
||||
batch_rule({', '.join(unwrapped_arg_list)});
|
||||
batch_rule({", ".join(unwrapped_arg_list)});
|
||||
}}"""
|
||||
|
||||
|
||||
@ -240,14 +240,14 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
|
||||
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
|
||||
return f"""\
|
||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
|
||||
{sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
|
||||
int64_t {cur_level_var} = maybe_layer->layerId();
|
||||
{textwrap.indent(bdims_all_none_case, " ")}
|
||||
{textwrap.indent(unwraps, " ")}
|
||||
auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
|
||||
auto {results_var} = batch_rule({", ".join(unwrapped_arg_list)});
|
||||
{wrapped_returns}
|
||||
}}"""
|
||||
|
||||
|
@ -1822,7 +1822,7 @@ class Annotation:
|
||||
alias_set = f"{alias_set}!"
|
||||
alias_set_after = "|".join(self.alias_set_after)
|
||||
if alias_set_after:
|
||||
alias_set = f'{alias_set}{" -> "}{alias_set_after}'
|
||||
alias_set = f"{alias_set} -> {alias_set_after}"
|
||||
return alias_set
|
||||
|
||||
|
||||
|
@ -534,7 +534,7 @@ def generate_non_out_variant_call(
|
||||
kernel_name = get_kernel_name(g, backend_index)
|
||||
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
||||
namespace_name = "cpu" if g.structured else "native"
|
||||
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
||||
return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
|
||||
|
||||
|
||||
def generate_call_to_view_ops(
|
||||
@ -547,7 +547,7 @@ def generate_call_to_view_ops(
|
||||
kernel_name = kernel.kernel
|
||||
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
||||
namespace_name = "native"
|
||||
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
||||
return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
|
||||
|
||||
|
||||
def generate_out_variant_call(
|
||||
|
Reference in New Issue
Block a user