mirror of
https://github.com/TwiN/gatus.git
synced 2024-12-14 11:58:04 +00:00
Minor improvements
This commit is contained in:
parent
d07d3434a6
commit
ca977fefa8
2 changed files with 28 additions and 37 deletions
|
@ -3,7 +3,7 @@ package client
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
|
@ -78,34 +78,29 @@ func CanCreateTCPConnection(address string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func CanPerformStartTls(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) {
|
// CanPerformStartTLS checks whether a connection can be established to an address using the STARTTLS protocol
|
||||||
tokens := strings.Split(address, ":")
|
func CanPerformStartTLS(address string, insecure bool) (connected bool, certificate *x509.Certificate, err error) {
|
||||||
if len(tokens) != 2 {
|
hostAndPort := strings.Split(address, ":")
|
||||||
err = fmt.Errorf("invalid address for starttls, must HOST:PORT")
|
if len(hostAndPort) != 2 {
|
||||||
|
return false, nil, errors.New("invalid address for starttls, format must be host:port")
|
||||||
|
}
|
||||||
|
smtpClient, err := smtp.Dial(address)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tlsconfig := &tls.Config{
|
err = smtpClient.StartTLS(&tls.Config{
|
||||||
InsecureSkipVerify: insecure,
|
InsecureSkipVerify: insecure,
|
||||||
ServerName: tokens[0],
|
ServerName: hostAndPort[0],
|
||||||
}
|
})
|
||||||
|
|
||||||
c, err := smtp.Dial(address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if state, ok := smtpClient.TLSConnectionState(); ok {
|
||||||
err = c.StartTLS(tlsconfig)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if state, ok := c.TLSConnectionState(); ok {
|
|
||||||
certificate = state.PeerCertificates[0]
|
certificate = state.PeerCertificates[0]
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("could not get TLS connection state")
|
return false, nil, errors.New("could not get TLS connection state")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
connected = true
|
return true, certificate, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged
|
// Ping checks if an address can be pinged and returns the round-trip time if the address can be pinged
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -51,35 +50,32 @@ func TestPing(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanPerformStartTls(t *testing.T) {
|
func TestCanPerformStartTLS(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
address string
|
address string
|
||||||
insecure bool
|
insecure bool
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantConnected bool
|
wantConnected bool
|
||||||
wantCertificate *x509.Certificate
|
wantErr bool
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "invalid address",
|
name: "invalid address",
|
||||||
args: args{
|
args: args{
|
||||||
address: "test",
|
address: "test",
|
||||||
},
|
},
|
||||||
wantConnected: false,
|
wantConnected: false,
|
||||||
wantCertificate: nil,
|
wantErr: true,
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error dial",
|
name: "error dial",
|
||||||
args: args{
|
args: args{
|
||||||
address: "test:1234",
|
address: "test:1234",
|
||||||
},
|
},
|
||||||
wantConnected: false,
|
wantConnected: false,
|
||||||
wantCertificate: nil,
|
wantErr: true,
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid starttls",
|
name: "valid starttls",
|
||||||
|
@ -92,13 +88,13 @@ func TestCanPerformStartTls(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
gotConnected, _, err := CanPerformStartTls(tt.args.address, tt.args.insecure)
|
connected, _, err := CanPerformStartTLS(tt.args.address, tt.args.insecure)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("CanPerformStartTls() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("CanPerformStartTLS() err=%v, wantErr=%v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if gotConnected != tt.wantConnected {
|
if connected != tt.wantConnected {
|
||||||
t.Errorf("CanPerformStartTls() gotConnected = %v, want %v", gotConnected, tt.wantConnected)
|
t.Errorf("CanPerformStartTLS() connected=%v, wantConnected=%v", connected, tt.wantConnected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue