add str comparisons (#20761)

Summary:
add string comparisons
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20761

Differential Revision: D15434616

Pulled By: eellison

fbshipit-source-id: c00c7bac6308dbcc6a9e46b92421f49fb2d5a81c
This commit is contained in:
Elias Ellison
2019-05-21 12:17:45 -07:00
committed by Facebook Github Bot
parent cca923c481
commit 47dc65fe76
2 changed files with 58 additions and 39 deletions

View File

@ -12029,6 +12029,15 @@ a")
self.checkScript(fn, ("abcde",))
def test_str_cmp(self):
def test(a, b):
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
return a != b, a == b, a < b, a > b, a <= b, a >= b
self.checkScript(test, ("1", "2"))
self.checkScript(test, ("2", "1"))
self.checkScript(test, ("1", "1"))
def test_ord(self):
def fn(x):
# type: (str) -> int

View File

@ -8,9 +8,8 @@
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/script/error_report.h>
@ -20,20 +19,20 @@
#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/Dict.h>
#include <ATen/core/ivalue.h>
#include <c10/core/thread_pool.h>
#include <c10/util/SmallVector.h>
#include <algorithm>
#include <cmath>
#include <exception>
#include <fstream>
#include <iostream>
#include <limits>
#include <memory>
#include <mutex>
#include <ostream>
#include <fstream>
#include <stdexcept>
#include <string>
#include <typeinfo>
@ -99,13 +98,13 @@ static int64_t floordiv(int64_t a, int64_t b) {
}
static int gcd(int a, int b) {
while (b != 0) {
int r = a % b;
a = b;
b = r;
}
// in python gcd returns non-negative values
return std::abs(a);
while (b != 0) {
int r = a % b;
a = b;
b = r;
}
// in python gcd returns non-negative values
return std::abs(a);
}
// reference function THPVariable_to in python_variable_methods.cpp
@ -1050,12 +1049,21 @@ RegisterOperators logging_operators(
return 0; \
})
#define DEFINE_STR_CMP_OP(aten_op, op) \
Operator(#aten_op "(str a, str b) -> bool", [](Stack& stack) { \
auto b = pop(stack).toStringRef(); \
auto a = pop(stack).toStringRef(); \
push(stack, op); \
return 0; \
})
#define DEFINE_BINARY_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
DEFINE_INT_FLOAT_OP(aten_op, op, float)
#define DEFINE_COMPARISON_OP(aten_op, op) \
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
DEFINE_INT_FLOAT_OP(aten_op, op, bool)
DEFINE_INT_FLOAT_OP(aten_op, op, bool), DEFINE_STR_CMP_OP(aten_op, op)
#define DEFINE_BOOL_OP(aten_op, op) \
Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
bool a, b; \
@ -1736,8 +1744,7 @@ RegisterOperators reg2({
}
push(stack, chars);
return 0;
}
),
}),
// Mutable ops for lists containing mutable types.
#define CREATE_MUTABLE_LIST_OPS(decl_type, c_type) \
Operator( \
@ -1916,13 +1923,13 @@ RegisterOperators reg2({
return 0;
}),
Operator(
"prim::str(t elem) -> str",
[](Stack& stack) {
std::stringstream ss;
ss << pop(stack);
push(stack, ss.str());
return 0;
}),
"prim::str(t elem) -> str",
[](Stack& stack) {
std::stringstream ss;
ss << pop(stack);
push(stack, ss.str());
return 0;
}),
Operator(
"aten::ord(str string) -> int",
[](Stack& stack) {
@ -2153,24 +2160,28 @@ RegisterOperators reg2({
DEFINE_INT_OP(aten::gcd, gcd(a, b)),
DEFINE_GENERIC_OP(aten::copysign, std::copysign(a, b), std::copysign(a, b), float, float),
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a,b), float),
DEFINE_GENERIC_OP(
aten::copysign,
std::copysign(a, b),
std::copysign(a, b),
float,
float),
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a, b), float),
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
Operator( \
#aten_op "(int a) -> " #int_result, \
[](Stack& stack) { \
int64_t a; \
pop(stack, a); \
push(stack, op); \
return 0; \
}), \
Operator(#aten_op "(float a) -> " #float_result, \
[](Stack& stack) { \
double a; \
pop(stack, a); \
push(stack, op); \
return 0; \
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
Operator( \
#aten_op "(int a) -> " #int_result, \
[](Stack& stack) { \
int64_t a; \
pop(stack, a); \
push(stack, op); \
return 0; \
}), \
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
double a; \
pop(stack, a); \
push(stack, op); \
return 0; \
})
DEFINE_MATH_OP(aten::gamma, std::tgamma(a), float, float),
@ -2186,7 +2197,6 @@ RegisterOperators reg2({
DEFINE_COMPARISON_OP(aten::gt, a > b),
DEFINE_COMPARISON_OP(aten::le, a <= b),
DEFINE_COMPARISON_OP(aten::ge, a >= b),
DEFINE_BOOL_OP(aten::__and__, a&& b),
DEFINE_BOOL_OP(aten::__or__, a || b),
DEFINE_BOOL_OP(aten::__xor__, a != b),