1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

chore: get rid of ToUpper call and use AsciiStrToUpper (#3944)

Also remove std:: in bitops family to reduce noise.
No functional changes.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-10-18 12:47:40 +03:00 committed by GitHub
parent 5ab32b97d9
commit a7c9fde38e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 103 additions and 126 deletions

View file

@ -28,14 +28,14 @@ using namespace std;
namespace { namespace {
using ShardStringResults = std::vector<OpResult<std::string>>; using ShardStringResults = vector<OpResult<string>>;
const int32_t OFFSET_FACTOR = 8; // number of bits in byte const int32_t OFFSET_FACTOR = 8; // number of bits in byte
const char* OR_OP_NAME = "OR"; const char* OR_OP_NAME = "OR";
const char* XOR_OP_NAME = "XOR"; const char* XOR_OP_NAME = "XOR";
const char* AND_OP_NAME = "AND"; const char* AND_OP_NAME = "AND";
const char* NOT_OP_NAME = "NOT"; const char* NOT_OP_NAME = "NOT";
using BitsStrVec = std::vector<std::string>; using BitsStrVec = vector<string>;
// The following is the list of the functions that would handle the // The following is the list of the functions that would handle the
// commands that handle the bit operations // commands that handle the bit operations
@ -47,48 +47,30 @@ void BitOp(CmdArgList args, ConnectionContext* cntx);
void GetBit(CmdArgList args, ConnectionContext* cntx); void GetBit(CmdArgList args, ConnectionContext* cntx);
void SetBit(CmdArgList args, ConnectionContext* cntx); void SetBit(CmdArgList args, ConnectionContext* cntx);
OpResult<std::string> ReadValue(const DbContext& context, std::string_view key, EngineShard* shard); OpResult<string> ReadValue(const DbContext& context, string_view key, EngineShard* shard);
OpResult<bool> ReadValueBitsetAt(const OpArgs& op_args, std::string_view key, uint32_t offset); OpResult<bool> ReadValueBitsetAt(const OpArgs& op_args, string_view key, uint32_t offset);
OpResult<std::size_t> CountBitsForValue(const OpArgs& op_args, std::string_view key, int64_t start, OpResult<std::size_t> CountBitsForValue(const OpArgs& op_args, string_view key, int64_t start,
int64_t end, bool bit_value); int64_t end, bool bit_value);
OpResult<int64_t> FindFirstBitWithValue(const OpArgs& op_args, std::string_view key, bool value, OpResult<int64_t> FindFirstBitWithValue(const OpArgs& op_args, string_view key, bool value,
int64_t start, int64_t end, bool as_bit); int64_t start, int64_t end, bool as_bit);
std::string GetString(const PrimeValue& pv); string GetString(const PrimeValue& pv);
bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry); bool SetBitValue(uint32_t offset, bool bit_value, string* entry);
std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end); std::size_t CountBitSetByByteIndices(string_view at, std::size_t start, std::size_t end);
std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits); std::size_t CountBitSet(string_view str, int64_t start, int64_t end, bool bits);
std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end); std::size_t CountBitSetByBitIndices(string_view at, std::size_t start, std::size_t end);
std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& values); string RunBitOperationOnValues(string_view op, const BitsStrVec& values);
// ------------------------------------------------------------------------- // // ------------------------------------------------------------------------- //
// Converts `args[i] to uppercase, then sets `*as_bit` to true if `args[i]` equals "BIT", false if
// `args[i]` equals "BYTE", or returns false if `args[i]` has some other invalid value.
bool ToUpperAndGetAsBit(CmdArgList args, size_t i, bool* as_bit) {
CHECK_NOTNULL(as_bit);
ToUpper(&args[i]);
std::string_view arg = ArgS(args, i);
if (arg == "BIT") {
*as_bit = true;
return true;
} else if (arg == "BYTE") {
*as_bit = false;
return true;
} else {
return false;
}
}
// This function can be used for any case where we allowing out of bound // This function can be used for any case where we allowing out of bound
// access where the default in this case would be 0 -such as bitop // access where the default in this case would be 0 -such as bitop
uint8_t GetByteAt(std::string_view s, std::size_t at) { uint8_t GetByteAt(string_view s, std::size_t at) {
return at >= s.size() ? 0 : s[at]; return at >= s.size() ? 0 : s[at];
} }
// For XOR, OR, AND operations on a collection of bytes // For XOR, OR, AND operations on a collection of bytes
template <typename BitOp, typename SkipOp> template <typename BitOp, typename SkipOp>
std::string BitOpString(BitOp operation_f, SkipOp skip_f, const BitsStrVec& values, string BitOpString(BitOp operation_f, SkipOp skip_f, const BitsStrVec& values, string new_value) {
std::string new_value) {
// at this point, values are not empty // at this point, values are not empty
std::size_t max_size = new_value.size(); std::size_t max_size = new_value.size();
@ -137,7 +119,7 @@ constexpr uint8_t XorOp(uint8_t left, uint8_t right) {
return left ^ right; return left ^ right;
} }
std::string BitOpNotString(std::string from) { string BitOpNotString(string from) {
std::transform(from.begin(), from.end(), from.begin(), [](auto c) { return ~c; }); std::transform(from.begin(), from.end(), from.begin(), [](auto c) { return ~c; });
return from; return from;
} }
@ -155,7 +137,7 @@ constexpr int32_t GetByteIndex(uint32_t offset) noexcept {
return offset / OFFSET_FACTOR; return offset / OFFSET_FACTOR;
} }
uint8_t GetByteValue(std::string_view str, uint32_t offset) { uint8_t GetByteValue(string_view str, uint32_t offset) {
return static_cast<uint8_t>(str[GetByteIndex(offset)]); return static_cast<uint8_t>(str[GetByteIndex(offset)]);
} }
@ -173,7 +155,7 @@ constexpr std::uint8_t CountBitsRange(std::uint8_t byte, std::uint8_t from, uint
// Count the number of bits that are on, on bytes boundaries: i.e. Start and end are the indices for // Count the number of bits that are on, on bytes boundaries: i.e. Start and end are the indices for
// bytes locations inside str CountBitSetByByteIndices // bytes locations inside str CountBitSetByByteIndices
std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std::size_t end) { std::size_t CountBitSetByByteIndices(string_view at, std::size_t start, std::size_t end) {
if (start >= end) { if (start >= end) {
return 0; return 0;
} }
@ -186,7 +168,7 @@ std::size_t CountBitSetByByteIndices(std::string_view at, std::size_t start, std
// Count the number of bits that are on, on bits boundaries: i.e. Start and end are the indices for // Count the number of bits that are on, on bits boundaries: i.e. Start and end are the indices for
// bits locations inside str // bits locations inside str
std::size_t CountBitSetByBitIndices(std::string_view at, std::size_t start, std::size_t end) { std::size_t CountBitSetByBitIndices(string_view at, std::size_t start, std::size_t end) {
auto first_byte_index = GetByteIndex(start); auto first_byte_index = GetByteIndex(start);
auto last_byte_index = GetByteIndex(end); auto last_byte_index = GetByteIndex(end);
if (start % OFFSET_FACTOR == 0 && end % OFFSET_FACTOR == 0) { if (start % OFFSET_FACTOR == 0 && end % OFFSET_FACTOR == 0) {
@ -219,7 +201,7 @@ int64_t NormalizedOffset(int64_t size, int64_t offset) {
// The parameters for start, end and bits are defaulted to the start of the string, // The parameters for start, end and bits are defaulted to the start of the string,
// end of the string and bits are false. // end of the string and bits are false.
// Note that when bits is false, it means that we are looking on byte boundaries. // Note that when bits is false, it means that we are looking on byte boundaries.
std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool bits) { std::size_t CountBitSet(string_view str, int64_t start, int64_t end, bool bits) {
const int64_t strlen = bits ? str.size() * OFFSET_FACTOR : str.size(); const int64_t strlen = bits ? str.size() * OFFSET_FACTOR : str.size();
if (start < 0) if (start < 0)
@ -241,13 +223,13 @@ std::size_t CountBitSet(std::string_view str, int64_t start, int64_t end, bool b
} }
// return true if bit is on // return true if bit is on
bool GetBitValue(const std::string& entry, uint32_t offset) { bool GetBitValue(const string& entry, uint32_t offset) {
const auto byte_val{GetByteValue(entry, offset)}; const auto byte_val{GetByteValue(entry, offset)};
const auto index{GetNormalizedBitIndex(offset)}; const auto index{GetNormalizedBitIndex(offset)};
return CheckBitStatus(byte_val, index); return CheckBitStatus(byte_val, index);
} }
bool GetBitValueSafe(const std::string& entry, uint32_t offset) { bool GetBitValueSafe(const string& entry, uint32_t offset) {
return ((entry.size() * OFFSET_FACTOR) > offset) ? GetBitValue(entry, offset) : false; return ((entry.size() * OFFSET_FACTOR) > offset) ? GetBitValue(entry, offset) : false;
} }
@ -259,7 +241,7 @@ constexpr uint8_t TurnBitOff(uint8_t on, uint32_t offset) {
return on &= ~(1 << offset); return on &= ~(1 << offset);
} }
bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry) { bool SetBitValue(uint32_t offset, bool bit_value, string* entry) {
// we need to return the old value after setting the value for offset // we need to return the old value after setting the value for offset
const auto old_value{GetBitValue(*entry, offset)}; // save this as the return value const auto old_value{GetBitValue(*entry, offset)}; // save this as the return value
auto byte{GetByteValue(*entry, offset)}; auto byte{GetByteValue(*entry, offset)};
@ -274,7 +256,7 @@ bool SetBitValue(uint32_t offset, bool bit_value, std::string* entry) {
class ElementAccess { class ElementAccess {
bool added_ = false; bool added_ = false;
DbSlice::Iterator element_iter_; DbSlice::Iterator element_iter_;
std::string_view key_; string_view key_;
DbContext context_; DbContext context_;
EngineShard* shard_ = nullptr; EngineShard* shard_ = nullptr;
mutable DbSlice::AutoUpdater post_updater_; mutable DbSlice::AutoUpdater post_updater_;
@ -282,7 +264,7 @@ class ElementAccess {
void SetFields(EngineShard* shard, DbSlice::AddOrFindResult res); void SetFields(EngineShard* shard, DbSlice::AddOrFindResult res);
public: public:
ElementAccess(std::string_view key, const OpArgs& args) : key_{key}, context_{args.db_cntx} { ElementAccess(string_view key, const OpArgs& args) : key_{key}, context_{args.db_cntx} {
} }
OpStatus Find(EngineShard* shard); OpStatus Find(EngineShard* shard);
@ -299,9 +281,9 @@ class ElementAccess {
return context_.db_index; return context_.db_index;
} }
std::string Value() const; string Value() const;
void Commit(std::string_view new_value) const; void Commit(string_view new_value) const;
// return nullopt when key exists but it's not encoded as string // return nullopt when key exists but it's not encoded as string
// return true if key exists and false if it doesn't // return true if key exists and false if it doesn't
@ -345,16 +327,16 @@ OpStatus ElementAccess::FindAllowWrongType(EngineShard* shard) {
return OpStatus::OK; return OpStatus::OK;
} }
std::string ElementAccess::Value() const { string ElementAccess::Value() const {
CHECK_NOTNULL(shard_); CHECK_NOTNULL(shard_);
if (!added_) { // Exist entry - return it if (!added_) { // Exist entry - return it
return GetString(element_iter_->second); return GetString(element_iter_->second);
} else { // we only have reference to the new entry but no value } else { // we only have reference to the new entry but no value
return std::string{}; return string{};
} }
} }
void ElementAccess::Commit(std::string_view new_value) const { void ElementAccess::Commit(string_view new_value) const {
if (shard_) { if (shard_) {
if (new_value.empty()) { if (new_value.empty()) {
if (!IsNewEntry()) { if (!IsNewEntry()) {
@ -374,8 +356,7 @@ void ElementAccess::Commit(std::string_view new_value) const {
// ============================================= // =============================================
// Set a new value to a given bit // Set a new value to a given bit
OpResult<bool> BitNewValue(const OpArgs& args, std::string_view key, uint32_t offset, OpResult<bool> BitNewValue(const OpArgs& args, string_view key, uint32_t offset, bool bit_value) {
bool bit_value) {
EngineShard* shard = args.shard; EngineShard* shard = args.shard;
ElementAccess element_access{key, args}; ElementAccess element_access{key, args};
auto& db_slice = args.GetDbSlice(); auto& db_slice = args.GetDbSlice();
@ -389,12 +370,12 @@ OpResult<bool> BitNewValue(const OpArgs& args, std::string_view key, uint32_t of
} }
if (element_access.IsNewEntry()) { if (element_access.IsNewEntry()) {
std::string new_entry(GetByteIndex(offset) + 1, 0); string new_entry(GetByteIndex(offset) + 1, 0);
old_value = SetBitValue(offset, bit_value, &new_entry); old_value = SetBitValue(offset, bit_value, &new_entry);
element_access.Commit(new_entry); element_access.Commit(new_entry);
} else { } else {
bool reset = false; bool reset = false;
std::string existing_entry{element_access.Value()}; string existing_entry{element_access.Value()};
if ((existing_entry.size() * OFFSET_FACTOR) <= offset) { if ((existing_entry.size() * OFFSET_FACTOR) <= offset) {
existing_entry.resize(GetByteIndex(offset) + 1, 0); existing_entry.resize(GetByteIndex(offset) + 1, 0);
reset = true; reset = true;
@ -409,7 +390,7 @@ OpResult<bool> BitNewValue(const OpArgs& args, std::string_view key, uint32_t of
// --------------------------------------------------------- // ---------------------------------------------------------
std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& values) { string RunBitOperationOnValues(string_view op, const BitsStrVec& values) {
// This function accept an operation (either OR, XOR, NOT or OR), and run bit operation // This function accept an operation (either OR, XOR, NOT or OR), and run bit operation
// on all the values we got from the database. Note that in case that one of the values // on all the values we got from the database. Note that in case that one of the values
// is shorter than the other it would return a 0 and the operation would continue // is shorter than the other it would return a 0 and the operation would continue
@ -419,22 +400,22 @@ std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& value
const auto BitOperation = [&]() { const auto BitOperation = [&]() {
if (op == OR_OP_NAME) { if (op == OR_OP_NAME) {
std::string default_str{values[max_len_index]}; string default_str{values[max_len_index]};
return BitOpString(OrOp, SkipOr, std::move(values), std::move(default_str)); return BitOpString(OrOp, SkipOr, std::move(values), std::move(default_str));
} else if (op == XOR_OP_NAME) { } else if (op == XOR_OP_NAME) {
return BitOpString(XorOp, SkipXor, std::move(values), std::string(max_len, 0)); return BitOpString(XorOp, SkipXor, std::move(values), string(max_len, 0));
} else if (op == AND_OP_NAME) { } else if (op == AND_OP_NAME) {
return BitOpString(AndOp, SkipAnd, std::move(values), std::string(max_len, 0)); return BitOpString(AndOp, SkipAnd, std::move(values), string(max_len, 0));
} else if (op == NOT_OP_NAME) { } else if (op == NOT_OP_NAME) {
return BitOpNotString(values[0]); return BitOpNotString(values[0]);
} else { } else {
LOG(FATAL) << "Operation not supported '" << op << "'"; LOG(FATAL) << "Operation not supported '" << op << "'";
return std::string{}; // otherwise we will have warning of not returning value return string{}; // otherwise we will have warning of not returning value
} }
}; };
if (values.empty()) { // this is ok in case we don't have the src keys if (values.empty()) { // this is ok in case we don't have the src keys
return std::string{}; return string{};
} }
// The new result is the max length input // The new result is the max length input
max_len = values[0].size(); max_len = values[0].size();
@ -447,7 +428,7 @@ std::string RunBitOperationOnValues(std::string_view op, const BitsStrVec& value
return BitOperation(); return BitOperation();
} }
OpResult<std::string> CombineResultOp(ShardStringResults result, std::string_view op) { OpResult<string> CombineResultOp(ShardStringResults result, string_view op) {
// take valid result for each shard // take valid result for each shard
BitsStrVec values; BitsStrVec values;
for (auto&& res : result) { for (auto&& res : result) {
@ -467,7 +448,7 @@ OpResult<std::string> CombineResultOp(ShardStringResults result, std::string_vie
} }
// For bitop not - we cannot accumulate // For bitop not - we cannot accumulate
OpResult<std::string> RunBitOpNot(const OpArgs& op_args, string_view key) { OpResult<string> RunBitOpNot(const OpArgs& op_args, string_view key) {
// if we found the value, just return, if not found then skip, otherwise report an error // if we found the value, just return, if not found then skip, otherwise report an error
DbSlice& db_slice = op_args.GetDbSlice(); DbSlice& db_slice = op_args.GetDbSlice();
auto find_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING); auto find_res = db_slice.FindReadOnly(op_args.db_cntx, key, OBJ_STRING);
@ -480,8 +461,8 @@ OpResult<std::string> RunBitOpNot(const OpArgs& op_args, string_view key) {
// Read only operation where we are running the bit operation on all the // Read only operation where we are running the bit operation on all the
// values that belong to same shard. // values that belong to same shard.
OpResult<std::string> RunBitOpOnShard(std::string_view op, const OpArgs& op_args, OpResult<string> RunBitOpOnShard(string_view op, const OpArgs& op_args, ShardArgs::Iterator start,
ShardArgs::Iterator start, ShardArgs::Iterator end) { ShardArgs::Iterator end) {
DCHECK(start != end); DCHECK(start != end);
if (op == NOT_OP_NAME) { if (op == NOT_OP_NAME) {
return RunBitOpNot(op_args, *start); return RunBitOpNot(op_args, *start);
@ -504,7 +485,7 @@ OpResult<std::string> RunBitOpOnShard(std::string_view op, const OpArgs& op_args
} }
} }
// Run the operation on all the values that we found // Run the operation on all the values that we found
std::string op_result = RunBitOperationOnValues(op, values); string op_result = RunBitOperationOnValues(op, values);
return op_result; return op_result;
} }
@ -539,7 +520,7 @@ void BitPos(CmdArgList args, ConnectionContext* cntx) {
return cntx->SendError(kSyntaxErr); return cntx->SendError(kSyntaxErr);
} }
std::string_view key = ArgS(args, 0); string_view key = ArgS(args, 0);
int32_t value{0}; int32_t value{0};
int64_t start = 0; int64_t start = 0;
@ -556,13 +537,19 @@ void BitPos(CmdArgList args, ConnectionContext* cntx) {
if (!absl::SimpleAtoi(ArgS(args, 2), &start)) { if (!absl::SimpleAtoi(ArgS(args, 2), &start)) {
return cntx->SendError(kInvalidIntErr); return cntx->SendError(kInvalidIntErr);
} }
if (args.size() >= 4) { if (args.size() >= 4) {
if (!absl::SimpleAtoi(ArgS(args, 3), &end)) { if (!absl::SimpleAtoi(ArgS(args, 3), &end)) {
return cntx->SendError(kInvalidIntErr); return cntx->SendError(kInvalidIntErr);
} }
if (args.size() >= 5) { if (args.size() >= 5) {
if (!ToUpperAndGetAsBit(args, 4, &as_bit)) { string arg = absl::AsciiStrToUpper(ArgS(args, 4));
if (arg == "BIT") {
as_bit = true;
} else if (arg == "BYTE") {
as_bit = false;
} else {
return cntx->SendError(kSyntaxErr); return cntx->SendError(kSyntaxErr);
} }
} }
@ -730,13 +717,13 @@ class Get {
// Apply the GET subcommand to the bitfield bytes. // Apply the GET subcommand to the bitfield bytes.
// Return either the subcommand result (int64_t) or empty optional if failed because of // Return either the subcommand result (int64_t) or empty optional if failed because of
// Policy:FAIL // Policy:FAIL
ResultType ApplyTo(Overflow ov, const std::string* bitfield); ResultType ApplyTo(Overflow ov, const string* bitfield);
private: private:
CommonAttributes attr_; CommonAttributes attr_;
}; };
ResultType Get::ApplyTo(Overflow ov, const std::string* bitfield) { ResultType Get::ApplyTo(Overflow ov, const string* bitfield) {
const auto& bytes = *bitfield; const auto& bytes = *bitfield;
const int32_t total_bytes = static_cast<int32_t>(bytes.size()); const int32_t total_bytes = static_cast<int32_t>(bytes.size());
const size_t offset = attr_.offset; const size_t offset = attr_.offset;
@ -774,7 +761,7 @@ class Set {
// Apply the SET subcommand to the bitfield value. // Apply the SET subcommand to the bitfield value.
// Return either the subcommand result (int64_t) or empty optional if failed because of // Return either the subcommand result (int64_t) or empty optional if failed because of
// Policy:FAIL Updates the bitfield to contain the new value // Policy:FAIL Updates the bitfield to contain the new value
ResultType ApplyTo(Overflow ov, std::string* bitfield); ResultType ApplyTo(Overflow ov, string* bitfield);
private: private:
// Helper function that delegates overflow checking to the Overflow object // Helper function that delegates overflow checking to the Overflow object
@ -784,8 +771,8 @@ class Set {
int64_t set_value_; int64_t set_value_;
}; };
ResultType Set::ApplyTo(Overflow ov, std::string* bitfield) { ResultType Set::ApplyTo(Overflow ov, string* bitfield) {
std::string& bytes = *bitfield; string& bytes = *bitfield;
const int32_t total_bytes = static_cast<int32_t>(bytes.size()); const int32_t total_bytes = static_cast<int32_t>(bytes.size());
auto last_byte_offset = GetByteIndex(attr_.offset + attr_.encoding_bit_size - 1) + 1; auto last_byte_offset = GetByteIndex(attr_.offset + attr_.encoding_bit_size - 1) + 1;
if (last_byte_offset > total_bytes) { if (last_byte_offset > total_bytes) {
@ -830,7 +817,7 @@ class IncrBy {
// Apply the INCRBY subcommand to the bitfield value. // Apply the INCRBY subcommand to the bitfield value.
// Return either the subcommand result (int64_t) or empty optional if failed because of // Return either the subcommand result (int64_t) or empty optional if failed because of
// Policy:FAIL Updates the bitfield to contain the new incremented value // Policy:FAIL Updates the bitfield to contain the new incremented value
ResultType ApplyTo(Overflow ov, std::string* bitfield); ResultType ApplyTo(Overflow ov, string* bitfield);
private: private:
// Helper function that delegates overflow checking to the Overflow object // Helper function that delegates overflow checking to the Overflow object
@ -840,8 +827,8 @@ class IncrBy {
int64_t incr_value_; int64_t incr_value_;
}; };
ResultType IncrBy::ApplyTo(Overflow ov, std::string* bitfield) { ResultType IncrBy::ApplyTo(Overflow ov, string* bitfield) {
std::string& bytes = *bitfield; string& bytes = *bitfield;
Get get(attr_); Get get(attr_);
auto res = get.ApplyTo(ov, &bytes); auto res = get.ApplyTo(ov, &bytes);
@ -876,7 +863,7 @@ using Result = std::optional<ResultType>;
// Visitor for all the subcommand variants. Calls ApplyTo, to execute the subcommand // Visitor for all the subcommand variants. Calls ApplyTo, to execute the subcommand
class CommandApplyVisitor { class CommandApplyVisitor {
public: public:
explicit CommandApplyVisitor(std::string bitfield) : bitfield_(std::move(bitfield)) { explicit CommandApplyVisitor(string bitfield) : bitfield_(std::move(bitfield)) {
} }
Result operator()(Get get) { Result operator()(Get get) {
@ -893,7 +880,7 @@ class CommandApplyVisitor {
return {}; return {};
} }
std::string_view Bitfield() const { string_view Bitfield() const {
return bitfield_; return bitfield_;
} }
@ -906,14 +893,14 @@ class CommandApplyVisitor {
// policy changes stick among different subcommands // policy changes stick among different subcommands
Overflow overflow_; Overflow overflow_;
// This will be commited if it was updated // This will be commited if it was updated
std::string bitfield_; string bitfield_;
// If either of the subcommands SET|INCRBY is used we should persist the changes. // If either of the subcommands SET|INCRBY is used we should persist the changes.
// Otherwise, we only used a read only subcommand (GET) // Otherwise, we only used a read only subcommand (GET)
bool should_commit_ = false; bool should_commit_ = false;
}; };
// A lit of subcommands used in BITFIELD command // A lit of subcommands used in BITFIELD command
using CommandList = std::vector<Command>; using CommandList = vector<Command>;
// Helper class used in the shard cb that abstracts away the iteration and execution of subcommands // Helper class used in the shard cb that abstracts away the iteration and execution of subcommands
class StateExecutor { class StateExecutor {
@ -925,25 +912,25 @@ class StateExecutor {
// Iterates over all of the parsed subcommands and executes them one by one. At the end, // Iterates over all of the parsed subcommands and executes them one by one. At the end,
// if an update subcommand SET|INCRBY was used, commit back the changes via the ElementAccess // if an update subcommand SET|INCRBY was used, commit back the changes via the ElementAccess
// object // object
OpResult<std::vector<ResultType>> Execute(const CommandList& commands); OpResult<vector<ResultType>> Execute(const CommandList& commands);
private: private:
ElementAccess access_; ElementAccess access_;
EngineShard* shard_; EngineShard* shard_;
}; };
OpResult<std::vector<ResultType>> StateExecutor::Execute(const CommandList& commands) { OpResult<vector<ResultType>> StateExecutor::Execute(const CommandList& commands) {
auto res = access_.Exists(shard_); auto res = access_.Exists(shard_);
if (!res) { if (!res) {
return {OpStatus::WRONG_TYPE}; return {OpStatus::WRONG_TYPE};
} }
std::string value; string value;
if (*res) { if (*res) {
access_.Find(shard_); access_.Find(shard_);
value = access_.Value(); value = access_.Value();
} }
std::vector<ResultType> results; vector<ResultType> results;
CommandApplyVisitor visitor(std::move(value)); CommandApplyVisitor visitor(std::move(value));
for (auto& command : commands) { for (auto& command : commands) {
auto res = std::visit(visitor, command); auto res = std::visit(visitor, command);
@ -960,7 +947,7 @@ OpResult<std::vector<ResultType>> StateExecutor::Execute(const CommandList& comm
return results; return results;
} }
nonstd::expected<CommonAttributes, std::string> ParseCommonAttr(CmdArgParser* parser) { nonstd::expected<CommonAttributes, string> ParseCommonAttr(CmdArgParser* parser) {
CommonAttributes parsed; CommonAttributes parsed;
using nonstd::make_unexpected; using nonstd::make_unexpected;
@ -977,7 +964,7 @@ nonstd::expected<CommonAttributes, std::string> ParseCommonAttr(CmdArgParser* pa
return make_unexpected(kSyntaxErr); return make_unexpected(kSyntaxErr);
} }
std::string_view bits = encoding.substr(1); string_view bits = encoding.substr(1);
if (!absl::SimpleAtoi(bits, &parsed.encoding_bit_size)) { if (!absl::SimpleAtoi(bits, &parsed.encoding_bit_size)) {
return make_unexpected(kSyntaxErr); return make_unexpected(kSyntaxErr);
@ -1010,9 +997,9 @@ nonstd::expected<CommonAttributes, std::string> ParseCommonAttr(CmdArgParser* pa
} }
// Parses a list of arguments (without key) to a CommandList. // Parses a list of arguments (without key) to a CommandList.
// Returns the CommandList if the parsing completed succefully or std::string // Returns the CommandList if the parsing completed succefully or string
// to indicate an error // to indicate an error
nonstd::expected<CommandList, std::string> ParseToCommandList(CmdArgList args, bool read_only) { nonstd::expected<CommandList, string> ParseToCommandList(CmdArgList args, bool read_only) {
enum class Cmds { OVERFLOW_OPT, GET_OPT, SET_OPT, INCRBY_OPT }; enum class Cmds { OVERFLOW_OPT, GET_OPT, SET_OPT, INCRBY_OPT };
CommandList result; CommandList result;
@ -1076,7 +1063,7 @@ nonstd::expected<CommandList, std::string> ParseToCommandList(CmdArgList args, b
return result; return result;
} }
void SendResults(const std::vector<ResultType>& results, ConnectionContext* cntx) { void SendResults(const vector<ResultType>& results, ConnectionContext* cntx) {
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder()); auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
const size_t total = results.size(); const size_t total = results.size();
if (total == 0) { if (total == 0) {
@ -1110,14 +1097,13 @@ void BitFieldGeneric(CmdArgList args, bool read_only, ConnectionContext* cntx) {
} }
CommandList cmd_list = std::move(maybe_ops_list.value()); CommandList cmd_list = std::move(maybe_ops_list.value());
auto cb = [&cmd_list, &key](Transaction* t, auto cb = [&cmd_list, &key](Transaction* t, EngineShard* shard) -> OpResult<vector<ResultType>> {
EngineShard* shard) -> OpResult<std::vector<ResultType>> {
StateExecutor executor(ElementAccess(key, t->GetOpArgs(shard)), shard); StateExecutor executor(ElementAccess(key, t->GetOpArgs(shard)), shard);
return executor.Execute(cmd_list); return executor.Execute(cmd_list);
}; };
Transaction* trans = cntx->transaction; Transaction* trans = cntx->transaction;
OpResult<std::vector<ResultType>> res = trans->ScheduleSingleHopT(std::move(cb)); OpResult<vector<ResultType>> res = trans->ScheduleSingleHopT(std::move(cb));
if (res == OpStatus::WRONG_TYPE) { if (res == OpStatus::WRONG_TYPE) {
cntx->SendError(kWrongTypeErr); cntx->SendError(kWrongTypeErr);
@ -1140,11 +1126,10 @@ void BitFieldRo(CmdArgList args, ConnectionContext* cntx) {
#endif #endif
void BitOp(CmdArgList args, ConnectionContext* cntx) { void BitOp(CmdArgList args, ConnectionContext* cntx) {
static const std::array<std::string_view, 4> BITOP_OP_NAMES{OR_OP_NAME, XOR_OP_NAME, AND_OP_NAME, static const std::array<string_view, 4> BITOP_OP_NAMES{OR_OP_NAME, XOR_OP_NAME, AND_OP_NAME,
NOT_OP_NAME}; NOT_OP_NAME};
ToUpper(&args[0]); string op = absl::AsciiStrToUpper(ArgS(args, 0));
std::string_view op = ArgS(args, 0); string_view dest_key = ArgS(args, 1);
std::string_view dest_key = ArgS(args, 1);
bool illegal = std::none_of(BITOP_OP_NAMES.begin(), BITOP_OP_NAMES.end(), bool illegal = std::none_of(BITOP_OP_NAMES.begin(), BITOP_OP_NAMES.end(),
[&op](auto val) { return op == val; }); [&op](auto val) { return op == val; });
@ -1218,7 +1203,7 @@ void GetBit(CmdArgList args, ConnectionContext* cntx) {
// see https://redis.io/commands/getbit/ // see https://redis.io/commands/getbit/
uint32_t offset{0}; uint32_t offset{0};
std::string_view key = ArgS(args, 0); string_view key = ArgS(args, 0);
if (!absl::SimpleAtoi(ArgS(args, 1), &offset)) { if (!absl::SimpleAtoi(ArgS(args, 1), &offset)) {
return cntx->SendError(kInvalidIntErr); return cntx->SendError(kInvalidIntErr);
@ -1253,14 +1238,14 @@ void SetBit(CmdArgList args, ConnectionContext* cntx) {
// ------------------------------------------------------------------------- // // ------------------------------------------------------------------------- //
// This are the "callbacks" that we're using from above // This are the "callbacks" that we're using from above
std::string GetString(const PrimeValue& pv) { string GetString(const PrimeValue& pv) {
std::string res; string res;
pv.GetString(&res); pv.GetString(&res);
return res; return res;
} }
OpResult<bool> ReadValueBitsetAt(const OpArgs& op_args, std::string_view key, uint32_t offset) { OpResult<bool> ReadValueBitsetAt(const OpArgs& op_args, string_view key, uint32_t offset) {
OpResult<std::string> result = ReadValue(op_args.db_cntx, key, op_args.shard); OpResult<string> result = ReadValue(op_args.db_cntx, key, op_args.shard);
if (result) { if (result) {
return GetBitValueSafe(result.value(), offset); return GetBitValueSafe(result.value(), offset);
} else { } else {
@ -1268,8 +1253,7 @@ OpResult<bool> ReadValueBitsetAt(const OpArgs& op_args, std::string_view key, ui
} }
} }
OpResult<std::string> ReadValue(const DbContext& context, std::string_view key, OpResult<string> ReadValue(const DbContext& context, string_view key, EngineShard* shard) {
EngineShard* shard) {
DbSlice& db_slice = context.GetDbSlice(shard->shard_id()); DbSlice& db_slice = context.GetDbSlice(shard->shard_id());
auto it_res = db_slice.FindReadOnly(context, key, OBJ_STRING); auto it_res = db_slice.FindReadOnly(context, key, OBJ_STRING);
if (!it_res.ok()) { if (!it_res.ok()) {
@ -1281,9 +1265,9 @@ OpResult<std::string> ReadValue(const DbContext& context, std::string_view key,
return GetString(pv); return GetString(pv);
} }
OpResult<std::size_t> CountBitsForValue(const OpArgs& op_args, std::string_view key, int64_t start, OpResult<std::size_t> CountBitsForValue(const OpArgs& op_args, string_view key, int64_t start,
int64_t end, bool bit_value) { int64_t end, bool bit_value) {
OpResult<std::string> result = ReadValue(op_args.db_cntx, key, op_args.shard); OpResult<string> result = ReadValue(op_args.db_cntx, key, op_args.shard);
if (result) { // if this is not found, just return 0 - per Redis if (result) { // if this is not found, just return 0 - per Redis
return CountBitSet(result.value(), start, end, bit_value); return CountBitSet(result.value(), start, end, bit_value);
@ -1302,7 +1286,7 @@ std::size_t GetFirstBitWithValueInByte(uint8_t byte, bool value) {
} }
} }
int64_t FindFirstBitWithValueAsBit(std::string_view value_str, bool bit_value, int64_t start, int64_t FindFirstBitWithValueAsBit(string_view value_str, bool bit_value, int64_t start,
int64_t end) { int64_t end) {
for (int64_t i = start; i <= end; ++i) { for (int64_t i = start; i <= end; ++i) {
if (static_cast<size_t>(GetByteIndex(i)) >= value_str.size()) { if (static_cast<size_t>(GetByteIndex(i)) >= value_str.size()) {
@ -1320,7 +1304,7 @@ int64_t FindFirstBitWithValueAsBit(std::string_view value_str, bool bit_value, i
return -1; return -1;
} }
int64_t FindFirstBitWithValueAsByte(std::string_view value_str, bool bit_value, int64_t start, int64_t FindFirstBitWithValueAsByte(string_view value_str, bool bit_value, int64_t start,
int64_t end) { int64_t end) {
for (int64_t i = start; i <= end; ++i) { for (int64_t i = start; i <= end; ++i) {
if (static_cast<size_t>(i) >= value_str.size()) { if (static_cast<size_t>(i) >= value_str.size()) {
@ -1338,9 +1322,9 @@ int64_t FindFirstBitWithValueAsByte(std::string_view value_str, bool bit_value,
return -1; return -1;
} }
OpResult<int64_t> FindFirstBitWithValue(const OpArgs& op_args, std::string_view key, bool bit_value, OpResult<int64_t> FindFirstBitWithValue(const OpArgs& op_args, string_view key, bool bit_value,
int64_t start, int64_t end, bool as_bit) { int64_t start, int64_t end, bool as_bit) {
OpResult<std::string> value = ReadValue(op_args.db_cntx, key, op_args.shard); OpResult<string> value = ReadValue(op_args.db_cntx, key, op_args.shard);
// non-existent keys are handled exactly as in Redis's implementation, // non-existent keys are handled exactly as in Redis's implementation,
// even though it contradicts its docs: // even though it contradicts its docs:
@ -1350,7 +1334,7 @@ OpResult<int64_t> FindFirstBitWithValue(const OpArgs& op_args, std::string_view
return bit_value ? -1 : 0; return bit_value ? -1 : 0;
} }
std::string_view value_str = value.value(); string_view value_str = value.value();
int64_t size = value_str.size(); int64_t size = value_str.size();
if (as_bit) { if (as_bit) {
size *= OFFSET_FACTOR; size *= OFFSET_FACTOR;

View file

@ -481,8 +481,7 @@ void DebugCmd::Reload(CmdArgList args) {
bool save = true; bool save = true;
for (size_t i = 1; i < args.size(); ++i) { for (size_t i = 1; i < args.size(); ++i) {
ToUpper(&args[i]); string_view opt = absl::AsciiStrToUpper(ArgS(args, i));
string_view opt = ArgS(args, i);
VLOG(1) << "opt " << opt; VLOG(1) << "opt " << opt;
if (opt == "NOSAVE") { if (opt == "NOSAVE") {
@ -520,8 +519,8 @@ void DebugCmd::Reload(CmdArgList args) {
void DebugCmd::Replica(CmdArgList args) { void DebugCmd::Replica(CmdArgList args) {
args.remove_prefix(1); args.remove_prefix(1);
ToUpper(&args[0]);
string_view opt = ArgS(args, 0); string opt = absl::AsciiStrToUpper(ArgS(args, 0));
auto* rb = static_cast<RedisReplyBuilder*>(cntx_->reply_builder()); auto* rb = static_cast<RedisReplyBuilder*>(cntx_->reply_builder());
if (opt == "PAUSE" || opt == "RESUME") { if (opt == "PAUSE" || opt == "RESUME") {
@ -568,8 +567,7 @@ optional<DebugCmd::PopulateOptions> DebugCmd::ParsePopulateArgs(CmdArgList args)
} }
for (size_t index = 4; args.size() > index; ++index) { for (size_t index = 4; args.size() > index; ++index) {
ToUpper(&args[index]); string str = absl::AsciiStrToUpper(ArgS(args, index));
std::string_view str = ArgS(args, index);
if (str == "RAND") { if (str == "RAND") {
options.populate_random_values = true; options.populate_random_values = true;
} else if (str == "TYPE") { } else if (str == "TYPE") {
@ -577,8 +575,8 @@ optional<DebugCmd::PopulateOptions> DebugCmd::ParsePopulateArgs(CmdArgList args)
cntx_->SendError(kSyntaxErr); cntx_->SendError(kSyntaxErr);
return nullopt; return nullopt;
} }
ToUpper(&args[++index]); ++index;
options.type = ArgS(args, index); options.type = absl::AsciiStrToUpper(ArgS(args, index));
} else if (str == "ELEMENTS") { } else if (str == "ELEMENTS") {
if (args.size() < index + 2) { if (args.size() < index + 2) {
cntx_->SendError(kSyntaxErr); cntx_->SendError(kSyntaxErr);

View file

@ -19,7 +19,7 @@ class DebugCmd {
std::string_view prefix{"key"}; std::string_view prefix{"key"};
uint32_t val_size = 16; uint32_t val_size = 16;
bool populate_random_values = false; bool populate_random_values = false;
std::string_view type{"STRING"}; std::string type{"STRING"};
uint32_t elements = 1; uint32_t elements = 1;
std::optional<cluster::SlotRange> slot_range; std::optional<cluster::SlotRange> slot_range;

View file

@ -134,8 +134,7 @@ DflyCmd::DflyCmd(ServerFamily* server_family) : sf_(server_family) {
void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) { void DflyCmd::Run(CmdArgList args, ConnectionContext* cntx) {
DCHECK_GE(args.size(), 1u); DCHECK_GE(args.size(), 1u);
ToUpper(&args[0]); string sub_cmd = absl::AsciiStrToUpper(ArgS(args, 0));
string_view sub_cmd = ArgS(args, 0);
if (sub_cmd == "THREAD") { if (sub_cmd == "THREAD") {
return Thread(args, cntx); return Thread(args, cntx);

View file

@ -223,7 +223,7 @@ std::optional<DbSlice::ItAndUpdater> RdbRestoreValue::Add(std::string_view data,
// [FREQ frequency], in any order // [FREQ frequency], in any order
OpResult<RestoreArgs> RestoreArgs::TryFrom(const CmdArgList& args) { OpResult<RestoreArgs> RestoreArgs::TryFrom(const CmdArgList& args) {
RestoreArgs out_args; RestoreArgs out_args;
std::string_view cur_arg = ArgS(args, 1); // extract ttl string cur_arg{ArgS(args, 1)}; // extract ttl
if (!absl::SimpleAtoi(cur_arg, &out_args.expiration_) || (out_args.expiration_ < 0)) { if (!absl::SimpleAtoi(cur_arg, &out_args.expiration_) || (out_args.expiration_ < 0)) {
return OpStatus::INVALID_INT; return OpStatus::INVALID_INT;
} }
@ -236,8 +236,7 @@ OpResult<RestoreArgs> RestoreArgs::TryFrom(const CmdArgList& args) {
int64_t idle_time = 0; int64_t idle_time = 0;
for (size_t i = 3; i < args.size(); ++i) { for (size_t i = 3; i < args.size(); ++i) {
ToUpper(&args[i]); cur_arg = absl::AsciiStrToUpper(ArgS(args, i));
cur_arg = ArgS(args, i);
bool additional = args.size() - i - 1 >= 1; bool additional = args.size() - i - 1 >= 1;
if (cur_arg == "REPLACE") { if (cur_arg == "REPLACE") {
out_args.replace_ = true; out_args.replace_ = true;
@ -982,8 +981,7 @@ void GenericFamily::Persist(CmdArgList args, ConnectionContext* cntx) {
std::optional<int32_t> ParseExpireOptionsOrReply(const CmdArgList args, ConnectionContext* cntx) { std::optional<int32_t> ParseExpireOptionsOrReply(const CmdArgList args, ConnectionContext* cntx) {
int32_t flags = ExpireFlags::EXPIRE_ALWAYS; int32_t flags = ExpireFlags::EXPIRE_ALWAYS;
for (auto& arg : args) { for (auto& arg : args) {
ToUpper(&arg); string arg_sv = absl::AsciiStrToUpper(ToSV(arg));
auto arg_sv = ToSV(arg);
if (arg_sv == "NX") { if (arg_sv == "NX") {
flags |= ExpireFlags::EXPIRE_NX; flags |= ExpireFlags::EXPIRE_NX;
} else if (arg_sv == "XX") { } else if (arg_sv == "XX") {
@ -1304,9 +1302,7 @@ void GenericFamily::Sort(CmdArgList args, ConnectionContext* cntx) {
std::optional<std::pair<size_t, size_t>> bounds; std::optional<std::pair<size_t, size_t>> bounds;
for (size_t i = 1; i < args.size(); i++) { for (size_t i = 1; i < args.size(); i++) {
ToUpper(&args[i]); string arg = absl::AsciiStrToUpper(ArgS(args, i));
std::string_view arg = ArgS(args, i);
if (arg == "ALPHA") { if (arg == "ALPHA") {
alpha = true; alpha = true;
} else if (arg == "DESC") { } else if (arg == "DESC") {