mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2024-12-14 11:58:02 +00:00
chore: better error reporting when connecting to tls with plain socket (#2740)
* chore: better error reporting when connecting to tls with plain socket --------- Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
parent
30c3f63ca2
commit
2d246adbbb
4 changed files with 37 additions and 31 deletions
2
helio
2
helio
|
@ -1 +1 @@
|
||||||
Subproject commit 91c4f48d025bccd6fc3ff14e471accf1c7801f38
|
Subproject commit 8985263c3acca038752e8f9fdd8e9f61d2ec2b6f
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
#include <absl/container/flat_hash_map.h>
|
#include <absl/container/flat_hash_map.h>
|
||||||
#include <absl/strings/match.h>
|
#include <absl/strings/match.h>
|
||||||
|
#include <absl/strings/str_cat.h>
|
||||||
#include <mimalloc.h>
|
#include <mimalloc.h>
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "base/flags.h"
|
#include "base/flags.h"
|
||||||
#include "base/io_buf.h"
|
#include "base/io_buf.h"
|
||||||
#include "base/logging.h"
|
#include "base/logging.h"
|
||||||
|
@ -613,10 +613,25 @@ void Connection::HandleRequests() {
|
||||||
if (!(IsPrivileged() && no_tls_on_admin_port)) {
|
if (!(IsPrivileged() && no_tls_on_admin_port)) {
|
||||||
// Must be done atomically before the premption point in Accept so that at any
|
// Must be done atomically before the premption point in Accept so that at any
|
||||||
// point in time, the socket_ is defined.
|
// point in time, the socket_ is defined.
|
||||||
|
uint8_t buf[2];
|
||||||
|
auto read_sz = socket_->Read(io::MutableBytes(buf));
|
||||||
|
if (!read_sz || *read_sz < sizeof(buf)) {
|
||||||
|
VLOG(1) << "Error reading from peer " << remote_ep << " " << read_sz.error().message();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (buf[0] != 0x16 || buf[1] != 0x03) {
|
||||||
|
VLOG(1) << "Bad TLS header "
|
||||||
|
<< absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2),
|
||||||
|
absl::Hex(buf[1], absl::kZeroPad2));
|
||||||
|
peer->Write(
|
||||||
|
io::Buffer("-ERR Bad TLS header, double check "
|
||||||
|
"if you enabled TLS for your client.\r\n"));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
FiberAtomicGuard fg;
|
FiberAtomicGuard fg;
|
||||||
unique_ptr<tls::TlsSocket> tls_sock = make_unique<tls::TlsSocket>(std::move(socket_));
|
unique_ptr<tls::TlsSocket> tls_sock = make_unique<tls::TlsSocket>(std::move(socket_));
|
||||||
tls_sock->InitSSL(ssl_ctx_);
|
tls_sock->InitSSL(ssl_ctx_, buf);
|
||||||
SetSocket(tls_sock.release());
|
SetSocket(tls_sock.release());
|
||||||
}
|
}
|
||||||
FiberSocketBase::AcceptResult aresult = socket_->Accept();
|
FiberSocketBase::AcceptResult aresult = socket_->Accept();
|
||||||
|
|
|
@ -4,12 +4,12 @@ import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from redis import asyncio as aioredis
|
from redis import asyncio as aioredis
|
||||||
from redis.exceptions import ConnectionError as redis_conn_error
|
from redis.exceptions import ConnectionError as redis_conn_error, ResponseError
|
||||||
import async_timeout
|
import async_timeout
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from . import dfly_args
|
from . import dfly_args
|
||||||
from .instance import DflyInstance
|
from .instance import DflyInstance, DflyInstanceFactory
|
||||||
|
|
||||||
BASE_PORT = 1111
|
BASE_PORT = 1111
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ async def test_large_cmd(async_client: aioredis.Redis):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
|
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
|
||||||
server = df_local_factory.create(
|
server: DflyInstance = df_local_factory.create(
|
||||||
no_tls_on_admin_port="true",
|
no_tls_on_admin_port="true",
|
||||||
admin_port=1111,
|
admin_port=1111,
|
||||||
port=1211,
|
port=1211,
|
||||||
|
@ -573,13 +573,12 @@ async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_
|
||||||
)
|
)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
client = aioredis.Redis(port=server.port, password="XXX")
|
client = server.client(password="XXX")
|
||||||
try:
|
with pytest.raises((ResponseError)):
|
||||||
await client.execute_command("DBSIZE")
|
await client.dbsize()
|
||||||
except redis_conn_error:
|
await client.close()
|
||||||
pass
|
|
||||||
|
|
||||||
client = aioredis.Redis(port=server.admin_port, password="XXX")
|
client = server.admin_client(password="XXX")
|
||||||
assert await client.dbsize() == 0
|
assert await client.dbsize() == 0
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
@ -605,27 +604,19 @@ async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, d
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tls_reject(with_ca_tls_server_args, with_tls_client_args, df_local_factory):
|
async def test_tls_reject(
|
||||||
server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
|
with_ca_tls_server_args, with_tls_client_args, df_local_factory: DflyInstanceFactory
|
||||||
|
):
|
||||||
|
server: DflyInstance = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
|
||||||
server.start()
|
server.start()
|
||||||
|
|
||||||
client = aioredis.Redis(port=server.port, **with_tls_client_args, ssl_cert_reqs=None)
|
client = server.client(**with_tls_client_args, ssl_cert_reqs=None)
|
||||||
try:
|
|
||||||
await client.ping()
|
await client.ping()
|
||||||
except redis_conn_error:
|
await client.close()
|
||||||
pass
|
|
||||||
|
|
||||||
client = aioredis.Redis(port=server.port, **with_tls_client_args)
|
client = server.client(**with_tls_client_args)
|
||||||
try:
|
with pytest.raises(redis_conn_error):
|
||||||
assert await client.dbsize() != 0
|
await client.ping()
|
||||||
except redis_conn_error:
|
|
||||||
pass
|
|
||||||
|
|
||||||
client = aioredis.Redis(port=server.port, ssl_cert_reqs=None)
|
|
||||||
try:
|
|
||||||
assert await client.dbsize() != 0
|
|
||||||
except redis_conn_error:
|
|
||||||
pass
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,7 @@ async def test_config_enable_tls(
|
||||||
await client.ping()
|
await client.ping()
|
||||||
|
|
||||||
# Connecting without TLS should fail.
|
# Connecting without TLS should fail.
|
||||||
with pytest.raises(redis.exceptions.ConnectionError):
|
with pytest.raises(redis.exceptions.ResponseError):
|
||||||
async with server.client() as client_unauth:
|
async with server.client() as client_unauth:
|
||||||
await client_unauth.ping()
|
await client_unauth.ping()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue