1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat(server): support restore command - refactor load class (#343) (#386)

feat(server): adding support for the restore command (#343)

Signed-off-by: Boaz Sade <boaz@dragonflydb.io>
Co-authored-by: Boaz Sade <boaz@dragonflydb.io>
This commit is contained in:
Boaz Sade 2022-10-18 11:13:16 +03:00 committed by GitHub
parent ce964f103a
commit b1470ba047
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 1076 additions and 700 deletions

View file

@ -179,7 +179,7 @@ with respect to Memcached and Redis APIs.
- [ ] OBJECT
- [x] PERSIST
- [X] PTTL
- [ ] RESTORE
- [x] RESTORE
- [X] SCRIPT LOAD/EXISTS
- [ ] SCRIPT DEBUG/KILL/FLUSH
- [X] Set Family

View file

@ -19,6 +19,7 @@ extern "C" {
#include "server/container_utils.h"
#include "server/engine_shard_set.h"
#include "server/error.h"
#include "server/rdb_load.h"
#include "server/rdb_save.h"
#include "server/transaction.h"
#include "util/varz.h"
@ -31,8 +32,16 @@ using namespace std;
using namespace facade;
namespace {
using VersionBuffer = std::array<char, 2>;
using VersionBuffer = std::array<char, sizeof(uint16_t)>;
using CrcBuffer = std::array<char, sizeof(uint64_t)>;
constexpr size_t DUMP_FOOTER_SIZE = sizeof(uint64_t) + sizeof(uint16_t); // version number and crc
int64_t CalculateExpirationTime(bool seconds, bool absolute, int64_t ts, int64_t now_msec) {
int64_t msec = seconds ? ts * 1000 : ts;
int64_t rel_msec = absolute ? msec - now_msec : msec;
return rel_msec;
}
VersionBuffer MakeRdbVersion() {
VersionBuffer buf;
@ -61,6 +70,200 @@ void AppendFooter(std::string* dump_res) {
dump_res->append(crc.data(), crc.size());
}
bool VerifyFooter(std::string_view msg) {
if (msg.size() <= DUMP_FOOTER_SIZE) {
LOG(WARNING) << "got restore payload that is too short - " << msg.size();
return false;
}
const uint8_t* footer =
reinterpret_cast<const uint8_t*>(msg.data()) + (msg.size() - DUMP_FOOTER_SIZE);
uint16_t version = (*(footer + 1) << 8 | (*footer));
if (version > RDB_VERSION) {
LOG(WARNING) << "got restore payload with illegal version - supporting version up to "
<< RDB_VERSION << " got version " << version;
return false;
}
uint64_t expected_cs =
crc64(0, reinterpret_cast<const uint8_t*>(msg.data()), msg.size() - sizeof(uint64_t));
uint64_t actual_cs = absl::little_endian::Load64(footer + sizeof(version));
if (actual_cs != expected_cs) {
LOG(WARNING) << "CRC check failed for restore command, expecting: " << expected_cs << " got "
<< actual_cs;
return false;
}
return true;
}
class InMemSource : public ::io::Source {
public:
InMemSource(std::string_view buf) : buf_(buf) {
}
::io::Result<size_t> ReadSome(const iovec* v, uint32_t len) final;
protected:
std::string_view buf_;
off_t offs_ = 0;
};
::io::Result<size_t> InMemSource::ReadSome(const iovec* v, uint32_t len) {
ssize_t read_total = 0;
while (size_t(offs_) < buf_.size() && len > 0) {
size_t read_sz = min(buf_.size() - offs_, v->iov_len);
memcpy(v->iov_base, buf_.data() + offs_, read_sz);
read_total += read_sz;
offs_ += read_sz;
++v;
--len;
}
return read_total;
}
class RdbRestoreValue : protected RdbLoaderBase {
public:
bool Add(std::string_view payload, std::string_view key, DbSlice& db_slice, DbIndex index,
uint64_t expire_ms);
private:
std::optional<OpaqueObj> Parse(std::string_view payload);
};
std::optional<RdbLoaderBase::OpaqueObj> RdbRestoreValue::Parse(std::string_view payload) {
InMemSource source(payload);
src_ = &source;
if (auto type_id = FetchType(); type_id && rdbIsObjectType(type_id.value())) {
io::Result<OpaqueObj> io_res = ReadObj(type_id.value()); // load the type from the input stream
if (!io_res) {
LOG(ERROR) << "failed to load data for type id " << (unsigned int)type_id.value();
return std::nullopt;
}
return std::optional<OpaqueObj>(std::move(io_res.value()));
} else {
LOG(ERROR) << "failed to load type id from the input stream or type id is invalid";
return std::nullopt;
}
}
bool RdbRestoreValue::Add(std::string_view data, std::string_view key, DbSlice& db_slice,
DbIndex index, uint64_t expire_ms) {
auto value_to_load = Parse(data);
if (!value_to_load) {
return false;
}
Item item{
.key = std::string(key), .val = std::move(value_to_load.value()), .expire_ms = expire_ms};
PrimeValue pv;
if (auto ec = Visit(item, &pv); ec) {
// we failed - report and exit
LOG(WARNING) << "error while trying to save data: " << ec;
return false;
}
DbContext context{.db_index = index, .time_now_ms = GetCurrentTimeMs()};
auto [it, added] = db_slice.AddEntry(context, key, std::move(pv), item.expire_ms);
return added;
}
class RestoreArgs {
static constexpr int64_t NO_EXPIRATION = 0;
int64_t expiration_ = NO_EXPIRATION;
bool abs_time_ = false;
bool replace_ = false; // if true, over-ride existing key
public:
constexpr bool Replace() const {
return replace_;
}
constexpr int64_t ExpirationTime() const {
return expiration_;
}
[[nodiscard]] constexpr bool Expired() const {
return ExpirationTime() < 0;
}
[[nodiscard]] constexpr bool HasExpiration() const {
return expiration_ != NO_EXPIRATION;
}
[[nodiscard]] bool UpdateExpiration(int64_t now_msec);
static OpResult<RestoreArgs> TryFrom(const CmdArgList& args);
};
[[nodiscard]] bool RestoreArgs::UpdateExpiration(int64_t now_msec) {
if (HasExpiration()) {
auto new_ttl = CalculateExpirationTime(!abs_time_, abs_time_, expiration_, now_msec);
if (new_ttl > kMaxExpireDeadlineSec * 1000) {
return false;
}
expiration_ = new_ttl;
if (new_ttl > 0) {
expiration_ += now_msec;
}
}
return true;
}
// The structure that we are expecting is:
// args[0] == "RESTORE"
// args[1] == "key"
// args[2] == "ttl"
// args[3] == serialized value (list of chars that are used for the actual restore).
// args[4] .. args[n]: optional arguments that can be [REPLACE] [ABSTTL] [IDLETIME seconds]
// [FREQ frequency], in any order
OpResult<RestoreArgs> RestoreArgs::TryFrom(const CmdArgList& args) {
RestoreArgs out_args;
std::string_view cur_arg = ArgS(args, 2); // extract ttl
if (!absl::SimpleAtoi(cur_arg, &out_args.expiration_) || (out_args.expiration_ < 0)) {
return OpStatus::INVALID_INT;
}
// the 3rd arg is the serialized value, so we are starting from one pass it
// Note that all these are actually optional
// note about the redis doc for this command: https://redis.io/commands/restore/
// the IDLETIME and FREQ are not required, but to make this the same as in redis
// we would parse them and ensure that they are correct, maybe later they will be used
int64_t idle_time = 0;
for (size_t i = 4; i < args.size(); ++i) {
ToUpper(&args[i]);
cur_arg = ArgS(args, i);
bool additional = args.size() - i - 1 >= 1;
if (cur_arg == "REPLACE") {
out_args.replace_ = true;
} else if (cur_arg == "ABSTTL") {
out_args.abs_time_ = true;
} else if (cur_arg == "IDLETIME" && additional) {
++i;
cur_arg = ArgS(args, i);
if (!absl::SimpleAtoi(cur_arg, &idle_time)) {
return OpStatus::INVALID_INT;
}
if (idle_time < 0) {
return OpStatus::SYNTAX_ERR;
}
} else if (cur_arg == "FREQ" && additional) {
++i;
cur_arg = ArgS(args, i);
int freq = 0;
if (!absl::SimpleAtoi(cur_arg, &freq)) {
return OpStatus::INVALID_INT;
}
if (freq < 0 || freq > 255) {
return OpStatus::OUT_OF_RANGE; // need to translate in this case
}
} else {
LOG(WARNING) << "Got unknown command line option for restore '" << cur_arg << "'";
return OpStatus::SYNTAX_ERR;
}
}
return out_args;
}
OpStatus OpPersist(const OpArgs& op_args, string_view key);
class Renamer {
@ -259,6 +462,40 @@ OpResult<std::string> OpDump(const OpArgs& op_args, string_view key) {
return OpStatus::KEY_NOTFOUND;
}
OpResult<bool> OnRestore(const OpArgs& op_args, std::string_view key, std::string_view payload,
RestoreArgs restore_args) {
if (!restore_args.UpdateExpiration(op_args.db_cntx.time_now_ms)) {
return OpStatus::OUT_OF_RANGE;
}
auto& db_slice = op_args.shard->db_slice();
// The redis impl (see cluster.c function restoreCommand), remove the old key if
// the replace option is set, so lets do the same here
auto [from_it, from_expire] = db_slice.FindExt(op_args.db_cntx, key);
if (restore_args.Replace()) {
if (IsValid(from_it)) {
VLOG(1) << "restore command is running with replace, found old key '" << key
<< "' and removing it";
CHECK(db_slice.Del(op_args.db_cntx.db_index, from_it));
}
} else {
// we are not allowed to replace it, so make sure it doesn't exist
if (IsValid(from_it)) {
return OpStatus::KEY_EXISTS;
}
}
if (restore_args.Expired()) {
VLOG(1) << "the new key '" << key << "' already expired, will not save the value";
return true;
}
RdbRestoreValue loader{};
return loader.Add(payload, key, db_slice, op_args.db_cntx.db_index,
restore_args.ExpirationTime());
}
bool ScanCb(const OpArgs& op_args, PrimeIterator it, const ScanOpts& opts, StringVec* res) {
auto& db_slice = op_args.shard->db_slice();
if (it->second.HasExpire()) {
@ -728,6 +965,47 @@ void GenericFamily::Sort(CmdArgList args, ConnectionContext* cntx) {
std::visit(std::move(sort_call), entries.value());
}
void GenericFamily::Restore(CmdArgList args, ConnectionContext* cntx) {
std::string_view key = ArgS(args, 1);
std::string_view serialized_value = ArgS(args, 3);
if (!VerifyFooter(serialized_value)) {
return (*cntx)->SendError("ERR DUMP payload version or checksum are wrong");
}
OpResult<RestoreArgs> restore_args = RestoreArgs::TryFrom(args);
if (!restore_args) {
if (restore_args.status() == OpStatus::OUT_OF_RANGE) {
return (*cntx)->SendError("Invalid IDLETIME value, must be >= 0");
} else {
return (*cntx)->SendError(restore_args.status());
}
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OnRestore(t->GetOpArgs(shard), key, serialized_value, restore_args.value());
};
OpResult<bool> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result) {
if (result.value()) {
return (*cntx)->SendOk();
} else {
return (*cntx)->SendError("Bad data format");
}
} else {
switch (result.status()) {
case OpStatus::KEY_EXISTS:
return (*cntx)->SendError("BUSYKEY: key name already exists.");
case OpStatus::WRONG_TYPE:
return (*cntx)->SendError("Bad data format");
default:
return (*cntx)->SendError(result.status());
}
}
}
void GenericFamily::Move(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);
int64_t target_db;
@ -1144,7 +1422,8 @@ void GenericFamily::Register(CommandRegistry* registry) {
<< CI{"UNLINK", CO::WRITE, -2, 1, -1, 1}.HFUNC(Del)
<< CI{"STICK", CO::WRITE, -2, 1, -1, 1}.HFUNC(Stick)
<< CI{"SORT", CO::READONLY, -2, 1, 1, 1}.HFUNC(Sort)
<< CI{"MOVE", CO::WRITE | CO::GLOBAL_TRANS, 3, 1, 1, 1}.HFUNC(Move);
<< CI{"MOVE", CO::WRITE | CO::GLOBAL_TRANS, 3, 1, 1, 1}.HFUNC(Move)
<< CI{"RESTORE", CO::WRITE, -4, 1, 1, 1}.HFUNC(Restore);
}
} // namespace dfly

View file

@ -56,6 +56,7 @@ class GenericFamily {
static void Time(CmdArgList args, ConnectionContext* cntx);
static void Type(CmdArgList args, ConnectionContext* cntx);
static void Dump(CmdArgList args, ConnectionContext* cntx);
static void Restore(CmdArgList args, ConnectionContext* cntx);
static OpResult<void> RenameGeneric(CmdArgList args, bool skip_exist_dest,
ConnectionContext* cntx);

View file

@ -417,4 +417,80 @@ TEST_F(GenericFamilyTest, Dump) {
resp = Run({"dump", "foo"});
EXPECT_EQ(resp.type, RespExpr::NIL);
}
TEST_F(GenericFamilyTest, Restore) {
using std::chrono::duration_cast;
using std::chrono::milliseconds;
using std::chrono::seconds;
using std::chrono::system_clock;
uint8_t STRING_DUMP_REDIS[] = {0x00, 0xc1, 0xd2, 0x04, 0x09, 0x00, 0xd0,
0x75, 0x59, 0x6d, 0x10, 0x04, 0x3f, 0x5c};
auto resp = Run({"set", "exiting-key", "1234"});
EXPECT_EQ(resp, "OK");
// try to restore into existing key - this should failed
ASSERT_THAT(Run({"restore", "exiting-key", "0", ToSV(STRING_DUMP_REDIS)}),
ArgType(RespExpr::ERROR));
// Try restore while setting expiration into the pass
// note that value for expiration is just some valid unix time stamp from the pass
resp = Run(
{"restore", "exiting-key", "1665476212900", ToSV(STRING_DUMP_REDIS), "ABSTTL", "REPLACE"});
CHECK_EQ(resp, "OK");
resp = Run({"get", "exiting-key"});
EXPECT_EQ(resp.type, RespExpr::NIL); // it was deleted as a result of restore action
// Test for string that we can successfully load the dumped data and read it back
resp = Run({"restore", "new-key", "0", ToSV(STRING_DUMP_REDIS)});
EXPECT_EQ(resp, "OK");
resp = Run({"get", "new-key"});
EXPECT_EQ("1234", resp);
resp = Run({"dump", "new-key"});
auto dump = resp.GetBuf();
CHECK_EQ(ToSV(dump), ToSV(STRING_DUMP_REDIS));
// test for list
EXPECT_EQ(1, CheckedInt({"rpush", "orig-list", "20"}));
resp = Run({"dump", "orig-list"});
dump = resp.GetBuf();
resp = Run({"restore", "new-list", "10", ToSV(dump)});
EXPECT_EQ(resp, "OK");
resp = Run({"lpop", "new-list"});
EXPECT_EQ("20", resp);
// run with hash type
EXPECT_EQ(1, CheckedInt({"hset", "orig-hash", "123", "45678"}));
resp = Run({"dump", "orig-hash"});
dump = resp.GetBuf();
resp = Run({"restore", "new-hash", "1", ToSV(dump)});
EXPECT_EQ(resp, "OK");
EXPECT_EQ(1, CheckedInt({"hexists", "new-hash", "123"}));
// test with replace and no TTL
resp = Run({"set", "string-key", "hello world"});
EXPECT_EQ(resp, "OK");
resp = Run({"dump", "string-key"});
dump = resp.GetBuf();
// this will change the value from "hello world" to "1234"
resp = Run({"restore", "string-key", "7", ToSV(STRING_DUMP_REDIS), "REPLACE"});
resp = Run({"get", "string-key"});
EXPECT_EQ("1234", resp);
// check TTL validity
EXPECT_EQ(CheckedInt({"ttl", "string-key"}), 7);
// Make check about ttl with abs time, restoring back to "hello world"
resp = Run({"restore", "string-key", absl::StrCat(TEST_current_time_ms + 2000), ToSV(dump),
"ABSTTL", "REPLACE"});
resp = Run({"get", "string-key"});
EXPECT_EQ("hello world", resp);
EXPECT_EQ(CheckedInt({"pttl", "string-key"}), 2000);
// Last but not least - just make sure that we are good without TTL as well
resp = Run({"restore", "string-key", "0", ToSV(STRING_DUMP_REDIS), "REPLACE"});
resp = Run({"get", "string-key"});
EXPECT_EQ("1234", resp);
EXPECT_EQ(CheckedInt({"ttl", "string-key"}), -1);
}
} // namespace dfly

File diff suppressed because it is too large Load diff

View file

@ -19,48 +19,23 @@ namespace dfly {
class EngineShardSet;
class ScriptMgr;
class CompactObj;
class RdbLoader {
public:
explicit RdbLoader(ScriptMgr* script_mgr);
class RdbLoaderBase {
protected:
RdbLoaderBase();
~RdbLoader();
std::error_code Load(::io::Source* src);
void set_source_limit(size_t n) {
source_limit_ = n;
}
::io::Bytes Leftover() const {
return mem_buf_.InputBuffer();
}
size_t bytes_read() const {
return bytes_read_;
}
size_t keys_loaded() const {
return keys_loaded_;
}
// returns time in seconds.
double load_time() const {
return load_time_;
}
private:
struct LoadTrace;
using MutableBytes = ::io::MutableBytes;
struct ObjSettings;
struct LzfString {
base::PODArray<uint8_t> compressed_blob;
uint64_t uncompressed_len;
};
struct LoadTrace;
using RdbVariant =
std::variant<long long, base::PODArray<char>, LzfString, std::unique_ptr<LoadTrace>>;
struct OpaqueObj {
RdbVariant obj;
int rdb_type;
@ -116,15 +91,14 @@ class RdbLoader {
};
using ItemsBuf = std::vector<Item>;
void ResizeDb(size_t key_num, size_t expire_num);
std::error_code HandleAux();
::io::Result<uint8_t> FetchType() {
return FetchInt<uint8_t>();
}
template <typename T> io::Result<T> FetchInt();
std::error_code Visit(const Item& item, CompactObj* pv);
io::Result<uint64_t> LoadLen(bool* is_encoded);
std::error_code FetchBuf(size_t size, void* dest);
@ -151,6 +125,8 @@ class RdbLoader {
::io::Result<OpaqueObj> ReadListQuicklist(int rdbtype);
::io::Result<OpaqueObj> ReadStreams();
static size_t StrLen(const RdbVariant& tset);
std::error_code EnsureRead(size_t min_sz) {
if (mem_buf_.InputLen() >= min_sz)
return std::error_code{};
@ -159,21 +135,57 @@ class RdbLoader {
}
std::error_code EnsureReadInternal(size_t min_sz);
protected:
base::IoBuf mem_buf_;
::io::Source* src_ = nullptr;
size_t bytes_read_ = 0;
size_t source_limit_ = SIZE_MAX;
base::PODArray<uint8_t> compr_buf_;
};
class RdbLoader : protected RdbLoaderBase {
public:
explicit RdbLoader(ScriptMgr* script_mgr);
~RdbLoader();
std::error_code Load(::io::Source* src);
void set_source_limit(size_t n) {
source_limit_ = n;
}
::io::Bytes Leftover() const {
return mem_buf_.InputBuffer();
}
size_t bytes_read() const {
return bytes_read_;
}
size_t keys_loaded() const {
return keys_loaded_;
}
// returns time in seconds.
double load_time() const {
return load_time_;
}
private:
struct ObjSettings;
std::error_code LoadKeyValPair(int type, ObjSettings* settings);
void ResizeDb(size_t key_num, size_t expire_num);
std::error_code HandleAux();
std::error_code VerifyChecksum();
void FlushShardAsync(ShardId sid);
void LoadItemsBuffer(DbIndex db_ind, const ItemsBuf& ib);
static size_t StrLen(const RdbVariant& tset);
ScriptMgr* script_mgr_;
base::IoBuf mem_buf_;
base::PODArray<uint8_t> compr_buf_;
std::unique_ptr<ItemsBuf[]> shard_buf_;
::io::Source* src_ = nullptr;
size_t bytes_read_ = 0;
size_t source_limit_ = SIZE_MAX;
size_t keys_loaded_ = 0;
double load_time_ = 0;