mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for nonzero, some improvements to reduce guards (#95387)
This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# It is essentially https://github.com/pytorch/pytorch/pull/95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1. What's in the PR: * nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question. * The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise. * PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`) * Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`) * I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful. * Added `constrain_unify` to let you specify two unbacked SymInts must have the same value Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/95387 Approved by: https://github.com/voznesenskym
This commit is contained in:
committed by
PyTorch MergeBot
parent
627282fa6c
commit
4833e47feb
@ -49,7 +49,7 @@ from common_utils import (
|
|||||||
)
|
)
|
||||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
||||||
from torch.fx.experimental.proxy_tensor import is_sym_node
|
from torch.fx.experimental.proxy_tensor import is_sym_node
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, GuardOnDataDependentSymNode
|
||||||
|
|
||||||
USE_TORCHVISION = False
|
USE_TORCHVISION = False
|
||||||
try:
|
try:
|
||||||
@ -2412,7 +2412,6 @@ symbolic_aot_autograd_failures = {
|
|||||||
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||||
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||||
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
|
||||||
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||||
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||||
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||||
@ -2613,7 +2612,16 @@ def _test_aot_autograd_helper(self, device, dtype, op):
|
|||||||
return op.op(*c_args, **c_kwargs)
|
return op.op(*c_args, **c_kwargs)
|
||||||
|
|
||||||
compiled_f = compiled_function(f, nop, nop)
|
compiled_f = compiled_function(f, nop, nop)
|
||||||
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args)
|
try:
|
||||||
|
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args)
|
||||||
|
except GuardOnDataDependentSymNode:
|
||||||
|
# Carveout for getitem; I don't want to xfail the entire test
|
||||||
|
# because that will reject known to be good tests see
|
||||||
|
# https://github.com/pytorch/pytorch/issues/94705
|
||||||
|
if op.name == "__getitem__":
|
||||||
|
self.skipTest("Dynamic output shape operation in trace")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
def _test_aot_autograd_module_helper(self, device, dtype, training, module_info):
|
def _test_aot_autograd_module_helper(self, device, dtype, training, module_info):
|
||||||
module_cls = module_info.module_cls
|
module_cls = module_info.module_cls
|
||||||
|
@ -1733,6 +1733,7 @@ class TestRefsOpsInfo(TestCase):
|
|||||||
skip_ref_ops = {
|
skip_ref_ops = {
|
||||||
'_refs.bitwise_right_shift',
|
'_refs.bitwise_right_shift',
|
||||||
'_refs.copy_to',
|
'_refs.copy_to',
|
||||||
|
'_refs.empty_permuted',
|
||||||
'_refs.empty_strided',
|
'_refs.empty_strided',
|
||||||
'_refs.equal',
|
'_refs.equal',
|
||||||
'_refs.full',
|
'_refs.full',
|
||||||
@ -1846,6 +1847,7 @@ class TestRefsOpsInfo(TestCase):
|
|||||||
'_refs.scalar_tensor', # missing "layout"
|
'_refs.scalar_tensor', # missing "layout"
|
||||||
# other
|
# other
|
||||||
'_refs.empty', # intentional; direct empty is faster and has less guards
|
'_refs.empty', # intentional; direct empty is faster and has less guards
|
||||||
|
'_refs.empty_permuted', # intentional; direct empty is faster and has less guards
|
||||||
'_refs.expand_as',
|
'_refs.expand_as',
|
||||||
'_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
|
'_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
|
||||||
'_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to
|
'_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to
|
||||||
|
@ -13,7 +13,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDepen
|
|||||||
from torch._decomp import decomposition_table
|
from torch._decomp import decomposition_table
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
|
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
|
||||||
constrain_range
|
constrain_range, constrain_unify, guard_int
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import ops
|
from torch.testing._internal.common_device_type import ops
|
||||||
from torch._C import _disabled_torch_function_impl
|
from torch._C import _disabled_torch_function_impl
|
||||||
@ -912,6 +912,115 @@ def forward(self, a_1):
|
|||||||
return empty""" # noqa: B950
|
return empty""" # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_dynamic_pointwise_scalar(self):
|
||||||
|
def f(gravity, mask):
|
||||||
|
gravity[mask, 0] = gravity[mask, 0] * -1
|
||||||
|
|
||||||
|
r = str(make_fx(f, tracing_mode="symbolic")(
|
||||||
|
torch.randn((12, 4)),
|
||||||
|
torch.randint(0, 2, (12,), dtype=torch.bool)
|
||||||
|
).code).strip()
|
||||||
|
self.assertExpectedInline(r, """\
|
||||||
|
def forward(self, gravity_1, mask_1):
|
||||||
|
select = torch.ops.aten.select.int(gravity_1, 1, 0)
|
||||||
|
index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
|
||||||
|
mul = torch.ops.aten.mul.Tensor(index, -1); index = None
|
||||||
|
select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
|
||||||
|
index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None
|
||||||
|
return None""")
|
||||||
|
|
||||||
|
def test_reflect_r_over_x(self):
|
||||||
|
def reflect_R_over_x(R):
|
||||||
|
reflect = torch.eye(3, device=R.device)
|
||||||
|
reflect[0, 0] = -1
|
||||||
|
return reflect @ R @ reflect
|
||||||
|
|
||||||
|
def f(crop_camera, mask):
|
||||||
|
crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
|
||||||
|
|
||||||
|
r = str(make_fx(f, tracing_mode="symbolic")(
|
||||||
|
torch.randn((12, 3, 3)),
|
||||||
|
torch.randint(0, 2, (12,), dtype=torch.bool)
|
||||||
|
).code).strip()
|
||||||
|
self.assertExpectedInline(r, """\
|
||||||
|
def forward(self, crop_camera_1, mask_1):
|
||||||
|
index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
|
||||||
|
eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
|
||||||
|
_tensor_constant0 = self._tensor_constant0
|
||||||
|
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||||
|
select = torch.ops.aten.select.int(eye, 0, 0)
|
||||||
|
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
|
||||||
|
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
|
||||||
|
transpose = torch.ops.aten.transpose.int(index, -2, -1)
|
||||||
|
t = torch.ops.aten.t.default(eye)
|
||||||
|
clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None
|
||||||
|
sym_size = torch.ops.aten.sym_size(index, 0); index = None
|
||||||
|
sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2)
|
||||||
|
mul = sym_size * sym_size_1
|
||||||
|
sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1)
|
||||||
|
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]); clone = mul = sym_size_2 = None
|
||||||
|
mm = torch.ops.aten.mm.default(_unsafe_view, t); _unsafe_view = t = None
|
||||||
|
view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]); mm = sym_size_1 = None
|
||||||
|
transpose_1 = torch.ops.aten.transpose.int(view, -2, -1)
|
||||||
|
clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None
|
||||||
|
mul_1 = sym_size * 3
|
||||||
|
sym_size_3 = torch.ops.aten.sym_size(view, 1); view = None
|
||||||
|
view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]); clone_1 = mul_1 = sym_size_3 = None
|
||||||
|
mm_1 = torch.ops.aten.mm.default(view_1, eye); view_1 = eye = None
|
||||||
|
view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]); mm_1 = sym_size = None
|
||||||
|
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
|
||||||
|
return None""")
|
||||||
|
|
||||||
|
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
||||||
|
def test_unbacked_batch_resnet(self):
|
||||||
|
mod = torchvision.models.resnet18()
|
||||||
|
|
||||||
|
def f(x, mask, params, buffers):
|
||||||
|
for p in itertools.chain([x, mask], params.values(), buffers.values()):
|
||||||
|
for s in p.shape:
|
||||||
|
guard_int(s)
|
||||||
|
x = x[mask]
|
||||||
|
constrain_range(x.shape[0], min=1)
|
||||||
|
for p in params.values():
|
||||||
|
p.grad = None
|
||||||
|
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
|
||||||
|
|
||||||
|
make_fx(f, tracing_mode="symbolic")(
|
||||||
|
torch.randn(3, 3, 250, 250),
|
||||||
|
torch.randint(0, 2, (3,), dtype=torch.bool),
|
||||||
|
dict(mod.named_parameters()),
|
||||||
|
dict(mod.named_buffers()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_boolean_index(self):
|
||||||
|
def f(images, handedness, valid):
|
||||||
|
images = images[valid]
|
||||||
|
handedness = handedness[valid]
|
||||||
|
zi = images.shape[0]
|
||||||
|
zh = handedness.shape[0]
|
||||||
|
# NB: We wouldn't actually need this if we could cache
|
||||||
|
# the result of running valid.nonzero() and assign the same
|
||||||
|
# SymInt in both cases. This is a workaround in lieu of
|
||||||
|
# that memoization.
|
||||||
|
constrain_unify(zi, zh)
|
||||||
|
right_hand_mask = handedness == 1
|
||||||
|
images[right_hand_mask] = images[right_hand_mask].flip(-1)
|
||||||
|
|
||||||
|
r = str(make_fx(f, tracing_mode="symbolic")(
|
||||||
|
torch.randint(0, 256, (512, 1, 96, 96)),
|
||||||
|
torch.randint(0, 1, (512,)),
|
||||||
|
torch.randint(0, 2, (512,), dtype=torch.bool)
|
||||||
|
).code).strip()
|
||||||
|
self.assertExpectedInline(r, """\
|
||||||
|
def forward(self, images_1, handedness_1, valid_1):
|
||||||
|
index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
|
||||||
|
index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
|
||||||
|
eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
|
||||||
|
index_2 = torch.ops.aten.index.Tensor(index, [eq])
|
||||||
|
flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
|
||||||
|
index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None
|
||||||
|
return None""")
|
||||||
|
|
||||||
def test_neg_shape(self):
|
def test_neg_shape(self):
|
||||||
def f(a):
|
def f(a):
|
||||||
return torch.empty(-a.shape[0] + 10)
|
return torch.empty(-a.shape[0] + 10)
|
||||||
@ -1202,7 +1311,6 @@ symbolic_tensor_failures = {
|
|||||||
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
|
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
|
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
|
|
||||||
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
|
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
|
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
|
||||||
@ -1317,7 +1425,6 @@ symbolic_tensor_failures = {
|
|||||||
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
|
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
|
||||||
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
|
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
|
||||||
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
|
|
||||||
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
|
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
|
||||||
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
|
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
|
||||||
xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
|
xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
|
||||||
|
@ -161,6 +161,13 @@ repro_tolerance = 1e-3
|
|||||||
# This requires dynamic_shapes to be True.
|
# This requires dynamic_shapes to be True.
|
||||||
capture_scalar_outputs = False
|
capture_scalar_outputs = False
|
||||||
|
|
||||||
|
# Not all backends support operators that have dynamic output shape (e.g.,
|
||||||
|
# nonzero, unique). When this flag is set to False, we introduce a graph
|
||||||
|
# break instead of capturing. This requires dynamic_shapes to be True.
|
||||||
|
# If you set this to True, you probably also want capture_scalar_outputs
|
||||||
|
# (these are separated for historical reasons).
|
||||||
|
capture_dynamic_output_shape_ops = False
|
||||||
|
|
||||||
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
|
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
|
||||||
# false_fn produces code with identical guards.
|
# false_fn produces code with identical guards.
|
||||||
enforce_cond_guards_match = True
|
enforce_cond_guards_match = True
|
||||||
|
@ -191,6 +191,7 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
|||||||
fake_mode = torch._subclasses.FakeTensorMode(
|
fake_mode = torch._subclasses.FakeTensorMode(
|
||||||
shape_env=ShapeEnv(
|
shape_env=ShapeEnv(
|
||||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||||
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||||
strict_mark_dyn=export,
|
strict_mark_dyn=export,
|
||||||
assume_static_by_default=config.assume_static_by_default,
|
assume_static_by_default=config.assume_static_by_default,
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,6 @@ from torch._prims_common import (
|
|||||||
from torch._prims_common.wrappers import out_wrapper
|
from torch._prims_common.wrappers import out_wrapper
|
||||||
from torch._refs import _broadcast_shapes
|
from torch._refs import _broadcast_shapes
|
||||||
|
|
||||||
from torch._subclasses.fake_tensor import check_no_bool_index_tensors
|
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
|
||||||
@ -996,7 +995,6 @@ def vdot(self, other):
|
|||||||
# get shape inference through structured kernels
|
# get shape inference through structured kernels
|
||||||
@register_meta(aten.index.Tensor)
|
@register_meta(aten.index.Tensor)
|
||||||
def meta_index_Tensor(self, indices):
|
def meta_index_Tensor(self, indices):
|
||||||
check_no_bool_index_tensors(aten.index.Tensor, self, indices)
|
|
||||||
check(indices, lambda: "at least one index must be provided")
|
check(indices, lambda: "at least one index must be provided")
|
||||||
# aten::index is the internal advanced indexing implementation
|
# aten::index is the internal advanced indexing implementation
|
||||||
# checkIndexTensorTypes and expandTensors
|
# checkIndexTensorTypes and expandTensors
|
||||||
|
@ -347,7 +347,7 @@ def _elementwise_meta(
|
|||||||
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
|
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
|
||||||
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
|
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||||
|
|
||||||
strides = utils.compute_elementwise_output_strides(*args_)
|
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
|
||||||
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
|
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||||
|
|
||||||
# Acquires the dtype
|
# Acquires the dtype
|
||||||
@ -398,7 +398,8 @@ def _elementwise_meta(
|
|||||||
else:
|
else:
|
||||||
dtype = dtype
|
dtype = dtype
|
||||||
|
|
||||||
return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype)
|
assert shape is not None
|
||||||
|
return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value]
|
||||||
|
|
||||||
# Number case
|
# Number case
|
||||||
# TODO: fix number type promotion (bool, complex->float)
|
# TODO: fix number type promotion (bool, complex->float)
|
||||||
|
@ -77,6 +77,7 @@ torch_function_passthrough = {
|
|||||||
torch.Tensor.device.__get__, # type: ignore[attr-defined]
|
torch.Tensor.device.__get__, # type: ignore[attr-defined]
|
||||||
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
|
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
|
||||||
torch.Tensor.layout.__get__, # type: ignore[attr-defined]
|
torch.Tensor.layout.__get__, # type: ignore[attr-defined]
|
||||||
|
torch.Tensor.is_contiguous,
|
||||||
# For TorchRefsMode only
|
# For TorchRefsMode only
|
||||||
torch.Tensor.__format__,
|
torch.Tensor.__format__,
|
||||||
torch.Tensor.__repr__,
|
torch.Tensor.__repr__,
|
||||||
@ -346,33 +347,41 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
|||||||
# non overlapping and dense strides.
|
# non overlapping and dense strides.
|
||||||
# This is also INCORRECT because it does not model TensorIterator's
|
# This is also INCORRECT because it does not model TensorIterator's
|
||||||
# short-circuit, which can cause different strides.
|
# short-circuit, which can cause different strides.
|
||||||
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
def compute_elementwise_output_logical_to_physical_perm(*tensors, _skip_checks=False) -> List[int]:
|
||||||
"""
|
if not _skip_checks and len(tensors) == 0:
|
||||||
Computes the output strides for elementwise operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(tensors) == 0:
|
|
||||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
|
if not _skip_checks:
|
||||||
|
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
|
||||||
|
|
||||||
# Filters the tensors to actual tensors
|
# Filters the tensors to actual tensors
|
||||||
tensors = tuple(
|
if not _skip_checks:
|
||||||
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
|
tensors = tuple(
|
||||||
)
|
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
|
||||||
|
)
|
||||||
|
|
||||||
# Short-circuits for CPU scalar case
|
# Short-circuits for CPU scalar case
|
||||||
if len(tensors) == 0:
|
if len(tensors) == 0:
|
||||||
return ()
|
return []
|
||||||
|
|
||||||
# Short-circuits for shapes with zero or one dimensions
|
# Short-circuits for shapes with zero or one dimensions
|
||||||
# TODO: are these necessary?
|
# TODO: are these necessary?
|
||||||
ndim = tensors[0].ndim
|
ndim = tensors[0].ndim
|
||||||
if ndim == 0:
|
if ndim == 0:
|
||||||
return ()
|
return []
|
||||||
if ndim == 1:
|
if ndim == 1:
|
||||||
return (1,)
|
return [0]
|
||||||
|
|
||||||
|
# Short-circuits if contiguous, following the fake fast path.
|
||||||
|
# This reduces the number of guards we end up making
|
||||||
|
# TODO: do channels last too
|
||||||
|
is_contiguous = True
|
||||||
|
for t in tensors:
|
||||||
|
is_contiguous = is_contiguous and t.is_contiguous(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
if is_contiguous:
|
||||||
|
return list(range(ndim))
|
||||||
|
|
||||||
shape = tensors[0].shape
|
shape = tensors[0].shape
|
||||||
|
|
||||||
@ -398,6 +407,11 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
|||||||
# or all strides are equal and all dimensions have the same length
|
# or all strides are equal and all dimensions have the same length
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
# The "sort" order for the permutation is back-to-front, but
|
||||||
|
# the natural order for permutations is front-to-back. Do the
|
||||||
|
# sorting back-to-front and then reverse it on output.
|
||||||
|
#
|
||||||
|
# also, note this returns the logical to physical shape permutation
|
||||||
perm = list(reversed(range(ndim)))
|
perm = list(reversed(range(ndim)))
|
||||||
|
|
||||||
# insertion sort with support for ambiguous comparisons
|
# insertion sort with support for ambiguous comparisons
|
||||||
@ -411,18 +425,64 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
|||||||
elif comparison < 0:
|
elif comparison < 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
permuted_shape = [-1] * ndim
|
return list(reversed(perm))
|
||||||
for idx, x in enumerate(reversed(perm)):
|
|
||||||
permuted_shape[idx] = shape[x]
|
|
||||||
|
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
||||||
|
"""
|
||||||
|
Computes the output strides for elementwise operations.
|
||||||
|
"""
|
||||||
|
if len(tensors) == 0:
|
||||||
|
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
|
||||||
|
|
||||||
|
# Filters the tensors to actual tensors
|
||||||
|
tensors = tuple(
|
||||||
|
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Short-circuits for CPU scalar case
|
||||||
|
if len(tensors) == 0:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
ndim = tensors[0].ndim
|
||||||
|
shape = tensors[0].shape
|
||||||
|
|
||||||
|
if ndim == 0:
|
||||||
|
return ()
|
||||||
|
if ndim == 1:
|
||||||
|
return (1,)
|
||||||
|
|
||||||
|
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
|
||||||
|
*tensors, _skip_checks=True
|
||||||
|
)
|
||||||
|
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
|
||||||
|
|
||||||
new_strides = make_contiguous_strides_for(permuted_shape)
|
new_strides = make_contiguous_strides_for(permuted_shape)
|
||||||
permuted_strides = [-1] * ndim
|
permuted_strides = apply_perm(new_strides, invert_perm(logical_to_physical_perm)) # to logical
|
||||||
for idx, x in enumerate(reversed(perm)):
|
|
||||||
permuted_strides[x] = new_strides[idx]
|
|
||||||
|
|
||||||
return tuple(permuted_strides)
|
return tuple(permuted_strides)
|
||||||
|
|
||||||
|
|
||||||
|
# Identity permutation is [0, 1, 2]
|
||||||
|
def apply_perm(inp, perm):
|
||||||
|
ndim = len(inp)
|
||||||
|
permuted_inp = [-1] * ndim
|
||||||
|
for idx, x in enumerate(perm):
|
||||||
|
permuted_inp[idx] = inp[x]
|
||||||
|
return permuted_inp
|
||||||
|
|
||||||
|
|
||||||
|
def invert_perm(perm):
|
||||||
|
ndim = len(perm)
|
||||||
|
new_perm = [-1] * ndim
|
||||||
|
for idx, x in enumerate(perm):
|
||||||
|
new_perm[x] = idx
|
||||||
|
return new_perm
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Common helper functions
|
# Common helper functions
|
||||||
#
|
#
|
||||||
|
@ -276,6 +276,7 @@ __all__ = [
|
|||||||
"arange",
|
"arange",
|
||||||
"empty",
|
"empty",
|
||||||
"empty_like",
|
"empty_like",
|
||||||
|
"empty_permuted",
|
||||||
"empty_strided",
|
"empty_strided",
|
||||||
"eye",
|
"eye",
|
||||||
"full",
|
"full",
|
||||||
@ -4055,9 +4056,7 @@ def empty_permuted(
|
|||||||
shape,
|
shape,
|
||||||
physical_layout,
|
physical_layout,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
layout=layout,
|
|
||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory,
|
|
||||||
requires_grad=requires_grad,
|
requires_grad=requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4274,10 +4273,13 @@ def empty_like(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# memory_format == torch.preserve_format
|
# memory_format == torch.preserve_format
|
||||||
strides = utils.compute_elementwise_output_strides(a)
|
logical_to_physical_perm = (
|
||||||
return torch.empty_strided(
|
utils.compute_elementwise_output_logical_to_physical_perm(a)
|
||||||
|
)
|
||||||
|
# identity perm is [2, 1, 0]
|
||||||
|
return torch.empty_permuted(
|
||||||
a.shape,
|
a.shape,
|
||||||
strides,
|
logical_to_physical_perm,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
layout=layout,
|
layout=layout,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -397,7 +397,7 @@ def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
|
|||||||
# index.Tensor data-dependent in only some conditions
|
# index.Tensor data-dependent in only some conditions
|
||||||
@register_op_impl(
|
@register_op_impl(
|
||||||
lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
|
lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
|
||||||
and func != aten.index.Tensor
|
and func not in [aten.index.Tensor, aten.nonzero.default]
|
||||||
)
|
)
|
||||||
def dyn_shape(fake_mode, func, *args, **kwargs):
|
def dyn_shape(fake_mode, func, *args, **kwargs):
|
||||||
raise DynamicOutputShapeException(func)
|
raise DynamicOutputShapeException(func)
|
||||||
@ -405,11 +405,9 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
|
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
|
||||||
def local_scalar_dense(fake_mode, func, arg):
|
def local_scalar_dense(fake_mode, func, arg):
|
||||||
if fake_mode.shape_env is None:
|
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
|
||||||
# Without symints/symfloats, cannot handle this
|
# Without symints/symfloats, cannot handle this
|
||||||
raise DataDependentOutputException(func)
|
raise DataDependentOutputException(func)
|
||||||
if not fake_mode.shape_env.allow_scalar_outputs:
|
|
||||||
raise DataDependentOutputException(func)
|
|
||||||
if is_float_dtype(arg.dtype):
|
if is_float_dtype(arg.dtype):
|
||||||
return fake_mode.shape_env.create_unbacked_symfloat()
|
return fake_mode.shape_env.create_unbacked_symfloat()
|
||||||
elif is_integer_dtype(arg.dtype):
|
elif is_integer_dtype(arg.dtype):
|
||||||
@ -418,6 +416,36 @@ def local_scalar_dense(fake_mode, func, arg):
|
|||||||
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
|
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_impl(lambda func: func is torch.ops.aten.nonzero.default)
|
||||||
|
def nonzero(fake_mode, func, arg):
|
||||||
|
if (
|
||||||
|
fake_mode.shape_env is None
|
||||||
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
||||||
|
):
|
||||||
|
# Without symints/symfloats, cannot handle this
|
||||||
|
raise DynamicOutputShapeException(func)
|
||||||
|
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||||
|
|
||||||
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
constrain_range,
|
||||||
|
definitely_true,
|
||||||
|
guard_int,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is unsound, but it works well in practice
|
||||||
|
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
|
||||||
|
# TODO: Add a config knob to turn off this unsound behavior
|
||||||
|
lower = 2
|
||||||
|
upper = None
|
||||||
|
# But don't give totally unsatisfiable bounds if we know it's too small!
|
||||||
|
if definitely_true(arg.numel() < 2):
|
||||||
|
lower = 0
|
||||||
|
upper = guard_int(arg.numel())
|
||||||
|
constrain_range(nnz, min=lower, max=upper)
|
||||||
|
|
||||||
|
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
# NB: this must be ordered after local_scalar_dense
|
# NB: this must be ordered after local_scalar_dense
|
||||||
@register_op_impl(
|
@register_op_impl(
|
||||||
lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined]
|
lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined]
|
||||||
@ -451,10 +479,17 @@ def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
|
|||||||
# index tensors with cuda self
|
# index tensors with cuda self
|
||||||
@register_op_impl(aten.index.Tensor)
|
@register_op_impl(aten.index.Tensor)
|
||||||
def index_tensor(fake_mode, func, *args, **kwargs):
|
def index_tensor(fake_mode, func, *args, **kwargs):
|
||||||
# dynamic shape op if indices are bool/uint8
|
from torch._meta_registrations import meta_index_Tensor
|
||||||
check_no_bool_index_tensors(func, *args, **kwargs)
|
|
||||||
|
|
||||||
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
|
_, new_kwargs = normalize_function(
|
||||||
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
|
out_device = new_kwargs["input"].device
|
||||||
|
# ensure nonzero call goes to fake tensor
|
||||||
|
with fake_mode:
|
||||||
|
out = meta_index_Tensor(*args, **kwargs)
|
||||||
|
return out.to(out_device)
|
||||||
|
|
||||||
|
|
||||||
# takes in multiple-devices, dont default to default device handling
|
# takes in multiple-devices, dont default to default device handling
|
||||||
@ -493,7 +528,15 @@ def conv(fake_mode, func, *args, **kwargs):
|
|||||||
with fake_mode:
|
with fake_mode:
|
||||||
# if the input is unsqueezed is done in Convolution.cpp we get segfault
|
# if the input is unsqueezed is done in Convolution.cpp we get segfault
|
||||||
k = kwargs["weight"].ndim
|
k = kwargs["weight"].ndim
|
||||||
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
batch = kwargs["input"].shape[0]
|
||||||
|
|
||||||
|
from torch.fx.experimental.symbolic_shapes import has_hint
|
||||||
|
|
||||||
|
if not has_hint(batch):
|
||||||
|
# TODO: We can make this a little more faithful with best effort
|
||||||
|
# channels last detection (but only if it's statically obvious!)
|
||||||
|
mem_fmt = None
|
||||||
|
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
||||||
mem_fmt = None
|
mem_fmt = None
|
||||||
else:
|
else:
|
||||||
if func is aten.convolution.default:
|
if func is aten.convolution.default:
|
||||||
|
@ -108,6 +108,81 @@ def hint_int(a):
|
|||||||
assert type(a) is int, a
|
assert type(a) is int, a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
def has_hint(a):
|
||||||
|
if isinstance(a, torch.SymInt):
|
||||||
|
return a.node.has_hint()
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Returns True if every size dim on the tensor has a hint
|
||||||
|
# TODO: Should this include strides too? For now it doesn't matter,
|
||||||
|
# that's quite an obscure case
|
||||||
|
def tensor_has_hints(t):
|
||||||
|
return all(has_hint(s) for s in t.size())
|
||||||
|
|
||||||
|
def definitely_true(a):
|
||||||
|
"""
|
||||||
|
Returns True only if we can tell that a is True, possibly introducing
|
||||||
|
a guard in the process. If a depends on some unbacked SymInt, we may
|
||||||
|
return False even though there may exist a possible value of the SymInt
|
||||||
|
that would cause the expression to return True.
|
||||||
|
|
||||||
|
When is it appropriate to use definitely_true? First, if you can use
|
||||||
|
a higher level combinator like parallel_or/parallel_and, prefer using
|
||||||
|
those instead, they are definitely safe (modulo short-circuiting).
|
||||||
|
Second, it can be used if the program would behave equivalently if
|
||||||
|
definitely_true always returned False (parallel_or/parallel_and are
|
||||||
|
examples of this pattern, modulo short-circuiting). Finally, it even
|
||||||
|
be OK if the program wouldn't behave equivalently, so long as the
|
||||||
|
change is semantics preserving. It can be semantics preserving if
|
||||||
|
the program errors in more cases than it did previously (but otherwise
|
||||||
|
behaves identically), or if it changes some quantity in a way that
|
||||||
|
doesn't matter (e.g., strides often fall in this bucket.)
|
||||||
|
"""
|
||||||
|
if isinstance(a, SymBool):
|
||||||
|
if a.node.has_hint():
|
||||||
|
return guard_bool(a)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return bool(a)
|
||||||
|
|
||||||
|
def definitely_false(a):
|
||||||
|
"""
|
||||||
|
Returns True only if we can tell that a is False, possibly introducing
|
||||||
|
a guard in the process. If a depends on some unbacked SymInt, we may
|
||||||
|
return False even though there may exist a possible value of the SymInt
|
||||||
|
that would cause the expression a to be False. See definitely_true
|
||||||
|
for more usage guidance.
|
||||||
|
"""
|
||||||
|
if isinstance(a, SymBool):
|
||||||
|
if a.node.has_hint():
|
||||||
|
return not guard_bool(a)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return not bool(a)
|
||||||
|
|
||||||
|
# TODO: could improve parallel_or/parallel_and by avoiding guards
|
||||||
|
# if there exists a quantity that can be handled un-guardedly. However,
|
||||||
|
# for backed SymInts, avoiding guards doesn't really matter in practice,
|
||||||
|
# so I chose not to do it.
|
||||||
|
|
||||||
|
def parallel_or(*args):
|
||||||
|
"""
|
||||||
|
Evaluate the logical OR of several arguments, avoiding guarding on
|
||||||
|
unbacked SymInts if another argument is definitely True.
|
||||||
|
"""
|
||||||
|
if any(definitely_true(args) for a in args):
|
||||||
|
return True
|
||||||
|
return any(args)
|
||||||
|
|
||||||
|
def parallel_and(*args):
|
||||||
|
"""
|
||||||
|
Evaluate the logical FALSE of several arguments, avoiding guarding on
|
||||||
|
unbacked SymInts if another argument is definitely False.
|
||||||
|
"""
|
||||||
|
if any(definitely_false(args) for a in args):
|
||||||
|
return False
|
||||||
|
return all(args)
|
||||||
|
|
||||||
def guard_scalar(a):
|
def guard_scalar(a):
|
||||||
if isinstance(a, (SymBool, bool)):
|
if isinstance(a, (SymBool, bool)):
|
||||||
return guard_bool(a)
|
return guard_bool(a)
|
||||||
@ -138,6 +213,34 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def constrain_unify(a, b):
|
||||||
|
"""
|
||||||
|
Given two SymInts, constrain them so that they must be equal. NB:
|
||||||
|
this will not work with SymInts that represent nontrivial expressions
|
||||||
|
(yet!)
|
||||||
|
"""
|
||||||
|
# TODO: Maybe dedupe this with _maybe_guard_eq?
|
||||||
|
if not isinstance(a, SymInt):
|
||||||
|
if not isinstance(b, SymInt):
|
||||||
|
assert a == b
|
||||||
|
else:
|
||||||
|
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||||
|
shape_env = b.node.shape_env
|
||||||
|
shape_env.replacements[b.node.expr] = sympy.Integer(a)
|
||||||
|
else:
|
||||||
|
# TODO: Actually, we can support this as long as one of them is a symbol.
|
||||||
|
# NB: We can't actually do "unification" as our operators are not
|
||||||
|
# injective
|
||||||
|
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||||
|
shape_env = a.node.shape_env
|
||||||
|
if not isinstance(b, SymInt):
|
||||||
|
shape_env.replacements[a.node.expr] = sympy.Integer(b)
|
||||||
|
else:
|
||||||
|
assert a.node.shape_env is b.node.shape_env
|
||||||
|
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||||
|
new_var = shape_env._find(a.node.expr)
|
||||||
|
shape_env.replacements[b.node.expr] = new_var
|
||||||
|
|
||||||
def guard_bool(a):
|
def guard_bool(a):
|
||||||
if isinstance(a, SymBool):
|
if isinstance(a, SymBool):
|
||||||
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
||||||
@ -242,7 +345,13 @@ class SymNode:
|
|||||||
# simplify it into a hint
|
# simplify it into a hint
|
||||||
def _update_hint(self):
|
def _update_hint(self):
|
||||||
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
|
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
|
||||||
self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
|
new_hint = self.shape_env.replace(self._hint_expr)
|
||||||
|
# NB: unification constraints could result in a replacement that
|
||||||
|
# doesn't actually solve the hint! Check for this.
|
||||||
|
if new_hint.free_symbols:
|
||||||
|
self._hint_expr = new_hint
|
||||||
|
return
|
||||||
|
self._hint = self.pytype(new_hint)
|
||||||
self._hint_expr = None
|
self._hint_expr = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1076,6 +1185,7 @@ class ShapeEnv:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, *,
|
self, *,
|
||||||
allow_scalar_outputs=True,
|
allow_scalar_outputs=True,
|
||||||
|
allow_dynamic_output_shape_ops=True,
|
||||||
strict_mark_dyn=False,
|
strict_mark_dyn=False,
|
||||||
assume_static_by_default=False,
|
assume_static_by_default=False,
|
||||||
# The following options affect decisions we make about eager
|
# The following options affect decisions we make about eager
|
||||||
@ -1093,6 +1203,7 @@ class ShapeEnv:
|
|||||||
):
|
):
|
||||||
# Not directly used by ShapeEnv; indirectly used by FakeTensor
|
# Not directly used by ShapeEnv; indirectly used by FakeTensor
|
||||||
self.allow_scalar_outputs = allow_scalar_outputs
|
self.allow_scalar_outputs = allow_scalar_outputs
|
||||||
|
self.allow_dynamic_output_shape_ops = allow_dynamic_output_shape_ops
|
||||||
self.guards: List[ShapeGuard] = []
|
self.guards: List[ShapeGuard] = []
|
||||||
# Maps symbolic ints to their original concrete values
|
# Maps symbolic ints to their original concrete values
|
||||||
# Currently populated from tensors
|
# Currently populated from tensors
|
||||||
@ -1244,12 +1355,10 @@ class ShapeEnv:
|
|||||||
if not dyn:
|
if not dyn:
|
||||||
# Non explicitly marked dynamic dims register to val_to_var to get duck shaped
|
# Non explicitly marked dynamic dims register to val_to_var to get duck shaped
|
||||||
self.val_to_var[val] = sympy_expr
|
self.val_to_var[val] = sympy_expr
|
||||||
# We also infer that they must not be 0/1
|
|
||||||
lower = 2 if self.specialize_zero_one else 0
|
# We also infer that it must be not 0/1
|
||||||
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
|
lower = 2 if self.specialize_zero_one else 0
|
||||||
else:
|
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
|
||||||
# Avoid up front 0/1 specializing dynamic dims
|
|
||||||
self.var_to_range[sympy_expr] = ValueRanges(0, sympy.oo)
|
|
||||||
|
|
||||||
if not dyn and self.duck_shape:
|
if not dyn and self.duck_shape:
|
||||||
# This implements duck-shaping: input sizes that match are assigned
|
# This implements duck-shaping: input sizes that match are assigned
|
||||||
@ -1556,15 +1665,29 @@ class ShapeEnv:
|
|||||||
Tries to evaluate expr without introducing guards
|
Tries to evaluate expr without introducing guards
|
||||||
"""
|
"""
|
||||||
expr = self.simplify(expr)
|
expr = self.simplify(expr)
|
||||||
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
|
|
||||||
|
# Simplify making use of value range lower bound
|
||||||
symbols = list(expr.free_symbols)
|
symbols = list(expr.free_symbols)
|
||||||
new_shape_env = {
|
new_shape_env = {}
|
||||||
k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1
|
new_range_env = {}
|
||||||
for idx, k in enumerate(symbols)
|
for idx, k in enumerate(symbols):
|
||||||
# Do not assume unbacked symints are > 1
|
vr = self.var_to_range[k]
|
||||||
# If we didn't specialize 0/1, this shape env is empty
|
# Don't do anything if we don't have a nontrivial lower bound
|
||||||
if k in self.var_to_val and self.specialize_zero_one
|
if vr.lower == -sympy.oo:
|
||||||
}
|
new_range_env[k] = vr
|
||||||
|
continue
|
||||||
|
# Positive means >= 1
|
||||||
|
# Positive - 1 means >= 0
|
||||||
|
# Positive + lower - 1 means >= lower
|
||||||
|
# The new symbol 's' is "too low", so when we substitute it in
|
||||||
|
# we have to increase it by offset (and conversely, the new
|
||||||
|
# variables have to have their value range bounds adjusted as
|
||||||
|
# well)
|
||||||
|
s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
|
||||||
|
offset = vr.lower - 1
|
||||||
|
new_shape_env[k] = s + offset
|
||||||
|
new_range_env[s] = ValueRangeAnalysis.sub(vr, offset)
|
||||||
|
|
||||||
new_expr = expr.xreplace(new_shape_env)
|
new_expr = expr.xreplace(new_shape_env)
|
||||||
floor_div_replace = {}
|
floor_div_replace = {}
|
||||||
for atom in new_expr.atoms(FloorDiv):
|
for atom in new_expr.atoms(FloorDiv):
|
||||||
@ -1574,17 +1697,7 @@ class ShapeEnv:
|
|||||||
return new_expr
|
return new_expr
|
||||||
|
|
||||||
# Check if the range can solve it statically
|
# Check if the range can solve it statically
|
||||||
range_env = {
|
out = sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)
|
||||||
s: self.var_to_range[s]
|
|
||||||
for s in expr.free_symbols
|
|
||||||
if not (s in self.var_to_val and self.specialize_zero_one)
|
|
||||||
}
|
|
||||||
range_env.update({
|
|
||||||
new_shape_env[s] - 1: ValueRangeAnalysis.sub(self.var_to_range[s], 1)
|
|
||||||
for s in expr.free_symbols
|
|
||||||
if s in self.var_to_val and self.specialize_zero_one
|
|
||||||
})
|
|
||||||
out = sympy_interp(ValueRangeAnalysis, range_env, new_expr)
|
|
||||||
if out.is_singleton():
|
if out.is_singleton():
|
||||||
return out.lower
|
return out.lower
|
||||||
|
|
||||||
@ -1652,13 +1765,9 @@ class ShapeEnv:
|
|||||||
"""
|
"""
|
||||||
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
||||||
if len(result_expr.free_symbols) != 0:
|
if len(result_expr.free_symbols) != 0:
|
||||||
range_env = {
|
r = self._maybe_evaluate_static(result_expr)
|
||||||
s: self.var_to_range[s]
|
if r is not None:
|
||||||
for s in result_expr.free_symbols
|
return r
|
||||||
}
|
|
||||||
out = sympy_interp(ValueRangeAnalysis, range_env, result_expr)
|
|
||||||
if out.is_singleton():
|
|
||||||
return out.lower
|
|
||||||
raise self._make_data_dependent_error(result_expr)
|
raise self._make_data_dependent_error(result_expr)
|
||||||
return result_expr
|
return result_expr
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user