Files
pytorch/torchgen/dest/lazy_ts_lowering.py
Peter Bell bc438af6fe std/var: support floating point correction value (#94073)
Ref https://github.com/pytorch/pytorch/issues/61492#issuecomment-1413003480

The array API specifies correction to be `Union[int, float]` while we currently only support integers.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html

As std/var is calculated currently, the final count of elements is already done
in floating point so we can make the correction floating point without any loss
of precision or generality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94073
Approved by: https://github.com/ezyang
2023-02-23 05:50:45 +00:00

49 lines
1.8 KiB
Python

from torchgen.api.lazy import LazyArgument, LazyIrSchema
from torchgen.api.types import OptionalCType
def ts_lowering_body(schema: LazyIrSchema) -> str:
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
emplace_arguments = []
def get_value(arg: LazyArgument) -> str:
if isinstance(arg.lazy_type, OptionalCType):
return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
return "loctx->GetOutputOp(operand(i++))"
for arg in schema.positional_args:
if arg.is_lazy_value:
emplace_arguments.append(get_value(arg))
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')
emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments]
)
emplace_kwarg_values = [
f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
]
emplace_kwarg_scalars = [
f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
]
emplace_kwarguments = "\n ".join(
[
f"kwarguments.emplace_back({a});"
for a in emplace_kwarg_values + emplace_kwarg_scalars
]
)
return f"""\
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
"""