diff --git a/gateway/handlers/forwarding_proxy.go b/gateway/handlers/forwarding_proxy.go index bfae4ef6..e4e729a7 100644 --- a/gateway/handlers/forwarding_proxy.go +++ b/gateway/handlers/forwarding_proxy.go @@ -39,7 +39,11 @@ type URLPathTransformer interface { } // MakeForwardingProxyHandler create a handler which forwards HTTP requests -func MakeForwardingProxyHandler(proxy *types.HTTPClientReverseProxy, notifiers []HTTPNotifier, baseURLResolver BaseURLResolver, urlPathTransformer URLPathTransformer) http.HandlerFunc { +func MakeForwardingProxyHandler(proxy *types.HTTPClientReverseProxy, + notifiers []HTTPNotifier, + baseURLResolver BaseURLResolver, + urlPathTransformer URLPathTransformer, + serviceAuthInjector AuthInjector) http.HandlerFunc { writeRequestURI := false if _, exists := os.LookupEnv("write_request_uri"); exists { @@ -54,7 +58,7 @@ func MakeForwardingProxyHandler(proxy *types.HTTPClientReverseProxy, notifiers [ start := time.Now() - statusCode, err := forwardRequest(w, r, proxy.Client, baseURL, requestURL, proxy.Timeout, writeRequestURI) + statusCode, err := forwardRequest(w, r, proxy.Client, baseURL, requestURL, proxy.Timeout, writeRequestURI, serviceAuthInjector) seconds := time.Since(start) if err != nil { @@ -96,13 +100,24 @@ func buildUpstreamRequest(r *http.Request, baseURL string, requestURL string) *h return upstreamReq } -func forwardRequest(w http.ResponseWriter, r *http.Request, proxyClient *http.Client, baseURL string, requestURL string, timeout time.Duration, writeRequestURI bool) (int, error) { +func forwardRequest(w http.ResponseWriter, + r *http.Request, + proxyClient *http.Client, + baseURL string, + requestURL string, + timeout time.Duration, + writeRequestURI bool, + serviceAuthInjector AuthInjector) (int, error) { upstreamReq := buildUpstreamRequest(r, baseURL, requestURL) if upstreamReq.Body != nil { defer upstreamReq.Body.Close() } + if serviceAuthInjector != nil { + serviceAuthInjector.Inject(upstreamReq) + } + if writeRequestURI { log.Printf("forwardRequest: %s %s\n", upstreamReq.Host, upstreamReq.URL.String()) } diff --git a/gateway/handlers/serviceauthinjector.go b/gateway/handlers/serviceauthinjector.go new file mode 100644 index 00000000..21a03881 --- /dev/null +++ b/gateway/handlers/serviceauthinjector.go @@ -0,0 +1,7 @@ +package handlers + +import "net/http" + +type AuthInjector interface { + Inject(r *http.Request) +} diff --git a/gateway/plugin/external.go b/gateway/plugin/external.go index 2ad3d3b7..b3e14df8 100644 --- a/gateway/plugin/external.go +++ b/gateway/plugin/external.go @@ -15,13 +15,13 @@ import ( "strconv" "time" - "github.com/openfaas/faas-provider/auth" + "github.com/openfaas/faas/gateway/handlers" "github.com/openfaas/faas/gateway/requests" "github.com/openfaas/faas/gateway/scaling" ) // NewExternalServiceQuery proxies service queries to external plugin via HTTP -func NewExternalServiceQuery(externalURL url.URL, credentials *auth.BasicAuthCredentials) scaling.ServiceQuery { +func NewExternalServiceQuery(externalURL url.URL, authInjector handlers.AuthInjector) scaling.ServiceQuery { timeout := 3 * time.Second proxyClient := http.Client{ @@ -39,17 +39,17 @@ func NewExternalServiceQuery(externalURL url.URL, credentials *auth.BasicAuthCre } return ExternalServiceQuery{ - URL: externalURL, - ProxyClient: proxyClient, - Credentials: credentials, + URL: externalURL, + ProxyClient: proxyClient, + AuthInjector: authInjector, } } // ExternalServiceQuery proxies service queries to external plugin via HTTP type ExternalServiceQuery struct { - URL url.URL - ProxyClient http.Client - Credentials *auth.BasicAuthCredentials + URL url.URL + ProxyClient http.Client + AuthInjector handlers.AuthInjector } // ScaleServiceRequest request scaling of replica @@ -71,8 +71,8 @@ func (s ExternalServiceQuery) GetReplicas(serviceName string) (scaling.ServiceQu req, _ := http.NewRequest(http.MethodGet, urlPath, nil) - if s.Credentials != nil { - req.SetBasicAuth(s.Credentials.User, s.Credentials.Password) + if s.AuthInjector != nil { + s.AuthInjector.Inject(req) } res, err := s.ProxyClient.Do(req) @@ -144,8 +144,8 @@ func (s ExternalServiceQuery) SetReplicas(serviceName string, count uint64) erro urlPath := fmt.Sprintf("%ssystem/scale-function/%s", s.URL.String(), serviceName) req, _ := http.NewRequest(http.MethodPost, urlPath, bytes.NewReader(requestBody)) - if s.Credentials != nil { - req.SetBasicAuth(s.Credentials.User, s.Credentials.Password) + if s.AuthInjector != nil { + s.AuthInjector.Inject(req) } defer req.Body.Close() diff --git a/gateway/plugin/external_test.go b/gateway/plugin/external_test.go index 851c0585..d16e80b3 100644 --- a/gateway/plugin/external_test.go +++ b/gateway/plugin/external_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "github.com/openfaas/faas-provider/auth" + "github.com/openfaas/faas/gateway/handlers" "github.com/openfaas/faas/gateway/scaling" ) @@ -47,11 +47,10 @@ func TestGetReplicasNonExistentFn(t *testing.T) { })) defer testServer.Close() - var creds auth.BasicAuthCredentials - + var injector handlers.AuthInjector url, _ := url.Parse(testServer.URL + "/") - esq := NewExternalServiceQuery(*url, &creds) + esq := NewExternalServiceQuery(*url, injector) svcQryResp, err := esq.GetReplicas("burt") @@ -78,11 +77,10 @@ func TestGetReplicasExistentFn(t *testing.T) { AvailableReplicas: 0, } - var creds auth.BasicAuthCredentials - + var injector handlers.AuthInjector url, _ := url.Parse(testServer.URL + "/") - esq := NewExternalServiceQuery(*url, &creds) + esq := NewExternalServiceQuery(*url, injector) svcQryResp, err := esq.GetReplicas("burt") @@ -104,9 +102,9 @@ func TestSetReplicasNonExistentFn(t *testing.T) { })) defer testServer.Close() - var creds auth.BasicAuthCredentials + var injector handlers.AuthInjector url, _ := url.Parse(testServer.URL + "/") - esq := NewExternalServiceQuery(*url, &creds) + esq := NewExternalServiceQuery(*url, injector) err := esq.SetReplicas("burt", 1) @@ -126,9 +124,10 @@ func TestSetReplicasExistentFn(t *testing.T) { })) defer testServer.Close() - var creds auth.BasicAuthCredentials + var injector handlers.AuthInjector + url, _ := url.Parse(testServer.URL + "/") - esq := NewExternalServiceQuery(*url, &creds) + esq := NewExternalServiceQuery(*url, injector) err := esq.SetReplicas("burt", 1) diff --git a/gateway/server.go b/gateway/server.go index 7377639d..789bd6ca 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -34,6 +34,7 @@ func main() { log.Printf("Binding to external function provider: %s", config.FunctionsProviderURL) + // credentials is used for service-to-service auth var credentials *auth.BasicAuthCredentials if config.UseBasicAuth { @@ -57,12 +58,17 @@ func main() { exporter.StartServiceWatcher(*config.FunctionsProviderURL, metricsOptions, "func", servicePollInterval) metrics.RegisterExporter(exporter) - reverseProxy := types.NewHTTPClientReverseProxy(config.FunctionsProviderURL, config.UpstreamTimeout, config.MaxIdleConns, config.MaxIdleConnsPerHost) + reverseProxy := types.NewHTTPClientReverseProxy(config.FunctionsProviderURL, + config.UpstreamTimeout, + config.MaxIdleConns, + config.MaxIdleConnsPerHost) loggingNotifier := handlers.LoggingNotifier{} + prometheusNotifier := handlers.PrometheusFunctionNotifier{ Metrics: &metricsOptions, } + prometheusServiceNotifier := handlers.PrometheusServiceNotifier{ ServiceMetrics: metricsOptions.ServiceMetrics, } @@ -83,20 +89,22 @@ func main() { functionURLTransformer = nilURLTransformer } + serviceAuthInjector := &BasicAuthInjector{Credentials: credentials} + decorateExternalAuth := handlers.MakeExternalAuthHandler - faasHandlers.Proxy = handlers.MakeForwardingProxyHandler(reverseProxy, functionNotifiers, functionURLResolver, functionURLTransformer) + faasHandlers.Proxy = handlers.MakeForwardingProxyHandler(reverseProxy, functionNotifiers, functionURLResolver, functionURLTransformer, nil) - faasHandlers.RoutelessProxy = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) - faasHandlers.ListFunctions = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) - faasHandlers.DeployFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) - faasHandlers.DeleteFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) - faasHandlers.UpdateFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) - faasHandlers.QueryFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) - faasHandlers.InfoHandler = handlers.MakeInfoHandler(handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer)) - faasHandlers.SecretHandler = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) + faasHandlers.RoutelessProxy = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.ListFunctions = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.DeployFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.DeleteFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.UpdateFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.QueryFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.InfoHandler = handlers.MakeInfoHandler(handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector)) + faasHandlers.SecretHandler = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) - alertHandler := plugin.NewExternalServiceQuery(*config.FunctionsProviderURL, credentials) + alertHandler := plugin.NewExternalServiceQuery(*config.FunctionsProviderURL, serviceAuthInjector) faasHandlers.Alert = handlers.MakeNotifierWrapper( handlers.MakeAlertHandler(alertHandler), forwardingNotifiers, @@ -129,7 +137,7 @@ func main() { faasHandlers.ListFunctions = metrics.AddMetricsHandler(faasHandlers.ListFunctions, prometheusQuery) faasHandlers.Proxy = handlers.MakeCallIDMiddleware(faasHandlers.Proxy) - faasHandlers.ScaleFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer) + faasHandlers.ScaleFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) if credentials != nil { faasHandlers.Alert = @@ -211,7 +219,7 @@ func main() { //Start metrics server in a goroutine go runMetricsServer() - r.HandleFunc("/healthz", handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer)).Methods(http.MethodGet) + r.HandleFunc("/healthz", handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector)).Methods(http.MethodGet) r.Handle("/", http.RedirectHandler("/ui/", http.StatusMovedPermanently)).Methods(http.MethodGet) @@ -250,3 +258,11 @@ func runMetricsServer() { log.Fatal(s.ListenAndServe()) } + +type BasicAuthInjector struct { + Credentials *auth.BasicAuthCredentials +} + +func (b BasicAuthInjector) Inject(r *http.Request) { + r.SetBasicAuth(b.Credentials.User, b.Credentials.Password) +}