mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add str comparisons (#20761)
Summary: add string comparisons Pull Request resolved: https://github.com/pytorch/pytorch/pull/20761 Differential Revision: D15434616 Pulled By: eellison fbshipit-source-id: c00c7bac6308dbcc6a9e46b92421f49fb2d5a81c
This commit is contained in:
committed by
Facebook Github Bot
parent
cca923c481
commit
47dc65fe76
@ -12029,6 +12029,15 @@ a")
|
||||
|
||||
self.checkScript(fn, ("abcde",))
|
||||
|
||||
def test_str_cmp(self):
|
||||
def test(a, b):
|
||||
# type: (str, str) -> Tuple[bool, bool, bool, bool, bool, bool]
|
||||
return a != b, a == b, a < b, a > b, a <= b, a >= b
|
||||
|
||||
self.checkScript(test, ("1", "2"))
|
||||
self.checkScript(test, ("2", "1"))
|
||||
self.checkScript(test, ("1", "1"))
|
||||
|
||||
def test_ord(self):
|
||||
def fn(x):
|
||||
# type: (str) -> int
|
||||
|
@ -8,9 +8,8 @@
|
||||
#include <torch/csrc/jit/fuser/interface.h>
|
||||
#include <torch/csrc/jit/graph_executor.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/script/logging.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/profiling_record.h>
|
||||
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||
#include <torch/csrc/jit/script/error_report.h>
|
||||
@ -20,20 +19,20 @@
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/Dict.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <typeinfo>
|
||||
@ -99,13 +98,13 @@ static int64_t floordiv(int64_t a, int64_t b) {
|
||||
}
|
||||
|
||||
static int gcd(int a, int b) {
|
||||
while (b != 0) {
|
||||
int r = a % b;
|
||||
a = b;
|
||||
b = r;
|
||||
}
|
||||
// in python gcd returns non-negative values
|
||||
return std::abs(a);
|
||||
while (b != 0) {
|
||||
int r = a % b;
|
||||
a = b;
|
||||
b = r;
|
||||
}
|
||||
// in python gcd returns non-negative values
|
||||
return std::abs(a);
|
||||
}
|
||||
|
||||
// reference function THPVariable_to in python_variable_methods.cpp
|
||||
@ -1050,12 +1049,21 @@ RegisterOperators logging_operators(
|
||||
return 0; \
|
||||
})
|
||||
|
||||
#define DEFINE_STR_CMP_OP(aten_op, op) \
|
||||
Operator(#aten_op "(str a, str b) -> bool", [](Stack& stack) { \
|
||||
auto b = pop(stack).toStringRef(); \
|
||||
auto a = pop(stack).toStringRef(); \
|
||||
push(stack, op); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
#define DEFINE_BINARY_OP(aten_op, op) \
|
||||
DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
|
||||
DEFINE_INT_FLOAT_OP(aten_op, op, float)
|
||||
#define DEFINE_COMPARISON_OP(aten_op, op) \
|
||||
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
|
||||
DEFINE_INT_FLOAT_OP(aten_op, op, bool)
|
||||
DEFINE_INT_FLOAT_OP(aten_op, op, bool), DEFINE_STR_CMP_OP(aten_op, op)
|
||||
|
||||
#define DEFINE_BOOL_OP(aten_op, op) \
|
||||
Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
|
||||
bool a, b; \
|
||||
@ -1736,8 +1744,7 @@ RegisterOperators reg2({
|
||||
}
|
||||
push(stack, chars);
|
||||
return 0;
|
||||
}
|
||||
),
|
||||
}),
|
||||
// Mutable ops for lists containing mutable types.
|
||||
#define CREATE_MUTABLE_LIST_OPS(decl_type, c_type) \
|
||||
Operator( \
|
||||
@ -1916,13 +1923,13 @@ RegisterOperators reg2({
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"prim::str(t elem) -> str",
|
||||
[](Stack& stack) {
|
||||
std::stringstream ss;
|
||||
ss << pop(stack);
|
||||
push(stack, ss.str());
|
||||
return 0;
|
||||
}),
|
||||
"prim::str(t elem) -> str",
|
||||
[](Stack& stack) {
|
||||
std::stringstream ss;
|
||||
ss << pop(stack);
|
||||
push(stack, ss.str());
|
||||
return 0;
|
||||
}),
|
||||
Operator(
|
||||
"aten::ord(str string) -> int",
|
||||
[](Stack& stack) {
|
||||
@ -2153,24 +2160,28 @@ RegisterOperators reg2({
|
||||
|
||||
DEFINE_INT_OP(aten::gcd, gcd(a, b)),
|
||||
|
||||
DEFINE_GENERIC_OP(aten::copysign, std::copysign(a, b), std::copysign(a, b), float, float),
|
||||
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a,b), float),
|
||||
DEFINE_GENERIC_OP(
|
||||
aten::copysign,
|
||||
std::copysign(a, b),
|
||||
std::copysign(a, b),
|
||||
float,
|
||||
float),
|
||||
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a, b), float),
|
||||
|
||||
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
|
||||
Operator( \
|
||||
#aten_op "(int a) -> " #int_result, \
|
||||
[](Stack& stack) { \
|
||||
int64_t a; \
|
||||
pop(stack, a); \
|
||||
push(stack, op); \
|
||||
return 0; \
|
||||
}), \
|
||||
Operator(#aten_op "(float a) -> " #float_result, \
|
||||
[](Stack& stack) { \
|
||||
double a; \
|
||||
pop(stack, a); \
|
||||
push(stack, op); \
|
||||
return 0; \
|
||||
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
|
||||
Operator( \
|
||||
#aten_op "(int a) -> " #int_result, \
|
||||
[](Stack& stack) { \
|
||||
int64_t a; \
|
||||
pop(stack, a); \
|
||||
push(stack, op); \
|
||||
return 0; \
|
||||
}), \
|
||||
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
|
||||
double a; \
|
||||
pop(stack, a); \
|
||||
push(stack, op); \
|
||||
return 0; \
|
||||
})
|
||||
|
||||
DEFINE_MATH_OP(aten::gamma, std::tgamma(a), float, float),
|
||||
@ -2186,7 +2197,6 @@ RegisterOperators reg2({
|
||||
DEFINE_COMPARISON_OP(aten::gt, a > b),
|
||||
DEFINE_COMPARISON_OP(aten::le, a <= b),
|
||||
DEFINE_COMPARISON_OP(aten::ge, a >= b),
|
||||
|
||||
DEFINE_BOOL_OP(aten::__and__, a&& b),
|
||||
DEFINE_BOOL_OP(aten::__or__, a || b),
|
||||
DEFINE_BOOL_OP(aten::__xor__, a != b),
|
||||
|
Reference in New Issue
Block a user