From cb0d8dfee2cc27e54b78498ddfba9e416c5c8742 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Fri, 18 Mar 2022 05:12:22 +0200 Subject: [PATCH] Add ZRANGEBYSCORE. Cover rank case for ZRANGE --- README.md | 4 +- src/redis/zset.h | 1 + src/server/zset_family.cc | 247 ++++++++++++++++++++++++--------- src/server/zset_family.h | 18 ++- src/server/zset_family_test.cc | 6 + 5 files changed, 206 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index d5761bd5d..db0f5f426 100644 --- a/README.md +++ b/README.md @@ -108,8 +108,8 @@ API 1.0 - [X] ZADD - [X] ZCARD - [ ] ZINCRBY - - [ ] ZRANGE - - [ ] ZRANGEBYSCORE + - [X] ZRANGE + - [X] ZRANGEBYSCORE - [X] ZREM - [ ] ZREMRANGEBYSCORE - [ ] ZREVRANGE diff --git a/src/redis/zset.h b/src/redis/zset.h index 53d86d00d..314ca31cb 100644 --- a/src/redis/zset.h +++ b/src/redis/zset.h @@ -96,5 +96,6 @@ int zzlLexValueLteMax(unsigned char* p, const zlexrangespec* spec); int zslLexValueGteMin(sds value, const zlexrangespec* spec); int zslLexValueLteMax(sds value, const zlexrangespec* spec); int zsetZiplistValidateIntegrity(unsigned char* zl, size_t size, int deep); +zskiplistNode* zslGetElementByRank(zskiplist *zsl, unsigned long rank); #endif diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 9528dca56..9d7c70242 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -53,14 +53,9 @@ OpResult FindZEntry(unsigned flags, const OpArgs& op_args, string_ return it; } -struct ZListParams { - uint32_t offset = 0; - uint32_t limit = UINT32_MAX; -}; - class IntervalVisitor { public: - IntervalVisitor(const ZListParams& params, robj* o) : params_(params), zobj_(o) { + IntervalVisitor(const ZSetFamily::RangeParams& params, robj* o) : params_(params), zobj_(o) { } void operator()(const ZSetFamily::IndexInterval& ii); @@ -87,7 +82,9 @@ class IntervalVisitor { return reverse_ ? zslValueGteMin(score, &spec) : zslValueLteMax(score, &spec); } - ZListParams params_; + void AddResult(const uint8_t* vstr, unsigned vlen, long long vlon, double score); + + ZSetFamily::RangeParams params_; robj* zobj_; bool reverse_ = false; @@ -95,7 +92,77 @@ class IntervalVisitor { }; void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) { - LOG(FATAL) << "TBD"; + unsigned long llen = zsetLength(zobj_); + int32_t start = ii.first; + int32_t end = ii.second; + + if (start < 0) + start = llen + start; + if (end < 0) + end = llen + end; + if (start < 0) + start = 0; + + if (start > end || unsigned(start) >= llen) { + return; + } + + if (unsigned(end) >= llen) + end = llen - 1; + + unsigned rangelen = (end - start) + 1; + + if (zobj_->encoding == OBJ_ENCODING_LISTPACK) { + unsigned char* zl = (uint8_t*)zobj_->ptr; + unsigned char *eptr, *sptr; + unsigned char* vstr; + unsigned int vlen; + long long vlong; + double score = 0.0; + + if (reverse_) + eptr = lpSeek(zl, -2 - (2 * start)); + else + eptr = lpSeek(zl, 2 * start); + + sptr = lpNext(zl, eptr); + + while (rangelen--) { + DCHECK(eptr != NULL && sptr != NULL); + vstr = lpGetValue(eptr, &vlen, &vlong); + + if (params_.with_scores) /* don't bother to extract the score if it's gonna be ignored. */ + score = zzlGetScore(sptr); + + AddResult(vstr, vlen, vlong, score); + + Next(zl, &eptr, &sptr); + } + } else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) { + zset* zs = (zset*)zobj_->ptr; + zskiplist* zsl = zs->zsl; + zskiplistNode* ln; + + /* Check if starting point is trivial, before doing log(N) lookup. */ + if (reverse_) { + ln = zsl->tail; + if (start > 0) + ln = zslGetElementByRank(zsl, llen - start); + } else { + ln = zsl->header->level[0].forward; + if (start > 0) + ln = zslGetElementByRank(zsl, start + 1); + } + + while (rangelen--) { + DCHECK(ln != NULL); + sds ele = ln->ele; + result_.emplace_back(string(ele, sdslen(ele)), ln->score); + ln = reverse_ ? ln->backward : ln->level[0].forward; + } + } else { + LOG(FATAL) << "Unknown sorted set encoding" << zobj_->encoding; + } } void IntervalVisitor::ExtractListPack(const zrangespec& range) { @@ -136,14 +203,9 @@ void IntervalVisitor::ExtractListPack(const zrangespec& range) { * succeed */ vstr = lpGetValue(eptr, &vlen, &vlong); - rangelen++; - if (vstr == NULL) { - result_.emplace_back(absl::StrCat(vlong), score); - } else { - result_.emplace_back(string{reinterpret_cast(vstr), vlen}, score); - // handler->emitResultFromCBuffer(handler, vstr, vlen, score); - } + AddResult(vstr, vlen, vlong, score); + rangelen++; /* Move to next node */ Next(zl, &eptr, &sptr); } @@ -196,7 +258,7 @@ void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) { range.min = si.first.val; range.max = si.second.val; range.minex = si.first.is_open; - range.maxex = si.first.is_open; + range.maxex = si.second.is_open; if (zobj_->encoding == OBJ_ENCODING_LISTPACK) { ExtractListPack(range); @@ -207,17 +269,37 @@ void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) { } } -bool ParseScore(string_view src, double* d) { - if (src == "-inf") { - *d = -HUGE_VAL; - } else if (src == "+inf") { - *d = HUGE_VAL; +void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vlong, double score) { + if (vstr == NULL) { + result_.emplace_back(absl::StrCat(vlong), score); } else { - return absl::SimpleAtod(src, d); + result_.emplace_back(string{reinterpret_cast(vstr), vlen}, score); + } +} + +bool ParseScore(string_view src, double* score) { + if (src == "-inf") { + *score = -HUGE_VAL; + } else if (src == "+inf") { + *score = HUGE_VAL; + } else { + return absl::SimpleAtod(src, score); } return true; }; +bool ParseBound(string_view src, ZSetFamily::Bound* bound) { + if (src.empty()) + return false; + + if (src[0] == '(') { + bound->is_open = true; + src.remove_prefix(1); + } + + return ParseScore(src, &bound->val); +} + } // namespace void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) { @@ -331,12 +413,8 @@ void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) { std::string_view min_s = ArgS(args, 2); std::string_view max_s = ArgS(args, 3); - if (min_s.empty() || max_s.empty()) { - return (*cntx)->SendError(kInvalidIntErr); - } - - ZRangeSpec range_spec; bool parse_score = false; + RangeParams range_params; for (size_t i = 4; i < args.size(); ++i) { ToUpper(&args[i]); @@ -344,58 +422,57 @@ void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) { string_view cur_arg = ArgS(args, i); if (cur_arg == "BYSCORE") { parse_score = true; + } else if (cur_arg == "WITHSCORES") { + range_params.with_scores = true; } else { return cntx->reply_builder()->SendError(absl::StrCat("unsupported option ", cur_arg)); } } if (parse_score) { - ScoreInterval si; - - if (min_s[0] == '(') { - si.first.is_open = true; - min_s.remove_prefix(1); - } - - if (max_s[0] == '(') { - si.second.is_open = true; - max_s.remove_prefix(1); - } - - if (!ParseScore(min_s, &si.first.val) || !ParseScore(max_s, &si.second.val)) { - return (*cntx)->SendError("min or max is not a float"); - } - range_spec.interval = si; - } else { - IndexInterval ii; - - if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) { - (*cntx)->SendError(kInvalidIntErr); - return; - } - range_spec.interval = ii; + ZRangeByScoreInternal(key, min_s, max_s, range_params, cntx); + return; } + IndexInterval ii; + + if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) { + (*cntx)->SendError(kInvalidIntErr); + return; + } + + ZRangeSpec range_spec; + range_spec.params = range_params; + range_spec.interval = ii; + auto cb = [&](Transaction* t, EngineShard* shard) { OpArgs op_args{shard, t->db_index()}; return OpRange(range_spec, op_args, key); }; - OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); - if (result.status() == OpStatus::WRONG_TYPE) { - (*cntx)->SendError(kWrongTypeErr); - } else { - (*cntx)->StartArray(result.value().size()); - for (const auto& p : result.value()) { - (*cntx)->SendBulkString(p.first); - if (false) { // withscores - (*cntx)->SendDouble(p.second); - } - } - } + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + OutputScoredArrayResult(result, range_params.with_scores, cntx); } void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view min_s = ArgS(args, 2); + std::string_view max_s = ArgS(args, 3); + + RangeParams range_params; + + for (size_t i = 4; i < args.size(); ++i) { + ToUpper(&args[i]); + + string_view cur_arg = ArgS(args, i); + if (cur_arg == "WITHSCORES") { + range_params.with_scores = true; + } else { + return cntx->reply_builder()->SendError(absl::StrCat("unsupported option ", cur_arg)); + } + } + + ZRangeByScoreInternal(key, min_s, max_s, range_params, cntx); } void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) { @@ -438,6 +515,47 @@ void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) { } } +void ZSetFamily::ZRangeByScoreInternal(std::string_view key, std::string_view min_s, + std::string_view max_s, const RangeParams& params, + ConnectionContext* cntx) { + ZRangeSpec range_spec; + range_spec.params = params; + + ScoreInterval si; + if (!ParseBound(min_s, &si.first) || + !ParseBound(max_s, &si.second)) { + return (*cntx)->SendError("min or max is not a float"); + } + range_spec.interval = si; + + auto cb = [&](Transaction* t, EngineShard* shard) { + OpArgs op_args{shard, t->db_index()}; + return OpRange(range_spec, op_args, key); + }; + + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + OutputScoredArrayResult(result, params.with_scores, cntx); +} + +void ZSetFamily::OutputScoredArrayResult(const OpResult& result, bool with_scores, + ConnectionContext* cntx) { + if (result.status() == OpStatus::WRONG_TYPE) { + return (*cntx)->SendError(kWrongTypeErr); + } + + LOG_IF(WARNING, !result && result.status() != OpStatus::KEY_NOTFOUND) + << "Unexpected status " << result.status(); + + (*cntx)->StartArray(result->size() * (with_scores ? 2 : 1)); + for (const auto& p : result.value()) { + (*cntx)->SendBulkString(p.first); + + if (with_scores) { + (*cntx)->SendDouble(p.second); + } + } +} + OpResult ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key, ScoredMemberSpan members) { DCHECK(!members.empty()); @@ -523,8 +641,7 @@ auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, st return res_it.status(); robj* zobj = res_it.value()->second.AsRObj(); - ZListParams params; - IntervalVisitor iv{params, zobj}; + IntervalVisitor iv{range_spec.params, zobj}; absl::visit(iv, range_spec.interval); diff --git a/src/server/zset_family.h b/src/server/zset_family.h index 50734cb88..1f27dd533 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -27,15 +27,23 @@ class ZSetFamily { using ScoreInterval = std::pair; + struct RangeParams { + uint32_t offset = 0; + uint32_t limit = UINT32_MAX; + bool with_scores = false; + }; + struct ZRangeSpec { std::variant interval; - // TODO: handle open/close, inf etc. + RangeParams params; }; using ScoredMember = std::pair; using ScoredArray = std::vector; private: + template using OpResult = facade::OpResult; + static void ZCard(CmdArgList args, ConnectionContext* cntx); static void ZAdd(CmdArgList args, ConnectionContext* cntx); static void ZIncrBy(CmdArgList args, ConnectionContext* cntx); @@ -44,6 +52,12 @@ class ZSetFamily { static void ZScore(CmdArgList args, ConnectionContext* cntx); static void ZRangeByScore(CmdArgList args, ConnectionContext* cntx); + static void ZRangeByScoreInternal(std::string_view key, std::string_view min_s, + std::string_view max_s, const RangeParams& params, + ConnectionContext* cntx); + static void OutputScoredArrayResult(const OpResult& arr, bool with_scores, + ConnectionContext* cntx); + struct ZParams { unsigned flags = 0; // mask of ZADD_IN_ macros. bool ch = false; // Corresponds to CH option. @@ -51,7 +65,6 @@ class ZSetFamily { using ScoredMemberView = std::pair; using ScoredMemberSpan = absl::Span; - template using OpResult = facade::OpResult; static OpResult OpAdd(const ZParams& zparams, const OpArgs& op_args, std::string_view key, ScoredMemberSpan members); @@ -60,7 +73,6 @@ class ZSetFamily { std::string_view member); static OpResult OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, std::string_view key); - }; } // namespace dfly diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 12188f482..1cf8b923f 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -48,10 +48,16 @@ TEST_F(ZSetFamilyTest, ZRem) { resp = Run({"zrem", "x", "b", "c"}); EXPECT_THAT(resp[0], IntArg(1)); + resp = Run({"zcard", "x"}); EXPECT_THAT(resp[0], IntArg(1)); EXPECT_THAT(Run({"zrange", "x", "0", "3", "byscore"}), ElementsAre("a")); EXPECT_THAT(Run({"zrange", "x", "(-inf", "(+inf", "byscore"}), ElementsAre("a")); } +TEST_F(ZSetFamilyTest, ZRange) { + Run({"zadd", "x", "1.1", "a", "2.1", "b"}); + EXPECT_THAT(Run({"zrangebyscore", "x", "0", "(1.1"}), ElementsAre(ArrLen(0))); +} + } // namespace dfly