mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for nonzero, some improvements to reduce guards (#95387)
This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# It is essentially https://github.com/pytorch/pytorch/pull/95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1. What's in the PR: * nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question. * The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise. * PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`) * Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`) * I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful. * Added `constrain_unify` to let you specify two unbacked SymInts must have the same value Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/95387 Approved by: https://github.com/voznesenskym
This commit is contained in:
committed by
PyTorch MergeBot
parent
627282fa6c
commit
4833e47feb
@ -49,7 +49,7 @@ from common_utils import (
|
||||
)
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, GuardOnDataDependentSymNode
|
||||
|
||||
USE_TORCHVISION = False
|
||||
try:
|
||||
@ -2412,7 +2412,6 @@ symbolic_aot_autograd_failures = {
|
||||
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
@ -2613,7 +2612,16 @@ def _test_aot_autograd_helper(self, device, dtype, op):
|
||||
return op.op(*c_args, **c_kwargs)
|
||||
|
||||
compiled_f = compiled_function(f, nop, nop)
|
||||
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args)
|
||||
try:
|
||||
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args)
|
||||
except GuardOnDataDependentSymNode:
|
||||
# Carveout for getitem; I don't want to xfail the entire test
|
||||
# because that will reject known to be good tests see
|
||||
# https://github.com/pytorch/pytorch/issues/94705
|
||||
if op.name == "__getitem__":
|
||||
self.skipTest("Dynamic output shape operation in trace")
|
||||
else:
|
||||
raise
|
||||
|
||||
def _test_aot_autograd_module_helper(self, device, dtype, training, module_info):
|
||||
module_cls = module_info.module_cls
|
||||
|
||||
@ -1733,6 +1733,7 @@ class TestRefsOpsInfo(TestCase):
|
||||
skip_ref_ops = {
|
||||
'_refs.bitwise_right_shift',
|
||||
'_refs.copy_to',
|
||||
'_refs.empty_permuted',
|
||||
'_refs.empty_strided',
|
||||
'_refs.equal',
|
||||
'_refs.full',
|
||||
@ -1846,6 +1847,7 @@ class TestRefsOpsInfo(TestCase):
|
||||
'_refs.scalar_tensor', # missing "layout"
|
||||
# other
|
||||
'_refs.empty', # intentional; direct empty is faster and has less guards
|
||||
'_refs.empty_permuted', # intentional; direct empty is faster and has less guards
|
||||
'_refs.expand_as',
|
||||
'_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
|
||||
'_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to
|
||||
|
||||
@ -13,7 +13,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDepen
|
||||
from torch._decomp import decomposition_table
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
|
||||
constrain_range
|
||||
constrain_range, constrain_unify, guard_int
|
||||
)
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
@ -912,6 +912,115 @@ def forward(self, a_1):
|
||||
return empty""" # noqa: B950
|
||||
)
|
||||
|
||||
def test_dynamic_pointwise_scalar(self):
|
||||
def f(gravity, mask):
|
||||
gravity[mask, 0] = gravity[mask, 0] * -1
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(
|
||||
torch.randn((12, 4)),
|
||||
torch.randint(0, 2, (12,), dtype=torch.bool)
|
||||
).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, gravity_1, mask_1):
|
||||
select = torch.ops.aten.select.int(gravity_1, 1, 0)
|
||||
index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
|
||||
mul = torch.ops.aten.mul.Tensor(index, -1); index = None
|
||||
select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None
|
||||
return None""")
|
||||
|
||||
def test_reflect_r_over_x(self):
|
||||
def reflect_R_over_x(R):
|
||||
reflect = torch.eye(3, device=R.device)
|
||||
reflect[0, 0] = -1
|
||||
return reflect @ R @ reflect
|
||||
|
||||
def f(crop_camera, mask):
|
||||
crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(
|
||||
torch.randn((12, 3, 3)),
|
||||
torch.randint(0, 2, (12,), dtype=torch.bool)
|
||||
).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, crop_camera_1, mask_1):
|
||||
index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
|
||||
eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||
select = torch.ops.aten.select.int(eye, 0, 0)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
|
||||
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
|
||||
transpose = torch.ops.aten.transpose.int(index, -2, -1)
|
||||
t = torch.ops.aten.t.default(eye)
|
||||
clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None
|
||||
sym_size = torch.ops.aten.sym_size(index, 0); index = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2)
|
||||
mul = sym_size * sym_size_1
|
||||
sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1)
|
||||
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]); clone = mul = sym_size_2 = None
|
||||
mm = torch.ops.aten.mm.default(_unsafe_view, t); _unsafe_view = t = None
|
||||
view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]); mm = sym_size_1 = None
|
||||
transpose_1 = torch.ops.aten.transpose.int(view, -2, -1)
|
||||
clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None
|
||||
mul_1 = sym_size * 3
|
||||
sym_size_3 = torch.ops.aten.sym_size(view, 1); view = None
|
||||
view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]); clone_1 = mul_1 = sym_size_3 = None
|
||||
mm_1 = torch.ops.aten.mm.default(view_1, eye); view_1 = eye = None
|
||||
view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]); mm_1 = sym_size = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
|
||||
return None""")
|
||||
|
||||
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
||||
def test_unbacked_batch_resnet(self):
|
||||
mod = torchvision.models.resnet18()
|
||||
|
||||
def f(x, mask, params, buffers):
|
||||
for p in itertools.chain([x, mask], params.values(), buffers.values()):
|
||||
for s in p.shape:
|
||||
guard_int(s)
|
||||
x = x[mask]
|
||||
constrain_range(x.shape[0], min=1)
|
||||
for p in params.values():
|
||||
p.grad = None
|
||||
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
|
||||
|
||||
make_fx(f, tracing_mode="symbolic")(
|
||||
torch.randn(3, 3, 250, 250),
|
||||
torch.randint(0, 2, (3,), dtype=torch.bool),
|
||||
dict(mod.named_parameters()),
|
||||
dict(mod.named_buffers()),
|
||||
)
|
||||
|
||||
def test_boolean_index(self):
|
||||
def f(images, handedness, valid):
|
||||
images = images[valid]
|
||||
handedness = handedness[valid]
|
||||
zi = images.shape[0]
|
||||
zh = handedness.shape[0]
|
||||
# NB: We wouldn't actually need this if we could cache
|
||||
# the result of running valid.nonzero() and assign the same
|
||||
# SymInt in both cases. This is a workaround in lieu of
|
||||
# that memoization.
|
||||
constrain_unify(zi, zh)
|
||||
right_hand_mask = handedness == 1
|
||||
images[right_hand_mask] = images[right_hand_mask].flip(-1)
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(
|
||||
torch.randint(0, 256, (512, 1, 96, 96)),
|
||||
torch.randint(0, 1, (512,)),
|
||||
torch.randint(0, 2, (512,), dtype=torch.bool)
|
||||
).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, images_1, handedness_1, valid_1):
|
||||
index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
|
||||
index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
|
||||
eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
|
||||
index_2 = torch.ops.aten.index.Tensor(index, [eq])
|
||||
flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None
|
||||
return None""")
|
||||
|
||||
def test_neg_shape(self):
|
||||
def f(a):
|
||||
return torch.empty(-a.shape[0] + 10)
|
||||
@ -1202,7 +1311,6 @@ symbolic_tensor_failures = {
|
||||
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
|
||||
@ -1317,7 +1425,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
|
||||
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
|
||||
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
|
||||
|
||||
@ -161,6 +161,13 @@ repro_tolerance = 1e-3
|
||||
# This requires dynamic_shapes to be True.
|
||||
capture_scalar_outputs = False
|
||||
|
||||
# Not all backends support operators that have dynamic output shape (e.g.,
|
||||
# nonzero, unique). When this flag is set to False, we introduce a graph
|
||||
# break instead of capturing. This requires dynamic_shapes to be True.
|
||||
# If you set this to True, you probably also want capture_scalar_outputs
|
||||
# (these are separated for historical reasons).
|
||||
capture_dynamic_output_shape_ops = False
|
||||
|
||||
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
|
||||
# false_fn produces code with identical guards.
|
||||
enforce_cond_guards_match = True
|
||||
|
||||
@ -191,6 +191,7 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
||||
fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=ShapeEnv(
|
||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||
strict_mark_dyn=export,
|
||||
assume_static_by_default=config.assume_static_by_default,
|
||||
)
|
||||
|
||||
@ -20,7 +20,6 @@ from torch._prims_common import (
|
||||
from torch._prims_common.wrappers import out_wrapper
|
||||
from torch._refs import _broadcast_shapes
|
||||
|
||||
from torch._subclasses.fake_tensor import check_no_bool_index_tensors
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
@ -996,7 +995,6 @@ def vdot(self, other):
|
||||
# get shape inference through structured kernels
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
check_no_bool_index_tensors(aten.index.Tensor, self, indices)
|
||||
check(indices, lambda: "at least one index must be provided")
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
|
||||
@ -347,7 +347,7 @@ def _elementwise_meta(
|
||||
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
|
||||
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||
|
||||
strides = utils.compute_elementwise_output_strides(*args_)
|
||||
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
|
||||
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
|
||||
|
||||
# Acquires the dtype
|
||||
@ -398,7 +398,8 @@ def _elementwise_meta(
|
||||
else:
|
||||
dtype = dtype
|
||||
|
||||
return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype)
|
||||
assert shape is not None
|
||||
return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value]
|
||||
|
||||
# Number case
|
||||
# TODO: fix number type promotion (bool, complex->float)
|
||||
|
||||
@ -77,6 +77,7 @@ torch_function_passthrough = {
|
||||
torch.Tensor.device.__get__, # type: ignore[attr-defined]
|
||||
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
|
||||
torch.Tensor.layout.__get__, # type: ignore[attr-defined]
|
||||
torch.Tensor.is_contiguous,
|
||||
# For TorchRefsMode only
|
||||
torch.Tensor.__format__,
|
||||
torch.Tensor.__repr__,
|
||||
@ -346,33 +347,41 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
# non overlapping and dense strides.
|
||||
# This is also INCORRECT because it does not model TensorIterator's
|
||||
# short-circuit, which can cause different strides.
|
||||
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
||||
"""
|
||||
Computes the output strides for elementwise operations.
|
||||
"""
|
||||
|
||||
if len(tensors) == 0:
|
||||
def compute_elementwise_output_logical_to_physical_perm(*tensors, _skip_checks=False) -> List[int]:
|
||||
if not _skip_checks and len(tensors) == 0:
|
||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||
raise ValueError(msg)
|
||||
|
||||
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
|
||||
if not _skip_checks:
|
||||
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
|
||||
|
||||
# Filters the tensors to actual tensors
|
||||
tensors = tuple(
|
||||
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
|
||||
)
|
||||
if not _skip_checks:
|
||||
tensors = tuple(
|
||||
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
|
||||
)
|
||||
|
||||
# Short-circuits for CPU scalar case
|
||||
if len(tensors) == 0:
|
||||
return ()
|
||||
return []
|
||||
|
||||
# Short-circuits for shapes with zero or one dimensions
|
||||
# TODO: are these necessary?
|
||||
ndim = tensors[0].ndim
|
||||
if ndim == 0:
|
||||
return ()
|
||||
return []
|
||||
if ndim == 1:
|
||||
return (1,)
|
||||
return [0]
|
||||
|
||||
# Short-circuits if contiguous, following the fake fast path.
|
||||
# This reduces the number of guards we end up making
|
||||
# TODO: do channels last too
|
||||
is_contiguous = True
|
||||
for t in tensors:
|
||||
is_contiguous = is_contiguous and t.is_contiguous(memory_format=torch.contiguous_format)
|
||||
|
||||
if is_contiguous:
|
||||
return list(range(ndim))
|
||||
|
||||
shape = tensors[0].shape
|
||||
|
||||
@ -398,6 +407,11 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
||||
# or all strides are equal and all dimensions have the same length
|
||||
return 0
|
||||
|
||||
# The "sort" order for the permutation is back-to-front, but
|
||||
# the natural order for permutations is front-to-back. Do the
|
||||
# sorting back-to-front and then reverse it on output.
|
||||
#
|
||||
# also, note this returns the logical to physical shape permutation
|
||||
perm = list(reversed(range(ndim)))
|
||||
|
||||
# insertion sort with support for ambiguous comparisons
|
||||
@ -411,18 +425,64 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
||||
elif comparison < 0:
|
||||
break
|
||||
|
||||
permuted_shape = [-1] * ndim
|
||||
for idx, x in enumerate(reversed(perm)):
|
||||
permuted_shape[idx] = shape[x]
|
||||
return list(reversed(perm))
|
||||
|
||||
|
||||
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
|
||||
"""
|
||||
Computes the output strides for elementwise operations.
|
||||
"""
|
||||
if len(tensors) == 0:
|
||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||
raise ValueError(msg)
|
||||
|
||||
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
|
||||
|
||||
# Filters the tensors to actual tensors
|
||||
tensors = tuple(
|
||||
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
|
||||
)
|
||||
|
||||
# Short-circuits for CPU scalar case
|
||||
if len(tensors) == 0:
|
||||
return ()
|
||||
|
||||
ndim = tensors[0].ndim
|
||||
shape = tensors[0].shape
|
||||
|
||||
if ndim == 0:
|
||||
return ()
|
||||
if ndim == 1:
|
||||
return (1,)
|
||||
|
||||
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
|
||||
*tensors, _skip_checks=True
|
||||
)
|
||||
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
|
||||
|
||||
new_strides = make_contiguous_strides_for(permuted_shape)
|
||||
permuted_strides = [-1] * ndim
|
||||
for idx, x in enumerate(reversed(perm)):
|
||||
permuted_strides[x] = new_strides[idx]
|
||||
permuted_strides = apply_perm(new_strides, invert_perm(logical_to_physical_perm)) # to logical
|
||||
|
||||
return tuple(permuted_strides)
|
||||
|
||||
|
||||
# Identity permutation is [0, 1, 2]
|
||||
def apply_perm(inp, perm):
|
||||
ndim = len(inp)
|
||||
permuted_inp = [-1] * ndim
|
||||
for idx, x in enumerate(perm):
|
||||
permuted_inp[idx] = inp[x]
|
||||
return permuted_inp
|
||||
|
||||
|
||||
def invert_perm(perm):
|
||||
ndim = len(perm)
|
||||
new_perm = [-1] * ndim
|
||||
for idx, x in enumerate(perm):
|
||||
new_perm[x] = idx
|
||||
return new_perm
|
||||
|
||||
|
||||
#
|
||||
# Common helper functions
|
||||
#
|
||||
|
||||
@ -276,6 +276,7 @@ __all__ = [
|
||||
"arange",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"empty_permuted",
|
||||
"empty_strided",
|
||||
"eye",
|
||||
"full",
|
||||
@ -4055,9 +4056,7 @@ def empty_permuted(
|
||||
shape,
|
||||
physical_layout,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
@ -4274,10 +4273,13 @@ def empty_like(
|
||||
)
|
||||
|
||||
# memory_format == torch.preserve_format
|
||||
strides = utils.compute_elementwise_output_strides(a)
|
||||
return torch.empty_strided(
|
||||
logical_to_physical_perm = (
|
||||
utils.compute_elementwise_output_logical_to_physical_perm(a)
|
||||
)
|
||||
# identity perm is [2, 1, 0]
|
||||
return torch.empty_permuted(
|
||||
a.shape,
|
||||
strides,
|
||||
logical_to_physical_perm,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
|
||||
@ -397,7 +397,7 @@ def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
|
||||
# index.Tensor data-dependent in only some conditions
|
||||
@register_op_impl(
|
||||
lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
|
||||
and func != aten.index.Tensor
|
||||
and func not in [aten.index.Tensor, aten.nonzero.default]
|
||||
)
|
||||
def dyn_shape(fake_mode, func, *args, **kwargs):
|
||||
raise DynamicOutputShapeException(func)
|
||||
@ -405,11 +405,9 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
|
||||
|
||||
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
|
||||
def local_scalar_dense(fake_mode, func, arg):
|
||||
if fake_mode.shape_env is None:
|
||||
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
|
||||
# Without symints/symfloats, cannot handle this
|
||||
raise DataDependentOutputException(func)
|
||||
if not fake_mode.shape_env.allow_scalar_outputs:
|
||||
raise DataDependentOutputException(func)
|
||||
if is_float_dtype(arg.dtype):
|
||||
return fake_mode.shape_env.create_unbacked_symfloat()
|
||||
elif is_integer_dtype(arg.dtype):
|
||||
@ -418,6 +416,36 @@ def local_scalar_dense(fake_mode, func, arg):
|
||||
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
|
||||
|
||||
|
||||
@register_op_impl(lambda func: func is torch.ops.aten.nonzero.default)
|
||||
def nonzero(fake_mode, func, arg):
|
||||
if (
|
||||
fake_mode.shape_env is None
|
||||
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
|
||||
):
|
||||
# Without symints/symfloats, cannot handle this
|
||||
raise DynamicOutputShapeException(func)
|
||||
nnz = fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
constrain_range,
|
||||
definitely_true,
|
||||
guard_int,
|
||||
)
|
||||
|
||||
# This is unsound, but it works well in practice
|
||||
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
|
||||
# TODO: Add a config knob to turn off this unsound behavior
|
||||
lower = 2
|
||||
upper = None
|
||||
# But don't give totally unsatisfiable bounds if we know it's too small!
|
||||
if definitely_true(arg.numel() < 2):
|
||||
lower = 0
|
||||
upper = guard_int(arg.numel())
|
||||
constrain_range(nnz, min=lower, max=upper)
|
||||
|
||||
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
|
||||
|
||||
|
||||
# NB: this must be ordered after local_scalar_dense
|
||||
@register_op_impl(
|
||||
lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined]
|
||||
@ -451,10 +479,17 @@ def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
|
||||
# index tensors with cuda self
|
||||
@register_op_impl(aten.index.Tensor)
|
||||
def index_tensor(fake_mode, func, *args, **kwargs):
|
||||
# dynamic shape op if indices are bool/uint8
|
||||
check_no_bool_index_tensors(func, *args, **kwargs)
|
||||
from torch._meta_registrations import meta_index_Tensor
|
||||
|
||||
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
out_device = new_kwargs["input"].device
|
||||
# ensure nonzero call goes to fake tensor
|
||||
with fake_mode:
|
||||
out = meta_index_Tensor(*args, **kwargs)
|
||||
return out.to(out_device)
|
||||
|
||||
|
||||
# takes in multiple-devices, dont default to default device handling
|
||||
@ -493,7 +528,15 @@ def conv(fake_mode, func, *args, **kwargs):
|
||||
with fake_mode:
|
||||
# if the input is unsqueezed is done in Convolution.cpp we get segfault
|
||||
k = kwargs["weight"].ndim
|
||||
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
||||
batch = kwargs["input"].shape[0]
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import has_hint
|
||||
|
||||
if not has_hint(batch):
|
||||
# TODO: We can make this a little more faithful with best effort
|
||||
# channels last detection (but only if it's statically obvious!)
|
||||
mem_fmt = None
|
||||
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
||||
mem_fmt = None
|
||||
else:
|
||||
if func is aten.convolution.default:
|
||||
|
||||
@ -108,6 +108,81 @@ def hint_int(a):
|
||||
assert type(a) is int, a
|
||||
return a
|
||||
|
||||
def has_hint(a):
|
||||
if isinstance(a, torch.SymInt):
|
||||
return a.node.has_hint()
|
||||
return True
|
||||
|
||||
# Returns True if every size dim on the tensor has a hint
|
||||
# TODO: Should this include strides too? For now it doesn't matter,
|
||||
# that's quite an obscure case
|
||||
def tensor_has_hints(t):
|
||||
return all(has_hint(s) for s in t.size())
|
||||
|
||||
def definitely_true(a):
|
||||
"""
|
||||
Returns True only if we can tell that a is True, possibly introducing
|
||||
a guard in the process. If a depends on some unbacked SymInt, we may
|
||||
return False even though there may exist a possible value of the SymInt
|
||||
that would cause the expression to return True.
|
||||
|
||||
When is it appropriate to use definitely_true? First, if you can use
|
||||
a higher level combinator like parallel_or/parallel_and, prefer using
|
||||
those instead, they are definitely safe (modulo short-circuiting).
|
||||
Second, it can be used if the program would behave equivalently if
|
||||
definitely_true always returned False (parallel_or/parallel_and are
|
||||
examples of this pattern, modulo short-circuiting). Finally, it even
|
||||
be OK if the program wouldn't behave equivalently, so long as the
|
||||
change is semantics preserving. It can be semantics preserving if
|
||||
the program errors in more cases than it did previously (but otherwise
|
||||
behaves identically), or if it changes some quantity in a way that
|
||||
doesn't matter (e.g., strides often fall in this bucket.)
|
||||
"""
|
||||
if isinstance(a, SymBool):
|
||||
if a.node.has_hint():
|
||||
return guard_bool(a)
|
||||
else:
|
||||
return False
|
||||
return bool(a)
|
||||
|
||||
def definitely_false(a):
|
||||
"""
|
||||
Returns True only if we can tell that a is False, possibly introducing
|
||||
a guard in the process. If a depends on some unbacked SymInt, we may
|
||||
return False even though there may exist a possible value of the SymInt
|
||||
that would cause the expression a to be False. See definitely_true
|
||||
for more usage guidance.
|
||||
"""
|
||||
if isinstance(a, SymBool):
|
||||
if a.node.has_hint():
|
||||
return not guard_bool(a)
|
||||
else:
|
||||
return False
|
||||
return not bool(a)
|
||||
|
||||
# TODO: could improve parallel_or/parallel_and by avoiding guards
|
||||
# if there exists a quantity that can be handled un-guardedly. However,
|
||||
# for backed SymInts, avoiding guards doesn't really matter in practice,
|
||||
# so I chose not to do it.
|
||||
|
||||
def parallel_or(*args):
|
||||
"""
|
||||
Evaluate the logical OR of several arguments, avoiding guarding on
|
||||
unbacked SymInts if another argument is definitely True.
|
||||
"""
|
||||
if any(definitely_true(args) for a in args):
|
||||
return True
|
||||
return any(args)
|
||||
|
||||
def parallel_and(*args):
|
||||
"""
|
||||
Evaluate the logical FALSE of several arguments, avoiding guarding on
|
||||
unbacked SymInts if another argument is definitely False.
|
||||
"""
|
||||
if any(definitely_false(args) for a in args):
|
||||
return False
|
||||
return all(args)
|
||||
|
||||
def guard_scalar(a):
|
||||
if isinstance(a, (SymBool, bool)):
|
||||
return guard_bool(a)
|
||||
@ -138,6 +213,34 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
||||
)
|
||||
|
||||
|
||||
def constrain_unify(a, b):
|
||||
"""
|
||||
Given two SymInts, constrain them so that they must be equal. NB:
|
||||
this will not work with SymInts that represent nontrivial expressions
|
||||
(yet!)
|
||||
"""
|
||||
# TODO: Maybe dedupe this with _maybe_guard_eq?
|
||||
if not isinstance(a, SymInt):
|
||||
if not isinstance(b, SymInt):
|
||||
assert a == b
|
||||
else:
|
||||
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||
shape_env = b.node.shape_env
|
||||
shape_env.replacements[b.node.expr] = sympy.Integer(a)
|
||||
else:
|
||||
# TODO: Actually, we can support this as long as one of them is a symbol.
|
||||
# NB: We can't actually do "unification" as our operators are not
|
||||
# injective
|
||||
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||
shape_env = a.node.shape_env
|
||||
if not isinstance(b, SymInt):
|
||||
shape_env.replacements[a.node.expr] = sympy.Integer(b)
|
||||
else:
|
||||
assert a.node.shape_env is b.node.shape_env
|
||||
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||||
new_var = shape_env._find(a.node.expr)
|
||||
shape_env.replacements[b.node.expr] = new_var
|
||||
|
||||
def guard_bool(a):
|
||||
if isinstance(a, SymBool):
|
||||
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
||||
@ -242,7 +345,13 @@ class SymNode:
|
||||
# simplify it into a hint
|
||||
def _update_hint(self):
|
||||
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
|
||||
self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
|
||||
new_hint = self.shape_env.replace(self._hint_expr)
|
||||
# NB: unification constraints could result in a replacement that
|
||||
# doesn't actually solve the hint! Check for this.
|
||||
if new_hint.free_symbols:
|
||||
self._hint_expr = new_hint
|
||||
return
|
||||
self._hint = self.pytype(new_hint)
|
||||
self._hint_expr = None
|
||||
|
||||
@property
|
||||
@ -1076,6 +1185,7 @@ class ShapeEnv:
|
||||
def __init__(
|
||||
self, *,
|
||||
allow_scalar_outputs=True,
|
||||
allow_dynamic_output_shape_ops=True,
|
||||
strict_mark_dyn=False,
|
||||
assume_static_by_default=False,
|
||||
# The following options affect decisions we make about eager
|
||||
@ -1093,6 +1203,7 @@ class ShapeEnv:
|
||||
):
|
||||
# Not directly used by ShapeEnv; indirectly used by FakeTensor
|
||||
self.allow_scalar_outputs = allow_scalar_outputs
|
||||
self.allow_dynamic_output_shape_ops = allow_dynamic_output_shape_ops
|
||||
self.guards: List[ShapeGuard] = []
|
||||
# Maps symbolic ints to their original concrete values
|
||||
# Currently populated from tensors
|
||||
@ -1244,12 +1355,10 @@ class ShapeEnv:
|
||||
if not dyn:
|
||||
# Non explicitly marked dynamic dims register to val_to_var to get duck shaped
|
||||
self.val_to_var[val] = sympy_expr
|
||||
# We also infer that they must not be 0/1
|
||||
lower = 2 if self.specialize_zero_one else 0
|
||||
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
|
||||
else:
|
||||
# Avoid up front 0/1 specializing dynamic dims
|
||||
self.var_to_range[sympy_expr] = ValueRanges(0, sympy.oo)
|
||||
|
||||
# We also infer that it must be not 0/1
|
||||
lower = 2 if self.specialize_zero_one else 0
|
||||
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
|
||||
|
||||
if not dyn and self.duck_shape:
|
||||
# This implements duck-shaping: input sizes that match are assigned
|
||||
@ -1556,15 +1665,29 @@ class ShapeEnv:
|
||||
Tries to evaluate expr without introducing guards
|
||||
"""
|
||||
expr = self.simplify(expr)
|
||||
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
|
||||
|
||||
# Simplify making use of value range lower bound
|
||||
symbols = list(expr.free_symbols)
|
||||
new_shape_env = {
|
||||
k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1
|
||||
for idx, k in enumerate(symbols)
|
||||
# Do not assume unbacked symints are > 1
|
||||
# If we didn't specialize 0/1, this shape env is empty
|
||||
if k in self.var_to_val and self.specialize_zero_one
|
||||
}
|
||||
new_shape_env = {}
|
||||
new_range_env = {}
|
||||
for idx, k in enumerate(symbols):
|
||||
vr = self.var_to_range[k]
|
||||
# Don't do anything if we don't have a nontrivial lower bound
|
||||
if vr.lower == -sympy.oo:
|
||||
new_range_env[k] = vr
|
||||
continue
|
||||
# Positive means >= 1
|
||||
# Positive - 1 means >= 0
|
||||
# Positive + lower - 1 means >= lower
|
||||
# The new symbol 's' is "too low", so when we substitute it in
|
||||
# we have to increase it by offset (and conversely, the new
|
||||
# variables have to have their value range bounds adjusted as
|
||||
# well)
|
||||
s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
|
||||
offset = vr.lower - 1
|
||||
new_shape_env[k] = s + offset
|
||||
new_range_env[s] = ValueRangeAnalysis.sub(vr, offset)
|
||||
|
||||
new_expr = expr.xreplace(new_shape_env)
|
||||
floor_div_replace = {}
|
||||
for atom in new_expr.atoms(FloorDiv):
|
||||
@ -1574,17 +1697,7 @@ class ShapeEnv:
|
||||
return new_expr
|
||||
|
||||
# Check if the range can solve it statically
|
||||
range_env = {
|
||||
s: self.var_to_range[s]
|
||||
for s in expr.free_symbols
|
||||
if not (s in self.var_to_val and self.specialize_zero_one)
|
||||
}
|
||||
range_env.update({
|
||||
new_shape_env[s] - 1: ValueRangeAnalysis.sub(self.var_to_range[s], 1)
|
||||
for s in expr.free_symbols
|
||||
if s in self.var_to_val and self.specialize_zero_one
|
||||
})
|
||||
out = sympy_interp(ValueRangeAnalysis, range_env, new_expr)
|
||||
out = sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)
|
||||
if out.is_singleton():
|
||||
return out.lower
|
||||
|
||||
@ -1652,13 +1765,9 @@ class ShapeEnv:
|
||||
"""
|
||||
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
||||
if len(result_expr.free_symbols) != 0:
|
||||
range_env = {
|
||||
s: self.var_to_range[s]
|
||||
for s in result_expr.free_symbols
|
||||
}
|
||||
out = sympy_interp(ValueRangeAnalysis, range_env, result_expr)
|
||||
if out.is_singleton():
|
||||
return out.lower
|
||||
r = self._maybe_evaluate_static(result_expr)
|
||||
if r is not None:
|
||||
return r
|
||||
raise self._make_data_dependent_error(result_expr)
|
||||
return result_expr
|
||||
|
||||
|
||||
Reference in New Issue
Block a user