mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Ref https://github.com/pytorch/pytorch/issues/61492#issuecomment-1413003480 The array API specifies correction to be `Union[int, float]` while we currently only support integers. https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html As std/var is calculated currently, the final count of elements is already done in floating point so we can make the correction floating point without any loss of precision or generality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94073 Approved by: https://github.com/ezyang
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
from torchgen.api.lazy import LazyArgument, LazyIrSchema
|
|
from torchgen.api.types import OptionalCType
|
|
|
|
|
|
def ts_lowering_body(schema: LazyIrSchema) -> str:
|
|
# for now, we just want one IR class decl and soon after also the method defs
|
|
# and we use the functional version not out/inplace.
|
|
emplace_arguments = []
|
|
|
|
def get_value(arg: LazyArgument) -> str:
|
|
if isinstance(arg.lazy_type, OptionalCType):
|
|
return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
|
|
return "loctx->GetOutputOp(operand(i++))"
|
|
|
|
for arg in schema.positional_args:
|
|
if arg.is_lazy_value:
|
|
emplace_arguments.append(get_value(arg))
|
|
continue
|
|
emplace_arguments.append(f'"{arg.name}", {arg.name}')
|
|
|
|
emplace_arguments_str = "\n ".join(
|
|
[f"arguments.emplace_back({a});" for a in emplace_arguments]
|
|
)
|
|
emplace_kwarg_values = [
|
|
f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
|
|
]
|
|
emplace_kwarg_scalars = [
|
|
f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
|
|
]
|
|
emplace_kwarguments = "\n ".join(
|
|
[
|
|
f"kwarguments.emplace_back({a});"
|
|
for a in emplace_kwarg_values + emplace_kwarg_scalars
|
|
]
|
|
)
|
|
return f"""\
|
|
std::vector<torch::jit::NamedValue> arguments;
|
|
std::vector<torch::jit::NamedValue> kwarguments;
|
|
arguments.reserve({len(emplace_arguments)});
|
|
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
|
|
size_t i = 0;
|
|
{emplace_arguments_str}
|
|
{emplace_kwarguments}
|
|
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
|
|
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
|
|
|
|
return {schema.aten_name}_out;
|
|
"""
|