mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
reinplace pass: special handling for view_scatter ops (#83846)
There is already special handling in the reinplacing pass for removing `{view}_scatter` ops, but there is another case that needs special handling. In this code: ``` def f(): a = torch.zeros(4, 4, 4) a[:, 2:] = torch.ones(4, 2, 4) return a ``` Tracing normally with `make_fx()` gives you: ``` def forward(self): zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None copy__default = torch.ops.aten.copy_.default(slice_tensor_1, ones); slice_tensor_1 = ones = None return zeros ``` Functionalizing it gives you: ``` def forward(self): zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, 9223372036854775807); slice_tensor_2 = ones = None slice_scatter_default_1 = torch.ops.aten.slice_scatter.default(zeros, slice_scatter_default, 0, 0, 9223372036854775807); zeros = slice_scatter_default = None return slice_scatter_default_1 ``` Notice that there are not any functional ops to directly re-inplace! What actually happened is that functionalization turned the `copy_()` into a `copy()`, but the out-of-place `copy()` operator gets optimized away because it's a no-op (when the input and output metadata are the same, `out = copy(a, b)` just returns `b`). What we actually want is to replace this line: ``` slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, ...); ``` with this: ``` new_slice = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, ...); _ = torch.ops.aten.copy_.default(new_slice, ones) ``` In the above, we're taking a fresh slice of the "base" tensor, and performing a `copy_()` on the slice, adding back what functionalization removed. We actually need to create a fresh "slice" node, because we're not guaranteed that one already exists in the graph (technically there should be one, but it might have been DCE'd by the time we hit re-inplacing) I also updated the docs for re-inplacing to more closely match the order of the logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83846 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
75ec7b7547
commit
8db04c1113
@ -259,8 +259,9 @@ def forward(self, a__1):
|
||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
|
||||
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
|
||||
select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 0); as_strided_default = add_tensor = None
|
||||
return select_scatter_default
|
||||
select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 0)
|
||||
copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None
|
||||
return as_strided_default
|
||||
""") # noqa: B950
|
||||
|
||||
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
|
||||
@ -291,8 +292,9 @@ def forward(self, a__1):
|
||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
|
||||
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None
|
||||
select_scatter_default = torch.ops.aten.select_scatter.default(as_strided_default, add_tensor, 0, 1); as_strided_default = add_tensor = None
|
||||
return select_scatter_default
|
||||
select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 1)
|
||||
copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None
|
||||
return as_strided_default
|
||||
""") # noqa: B950
|
||||
|
||||
|
||||
@ -322,5 +324,32 @@ def forward(self):
|
||||
return [zeros]
|
||||
""")
|
||||
|
||||
def test_reinplace_index_mutation(self):
|
||||
def f():
|
||||
a = torch.zeros(4, 4, 4)
|
||||
a[:, 2:] = torch.ones(4, 2, 4)
|
||||
return a
|
||||
|
||||
if not HAS_FUNCTIONALIZATION:
|
||||
return
|
||||
f2 = reinplace(make_fx(functionalize(f))())
|
||||
expected_out = f()
|
||||
actual_out = f2()
|
||||
self.assertEqual(actual_out, expected_out)
|
||||
self.assertExpectedInline(f2.code, """\
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
|
||||
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
|
||||
slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_tensor_3 = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, 9223372036854775807); slice_tensor_2 = None
|
||||
copy__default = torch.ops.aten.copy_.default(slice_tensor_3, ones); slice_tensor_3 = ones = None
|
||||
return zeros
|
||||
""")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -184,7 +184,7 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
||||
usage_nodes = t.users
|
||||
for n in usage_nodes:
|
||||
# We only care about usages after the current node
|
||||
if n.meta['node_idx'] <= op_index:
|
||||
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
|
||||
continue
|
||||
# We also don't care about intermediate view ops.
|
||||
# They only matter if their output is then used elsewhere
|
||||
@ -263,60 +263,61 @@ def reinplace(gm, *sample_args):
|
||||
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
|
||||
inputs to the program.
|
||||
|
||||
Given a node "b = foo(a, ...)", the algorithm for re-inplacing is as follows:
|
||||
Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
|
||||
|
||||
(1a) Check if foo has a mutating variant. If not, move to the next node.
|
||||
(1) Perform some initial checks on the metadata of "a" and "args..."
|
||||
that can disqualify them from being reinplaced.
|
||||
|
||||
Note that we ignore view ops (we don't bother to turn `as_strided()`
|
||||
into `as_strided_()`), as it complicates the algorithm and doesn't
|
||||
provide meaningful speedups.
|
||||
(1a) Check that the self argument we're attempting to reinplace
|
||||
has acceptable dtype/size metadata to reinplace with.
|
||||
|
||||
Currently, we also only check for an inplace op, `foo_`.
|
||||
Later, we should beef this up to check for out= or mutable ops.
|
||||
For example, if we have:
|
||||
a = torch.ones(1)
|
||||
b = torch.ones(10)
|
||||
out = torch.add(a, b)
|
||||
We can't turn that into
|
||||
a.add_(b)
|
||||
Because that would require resizing "a".
|
||||
|
||||
(1b) Check that the self argument we're attempting to reinplace
|
||||
has acceptable metadata to reinplace with.
|
||||
Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
|
||||
beause that would require changing a's dtype (from e.g. float32 to bool).
|
||||
Note that in this specific example, we could technically do better..
|
||||
|
||||
For example, if we have:
|
||||
a = torch.ones(1)
|
||||
b = torch.ones(10)
|
||||
out = torch.add(a, b)
|
||||
We can't turn that into
|
||||
a.add_(b)
|
||||
Because that would require resizing "a".
|
||||
If we see the pattern:
|
||||
a_1 = a.ge(b)
|
||||
a_2 = aten._to_copy(a_1, a.dtype)
|
||||
Then we this should be valid to completely re-inplace
|
||||
(this is exactly what functionalization will emit when it sees a.ge_(b)).
|
||||
|
||||
Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
|
||||
beause that would require changing a's dtype (from e.g. float32 to bool).
|
||||
Note that in this specific example, we could technically do better..
|
||||
This optimization is only really important for user programs
|
||||
that directly use inplace comparison ops though.
|
||||
|
||||
If we see the pattern:
|
||||
a_1 = a.ge(b)
|
||||
a_2 = aten._to_copy(a_1, a.dtype)
|
||||
Then we this should be valid to completely re-inplace
|
||||
(this is exactly what functionalization will emit when it sees a.ge_(b)).
|
||||
We also cannot re-inplace on tensors that have overlapping memory,
|
||||
e.g. torch.ones(1).expand(4, 4).add_(1)
|
||||
|
||||
This optimization is only really important for user programs
|
||||
that directly use inplace comparison ops though.
|
||||
(1b) Check if "a" is an alias of any of the program inputs.
|
||||
|
||||
We also cannot re-inplace on tensors that have overlapping memory,
|
||||
e.g. torch.ones(1).expand(4, 4).add_(1)
|
||||
If it is, skip and move to the next node.
|
||||
Inplace'ing an op that would cause it to mutate a program is not sound,
|
||||
because that would be a side effect visible to the user.
|
||||
|
||||
(2) Check if "a" is an alias of any of the program inputs.
|
||||
NOTE: there's a future optimization that we should make:
|
||||
if "a" is a (alias of a) program input, but later in the program
|
||||
there is a node that looks like "a.copy_(...)",
|
||||
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
|
||||
which will later be overwritten by the copy_() call.
|
||||
|
||||
If it is, skip and move to the next node.
|
||||
Inplace'ing an op that would cause it to mutate a program is not sound,
|
||||
because that would be a side effect visible to the user.
|
||||
This will be an important optimization to have for programs that mutate
|
||||
their inputs. It currently isn't implemented though.
|
||||
|
||||
NOTE: there's a future optimization that we should make:
|
||||
if "a" is a (alias of a) program input, but later in the program
|
||||
there is a node that looks like "a.copy_(...)",
|
||||
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
|
||||
which will later be overwritten by the copy_() call.
|
||||
(1c) Check if "a" and "args..." alias
|
||||
|
||||
This will be an important optimization to have for programs that mutate
|
||||
their inputs. It currently isn't implemented though.
|
||||
For example, re-inplacing to create code like the below
|
||||
isn't guaranteed to be sound:
|
||||
|
||||
(3) Check that "a" and all of its outstanding aliases are not used anywhere
|
||||
aten.mul_(a, a)
|
||||
|
||||
(2) Check that "a" and all of its outstanding aliases are not used anywhere
|
||||
later in the graph. If this is the case, then it's safe to re-inplace
|
||||
to "b = foo_(a)".
|
||||
|
||||
@ -380,8 +381,53 @@ def reinplace(gm, *sample_args):
|
||||
as_strided -> as_strided_scatter
|
||||
(ii) "args..." are the same between the foo() and foo_scatter() calls.
|
||||
|
||||
(4) Finally, after converting "b = foo(a)" into "foo_(a)",
|
||||
we need to find all later nodes that use "b" as an argument
|
||||
(3) Perform the actual re-inplacing on foo!
|
||||
|
||||
(3b) is the common case, but special care is needed for {view}_scatter (3a)
|
||||
|
||||
(3a) {view}_scatter ops.
|
||||
|
||||
Consider this program:
|
||||
a = torch.zeros(2, 2)
|
||||
b = torch.ones(2)
|
||||
a[0] = b
|
||||
|
||||
Post functionalization, that will look like:
|
||||
a = torch.zeros(2)
|
||||
b = torch.ones(1)
|
||||
a_updated = torch.select_scatter(a, b, 0, 0)
|
||||
|
||||
In this case though, there is no "functional" op to re-inplace!
|
||||
Instead, we'd like to directly remove toe select_scatter call.
|
||||
We already know from (3) that this is valid,
|
||||
because "a" has no later usages in the graph.
|
||||
|
||||
We perform the re-inplacing on the {view}_scatter op like so
|
||||
Before:
|
||||
a_updated = torch.select_scatter(a, b, args...)
|
||||
After:
|
||||
a_slice = a.select(a, args...)
|
||||
a_slice.copy_(b)
|
||||
|
||||
(3b) Otherwise, replace the functional op with its inplace variant.
|
||||
Before:
|
||||
b = foo(a, args...)
|
||||
After:
|
||||
a.foo_(args...)
|
||||
|
||||
(4) Finally, after converting either:
|
||||
Before:
|
||||
b = foo(a)
|
||||
After:
|
||||
foo_(a)
|
||||
or
|
||||
Before:
|
||||
b = {slice}_scatter(a, mutated_slice, args...)
|
||||
After:
|
||||
slice = {slice}(a, args...)
|
||||
slice.copy_(mutated_slice)
|
||||
|
||||
We now need to find all later nodes that use "b" as an argument
|
||||
and update them to take in "a" instead.
|
||||
|
||||
Note that for the majority of inplace ops, this isn't actually necessary
|
||||
@ -436,20 +482,26 @@ def reinplace(gm, *sample_args):
|
||||
tree_map(_add_to_map, n.meta['fake_result'])
|
||||
|
||||
# inplace-ify functional ops, subject to the constraints written below.
|
||||
all_later_view_inverse_node_usages = set()
|
||||
all_later_view_inverse_nodes_to_delete = set()
|
||||
for idx, node in enumerate(gm.graph.nodes):
|
||||
if node.op == 'call_function':
|
||||
# Step 1: Check to see if this operator has an inplace variant.
|
||||
maybe_inplace_op = _maybe_get_inplace_op(node.target)
|
||||
if maybe_inplace_op is None:
|
||||
continue
|
||||
# This is a proxy check for ensuring that the first argument is "tensor-like"
|
||||
# (This should be the case for all ops with inplace variants in ATen,
|
||||
# although we technically don't have guarantees for custom ops).
|
||||
assert len(node.target._schema.arguments) > 0
|
||||
assert 'Tensor' in str(node.target._schema.arguments[0].type)
|
||||
|
||||
# Step 1b: Check that the self argument we're attempting to reinplace
|
||||
# Today, the re-inplace pass on directly acts on:
|
||||
# - functional ops with an inplace variant
|
||||
# - {view}_scatter ops that can be potentially removed from the graph.
|
||||
# Both of these ops take in tensor first args, so filtering on this condition
|
||||
# makes the later code simpler.
|
||||
# We should revisit this at some point though, particularly when we also want
|
||||
# the reinplacer to be able to handle out= and mutable operators
|
||||
# and tensorlist first args (like `_foreach_` ops).
|
||||
if not isinstance(node.target, torch._ops.OpOverload):
|
||||
continue
|
||||
if len(node.target._schema.arguments) < 1:
|
||||
continue
|
||||
if type(node.target._schema.arguments[0].type) != torch.TensorType:
|
||||
continue
|
||||
|
||||
# Step 1a: Check that the self argument we're attempting to reinplace
|
||||
# has the same size/stride as the output.
|
||||
# For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
|
||||
# As it would require resizing scalar_tensor.
|
||||
@ -458,38 +510,36 @@ def reinplace(gm, *sample_args):
|
||||
self_arg = node.args[0]
|
||||
self_flattened, _ = tree_flatten(self_arg.meta['fake_result'])
|
||||
node_flattened, _ = tree_flatten(node.meta['fake_result'])
|
||||
assert len(self_flattened) == len(node_flattened)
|
||||
self_has_wrong_metadata = False
|
||||
for self_meta, node_meta in zip(self_flattened, node_flattened):
|
||||
if self_meta.numel() != node_meta.numel():
|
||||
self_has_wrong_metadata = True
|
||||
if self_meta.dtype != node_meta.dtype:
|
||||
self_has_wrong_metadata = True
|
||||
# We also cannot re-inplace on tensors that have internal memory overlap.
|
||||
# e.g. torch.ones(1).expand(4, 4).add_(1)
|
||||
if torch._debug_has_internal_overlap(self_meta) == 1:
|
||||
self_has_wrong_metadata = True
|
||||
if len(self_flattened) == len(node_flattened):
|
||||
for self_meta, node_meta in zip(self_flattened, node_flattened):
|
||||
if self_meta.numel() != node_meta.numel():
|
||||
self_has_wrong_metadata = True
|
||||
if self_meta.dtype != node_meta.dtype:
|
||||
self_has_wrong_metadata = True
|
||||
# We also cannot re-inplace on tensors that have internal memory overlap.
|
||||
# e.g. torch.ones(1).expand(4, 4).add_(1)
|
||||
if torch._debug_has_internal_overlap(self_meta) == 1:
|
||||
self_has_wrong_metadata = True
|
||||
# Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
|
||||
# Since users should never really be calling the functional "torch.ops.aten.resize"
|
||||
# op directly in their programs.
|
||||
if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
|
||||
continue
|
||||
|
||||
# Step 2: ensure that the op we're trying to re-inplace isn't a program i
|
||||
self_arg = node.args[0]
|
||||
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
|
||||
self_arg_name = self_arg.name
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
|
||||
if self_arg_storage in input_storages:
|
||||
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
||||
continue
|
||||
if len([x for x in node.args if x is self_arg]) > 1:
|
||||
# Step (3b) in the original description.
|
||||
# Step 1c:
|
||||
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
|
||||
# so we prevent re-inplacing in this case.
|
||||
continue
|
||||
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
|
||||
curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage())
|
||||
self_aliases = storage_to_nodes[self_arg_storage]
|
||||
|
||||
# First, we find all later usages of any of the aliases of self_arg.
|
||||
@ -498,27 +548,60 @@ def reinplace(gm, *sample_args):
|
||||
# that are safe to fully remove.
|
||||
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
|
||||
|
||||
# Step 3: Check to see if the input to the op is re-used later in the graph.
|
||||
# Step 2: Check to see if the input to the op is re-used later in the graph.
|
||||
# If not (same goes for its aliases), then this op is safe to re-in place.
|
||||
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
|
||||
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
|
||||
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
|
||||
if not can_reinplace:
|
||||
continue
|
||||
# Step 4: replace the current out-of-place op with its inplace variant.
|
||||
node.target = maybe_inplace_op
|
||||
|
||||
# Step 3a: Special handling for when we see *_scatter operators.
|
||||
# When we see an operator like `b = torch.slice_scatter(a, ...)`,
|
||||
# instead of trying to "inplace" it into a.slice_scatter_(..._),
|
||||
# we would prefer to remove it from the graph entirely,
|
||||
# and instead copy_() the slice directly into the larger tensor.
|
||||
# See the description of the algorithm for a full example.
|
||||
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
|
||||
view_op = _VIEW_INVERSE_MAP[node.target]
|
||||
# Before:
|
||||
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
|
||||
# After:
|
||||
# slice = torch.ops.aten.slice.default(base, args...)
|
||||
# slice.copy_(mutated_slice)
|
||||
with gm.graph.inserting_before(node):
|
||||
mutated_slice_node = node.args[1]
|
||||
remaining_slice_args = node.args[2:]
|
||||
slice_node = gm.graph.create_node(
|
||||
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
|
||||
copy_node = gm.graph.create_node(
|
||||
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
|
||||
# Add the slice_scatter node to our "nodes to delete" list.
|
||||
all_later_view_inverse_nodes_to_delete.add(node)
|
||||
|
||||
|
||||
else:
|
||||
# Step 3b: Check to see if this operator has an inplace variant.
|
||||
maybe_inplace_op = _maybe_get_inplace_op(node.target)
|
||||
if maybe_inplace_op is None:
|
||||
continue
|
||||
# And if so, replace it with its inplace variant.
|
||||
node.target = maybe_inplace_op
|
||||
|
||||
# At this point, 'storage_to_nodes' will be stale.
|
||||
# Now that we're inplacing `b = foo(a)`, we need to effectively
|
||||
# union together the dict values for b and a's storage.
|
||||
# Hmm... morally I think we also want to keep the `fake_result` metadata
|
||||
# up to date here, but I'm not sure how easy it is to do.
|
||||
# Maybe it's fine to wait until the end of the pass to update it.
|
||||
curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage())
|
||||
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
|
||||
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
|
||||
|
||||
# Need to remember the view_scatter view nodes we found so we can remove them alter.
|
||||
all_later_view_inverse_node_usages.update(later_view_inverse_node_usages)
|
||||
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
|
||||
|
||||
# Step 4:
|
||||
# Now that we've replaced b = a.foo() with a.foo_(),
|
||||
# We need to replace any later usages of "b" with "a"
|
||||
for old in itertools.chain([node], later_view_inverse_node_usages):
|
||||
@ -568,10 +651,10 @@ def reinplace(gm, *sample_args):
|
||||
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
|
||||
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
|
||||
|
||||
# Step 5: delete any _scatter nodes that we de-functionalized
|
||||
# Step 4: delete any _scatter nodes that we de-functionalized
|
||||
# Need to take care not to delete any of these nodes until after *all* modifications
|
||||
# to the graph are finished.
|
||||
for to_delete in all_later_view_inverse_node_usages:
|
||||
for to_delete in all_later_view_inverse_nodes_to_delete:
|
||||
gm.graph.erase_node(to_delete)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user