diff --git a/gateway/handlers/basic_auth.go b/gateway/handlers/basic_auth.go new file mode 100644 index 00000000..8d266c3f --- /dev/null +++ b/gateway/handlers/basic_auth.go @@ -0,0 +1,29 @@ +package handlers + +import ( + "net/http" +) + +// DecorateWithBasicAuth enforces basic auth as a middleware with given credentials +func DecorateWithBasicAuth(next http.HandlerFunc, credentials *BasicAuthCredentials) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + user, password, ok := r.BasicAuth() + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + + if !ok || !(credentials.Password == password && user == credentials.User) { + + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("invalid credentials")) + return + } + + next.ServeHTTP(w, r) + } +} + +// BasicAuthCredentials for credentials +type BasicAuthCredentials struct { + User string + Password string +} diff --git a/gateway/handlers/basic_auth_test.go b/gateway/handlers/basic_auth_test.go new file mode 100644 index 00000000..7514bde7 --- /dev/null +++ b/gateway/handlers/basic_auth_test.go @@ -0,0 +1,63 @@ +package handlers + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func Test_AuthWithValidPassword_Gives200(t *testing.T) { + + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Hello World!") + } + w := httptest.NewRecorder() + + wantUser := "admin" + wantPassword := "password" + r := httptest.NewRequest(http.MethodGet, "http://localhost:8080", nil) + r.SetBasicAuth(wantUser, wantPassword) + wantCredentials := &BasicAuthCredentials{ + User: wantUser, + Password: wantPassword, + } + + decorated := DecorateWithBasicAuth(handler, wantCredentials) + decorated.ServeHTTP(w, r) + + wantCode := http.StatusOK + + if w.Code != wantCode { + t.Errorf("status code, want: %d, got: %d", wantCode, w.Code) + t.Fail() + } +} + +func Test_AuthWithInvalidPassword_Gives403(t *testing.T) { + + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Hello World!") + } + + w := httptest.NewRecorder() + + wantUser := "admin" + wantPassword := "test" + r := httptest.NewRequest(http.MethodGet, "http://localhost:8080", nil) + r.SetBasicAuth(wantUser, wantPassword) + + wantCredentials := &BasicAuthCredentials{ + User: wantUser, + Password: "", + } + + decorated := DecorateWithBasicAuth(handler, wantCredentials) + decorated.ServeHTTP(w, r) + + wantCode := http.StatusUnauthorized + if w.Code != wantCode { + t.Errorf("status code, want: %d, got: %d", wantCode, w.Code) + t.Fail() + } +} diff --git a/gateway/server.go b/gateway/server.go index a86e929a..090210ee 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -5,8 +5,10 @@ package main import ( "fmt" + "io/ioutil" "log" "net/http" + "strings" "time" "github.com/gorilla/mux" @@ -33,6 +35,27 @@ func main() { log.Printf("Binding to external function provider: %s", config.FunctionsProviderURL) + var credentials *handlers.BasicAuthCredentials + + if config.UseBasicAuth { + userPath := "/var/secrets/basic_auth_user" + user, userErr := ioutil.ReadFile(userPath) + if userErr != nil { + log.Panicf("Unable to load %s", userPath) + } + + userPassword := "/var/secrets/basic_auth_password" + password, passErr := ioutil.ReadFile(userPassword) + if passErr != nil { + log.Panicf("Unable to load %s", userPassword) + } + + credentials = &handlers.BasicAuthCredentials{ + User: strings.TrimSpace(string(user)), + Password: strings.TrimSpace(string(password)), + } + } + metricsOptions := metrics.BuildMetricsOptions() metrics.RegisterMetrics(metricsOptions) @@ -84,8 +107,20 @@ func main() { } prometheusQuery := metrics.NewPrometheusQuery(config.PrometheusHost, config.PrometheusPort, &http.Client{}) - listFunctions := metrics.AddMetricsHandler(faasHandlers.ListFunctions, prometheusQuery) + faasHandlers.ListFunctions = metrics.AddMetricsHandler(faasHandlers.ListFunctions, prometheusQuery) faasHandlers.Proxy = handlers.MakeCallIDMiddleware(faasHandlers.Proxy) + + if credentials != nil { + faasHandlers.UpdateFunction = + handlers.DecorateWithBasicAuth(faasHandlers.UpdateFunction, credentials) + faasHandlers.DeleteFunction = + handlers.DecorateWithBasicAuth(faasHandlers.DeleteFunction, credentials) + faasHandlers.DeployFunction = + handlers.DecorateWithBasicAuth(faasHandlers.DeployFunction, credentials) + faasHandlers.ListFunctions = + handlers.DecorateWithBasicAuth(faasHandlers.ListFunctions, credentials) + } + r := mux.NewRouter() // r.StrictSlash(false) // This didn't work, so register routes twice. @@ -97,7 +132,7 @@ func main() { r.HandleFunc("/system/alert", faasHandlers.Alert) r.HandleFunc("/system/function/{name:[-a-zA-Z_0-9]+}", queryFunction).Methods(http.MethodGet) - r.HandleFunc("/system/functions", listFunctions).Methods(http.MethodGet) + r.HandleFunc("/system/functions", faasHandlers.ListFunctions).Methods(http.MethodGet) r.HandleFunc("/system/functions", faasHandlers.DeployFunction).Methods(http.MethodPost) r.HandleFunc("/system/functions", faasHandlers.DeleteFunction).Methods(http.MethodDelete) r.HandleFunc("/system/functions", faasHandlers.UpdateFunction).Methods(http.MethodPut) @@ -115,7 +150,12 @@ func main() { allowedCORSHost := "raw.githubusercontent.com" fsCORS := handlers.DecorateWithCORS(fs, allowedCORSHost) - r.PathPrefix("/ui/").Handler(http.StripPrefix("/ui", fsCORS)).Methods(http.MethodGet) + uiHandler := http.StripPrefix("/ui", fsCORS) + if credentials != nil { + r.PathPrefix("/ui/").Handler(handlers.DecorateWithBasicAuth(uiHandler.ServeHTTP, credentials)).Methods(http.MethodGet) + } else { + r.PathPrefix("/ui/").Handler(uiHandler).Methods(http.MethodGet) + } metricsHandler := metrics.PrometheusHandler() r.Handle("/metrics", metricsHandler) diff --git a/gateway/types/readconfig.go b/gateway/types/readconfig.go index 7eea2f1f..8152b8ef 100644 --- a/gateway/types/readconfig.go +++ b/gateway/types/readconfig.go @@ -105,6 +105,8 @@ func (ReadConfig) Read(hasEnv HasEnv) GatewayConfig { cfg.DirectFunctions = parseBoolValue(hasEnv.Getenv("direct_functions")) cfg.DirectFunctionsSuffix = hasEnv.Getenv("direct_functions_suffix") + cfg.UseBasicAuth = parseBoolValue(hasEnv.Getenv("basic_auth")) + return cfg } @@ -140,6 +142,9 @@ type GatewayConfig struct { // If set this will be used to resolve functions directly DirectFunctionsSuffix string + + // If set, reads secrets from file-system for enabling basic auth. + UseBasicAuth bool } // UseNATS Use NATSor not diff --git a/gateway/types/readconfig_test.go b/gateway/types/readconfig_test.go index d9d827b0..cd220931 100644 --- a/gateway/types/readconfig_test.go +++ b/gateway/types/readconfig_test.go @@ -197,3 +197,30 @@ func TestRead_PrometheusDefaults(t *testing.T) { t.Fail() } } + +func TestRead_BasicAuthDefaults(t *testing.T) { + defaults := NewEnvBucket() + + readConfig := ReadConfig{} + + config := readConfig.Read(defaults) + + if config.UseBasicAuth != false { + t.Logf("config.UseBasicAuth, want: %t, got: %t\n", false, config.UseBasicAuth) + t.Fail() + } +} + +func TestRead_BasicAuth_SetTrue(t *testing.T) { + defaults := NewEnvBucket() + defaults.Setenv("basic_auth", "true") + + readConfig := ReadConfig{} + + config := readConfig.Read(defaults) + + if config.UseBasicAuth != true { + t.Logf("config.UseBasicAuth, want: %t, got: %t\n", true, config.UseBasicAuth) + t.Fail() + } +}