diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 34b051666..9528dca56 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -5,6 +5,7 @@ #include "server/zset_family.h" extern "C" { +#include "redis/listpack.h" #include "redis/object.h" #include "redis/zset.h" } @@ -52,6 +53,171 @@ OpResult FindZEntry(unsigned flags, const OpArgs& op_args, string_ return it; } +struct ZListParams { + uint32_t offset = 0; + uint32_t limit = UINT32_MAX; +}; + +class IntervalVisitor { + public: + IntervalVisitor(const ZListParams& params, robj* o) : params_(params), zobj_(o) { + } + + void operator()(const ZSetFamily::IndexInterval& ii); + + void operator()(const ZSetFamily::ScoreInterval& si); + + ZSetFamily::ScoredArray PopResult() { + return std::move(result_); + } + + private: + void ExtractListPack(const zrangespec& range); + void ExtractSkipList(const zrangespec& range); + + void Next(uint8_t* zl, uint8_t** eptr, uint8_t** sptr) const { + if (reverse_) { + zzlPrev(zl, eptr, sptr); + } else { + zzlNext(zl, eptr, sptr); + } + } + + bool IsUnder(double score, const zrangespec& spec) const { + return reverse_ ? zslValueGteMin(score, &spec) : zslValueLteMax(score, &spec); + } + + ZListParams params_; + robj* zobj_; + + bool reverse_ = false; + ZSetFamily::ScoredArray result_; +}; + +void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) { + LOG(FATAL) << "TBD"; +} + +void IntervalVisitor::ExtractListPack(const zrangespec& range) { + uint8_t* zl = (uint8_t*)zobj_->ptr; + uint8_t *eptr, *sptr; + uint8_t* vstr; + unsigned int vlen; + long long vlong; + unsigned rangelen = 0; + unsigned offset = params_.offset; + unsigned limit = params_.limit; + + /* If reversed, get the last node in range as starting point. */ + if (reverse_) { + eptr = zzlLastInRange(zl, &range); + } else { + eptr = zzlFirstInRange(zl, &range); + } + + /* Get score pointer for the first element. */ + if (eptr) + sptr = lpNext(zl, eptr); + + /* If there is an offset, just traverse the number of elements without + * checking the score because that is done in the next loop. */ + while (eptr && offset--) { + Next(zl, &eptr, &sptr); + } + + while (eptr && limit--) { + double score = zzlGetScore(sptr); + + /* Abort when the node is no longer in range. */ + if (!IsUnder(score, range)) + break; + + /* We know the element exists, so lpGetValue should always + * succeed */ + vstr = lpGetValue(eptr, &vlen, &vlong); + + rangelen++; + if (vstr == NULL) { + result_.emplace_back(absl::StrCat(vlong), score); + } else { + result_.emplace_back(string{reinterpret_cast(vstr), vlen}, score); + // handler->emitResultFromCBuffer(handler, vstr, vlen, score); + } + + /* Move to next node */ + Next(zl, &eptr, &sptr); + } +} + +void IntervalVisitor::ExtractSkipList(const zrangespec& range) { + zset* zs = (zset*)zobj_->ptr; + zskiplist* zsl = zs->zsl; + zskiplistNode* ln; + unsigned offset = params_.offset; + unsigned limit = params_.limit; + unsigned rangelen = 0; + + /* If reversed, get the last node in range as starting point. */ + if (reverse_) { + ln = zslLastInRange(zsl, &range); + } else { + ln = zslFirstInRange(zsl, &range); + } + + /* If there is an offset, just traverse the number of elements without + * checking the score because that is done in the next loop. */ + while (ln && offset--) { + if (reverse_) { + ln = ln->backward; + } else { + ln = ln->level[0].forward; + } + } + + while (ln && limit--) { + /* Abort when the node is no longer in range. */ + if (!IsUnder(ln->score, range)) + break; + + rangelen++; + result_.emplace_back(string{ln->ele, sdslen(ln->ele)}, ln->score); + + /* Move to next node */ + if (reverse_) { + ln = ln->backward; + } else { + ln = ln->level[0].forward; + } + } +} + +void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) { + zrangespec range; + range.min = si.first.val; + range.max = si.second.val; + range.minex = si.first.is_open; + range.maxex = si.first.is_open; + + if (zobj_->encoding == OBJ_ENCODING_LISTPACK) { + ExtractListPack(range); + } else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) { + ExtractSkipList(range); + } else { + LOG(FATAL) << "Unknown sorted set encoding " << zobj_->encoding; + } +} + +bool ParseScore(string_view src, double* d) { + if (src == "-inf") { + *d = -HUGE_VAL; + } else if (src == "+inf") { + *d = HUGE_VAL; + } else { + return absl::SimpleAtod(src, d); + } + return true; +}; + } // namespace void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) { @@ -127,7 +293,7 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) { for (; i < args.size(); i += 2) { std::string_view cur_arg = ArgS(args, i); double val; - if (!absl::SimpleAtod(cur_arg, &val) || !std::isfinite(val)) { + if (!ParseScore(cur_arg, &val)) { (*cntx)->SendError(kInvalidFloatErr); return; } @@ -161,6 +327,72 @@ void ZSetFamily::ZIncrBy(CmdArgList args, ConnectionContext* cntx) { } void ZSetFamily::ZRange(CmdArgList args, ConnectionContext* cntx) { + std::string_view key = ArgS(args, 1); + std::string_view min_s = ArgS(args, 2); + std::string_view max_s = ArgS(args, 3); + + if (min_s.empty() || max_s.empty()) { + return (*cntx)->SendError(kInvalidIntErr); + } + + ZRangeSpec range_spec; + bool parse_score = false; + + for (size_t i = 4; i < args.size(); ++i) { + ToUpper(&args[i]); + + string_view cur_arg = ArgS(args, i); + if (cur_arg == "BYSCORE") { + parse_score = true; + } else { + return cntx->reply_builder()->SendError(absl::StrCat("unsupported option ", cur_arg)); + } + } + + if (parse_score) { + ScoreInterval si; + + if (min_s[0] == '(') { + si.first.is_open = true; + min_s.remove_prefix(1); + } + + if (max_s[0] == '(') { + si.second.is_open = true; + max_s.remove_prefix(1); + } + + if (!ParseScore(min_s, &si.first.val) || !ParseScore(max_s, &si.second.val)) { + return (*cntx)->SendError("min or max is not a float"); + } + range_spec.interval = si; + } else { + IndexInterval ii; + + if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) { + (*cntx)->SendError(kInvalidIntErr); + return; + } + range_spec.interval = ii; + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + OpArgs op_args{shard, t->db_index()}; + return OpRange(range_spec, op_args, key); + }; + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result.status() == OpStatus::WRONG_TYPE) { + (*cntx)->SendError(kWrongTypeErr); + } else { + (*cntx)->StartArray(result.value().size()); + for (const auto& p : result.value()) { + (*cntx)->SendBulkString(p.first); + + if (false) { // withscores + (*cntx)->SendDouble(p.second); + } + } + } } void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) { @@ -284,6 +516,21 @@ OpResult ZSetFamily::OpScore(const OpArgs& op_args, string_view key, str return score; } +auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, std::string_view key) + -> OpResult { + OpResult res_it = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_ZSET); + if (!res_it) + return res_it.status(); + + robj* zobj = res_it.value()->second.AsRObj(); + ZListParams params; + IntervalVisitor iv{params, zobj}; + + absl::visit(iv, range_spec.interval); + + return iv.PopResult(); +} + #define HFUNC(x) SetHandler(&ZSetFamily::x) void ZSetFamily::Register(CommandRegistry* registry) { diff --git a/src/server/zset_family.h b/src/server/zset_family.h index cada49849..50734cb88 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -18,6 +18,23 @@ class ZSetFamily { public: static void Register(CommandRegistry* registry); + using IndexInterval = std::pair; + + struct Bound { + double val; + bool is_open = false; + }; + + using ScoreInterval = std::pair; + + struct ZRangeSpec { + std::variant interval; + // TODO: handle open/close, inf etc. + }; + + using ScoredMember = std::pair; + using ScoredArray = std::vector; + private: static void ZCard(CmdArgList args, ConnectionContext* cntx); static void ZAdd(CmdArgList args, ConnectionContext* cntx); @@ -41,6 +58,9 @@ class ZSetFamily { static OpResult OpRem(const OpArgs& op_args, std::string_view key, ArgSlice members); static OpResult OpScore(const OpArgs& op_args, std::string_view key, std::string_view member); + static OpResult OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, + std::string_view key); + }; } // namespace dfly diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 6a0738481..12188f482 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -50,6 +50,8 @@ TEST_F(ZSetFamilyTest, ZRem) { EXPECT_THAT(resp[0], IntArg(1)); resp = Run({"zcard", "x"}); EXPECT_THAT(resp[0], IntArg(1)); + EXPECT_THAT(Run({"zrange", "x", "0", "3", "byscore"}), ElementsAre("a")); + EXPECT_THAT(Run({"zrange", "x", "(-inf", "(+inf", "byscore"}), ElementsAre("a")); } } // namespace dfly