1
0
Fork 0
mirror of https://github.com/dragonflydb/dragonfly.git synced 2024-12-14 11:58:02 +00:00

chore: better error reporting when connecting to tls with plain socket (#2740)

* chore: better error reporting when connecting to tls with plain socket

---------

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-03-19 17:20:23 +02:00 committed by GitHub
parent 30c3f63ca2
commit 2d246adbbb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 37 additions and 31 deletions

2
helio

@ -1 +1 @@
Subproject commit 91c4f48d025bccd6fc3ff14e471accf1c7801f38
Subproject commit 8985263c3acca038752e8f9fdd8e9f61d2ec2b6f

View file

@ -6,12 +6,12 @@
#include <absl/container/flat_hash_map.h>
#include <absl/strings/match.h>
#include <absl/strings/str_cat.h>
#include <mimalloc.h>
#include <numeric>
#include <variant>
#include "absl/strings/str_cat.h"
#include "base/flags.h"
#include "base/io_buf.h"
#include "base/logging.h"
@ -613,10 +613,25 @@ void Connection::HandleRequests() {
if (!(IsPrivileged() && no_tls_on_admin_port)) {
// Must be done atomically before the premption point in Accept so that at any
// point in time, the socket_ is defined.
uint8_t buf[2];
auto read_sz = socket_->Read(io::MutableBytes(buf));
if (!read_sz || *read_sz < sizeof(buf)) {
VLOG(1) << "Error reading from peer " << remote_ep << " " << read_sz.error().message();
return;
}
if (buf[0] != 0x16 || buf[1] != 0x03) {
VLOG(1) << "Bad TLS header "
<< absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2),
absl::Hex(buf[1], absl::kZeroPad2));
peer->Write(
io::Buffer("-ERR Bad TLS header, double check "
"if you enabled TLS for your client.\r\n"));
}
{
FiberAtomicGuard fg;
unique_ptr<tls::TlsSocket> tls_sock = make_unique<tls::TlsSocket>(std::move(socket_));
tls_sock->InitSSL(ssl_ctx_);
tls_sock->InitSSL(ssl_ctx_, buf);
SetSocket(tls_sock.release());
}
FiberSocketBase::AcceptResult aresult = socket_->Accept();

View file

@ -4,12 +4,12 @@ import pytest
import asyncio
import time
from redis import asyncio as aioredis
from redis.exceptions import ConnectionError as redis_conn_error
from redis.exceptions import ConnectionError as redis_conn_error, ResponseError
import async_timeout
from dataclasses import dataclass
from . import dfly_args
from .instance import DflyInstance
from .instance import DflyInstance, DflyInstanceFactory
BASE_PORT = 1111
@ -564,7 +564,7 @@ async def test_large_cmd(async_client: aioredis.Redis):
@pytest.mark.asyncio
async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_factory):
server = df_local_factory.create(
server: DflyInstance = df_local_factory.create(
no_tls_on_admin_port="true",
admin_port=1111,
port=1211,
@ -573,13 +573,12 @@ async def test_reject_non_tls_connections_on_tls(with_tls_server_args, df_local_
)
server.start()
client = aioredis.Redis(port=server.port, password="XXX")
try:
await client.execute_command("DBSIZE")
except redis_conn_error:
pass
client = server.client(password="XXX")
with pytest.raises((ResponseError)):
await client.dbsize()
await client.close()
client = aioredis.Redis(port=server.admin_port, password="XXX")
client = server.admin_client(password="XXX")
assert await client.dbsize() == 0
await client.close()
@ -605,27 +604,19 @@ async def test_tls_full_auth(with_ca_tls_server_args, with_ca_tls_client_args, d
@pytest.mark.asyncio
async def test_tls_reject(with_ca_tls_server_args, with_tls_client_args, df_local_factory):
server = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
async def test_tls_reject(
with_ca_tls_server_args, with_tls_client_args, df_local_factory: DflyInstanceFactory
):
server: DflyInstance = df_local_factory.create(port=BASE_PORT, **with_ca_tls_server_args)
server.start()
client = aioredis.Redis(port=server.port, **with_tls_client_args, ssl_cert_reqs=None)
try:
client = server.client(**with_tls_client_args, ssl_cert_reqs=None)
await client.ping()
await client.close()
client = server.client(**with_tls_client_args)
with pytest.raises(redis_conn_error):
await client.ping()
except redis_conn_error:
pass
client = aioredis.Redis(port=server.port, **with_tls_client_args)
try:
assert await client.dbsize() != 0
except redis_conn_error:
pass
client = aioredis.Redis(port=server.port, ssl_cert_reqs=None)
try:
assert await client.dbsize() != 0
except redis_conn_error:
pass
await client.close()

View file

@ -141,7 +141,7 @@ async def test_config_enable_tls(
await client.ping()
# Connecting without TLS should fail.
with pytest.raises(redis.exceptions.ConnectionError):
with pytest.raises(redis.exceptions.ResponseError):
async with server.client() as client_unauth:
await client_unauth.ping()