Enable max.unary_out (#86855)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86855
Approved by: https://github.com/jerryzh168, https://github.com/bdhirsh
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2022-10-12 17:24:38 -07:00
committed by PyTorch MergeBot
parent 25811663af
commit cff333bdb5
8 changed files with 55 additions and 18 deletions

View File

@ -34,9 +34,16 @@ Tensor max(const Tensor &self) {
}
Tensor& max_unary_out(const Tensor &self, Tensor& out) {
Tensor tmp_output = at::max(self);
at::native::resize_output(out, tmp_output.sizes());
out.copy_(tmp_output);
// First check if the devices match (CPU vs GPU)
TORCH_CHECK(self.device() == out.device());
TORCH_CHECK(canCast(
typeMetaToScalarType(self.dtype()),
typeMetaToScalarType(out.dtype())));
at::native::resize_output(out, {});
max_all_stub(self.device().type(), out, self.contiguous());
return out;
}

View File

@ -8789,13 +8789,6 @@
MPS: max_mps
QuantizedCPU: max_quantized_cpu
# Not to be confused with binary op `max.out`. Commented because of failed CI
# FIXME: enable this
#- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
# device_check: NoCheck # TensorIterator
# dispatch:
# CompositeExplicitAutograd: max_unary_out
- func: fmax(Tensor self, Tensor other) -> Tensor
structured_delegate: fmax.out
device_check: NoCheck # TensorIterator
@ -8831,6 +8824,12 @@
- func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: max_unary_out
QuantizedCPU: max_quantized_unary_out
- func: minimum(Tensor self, Tensor other) -> Tensor
structured_delegate: minimum.out
device_check: NoCheck # TensorIterator

View File

@ -14,6 +14,19 @@ Tensor max_quantized_cpu(const Tensor& self) {
return std::get<0>(self.reshape({-1}).max(/*dim=*/0));
}
Tensor& max_quantized_unary_out(const Tensor& self, Tensor& out) {
// TODO this implementation is inefficient for now.
TORCH_CHECK(self.device() == out.device());
TORCH_CHECK(canCast(
typeMetaToScalarType(self.dtype()),
typeMetaToScalarType(out.dtype())));
Tensor temp = max_quantized_cpu(self);
at::native::resize_output(out, temp.sizes());
out.copy_(temp);
return out;
}
Tensor min_quantized_cpu(const Tensor& self) {
return std::get<0>(self.reshape({-1}).min(/*dim=*/0));
}

View File

@ -115,11 +115,17 @@ def meta_index_select_out(self, dim, index, out):
return out.copy_(torch.index_select(self, dim, index))
@register_meta([aten.max.default, aten.min.default])
@register_meta([aten.max.default, aten.max.unary_out])
@out_wrapper()
def meta_max(self):
return self.new_empty(())
@register_meta([aten.min.default])
def meta_min(self):
return self.new_empty(())
@register_meta(aten.angle.default)
def meta_angle(self):
if self.is_complex():

View File

@ -324,16 +324,23 @@ static bool varargsCanBeUsedAsList(
!typevar_list;
}
// Note (@zasdfgbnm):
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
// Currently JIT does not distinguish ScalarType vs int, so there is really
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode
// the aten::view.dtype here to block this overload. This blocklist should be
// removed when JIT fully suports ScalarType as its own type.
bool isBlockListedSchema(const FunctionSchema& schema) {
// Note (@zasdfgbnm):
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
// Currently JIT does not distinguish ScalarType vs int, so there is really
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to
// hardcode the aten::view.dtype here to block this overload. This blocklist
// should be removed when JIT fully suports ScalarType as its own type.
if (schema.name() == "aten::view" && schema.overload_name() == "dtype") {
return true;
}
// Note (@tugsbayasgalan)
// TorchScript doesn't suport kwargs so this op collides with aten.max.others
// since both of them have 2 Tensor inputs. Since we don't expect users to
// use this op in TS, we just skip it
if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") {
return true;
}
return false;
}

View File

@ -20,6 +20,8 @@ struct MatchedSchema {
std::string schema_name;
};
TORCH_API bool isBlockListedSchema(const FunctionSchema& schema);
TORCH_API MatchedSchema matchSchema(
const ::c10::FunctionSchema& schema,
const SourceRange& loc,

View File

@ -1009,6 +1009,9 @@ Value* Node::namedInput(Symbol name) const {
}
bool Node::matches(const FunctionSchema& schema) const {
if (isBlockListedSchema(schema)) {
return false;
}
// wrong name
if (kind().toQualString() != schema.name()) {
return false;

View File

@ -9901,7 +9901,7 @@ op_db: List[OpInfo] = [
OpInfo('max',
variant_test_name='reduction_no_dim',
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
supports_out=False,
supports_out=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_max_min_reduction_no_dim,