mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198 Approved by: https://github.com/bobrenjc93
127 lines
4.0 KiB
Python
127 lines
4.0 KiB
Python
import math
|
|
from typing import Any
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
from .loop_body import LoopBody
|
|
from .utils import dominated_nodes
|
|
|
|
|
|
def val_expressable_in_32_bits(val: Any) -> bool:
|
|
if getattr(val, "is_Boolean", False):
|
|
return True
|
|
|
|
if isinstance(val, sympy.Expr):
|
|
assert val.is_number
|
|
if val.is_Integer or val.is_Boolean:
|
|
val = int(val)
|
|
else:
|
|
val = float(val)
|
|
|
|
# bound within mantissa
|
|
if isinstance(val, float):
|
|
return val <= (2**24) and val >= -(2**24)
|
|
|
|
if isinstance(val, int):
|
|
iinfo = torch.iinfo(torch.int32)
|
|
return val <= iinfo.max and val >= iinfo.min
|
|
|
|
raise TypeError(f"Unexpected value {val}")
|
|
|
|
|
|
def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool:
|
|
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
|
|
range.upper
|
|
)
|
|
|
|
|
|
def try_to_reduce_precision(
|
|
node: Any,
|
|
bounds: dict[Any, Any],
|
|
indirect_vars: list[Any],
|
|
indices: dict[Any, sympy.Expr],
|
|
replacement_vals: dict[Any, ValueRanges[sympy.Expr]],
|
|
) -> None:
|
|
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
|
|
# then it's precision is set for that chain of uses, and we don't need to consider those
|
|
# dominated values
|
|
def skip_filter(node: Any) -> bool:
|
|
return node.target == "to_dtype" and node.args[2] in (
|
|
torch.int32,
|
|
torch.float32,
|
|
torch.float64,
|
|
)
|
|
|
|
# TODO - there are dominated uses whose dtype does not depend on whether
|
|
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
|
|
# int32 without changing the output precision of the node. this case hasn't shown up
|
|
for dominated in dominated_nodes([node], skip_filter):
|
|
if dominated.target in ["store", "output"]:
|
|
continue
|
|
|
|
if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
|
|
idx = int(dominated.target[len("set_indirect") :])
|
|
indirect_var = indirect_vars[idx]
|
|
|
|
# We check that we can compute all the indices it's involved in with int32
|
|
for index, expr in indices.items():
|
|
if indirect_var in expr.free_symbols:
|
|
index_val = replacement_vals[index]
|
|
|
|
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
|
|
return
|
|
|
|
# all indices are integers, so make sure that we
|
|
# use the bounds of integers instead of floats.
|
|
# TODO - not sure if we should be doing int/float casts while tracing,
|
|
# might interfere with sympy.
|
|
|
|
index_val_int = ValueRanges[sympy.Expr](
|
|
int(index_val.lower), int(index_val.upper)
|
|
)
|
|
if not range_expressable_in_32_bits(index_val_int):
|
|
return
|
|
|
|
if not range_expressable_in_32_bits(bounds[dominated]):
|
|
return
|
|
|
|
args = list(node.args)
|
|
args[2] = torch.int32
|
|
node.args = tuple(args)
|
|
|
|
|
|
def indexing_dtype_strength_reduction(loop_body: LoopBody) -> None:
|
|
"""
|
|
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
|
|
intermediaries from int64 to int32
|
|
"""
|
|
bv = loop_body.bounds()
|
|
|
|
int64_dtype_nodes = [
|
|
node
|
|
for node in loop_body.get_nodes()
|
|
if (
|
|
node.target == "to_dtype"
|
|
and node.args[2] == torch.int64
|
|
and node not in bv.unbounded_vars
|
|
)
|
|
]
|
|
if not int64_dtype_nodes:
|
|
return
|
|
|
|
bounds = bv.get_bounds()
|
|
|
|
# TODO - if dominated node of one to_dtype is not expressible in int32,
|
|
# we should short circuit another to_dtype node if that node also dominates
|
|
for node in int64_dtype_nodes:
|
|
try_to_reduce_precision(
|
|
node,
|
|
bounds,
|
|
loop_body.indirect_vars,
|
|
loop_body.indexing_exprs,
|
|
bv.replacement_vals,
|
|
)
|