diff --git a/api/api.go b/api/api.go new file mode 100644 index 00000000..2f38ac65 --- /dev/null +++ b/api/api.go @@ -0,0 +1,106 @@ +package api + +import ( + "io/fs" + "log" + "net/http" + "os" + + "github.com/TwiN/gatus/v5/config" + static "github.com/TwiN/gatus/v5/web" + "github.com/TwiN/health" + fiber "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/gofiber/fiber/v2/middleware/compress" + "github.com/gofiber/fiber/v2/middleware/cors" + fiberfs "github.com/gofiber/fiber/v2/middleware/filesystem" + "github.com/gofiber/fiber/v2/middleware/recover" + "github.com/gofiber/fiber/v2/middleware/redirect" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +type API struct { + router *fiber.App +} + +func New(cfg *config.Config) *API { + api := &API{} + api.router = api.createRouter(cfg) + return api +} + +func (a *API) Router() *fiber.App { + return a.router +} + +func (a *API) createRouter(cfg *config.Config) *fiber.App { + app := fiber.New(fiber.Config{ + ErrorHandler: func(c *fiber.Ctx, err error) error { + log.Printf("[api.ErrorHandler] %s", err.Error()) + return fiber.DefaultErrorHandler(c, err) + }, + }) + if os.Getenv("ENVIRONMENT") == "dev" { + app.Use(cors.New(cors.Config{ + AllowOrigins: "http://localhost:8081", + AllowCredentials: true, + })) + } + apiRouter := app.Group("/api") + protectedAPIRouter := apiRouter.Group("/") + unprotectedAPIRouter := apiRouter.Group("/") + if cfg.Metrics { + metricsHandler := promhttp.InstrumentMetricHandler(prometheus.DefaultRegisterer, promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ + DisableCompression: true, + })) + app.Get("/metrics", adaptor.HTTPHandler(metricsHandler)) + } + // Security (ORDER IS IMPORTANT: middlewares must be applied before the routes are registered) + if cfg.Security != nil { + if err := cfg.Security.RegisterHandlers(app); err != nil { + panic(err) + } + if err := cfg.Security.ApplySecurityMiddleware(protectedAPIRouter); err != nil { + panic(err) + } + } + // Middlewares + app.Use(recover.New()) + app.Use(compress.New()) + // Routes + handler := ConfigHandler{securityConfig: cfg.Security} + unprotectedAPIRouter.Get("/v1/config", handler.GetConfig) + protectedAPIRouter.Get("/v1/endpoints/statuses", EndpointStatuses(cfg)) + protectedAPIRouter.Get("/v1/endpoints/:key/statuses", EndpointStatus) + unprotectedAPIRouter.Get("/v1/endpoints/:key/health/badge.svg", HealthBadge) + unprotectedAPIRouter.Get("/v1/endpoints/:key/uptimes/:duration/badge.svg", UptimeBadge) + unprotectedAPIRouter.Get("/v1/endpoints/:key/response-times/:duration/badge.svg", ResponseTimeBadge(cfg)) + unprotectedAPIRouter.Get("/v1/endpoints/:key/response-times/:duration/chart.svg", ResponseTimeChart) + // SPA + app.Get("/", SinglePageApplication(cfg.UI)) + app.Get("/endpoints/:name", SinglePageApplication(cfg.UI)) + // Health endpoint + healthHandler := health.Handler().WithJSON(true) + app.Get("/health", func(c *fiber.Ctx) error { + statusCode, body := healthHandler.GetResponseStatusCodeAndBody() + return c.Status(statusCode).Send(body) + }) + // Everything else falls back on static content + app.Use(redirect.New(redirect.Config{ + Rules: map[string]string{ + "/index.html": "/", + }, + StatusCode: 301, + })) + staticFileSystem, err := fs.Sub(static.FileSystem, static.RootPath) + if err != nil { + panic(err) + } + app.Use("/", fiberfs.New(fiberfs.Config{ + Root: http.FS(staticFileSystem), + Index: "index.html", + Browse: true, + })) + return app +} diff --git a/api/api_test.go b/api/api_test.go new file mode 100644 index 00000000..ab4f16ce --- /dev/null +++ b/api/api_test.go @@ -0,0 +1,97 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/TwiN/gatus/v5/config" + "github.com/TwiN/gatus/v5/config/ui" + "github.com/TwiN/gatus/v5/security" + "github.com/gofiber/fiber/v2" +) + +func TestNew(t *testing.T) { + type Scenario struct { + Name string + Path string + ExpectedCode int + Gzip bool + WithSecurity bool + } + scenarios := []Scenario{ + { + Name: "health", + Path: "/health", + ExpectedCode: fiber.StatusOK, + }, + { + Name: "metrics", + Path: "/metrics", + ExpectedCode: fiber.StatusOK, + }, + { + Name: "favicon.ico", + Path: "/favicon.ico", + ExpectedCode: fiber.StatusOK, + }, + { + Name: "app.js", + Path: "/js/app.js", + ExpectedCode: fiber.StatusOK, + }, + { + Name: "app.js-gzipped", + Path: "/js/app.js", + ExpectedCode: fiber.StatusOK, + Gzip: true, + }, + { + Name: "chunk-vendors.js", + Path: "/js/chunk-vendors.js", + ExpectedCode: fiber.StatusOK, + }, + { + Name: "chunk-vendors.js-gzipped", + Path: "/js/chunk-vendors.js", + ExpectedCode: fiber.StatusOK, + Gzip: true, + }, + { + Name: "index", + Path: "/", + ExpectedCode: fiber.StatusOK, + }, + { + Name: "index-html-redirect", + Path: "/index.html", + ExpectedCode: fiber.StatusMovedPermanently, + }, + } + for _, scenario := range scenarios { + t.Run(scenario.Name, func(t *testing.T) { + cfg := &config.Config{Metrics: true, UI: &ui.Config{}} + if scenario.WithSecurity { + cfg.Security = &security.Config{ + Basic: &security.BasicConfig{ + Username: "john.doe", + PasswordBcryptHashBase64Encoded: "JDJhJDA4JDFoRnpPY1hnaFl1OC9ISlFsa21VS09wOGlPU1ZOTDlHZG1qeTFvb3dIckRBUnlHUmNIRWlT", + }, + } + } + api := New(cfg) + router := api.Router() + request := httptest.NewRequest("GET", scenario.Path, http.NoBody) + if scenario.Gzip { + request.Header.Set("Accept-Encoding", "gzip") + } + response, err := router.Test(request) + if err != nil { + t.Fatal(err) + } + if response.StatusCode != scenario.ExpectedCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, response.StatusCode) + } + }) + } +} diff --git a/controller/handler/badge.go b/api/badge.go similarity index 78% rename from controller/handler/badge.go rename to api/badge.go index 7d9af244..24cb3002 100644 --- a/controller/handler/badge.go +++ b/api/badge.go @@ -1,8 +1,7 @@ -package handler +package api import ( "fmt" - "net/http" "strconv" "strings" "time" @@ -11,7 +10,7 @@ import ( "github.com/TwiN/gatus/v5/storage/store" "github.com/TwiN/gatus/v5/storage/store/common" "github.com/TwiN/gatus/v5/storage/store/common/paging" - "github.com/gorilla/mux" + "github.com/gofiber/fiber/v2" ) const ( @@ -35,10 +34,9 @@ var ( // UptimeBadge handles the automatic generation of badge based on the group name and endpoint name passed. // -// Valid values for {duration}: 7d, 24h, 1h -func UptimeBadge(writer http.ResponseWriter, request *http.Request) { - variables := mux.Vars(request) - duration := variables["duration"] +// Valid values for :duration -> 7d, 24h, 1h +func UptimeBadge(c *fiber.Ctx) error { + duration := c.Params("duration") var from time.Time switch duration { case "7d": @@ -48,35 +46,30 @@ func UptimeBadge(writer http.ResponseWriter, request *http.Request) { case "1h": from = time.Now().Add(-2 * time.Hour) // Because uptime metrics are stored by hour, we have to cheat a little default: - http.Error(writer, "Durations supported: 7d, 24h, 1h", http.StatusBadRequest) - return + return c.Status(400).SendString("Durations supported: 7d, 24h, 1h") } - key := variables["key"] + key := c.Params("key") uptime, err := store.Get().GetUptimeByKey(key, from, time.Now()) if err != nil { if err == common.ErrEndpointNotFound { - http.Error(writer, err.Error(), http.StatusNotFound) + return c.Status(404).SendString(err.Error()) } else if err == common.ErrInvalidTimeRange { - http.Error(writer, err.Error(), http.StatusBadRequest) - } else { - http.Error(writer, err.Error(), http.StatusInternalServerError) + return c.Status(400).SendString(err.Error()) } - return + return c.Status(500).SendString(err.Error()) } - writer.Header().Set("Content-Type", "image/svg+xml") - writer.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") - writer.Header().Set("Expires", "0") - writer.WriteHeader(http.StatusOK) - _, _ = writer.Write(generateUptimeBadgeSVG(duration, uptime)) + c.Set("Content-Type", "image/svg+xml") + c.Set("Cache-Control", "no-cache, no-store, must-revalidate") + c.Set("Expires", "0") + return c.Status(200).Send(generateUptimeBadgeSVG(duration, uptime)) } // ResponseTimeBadge handles the automatic generation of badge based on the group name and endpoint name passed. // -// Valid values for {duration}: 7d, 24h, 1h -func ResponseTimeBadge(config *config.Config) http.HandlerFunc { - return func(writer http.ResponseWriter, request *http.Request) { - variables := mux.Vars(request) - duration := variables["duration"] +// Valid values for :duration -> 7d, 24h, 1h +func ResponseTimeBadge(config *config.Config) fiber.Handler { + return func(c *fiber.Ctx) error { + duration := c.Params("duration") var from time.Time switch duration { case "7d": @@ -86,44 +79,37 @@ func ResponseTimeBadge(config *config.Config) http.HandlerFunc { case "1h": from = time.Now().Add(-2 * time.Hour) // Because response time metrics are stored by hour, we have to cheat a little default: - http.Error(writer, "Durations supported: 7d, 24h, 1h", http.StatusBadRequest) - return + return c.Status(400).SendString("Durations supported: 7d, 24h, 1h") } - key := variables["key"] + key := c.Params("key") averageResponseTime, err := store.Get().GetAverageResponseTimeByKey(key, from, time.Now()) if err != nil { if err == common.ErrEndpointNotFound { - http.Error(writer, err.Error(), http.StatusNotFound) + return c.Status(404).SendString(err.Error()) } else if err == common.ErrInvalidTimeRange { - http.Error(writer, err.Error(), http.StatusBadRequest) - } else { - http.Error(writer, err.Error(), http.StatusInternalServerError) + return c.Status(400).SendString(err.Error()) } - return + return c.Status(500).SendString(err.Error()) } - writer.Header().Set("Content-Type", "image/svg+xml") - writer.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") - writer.Header().Set("Expires", "0") - writer.WriteHeader(http.StatusOK) - _, _ = writer.Write(generateResponseTimeBadgeSVG(duration, averageResponseTime, key, config)) + c.Set("Content-Type", "image/svg+xml") + c.Set("Cache-Control", "no-cache, no-store, must-revalidate") + c.Set("Expires", "0") + return c.Status(200).Send(generateResponseTimeBadgeSVG(duration, averageResponseTime, key, config)) } } // HealthBadge handles the automatic generation of badge based on the group name and endpoint name passed. -func HealthBadge(writer http.ResponseWriter, request *http.Request) { - variables := mux.Vars(request) - key := variables["key"] +func HealthBadge(c *fiber.Ctx) error { + key := c.Params("key") pagingConfig := paging.NewEndpointStatusParams() status, err := store.Get().GetEndpointStatusByKey(key, pagingConfig.WithResults(1, 1)) if err != nil { if err == common.ErrEndpointNotFound { - http.Error(writer, err.Error(), http.StatusNotFound) + return c.Status(404).SendString(err.Error()) } else if err == common.ErrInvalidTimeRange { - http.Error(writer, err.Error(), http.StatusBadRequest) - } else { - http.Error(writer, err.Error(), http.StatusInternalServerError) + return c.Status(400).SendString(err.Error()) } - return + return c.Status(500).SendString(err.Error()) } healthStatus := HealthStatusUnknown if len(status.Results) > 0 { @@ -133,11 +119,10 @@ func HealthBadge(writer http.ResponseWriter, request *http.Request) { healthStatus = HealthStatusDown } } - writer.Header().Set("Content-Type", "image/svg+xml") - writer.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") - writer.Header().Set("Expires", "0") - writer.WriteHeader(http.StatusOK) - _, _ = writer.Write(generateHealthBadgeSVG(healthStatus)) + c.Set("Content-Type", "image/svg+xml") + c.Set("Cache-Control", "no-cache, no-store, must-revalidate") + c.Set("Expires", "0") + return c.Status(200).Send(generateHealthBadgeSVG(healthStatus)) } func generateUptimeBadgeSVG(duration string, uptime float64) []byte { diff --git a/controller/handler/badge_test.go b/api/badge_test.go similarity index 96% rename from controller/handler/badge_test.go rename to api/badge_test.go index 4dc7db15..d0ca313c 100644 --- a/controller/handler/badge_test.go +++ b/api/badge_test.go @@ -1,4 +1,4 @@ -package handler +package api import ( "net/http" @@ -31,38 +31,13 @@ func TestBadge(t *testing.T) { }, } - testSuccessfulResult = core.Result{ - Hostname: "example.org", - IP: "127.0.0.1", - HTTPStatus: 200, - Errors: nil, - Connected: true, - Success: true, - Timestamp: timestamp, - Duration: 150 * time.Millisecond, - CertificateExpiration: 10 * time.Hour, - ConditionResults: []*core.ConditionResult{ - { - Condition: "[STATUS] == 200", - Success: true, - }, - { - Condition: "[RESPONSE_TIME] < 500", - Success: true, - }, - { - Condition: "[CERTIFICATE_EXPIRATION] < 72h", - Success: true, - }, - }, - } - cfg.Endpoints[0].UIConfig = ui.GetDefaultConfig() cfg.Endpoints[1].UIConfig = ui.GetDefaultConfig() watchdog.UpdateEndpointStatuses(cfg.Endpoints[0], &core.Result{Success: true, Connected: true, Duration: time.Millisecond, Timestamp: time.Now()}) watchdog.UpdateEndpointStatuses(cfg.Endpoints[1], &core.Result{Success: false, Connected: false, Duration: time.Second, Timestamp: time.Now()}) - router := CreateRouter(cfg) + api := New(cfg) + router := api.Router() type Scenario struct { Name string Path string @@ -153,14 +128,16 @@ func TestBadge(t *testing.T) { } for _, scenario := range scenarios { t.Run(scenario.Name, func(t *testing.T) { - request, _ := http.NewRequest("GET", scenario.Path, http.NoBody) + request := httptest.NewRequest("GET", scenario.Path, http.NoBody) if scenario.Gzip { request.Header.Set("Accept-Encoding", "gzip") } - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.ExpectedCode { - t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, responseRecorder.Code) + response, err := router.Test(request) + if err != nil { + return + } + if response.StatusCode != scenario.ExpectedCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, response.StatusCode) } }) } @@ -266,6 +243,32 @@ func TestGetBadgeColorFromResponseTime(t *testing.T) { Endpoints: []*core.Endpoint{&firstTestEndpoint, &secondTestEndpoint}, } + testSuccessfulResult := core.Result{ + Hostname: "example.org", + IP: "127.0.0.1", + HTTPStatus: 200, + Errors: nil, + Connected: true, + Success: true, + Timestamp: time.Now(), + Duration: 150 * time.Millisecond, + CertificateExpiration: 10 * time.Hour, + ConditionResults: []*core.ConditionResult{ + { + Condition: "[STATUS] == 200", + Success: true, + }, + { + Condition: "[RESPONSE_TIME] < 500", + Success: true, + }, + { + Condition: "[CERTIFICATE_EXPIRATION] < 72h", + Success: true, + }, + }, + } + store.Get().Insert(&firstTestEndpoint, &testSuccessfulResult) store.Get().Insert(&secondTestEndpoint, &testSuccessfulResult) diff --git a/api/cache.go b/api/cache.go new file mode 100644 index 00000000..a1a58baf --- /dev/null +++ b/api/cache.go @@ -0,0 +1,15 @@ +package api + +import ( + "time" + + "github.com/TwiN/gocache/v2" +) + +const ( + cacheTTL = 10 * time.Second +) + +var ( + cache = gocache.NewCache().WithMaxSize(100).WithEvictionPolicy(gocache.FirstInFirstOut) +) diff --git a/controller/handler/chart.go b/api/chart.go similarity index 74% rename from controller/handler/chart.go rename to api/chart.go index d0fce820..11c634b8 100644 --- a/controller/handler/chart.go +++ b/api/chart.go @@ -1,4 +1,4 @@ -package handler +package api import ( "log" @@ -9,7 +9,7 @@ import ( "github.com/TwiN/gatus/v5/storage/store" "github.com/TwiN/gatus/v5/storage/store/common" - "github.com/gorilla/mux" + "github.com/gofiber/fiber/v2" "github.com/wcharczuk/go-chart/v2" "github.com/wcharczuk/go-chart/v2/drawing" ) @@ -29,9 +29,8 @@ var ( } ) -func ResponseTimeChart(writer http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - duration := vars["duration"] +func ResponseTimeChart(c *fiber.Ctx) error { + duration := c.Params("duration") var from time.Time switch duration { case "7d": @@ -39,23 +38,19 @@ func ResponseTimeChart(writer http.ResponseWriter, r *http.Request) { case "24h": from = time.Now().Truncate(time.Hour).Add(-24 * time.Hour) default: - http.Error(writer, "Durations supported: 7d, 24h", http.StatusBadRequest) - return + return c.Status(400).SendString("Durations supported: 7d, 24h") } - hourlyAverageResponseTime, err := store.Get().GetHourlyAverageResponseTimeByKey(vars["key"], from, time.Now()) + hourlyAverageResponseTime, err := store.Get().GetHourlyAverageResponseTimeByKey(c.Params("key"), from, time.Now()) if err != nil { if err == common.ErrEndpointNotFound { - http.Error(writer, err.Error(), http.StatusNotFound) + return c.Status(404).SendString(err.Error()) } else if err == common.ErrInvalidTimeRange { - http.Error(writer, err.Error(), http.StatusBadRequest) - } else { - http.Error(writer, err.Error(), http.StatusInternalServerError) + return c.Status(400).SendString(err.Error()) } - return + return c.Status(500).SendString(err.Error()) } if len(hourlyAverageResponseTime) == 0 { - http.Error(writer, "", http.StatusNoContent) - return + return c.Status(204).SendString("") } series := chart.TimeSeries{ Name: "Average response time per hour", @@ -111,12 +106,13 @@ func ResponseTimeChart(writer http.ResponseWriter, r *http.Request) { }, Series: []chart.Series{series}, } - writer.Header().Set("Content-Type", "image/svg+xml") - writer.Header().Set("Cache-Control", "no-cache, no-store") - writer.Header().Set("Expires", "0") - writer.WriteHeader(http.StatusOK) - if err := graph.Render(chart.SVG, writer); err != nil { - log.Println("[handler][ResponseTimeChart] Failed to render response time chart:", err.Error()) - return + c.Set("Content-Type", "image/svg+xml") + c.Set("Cache-Control", "no-cache, no-store") + c.Set("Expires", "0") + c.Status(http.StatusOK) + if err := graph.Render(chart.SVG, c); err != nil { + log.Println("[api][ResponseTimeChart] Failed to render response time chart:", err.Error()) + return c.Status(500).SendString(err.Error()) } + return nil } diff --git a/controller/handler/chart_test.go b/api/chart_test.go similarity index 87% rename from controller/handler/chart_test.go rename to api/chart_test.go index d80591b7..341c6725 100644 --- a/controller/handler/chart_test.go +++ b/api/chart_test.go @@ -1,4 +1,4 @@ -package handler +package api import ( "net/http" @@ -30,7 +30,8 @@ func TestResponseTimeChart(t *testing.T) { } watchdog.UpdateEndpointStatuses(cfg.Endpoints[0], &core.Result{Success: true, Duration: time.Millisecond, Timestamp: time.Now()}) watchdog.UpdateEndpointStatuses(cfg.Endpoints[1], &core.Result{Success: false, Duration: time.Second, Timestamp: time.Now()}) - router := CreateRouter(cfg) + api := New(cfg) + router := api.Router() type Scenario struct { Name string Path string @@ -61,14 +62,16 @@ func TestResponseTimeChart(t *testing.T) { } for _, scenario := range scenarios { t.Run(scenario.Name, func(t *testing.T) { - request, _ := http.NewRequest("GET", scenario.Path, http.NoBody) + request := httptest.NewRequest("GET", scenario.Path, http.NoBody) if scenario.Gzip { request.Header.Set("Accept-Encoding", "gzip") } - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.ExpectedCode { - t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, responseRecorder.Code) + response, err := router.Test(request) + if err != nil { + return + } + if response.StatusCode != scenario.ExpectedCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, response.StatusCode) } }) } diff --git a/api/config.go b/api/config.go new file mode 100644 index 00000000..0b94e92a --- /dev/null +++ b/api/config.go @@ -0,0 +1,25 @@ +package api + +import ( + "fmt" + + "github.com/TwiN/gatus/v5/security" + "github.com/gofiber/fiber/v2" +) + +type ConfigHandler struct { + securityConfig *security.Config +} + +func (handler ConfigHandler) GetConfig(c *fiber.Ctx) error { + hasOIDC := false + isAuthenticated := true // Default to true if no security config is set + if handler.securityConfig != nil { + hasOIDC = handler.securityConfig.OIDC != nil + isAuthenticated = handler.securityConfig.IsAuthenticated(c) + } + // Return the config + c.Set("Content-Type", "application/json") + return c.Status(200). + SendString(fmt.Sprintf(`{"oidc":%v,"authenticated":%v}`, hasOIDC, isAuthenticated)) +} diff --git a/controller/handler/config_test.go b/api/config_test.go similarity index 51% rename from controller/handler/config_test.go rename to api/config_test.go index de69704c..5dab021d 100644 --- a/controller/handler/config_test.go +++ b/api/config_test.go @@ -1,12 +1,12 @@ -package handler +package api import ( + "io" "net/http" - "net/http/httptest" "testing" "github.com/TwiN/gatus/v5/security" - "github.com/gorilla/mux" + "github.com/gofiber/fiber/v2" ) func TestConfigHandler_ServeHTTP(t *testing.T) { @@ -20,15 +20,27 @@ func TestConfigHandler_ServeHTTP(t *testing.T) { } handler := ConfigHandler{securityConfig: securityConfig} // Create a fake router. We're doing this because I need the gate to be initialized. - securityConfig.ApplySecurityMiddleware(mux.NewRouter()) + app := fiber.New() + app.Get("/api/v1/config", handler.GetConfig) + err := securityConfig.ApplySecurityMiddleware(app) + if err != nil { + t.Error("expected err to be nil, but was", err) + } // Test the config handler request, _ := http.NewRequest("GET", "/api/v1/config", http.NoBody) - responseRecorder := httptest.NewRecorder() - handler.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusOK { - t.Error("expected code to be 200, but was", responseRecorder.Code) + response, err := app.Test(request) + if err != nil { + t.Error("expected err to be nil, but was", err) } - if responseRecorder.Body.String() != `{"oidc":true,"authenticated":false}` { - t.Error("expected body to be `{\"oidc\":true,\"authenticated\":false}`, but was", responseRecorder.Body.String()) + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + t.Error("expected code to be 200, but was", response.StatusCode) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Error("expected err to be nil, but was", err) + } + if string(body) != `{"oidc":true,"authenticated":false}` { + t.Error("expected body to be `{\"oidc\":true,\"authenticated\":false}`, but was", string(body)) } } diff --git a/controller/handler/endpoint_status.go b/api/endpoint_status.go similarity index 60% rename from controller/handler/endpoint_status.go rename to api/endpoint_status.go index b75fedc9..14a1d55c 100644 --- a/controller/handler/endpoint_status.go +++ b/api/endpoint_status.go @@ -1,12 +1,10 @@ -package handler +package api import ( "encoding/json" "fmt" "io" "log" - "net/http" - "time" "github.com/TwiN/gatus/v5/client" "github.com/TwiN/gatus/v5/config" @@ -15,32 +13,21 @@ import ( "github.com/TwiN/gatus/v5/storage/store" "github.com/TwiN/gatus/v5/storage/store/common" "github.com/TwiN/gatus/v5/storage/store/common/paging" - "github.com/TwiN/gocache/v2" - "github.com/gorilla/mux" -) - -const ( - cacheTTL = 10 * time.Second -) - -var ( - cache = gocache.NewCache().WithMaxSize(100).WithEvictionPolicy(gocache.FirstInFirstOut) + "github.com/gofiber/fiber/v2" ) // EndpointStatuses handles requests to retrieve all EndpointStatus // Due to how intensive this operation can be on the storage, this function leverages a cache. -func EndpointStatuses(cfg *config.Config) http.HandlerFunc { - return func(writer http.ResponseWriter, r *http.Request) { - page, pageSize := extractPageAndPageSizeFromRequest(r) +func EndpointStatuses(cfg *config.Config) fiber.Handler { + return func(c *fiber.Ctx) error { + page, pageSize := extractPageAndPageSizeFromRequest(c) value, exists := cache.Get(fmt.Sprintf("endpoint-status-%d-%d", page, pageSize)) var data []byte if !exists { - var err error endpointStatuses, err := store.Get().GetAllEndpointStatuses(paging.NewEndpointStatusParams().WithResults(page, pageSize)) if err != nil { - log.Printf("[handler][EndpointStatuses] Failed to retrieve endpoint statuses: %s", err.Error()) - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + log.Printf("[api][EndpointStatuses] Failed to retrieve endpoint statuses: %s", err.Error()) + return c.Status(500).SendString(err.Error()) } // ALPHA: Retrieve endpoint statuses from remote instances if endpointStatusesFromRemote, err := getEndpointStatusesFromRemoteInstances(cfg.Remote); err != nil { @@ -51,17 +38,15 @@ func EndpointStatuses(cfg *config.Config) http.HandlerFunc { // Marshal endpoint statuses to JSON data, err = json.Marshal(endpointStatuses) if err != nil { - log.Printf("[handler][EndpointStatuses] Unable to marshal object to JSON: %s", err.Error()) - http.Error(writer, "unable to marshal object to JSON", http.StatusInternalServerError) - return + log.Printf("[api][EndpointStatuses] Unable to marshal object to JSON: %s", err.Error()) + return c.Status(500).SendString("unable to marshal object to JSON") } cache.SetWithTTL(fmt.Sprintf("endpoint-status-%d-%d", page, pageSize), data, cacheTTL) } else { data = value.([]byte) } - writer.Header().Add("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) - _, _ = writer.Write(data) + c.Set("Content-Type", "application/json") + return c.Status(200).Send(data) } } @@ -98,31 +83,25 @@ func getEndpointStatusesFromRemoteInstances(remoteConfig *remote.Config) ([]*cor } // EndpointStatus retrieves a single core.EndpointStatus by group and endpoint name -func EndpointStatus(writer http.ResponseWriter, r *http.Request) { - page, pageSize := extractPageAndPageSizeFromRequest(r) - vars := mux.Vars(r) - endpointStatus, err := store.Get().GetEndpointStatusByKey(vars["key"], paging.NewEndpointStatusParams().WithResults(page, pageSize).WithEvents(1, common.MaximumNumberOfEvents)) +func EndpointStatus(c *fiber.Ctx) error { + page, pageSize := extractPageAndPageSizeFromRequest(c) + endpointStatus, err := store.Get().GetEndpointStatusByKey(c.Params("key"), paging.NewEndpointStatusParams().WithResults(page, pageSize).WithEvents(1, common.MaximumNumberOfEvents)) if err != nil { if err == common.ErrEndpointNotFound { - http.Error(writer, err.Error(), http.StatusNotFound) - return + return c.Status(404).SendString(err.Error()) } - log.Printf("[handler][EndpointStatus] Failed to retrieve endpoint status: %s", err.Error()) - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + log.Printf("[api][EndpointStatus] Failed to retrieve endpoint status: %s", err.Error()) + return c.Status(500).SendString(err.Error()) } - if endpointStatus == nil { - log.Printf("[handler][EndpointStatus] Endpoint with key=%s not found", vars["key"]) - http.Error(writer, "not found", http.StatusNotFound) - return + if endpointStatus == nil { // XXX: is this check necessary? + log.Printf("[api][EndpointStatus] Endpoint with key=%s not found", c.Params("key")) + return c.Status(404).SendString("not found") } output, err := json.Marshal(endpointStatus) if err != nil { - log.Printf("[handler][EndpointStatus] Unable to marshal object to JSON: %s", err.Error()) - http.Error(writer, "unable to marshal object to JSON", http.StatusInternalServerError) - return + log.Printf("[api][EndpointStatus] Unable to marshal object to JSON: %s", err.Error()) + return c.Status(500).SendString("unable to marshal object to JSON") } - writer.Header().Add("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) - _, _ = writer.Write(output) + c.Set("Content-Type", "application/json") + return c.Status(200).Send(output) } diff --git a/controller/handler/endpoint_status_test.go b/api/endpoint_status_test.go similarity index 90% rename from controller/handler/endpoint_status_test.go rename to api/endpoint_status_test.go index 608bc624..b88a915f 100644 --- a/controller/handler/endpoint_status_test.go +++ b/api/endpoint_status_test.go @@ -1,6 +1,7 @@ -package handler +package api import ( + "io" "net/http" "net/http/httptest" "testing" @@ -97,8 +98,8 @@ func TestEndpointStatus(t *testing.T) { } watchdog.UpdateEndpointStatuses(cfg.Endpoints[0], &core.Result{Success: true, Duration: time.Millisecond, Timestamp: time.Now()}) watchdog.UpdateEndpointStatuses(cfg.Endpoints[1], &core.Result{Success: false, Duration: time.Second, Timestamp: time.Now()}) - router := CreateRouter(cfg) - + api := New(cfg) + router := api.Router() type Scenario struct { Name string Path string @@ -130,14 +131,16 @@ func TestEndpointStatus(t *testing.T) { } for _, scenario := range scenarios { t.Run(scenario.Name, func(t *testing.T) { - request, _ := http.NewRequest("GET", scenario.Path, http.NoBody) + request := httptest.NewRequest("GET", scenario.Path, http.NoBody) if scenario.Gzip { request.Header.Set("Accept-Encoding", "gzip") } - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.ExpectedCode { - t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, responseRecorder.Code) + response, err := router.Test(request) + if err != nil { + return + } + if response.StatusCode != scenario.ExpectedCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, response.StatusCode) } }) } @@ -153,8 +156,8 @@ func TestEndpointStatuses(t *testing.T) { // Can't be bothered dealing with timezone issues on the worker that runs the automated tests firstResult.Timestamp = time.Time{} secondResult.Timestamp = time.Time{} - router := CreateRouter(&config.Config{Metrics: true}) - + api := New(&config.Config{Metrics: true}) + router := api.Router() type Scenario struct { Name string Path string @@ -196,15 +199,21 @@ func TestEndpointStatuses(t *testing.T) { for _, scenario := range scenarios { t.Run(scenario.Name, func(t *testing.T) { - request, _ := http.NewRequest("GET", scenario.Path, http.NoBody) - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.ExpectedCode { - t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, responseRecorder.Code) + request := httptest.NewRequest("GET", scenario.Path, http.NoBody) + response, err := router.Test(request) + if err != nil { + return } - output := responseRecorder.Body.String() - if output != scenario.ExpectedBody { - t.Errorf("expected:\n %s\n\ngot:\n %s", scenario.ExpectedBody, output) + defer response.Body.Close() + if response.StatusCode != scenario.ExpectedCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, response.StatusCode) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Error("expected err to be nil, but was", err) + } + if string(body) != scenario.ExpectedBody { + t.Errorf("expected:\n %s\n\ngot:\n %s", scenario.ExpectedBody, string(body)) } }) } diff --git a/api/spa.go b/api/spa.go new file mode 100644 index 00000000..77c29f53 --- /dev/null +++ b/api/spa.go @@ -0,0 +1,30 @@ +package api + +import ( + _ "embed" + "html/template" + "log" + + "github.com/TwiN/gatus/v5/config/ui" + static "github.com/TwiN/gatus/v5/web" + "github.com/gofiber/fiber/v2" +) + +func SinglePageApplication(ui *ui.Config) fiber.Handler { + return func(c *fiber.Ctx) error { + t, err := template.ParseFS(static.FileSystem, static.IndexPath) + if err != nil { + // This should never happen, because ui.ValidateAndSetDefaults validates that the template works. + log.Println("[api][SinglePageApplication] Failed to parse template. This should never happen, because the template is validated on start. Error:", err.Error()) + return c.Status(500).SendString("Failed to parse template. This should never happen, because the template is validated on start.") + } + c.Set("Content-Type", "text/html") + err = t.Execute(c, ui) + if err != nil { + // This should never happen, because ui.ValidateAndSetDefaults validates that the template works. + log.Println("[api][SinglePageApplication] Failed to execute template. This should never happen, because the template is validated on start. Error:", err.Error()) + return c.Status(500).SendString("Failed to parse template. This should never happen, because the template is validated on start.") + } + return c.SendStatus(200) + } +} diff --git a/controller/handler/spa_test.go b/api/spa_test.go similarity index 68% rename from controller/handler/spa_test.go rename to api/spa_test.go index f900d4a1..250ecdde 100644 --- a/controller/handler/spa_test.go +++ b/api/spa_test.go @@ -1,12 +1,15 @@ -package handler +package api import ( + "io" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/TwiN/gatus/v5/config" + "github.com/TwiN/gatus/v5/config/ui" "github.com/TwiN/gatus/v5/core" "github.com/TwiN/gatus/v5/storage/store" "github.com/TwiN/gatus/v5/watchdog" @@ -27,10 +30,14 @@ func TestSinglePageApplication(t *testing.T) { Group: "core", }, }, + UI: &ui.Config{ + Title: "example-title", + }, } watchdog.UpdateEndpointStatuses(cfg.Endpoints[0], &core.Result{Success: true, Duration: time.Millisecond, Timestamp: time.Now()}) watchdog.UpdateEndpointStatuses(cfg.Endpoints[1], &core.Result{Success: false, Duration: time.Second, Timestamp: time.Now()}) - router := CreateRouter(cfg) + api := New(cfg) + router := api.Router() type Scenario struct { Name string Path string @@ -41,24 +48,31 @@ func TestSinglePageApplication(t *testing.T) { { Name: "frontend-home", Path: "/", - ExpectedCode: http.StatusOK, + ExpectedCode: 200, }, { Name: "frontend-endpoint", Path: "/endpoints/core_frontend", - ExpectedCode: http.StatusOK, + ExpectedCode: 200, }, } for _, scenario := range scenarios { t.Run(scenario.Name, func(t *testing.T) { - request, _ := http.NewRequest("GET", scenario.Path, http.NoBody) + request := httptest.NewRequest("GET", scenario.Path, http.NoBody) if scenario.Gzip { request.Header.Set("Accept-Encoding", "gzip") } - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.ExpectedCode { - t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, responseRecorder.Code) + response, err := router.Test(request) + if err != nil { + return + } + defer response.Body.Close() + if response.StatusCode != scenario.ExpectedCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, response.StatusCode) + } + body, _ := io.ReadAll(response.Body) + if !strings.Contains(string(body), cfg.UI.Title) { + t.Errorf("%s %s should have contained the title", request.Method, request.URL) } }) } diff --git a/controller/handler/util.go b/api/util.go similarity index 76% rename from controller/handler/util.go rename to api/util.go index 5ffd5cb2..90e647b5 100644 --- a/controller/handler/util.go +++ b/api/util.go @@ -1,10 +1,10 @@ -package handler +package api import ( - "net/http" "strconv" "github.com/TwiN/gatus/v5/storage/store/common" + "github.com/gofiber/fiber/v2" ) const ( @@ -18,9 +18,9 @@ const ( MaximumPageSize = common.MaximumNumberOfResults ) -func extractPageAndPageSizeFromRequest(r *http.Request) (page int, pageSize int) { +func extractPageAndPageSizeFromRequest(c *fiber.Ctx) (page, pageSize int) { var err error - if pageParameter := r.URL.Query().Get("page"); len(pageParameter) == 0 { + if pageParameter := c.Query("page"); len(pageParameter) == 0 { page = DefaultPage } else { page, err = strconv.Atoi(pageParameter) @@ -31,7 +31,7 @@ func extractPageAndPageSizeFromRequest(r *http.Request) (page int, pageSize int) page = DefaultPage } } - if pageSizeParameter := r.URL.Query().Get("pageSize"); len(pageSizeParameter) == 0 { + if pageSizeParameter := c.Query("pageSize"); len(pageSizeParameter) == 0 { pageSize = DefaultPageSize } else { pageSize, err = strconv.Atoi(pageSizeParameter) diff --git a/controller/handler/util_test.go b/api/util_test.go similarity index 77% rename from controller/handler/util_test.go rename to api/util_test.go index c91b72e1..652146ed 100644 --- a/controller/handler/util_test.go +++ b/api/util_test.go @@ -1,9 +1,11 @@ -package handler +package api import ( "fmt" - "net/http" "testing" + + "github.com/gofiber/fiber/v2" + "github.com/valyala/fasthttp" ) func TestExtractPageAndPageSizeFromRequest(t *testing.T) { @@ -54,8 +56,12 @@ func TestExtractPageAndPageSizeFromRequest(t *testing.T) { } for _, scenario := range scenarios { t.Run("page-"+scenario.Page+"-pageSize-"+scenario.PageSize, func(t *testing.T) { - request, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/statuses?page=%s&pageSize=%s", scenario.Page, scenario.PageSize), http.NoBody) - actualPage, actualPageSize := extractPageAndPageSizeFromRequest(request) + //request := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/statuses?page=%s&pageSize=%s", scenario.Page, scenario.PageSize), http.NoBody) + app := fiber.New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().SetRequestURI(fmt.Sprintf("/api/v1/statuses?page=%s&pageSize=%s", scenario.Page, scenario.PageSize)) + actualPage, actualPageSize := extractPageAndPageSizeFromRequest(c) if actualPage != scenario.ExpectedPage { t.Errorf("expected %d, got %d", scenario.ExpectedPage, actualPage) } diff --git a/config/endpoints/README.md b/config/endpoints/README.md new file mode 100644 index 00000000..9416ef98 --- /dev/null +++ b/config/endpoints/README.md @@ -0,0 +1 @@ +TODO: move files from core to here. \ No newline at end of file diff --git a/controller/controller.go b/controller/controller.go index 4d152ddc..cc20be89 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -1,53 +1,43 @@ package controller import ( - "context" - "fmt" "log" - "net/http" "os" "time" + "github.com/TwiN/gatus/v5/api" "github.com/TwiN/gatus/v5/config" - "github.com/TwiN/gatus/v5/controller/handler" + "github.com/gofiber/fiber/v2" ) var ( - // server is the http.Server created by Handle. - // The only reason it exists is for testing purposes. - server *http.Server + app *fiber.App ) // Handle creates the router and starts the server func Handle(cfg *config.Config) { - var router http.Handler = handler.CreateRouter(cfg) - if os.Getenv("ENVIRONMENT") == "dev" { - router = handler.DevelopmentCORS(router) - } - tlsConfig := cfg.Web.TLSConfig() - server = &http.Server{ - Addr: fmt.Sprintf("%s:%d", cfg.Web.Address, cfg.Web.Port), - TLSConfig: tlsConfig, - Handler: router, - ReadTimeout: 15 * time.Second, - WriteTimeout: 15 * time.Second, - IdleTimeout: 15 * time.Second, - } - log.Println("[controller][Handle] Listening on " + cfg.Web.SocketAddress()) + api := api.New(cfg) + app = api.Router() + server := app.Server() + server.ReadTimeout = 15 * time.Second + server.WriteTimeout = 15 * time.Second + server.IdleTimeout = 15 * time.Second + server.TLSConfig = cfg.Web.TLSConfig() if os.Getenv("ROUTER_TEST") == "true" { return } - if tlsConfig != nil { - log.Println("[controller][Handle]", server.ListenAndServeTLS("", "")) + log.Println("[controller][Handle] Listening on " + cfg.Web.SocketAddress()) + if server.TLSConfig != nil { + log.Println("[controller][Handle]", app.ListenTLS(cfg.Web.SocketAddress(), "", "")) } else { - log.Println("[controller][Handle]", server.ListenAndServe()) + log.Println("[controller][Handle]", app.Listen(cfg.Web.SocketAddress())) } } // Shutdown stops the server func Shutdown() { - if server != nil { - _ = server.Shutdown(context.TODO()) - server = nil + if app != nil { + _ = app.Shutdown() + app = nil } } diff --git a/controller/controller_test.go b/controller/controller_test.go index ca398c20..62a28297 100644 --- a/controller/controller_test.go +++ b/controller/controller_test.go @@ -10,6 +10,7 @@ import ( "github.com/TwiN/gatus/v5/config" "github.com/TwiN/gatus/v5/config/web" "github.com/TwiN/gatus/v5/core" + "github.com/gofiber/fiber/v2" ) func TestHandle(t *testing.T) { @@ -34,13 +35,15 @@ func TestHandle(t *testing.T) { defer os.Clearenv() Handle(cfg) defer Shutdown() - request, _ := http.NewRequest("GET", "/health", http.NoBody) - responseRecorder := httptest.NewRecorder() - server.Handler.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusOK { + request := httptest.NewRequest("GET", "/health", http.NoBody) + response, err := app.Test(request) + if err != nil { + t.Fatal(err) + } + if response.StatusCode != 200 { t.Error("expected GET /health to return status code 200") } - if server == nil { + if app == nil { t.Fatal("server should've been set (but because we set ROUTER_TEST, it shouldn't have been started)") } } @@ -74,13 +77,15 @@ func TestHandleTLS(t *testing.T) { defer os.Clearenv() Handle(cfg) defer Shutdown() - request, _ := http.NewRequest("GET", "/health", http.NoBody) - responseRecorder := httptest.NewRecorder() - server.Handler.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.expectedStatusCode { - t.Errorf("expected GET /health to return status code %d, got %d", scenario.expectedStatusCode, responseRecorder.Code) + request := httptest.NewRequest("GET", "/health", http.NoBody) + response, err := app.Test(request) + if err != nil { + t.Fatal(err) } - if server == nil { + if response.StatusCode != scenario.expectedStatusCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.expectedStatusCode, response.StatusCode) + } + if app == nil { t.Fatal("server should've been set (but because we set ROUTER_TEST, it shouldn't have been started)") } }) @@ -89,9 +94,9 @@ func TestHandleTLS(t *testing.T) { func TestShutdown(t *testing.T) { // Pretend that we called controller.Handle(), which initializes the server variable - server = &http.Server{} + app = fiber.New() Shutdown() - if server != nil { + if app != nil { t.Error("server should've been shut down") } } diff --git a/controller/handler/config.go b/controller/handler/config.go deleted file mode 100644 index 1d626a53..00000000 --- a/controller/handler/config.go +++ /dev/null @@ -1,26 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - - "github.com/TwiN/gatus/v5/security" -) - -// ConfigHandler is a handler that returns information for the front end of the application. -type ConfigHandler struct { - securityConfig *security.Config -} - -func (handler ConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - hasOIDC := false - isAuthenticated := true // Default to true if no security config is set - if handler.securityConfig != nil { - hasOIDC = handler.securityConfig.OIDC != nil - isAuthenticated = handler.securityConfig.IsAuthenticated(r) - } - // Return the config - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(fmt.Sprintf(`{"oidc":%v,"authenticated":%v}`, hasOIDC, isAuthenticated))) -} diff --git a/controller/handler/cors.go b/controller/handler/cors.go deleted file mode 100644 index f7e0b9b9..00000000 --- a/controller/handler/cors.go +++ /dev/null @@ -1,14 +0,0 @@ -package handler - -import "net/http" - -func DevelopmentCORS(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Origin", "http://localhost:8081") - if r.Method == "OPTIONS" { - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/controller/handler/gzip.go b/controller/handler/gzip.go deleted file mode 100644 index 03664795..00000000 --- a/controller/handler/gzip.go +++ /dev/null @@ -1,58 +0,0 @@ -package handler - -import ( - "compress/gzip" - "io" - "net/http" - "strings" - "sync" -) - -var gzPool = sync.Pool{ - New: func() interface{} { - return gzip.NewWriter(io.Discard) - }, -} - -type gzipResponseWriter struct { - io.Writer - http.ResponseWriter -} - -// WriteHeader sends an HTTP response header with the provided status code. -// It also deletes the Content-Length header, since the GZIP compression may modify the size of the payload -func (w *gzipResponseWriter) WriteHeader(status int) { - w.Header().Del("Content-Length") - w.ResponseWriter.WriteHeader(status) -} - -// Write writes len(b) bytes from b to the underlying data stream. -func (w *gzipResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) -} - -// GzipHandler compresses the response of a given http.Handler if the request's headers specify that the client -// supports gzip encoding -func GzipHandler(next http.Handler) http.Handler { - return GzipHandlerFunc(func(writer http.ResponseWriter, r *http.Request) { - next.ServeHTTP(writer, r) - }) -} - -// GzipHandlerFunc compresses the response of a given http.HandlerFunc if the request's headers specify that the client -// supports gzip encoding -func GzipHandlerFunc(next http.HandlerFunc) http.HandlerFunc { - return func(writer http.ResponseWriter, r *http.Request) { - // If the request doesn't specify that it supports gzip, then don't compress it - if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - next.ServeHTTP(writer, r) - return - } - writer.Header().Set("Content-Encoding", "gzip") - gz := gzPool.Get().(*gzip.Writer) - defer gzPool.Put(gz) - gz.Reset(writer) - defer gz.Close() - next.ServeHTTP(&gzipResponseWriter{ResponseWriter: writer, Writer: gz}, r) - } -} diff --git a/controller/handler/handler.go b/controller/handler/handler.go deleted file mode 100644 index 7d2a783b..00000000 --- a/controller/handler/handler.go +++ /dev/null @@ -1,54 +0,0 @@ -package handler - -import ( - "io/fs" - "net/http" - - "github.com/TwiN/gatus/v5/config" - static "github.com/TwiN/gatus/v5/web" - "github.com/TwiN/health" - "github.com/gorilla/mux" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -func CreateRouter(cfg *config.Config) *mux.Router { - router := mux.NewRouter() - if cfg.Metrics { - router.Handle("/metrics", promhttp.InstrumentMetricHandler(prometheus.DefaultRegisterer, promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ - DisableCompression: true, - }))).Methods("GET") - } - router.Use(GzipHandler) - api := router.PathPrefix("/api").Subrouter() - protected := api.PathPrefix("/").Subrouter() - unprotected := api.PathPrefix("/").Subrouter() - if cfg.Security != nil { - if err := cfg.Security.RegisterHandlers(router); err != nil { - panic(err) - } - if err := cfg.Security.ApplySecurityMiddleware(protected); err != nil { - panic(err) - } - } - // Endpoints - unprotected.Handle("/v1/config", ConfigHandler{securityConfig: cfg.Security}).Methods("GET") - protected.HandleFunc("/v1/endpoints/statuses", EndpointStatuses(cfg)).Methods("GET") - protected.HandleFunc("/v1/endpoints/{key}/statuses", EndpointStatus).Methods("GET") - unprotected.HandleFunc("/v1/endpoints/{key}/health/badge.svg", HealthBadge).Methods("GET") - unprotected.HandleFunc("/v1/endpoints/{key}/uptimes/{duration}/badge.svg", UptimeBadge).Methods("GET") - unprotected.HandleFunc("/v1/endpoints/{key}/response-times/{duration}/badge.svg", ResponseTimeBadge(cfg)).Methods("GET") - unprotected.HandleFunc("/v1/endpoints/{key}/response-times/{duration}/chart.svg", ResponseTimeChart).Methods("GET") - // Misc - router.Handle("/health", health.Handler().WithJSON(true)).Methods("GET") - // SPA - router.HandleFunc("/endpoints/{name}", SinglePageApplication(cfg.UI)).Methods("GET") - router.HandleFunc("/", SinglePageApplication(cfg.UI)).Methods("GET") - // Everything else falls back on static content - staticFileSystem, err := fs.Sub(static.FileSystem, static.RootPath) - if err != nil { - panic(err) - } - router.PathPrefix("/").Handler(http.FileServer(http.FS(staticFileSystem))) - return router -} diff --git a/controller/handler/handler_test.go b/controller/handler/handler_test.go deleted file mode 100644 index fee3b4af..00000000 --- a/controller/handler/handler_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package handler - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/TwiN/gatus/v5/config" -) - -func TestCreateRouter(t *testing.T) { - router := CreateRouter(&config.Config{Metrics: true}) - type Scenario struct { - Name string - Path string - ExpectedCode int - Gzip bool - } - scenarios := []Scenario{ - { - Name: "health", - Path: "/health", - ExpectedCode: http.StatusOK, - }, - { - Name: "metrics", - Path: "/metrics", - ExpectedCode: http.StatusOK, - }, - { - Name: "favicon.ico", - Path: "/favicon.ico", - ExpectedCode: http.StatusOK, - }, - { - Name: "app.js", - Path: "/js/app.js", - ExpectedCode: http.StatusOK, - }, - { - Name: "app.js-gzipped", - Path: "/js/app.js", - ExpectedCode: http.StatusOK, - Gzip: true, - }, - { - Name: "chunk-vendors.js", - Path: "/js/chunk-vendors.js", - ExpectedCode: http.StatusOK, - }, - { - Name: "chunk-vendors.js-gzipped", - Path: "/js/chunk-vendors.js", - ExpectedCode: http.StatusOK, - Gzip: true, - }, - { - Name: "index-redirect", - Path: "/index.html", - ExpectedCode: http.StatusMovedPermanently, - }, - } - for _, scenario := range scenarios { - t.Run(scenario.Name, func(t *testing.T) { - request, _ := http.NewRequest("GET", scenario.Path, http.NoBody) - if scenario.Gzip { - request.Header.Set("Accept-Encoding", "gzip") - } - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != scenario.ExpectedCode { - t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, scenario.ExpectedCode, responseRecorder.Code) - } - }) - } -} diff --git a/controller/handler/spa.go b/controller/handler/spa.go deleted file mode 100644 index 6d702d46..00000000 --- a/controller/handler/spa.go +++ /dev/null @@ -1,31 +0,0 @@ -package handler - -import ( - _ "embed" - "html/template" - "log" - "net/http" - - "github.com/TwiN/gatus/v5/config/ui" - static "github.com/TwiN/gatus/v5/web" -) - -func SinglePageApplication(ui *ui.Config) http.HandlerFunc { - return func(writer http.ResponseWriter, request *http.Request) { - t, err := template.ParseFS(static.FileSystem, static.IndexPath) - if err != nil { - // This should never happen, because ui.ValidateAndSetDefaults validates that the template works. - log.Println("[handler][SinglePageApplication] Failed to parse template. This should never happen, because the template is validated on start. Error:", err.Error()) - http.Error(writer, "Failed to parse template. This should never happen, because the template is validated on start.", http.StatusInternalServerError) - return - } - writer.Header().Set("Content-Type", "text/html") - err = t.Execute(writer, ui) - if err != nil { - // This should never happen, because ui.ValidateAndSetDefaults validates that the template works. - log.Println("[handler][SinglePageApplication] Failed to execute template. This should never happen, because the template is validated on start. Error:", err.Error()) - http.Error(writer, "Failed to execute template. This should never happen, because the template is validated on start.", http.StatusInternalServerError) - return - } - } -} diff --git a/go.mod b/go.mod index 5bc07281..0f9e2ba9 100644 --- a/go.mod +++ b/go.mod @@ -9,14 +9,15 @@ require ( github.com/TwiN/health v1.6.0 github.com/TwiN/whois v1.1.2 github.com/coreos/go-oidc/v3 v3.6.0 + github.com/gofiber/fiber/v2 v2.46.0 github.com/google/go-github/v48 v48.2.0 github.com/google/uuid v1.3.0 - github.com/gorilla/mux v1.8.0 github.com/ishidawataru/sctp v0.0.0-20210707070123-9a39160e9062 github.com/lib/pq v1.10.7 github.com/miekg/dns v1.1.54 github.com/prometheus-community/pro-bing v0.2.0 github.com/prometheus/client_golang v1.14.0 + github.com/valyala/fasthttp v1.47.0 github.com/wcharczuk/go-chart/v2 v2.1.0 golang.org/x/crypto v0.10.0 golang.org/x/oauth2 v0.8.0 @@ -26,6 +27,7 @@ require ( ) require ( + github.com/andybalholm/brotli v1.0.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -35,18 +37,28 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/mattn/go-isatty v0.0.16 // indirect + github.com/klauspost/compress v1.16.5 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-runewidth v0.0.14 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/philhofer/fwd v1.1.2 // indirect github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.4.4 // indirect + github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect + github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/tinylib/msgp v1.1.8 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect golang.org/x/image v0.5.0 // indirect golang.org/x/mod v0.7.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sync v0.2.0 // indirect golang.org/x/sys v0.9.0 // indirect - golang.org/x/tools v0.3.0 // indirect + golang.org/x/tools v0.4.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect diff --git a/go.sum b/go.sum index 4159de44..625790fd 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -86,6 +88,8 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofiber/fiber/v2 v2.46.0 h1:wkkWotblsGVlLjXj2dpgKQAYHtXumsK/HyFugQM68Ns= +github.com/gofiber/fiber/v2 v2.46.0/go.mod h1:DNl0/c37WLe0g92U6lx1VMQuxGUQY5V7EIaVoEsUffc= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= @@ -150,8 +154,6 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= @@ -169,6 +171,8 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= +github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -179,8 +183,13 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -193,6 +202,9 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= +github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= +github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -229,7 +241,15 @@ github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0ua github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= +github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= +github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3D/WJsDd1iXHT96alCoN2KJo6/4x1DZC3wZs8= +github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= @@ -240,11 +260,21 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= +github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= +github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.47.0 h1:y7moDoxYzMooFpT5aHgNgVOQDrS3qlkfiP9mDtGGK9c= +github.com/valyala/fasthttp v1.47.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/wcharczuk/go-chart/v2 v2.1.0 h1:tY2slqVQ6bN+yHSnDYwZebLQFkphK4WNrVwnt7CJZ2I= github.com/wcharczuk/go-chart/v2 v2.1.0/go.mod h1:yx7MvAVNcP/kN9lKXM/NTce4au4DFN99j6i1OwDclNA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= @@ -325,11 +355,13 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -349,8 +381,10 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -384,6 +418,7 @@ golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -394,10 +429,13 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -405,6 +443,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -449,9 +488,10 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.3.0 h1:SrNbZl6ECOS1qFzgTdQfWXZM9XBkiA6tkFrH9YSTPHM= -golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= +golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= +golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/security/config.go b/security/config.go index 6215bd28..e767be20 100644 --- a/security/config.go +++ b/security/config.go @@ -2,10 +2,13 @@ package security import ( "encoding/base64" + "log" "net/http" g8 "github.com/TwiN/g8/v2" - "github.com/gorilla/mux" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/gofiber/fiber/v2/middleware/basicauth" "golang.org/x/crypto/bcrypt" ) @@ -29,20 +32,20 @@ func (c *Config) IsValid() bool { } // RegisterHandlers registers all handlers required based on the security configuration -func (c *Config) RegisterHandlers(router *mux.Router) error { +func (c *Config) RegisterHandlers(router fiber.Router) error { if c.OIDC != nil { if err := c.OIDC.initialize(); err != nil { return err } - router.HandleFunc("/oidc/login", c.OIDC.loginHandler) - router.HandleFunc("/authorization-code/callback", c.OIDC.callbackHandler) + router.All("/oidc/login", c.OIDC.loginHandler) + router.All("/authorization-code/callback", adaptor.HTTPHandlerFunc(c.OIDC.callbackHandler)) } return nil } // ApplySecurityMiddleware applies an authentication middleware to the router passed. -// The router passed should be a subrouter in charge of handlers that require authentication. -func (c *Config) ApplySecurityMiddleware(api *mux.Router) error { +// The router passed should be a sub-router in charge of handlers that require authentication. +func (c *Config) ApplySecurityMiddleware(router fiber.Router) error { if c.OIDC != nil { // We're going to use g8 for session handling clientProvider := g8.NewClientProvider(func(token string) *g8.Client { @@ -61,7 +64,7 @@ func (c *Config) ApplySecurityMiddleware(api *mux.Router) error { // TODO: g8: Add a way to update cookie after? would need the writer authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) c.gate = g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) - api.Use(c.gate.Protect) + router.Use(adaptor.HTTPMiddleware(c.gate.Protect)) } else if c.Basic != nil { var decodedBcryptHash []byte if len(c.Basic.PasswordBcryptHashBase64Encoded) > 0 { @@ -71,29 +74,35 @@ func (c *Config) ApplySecurityMiddleware(api *mux.Router) error { return err } } - api.Use(func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - usernameEntered, passwordEntered, ok := r.BasicAuth() + router.Use(basicauth.New(basicauth.Config{ + Authorizer: func(username, password string) bool { if len(c.Basic.PasswordBcryptHashBase64Encoded) > 0 { - if !ok || usernameEntered != c.Basic.Username || bcrypt.CompareHashAndPassword(decodedBcryptHash, []byte(passwordEntered)) != nil { - w.Header().Set("WWW-Authenticate", "Basic") - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte("Unauthorized")) - return + if username != c.Basic.Username || bcrypt.CompareHashAndPassword(decodedBcryptHash, []byte(password)) != nil { + return false } } - handler.ServeHTTP(w, r) - }) - }) + return true + }, + Unauthorized: func(ctx *fiber.Ctx) error { + ctx.Set("WWW-Authenticate", "Basic") + return ctx.Status(401).SendString("Unauthorized") + }, + })) } return nil } // IsAuthenticated checks whether the user is authenticated // If the Config does not warrant authentication, it will always return true. -func (c *Config) IsAuthenticated(r *http.Request) bool { +func (c *Config) IsAuthenticated(ctx *fiber.Ctx) bool { if c.gate != nil { - token := c.gate.ExtractTokenFromRequest(r) + // TODO: Update g8 to support fasthttp natively? (see g8's fasthttp branch) + request, err := adaptor.ConvertRequest(ctx, false) + if err != nil { + log.Printf("[IsAuthenticated] Unexpected error converting request: %v", err) + return false + } + token := c.gate.ExtractTokenFromRequest(request) _, hasSession := sessions.Get(token) return hasSession } diff --git a/security/config_test.go b/security/config_test.go index 26f534ea..45fb245b 100644 --- a/security/config_test.go +++ b/security/config_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/gorilla/mux" + "github.com/gofiber/fiber/v2" "golang.org/x/oauth2" ) @@ -23,83 +23,96 @@ func TestConfig_ApplySecurityMiddleware(t *testing.T) { /////////// // BASIC // /////////// - // Bcrypt - c := &Config{Basic: &BasicConfig{ - Username: "john.doe", - PasswordBcryptHashBase64Encoded: "JDJhJDA4JDFoRnpPY1hnaFl1OC9ISlFsa21VS09wOGlPU1ZOTDlHZG1qeTFvb3dIckRBUnlHUmNIRWlT", - }} - api := mux.NewRouter() - api.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + t.Run("basic", func(t *testing.T) { + // Bcrypt + c := &Config{Basic: &BasicConfig{ + Username: "john.doe", + PasswordBcryptHashBase64Encoded: "JDJhJDA4JDFoRnpPY1hnaFl1OC9ISlFsa21VS09wOGlPU1ZOTDlHZG1qeTFvb3dIckRBUnlHUmNIRWlT", + }} + app := fiber.New() + if err := c.ApplySecurityMiddleware(app); err != nil { + t.Error("expected no error, got", err) + } + app.Get("/test", func(c *fiber.Ctx) error { + return c.SendStatus(200) + }) + // Try to access the route without basic auth + request := httptest.NewRequest("GET", "/test", http.NoBody) + response, err := app.Test(request) + if err != nil { + t.Fatal("expected no error, got", err) + } + if response.StatusCode != 401 { + t.Error("expected code to be 401, but was", response.StatusCode) + } + // Try again, but with basic auth + request = httptest.NewRequest("GET", "/test", http.NoBody) + request.SetBasicAuth("john.doe", "hunter2") + response, err = app.Test(request) + if err != nil { + t.Fatal("expected no error, got", err) + } + if response.StatusCode != 200 { + t.Error("expected code to be 200, but was", response.StatusCode) + } }) - if err := c.ApplySecurityMiddleware(api); err != nil { - t.Error("expected no error, but was", err) - } - // Try to access the route without basic auth - request, _ := http.NewRequest("GET", "/test", http.NoBody) - responseRecorder := httptest.NewRecorder() - api.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusUnauthorized { - t.Error("expected code to be 401, but was", responseRecorder.Code) - } - // Try again, but with basic auth - request, _ = http.NewRequest("GET", "/test", http.NoBody) - responseRecorder = httptest.NewRecorder() - request.SetBasicAuth("john.doe", "hunter2") - api.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusOK { - t.Error("expected code to be 200, but was", responseRecorder.Code) - } ////////// // OIDC // ////////// - api = mux.NewRouter() - api.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + t.Run("oidc", func(t *testing.T) { + c := &Config{OIDC: &OIDCConfig{ + IssuerURL: "https://sso.gatus.io/", + RedirectURL: "http://localhost:80/authorization-code/callback", + Scopes: []string{"openid"}, + AllowedSubjects: []string{"user1@example.com"}, + oauth2Config: oauth2.Config{}, + verifier: nil, + }} + app := fiber.New() + if err := c.ApplySecurityMiddleware(app); err != nil { + t.Error("expected no error, got", err) + } + app.Get("/test", func(c *fiber.Ctx) error { + return c.SendStatus(200) + }) + // Try without any session cookie + request := httptest.NewRequest("GET", "/test", http.NoBody) + response, err := app.Test(request) + if err != nil { + t.Fatal("expected no error, got", err) + } + if response.StatusCode != 401 { + t.Error("expected code to be 401, but was", response.StatusCode) + } + // Try with a session cookie + request = httptest.NewRequest("GET", "/test", http.NoBody) + request.AddCookie(&http.Cookie{Name: "session", Value: "123"}) + response, err = app.Test(request) + if err != nil { + t.Fatal("expected no error, got", err) + } + if response.StatusCode != 401 { + t.Error("expected code to be 401, but was", response.StatusCode) + } }) - c.OIDC = &OIDCConfig{ - IssuerURL: "https://sso.gatus.io/", - RedirectURL: "http://localhost:80/authorization-code/callback", - Scopes: []string{"openid"}, - AllowedSubjects: []string{"user1@example.com"}, - oauth2Config: oauth2.Config{}, - verifier: nil, - } - c.Basic = nil - if err := c.ApplySecurityMiddleware(api); err != nil { - t.Error("expected no error, but was", err) - } - // Try without any session cookie - request, _ = http.NewRequest("GET", "/test", http.NoBody) - responseRecorder = httptest.NewRecorder() - api.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusUnauthorized { - t.Error("expected code to be 401, but was", responseRecorder.Code) - } - // Try with a session cookie - request, _ = http.NewRequest("GET", "/test", http.NoBody) - request.AddCookie(&http.Cookie{Name: "session", Value: "123"}) - responseRecorder = httptest.NewRecorder() - api.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusUnauthorized { - t.Error("expected code to be 401, but was", responseRecorder.Code) - } } func TestConfig_RegisterHandlers(t *testing.T) { c := &Config{} - router := mux.NewRouter() - c.RegisterHandlers(router) + app := fiber.New() + c.RegisterHandlers(app) // Try to access the OIDC handler. This should fail, because the security config doesn't have OIDC - request, _ := http.NewRequest("GET", "/oidc/login", http.NoBody) - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusNotFound { - t.Error("expected code to be 404, but was", responseRecorder.Code) + request := httptest.NewRequest("GET", "/oidc/login", http.NoBody) + response, err := app.Test(request) + if err != nil { + t.Fatal("expected no error, got", err) + } + if response.StatusCode != 404 { + t.Error("expected code to be 404, but was", response.StatusCode) } // Set an empty OIDC config. This should fail, because the IssuerURL is required. c.OIDC = &OIDCConfig{} - if err := c.RegisterHandlers(router); err == nil { + if err := c.RegisterHandlers(app); err == nil { t.Fatal("expected an error, but got none") } // Set the OIDC config and try again @@ -109,13 +122,15 @@ func TestConfig_RegisterHandlers(t *testing.T) { Scopes: []string{"openid"}, AllowedSubjects: []string{"user1@example.com"}, } - if err := c.RegisterHandlers(router); err != nil { + if err := c.RegisterHandlers(app); err != nil { t.Fatal("expected no error, but got", err) } - request, _ = http.NewRequest("GET", "/oidc/login", http.NoBody) - responseRecorder = httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - if responseRecorder.Code != http.StatusFound { - t.Error("expected code to be 302, but was", responseRecorder.Code) + request = httptest.NewRequest("GET", "/oidc/login", http.NoBody) + response, err = app.Test(request) + if err != nil { + t.Fatal("expected no error, got", err) + } + if response.StatusCode != 302 { + t.Error("expected code to be 302, but was", response.StatusCode) } } diff --git a/security/oidc.go b/security/oidc.go index c854d134..980d3fc5 100644 --- a/security/oidc.go +++ b/security/oidc.go @@ -8,6 +8,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/gofiber/fiber/v2" "github.com/google/uuid" "golang.org/x/oauth2" ) @@ -47,28 +48,28 @@ func (c *OIDCConfig) initialize() error { return nil } -func (c *OIDCConfig) loginHandler(w http.ResponseWriter, r *http.Request) { +func (c *OIDCConfig) loginHandler(ctx *fiber.Ctx) error { state, nonce := uuid.NewString(), uuid.NewString() - http.SetCookie(w, &http.Cookie{ + ctx.Cookie(&fiber.Cookie{ Name: cookieNameState, Value: state, Path: "/", MaxAge: int(time.Hour.Seconds()), - SameSite: http.SameSiteLaxMode, - HttpOnly: true, + SameSite: "lax", + HTTPOnly: true, }) - http.SetCookie(w, &http.Cookie{ + ctx.Cookie(&fiber.Cookie{ Name: cookieNameNonce, Value: nonce, Path: "/", MaxAge: int(time.Hour.Seconds()), - SameSite: http.SameSiteLaxMode, - HttpOnly: true, + SameSite: "lax", + HTTPOnly: true, }) - http.Redirect(w, r, c.oauth2Config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound) + return ctx.Redirect(c.oauth2Config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound) } -func (c *OIDCConfig) callbackHandler(w http.ResponseWriter, r *http.Request) { +func (c *OIDCConfig) callbackHandler(w http.ResponseWriter, r *http.Request) { // TODO: Migrate to a native fiber handler // Check if there's an error if len(r.URL.Query().Get("error")) > 0 { http.Error(w, r.URL.Query().Get("error")+": "+r.URL.Query().Get("error_description"), http.StatusBadRequest)