diff --git a/gateway/main.go b/gateway/main.go index 39c91eda..3fdeb0fb 100644 --- a/gateway/main.go +++ b/gateway/main.go @@ -168,7 +168,7 @@ func main() { prometheusQuery := metrics.NewPrometheusQuery(config.PrometheusHost, config.PrometheusPort, &http.Client{}) faasHandlers.ListFunctions = metrics.AddMetricsHandler(faasHandlers.ListFunctions, prometheusQuery) - faasHandlers.ScaleFunction = handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector) + faasHandlers.ScaleFunction = scaling.MakeHorizontalScalingHandler(handlers.MakeForwardingProxyHandler(reverseProxy, forwardingNotifiers, urlResolver, nilURLTransformer, serviceAuthInjector)) if credentials != nil { faasHandlers.Alert = diff --git a/gateway/scaling/ranges.go b/gateway/scaling/ranges.go index 4e6cab2c..33850249 100644 --- a/gateway/scaling/ranges.go +++ b/gateway/scaling/ranges.go @@ -1,5 +1,14 @@ package scaling +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + + "github.com/openfaas/faas-provider/types" +) + const ( // DefaultMinReplicas is the minimal amount of replicas for a service. DefaultMinReplicas = 1 @@ -21,3 +30,43 @@ const ( // ScalingFactorLabel label indicates the scaling factor for a function ScalingFactorLabel = "com.openfaas.scale.factor" ) + +func MakeHorizontalScalingHandler(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Only POST is allowed", http.StatusMethodNotAllowed) + return + } + + if r.Body == nil { + http.Error(w, "Error reading request body", http.StatusBadRequest) + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading request body", http.StatusBadRequest) + return + } + + scaleRequest := types.ScaleServiceRequest{} + if err := json.Unmarshal(body, &scaleRequest); err != nil { + http.Error(w, "Error unmarshalling request body", http.StatusBadRequest) + return + } + + if scaleRequest.Replicas < 1 { + scaleRequest.Replicas = 1 + } + + if scaleRequest.Replicas > DefaultMaxReplicas { + scaleRequest.Replicas = DefaultMaxReplicas + } + + upstreamReq, _ := json.Marshal(scaleRequest) + // Restore the io.ReadCloser to its original state + r.Body = ioutil.NopCloser(bytes.NewBuffer(upstreamReq)) + + next.ServeHTTP(w, r) + } +}