mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
chore: better error reporting when connecting to tls with plain socket (#2740)
* chore: better error reporting when connecting to tls with plain socket --------- Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
parent
30c3f63ca2
commit
2d246adbbb
4 changed files with 37 additions and 31 deletions
2
helio
2
helio
|
@ -1 +1 @@
|
|||
Subproject commit 91c4f48d025bccd6fc3ff14e471accf1c7801f38
|
||||
Subproject commit 8985263c3acca038752e8f9fdd8e9f61d2ec2b6f
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
#include <absl/container/flat_hash_map.h>
|
||||
#include <absl/strings/match.h>
|
||||
#include <absl/strings/str_cat.h>
|
||||
#include <mimalloc.h>
|
||||
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "base/flags.h"
|
||||
#include "base/io_buf.h"
|
||||
#include "base/logging.h"
|
||||
|
@ -613,10 +613,25 @@ void Connection::HandleRequests() {
|
|||
if (!(IsPrivileged() && no_tls_on_admin_port)) {
|
||||
// Must be done atomically before the premption point in Accept so that at any
|
||||
// point in time, the socket_ is defined.
|
||||
uint8_t buf[2];
|
||||
auto read_sz = socket_->Read(io::MutableBytes(buf));
|
||||
if (!read_sz || *read_sz < sizeof(buf)) {
|
||||
VLOG(1) << "Error reading from peer " << remote_ep << " " << read_sz.error().message();
|
||||
return;
|
||||
}
|
||||
if (buf[0] != 0x16 || buf[1] != 0x03) {
|
||||
VLOG(1) << "Bad TLS header "
|
||||
<< absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2),
|
||||
absl::Hex(buf[1], absl::kZeroPad2));
|
||||
peer->Write(
|
||||
io::Buffer("-ERR Bad TLS header, double check "
|
||||
"if you enabled TLS for your client.\r\n"));
|
||||
}
|
||||
|
||||
{
|
||||
FiberAtomicGuard fg;
|
||||
unique_ptr<tls::TlsSocket> tls_sock = make_unique<tls::TlsSocket>(std::move(socket_));
|
||||
tls_sock->InitSSL(ssl_ctx_);
|
||||
tls_sock->InitSSL(ssl_ctx_, buf);
|
||||
SetSocket(tls_sock.release());
|
||||
}
|
||||
FiberSocketBase::AcceptResult aresult = socket_->Accept();
|
||||
|
|
|
@ -4,12 +4,12 @@ import pytest
|
|||
import asyncio
|
||||
import time
|
||||
from redis import asyncio as aioredis
|
||||
from redis.exceptions import ConnectionError as redis_conn_error
|
||||
from redis.exceptions import ConnectionError as redis_conn_error, ResponseError
|
||||
import async_timeout
|
||||
from dataclasses import dataclass
|
||||
|
||||
from . import dfly_args
|
||||
from .instance import DflyInstance
|
||||
from .instance import DflyInstance, DflyInstanceFactory
|
||||
|
||||
BASE_PORT = 1111
|
||||
|
||||
|
@ -564,7 +564,7 @@ async def test_large_cmd(async_client: aioredis.Redis):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
|
||||
server = df_local_factory.create(
|
||||
server: DflyInstance = df_local_factory.create(
|
||||
no_tls_on_admin_port="true",
|
||||
admin_port=1111,
|
||||
port=1211,
|
||||
|
@ -573,13 +573,12 @@ async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_
|
|||
)
|
||||
server.start()
|
||||
|
||||
client = aioredis.Redis(port=server.port, password="XXX")
|
||||
try:
|
||||
await client.execute_command("DBSIZE")
|
||||
except redis_conn_error:
|
||||
pass
|
||||
client = server.client(password="XXX")
|
||||
with pytest.raises((ResponseError)):
|
||||
await client.dbsize()
|
||||
await client.close()
|
||||
|
||||
client = aioredis.Redis(port=server.admin_port, password="XXX")
|
||||
client = server.admin_client(password="XXX")
|
||||
assert await client.dbsize() == 0
|
||||
await client.close()
|
||||
|
||||
|
@ -605,27 +604,19 @@ async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, d
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tls_reject(with_ca_tls_server_args, with_tls_client_args, df_local_factory):
|
||||
server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
|
||||
async def test_tls_reject(
|
||||
with_ca_tls_server_args, with_tls_client_args, df_local_factory: DflyInstanceFactory
|
||||
):
|
||||
server: DflyInstance = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
|
||||
server.start()
|
||||
|
||||
client = aioredis.Redis(port=server.port, **with_tls_client_args, ssl_cert_reqs=None)
|
||||
try:
|
||||
client = server.client(**with_tls_client_args, ssl_cert_reqs=None)
|
||||
await client.ping()
|
||||
await client.close()
|
||||
|
||||
client = server.client(**with_tls_client_args)
|
||||
with pytest.raises(redis_conn_error):
|
||||
await client.ping()
|
||||
except redis_conn_error:
|
||||
pass
|
||||
|
||||
client = aioredis.Redis(port=server.port, **with_tls_client_args)
|
||||
try:
|
||||
assert await client.dbsize() != 0
|
||||
except redis_conn_error:
|
||||
pass
|
||||
|
||||
client = aioredis.Redis(port=server.port, ssl_cert_reqs=None)
|
||||
try:
|
||||
assert await client.dbsize() != 0
|
||||
except redis_conn_error:
|
||||
pass
|
||||
await client.close()
|
||||
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ async def test_config_enable_tls(
|
|||
await client.ping()
|
||||
|
||||
# Connecting without TLS should fail.
|
||||
with pytest.raises(redis.exceptions.ConnectionError):
|
||||
with pytest.raises(redis.exceptions.ResponseError):
|
||||
async with server.client() as client_unauth:
|
||||
await client_unauth.ping()
|
||||
|
||||
|
|
Loading…
Reference in a new issue