diff --git a/gateway/scaling/function_query.go b/gateway/scaling/function_query.go index 26639424..7f40b448 100644 --- a/gateway/scaling/function_query.go +++ b/gateway/scaling/function_query.go @@ -3,12 +3,16 @@ package scaling -import "fmt" +import ( + "fmt" + "log" +) type CachedFunctionQuery struct { cache FunctionCacher serviceQuery ServiceQuery emptyAnnotations map[string]string + singleFlight *SingleFlight } func NewCachedFunctionQuery(cache FunctionCacher, serviceQuery ServiceQuery) FunctionQuery { @@ -16,6 +20,7 @@ func NewCachedFunctionQuery(cache FunctionCacher, serviceQuery ServiceQuery) Fun cache: cache, serviceQuery: serviceQuery, emptyAnnotations: map[string]string{}, + singleFlight: NewSingleFlight(), } } @@ -35,13 +40,23 @@ func (c *CachedFunctionQuery) Get(fn string, ns string) (ServiceQueryResponse, e query, hit := c.cache.Get(fn, ns) if !hit { + key := fmt.Sprintf("GetReplicas-%s.%s", fn, ns) + queryResponse, err := c.singleFlight.Do(key, func() (interface{}, error) { + log.Printf("Cache miss - run GetReplicas") + // If there is a cache miss, then fetch the value from the provider API + return c.serviceQuery.GetReplicas(fn, ns) + }) + + log.Printf("Result: %v %v", queryResponse, err) - // If there is a cache miss, then fetch the value from the provider API - queryResponse, err := c.serviceQuery.GetReplicas(fn, ns) if err != nil { return ServiceQueryResponse{}, err } - c.cache.Set(fn, ns, queryResponse) + + if queryResponse != nil { + c.cache.Set(fn, ns, queryResponse.(ServiceQueryResponse)) + } + } else { return query, nil } diff --git a/gateway/scaling/function_scaler.go b/gateway/scaling/function_scaler.go index bac88fb9..8a3fb824 100644 --- a/gateway/scaling/function_scaler.go +++ b/gateway/scaling/function_scaler.go @@ -10,15 +10,17 @@ import ( // ScalingConfig func NewFunctionScaler(config ScalingConfig, functionCacher FunctionCacher) FunctionScaler { return FunctionScaler{ - Cache: functionCacher, - Config: config, + Cache: functionCacher, + Config: config, + SingleFlight: NewSingleFlight(), } } // FunctionScaler scales from zero type FunctionScaler struct { - Cache FunctionCacher - Config ScalingConfig + Cache FunctionCacher + Config ScalingConfig + SingleFlight *SingleFlight } // FunctionScaleResult holds the result of scaling from zero @@ -43,8 +45,11 @@ func (f *FunctionScaler) Scale(functionName, namespace string) FunctionScaleResu Duration: time.Since(start), } } + getKey := fmt.Sprintf("GetReplicas-%s.%s", functionName, namespace) - queryResponse, err := f.Config.ServiceQuery.GetReplicas(functionName, namespace) + res, err := f.SingleFlight.Do(getKey, func() (interface{}, error) { + return f.Config.ServiceQuery.GetReplicas(functionName, namespace) + }) if err != nil { return FunctionScaleResult{ @@ -54,6 +59,16 @@ func (f *FunctionScaler) Scale(functionName, namespace string) FunctionScaleResu Duration: time.Since(start), } } + if res == nil { + return FunctionScaleResult{ + Error: fmt.Errorf("empty response from server"), + Available: false, + Found: false, + Duration: time.Since(start), + } + } + + queryResponse := res.(ServiceQueryResponse) f.Cache.Set(functionName, namespace, queryResponse) @@ -64,21 +79,35 @@ func (f *FunctionScaler) Scale(functionName, namespace string) FunctionScaleResu } scaleResult := backoff(func(attempt int) error { - queryResponse, err := f.Config.ServiceQuery.GetReplicas(functionName, namespace) + + res, err := f.SingleFlight.Do(getKey, func() (interface{}, error) { + return f.Config.ServiceQuery.GetReplicas(functionName, namespace) + }) + if err != nil { return err } + queryResponse = res.(ServiceQueryResponse) + f.Cache.Set(functionName, namespace, queryResponse) if queryResponse.Replicas > 0 { return nil } - log.Printf("[Scale %d] function=%s 0 => %d requested", attempt, functionName, minReplicas) - setScaleErr := f.Config.ServiceQuery.SetReplicas(functionName, namespace, minReplicas) - if setScaleErr != nil { - return fmt.Errorf("unable to scale function [%s], err: %s", functionName, setScaleErr) + setKey := fmt.Sprintf("SetReplicas-%s.%s", functionName, namespace) + + if _, err := f.SingleFlight.Do(setKey, func() (interface{}, error) { + + log.Printf("[Scale %d] function=%s 0 => %d requested", attempt, functionName, minReplicas) + + if err := f.Config.ServiceQuery.SetReplicas(functionName, namespace, minReplicas); err != nil { + return nil, fmt.Errorf("unable to scale function [%s], err: %s", functionName, err) + } + return nil, nil + }); err != nil { + return err } return nil @@ -95,10 +124,16 @@ func (f *FunctionScaler) Scale(functionName, namespace string) FunctionScaleResu } for i := 0; i < int(f.Config.MaxPollCount); i++ { - queryResponse, err := f.Config.ServiceQuery.GetReplicas(functionName, namespace) + + res, err := f.SingleFlight.Do(getKey, func() (interface{}, error) { + return f.Config.ServiceQuery.GetReplicas(functionName, namespace) + }) + queryResponse := res.(ServiceQueryResponse) + if err == nil { f.Cache.Set(functionName, namespace, queryResponse) } + totalTime := time.Since(start) if err != nil { diff --git a/gateway/scaling/single.go b/gateway/scaling/single.go new file mode 100644 index 00000000..6b5f395a --- /dev/null +++ b/gateway/scaling/single.go @@ -0,0 +1,73 @@ +package scaling + +import ( + "log" + "sync" +) + +type Call struct { + wg *sync.WaitGroup + res *SingleFlightResult +} + +type SingleFlight struct { + lock *sync.RWMutex + calls map[string]*Call +} + +type SingleFlightResult struct { + Result interface{} + Error error +} + +func NewSingleFlight() *SingleFlight { + return &SingleFlight{ + lock: &sync.RWMutex{}, + calls: map[string]*Call{}, + } +} + +func (s *SingleFlight) Do(key string, f func() (interface{}, error)) (interface{}, error) { + + s.lock.Lock() + + if call, ok := s.calls[key]; ok { + s.lock.Unlock() + call.wg.Wait() + + return call.res.Result, call.res.Error + } + + var call *Call + if s.calls[key] == nil { + call = &Call{ + wg: &sync.WaitGroup{}, + } + s.calls[key] = call + } + + call.wg.Add(1) + + s.lock.Unlock() + + go func() { + log.Printf("Miss, so running: %s", key) + res, err := f() + + s.lock.Lock() + call.res = &SingleFlightResult{ + Result: res, + Error: err, + } + + call.wg.Done() + + delete(s.calls, key) + + s.lock.Unlock() + }() + + call.wg.Wait() + + return call.res.Result, call.res.Error +}