diff --git a/gateway/handlers/alerthandler.go b/gateway/handlers/alerthandler.go index 10292fe6..31763797 100644 --- a/gateway/handlers/alerthandler.go +++ b/gateway/handlers/alerthandler.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "log" "net/http" - "strconv" "fmt" @@ -22,8 +21,73 @@ import ( // DefaultMaxReplicas is the amount of replicas a service will auto-scale up to. const DefaultMaxReplicas = 20 +type ServiceQuery interface { + GetReplicas(service string) (currentReplicas uint64, maxReplicas uint64, err error) + SetReplicas(service string, count uint64) error +} + +// NewSwarmServiceQuery create new Docker Swarm implementation +func NewSwarmServiceQuery(c *client.Client) ServiceQuery { + return SwarmServiceQuery{ + c: c, + } +} + +// SwarmServiceQuery Docker Swarm implementation +type SwarmServiceQuery struct { + c *client.Client +} + +// GetReplicas replica count for function +func (s SwarmServiceQuery) GetReplicas(serviceName string) (uint64, uint64, error) { + var err error + var currentReplicas uint64 + maxReplicas := uint64(DefaultMaxReplicas) + opts := types.ServiceInspectOptions{ + InsertDefaults: true, + } + service, _, err := s.c.ServiceInspectWithRaw(context.Background(), serviceName, opts) + if err == nil { + currentReplicas = *service.Spec.Mode.Replicated.Replicas + + replicaLabel := service.Spec.TaskTemplate.ContainerSpec.Labels["com.faas.max_replicas"] + + if len(replicaLabel) > 0 { + maxReplicasLabel, err := strconv.Atoi(replicaLabel) + if err != nil { + log.Printf("Bad replica count: %s, should be uint.\n", replicaLabel) + } else { + maxReplicas = uint64(maxReplicasLabel) + } + } + } + + return currentReplicas, maxReplicas, err +} + +// SetReplicas update the replica count +func (s SwarmServiceQuery) SetReplicas(serviceName string, count uint64) error { + opts := types.ServiceInspectOptions{ + InsertDefaults: true, + } + + service, _, err := s.c.ServiceInspectWithRaw(context.Background(), serviceName, opts) + if err == nil { + + service.Spec.Mode.Replicated.Replicas = &count + updateOpts := types.ServiceUpdateOptions{} + updateOpts.RegistryAuthFrom = types.RegistryAuthFromSpec + + _, updateErr := s.c.ServiceUpdate(context.Background(), service.ID, service.Version, service.Spec, updateOpts) + if updateErr != nil { + err = updateErr + } + } + return err +} + // MakeAlertHandler handles alerts from Prometheus Alertmanager -func MakeAlertHandler(c *client.Client) http.HandlerFunc { +func MakeAlertHandler(sq ServiceQuery) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { log.Println("Alert received.") @@ -48,7 +112,7 @@ func MakeAlertHandler(c *client.Client) http.HandlerFunc { return } - errors := handleAlerts(&req, c) + errors := handleAlerts(&req, sq) if len(errors) > 0 { log.Println(errors) w.WriteHeader(http.StatusInternalServerError) @@ -64,10 +128,10 @@ func MakeAlertHandler(c *client.Client) http.HandlerFunc { } } -func handleAlerts(req *requests.PrometheusAlert, c *client.Client) []error { +func handleAlerts(req *requests.PrometheusAlert, sq ServiceQuery) []error { var errors []error for _, alert := range req.Alerts { - if err := scaleService(alert, c); err != nil { + if err := scaleService(alert, sq); err != nil { log.Println(err) errors = append(errors, err) } @@ -76,47 +140,25 @@ func handleAlerts(req *requests.PrometheusAlert, c *client.Client) []error { return errors } -func scaleService(alert requests.PrometheusInnerAlert, c *client.Client) error { +func scaleService(alert requests.PrometheusInnerAlert, sq ServiceQuery) error { var err error serviceName := alert.Labels.FunctionName if len(serviceName) > 0 { - opts := types.ServiceInspectOptions{ - InsertDefaults: true, - } - - service, _, inspectErr := c.ServiceInspectWithRaw(context.Background(), serviceName, opts) - if inspectErr == nil { - - currentReplicas := *service.Spec.Mode.Replicated.Replicas + currentReplicas, maxReplicas, getErr := sq.GetReplicas(serviceName) + if getErr == nil { status := alert.Status - replicaLabel := service.Spec.TaskTemplate.ContainerSpec.Labels["com.faas.max_replicas"] - maxReplicas := DefaultMaxReplicas - if len(replicaLabel) > 0 { - maxReplicas, err = strconv.Atoi(replicaLabel) - if err != nil { - log.Printf("Bad replica count: %s, should be uint.\n", replicaLabel) - } - } newReplicas := CalculateReplicas(status, currentReplicas, uint64(maxReplicas)) log.Printf("[Scale] function=%s %d => %d.\n", serviceName, currentReplicas, newReplicas) if newReplicas == currentReplicas { return nil } - - service.Spec.Mode.Replicated.Replicas = &newReplicas - updateOpts := types.ServiceUpdateOptions{} - updateOpts.RegistryAuthFrom = types.RegistryAuthFromSpec - - _, updateErr := c.ServiceUpdate(context.Background(), service.ID, service.Version, service.Spec, updateOpts) + updateErr := sq.SetReplicas(serviceName, newReplicas) if updateErr != nil { err = updateErr } - - } else { - err = inspectErr } } return err diff --git a/gateway/server.go b/gateway/server.go index 4ba3df30..450ed885 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -77,7 +77,7 @@ func main() { } else { faasHandlers.Proxy = internalHandlers.MakeProxy(metricsOptions, true, dockerClient, &logger) faasHandlers.RoutelessProxy = internalHandlers.MakeProxy(metricsOptions, true, dockerClient, &logger) - faasHandlers.Alert = internalHandlers.MakeAlertHandler(dockerClient) + faasHandlers.Alert = internalHandlers.MakeAlertHandler(internalHandlers.NewSwarmServiceQuery(dockerClient)) faasHandlers.ListFunctions = internalHandlers.MakeFunctionReader(metricsOptions, dockerClient) faasHandlers.DeployFunction = internalHandlers.MakeNewFunctionHandler(metricsOptions, dockerClient) faasHandlers.DeleteFunction = internalHandlers.MakeDeleteFunctionHandler(metricsOptions, dockerClient)