From bc25fea1c00b911eac1bc1cac45d5e5f9e8cbf48 Mon Sep 17 00:00:00 2001 From: TwinProduction Date: Thu, 30 Sep 2021 20:45:47 -0400 Subject: [PATCH] Minor improvements --- client/client.go | 24 ++++------- client/client_test.go | 94 +++++++++++++++++++++---------------------- 2 files changed, 55 insertions(+), 63 deletions(-) diff --git a/client/client.go b/client/client.go index 3e46ca88..ab34b9bb 100644 --- a/client/client.go +++ b/client/client.go @@ -40,7 +40,7 @@ func CanPerformStartTLS(address string, config *Config) (connected bool, certifi } conn, err := net.DialTimeout("tcp", address, config.Timeout) if err != nil { - return + return } smtpClient, err := smtp.NewClient(conn, hostAndPort[0]) if err != nil { @@ -63,24 +63,16 @@ func CanPerformStartTLS(address string, config *Config) (connected bool, certifi // CanPerformTLS checks whether a connection can be established to an address using the TLS protocol func CanPerformTLS(address string, config *Config) (connected bool, certificate *x509.Certificate, err error) { - conn, err := tls.DialWithDialer(&net.Dialer{Timeout: config.Timeout}, "tcp", address, nil) + connection, err := tls.DialWithDialer(&net.Dialer{Timeout: config.Timeout}, "tcp", address, nil) if err != nil { - return + return } - defer conn.Close() - - verifiedChains := conn.ConnectionState().VerifiedChains - if len(verifiedChains) == 0 { - return + defer connection.Close() + verifiedChains := connection.ConnectionState().VerifiedChains + if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 { + return } - - chain := verifiedChains[0] // VerifiedChains[0] == PeerCertificates[0] - if len(chain) == 0 { - return - } - - certificate = chain[0] - return true, certificate, nil + return true, verifiedChains[0][0], nil } // Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged diff --git a/client/client_test.go b/client/client_test.go index 5bbb4d99..76191e63 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -92,53 +92,53 @@ func TestCanPerformStartTLS(t *testing.T) { } func TestCanPerformTLS(t *testing.T) { - type args struct { - address string - insecure bool - } - tests := []struct { - name string - args args - wantConnected bool - wantErr bool - }{ - { - name: "invalid address", - args: args{ - address: "test", - }, - wantConnected: false, - wantErr: true, - }, - { - name: "error dial", - args: args{ - address: "test:1234", - }, - wantConnected: false, - wantErr: true, - }, - { - name: "valid tls", - args: args{ - address: "smtp.gmail.com:465", - }, - wantConnected: true, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - connected, _, err := CanPerformTLS(tt.args.address, &Config{Insecure: tt.args.insecure, Timeout: 5 * time.Second}) - if (err != nil) != tt.wantErr { - t.Errorf("CanPerformTLS() err=%v, wantErr=%v", err, tt.wantErr) - return - } - if connected != tt.wantConnected { - t.Errorf("CanPerformTLS() connected=%v, wantConnected=%v", connected, tt.wantConnected) - } - }) - } + type args struct { + address string + insecure bool + } + tests := []struct { + name string + args args + wantConnected bool + wantErr bool + }{ + { + name: "invalid address", + args: args{ + address: "test", + }, + wantConnected: false, + wantErr: true, + }, + { + name: "error dial", + args: args{ + address: "test:1234", + }, + wantConnected: false, + wantErr: true, + }, + { + name: "valid tls", + args: args{ + address: "smtp.gmail.com:465", + }, + wantConnected: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connected, _, err := CanPerformTLS(tt.args.address, &Config{Insecure: tt.args.insecure, Timeout: 5 * time.Second}) + if (err != nil) != tt.wantErr { + t.Errorf("CanPerformTLS() err=%v, wantErr=%v", err, tt.wantErr) + return + } + if connected != tt.wantConnected { + t.Errorf("CanPerformTLS() connected=%v, wantConnected=%v", connected, tt.wantConnected) + } + }) + } } func TestCanCreateTCPConnection(t *testing.T) {