diff --git a/src/facade/redis_parser.cc b/src/facade/redis_parser.cc index b5e6c2a6b..beaa4ac0d 100644 --- a/src/facade/redis_parser.cc +++ b/src/facade/redis_parser.cc @@ -21,55 +21,56 @@ auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> R } if (state_ == CMD_COMPLETE_S) { - state_ = INIT_S; - } - - if (state_ == INIT_S) { InitStart(str[0], res); + } else { + // We continue parsing in the middle. + if (!cached_expr_) + cached_expr_ = res; } + DCHECK(state_ != CMD_COMPLETE_S); - if (!cached_expr_) - cached_expr_ = res; + ResultConsumed resultc{OK, 0}; + + do { + if (str.empty()) { + resultc.first = INPUT_PENDING; + break; + } - while (state_ != CMD_COMPLETE_S) { - last_consumed_ = 0; switch (state_) { case MAP_LEN_S: case ARRAY_LEN_S: - last_result_ = ConsumeArrayLen(str); + resultc = ConsumeArrayLen(str); break; case PARSE_ARG_S: if (str.size() == 0 || (str.size() < 4 && str[0] != '_')) { - last_result_ = INPUT_PENDING; + resultc.first = INPUT_PENDING; } else { - last_result_ = ParseArg(str); + resultc = ParseArg(str); } break; case INLINE_S: DCHECK(parse_stack_.empty()); - last_result_ = ParseInline(str); + resultc = ParseInline(str); break; case BULK_STR_S: - last_result_ = ConsumeBulk(str); - break; - case FINISH_ARG_S: - HandleFinishArg(); + resultc = ConsumeBulk(str); break; default: LOG(FATAL) << "Unexpected state " << int(state_); } - *consumed += last_consumed_; + *consumed += resultc.second; - if (last_result_ != OK) { + if (resultc.first != OK) { break; } - str.remove_prefix(last_consumed_); - } + str.remove_prefix(exchange(resultc.second, 0)); + } while (state_ != CMD_COMPLETE_S); - if (last_result_ == INPUT_PENDING) { + if (resultc.first == INPUT_PENDING) { StashState(res); - } else if (last_result_ == OK) { + } else if (resultc.first == OK) { DCHECK(cached_expr_); if (res != cached_expr_) { DCHECK(!stash_.empty()); @@ -78,7 +79,7 @@ auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> R } } - return last_result_; + return resultc.first; } void RedisParser::InitStart(uint8_t prefix_b, RespExpr::Vec* res) { @@ -150,17 +151,20 @@ void RedisParser::StashState(RespExpr::Vec* res) { } } -auto RedisParser::ParseInline(Buffer str) -> Result { +auto RedisParser::ParseInline(Buffer str) -> ResultConsumed { DCHECK(!str.empty()); uint8_t* ptr = str.begin(); uint8_t* end = str.end(); uint8_t* token_start = ptr; - if (is_broken_token_) { + auto find_token_end = [&] { while (ptr != end && *ptr > 32) ++ptr; + }; + if (is_broken_token_) { + find_token_end(); size_t len = ptr - token_start; ExtendLastString(Buffer(token_start, len)); @@ -182,80 +186,69 @@ auto RedisParser::ParseInline(Buffer str) -> Result { DCHECK(!is_broken_token_); token_start = ptr; - while (ptr != end && *ptr > 32) - ++ptr; + find_token_end(); cached_expr_->emplace_back(RespExpr::STRING); cached_expr_->back().u = Buffer{token_start, size_t(ptr - token_start)}; } - last_consumed_ = ptr - str.data(); + uint32_t last_consumed = ptr - str.data(); if (ptr == end) { // we have not finished parsing. if (ptr[-1] > 32) { // we stopped in the middle of the token. is_broken_token_ = true; } - return INPUT_PENDING; - } else { - ++last_consumed_; // consume the delimiter as well. + return {INPUT_PENDING, last_consumed}; } + + ++last_consumed; // consume the delimiter as well. state_ = CMD_COMPLETE_S; - return OK; + return {OK, last_consumed}; } -auto RedisParser::ParseNum(Buffer str, int64_t* res) -> Result { - if (str.size() < 4) { - return INPUT_PENDING; - } - DCHECK(str[0] == '$' || str[0] == '*' || str[0] == '%' || str[0] == '~'); +// Parse lines like:'$5\r\n' or '*2\r\n' +auto RedisParser::ParseLen(Buffer str, int64_t* res) -> ResultConsumed { + DCHECK(!str.empty()); char* s = reinterpret_cast(str.data() + 1); char* pos = reinterpret_cast(memchr(s, '\n', str.size() - 1)); if (!pos) { - return str.size() < 32 ? INPUT_PENDING : BAD_INT; + Result r = str.size() < 32 ? INPUT_PENDING : BAD_ARRAYLEN; + return {r, 0}; } + if (pos[-1] != '\r') { - return BAD_INT; + return {BAD_ARRAYLEN, 0}; } bool success = absl::SimpleAtoi(std::string_view{s, size_t(pos - s - 1)}, res); - if (!success) { - return BAD_INT; - } - last_consumed_ = (pos - s) + 2; - - return OK; + return ResultConsumed{success ? OK : BAD_ARRAYLEN, (pos - s) + 2}; } -auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { +auto RedisParser::ConsumeArrayLen(Buffer str) -> ResultConsumed { int64_t len; - Result res = ParseNum(str, &len); + ResultConsumed res = ParseLen(str, &len); + if (res.first != OK) { + return res; + } + if (state_ == MAP_LEN_S) { // Map starts with %N followed by an array of 2*N elements. // Even elements are keys, odd elements are values. len *= 2; } - switch (res) { - case INPUT_PENDING: - return INPUT_PENDING; - case BAD_INT: - return BAD_ARRAYLEN; - case OK: - if (len < -1 || len > max_arr_len_) { - LOG_IF(WARNING, len > max_arr_len_) << "Multibulk len is too large " << len; - return BAD_ARRAYLEN; - } - break; - default: - LOG(ERROR) << "Unexpected result " << res; + if (len < -1 || len > max_arr_len_) { + LOG_IF(WARNING, len > max_arr_len_) << "Multibulk len is too large " << len; + + return {BAD_ARRAYLEN, res.second}; } if (server_mode_ && (parse_stack_.size() > 0 || !cached_expr_->empty())) - return BAD_STRING; + return {BAD_STRING, res.second}; if (len <= 0) { cached_expr_->emplace_back(len == -1 ? RespExpr::NIL_ARRAY : RespExpr::ARRAY); @@ -265,9 +258,13 @@ auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { static RespVec empty_vec; cached_expr_->back().u = &empty_vec; } - state_ = (parse_stack_.empty()) ? CMD_COMPLETE_S : FINISH_ARG_S; + if (parse_stack_.empty()) { + state_ = CMD_COMPLETE_S; + } else { + HandleFinishArg(); + } - return OK; + return {OK, res.second}; } if (state_ == PARSE_ARG_S) { @@ -286,54 +283,49 @@ auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { DVLOG(1) << "PushStack: (" << len << ", " << cached_expr_ << ")"; parse_stack_.emplace_back(len, cached_expr_); - return OK; + return {OK, res.second}; } -auto RedisParser::ParseArg(Buffer str) -> Result { +auto RedisParser::ParseArg(Buffer str) -> ResultConsumed { char c = str[0]; if (c == '$') { int64_t len; - Result res = ParseNum(str, &len); - switch (res) { - case INPUT_PENDING: - return INPUT_PENDING; - case BAD_INT: - return BAD_ARRAYLEN; - case OK: - if (len < -1 || len > kMaxBulkLen) - return BAD_ARRAYLEN; - break; - default: - LOG(ERROR) << "Unexpected result " << res; + ResultConsumed res = ParseLen(str, &len); + if (res.first != OK) { + return res; } - if (len < 0) { // Resp2 NIL - state_ = FINISH_ARG_S; + if (len < -1 || len > kMaxBulkLen) + return {BAD_ARRAYLEN, res.second}; + + if (len == -1) { // Resp2 NIL cached_expr_->emplace_back(RespExpr::NIL); + cached_expr_->back().u = Buffer{}; + HandleFinishArg(); } else { DVLOG(1) << "String(" << len << ")"; cached_expr_->emplace_back(RespExpr::STRING); + cached_expr_->back().u = Buffer{}; bulk_len_ = len; state_ = BULK_STR_S; } - cached_expr_->back().u = Buffer{}; - return OK; + return {OK, res.second}; } if (server_mode_) { - return BAD_BULKLEN; + return {BAD_BULKLEN, 0}; } if (c == '_') { // Resp3 NIL // TODO: Do we need to validate that str[1:2] == "\r\n"? - state_ = FINISH_ARG_S; + cached_expr_->emplace_back(RespExpr::NIL); cached_expr_->back().u = Buffer{}; - last_consumed_ += 3; // '_','\r','\n' - return OK; + HandleFinishArg(); + return {OK, 3}; // // '_','\r','\n' } if (c == '*') { @@ -346,54 +338,60 @@ auto RedisParser::ParseArg(Buffer str) -> Result { if (c == '+' || c == '-') { // Simple string or error. DCHECK(!server_mode_); if (!eol) { - return str.size() < 256 ? INPUT_PENDING : BAD_STRING; + Result r = str.size() < 256 ? INPUT_PENDING : BAD_STRING; + return {r, 0}; } + if (eol[-1] != '\r') - return BAD_STRING; + return {BAD_STRING, 0}; cached_expr_->emplace_back(c == '+' ? RespExpr::STRING : RespExpr::ERROR); cached_expr_->back().u = Buffer{reinterpret_cast(s), size_t((eol - 1) - s)}; } else if (c == ':') { DCHECK(!server_mode_); if (!eol) { - return str.size() < 32 ? INPUT_PENDING : BAD_INT; + Result r = str.size() < 32 ? INPUT_PENDING : BAD_INT; + return {r, 0}; } int64_t ival; std::string_view tok{s, size_t((eol - s) - 1)}; if (eol[-1] != '\r' || !absl::SimpleAtoi(tok, &ival)) - return BAD_INT; + return {BAD_INT, 0}; cached_expr_->emplace_back(RespExpr::INT64); cached_expr_->back().u = ival; } else if (c == ',') { DCHECK(!server_mode_); if (!eol) { - return str.size() < 32 ? INPUT_PENDING : BAD_DOUBLE; + Result r = str.size() < 32 ? INPUT_PENDING : BAD_DOUBLE; + return {r, 0}; } double_t dval; std::string_view tok{s, size_t((eol - s) - 1)}; if (eol[-1] != '\r' || !absl::SimpleAtod(tok, &dval)) - return BAD_INT; + return {BAD_INT, 0}; cached_expr_->emplace_back(RespExpr::DOUBLE); cached_expr_->back().u = dval; } else { - return BAD_STRING; + return {BAD_STRING, 0}; } - last_consumed_ = (eol - s) + 2; - state_ = FINISH_ARG_S; - return OK; + HandleFinishArg(); + + return {OK, (eol - s) + 2}; } -auto RedisParser::ConsumeBulk(Buffer str) -> Result { +auto RedisParser::ConsumeBulk(Buffer str) -> ResultConsumed { auto& bulk_str = get(cached_expr_->back().u); + uint32_t consumed = 0; + if (str.size() >= bulk_len_ + 2) { if (str[bulk_len_] != '\r' || str[bulk_len_ + 1] != '\n') { - return BAD_STRING; + return {BAD_STRING, 0}; } if (bulk_len_) { @@ -405,11 +403,11 @@ auto RedisParser::ConsumeBulk(Buffer str) -> Result { } } is_broken_token_ = false; - state_ = FINISH_ARG_S; - last_consumed_ = bulk_len_ + 2; + consumed = bulk_len_ + 2; bulk_len_ = 0; + HandleFinishArg(); - return OK; + return {OK, consumed}; } if (str.size() >= 32) { @@ -429,11 +427,11 @@ auto RedisParser::ConsumeBulk(Buffer str) -> Result { is_broken_token_ = true; cached_expr_->back().has_support = true; } - last_consumed_ = len; + consumed = len; bulk_len_ -= len; } - return INPUT_PENDING; + return {INPUT_PENDING, consumed}; } void RedisParser::HandleFinishArg() { diff --git a/src/facade/redis_parser.h b/src/facade/redis_parser.h index 9f899c672..41c3838ef 100644 --- a/src/facade/redis_parser.h +++ b/src/facade/redis_parser.h @@ -24,23 +24,31 @@ class RedisParser { public: constexpr static long kMaxBulkLen = 256 * (1ul << 20); // 256MB. - enum Result { OK, INPUT_PENDING, BAD_ARRAYLEN, BAD_BULKLEN, BAD_STRING, BAD_INT, BAD_DOUBLE }; + enum Result : uint8_t { + OK, + INPUT_PENDING, + BAD_ARRAYLEN, + BAD_BULKLEN, + BAD_STRING, + BAD_INT, + BAD_DOUBLE + }; using Buffer = RespExpr::Buffer; explicit RedisParser(uint32_t max_arr_len = UINT32_MAX, bool server_mode = true) - : max_arr_len_(max_arr_len), server_mode_(server_mode) { + : server_mode_(server_mode), max_arr_len_(max_arr_len) { } /** * @brief Parses str into res. "consumed" stores number of bytes consumed from str. * * A caller should not invalidate str if the parser returns RESP_OK as long as he continues - * accessing res. However, if parser returns MORE_INPUT a caller may discard consumed + * accessing res. However, if parser returns INPUT_PENDING a caller may discard consumed * part of str because parser caches the intermediate state internally according to 'consumed' * result. * * Note: A parser does not always guarantee progress, i.e. if a small buffer was passed it may - * returns MORE_INPUT with consumed == 0. + * returns INPUT_PENDING with consumed == 0. * */ @@ -64,49 +72,49 @@ class RedisParser { size_t UsedMemory() const; private: + using ResultConsumed = std::pair; + void InitStart(uint8_t prefix_b, RespVec* res); void StashState(RespVec* res); // Skips the first character (*). - Result ConsumeArrayLen(Buffer str); - Result ParseArg(Buffer str); - Result ConsumeBulk(Buffer str); - Result ParseInline(Buffer str); + ResultConsumed ConsumeArrayLen(Buffer str); + ResultConsumed ParseArg(Buffer str); + ResultConsumed ConsumeBulk(Buffer str); + ResultConsumed ParseInline(Buffer str); + + ResultConsumed ParseLen(Buffer str, int64_t* res); - // Updates last_consumed_ - Result ParseNum(Buffer str, int64_t* res); void HandleFinishArg(); void ExtendLastString(Buffer str); enum State : uint8_t { - INIT_S = 0, INLINE_S, ARRAY_LEN_S, MAP_LEN_S, PARSE_ARG_S, // Parse [$:+-]string\r\n BULK_STR_S, - FINISH_ARG_S, CMD_COMPLETE_S, }; - State state_ = INIT_S; - Result last_result_ = OK; + State state_ = CMD_COMPLETE_S; + bool is_broken_token_ = false; // whether the last inline string was broken in the middle. + bool server_mode_ = true; - uint32_t last_consumed_ = 0; uint32_t bulk_len_ = 0; uint32_t last_stashed_level_ = 0, last_stashed_index_ = 0; + uint32_t max_arr_len_; + + // Points either to the result passed by the caller or to the stash. + RespVec* cached_expr_ = nullptr; // expected expression length, pointer to expression vector. + // For server mode, the length is at most 1. absl::InlinedVector, 4> parse_stack_; std::vector> stash_; using Blob = std::vector; std::vector buf_stash_; - RespVec* cached_expr_ = nullptr; - uint32_t max_arr_len_; - - bool is_broken_token_ = false; - bool server_mode_ = true; }; } // namespace facade