1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-15 17:51:06 +00:00

Add ZREMRANGEBYSCORE and ZREMRANGEBYRANK commands

This commit is contained in:
Roman Gershman 2022-03-19 15:22:40 +02:00
parent cb0d8dfee2
commit 5bce920308
6 changed files with 217 additions and 28 deletions

View file

@ -109,6 +109,21 @@ size_t MallocUsedHSet(unsigned encoding, void* ptr) {
return 0;
}
size_t MallocUsedZSet(unsigned encoding, void* ptr) {
switch (encoding) {
case OBJ_ENCODING_LISTPACK:
return lpBytes(reinterpret_cast<uint8_t*>(ptr));
case OBJ_ENCODING_SKIPLIST: {
zset* zs = (zset*)ptr;
return DictMallocSize(zs->dict);
}
break;
default:
LOG(FATAL) << "Unknown set encoding type " << encoding;
}
return 0;
}
inline void FreeObjHash(unsigned encoding, void* ptr) {
switch (encoding) {
case OBJ_ENCODING_HT:
@ -232,6 +247,9 @@ size_t RobjWrapper::MallocUsed() const {
return MallocUsedSet(encoding_, inner_obj_);
case OBJ_HASH:
return MallocUsedHSet(encoding_, inner_obj_);
case OBJ_ZSET:
return MallocUsedZSet(encoding_, inner_obj_);
default:
LOG(FATAL) << "Not supported " << type_;
}
@ -265,6 +283,7 @@ size_t RobjWrapper::Size() const {
void RobjWrapper::Free(pmr::memory_resource* mr) {
if (!inner_obj_)
return;
DVLOG(1) << "RobjWrapper::Free " << inner_obj_;
switch (type_) {
case OBJ_STRING:

View file

@ -398,7 +398,7 @@ zskiplistNode *zslLastInRange(zskiplist *zsl, const zrangespec *range) {
* range->maxex). When inclusive a score >= min && score <= max is deleted.
* Note that this function takes the reference to the hash table view of the
* sorted set, in order to remove the elements from the hash table too. */
unsigned long zslDeleteRangeByScore(zskiplist *zsl, zrangespec *range, dict *dict) {
unsigned long zslDeleteRangeByScore(zskiplist *zsl, const zrangespec *range, dict *dict) {
zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
unsigned long removed = 0;
int i;
@ -1098,7 +1098,7 @@ unsigned char *zzlInsert(unsigned char *zl, sds ele, double score) {
return zl;
}
unsigned char *zzlDeleteRangeByScore(unsigned char *zl, zrangespec *range, unsigned long *deleted) {
unsigned char *zzlDeleteRangeByScore(unsigned char *zl, const zrangespec *range, unsigned long *deleted) {
unsigned char *eptr, *sptr;
double score;
unsigned long num = 0;

View file

@ -97,5 +97,9 @@ int zslLexValueGteMin(sds value, const zlexrangespec* spec);
int zslLexValueLteMax(sds value, const zlexrangespec* spec);
int zsetZiplistValidateIntegrity(unsigned char* zl, size_t size, int deep);
zskiplistNode* zslGetElementByRank(zskiplist *zsl, unsigned long rank);
unsigned long zslDeleteRangeByRank(zskiplist *zsl, unsigned int start, unsigned int end,
dict *dict);
unsigned long zslDeleteRangeByScore(zskiplist *zsl, const zrangespec *range, dict *dict);
unsigned char *zzlDeleteRangeByScore(unsigned char *zl, const zrangespec *range, unsigned long *deleted);
#endif

View file

@ -27,7 +27,7 @@ namespace {
using CI = CommandId;
static const char kNxXxErr[] = "XX and NX options at the same time are not compatible";
constexpr unsigned kMaxZiplistValue = 64;
constexpr unsigned kMaxListPackValue = 64;
OpResult<MainIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_view key,
size_t member_len) {
@ -40,11 +40,13 @@ OpResult<MainIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_
if (inserted) {
robj* zobj = nullptr;
if (member_len > kMaxZiplistValue) {
if (member_len > kMaxListPackValue) {
zobj = createZsetObject();
} else {
zobj = createZsetListpackObject();
}
DVLOG(2) << "Created zset " << zobj->ptr;
it->second.ImportRObj(zobj);
} else {
if (it->second.ObjType() != OBJ_ZSET)
@ -53,9 +55,15 @@ OpResult<MainIterator> FindZEntry(unsigned flags, const OpArgs& op_args, string_
return it;
}
enum class Action {
RANGE = 0,
REM = 1,
};
class IntervalVisitor {
public:
IntervalVisitor(const ZSetFamily::RangeParams& params, robj* o) : params_(params), zobj_(o) {
IntervalVisitor(Action action, const ZSetFamily::RangeParams& params, robj* o)
: action_(action), params_(params), zobj_(o) {
}
void operator()(const ZSetFamily::IndexInterval& ii);
@ -66,9 +74,17 @@ class IntervalVisitor {
return std::move(result_);
}
unsigned removed() const {
return removed_;
}
private:
void ExtractListPack(const zrangespec& range);
void ExtractSkipList(const zrangespec& range);
void ActionRange(unsigned start, unsigned end); // rank
void ActionRange(const zrangespec& range); // score
void ActionRem(unsigned start, unsigned end); // rank
void ActionRem(const zrangespec& range); // score
void Next(uint8_t* zl, uint8_t** eptr, uint8_t** sptr) const {
if (reverse_) {
@ -84,11 +100,13 @@ class IntervalVisitor {
void AddResult(const uint8_t* vstr, unsigned vlen, long long vlon, double score);
Action action_;
ZSetFamily::RangeParams params_;
robj* zobj_;
bool reverse_ = false;
ZSetFamily::ScoredArray result_;
unsigned removed_ = 0;
};
void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
@ -109,7 +127,34 @@ void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
if (unsigned(end) >= llen)
end = llen - 1;
switch (action_) {
case Action::RANGE:
ActionRange(start, end);
break;
case Action::REM:
ActionRem(start, end);
break;
}
}
void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) {
zrangespec range;
range.min = si.first.val;
range.max = si.second.val;
range.minex = si.first.is_open;
range.maxex = si.second.is_open;
switch (action_) {
case Action::RANGE:
ActionRange(range);
break;
case Action::REM:
ActionRem(range);
break;
}
}
void IntervalVisitor::ActionRange(unsigned start, unsigned end) {
unsigned rangelen = (end - start) + 1;
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
@ -146,6 +191,7 @@ void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
/* Check if starting point is trivial, before doing log(N) lookup. */
if (reverse_) {
ln = zsl->tail;
unsigned long llen = zsetLength(zobj_);
if (start > 0)
ln = zslGetElementByRank(zsl, llen - start);
} else {
@ -165,6 +211,46 @@ void IntervalVisitor::operator()(const ZSetFamily::IndexInterval& ii) {
}
}
void IntervalVisitor::ActionRange(const zrangespec& range) {
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
ExtractListPack(range);
} else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) {
ExtractSkipList(range);
} else {
LOG(FATAL) << "Unknown sorted set encoding " << zobj_->encoding;
}
}
void IntervalVisitor::ActionRem(unsigned start, unsigned end) {
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
uint8_t* zl = (uint8_t*)zobj_->ptr;
removed_ = (end - start) + 1;
zl = lpDeleteRange(zl, 2 * start, 2 * removed_);
zobj_->ptr = zl;
} else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) {
zset* zs = (zset*)zobj_->ptr;
removed_ = zslDeleteRangeByRank(zs->zsl, start + 1, end + 1, zs->dict);
} else {
LOG(FATAL) << "Unknown sorted set encoding" << zobj_->encoding;
}
}
void IntervalVisitor::ActionRem(const zrangespec& range) {
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
uint8_t* zl = (uint8_t*)zobj_->ptr;
unsigned long deleted = 0;
zl = zzlDeleteRangeByScore(zl, &range, &deleted);
zobj_->ptr = zl;
removed_ = deleted;
} else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) {
zset* zs = (zset*)zobj_->ptr;
removed_ = zslDeleteRangeByScore(zs->zsl, &range, zs->dict);
} else {
LOG(FATAL) << "Unknown sorted set encoding" << zobj_->encoding;
}
}
void IntervalVisitor::ExtractListPack(const zrangespec& range) {
uint8_t* zl = (uint8_t*)zobj_->ptr;
uint8_t *eptr, *sptr;
@ -253,22 +339,6 @@ void IntervalVisitor::ExtractSkipList(const zrangespec& range) {
}
}
void IntervalVisitor::operator()(const ZSetFamily::ScoreInterval& si) {
zrangespec range;
range.min = si.first.val;
range.max = si.second.val;
range.minex = si.first.is_open;
range.maxex = si.second.is_open;
if (zobj_->encoding == OBJ_ENCODING_LISTPACK) {
ExtractListPack(range);
} else if (zobj_->encoding == OBJ_ENCODING_SKIPLIST) {
ExtractSkipList(range);
} else {
LOG(FATAL) << "Unknown sorted set encoding " << zobj_->encoding;
}
}
void IntervalVisitor::AddResult(const uint8_t* vstr, unsigned vlen, long long vlong, double score) {
if (vstr == NULL) {
result_.emplace_back(absl::StrCat(vlong), score);
@ -475,6 +545,38 @@ void ZSetFamily::ZRangeByScore(CmdArgList args, ConnectionContext* cntx) {
ZRangeByScoreInternal(key, min_s, max_s, range_params, cntx);
}
void ZSetFamily::ZRemRangeByRank(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
IndexInterval ii;
if (!absl::SimpleAtoi(min_s, &ii.first) || !absl::SimpleAtoi(max_s, &ii.second)) {
return (*cntx)->SendError(kInvalidIntErr);
}
ZRangeSpec range_spec;
range_spec.interval = ii;
ZRemRangeGeneric(key, range_spec, cntx);
}
void ZSetFamily::ZRemRangeByScore(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view min_s = ArgS(args, 2);
std::string_view max_s = ArgS(args, 3);
ScoreInterval si;
if (!ParseBound(min_s, &si.first) || !ParseBound(max_s, &si.second)) {
return (*cntx)->SendError("min or max is not a float");
}
ZRangeSpec range_spec;
range_spec.interval = si;
ZRemRangeGeneric(key, range_spec, cntx);
}
void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
@ -492,7 +594,7 @@ void ZSetFamily::ZRem(CmdArgList args, ConnectionContext* cntx) {
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
} else {
(*cntx)->SendLong(result.value());
(*cntx)->SendLong(*result);
}
}
@ -511,7 +613,7 @@ void ZSetFamily::ZScore(CmdArgList args, ConnectionContext* cntx) {
} else if (!result) {
(*cntx)->SendNull();
} else {
(*cntx)->SendDouble(result.value());
(*cntx)->SendDouble(*result);
}
}
@ -522,8 +624,7 @@ void ZSetFamily::ZRangeByScoreInternal(std::string_view key, std::string_view mi
range_spec.params = params;
ScoreInterval si;
if (!ParseBound(min_s, &si.first) ||
!ParseBound(max_s, &si.second)) {
if (!ParseBound(min_s, &si.first) || !ParseBound(max_s, &si.second)) {
return (*cntx)->SendError("min or max is not a float");
}
range_spec.interval = si;
@ -556,6 +657,21 @@ void ZSetFamily::OutputScoredArrayResult(const OpResult<ScoredArray>& result, bo
}
}
void ZSetFamily::ZRemRangeGeneric(std::string_view key, const ZRangeSpec& range_spec,
ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
OpArgs op_args{shard, t->db_index()};
return OpRemRange(op_args, key, range_spec);
};
OpResult<unsigned> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
} else {
(*cntx)->SendLong(*result);
}
}
OpResult<unsigned> ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_args, string_view key,
ScoredMemberSpan members) {
DCHECK(!members.empty());
@ -591,6 +707,9 @@ OpResult<unsigned> ZSetFamily::OpAdd(const ZParams& zparams, const OpArgs& op_ar
if (!(retflags & ZADD_OUT_NOP))
processed++;
}
DVLOG(2) << "ZAdd " << zobj->ptr;
res_it.value()->second.SyncRObj();
return zparams.ch ? added + updated : added;
@ -641,13 +760,33 @@ auto ZSetFamily::OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args, st
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
IntervalVisitor iv{range_spec.params, zobj};
IntervalVisitor iv{Action::RANGE, range_spec.params, zobj};
absl::visit(iv, range_spec.interval);
std::visit(iv, range_spec.interval);
return iv.PopResult();
}
OpResult<unsigned> ZSetFamily::OpRemRange(const OpArgs& op_args, string_view key,
const ZRangeSpec& range_spec) {
OpResult<MainIterator> res_it = op_args.shard->db_slice().Find(op_args.db_ind, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
robj* zobj = res_it.value()->second.AsRObj();
IntervalVisitor iv{Action::REM, range_spec.params, zobj};
std::visit(iv, range_spec.interval);
res_it.value()->second.SyncRObj();
auto zlen = zsetLength(zobj);
if (zlen == 0) {
CHECK(op_args.shard->db_slice().Del(op_args.db_ind, res_it.value()));
}
return iv.removed();
}
#define HFUNC(x) SetHandler(&ZSetFamily::x)
void ZSetFamily::Register(CommandRegistry* registry) {
@ -657,7 +796,9 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZREM", CO::FAST | CO::WRITE, -3, 1, 1, 1}.HFUNC(ZRem)
<< CI{"ZRANGE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRange)
<< CI{"ZRANGEBYSCORE", CO::READONLY, -4, 1, 1, 1}.HFUNC(ZRangeByScore)
<< CI{"ZSCORE", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZScore);
<< CI{"ZSCORE", CO::READONLY | CO::FAST, 3, 1, 1, 1}.HFUNC(ZScore)
<< CI{"ZREMRANGEBYRANK", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRemRangeByRank)
<< CI{"ZREMRANGEBYSCORE", CO::WRITE, 4, 1, 1, 1}.HFUNC(ZRemRangeByScore);
}
} // namespace dfly

View file

@ -51,12 +51,16 @@ class ZSetFamily {
static void ZRem(CmdArgList args, ConnectionContext* cntx);
static void ZScore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRemRangeByRank(CmdArgList args, ConnectionContext* cntx);
static void ZRemRangeByScore(CmdArgList args, ConnectionContext* cntx);
static void ZRangeByScoreInternal(std::string_view key, std::string_view min_s,
std::string_view max_s, const RangeParams& params,
ConnectionContext* cntx);
static void OutputScoredArrayResult(const OpResult<ScoredArray>& arr, bool with_scores,
ConnectionContext* cntx);
static void ZRemRangeGeneric(std::string_view key, const ZRangeSpec& range_spec,
ConnectionContext* cntx);
struct ZParams {
unsigned flags = 0; // mask of ZADD_IN_ macros.
@ -73,6 +77,8 @@ class ZSetFamily {
std::string_view member);
static OpResult<ScoredArray> OpRange(const ZRangeSpec& range_spec, const OpArgs& op_args,
std::string_view key);
static OpResult<unsigned> OpRemRange(const OpArgs& op_args, std::string_view key,
const ZRangeSpec& spec);
};
} // namespace dfly

View file

@ -58,6 +58,25 @@ TEST_F(ZSetFamilyTest, ZRem) {
TEST_F(ZSetFamilyTest, ZRange) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_THAT(Run({"zrangebyscore", "x", "0", "(1.1"}), ElementsAre(ArrLen(0)));
EXPECT_THAT(Run({"zrangebyscore", "x", "-inf", "1.1"}), ElementsAre("a"));
}
TEST_F(ZSetFamilyTest, ZRemRangeRank) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_THAT(Run({"ZREMRANGEBYRANK", "y", "0", "1"}), ElementsAre(IntArg(0)));
EXPECT_THAT(Run({"ZREMRANGEBYRANK", "x", "0", "0"}), ElementsAre(IntArg(1)));
EXPECT_THAT(Run({"zrange", "x", "0", "5"}), ElementsAre("b"));
EXPECT_THAT(Run({"ZREMRANGEBYRANK", "x", "0", "1"}), ElementsAre(IntArg(1)));
EXPECT_THAT(Run({"type", "x"}), ElementsAre("none"));
}
TEST_F(ZSetFamilyTest, ZRemRangeScore) {
Run({"zadd", "x", "1.1", "a", "2.1", "b"});
EXPECT_THAT(Run({"ZREMRANGEBYSCORE", "y", "0", "1"}), ElementsAre(IntArg(0)));
EXPECT_THAT(Run({"ZREMRANGEBYSCORE", "x", "-inf", "1.1"}), ElementsAre(IntArg(1)));
EXPECT_THAT(Run({"zrange", "x", "0", "5"}), ElementsAre("b"));
EXPECT_THAT(Run({"ZREMRANGEBYSCORE", "x", "(2.0", "+inf"}), ElementsAre(IntArg(1)));
EXPECT_THAT(Run({"type", "x"}), ElementsAre("none"));
}
} // namespace dfly