diff --git a/docs/source/fx.experimental.rst b/docs/source/fx.experimental.rst index e3496df82716..128c744940dd 100644 --- a/docs/source/fx.experimental.rst +++ b/docs/source/fx.experimental.rst @@ -51,6 +51,7 @@ torch.fx.experimental.symbolic_shapes compute_unbacked_bindings rebind_unbacked resolve_unbacked_bindings + is_accessor_node torch.fx.experimental.proxy_tensor ------------------------------------- @@ -65,3 +66,5 @@ torch.fx.experimental.proxy_tensor make_fx handle_sym_dispatch get_proxy_mode + maybe_enable_thunkify + maybe_disable_thunkify diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 086de9bff4e9..815cb2107578 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3480,8 +3480,8 @@ def forward(self, x): def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) arg0_1 = arg0 - slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3) sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0) + slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3) sub = sym_size_int - 1 slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 1c10fee5f716..a3dbb95c643a 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -2148,7 +2148,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase): self._validate_compile(fn, arg_fn) - @unittest.expectedFailure def test_return_shape(self): nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) diff --git a/test/export/test_export.py b/test/export/test_export.py index 82118f65ecf7..966a49044a6e 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6037,14 +6037,14 @@ def forward(self, x): def forward(self, x): item = torch.ops.aten.item.default(x); x = None sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None - ge = item >= 3 - _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 3 on node 'ge'"); ge = _assert_scalar_default = None + ge_1 = item >= 3 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le = item <= 5 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = _assert_scalar_default_1 = None - gt = item > 2 - _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 2 < u1 on node 'gt'"); gt = _assert_scalar_default_2 = None - lt = item < 6 - _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression u1 < 6 on node 'lt'"); lt = _assert_scalar_default_3 = None + gt_1 = item > 2 + _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 2 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_2 = None + lt_1 = item < 6 + _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = _assert_scalar_default_3 = None foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None return (foo_unbacked,)""", ) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 119064e1dd56..8b766f07dcb2 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1320,8 +1320,8 @@ def forward(self, token, obj, x): with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None - add = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None - return (getitem, add)""", # noqa: B950 + add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None + return (getitem, add_3)""", # noqa: B950 ) @parametrize("backend", ["eager", "aot_eager"]) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 2ab6503897fe..1a35e6823a18 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1056,8 +1056,8 @@ def forward(self, s0_1, s1_1, x_1, y_1): self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None - empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False) sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None + empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False) return ((sym_size_int, sym_size_int_1), empty)""") def test_unary(self): @@ -1355,8 +1355,8 @@ def forward(self, crop_camera_1, mask_1): view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None - mul = sym_size_int * 3 - view_3 = torch.ops.aten.view.default(view_2, [mul, 3]); view_2 = mul = None + mul_4 = sym_size_int * 3 + view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 83c52e595b19..05fa3a7d1b42 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1536,29 +1536,7 @@ class OutputGraph: return False return True - # NB: You could try to expand this to cover more cases by simply - # detecting whenever you have an int output, but this is a bit - # dangerous in case someone adds a function that returns an int but is - # mutating. So manually whitelist for now. - def is_accessor_node(node): - if ( - node.op == "call_method" - and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) - and node.target in ["size", "stride", "storage_offset", "item"] - ): - return True - if node.op == "call_function" and node.target in [ - torch.ops.aten.sym_size, - torch.ops.aten.sym_size.default, - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_stride, - torch.ops.aten.sym_stride.default, - torch.ops.aten.sym_stride.int, - torch.ops.aten.sym_storage_offset, - torch.ops.aten.sym_storage_offset.default, - ]: - return True - return False + from torch.fx.experimental.symbolic_shapes import is_accessor_node for node in reversed(list(self.graph.nodes)): if len(list(node.users)) == 0: diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index c629ffce30dc..63d372db774d 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -24,6 +24,10 @@ from torch import Tensor from torch._decomp.decompositions_for_rng import PhiloxStateTracker from torch._guards import detect_fake_mode from torch._prims_common import CUDARngStateHelper +from torch.fx.experimental.proxy_tensor import ( + maybe_disable_thunkify, + maybe_enable_thunkify, +) from torch.fx.experimental.symbolic_shapes import ( definitely_false, PropagateUnbackedSymInts, @@ -188,6 +192,7 @@ def fn_prepped_for_autograd( def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any: def inner_fn(primals: List[Any], tangents: List[Any]): outs, tangent_mask = fn(*primals) + assert len(tangent_mask) == len(outs) outs_to_grad = [ o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent @@ -365,266 +370,280 @@ def create_functionalized_fn( ) -> Any: @wraps(fn) def _functionalized_f_helper(*args): - # See Note [Disabling Functionalize TLS Above Python Functionalization] - disable_above = torch._C._ExcludeDispatchKeyGuard( - torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) - ) - - # See Note [Side-Effectful Tokens in AOTAutograd] - if trace_joint: - assert ( - isinstance(args, tuple) - and len(args) == 2 - and isinstance(args[0], (list, tuple)) + with maybe_enable_thunkify(): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ) - tokens = args[0][: len(meta.tokens)] - actual_args = args[0][len(meta.tokens) :] - args = (actual_args, args[1]) - else: - tokens = args[: len(meta.tokens)] - args = args[len(meta.tokens) :] - assert all(token.numel() == 0 for token in tokens) - with disable_above: - # Wrap inputs into functional wrappers - f_args = pytree.tree_map(to_fun, args) - f_tokens = pytree.tree_map(to_fun, tokens) - - # Populate the current FunctionalTensorMode with the tokens per - # operator. See Note [FunctionalTensorMode is Stateful] - functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( - torch._C._TorchDispatchModeKey.FUNCTIONAL - ) - assert functional_tensor_mode is not None - for i, k in enumerate(meta.tokens.keys()): - functional_tensor_mode._tokens[k] = f_tokens[i] - - # Run the joint - f_outs = fn(*f_args) - - # Return both the tokens and the outputs # See Note [Side-Effectful Tokens in AOTAutograd] - f_outs = (*functional_tensor_mode._tokens.values(), *f_outs) - - if trace_joint: - # We support a limited amount of mutation of graph inputs during the backward pass. - # (This is used e.g. by Float8, which needs to update buffers during the backward pass) - # Here, we perform extra checks for primals that were mutated in the **backward** - # We're doing the checks here instead of doing them with the rest of the input mutation handling because: - # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened - # during the forward, because the handling is different: some input mutations from the the forward - # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same - # types of mutations in the backward we would need a bw-only runtime epilogue. - # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in - # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would - # require an extra round of tracing though, so it's more efficient to do in-line here. - assert ( - isinstance(args, tuple) - and len(args) == 2 - and isinstance(args[0], (list, tuple)) - ) - # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) - primals_before = args[0] - primals_after = pytree.tree_map(from_fun, f_args[0]) - for idx, (f_inpt, before, after, inpt_info) in enumerate( - zip(f_args[0], primals_before, primals_after, meta.input_info) - ): - # Store information about mutations in joint(for backward analysis) - joint_mutates_data = has_data_mutation(f_inpt) - - joint_mutates_metadata = has_metadata_mutation( - f_inpt, before, check_only_storage_mutation=False + if trace_joint: + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) ) + tokens = args[0][: len(meta.tokens)] + actual_args = args[0][len(meta.tokens) :] + args = (actual_args, args[1]) + else: + tokens = args[: len(meta.tokens)] + args = args[len(meta.tokens) :] + assert all(token.numel() == 0 for token in tokens) - # Ban metadata mutations on fw inputs during the bw - if not inpt_info.mutates_metadata: - assert ( - not joint_mutates_metadata - ), "Found a graph input that had its metadata mutated in the backward. This is not supported" + with disable_above: + # The functionalization code here can potentially trigger traces + # into the graph, but we'd prefer to NOT do this, because if we + # trace them now, we will end up with FX nodes that don't have + # module stack annotations, which makes unflattener unhappy. + # Wrap inputs into functional wrappers + f_args = pytree.tree_map(to_fun, args) + f_tokens = pytree.tree_map(to_fun, tokens) - # Ban storage resizing on fw inputs during the bw - if not inpt_info.mutation_inductor_storage_resize: - assert not was_inductor_storage_resized( - f_inpt - ), "Found a graph input that had storage resizing in the backward. This is not supported" - - # Allow data mutations on fw inputs during the bw, but only if they do not require grad - # So we can guarantee that we can keep the mutations in the graph - if ( - joint_mutates_data - and not inpt_info.mutates_data - and not inpt_info.mutates_storage_metadata - ): - # Not banning here mutations on inpt_info.requires_grad - - # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) - # Add node meta for copy_ for partitioner that this node should be in backward graph. - with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag( - "must_be_in_backward" - ): - before.copy_(after) - meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( - idx + # Populate the current FunctionalTensorMode with the tokens per + # operator. See Note [FunctionalTensorMode is Stateful] + functional_tensor_mode = ( + torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL ) - # Now that we covered mutations to *forward* inputs during the backward, - # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). - # Today, we will just error in all cases of this happening unless someone needs us to support it. - tangents_before = args[1] - tangents_after = pytree.tree_map(from_fun, f_args[1]) - for f_inpt, before, after in zip( - f_args[1], tangents_before, tangents_after - ): - assert not has_metadata_mutation( - f_inpt, before, check_only_storage_mutation=False - ) and not has_data_mutation( - f_inpt - ), "Found an input to the backward that was mutated during the backward pass. This is not supported" + ) + assert functional_tensor_mode is not None + for i, k in enumerate(meta.tokens.keys()): + functional_tensor_mode._tokens[k] = f_tokens[i] - if aot_config.keep_inference_input_mutations: - # Note: This is a bit annoying. There's a layering issue here, where: - # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. - # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. - # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, - # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). - # This makes it pretty difficult for this logic to operate on synthetic bases. - # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual - # (unpacked) input aliases, instead of the synthetic base. - # Example case where (3) could be important: - # - # def f(x, y): - # x.mul_(2) - # y.mul_(3) - # return x, y - # a = torch.ones(1'000'000) - # x, y = out(a[0:9], a[1:10]) - # - # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing - # a giant "updated synthetic base" and copying into a's entire storage. - # - # For now, we are pessimistically not performing the optimization from (3); - # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. - # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry - # about synthetic bases. - for i, (inpt_old, inpt_f) in enumerate( - zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) - ): - if not isinstance(inpt_f, torch.Tensor): - continue - assert is_fun(inpt_f) - inpt_new = from_fun(inpt_f) - if meta.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH: - # See Note [set_() Input Mutations in AOTAutograd] - # all mutations on the input must be under no_grad, so it is safe to put in the graph - # Here, we're saying that if an input experienced a set call, inp.set_(other), - # then we can effectively not have to worry about whether its data was mutated. - # There are 3 cases: - # (1) We mutate inp *after* the set_() call. other is a graph intermediate. - # In this case, we're not really mutating the input storage of "inp"; - # we're mutating the storage of an intermdiate value (other), - # and slamming that storage into the input tensor. So no data mutation is necessary. - # (2) We mutate inp *after* the set_() call. other is a graph *input*. - # In this case, the data mutation will be properly handled in the runtime - # epilogue during the processing of "other" - # (3) We mutate inp *before* the set_() call. - # This case is *not* currently handled. - if meta.input_info[i].mutates_storage_metadata: - with torch.no_grad(): - inpt_old.set_(inpt_new) + # Run the joint + f_outs = fn(*f_args) - # Note [Ordering of resize_() and set_()] - # Importantly: the common usage in FSDP is that we have a dummy parameter - # that sees a set_() and **Then** a resize_(). - # We must put those mutations into the graph in the same order, - # Since running them in the opposite order will have different behavior. - # We fully ban resize_() followed by set_() for now, although in principal - # we could support this - if meta.input_info[i].mutation_inductor_storage_resize: - # resizing is not supported on subclasses (we error earlier if this happens) - from torch._subclasses.functional_tensor import FunctionalTensor + # Return both the tokens and the outputs + # See Note [Side-Effectful Tokens in AOTAutograd] + f_outs = (*functional_tensor_mode._tokens.values(), *f_outs) - assert isinstance(inpt_f, FunctionalTensor) - old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] - inpt_f.elem, before=True - ) - new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] - inpt_f.elem, before=False - ) - if old_storage_size != new_storage_size: - assert ( - old_storage_size == 0 or new_storage_size == 0 - ), f"""\ -Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} -We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0 -(the case for FSDP)""" - torch.ops.inductor.resize_storage_bytes_( - inpt_old, new_storage_size - ) - if new_storage_size == 0: - # Even if we marked the input as having a data mutation (thus needing a copy_()), - # We should **ignore** it if our input has no storage - # (this can happen if, e.g. we temporarily resize our input, copy data into it, - # and resize it back down to zero) - continue - # Optimization: if the copy_() is a no-op then don't include it in the graph. - # In theory inductor could optimize this away, however in fsdp, we end up with - # param.copy_(param), where param is a zero-storage-size tensor, - # and running this op in eager mode (using the aot_eager backend) will result in a segfault. - # So we may as well optimize it away here. - if inpt_old is inpt_new: - # (This check needs to be done after putting resize_() in the graph, - # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) - continue - # We found an input that had a (data-only) mutation. - # Since keep_input_mutations is set, we need to faithfully apply a copy_() - # so the compiler will see the input mutation in the graph. - if ( - meta.input_info[i].mutates_data - and meta.input_info[i].mutations_hidden_from_autograd - ): - # Hidden from autograd = run under no_grad, **and** don't bump VC - # (although if the tensor was created in inference mode, it has no VC) - if inpt_old.is_inference(): - maybe_preserve_vc = nullcontext() - else: - maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( - inpt_old # type: ignore[assignment] - ) - with torch.no_grad(), maybe_preserve_vc: - inpt_old.copy_(inpt_new) - elif ( - meta.input_info[i].mutates_data - and meta.input_info[i].mutations_under_no_grad_or_inference_mode - ): - # Under no_grad = run under no_grad (we still bump the VC though) - # (inference_mode will also bump the VC, as long as the tensor in question - # was created outside of inference_mode) - with torch.no_grad(): - inpt_old.copy_(inpt_new) - elif meta.input_info[i].mutates_data: - inpt_old.copy_(inpt_new) - - # When an output tensor is a functionalized mutated input, and we - # were able to move the mutation in to the graph then we can return - # the mutated input directly. This prevents duplicating the - # tensors contents. - flat_outs, outs_spec = pytree.tree_flatten(f_outs) - flat_outs = [from_fun(o) for o in flat_outs] - num_outs = len(meta.output_info) - - for i, outp in enumerate(flat_outs[:num_outs]): - info = meta.output_info[i] - if info.output_type != OutputType.is_input: - continue - - assert info.base_idx is not None - if ( - meta.input_info[info.base_idx].mutation_type - == MutationType.MUTATED_IN_GRAPH + if trace_joint: + # We support a limited amount of mutation of graph inputs during the backward pass. + # (This is used e.g. by Float8, which needs to update buffers during the backward pass) + # Here, we perform extra checks for primals that were mutated in the **backward** + # We're doing the checks here instead of doing them with the rest of the input mutation handling because: + # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened + # during the forward, because the handling is different: some input mutations from the the forward + # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same + # types of mutations in the backward we would need a bw-only runtime epilogue. + # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in + # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would + # require an extra round of tracing though, so it's more efficient to do in-line here. + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) + primals_before = args[0] + primals_after = pytree.tree_map(from_fun, f_args[0]) + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip(f_args[0], primals_before, primals_after, meta.input_info) ): - fw_args = args[0] if trace_joint else args - flat_outs[i] = fw_args[info.base_idx] - return pytree.tree_unflatten(flat_outs, outs_spec) + # Store information about mutations in joint(for backward analysis) + joint_mutates_data = has_data_mutation(f_inpt) - return pytree.tree_map(from_fun, f_outs) + joint_mutates_metadata = has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) + + # Ban metadata mutations on fw inputs during the bw + if not inpt_info.mutates_metadata: + assert ( + not joint_mutates_metadata + ), "Found a graph input that had its metadata mutated in the backward. This is not supported" + + # Ban storage resizing on fw inputs during the bw + if not inpt_info.mutation_inductor_storage_resize: + assert not was_inductor_storage_resized( + f_inpt + ), "Found a graph input that had storage resizing in the backward. This is not supported" + + # Allow data mutations on fw inputs during the bw, but only if they do not require grad + # So we can guarantee that we can keep the mutations in the graph + if ( + joint_mutates_data + and not inpt_info.mutates_data + and not inpt_info.mutates_storage_metadata + ): + # Not banning here mutations on inpt_info.requires_grad - + # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) + # Add node meta for copy_ for partitioner that this node should be in backward graph. + with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag( + "must_be_in_backward" + ): + before.copy_(after) + meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( + idx + ) + # Now that we covered mutations to *forward* inputs during the backward, + # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). + # Today, we will just error in all cases of this happening unless someone needs us to support it. + tangents_before = args[1] + tangents_after = pytree.tree_map(from_fun, f_args[1]) + for f_inpt, before, after in zip( + f_args[1], tangents_before, tangents_after + ): + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) and not has_data_mutation( + f_inpt + ), "Found an input to the backward that was mutated during the backward pass. This is not supported" + + if aot_config.keep_inference_input_mutations: + # Note: This is a bit annoying. There's a layering issue here, where: + # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. + # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. + # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, + # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). + # This makes it pretty difficult for this logic to operate on synthetic bases. + # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual + # (unpacked) input aliases, instead of the synthetic base. + # Example case where (3) could be important: + # + # def f(x, y): + # x.mul_(2) + # y.mul_(3) + # return x, y + # a = torch.ones(1'000'000) + # x, y = out(a[0:9], a[1:10]) + # + # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing + # a giant "updated synthetic base" and copying into a's entire storage. + # + # For now, we are pessimistically not performing the optimization from (3); + # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. + # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry + # about synthetic bases. + for i, (inpt_old, inpt_f) in enumerate( + zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) + ): + if not isinstance(inpt_f, torch.Tensor): + continue + assert is_fun(inpt_f) + inpt_new = from_fun(inpt_f) + if ( + meta.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + # See Note [set_() Input Mutations in AOTAutograd] + # all mutations on the input must be under no_grad, so it is safe to put in the graph + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + if meta.input_info[i].mutates_storage_metadata: + with torch.no_grad(): + inpt_old.set_(inpt_new) + + # Note [Ordering of resize_() and set_()] + # Importantly: the common usage in FSDP is that we have a dummy parameter + # that sees a set_() and **Then** a resize_(). + # We must put those mutations into the graph in the same order, + # Since running them in the opposite order will have different behavior. + # We fully ban resize_() followed by set_() for now, although in principal + # we could support this + if meta.input_info[i].mutation_inductor_storage_resize: + # resizing is not supported on subclasses (we error earlier if this happens) + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + ) + + assert isinstance(inpt_f, FunctionalTensor) + old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + inpt_f.elem, before=True + ) + new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + inpt_f.elem, before=False + ) + if old_storage_size != new_storage_size: + assert ( + old_storage_size == 0 or new_storage_size == 0 + ), f"""\ + Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} + We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0 + (the case for FSDP)""" + torch.ops.inductor.resize_storage_bytes_( + inpt_old, new_storage_size + ) + if new_storage_size == 0: + # Even if we marked the input as having a data mutation (thus needing a copy_()), + # We should **ignore** it if our input has no storage + # (this can happen if, e.g. we temporarily resize our input, copy data into it, + # and resize it back down to zero) + continue + # Optimization: if the copy_() is a no-op then don't include it in the graph. + # In theory inductor could optimize this away, however in fsdp, we end up with + # param.copy_(param), where param is a zero-storage-size tensor, + # and running this op in eager mode (using the aot_eager backend) will result in a segfault. + # So we may as well optimize it away here. + if inpt_old is inpt_new: + # (This check needs to be done after putting resize_() in the graph, + # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + continue + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + if ( + meta.input_info[i].mutates_data + and meta.input_info[i].mutations_hidden_from_autograd + ): + # Hidden from autograd = run under no_grad, **and** don't bump VC + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: + inpt_old.copy_(inpt_new) + elif ( + meta.input_info[i].mutates_data + and meta.input_info[ + i + ].mutations_under_no_grad_or_inference_mode + ): + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + with torch.no_grad(): + inpt_old.copy_(inpt_new) + elif meta.input_info[i].mutates_data: + inpt_old.copy_(inpt_new) + + # When an output tensor is a functionalized mutated input, and we + # were able to move the mutation in to the graph then we can return + # the mutated input directly. This prevents duplicating the + # tensors contents. + flat_outs, outs_spec = pytree.tree_flatten(f_outs) + flat_outs = [from_fun(o) for o in flat_outs] + num_outs = len(meta.output_info) + + for i, outp in enumerate(flat_outs[:num_outs]): + info = meta.output_info[i] + if info.output_type != OutputType.is_input: + continue + + assert info.base_idx is not None + if ( + meta.input_info[info.base_idx].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + fw_args = args[0] if trace_joint else args + flat_outs[i] = fw_args[info.base_idx] + return pytree.tree_unflatten(flat_outs, outs_spec) + + return pytree.tree_map(from_fun, f_outs) # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" # and "tangents" as its input names (which are special-cased by the partitioner) @@ -709,10 +728,14 @@ def aot_dispatch_subclass( return unwrapped_outs def joint_fn(primals, tangents): - return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True) + with maybe_enable_thunkify(): + return inner_fn( + flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True + ) def fw_fn(*primals): - return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) + with maybe_enable_thunkify(): + return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) def metadata_fn(*primals): return inner_fn(fw_only, primals, use_trace_joint=False) @@ -771,7 +794,7 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False): def functional_call(*args, **kwargs): with stateless._reparametrize_module( mod, pytree.tree_unflatten(args[:params_len], params_spec) - ): + ), maybe_disable_thunkify(): if isinstance(mod, torch.fx.GraphModule): with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): warnings.filterwarnings( diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index ce8625716798..786efd8c5392 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -54,6 +54,7 @@ from weakref import WeakKeyDictionary if TYPE_CHECKING: import types + import sympy from torch._ops import OpOverload from torch.fx._symbolic_trace import PHBase @@ -61,7 +62,8 @@ if TYPE_CHECKING: __all__ = [ "PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", - "py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch" + "py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch", + "maybe_enable_thunkify", "maybe_disable_thunkify", ] _ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] @@ -157,6 +159,7 @@ def set_proxy_slot( tracer: _ProxyTracer, proxy: object ) -> None: + log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy) if isinstance(obj, Tensor): # We DO want to clobber proxies whenever we run an inplace operation # on a tensor, and it affects the metadata on the proxy. @@ -175,6 +178,22 @@ def set_proxy_slot( if obj not in tracer.symnode_tracker: tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) + # WAR: python test/dynamo/test_subclasses.py + # TestNestedTensor.test_basic_autograd + # + # AOTAutograd doesn't pass the "outer sizes" as an actual argument + # to make_fx, but it is made use of internally in AOTAutograd's + # call to tensor unflatten. Because the outer sizes isn't passed + # as an argument, it is therefore untracked. However, it turns + # out you luck out, because *Dynamo* will manually add the outer + # sizes as an argument so you can fix up the proxy'ness. + # + # This is probably fixed in + # https://github.com/pytorch/pytorch/pull/125941/ + import sympy + if isinstance(obj.node.expr, sympy.Symbol): + tracer.sympy_expr_tracker[obj.node.expr] = proxy + def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: assert isinstance(obj, (Tensor, SymNode)), type(obj) return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) @@ -276,10 +295,15 @@ def get_proxy_slot( tracker = tracer.symnode_tracker if obj not in tracker: - if isinstance(default, _NoDefault): - raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}") - return default - value = tracker[obj] + # Last ditch + if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: + value = tracer.sympy_expr_tracker[obj.node.expr] + else: + if isinstance(default, _NoDefault): + raise RuntimeError(f"{obj} ({id(obj)})is not tracked with proxy for {tracer}") + return default + else: + value = tracker[obj] res = transform(value) return res @@ -330,6 +354,54 @@ def extract_val(val: _ExtractValType) -> _ExtractValType: typing_extensions.assert_never(val) +@contextmanager +def _enable_thunkify(tracer: _ProxyTracer, *, enable: bool = True) -> Generator[None, None, None]: + """ + Enable thunkification inside the context manager. Thunkification prevents + SymNode computation from directly being traced into an FX graph; instead, + the compute is only added to the graph if it is actually used. This helps + us track SymNode compute when it is computed (since we need /something/ + to put in the tracker) even if it is unlikely to be used. + """ + old = tracer.enable_thunkify + tracer.enable_thunkify = enable + try: + yield + finally: + tracer.enable_thunkify = old + +@contextmanager +def maybe_disable_thunkify() -> Generator[None, None, None]: + """Within a context, disable thunkification. See :func:`maybe_enable_thunkify` + for more details. This is helpful if you have a wrapper function which + you want to enable thunkification on, but in some segment on the inside (say, + the original user function), you want to disable thunkification as you know + it is not needed there. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer, enable=False): + yield + else: + yield + +@contextmanager +def maybe_enable_thunkify() -> Generator[None, None, None]: + """Within this context manager, if you are doing make_fx tracing, we will thunkify + all SymNode compute and avoid tracing it into the graph unless it is actually needed. + You should prefer to avoid using this as much as possible, as lazy evaluation of + SymNode tracing can lead to long chains of thunks which will stack overflow + if you evaluate them. However, this is currently sometimes necessary as there + are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error + due to insufficient tracing of SymNode computation. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer): + yield + else: + yield + # Note [invariants for node meta 'val'] # What invariants do we have for the 'val' set on the FX node? It has accurate # metadata... but only for metadata that exists "below" all other subsystems @@ -340,19 +412,24 @@ def extract_val(val: _ExtractValType) -> _ExtractValType: def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: proxy.node.meta['val'] = extract_val(val) - # Best effort tensor_meta setting; prefer using val! - if is_fake(val): - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) - elif isinstance(val, Tensor) and not val.is_sparse: - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) + with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] + # Best effort tensor_meta setting; prefer using val! + if is_fake(val): + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) + elif isinstance(val, Tensor) and not val.is_sparse: + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) return proxy -def thunkify(f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]: +def thunkify(tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs) -> Thunk[R]: """ Delays computation of f until it's called again Also caches the result """ - return Thunk(functools.partial(f, *args, **kwargs)) + if tracer.enable_thunkify: + return Thunk(functools.partial(f, *args, **kwargs)) + else: + r = f(*args, **kwargs) + return Thunk(lambda: r) def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer) -> None: def try_set_proxy_slot( @@ -363,7 +440,8 @@ def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tr ) -> None: assert callable(proxy_callable) if isinstance(outer_s, SymInt): - set_proxy_slot(outer_s, tracer, thunkify(proxy_callable, outer_s, *args, **kwargs)) + with _enable_thunkify(tracer): + set_proxy_slot(outer_s, tracer, thunkify(tracer, proxy_callable, outer_s, *args, **kwargs)) # The basic idea is that we need to associate each tensor/SymInt # with a Proxy. How do we setup this association? We just store # the proxy on the proxy slot of the object, keyed on the tracer @@ -411,7 +489,7 @@ def track_tensor_tree( assert isinstance(proxy, Proxy) # NB: eagerly set meta here, so that the numbering is in order set_meta(proxy, e) - set_proxy_slot(e, tracer, thunkify(lambda: proxy)) + set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) elif isinstance(e, _AnyScriptObject): assert isinstance(proxy, Proxy) set_proxy_slot(e, tracer, proxy) @@ -700,7 +778,8 @@ def proxy_call( # Adding an undefined attribute to Tensor? args[0].proxy = proxy_out # type: ignore[attr-defined] - out = func(*args, **kwargs) + with _enable_thunkify(proxy_mode.tracer): + out = func(*args, **kwargs) # In some circumstances, we will be tracing in a situation where a tensor # is *statically* known to be a constant (currently, this only happens if @@ -804,12 +883,14 @@ class PythonKeyTracer(Tracer): self.tensor_tracker = WeakTensorKeyDictionary() self.symnode_tracker = _SymNodeDict() self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef) + self.sympy_expr_tracker: Dict[sympy.Symbol, object] = dict() # Stores the torch function that was called during tracing self.torch_fn_metadata = None # Stores the counts for every torch function called. This is to help # distinguish between different calls to the same torch function. self.torch_fn_counts = {} + self.enable_thunkify = False # In general, we don't want to make modules leaves. In principle, users of # this tracer might want to override this in order to turn a couple specific @@ -901,6 +982,32 @@ def dispatch_trace( concrete_args: Optional[Tuple[Any, ...]] = None, ) -> GraphModule: graph = tracer.trace(root, concrete_args) + + # NB: be careful not to DCE .item() calls + def impure_pred(n: fx.Node) -> bool: + from .symbolic_shapes import is_accessor_node + + # Always defer to the built-in notion of impure + if n.is_impure(): + return True + + # Accessors always OK to DCE + if is_accessor_node(n): + return False + + # If the operator in question takes SymInt args to SymInt output, + # we assume it's pure and OK to DCE + if ( + isinstance(n.meta.get('val'), py_sym_types) and + # NB: constant args ok + all(isinstance(a.meta.get('val'), py_sym_types) for a in n.args if isinstance(a, fx.Node)) + ): + return False + + # No idea, just assume it's not OK + return True + + graph.eliminate_dead_code(impure_pred) from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints dedupe_symints(graph) name = root.__class__.__name__ if isinstance(root, Module) else root.__name__ @@ -945,6 +1052,7 @@ def wrap_key(f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dis return wrapped +# TODO: Make downstream users of this work with OperatorBase ORIGINAL_ATEN: Optional[object] = None @contextmanager def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]: @@ -1129,7 +1237,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode): # were symbolic) and it is no longer necessary to trace the # computation. This could occur if func triggered some guards. if isinstance(out, py_sym_types): - p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) + p_out_thunk = thunkify(self.tracer, self._compute_proxy, func=func, args=args, out=out) set_proxy_slot(out, self.tracer, p_out_thunk) return out @@ -1139,8 +1247,10 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): script_object_tracker: WeakKeyDictionary symnode_tracker: WeakKeyDictionary tensor_tracker: WeakTensorKeyDictionary + sympy_expr_tracker: Dict[sympy.Symbol, object] torch_fn_metadata: Optional[OpOverload] torch_fn_counts: Dict[OpOverload, int] + enable_thunkify: bool = False # TODO: I'm not sure what the point of this class is; you can just @@ -1159,6 +1269,7 @@ class DecompositionInterpreter(fx.Interpreter): # Blegh self.tracer.tensor_tracker = WeakTensorKeyDictionary() self.tracer.symnode_tracker = weakref.WeakKeyDictionary() + self.tracer.sympy_expr_tracker = dict() self.decomposition_table = decomposition_table or {} self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") @@ -1358,6 +1469,7 @@ class _ModuleStackTracer(PythonKeyTracer): concrete_args: Optional[Dict[str, object]] ) -> fx.Graph: res = super().trace(root, concrete_args) + # Since we are making _AttrProxy mimic the original # submodule, when someone registers a module directly # to the tracer while tracing, the proxy object gets registered diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a3e6d0e97339..1ec35fe27f51 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -110,7 +110,7 @@ __all__ = [ "StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true", "guard_size_oblivious", "check_consistent", "compute_unbacked_bindings", "ConvertIntKey", - "rebind_unbacked", "resolve_unbacked_bindings", + "rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node", ] # FX node metadata keys for symbolic shape FX graph. @@ -337,6 +337,32 @@ def rebind_unbacked(shape_env, n: torch.fx.Node, result): # Reuse the OLD symbol name shape_env._rename_unbacked_to(raw_u1, raw_u0) +# NB: You could try to expand this to cover more cases by simply +# detecting whenever you have an int output, but this is a bit +# dangerous in case someone adds a function that returns an int but is +# mutating. So manually whitelist for now. +def is_accessor_node(node: torch.fx.Node) -> bool: + # Dynamo only exercised condition + if ( + node.op == "call_method" + and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) + and node.target in ["size", "stride", "storage_offset", "item"] + ): + return True + if node.op == "call_function" and node.target in [ + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.default, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_storage_offset, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.sym_numel.default, + ]: + return True + return False + def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: r""" Canonicalize a boolean expression by transforming it into a lt / le inequality and moving all the non-constant terms to the rhs. diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 166082f44f54..6d33461595c6 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -8,6 +8,7 @@ import torch import inspect import operator import collections +import logging from dataclasses import is_dataclass, fields @@ -25,6 +26,9 @@ __all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 'ScopeContextManager'] +log = logging.getLogger(__name__) + + @compatibility(is_backward_compatible=False) class Scope: """ Scope object that records the module path and the module type @@ -136,6 +140,7 @@ class TracerBase: modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ + if kind == 'call_function' and self.check_mutable_operations: check_for_mutable_operation(target, args, kwargs) @@ -175,6 +180,8 @@ class TracerBase: elif self.module_stack: node.meta['nn_module_stack'] = copy.copy(self.module_stack) + + log.debug("create_node %s", node) return node @compatibility(is_backward_compatible=True) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 5c6f880608b8..215769ca91ff 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -302,14 +302,19 @@ class NestedTensor(torch.Tensor): if kwargs is None: kwargs = {} + from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify + from .ops import jagged_torch_function - try: - return jagged_torch_function(func, *args, **kwargs) - except NotImplementedError: - pass - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) + # This should be removed after + # https://github.com/pytorch/pytorch/pull/125941/ lands + with maybe_enable_thunkify(): + try: + return jagged_torch_function(func, *args, **kwargs) + except NotImplementedError: + pass + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) # NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!