Add support for nonzero, some improvements to reduce guards (#95387)

This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#

It is essentially https://github.com/pytorch/pytorch/pull/95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1.

What's in the PR:

* nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question.
* The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise.
* PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`)
* Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`)
* I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful.
* Added `constrain_unify` to let you specify two unbacked SymInts must have the same value

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95387
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang
2023-02-23 11:51:25 -08:00
committed by PyTorch MergeBot
parent 627282fa6c
commit 4833e47feb
11 changed files with 413 additions and 75 deletions

View File

@ -49,7 +49,7 @@ from common_utils import (
) )
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.experimental.symbolic_shapes import ShapeEnv, GuardOnDataDependentSymNode
USE_TORCHVISION = False USE_TORCHVISION = False
try: try:
@ -2412,7 +2412,6 @@ symbolic_aot_autograd_failures = {
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@ -2613,7 +2612,16 @@ def _test_aot_autograd_helper(self, device, dtype, op):
return op.op(*c_args, **c_kwargs) return op.op(*c_args, **c_kwargs)
compiled_f = compiled_function(f, nop, nop) compiled_f = compiled_function(f, nop, nop)
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args) try:
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args)
except GuardOnDataDependentSymNode:
# Carveout for getitem; I don't want to xfail the entire test
# because that will reject known to be good tests see
# https://github.com/pytorch/pytorch/issues/94705
if op.name == "__getitem__":
self.skipTest("Dynamic output shape operation in trace")
else:
raise
def _test_aot_autograd_module_helper(self, device, dtype, training, module_info): def _test_aot_autograd_module_helper(self, device, dtype, training, module_info):
module_cls = module_info.module_cls module_cls = module_info.module_cls

View File

@ -1733,6 +1733,7 @@ class TestRefsOpsInfo(TestCase):
skip_ref_ops = { skip_ref_ops = {
'_refs.bitwise_right_shift', '_refs.bitwise_right_shift',
'_refs.copy_to', '_refs.copy_to',
'_refs.empty_permuted',
'_refs.empty_strided', '_refs.empty_strided',
'_refs.equal', '_refs.equal',
'_refs.full', '_refs.full',
@ -1846,6 +1847,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.scalar_tensor', # missing "layout" '_refs.scalar_tensor', # missing "layout"
# other # other
'_refs.empty', # intentional; direct empty is faster and has less guards '_refs.empty', # intentional; direct empty is faster and has less guards
'_refs.empty_permuted', # intentional; direct empty is faster and has less guards
'_refs.expand_as', '_refs.expand_as',
'_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value" '_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
'_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to '_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to

View File

@ -13,7 +13,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDepen
from torch._decomp import decomposition_table from torch._decomp import decomposition_table
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
constrain_range constrain_range, constrain_unify, guard_int
) )
from torch.testing._internal.common_device_type import ops from torch.testing._internal.common_device_type import ops
from torch._C import _disabled_torch_function_impl from torch._C import _disabled_torch_function_impl
@ -912,6 +912,115 @@ def forward(self, a_1):
return empty""" # noqa: B950 return empty""" # noqa: B950
) )
def test_dynamic_pointwise_scalar(self):
def f(gravity, mask):
gravity[mask, 0] = gravity[mask, 0] * -1
r = str(make_fx(f, tracing_mode="symbolic")(
torch.randn((12, 4)),
torch.randint(0, 2, (12,), dtype=torch.bool)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, gravity_1, mask_1):
select = torch.ops.aten.select.int(gravity_1, 1, 0)
index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
mul = torch.ops.aten.mul.Tensor(index, -1); index = None
select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None
return None""")
def test_reflect_r_over_x(self):
def reflect_R_over_x(R):
reflect = torch.eye(3, device=R.device)
reflect[0, 0] = -1
return reflect @ R @ reflect
def f(crop_camera, mask):
crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
r = str(make_fx(f, tracing_mode="symbolic")(
torch.randn((12, 3, 3)),
torch.randint(0, 2, (12,), dtype=torch.bool)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, crop_camera_1, mask_1):
index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
_tensor_constant0 = self._tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
select = torch.ops.aten.select.int(eye, 0, 0)
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
transpose = torch.ops.aten.transpose.int(index, -2, -1)
t = torch.ops.aten.t.default(eye)
clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None
sym_size = torch.ops.aten.sym_size(index, 0); index = None
sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2)
mul = sym_size * sym_size_1
sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]); clone = mul = sym_size_2 = None
mm = torch.ops.aten.mm.default(_unsafe_view, t); _unsafe_view = t = None
view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]); mm = sym_size_1 = None
transpose_1 = torch.ops.aten.transpose.int(view, -2, -1)
clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None
mul_1 = sym_size * 3
sym_size_3 = torch.ops.aten.sym_size(view, 1); view = None
view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]); clone_1 = mul_1 = sym_size_3 = None
mm_1 = torch.ops.aten.mm.default(view_1, eye); view_1 = eye = None
view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]); mm_1 = sym_size = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
return None""")
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_unbacked_batch_resnet(self):
mod = torchvision.models.resnet18()
def f(x, mask, params, buffers):
for p in itertools.chain([x, mask], params.values(), buffers.values()):
for s in p.shape:
guard_int(s)
x = x[mask]
constrain_range(x.shape[0], min=1)
for p in params.values():
p.grad = None
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
make_fx(f, tracing_mode="symbolic")(
torch.randn(3, 3, 250, 250),
torch.randint(0, 2, (3,), dtype=torch.bool),
dict(mod.named_parameters()),
dict(mod.named_buffers()),
)
def test_boolean_index(self):
def f(images, handedness, valid):
images = images[valid]
handedness = handedness[valid]
zi = images.shape[0]
zh = handedness.shape[0]
# NB: We wouldn't actually need this if we could cache
# the result of running valid.nonzero() and assign the same
# SymInt in both cases. This is a workaround in lieu of
# that memoization.
constrain_unify(zi, zh)
right_hand_mask = handedness == 1
images[right_hand_mask] = images[right_hand_mask].flip(-1)
r = str(make_fx(f, tracing_mode="symbolic")(
torch.randint(0, 256, (512, 1, 96, 96)),
torch.randint(0, 1, (512,)),
torch.randint(0, 2, (512,), dtype=torch.bool)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, images_1, handedness_1, valid_1):
index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
index_2 = torch.ops.aten.index.Tensor(index, [eq])
flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None
return None""")
def test_neg_shape(self): def test_neg_shape(self):
def f(a): def f(a):
return torch.empty(-a.shape[0] + 10) return torch.empty(-a.shape[0] + 10)
@ -1202,7 +1311,6 @@ symbolic_tensor_failures = {
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
@ -1317,7 +1425,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition

View File

@ -161,6 +161,13 @@ repro_tolerance = 1e-3
# This requires dynamic_shapes to be True. # This requires dynamic_shapes to be True.
capture_scalar_outputs = False capture_scalar_outputs = False
# Not all backends support operators that have dynamic output shape (e.g.,
# nonzero, unique). When this flag is set to False, we introduce a graph
# break instead of capturing. This requires dynamic_shapes to be True.
# If you set this to True, you probably also want capture_scalar_outputs
# (these are separated for historical reasons).
capture_dynamic_output_shape_ops = False
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and # Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
# false_fn produces code with identical guards. # false_fn produces code with identical guards.
enforce_cond_guards_match = True enforce_cond_guards_match = True

View File

@ -191,6 +191,7 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
fake_mode = torch._subclasses.FakeTensorMode( fake_mode = torch._subclasses.FakeTensorMode(
shape_env=ShapeEnv( shape_env=ShapeEnv(
allow_scalar_outputs=config.capture_scalar_outputs, allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
strict_mark_dyn=export, strict_mark_dyn=export,
assume_static_by_default=config.assume_static_by_default, assume_static_by_default=config.assume_static_by_default,
) )

View File

@ -20,7 +20,6 @@ from torch._prims_common import (
from torch._prims_common.wrappers import out_wrapper from torch._prims_common.wrappers import out_wrapper
from torch._refs import _broadcast_shapes from torch._refs import _broadcast_shapes
from torch._subclasses.fake_tensor import check_no_bool_index_tensors
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -996,7 +995,6 @@ def vdot(self, other):
# get shape inference through structured kernels # get shape inference through structured kernels
@register_meta(aten.index.Tensor) @register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices): def meta_index_Tensor(self, indices):
check_no_bool_index_tensors(aten.index.Tensor, self, indices)
check(indices, lambda: "at least one index must be provided") check(indices, lambda: "at least one index must be provided")
# aten::index is the internal advanced indexing implementation # aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors # checkIndexTensorTypes and expandTensors

View File

@ -347,7 +347,7 @@ def _elementwise_meta(
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True) utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True) utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
strides = utils.compute_elementwise_output_strides(*args_) l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True) shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
# Acquires the dtype # Acquires the dtype
@ -398,7 +398,8 @@ def _elementwise_meta(
else: else:
dtype = dtype dtype = dtype
return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype) assert shape is not None
return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value]
# Number case # Number case
# TODO: fix number type promotion (bool, complex->float) # TODO: fix number type promotion (bool, complex->float)

View File

@ -77,6 +77,7 @@ torch_function_passthrough = {
torch.Tensor.device.__get__, # type: ignore[attr-defined] torch.Tensor.device.__get__, # type: ignore[attr-defined]
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
torch.Tensor.layout.__get__, # type: ignore[attr-defined] torch.Tensor.layout.__get__, # type: ignore[attr-defined]
torch.Tensor.is_contiguous,
# For TorchRefsMode only # For TorchRefsMode only
torch.Tensor.__format__, torch.Tensor.__format__,
torch.Tensor.__repr__, torch.Tensor.__repr__,
@ -346,33 +347,41 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
# non overlapping and dense strides. # non overlapping and dense strides.
# This is also INCORRECT because it does not model TensorIterator's # This is also INCORRECT because it does not model TensorIterator's
# short-circuit, which can cause different strides. # short-circuit, which can cause different strides.
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: def compute_elementwise_output_logical_to_physical_perm(*tensors, _skip_checks=False) -> List[int]:
""" if not _skip_checks and len(tensors) == 0:
Computes the output strides for elementwise operations.
"""
if len(tensors) == 0:
msg = "Can't compute elementwise output strides for zero tensors!" msg = "Can't compute elementwise output strides for zero tensors!"
raise ValueError(msg) raise ValueError(msg)
check_same_shape(*tensors, allow_cpu_scalar_tensors=True) if not _skip_checks:
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
# Filters the tensors to actual tensors # Filters the tensors to actual tensors
tensors = tuple( if not _skip_checks:
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) tensors = tuple(
) a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
)
# Short-circuits for CPU scalar case # Short-circuits for CPU scalar case
if len(tensors) == 0: if len(tensors) == 0:
return () return []
# Short-circuits for shapes with zero or one dimensions # Short-circuits for shapes with zero or one dimensions
# TODO: are these necessary? # TODO: are these necessary?
ndim = tensors[0].ndim ndim = tensors[0].ndim
if ndim == 0: if ndim == 0:
return () return []
if ndim == 1: if ndim == 1:
return (1,) return [0]
# Short-circuits if contiguous, following the fake fast path.
# This reduces the number of guards we end up making
# TODO: do channels last too
is_contiguous = True
for t in tensors:
is_contiguous = is_contiguous and t.is_contiguous(memory_format=torch.contiguous_format)
if is_contiguous:
return list(range(ndim))
shape = tensors[0].shape shape = tensors[0].shape
@ -398,6 +407,11 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
# or all strides are equal and all dimensions have the same length # or all strides are equal and all dimensions have the same length
return 0 return 0
# The "sort" order for the permutation is back-to-front, but
# the natural order for permutations is front-to-back. Do the
# sorting back-to-front and then reverse it on output.
#
# also, note this returns the logical to physical shape permutation
perm = list(reversed(range(ndim))) perm = list(reversed(range(ndim)))
# insertion sort with support for ambiguous comparisons # insertion sort with support for ambiguous comparisons
@ -411,18 +425,64 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
elif comparison < 0: elif comparison < 0:
break break
permuted_shape = [-1] * ndim return list(reversed(perm))
for idx, x in enumerate(reversed(perm)):
permuted_shape[idx] = shape[x]
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
"""
Computes the output strides for elementwise operations.
"""
if len(tensors) == 0:
msg = "Can't compute elementwise output strides for zero tensors!"
raise ValueError(msg)
check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
# Filters the tensors to actual tensors
tensors = tuple(
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
)
# Short-circuits for CPU scalar case
if len(tensors) == 0:
return ()
ndim = tensors[0].ndim
shape = tensors[0].shape
if ndim == 0:
return ()
if ndim == 1:
return (1,)
logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=True
)
permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
new_strides = make_contiguous_strides_for(permuted_shape) new_strides = make_contiguous_strides_for(permuted_shape)
permuted_strides = [-1] * ndim permuted_strides = apply_perm(new_strides, invert_perm(logical_to_physical_perm)) # to logical
for idx, x in enumerate(reversed(perm)):
permuted_strides[x] = new_strides[idx]
return tuple(permuted_strides) return tuple(permuted_strides)
# Identity permutation is [0, 1, 2]
def apply_perm(inp, perm):
ndim = len(inp)
permuted_inp = [-1] * ndim
for idx, x in enumerate(perm):
permuted_inp[idx] = inp[x]
return permuted_inp
def invert_perm(perm):
ndim = len(perm)
new_perm = [-1] * ndim
for idx, x in enumerate(perm):
new_perm[x] = idx
return new_perm
# #
# Common helper functions # Common helper functions
# #

View File

@ -276,6 +276,7 @@ __all__ = [
"arange", "arange",
"empty", "empty",
"empty_like", "empty_like",
"empty_permuted",
"empty_strided", "empty_strided",
"eye", "eye",
"full", "full",
@ -4055,9 +4056,7 @@ def empty_permuted(
shape, shape,
physical_layout, physical_layout,
dtype=dtype, dtype=dtype,
layout=layout,
device=device, device=device,
pin_memory=pin_memory,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
@ -4274,10 +4273,13 @@ def empty_like(
) )
# memory_format == torch.preserve_format # memory_format == torch.preserve_format
strides = utils.compute_elementwise_output_strides(a) logical_to_physical_perm = (
return torch.empty_strided( utils.compute_elementwise_output_logical_to_physical_perm(a)
)
# identity perm is [2, 1, 0]
return torch.empty_permuted(
a.shape, a.shape,
strides, logical_to_physical_perm,
dtype=dtype, dtype=dtype,
layout=layout, layout=layout,
device=device, device=device,

View File

@ -397,7 +397,7 @@ def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
# index.Tensor data-dependent in only some conditions # index.Tensor data-dependent in only some conditions
@register_op_impl( @register_op_impl(
lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined] lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
and func != aten.index.Tensor and func not in [aten.index.Tensor, aten.nonzero.default]
) )
def dyn_shape(fake_mode, func, *args, **kwargs): def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func) raise DynamicOutputShapeException(func)
@ -405,11 +405,9 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default) @register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
def local_scalar_dense(fake_mode, func, arg): def local_scalar_dense(fake_mode, func, arg):
if fake_mode.shape_env is None: if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
# Without symints/symfloats, cannot handle this # Without symints/symfloats, cannot handle this
raise DataDependentOutputException(func) raise DataDependentOutputException(func)
if not fake_mode.shape_env.allow_scalar_outputs:
raise DataDependentOutputException(func)
if is_float_dtype(arg.dtype): if is_float_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symfloat() return fake_mode.shape_env.create_unbacked_symfloat()
elif is_integer_dtype(arg.dtype): elif is_integer_dtype(arg.dtype):
@ -418,6 +416,36 @@ def local_scalar_dense(fake_mode, func, arg):
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
@register_op_impl(lambda func: func is torch.ops.aten.nonzero.default)
def nonzero(fake_mode, func, arg):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
nnz = fake_mode.shape_env.create_unbacked_symint()
from torch.fx.experimental.symbolic_shapes import (
constrain_range,
definitely_true,
guard_int,
)
# This is unsound, but it works well in practice
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
# TODO: Add a config knob to turn off this unsound behavior
lower = 2
upper = None
# But don't give totally unsatisfiable bounds if we know it's too small!
if definitely_true(arg.numel() < 2):
lower = 0
upper = guard_int(arg.numel())
constrain_range(nnz, min=lower, max=upper)
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
# NB: this must be ordered after local_scalar_dense # NB: this must be ordered after local_scalar_dense
@register_op_impl( @register_op_impl(
lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined] lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined]
@ -451,10 +479,17 @@ def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
# index tensors with cuda self # index tensors with cuda self
@register_op_impl(aten.index.Tensor) @register_op_impl(aten.index.Tensor)
def index_tensor(fake_mode, func, *args, **kwargs): def index_tensor(fake_mode, func, *args, **kwargs):
# dynamic shape op if indices are bool/uint8 from torch._meta_registrations import meta_index_Tensor
check_no_bool_index_tensors(func, *args, **kwargs)
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) _, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
# ensure nonzero call goes to fake tensor
with fake_mode:
out = meta_index_Tensor(*args, **kwargs)
return out.to(out_device)
# takes in multiple-devices, dont default to default device handling # takes in multiple-devices, dont default to default device handling
@ -493,7 +528,15 @@ def conv(fake_mode, func, *args, **kwargs):
with fake_mode: with fake_mode:
# if the input is unsqueezed is done in Convolution.cpp we get segfault # if the input is unsqueezed is done in Convolution.cpp we get segfault
k = kwargs["weight"].ndim k = kwargs["weight"].ndim
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: batch = kwargs["input"].shape[0]
from torch.fx.experimental.symbolic_shapes import has_hint
if not has_hint(batch):
# TODO: We can make this a little more faithful with best effort
# channels last detection (but only if it's statically obvious!)
mem_fmt = None
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
mem_fmt = None mem_fmt = None
else: else:
if func is aten.convolution.default: if func is aten.convolution.default:

View File

@ -108,6 +108,81 @@ def hint_int(a):
assert type(a) is int, a assert type(a) is int, a
return a return a
def has_hint(a):
if isinstance(a, torch.SymInt):
return a.node.has_hint()
return True
# Returns True if every size dim on the tensor has a hint
# TODO: Should this include strides too? For now it doesn't matter,
# that's quite an obscure case
def tensor_has_hints(t):
return all(has_hint(s) for s in t.size())
def definitely_true(a):
"""
Returns True only if we can tell that a is True, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression to return True.
When is it appropriate to use definitely_true? First, if you can use
a higher level combinator like parallel_or/parallel_and, prefer using
those instead, they are definitely safe (modulo short-circuiting).
Second, it can be used if the program would behave equivalently if
definitely_true always returned False (parallel_or/parallel_and are
examples of this pattern, modulo short-circuiting). Finally, it even
be OK if the program wouldn't behave equivalently, so long as the
change is semantics preserving. It can be semantics preserving if
the program errors in more cases than it did previously (but otherwise
behaves identically), or if it changes some quantity in a way that
doesn't matter (e.g., strides often fall in this bucket.)
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return guard_bool(a)
else:
return False
return bool(a)
def definitely_false(a):
"""
Returns True only if we can tell that a is False, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression a to be False. See definitely_true
for more usage guidance.
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return not guard_bool(a)
else:
return False
return not bool(a)
# TODO: could improve parallel_or/parallel_and by avoiding guards
# if there exists a quantity that can be handled un-guardedly. However,
# for backed SymInts, avoiding guards doesn't really matter in practice,
# so I chose not to do it.
def parallel_or(*args):
"""
Evaluate the logical OR of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely True.
"""
if any(definitely_true(args) for a in args):
return True
return any(args)
def parallel_and(*args):
"""
Evaluate the logical FALSE of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely False.
"""
if any(definitely_false(args) for a in args):
return False
return all(args)
def guard_scalar(a): def guard_scalar(a):
if isinstance(a, (SymBool, bool)): if isinstance(a, (SymBool, bool)):
return guard_bool(a) return guard_bool(a)
@ -138,6 +213,34 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
) )
def constrain_unify(a, b):
"""
Given two SymInts, constrain them so that they must be equal. NB:
this will not work with SymInts that represent nontrivial expressions
(yet!)
"""
# TODO: Maybe dedupe this with _maybe_guard_eq?
if not isinstance(a, SymInt):
if not isinstance(b, SymInt):
assert a == b
else:
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
shape_env = b.node.shape_env
shape_env.replacements[b.node.expr] = sympy.Integer(a)
else:
# TODO: Actually, we can support this as long as one of them is a symbol.
# NB: We can't actually do "unification" as our operators are not
# injective
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
shape_env = a.node.shape_env
if not isinstance(b, SymInt):
shape_env.replacements[a.node.expr] = sympy.Integer(b)
else:
assert a.node.shape_env is b.node.shape_env
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
new_var = shape_env._find(a.node.expr)
shape_env.replacements[b.node.expr] = new_var
def guard_bool(a): def guard_bool(a):
if isinstance(a, SymBool): if isinstance(a, SymBool):
return a.node.guard_bool("", 0) # NB: uses Python backtrace return a.node.guard_bool("", 0) # NB: uses Python backtrace
@ -242,7 +345,13 @@ class SymNode:
# simplify it into a hint # simplify it into a hint
def _update_hint(self): def _update_hint(self):
if self._hint_expr.free_symbols <= self.shape_env.replacements.keys(): if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
self._hint = self.pytype(self.shape_env.replace(self._hint_expr)) new_hint = self.shape_env.replace(self._hint_expr)
# NB: unification constraints could result in a replacement that
# doesn't actually solve the hint! Check for this.
if new_hint.free_symbols:
self._hint_expr = new_hint
return
self._hint = self.pytype(new_hint)
self._hint_expr = None self._hint_expr = None
@property @property
@ -1076,6 +1185,7 @@ class ShapeEnv:
def __init__( def __init__(
self, *, self, *,
allow_scalar_outputs=True, allow_scalar_outputs=True,
allow_dynamic_output_shape_ops=True,
strict_mark_dyn=False, strict_mark_dyn=False,
assume_static_by_default=False, assume_static_by_default=False,
# The following options affect decisions we make about eager # The following options affect decisions we make about eager
@ -1093,6 +1203,7 @@ class ShapeEnv:
): ):
# Not directly used by ShapeEnv; indirectly used by FakeTensor # Not directly used by ShapeEnv; indirectly used by FakeTensor
self.allow_scalar_outputs = allow_scalar_outputs self.allow_scalar_outputs = allow_scalar_outputs
self.allow_dynamic_output_shape_ops = allow_dynamic_output_shape_ops
self.guards: List[ShapeGuard] = [] self.guards: List[ShapeGuard] = []
# Maps symbolic ints to their original concrete values # Maps symbolic ints to their original concrete values
# Currently populated from tensors # Currently populated from tensors
@ -1244,12 +1355,10 @@ class ShapeEnv:
if not dyn: if not dyn:
# Non explicitly marked dynamic dims register to val_to_var to get duck shaped # Non explicitly marked dynamic dims register to val_to_var to get duck shaped
self.val_to_var[val] = sympy_expr self.val_to_var[val] = sympy_expr
# We also infer that they must not be 0/1
lower = 2 if self.specialize_zero_one else 0 # We also infer that it must be not 0/1
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo) lower = 2 if self.specialize_zero_one else 0
else: self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
# Avoid up front 0/1 specializing dynamic dims
self.var_to_range[sympy_expr] = ValueRanges(0, sympy.oo)
if not dyn and self.duck_shape: if not dyn and self.duck_shape:
# This implements duck-shaping: input sizes that match are assigned # This implements duck-shaping: input sizes that match are assigned
@ -1556,15 +1665,29 @@ class ShapeEnv:
Tries to evaluate expr without introducing guards Tries to evaluate expr without introducing guards
""" """
expr = self.simplify(expr) expr = self.simplify(expr)
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
# Simplify making use of value range lower bound
symbols = list(expr.free_symbols) symbols = list(expr.free_symbols)
new_shape_env = { new_shape_env = {}
k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1 new_range_env = {}
for idx, k in enumerate(symbols) for idx, k in enumerate(symbols):
# Do not assume unbacked symints are > 1 vr = self.var_to_range[k]
# If we didn't specialize 0/1, this shape env is empty # Don't do anything if we don't have a nontrivial lower bound
if k in self.var_to_val and self.specialize_zero_one if vr.lower == -sympy.oo:
} new_range_env[k] = vr
continue
# Positive means >= 1
# Positive - 1 means >= 0
# Positive + lower - 1 means >= lower
# The new symbol 's' is "too low", so when we substitute it in
# we have to increase it by offset (and conversely, the new
# variables have to have their value range bounds adjusted as
# well)
s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
offset = vr.lower - 1
new_shape_env[k] = s + offset
new_range_env[s] = ValueRangeAnalysis.sub(vr, offset)
new_expr = expr.xreplace(new_shape_env) new_expr = expr.xreplace(new_shape_env)
floor_div_replace = {} floor_div_replace = {}
for atom in new_expr.atoms(FloorDiv): for atom in new_expr.atoms(FloorDiv):
@ -1574,17 +1697,7 @@ class ShapeEnv:
return new_expr return new_expr
# Check if the range can solve it statically # Check if the range can solve it statically
range_env = { out = sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)
s: self.var_to_range[s]
for s in expr.free_symbols
if not (s in self.var_to_val and self.specialize_zero_one)
}
range_env.update({
new_shape_env[s] - 1: ValueRangeAnalysis.sub(self.var_to_range[s], 1)
for s in expr.free_symbols
if s in self.var_to_val and self.specialize_zero_one
})
out = sympy_interp(ValueRangeAnalysis, range_env, new_expr)
if out.is_singleton(): if out.is_singleton():
return out.lower return out.lower
@ -1652,13 +1765,9 @@ class ShapeEnv:
""" """
result_expr = safe_expand(expr).xreplace(self.var_to_val) result_expr = safe_expand(expr).xreplace(self.var_to_val)
if len(result_expr.free_symbols) != 0: if len(result_expr.free_symbols) != 0:
range_env = { r = self._maybe_evaluate_static(result_expr)
s: self.var_to_range[s] if r is not None:
for s in result_expr.free_symbols return r
}
out = sympy_interp(ValueRangeAnalysis, range_env, result_expr)
if out.is_singleton():
return out.lower
raise self._make_data_dependent_error(result_expr) raise self._make_data_dependent_error(result_expr)
return result_expr return result_expr