1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

feat(server): Enable overriding --requirepass form env var (#792)

Signed-off-by: ashotland <ari@dragonflydb.io>
This commit is contained in:
ashotland 2023-02-14 13:19:33 +02:00 committed by GitHub
parent 25db011afc
commit 49b1ba5b6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 12 deletions

View file

@ -43,8 +43,6 @@ using namespace std;
ABSL_FLAG(uint32_t, port, 6379, "Redis port");
ABSL_FLAG(uint32_t, memcache_port, 0, "Memcached port");
ABSL_DECLARE_FLAG(string, requirepass);
namespace dfly {
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
@ -813,7 +811,7 @@ void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view va
facade::ConnectionContext* Service::CreateContext(util::FiberSocketBase* peer,
facade::Connection* owner) {
ConnectionContext* res = new ConnectionContext{peer, owner};
res->req_auth = IsPassProtected();
res->req_auth = !GetPassword().empty();
// a bit of a hack. I set up breaker callback here for the owner.
// Should work though it's confusing to have it here.
@ -853,10 +851,6 @@ bool Service::IsShardSetLocked() const {
return res.load() != 0;
}
bool Service::IsPassProtected() const {
return !GetFlag(FLAGS_requirepass).empty();
}
absl::flat_hash_map<std::string, unsigned> Service::UknownCmdMap() const {
lock_guard lk(mu_);
return unknown_cmds_;

View file

@ -59,8 +59,6 @@ class Service : public facade::ServiceInterface {
return pp_;
}
bool IsPassProtected() const;
absl::flat_hash_map<std::string, unsigned> UknownCmdMap() const;
const CommandId* FindCmd(std::string_view cmd) const {

View file

@ -51,7 +51,9 @@ using namespace std;
ABSL_FLAG(string, dir, "", "working directory");
ABSL_FLAG(string, dbfilename, "dump", "the filename to save/load the DB");
ABSL_FLAG(string, requirepass, "", "password for AUTH authentication");
ABSL_FLAG(string, requirepass, "",
"password for AUTH authentication. "
"If empty can also be set with DFLY_PASSWORD environment variable.");
ABSL_FLAG(string, save_schedule, "",
"glob spec for the UTC time to save a snapshot which matches HH:MM 24h time");
@ -1047,6 +1049,20 @@ void ServerFamily::BreakOnShutdown() {
dfly_cmd_->BreakOnShutdown();
}
string GetPassword() {
string flag = GetFlag(FLAGS_requirepass);
if (!flag.empty()) {
return flag;
}
const char* env_var = getenv("DFLY_PASSWORD");
if (env_var) {
return env_var;
}
return "";
}
void ServerFamily::FlushDb(CmdArgList args, ConnectionContext* cntx) {
DCHECK(cntx->transaction);
Drakarys(cntx->transaction, cntx->transaction->GetDbIndex());
@ -1080,7 +1096,7 @@ void ServerFamily::Auth(CmdArgList args, ConnectionContext* cntx) {
}
string_view pass = ArgS(args, 1);
if (pass == GetFlag(FLAGS_requirepass)) {
if (pass == GetPassword()) {
cntx->authenticated = true;
(*cntx)->SendOk();
} else {

View file

@ -5,6 +5,7 @@
#pragma once
#include <boost/fiber/future.hpp>
#include <string>
#include "facade/conn_context.h"
#include "facade/redis_parser.h"
@ -20,6 +21,8 @@ class HttpListenerBase;
namespace dfly {
std::string GetPassword();
namespace journal {
class Journal;
} // namespace journal

View file

@ -1,3 +1,6 @@
import os
import aioredis
import pytest
from . import dfly_multi_test_args
from .utility import batch_fill_data, gen_test_data
@ -11,3 +14,36 @@ class TestKeys:
keys = client.keys()
assert len(keys) in range(max_keys, max_keys+512)
@pytest.fixture(scope="function")
def export_dfly_password() -> str:
pwd = 'flypwd'
os.environ['DFLY_PASSWORD'] = pwd
yield pwd
del os.environ['DFLY_PASSWORD']
@pytest.mark.asyncio
async def test_password(df_local_factory, export_dfly_password):
dfly = df_local_factory.create()
dfly.start()
# Expect password form environment variable
with pytest.raises(aioredis.exceptions.AuthenticationError):
client = aioredis.Redis()
await client.ping()
client = aioredis.Redis(password=export_dfly_password)
await client.ping()
dfly.stop()
# --requirepass should take precedence over environment variable
requirepass = 'requirepass'
dfly = df_local_factory.create(requirepass=requirepass)
dfly.start()
# Expect password form flag
with pytest.raises(aioredis.exceptions.ResponseError):
client = aioredis.Redis(password=export_dfly_password)
await client.ping()
client = aioredis.Redis(password=requirepass)
await client.ping()
dfly.stop()