mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
feat(server): Enable overriding --requirepass form env var (#792)
Signed-off-by: ashotland <ari@dragonflydb.io>
This commit is contained in:
parent
25db011afc
commit
49b1ba5b6d
5 changed files with 59 additions and 12 deletions
|
@ -43,8 +43,6 @@ using namespace std;
|
|||
ABSL_FLAG(uint32_t, port, 6379, "Redis port");
|
||||
ABSL_FLAG(uint32_t, memcache_port, 0, "Memcached port");
|
||||
|
||||
ABSL_DECLARE_FLAG(string, requirepass);
|
||||
|
||||
namespace dfly {
|
||||
|
||||
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
|
||||
|
@ -813,7 +811,7 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
|
|||
facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer,
|
||||
facade::Connection* owner) {
|
||||
ConnectionContext* res = new ConnectionContext{peer, owner};
|
||||
res->req_auth = IsPassProtected();
|
||||
res->req_auth = !GetPassword().empty();
|
||||
|
||||
// a bit of a hack. I set up breaker callback here for the owner.
|
||||
// Should work though it's confusing to have it here.
|
||||
|
@ -853,10 +851,6 @@ bool Service::IsShardSetLocked() const {
|
|||
return res.load() != 0;
|
||||
}
|
||||
|
||||
bool Service::IsPassProtected() const {
|
||||
return !GetFlag(FLAGS_requirepass).empty();
|
||||
}
|
||||
|
||||
absl::flat_hash_map<std::string, unsigned> Service::UknownCmdMap() const {
|
||||
lock_guard lk(mu_);
|
||||
return unknown_cmds_;
|
||||
|
|
|
@ -59,8 +59,6 @@ class Service : public facade::ServiceInterface {
|
|||
return pp_;
|
||||
}
|
||||
|
||||
bool IsPassProtected() const;
|
||||
|
||||
absl::flat_hash_map<std::string, unsigned> UknownCmdMap() const;
|
||||
|
||||
const CommandId* FindCmd(std::string_view cmd) const {
|
||||
|
|
|
@ -51,7 +51,9 @@ using namespace std;
|
|||
|
||||
ABSL_FLAG(string, dir, "", "working directory");
|
||||
ABSL_FLAG(string, dbfilename, "dump", "the filename to save/load the DB");
|
||||
ABSL_FLAG(string, requirepass, "", "password for AUTH authentication");
|
||||
ABSL_FLAG(string, requirepass, "",
|
||||
"password for AUTH authentication. "
|
||||
"If empty can also be set with DFLY_PASSWORD environment variable.");
|
||||
ABSL_FLAG(string, save_schedule, "",
|
||||
"glob spec for the UTC time to save a snapshot which matches HH:MM 24h time");
|
||||
|
||||
|
@ -1047,6 +1049,20 @@ void ServerFamily::BreakOnShutdown() {
|
|||
dfly_cmd_->BreakOnShutdown();
|
||||
}
|
||||
|
||||
string GetPassword() {
|
||||
string flag = GetFlag(FLAGS_requirepass);
|
||||
if (!flag.empty()) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
const char* env_var = getenv("DFLY_PASSWORD");
|
||||
if (env_var) {
|
||||
return env_var;
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
void ServerFamily::FlushDb(CmdArgList args, ConnectionContext* cntx) {
|
||||
DCHECK(cntx->transaction);
|
||||
Drakarys(cntx->transaction, cntx->transaction->GetDbIndex());
|
||||
|
@ -1080,7 +1096,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
|
|||
}
|
||||
|
||||
string_view pass = ArgS(args, 1);
|
||||
if (pass == GetFlag(FLAGS_requirepass)) {
|
||||
if (pass == GetPassword()) {
|
||||
cntx->authenticated = true;
|
||||
(*cntx)->SendOk();
|
||||
} else {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <boost/fiber/future.hpp>
|
||||
#include <string>
|
||||
|
||||
#include "facade/conn_context.h"
|
||||
#include "facade/redis_parser.h"
|
||||
|
@ -20,6 +21,8 @@ class HttpListenerBase;
|
|||
|
||||
namespace dfly {
|
||||
|
||||
std::string GetPassword();
|
||||
|
||||
namespace journal {
|
||||
class Journal;
|
||||
} // namespace journal
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import os
|
||||
import aioredis
|
||||
import pytest
|
||||
from . import dfly_multi_test_args
|
||||
from .utility import batch_fill_data, gen_test_data
|
||||
|
||||
|
@ -11,3 +14,36 @@ class TestKeys:
|
|||
|
||||
keys = client.keys()
|
||||
assert len(keys) in range(max_keys, max_keys+512)
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def export_dfly_password() -> str:
|
||||
pwd = 'flypwd'
|
||||
os.environ['DFLY_PASSWORD'] = pwd
|
||||
yield pwd
|
||||
del os.environ['DFLY_PASSWORD']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password(df_local_factory, export_dfly_password):
|
||||
dfly = df_local_factory.create()
|
||||
dfly.start()
|
||||
|
||||
# Expect password form environment variable
|
||||
with pytest.raises(aioredis.exceptions.AuthenticationError):
|
||||
client = aioredis.Redis()
|
||||
await client.ping()
|
||||
client = aioredis.Redis(password=export_dfly_password)
|
||||
await client.ping()
|
||||
dfly.stop()
|
||||
|
||||
# --requirepass should take precedence over environment variable
|
||||
requirepass = 'requirepass'
|
||||
dfly = df_local_factory.create(requirepass=requirepass)
|
||||
dfly.start()
|
||||
|
||||
# Expect password form flag
|
||||
with pytest.raises(aioredis.exceptions.ResponseError):
|
||||
client = aioredis.Redis(password=export_dfly_password)
|
||||
await client.ping()
|
||||
client = aioredis.Redis(password=requirepass)
|
||||
await client.ping()
|
||||
dfly.stop()
|
||||
|
|
Loading…
Reference in a new issue