1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

fix(rax_tree): Fix crash caused by destructor in RaxTreeMap (#4228)

* fix(rax_tree): Fix double raxStop call in the SeekIterator

fixes dragonflydb#4172

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor(rax_tree): Address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-12-04 21:34:15 +04:00 committed by GitHub
parent d8fda40d4d
commit 81079df0e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 63 additions and 23 deletions

View file

@ -24,30 +24,37 @@ template <typename V> struct RaxTreeMap {
// Simple seeking iterator
struct SeekIterator {
friend struct FindIterator;
SeekIterator() {
raxStart(&it_, nullptr);
it_.node = nullptr;
it_.rt = nullptr;
}
~SeekIterator() {
raxStop(&it_);
}
SeekIterator(SeekIterator&&) = delete; // self-referential
SeekIterator(const SeekIterator&) = delete; // self-referential
SeekIterator(rax* tree, const char* op, std::string_view key) {
raxStart(&it_, tree);
raxSeek(&it_, op, to_key_ptr(key), key.size());
operator++();
if (raxSeek(&it_, op, to_key_ptr(key), key.size())) { // Successfuly seeked
operator++();
} else {
InvalidateIterator();
}
}
explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) {
}
/* Remove copy/move constructors to avoid double iterator invalidation */
SeekIterator(SeekIterator&&) = delete;
SeekIterator(const SeekIterator&) = delete;
SeekIterator& operator=(SeekIterator&&) = delete;
SeekIterator& operator=(const SeekIterator&) = delete;
~SeekIterator() {
if (IsValid()) {
InvalidateIterator();
}
}
bool operator==(const SeekIterator& rhs) const {
if (!IsValid() || !rhs.IsValid())
return !IsValid() && !rhs.IsValid();
return it_.node == rhs.it_.node;
}
@ -56,31 +63,40 @@ template <typename V> struct RaxTreeMap {
}
SeekIterator& operator++() {
if (!raxNext(&it_)) {
raxStop(&it_);
it_.node = nullptr;
int next_result = raxNext(&it_);
if (!next_result) { // OOM or we reached the end of the tree
InvalidateIterator();
}
return *this;
}
/* After operator++() the first value (string_view) is invalid. So make sure your copied it to
* string */
std::pair<std::string_view, V&> operator*() const {
assert(IsValid() && it_.node && it_.node->iskey && it_.data);
return {std::string_view{reinterpret_cast<const char*>(it_.key), it_.key_len},
*reinterpret_cast<V*>(it_.data)};
}
bool IsValid() const {
return it_.rt;
}
private:
void InvalidateIterator() {
raxStop(&it_);
it_.rt = nullptr;
}
raxIterator it_;
};
// Result of find() call. Inherits from pair to mimic iterator interface, not incrementable.
struct FindIterator : public std::optional<std::pair<std::string, V&>> {
bool operator==(const SeekIterator& rhs) const {
if (this->has_value() != !bool(rhs.it_.flags & RAX_ITER_EOF))
return false;
if (!this->has_value())
return true;
return (*this)->first ==
std::string_view{reinterpret_cast<const char*>(rhs.it_.key), rhs.it_.key_len};
if (!this->has_value() || !rhs.IsValid())
return !this->has_value() && !rhs.IsValid();
return (*this)->first == (*rhs).first;
}
bool operator!=(const SeekIterator& rhs) const {
@ -160,7 +176,7 @@ std::pair<typename RaxTreeMap<V>::FindIterator, bool> RaxTreeMap<V>::try_emplace
V* old = nullptr;
raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast<void**>(&old));
assert(old == nullptr);
assert(!old);
auto it = std::make_optional(std::pair<std::string, V&>(std::string(key), *ptr));
return std::make_pair(std::move(FindIterator{it}), true);

View file

@ -104,4 +104,28 @@ TEST_F(RaxTreeTest, Find) {
EXPECT_TRUE(map.find(string_view{}) == map.end());
}
/* Run with mimalloc to make sure there is no double free */
TEST_F(RaxTreeTest, Iterate) {
const char* kKeys[] = {
"aaaaaaaaaaaaaaaaaaaa",
"bbbbbbbbbbbbbbbbbbbbbb"
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
"dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"
"eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
};
RaxTreeMap<int> map(pmr::get_default_resource());
for (const char* key : kKeys) {
map.try_emplace(key, 2);
}
for (auto it = map.begin(); it != map.end(); ++it) {
EXPECT_EQ((*it).second, 2);
}
for (auto it = map.begin(); it != map.end(); ++it) {
EXPECT_EQ((*it).second, 2);
}
}
} // namespace dfly::search