Refactoring: variable names, adding tests and http constants

Signed-off-by: Alex Ellis <alexellis2@gmail.com>
This commit is contained in:
Alex Ellis
2017-12-01 18:19:54 +00:00
parent 2452fdea0b
commit 23a7187435
19 changed files with 166 additions and 80 deletions

View File

@ -21,6 +21,13 @@ import (
// DefaultMaxReplicas is the amount of replicas a service will auto-scale up to. // DefaultMaxReplicas is the amount of replicas a service will auto-scale up to.
const DefaultMaxReplicas = 20 const DefaultMaxReplicas = 20
// MinScaleLabel label indicating min scale for a function
const MinScaleLabel = "com.openfaas.scale.min"
// MaxScaleLabel label indicating max scale for a function
const MaxScaleLabel = "com.openfaas.scale.max"
// ServiceQuery provides interface for replica querying/setting
type ServiceQuery interface { type ServiceQuery interface {
GetReplicas(service string) (currentReplicas uint64, maxReplicas uint64, minReplicas uint64, err error) GetReplicas(service string) (currentReplicas uint64, maxReplicas uint64, minReplicas uint64, err error)
SetReplicas(service string, count uint64) error SetReplicas(service string, count uint64) error
@ -33,7 +40,7 @@ func NewSwarmServiceQuery(c *client.Client) ServiceQuery {
} }
} }
// SwarmServiceQuery Docker Swarm implementation // SwarmServiceQuery implementation for Docker Swarm
type SwarmServiceQuery struct { type SwarmServiceQuery struct {
c *client.Client c *client.Client
} }
@ -42,21 +49,21 @@ type SwarmServiceQuery struct {
func (s SwarmServiceQuery) GetReplicas(serviceName string) (uint64, uint64, uint64, error) { func (s SwarmServiceQuery) GetReplicas(serviceName string) (uint64, uint64, uint64, error) {
var err error var err error
var currentReplicas uint64 var currentReplicas uint64
maxReplicas := uint64(DefaultMaxReplicas) maxReplicas := uint64(DefaultMaxReplicas)
minReplicas := uint64(1) minReplicas := uint64(1)
opts := types.ServiceInspectOptions{ opts := types.ServiceInspectOptions{
InsertDefaults: true, InsertDefaults: true,
} }
service, _, err := s.c.ServiceInspectWithRaw(context.Background(), serviceName, opts) service, _, err := s.c.ServiceInspectWithRaw(context.Background(), serviceName, opts)
if err == nil { if err == nil {
currentReplicas = *service.Spec.Mode.Replicated.Replicas currentReplicas = *service.Spec.Mode.Replicated.Replicas
log.Println("service.Spec.Annotations.Labels ", service.Spec.Annotations.Labels)
log.Println("service.Spec.TaskTemplate.ContainerSpec.Labels ", service.Spec.TaskTemplate.ContainerSpec.Labels)
log.Println("service.Spec.Labels ", service.Spec.Labels)
minScale := service.Spec.Annotations.Labels["com.openfaas.scale.min"] minScale := service.Spec.Annotations.Labels[MinScaleLabel]
maxScale := service.Spec.Annotations.Labels["com.openfaas.scale.max"] maxScale := service.Spec.Annotations.Labels[MaxScaleLabel]
if len(maxScale) > 0 { if len(maxScale) > 0 {
labelValue, err := strconv.Atoi(maxScale) labelValue, err := strconv.Atoi(maxScale)
@ -98,12 +105,14 @@ func (s SwarmServiceQuery) SetReplicas(serviceName string, count uint64) error {
err = updateErr err = updateErr
} }
} }
return err return err
} }
// MakeAlertHandler handles alerts from Prometheus Alertmanager // MakeAlertHandler handles alerts from Prometheus Alertmanager
func MakeAlertHandler(sq ServiceQuery) http.HandlerFunc { func MakeAlertHandler(service ServiceQuery) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
log.Println("Alert received.") log.Println("Alert received.")
body, readErr := ioutil.ReadAll(r.Body) body, readErr := ioutil.ReadAll(r.Body)
@ -127,7 +136,7 @@ func MakeAlertHandler(sq ServiceQuery) http.HandlerFunc {
return return
} }
errors := handleAlerts(&req, sq) errors := handleAlerts(&req, service)
if len(errors) > 0 { if len(errors) > 0 {
log.Println(errors) log.Println(errors)
var errorOutput string var errorOutput string
@ -143,10 +152,10 @@ func MakeAlertHandler(sq ServiceQuery) http.HandlerFunc {
} }
} }
func handleAlerts(req *requests.PrometheusAlert, sq ServiceQuery) []error { func handleAlerts(req *requests.PrometheusAlert, service ServiceQuery) []error {
var errors []error var errors []error
for _, alert := range req.Alerts { for _, alert := range req.Alerts {
if err := scaleService(alert, sq); err != nil { if err := scaleService(alert, service); err != nil {
log.Println(err) log.Println(err)
errors = append(errors, err) errors = append(errors, err)
} }
@ -155,12 +164,12 @@ func handleAlerts(req *requests.PrometheusAlert, sq ServiceQuery) []error {
return errors return errors
} }
func scaleService(alert requests.PrometheusInnerAlert, sq ServiceQuery) error { func scaleService(alert requests.PrometheusInnerAlert, service ServiceQuery) error {
var err error var err error
serviceName := alert.Labels.FunctionName serviceName := alert.Labels.FunctionName
if len(serviceName) > 0 { if len(serviceName) > 0 {
currentReplicas, maxReplicas, minReplicas, getErr := sq.GetReplicas(serviceName) currentReplicas, maxReplicas, minReplicas, getErr := service.GetReplicas(serviceName)
if getErr == nil { if getErr == nil {
status := alert.Status status := alert.Status
@ -171,7 +180,7 @@ func scaleService(alert requests.PrometheusInnerAlert, sq ServiceQuery) error {
return nil return nil
} }
updateErr := sq.SetReplicas(serviceName, newReplicas) updateErr := service.SetReplicas(serviceName, newReplicas)
if updateErr != nil { if updateErr != nil {
err = updateErr err = updateErr
} }

View File

@ -2,12 +2,13 @@ package handlers
import "net/http" import "net/http"
type CorsHandler struct { // CORSHandler set custom CORS instructions for the store.
type CORSHandler struct {
Upstream *http.Handler Upstream *http.Handler
AllowedHost string AllowedHost string
} }
func (c CorsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (c CORSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// https://raw.githubusercontent.com/openfaas/store/master/store.json // https://raw.githubusercontent.com/openfaas/store/master/store.json
w.Header().Set("Access-Control-Allow-Headers", "Content-Type") w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET") w.Header().Set("Access-Control-Allow-Methods", "GET")
@ -16,8 +17,9 @@ func (c CorsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
(*c.Upstream).ServeHTTP(w, r) (*c.Upstream).ServeHTTP(w, r)
} }
// DecorateWithCORS decorate a handler with CORS-injecting middleware
func DecorateWithCORS(upstream http.Handler, allowedHost string) http.Handler { func DecorateWithCORS(upstream http.Handler, allowedHost string) http.Handler {
return CorsHandler{ return CORSHandler{
Upstream: &upstream, Upstream: &upstream,
AllowedHost: allowedHost, AllowedHost: allowedHost,
} }

View File

@ -135,6 +135,7 @@ func buildEnv(envProcess string, envVars map[string]string) []string {
if len(envProcess) > 0 { if len(envProcess) > 0 {
env = append(env, fmt.Sprintf("fprocess=%s", envProcess)) env = append(env, fmt.Sprintf("fprocess=%s", envProcess))
} }
for k, v := range envVars { for k, v := range envVars {
env = append(env, fmt.Sprintf("%s=%s", k, v)) env = append(env, fmt.Sprintf("%s=%s", k, v))
} }

View File

@ -25,7 +25,8 @@ func MakeForwardingProxyHandler(proxy *httputil.ReverseProxy, metrics *metrics.M
proxy.ServeHTTP(writeAdapter, r) proxy.ServeHTTP(writeAdapter, r)
seconds := time.Since(start).Seconds() seconds := time.Since(start).Seconds()
log.Printf("< [%s] - %d took %f seconds\n", r.URL.String(), writeAdapter.GetHeaderCode(), seconds) log.Printf("< [%s] - %d took %f seconds\n", r.URL.String(),
writeAdapter.GetHeaderCode(), seconds)
forward := "/function/" forward := "/function/"
if startsWith(uri, forward) { if startsWith(uri, forward) {
@ -39,7 +40,9 @@ func MakeForwardingProxyHandler(proxy *httputil.ReverseProxy, metrics *metrics.M
code := strconv.Itoa(writeAdapter.GetHeaderCode()) code := strconv.Itoa(writeAdapter.GetHeaderCode())
metrics.GatewayFunctionInvocation.With(prometheus.Labels{"function_name": service, "code": code}).Inc() metrics.GatewayFunctionInvocation.
With(prometheus.Labels{"function_name": service, "code": code}).
Inc()
} }
} }
} }

View File

@ -26,6 +26,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
const watchdogPort = 8080
// MakeProxy creates a proxy for HTTP web requests which can be routed to a function. // MakeProxy creates a proxy for HTTP web requests which can be routed to a function.
func MakeProxy(metrics metrics.MetricOptions, wildcard bool, client *client.Client, logger *logrus.Logger) http.HandlerFunc { func MakeProxy(metrics metrics.MetricOptions, wildcard bool, client *client.Client, logger *logrus.Logger) http.HandlerFunc {
proxyClient := http.Client{ proxyClient := http.Client{
@ -123,8 +125,6 @@ func invokeService(w http.ResponseWriter, r *http.Request, metrics metrics.Metri
dnsrr = true dnsrr = true
} }
watchdogPort := 8080
addr := service addr := service
// Use DNS-RR via tasks.servicename if enabled as override, otherwise VIP. // Use DNS-RR via tasks.servicename if enabled as override, otherwise VIP.
if dnsrr { if dnsrr {
@ -165,6 +165,7 @@ func invokeService(w http.ResponseWriter, r *http.Request, metrics metrics.Metri
w.Header().Set("Content-Type", GetContentType(response.Header, r.Header, defaultHeader)) w.Header().Set("Content-Type", GetContentType(response.Header, r.Header, defaultHeader))
writeHead(service, metrics, response.StatusCode, w) writeHead(service, metrics, response.StatusCode, w)
if response.Body != nil { if response.Body != nil {
io.Copy(w, response.Body) io.Copy(w, response.Body)
} }
@ -188,10 +189,10 @@ func GetContentType(request http.Header, proxyResponse http.Header, defaultValue
} }
func copyHeaders(destination *http.Header, source *http.Header) { func copyHeaders(destination *http.Header, source *http.Header) {
for k, vv := range *source { for k, v := range *source {
vvClone := make([]string, len(vv)) vClone := make([]string, len(v))
copy(vvClone, vv) copy(vClone, v)
(*destination)[k] = vvClone (*destination)[k] = vClone
} }
} }
@ -207,14 +208,20 @@ func writeHead(service string, metrics metrics.MetricOptions, code int, w http.R
} }
func trackInvocation(service string, metrics metrics.MetricOptions, code int) { func trackInvocation(service string, metrics metrics.MetricOptions, code int) {
metrics.GatewayFunctionInvocation.With(prometheus.Labels{"function_name": service, "code": strconv.Itoa(code)}).Inc() metrics.GatewayFunctionInvocation.With(
prometheus.Labels{"function_name": service,
"code": strconv.Itoa(code)}).Inc()
} }
func trackTime(then time.Time, metrics metrics.MetricOptions, name string) { func trackTime(then time.Time, metrics metrics.MetricOptions, name string) {
since := time.Since(then) since := time.Since(then)
metrics.GatewayFunctionsHistogram.WithLabelValues(name).Observe(since.Seconds()) metrics.GatewayFunctionsHistogram.
WithLabelValues(name).
Observe(since.Seconds())
} }
func trackTimeExact(duration time.Duration, metrics metrics.MetricOptions, name string) { func trackTimeExact(duration time.Duration, metrics metrics.MetricOptions, name string) {
metrics.GatewayFunctionsHistogram.WithLabelValues(name).Observe(float64(duration)) metrics.GatewayFunctionsHistogram.
WithLabelValues(name).
Observe(float64(duration))
} }

View File

@ -44,6 +44,7 @@ func MakeQueuedProxy(metrics metrics.MetricOptions, wildcard bool, logger *logru
callbackURL = urlVal callbackURL = urlVal
} }
req := &queue.Request{ req := &queue.Request{
Function: name, Function: name,
Body: body, Body: body,
@ -54,13 +55,14 @@ func MakeQueuedProxy(metrics metrics.MetricOptions, wildcard bool, logger *logru
} }
err = canQueueRequests.Queue(req) err = canQueueRequests.Queue(req)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error())) w.Write([]byte(err.Error()))
fmt.Println(err) fmt.Println(err)
return return
} }
w.WriteHeader(http.StatusAccepted)
w.WriteHeader(http.StatusAccepted)
} }
} }

View File

@ -42,16 +42,11 @@ func MakeFunctionReader(metricsOptions metrics.MetricOptions, c client.ServiceAP
for _, service := range services { for _, service := range services {
if len(service.Spec.TaskTemplate.ContainerSpec.Labels["function"]) > 0 { if len(service.Spec.TaskTemplate.ContainerSpec.Labels["function"]) > 0 {
var envProcess string envProcess := getEnvProcess(service.Spec.TaskTemplate.ContainerSpec.Env)
for _, env := range service.Spec.TaskTemplate.ContainerSpec.Env {
if strings.Contains(env, "fprocess=") {
envProcess = env[len("fprocess="):]
}
}
// Required (copy by value) // Required (copy by value)
labels := service.Spec.Annotations.Labels labels := service.Spec.Annotations.Labels
f := requests.Function{ f := requests.Function{
Name: service.Spec.Name, Name: service.Spec.Name,
Image: service.Spec.TaskTemplate.ContainerSpec.Image, Image: service.Spec.TaskTemplate.ContainerSpec.Image,
@ -67,7 +62,19 @@ func MakeFunctionReader(metricsOptions metrics.MetricOptions, c client.ServiceAP
functionBytes, _ := json.Marshal(functions) functionBytes, _ := json.Marshal(functions)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200) w.WriteHeader(http.StatusOK)
w.Write(functionBytes) w.Write(functionBytes)
} }
} }
func getEnvProcess(envVars []string) string {
var value string
for _, env := range envVars {
if strings.Contains(env, "fprocess=") {
value = env[len("fprocess="):]
}
}
return value
}

View File

@ -30,14 +30,16 @@ func NewPrometheusQuery(host string, port int, client *http.Client) PrometheusQu
// Fetch queries aggregated stats // Fetch queries aggregated stats
func (q PrometheusQuery) Fetch(query string) (*VectorQueryResponse, error) { func (q PrometheusQuery) Fetch(query string) (*VectorQueryResponse, error) {
req, reqErr := http.NewRequest("GET", fmt.Sprintf("http://%s:%d/api/v1/query/?query=%s", q.Host, q.Port, query), nil) req, reqErr := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s:%d/api/v1/query/?query=%s", q.Host, q.Port, query), nil)
if reqErr != nil { if reqErr != nil {
return nil, reqErr return nil, reqErr
} }
res, getErr := q.Client.Do(req) res, getErr := q.Client.Do(req)
if getErr != nil { if getErr != nil {
return nil, getErr return nil, getErr
} }
defer res.Body.Close() defer res.Body.Close()
bytesOut, readErr := ioutil.ReadAll(res.Body) bytesOut, readErr := ioutil.ReadAll(res.Body)
if readErr != nil { if readErr != nil {

View File

@ -15,8 +15,9 @@ import (
// AttachSwarmWatcher adds a go-route to monitor the amount of service replicas in the swarm // AttachSwarmWatcher adds a go-route to monitor the amount of service replicas in the swarm
// matching a 'function' label. // matching a 'function' label.
func AttachSwarmWatcher(dockerClient *client.Client, metricsOptions MetricOptions, label string) { func AttachSwarmWatcher(dockerClient *client.Client, metricsOptions MetricOptions, label string, interval time.Duration) {
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(interval)
quit := make(chan struct{}) quit := make(chan struct{})
go func() { go func() {
@ -47,5 +48,4 @@ func AttachSwarmWatcher(dockerClient *client.Client, metricsOptions MetricOption
} }
} }
}() }()
} }

View File

@ -55,15 +55,21 @@ func (s ExternalServiceQuery) GetReplicas(serviceName string) (uint64, uint64, u
function := requests.Function{} function := requests.Function{}
urlPath := fmt.Sprintf("%ssystem/function/%s", s.URL.String(), serviceName) urlPath := fmt.Sprintf("%ssystem/function/%s", s.URL.String(), serviceName)
req, _ := http.NewRequest("GET", urlPath, nil)
res, err := s.ProxyClient.Do(req)
if err != nil {
log.Println(urlPath, err)
}
if res.StatusCode == 200 { req, _ := http.NewRequest(http.MethodGet, urlPath, nil)
res, err := s.ProxyClient.Do(req)
if err != nil {
log.Println(urlPath, err)
} else {
if res.Body != nil { if res.Body != nil {
defer res.Body.Close() defer res.Body.Close()
}
if res.StatusCode == http.StatusOK {
bytesOut, _ := ioutil.ReadAll(res.Body) bytesOut, _ := ioutil.ReadAll(res.Body)
err = json.Unmarshal(bytesOut, &function) err = json.Unmarshal(bytesOut, &function)
if err != nil { if err != nil {
@ -77,8 +83,8 @@ func (s ExternalServiceQuery) GetReplicas(serviceName string) (uint64, uint64, u
if function.Labels != nil { if function.Labels != nil {
labels := *function.Labels labels := *function.Labels
minScale := labels["com.openfaas.scale.min"] minScale := labels[handlers.MinScaleLabel]
maxScale := labels["com.openfaas.scale.max"] maxScale := labels[handlers.MaxScaleLabel]
if len(minScale) > 0 { if len(minScale) > 0 {
labelValue, err := strconv.Atoi(minScale) labelValue, err := strconv.Atoi(minScale)
@ -123,7 +129,7 @@ func (s ExternalServiceQuery) SetReplicas(serviceName string, count uint64) erro
} }
urlPath := fmt.Sprintf("%ssystem/scale-function/%s", s.URL.String(), serviceName) urlPath := fmt.Sprintf("%ssystem/scale-function/%s", s.URL.String(), serviceName)
req, _ := http.NewRequest("POST", urlPath, bytes.NewReader(requestBody)) req, _ := http.NewRequest(http.MethodPost, urlPath, bytes.NewReader(requestBody))
defer req.Body.Close() defer req.Body.Close()
res, err := s.ProxyClient.Do(req) res, err := s.ProxyClient.Do(req)

View File

@ -71,6 +71,7 @@ func main() {
metrics.RegisterMetrics(metricsOptions) metrics.RegisterMetrics(metricsOptions)
var faasHandlers handlerSet var faasHandlers handlerSet
servicePollInterval := time.Second * 5
if config.UseExternalProvider() { if config.UseExternalProvider() {
@ -86,7 +87,7 @@ func main() {
alertHandler := plugin.NewExternalServiceQuery(*config.FunctionsProviderURL) alertHandler := plugin.NewExternalServiceQuery(*config.FunctionsProviderURL)
faasHandlers.Alert = internalHandlers.MakeAlertHandler(alertHandler) faasHandlers.Alert = internalHandlers.MakeAlertHandler(alertHandler)
metrics.AttachExternalWatcher(*config.FunctionsProviderURL, metricsOptions, "func", time.Second*5) metrics.AttachExternalWatcher(*config.FunctionsProviderURL, metricsOptions, "func", servicePollInterval)
} else { } else {
@ -106,7 +107,7 @@ func main() {
// This could exist in a separate process - records the replicas of each swarm service. // This could exist in a separate process - records the replicas of each swarm service.
functionLabel := "function" functionLabel := "function"
metrics.AttachSwarmWatcher(dockerClient, metricsOptions, functionLabel) metrics.AttachSwarmWatcher(dockerClient, metricsOptions, functionLabel, servicePollInterval)
} }
if config.UseNATS() { if config.UseNATS() {

View File

@ -11,7 +11,7 @@ import (
func TestScale1to5(t *testing.T) { func TestScale1to5(t *testing.T) {
minReplicas := uint64(1) minReplicas := uint64(1)
newReplicas := handlers.CalculateReplicas("firing", 1, 20, minReplicas) newReplicas := handlers.CalculateReplicas("firing", 1, handlers.DefaultMaxReplicas, minReplicas)
if newReplicas != 5 { if newReplicas != 5 {
t.Log("Expected increment in blocks of 5 from 1 to 5") t.Log("Expected increment in blocks of 5 from 1 to 5")
t.Fail() t.Fail()
@ -20,7 +20,7 @@ func TestScale1to5(t *testing.T) {
func TestScale5to10(t *testing.T) { func TestScale5to10(t *testing.T) {
minReplicas := uint64(1) minReplicas := uint64(1)
newReplicas := handlers.CalculateReplicas("firing", 5, 20, minReplicas) newReplicas := handlers.CalculateReplicas("firing", 5, handlers.DefaultMaxReplicas, minReplicas)
if newReplicas != 10 { if newReplicas != 10 {
t.Log("Expected increment in blocks of 5 from 5 to 10") t.Log("Expected increment in blocks of 5 from 5 to 10")
t.Fail() t.Fail()
@ -29,7 +29,7 @@ func TestScale5to10(t *testing.T) {
func TestScaleCeilingOf20Replicas_Noaction(t *testing.T) { func TestScaleCeilingOf20Replicas_Noaction(t *testing.T) {
minReplicas := uint64(1) minReplicas := uint64(1)
newReplicas := handlers.CalculateReplicas("firing", 20, 20, minReplicas) newReplicas := handlers.CalculateReplicas("firing", 20, handlers.DefaultMaxReplicas, minReplicas)
if newReplicas != 20 { if newReplicas != 20 {
t.Log("Expected ceiling of 20 replicas") t.Log("Expected ceiling of 20 replicas")
t.Fail() t.Fail()
@ -38,7 +38,7 @@ func TestScaleCeilingOf20Replicas_Noaction(t *testing.T) {
func TestScaleCeilingOf20Replicas(t *testing.T) { func TestScaleCeilingOf20Replicas(t *testing.T) {
minReplicas := uint64(1) minReplicas := uint64(1)
newReplicas := handlers.CalculateReplicas("firing", 19, 20, minReplicas) newReplicas := handlers.CalculateReplicas("firing", 19, handlers.DefaultMaxReplicas, minReplicas)
if newReplicas != 20 { if newReplicas != 20 {
t.Log("Expected ceiling of 20 replicas") t.Log("Expected ceiling of 20 replicas")
t.Fail() t.Fail()
@ -47,7 +47,7 @@ func TestScaleCeilingOf20Replicas(t *testing.T) {
func TestBackingOff10to1(t *testing.T) { func TestBackingOff10to1(t *testing.T) {
minReplicas := uint64(1) minReplicas := uint64(1)
newReplicas := handlers.CalculateReplicas("resolved", 10, 20, minReplicas) newReplicas := handlers.CalculateReplicas("resolved", 10, handlers.DefaultMaxReplicas, minReplicas)
if newReplicas != 1 { if newReplicas != 1 {
t.Log("Expected backing off to 1 replica") t.Log("Expected backing off to 1 replica")
t.Fail() t.Fail()

View File

@ -0,0 +1,36 @@
package tests
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/openfaas/faas/gateway/handlers"
)
type customHandler struct {
}
func (h customHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func Test_HeadersAdded(t *testing.T) {
rr := httptest.NewRecorder()
handler := customHandler{}
host := "store.openfaas.com"
decorated := handlers.DecorateWithCORS(handler, host)
request, _ := http.NewRequest(http.MethodGet, "/", nil)
decorated.ServeHTTP(rr, request)
actual := rr.Header().Get("Access-Control-Allow-Origin")
if actual != host {
t.Errorf("Access-Control-Allow-Origin: want: %s got: %s", host, actual)
}
actualMethods := rr.Header().Get("Access-Control-Allow-Methods")
if actualMethods != "GET" {
t.Errorf("Access-Control-Allow-Methods: want: %s got: %s", "GET", actualMethods)
}
}

View File

@ -29,7 +29,6 @@ func Test_GetContentType_UsesRequest_WhenResponseEmpty(t *testing.T) {
if contentType != request.Get("Content-Type") { if contentType != request.Get("Content-Type") {
t.Errorf("Got: %s, want: %s", contentType, request.Get("Content-Type")) t.Errorf("Got: %s, want: %s", contentType, request.Get("Content-Type"))
} }
} }
func Test_GetContentType_UsesDefaultWhenRequestResponseEmpty(t *testing.T) { func Test_GetContentType_UsesDefaultWhenRequestResponseEmpty(t *testing.T) {
@ -42,5 +41,4 @@ func Test_GetContentType_UsesDefaultWhenRequestResponseEmpty(t *testing.T) {
if contentType != "default" { if contentType != "default" {
t.Errorf("Got: %s, want: %s", contentType, "default") t.Errorf("Got: %s, want: %s", contentType, "default")
} }
} }

View File

@ -167,6 +167,7 @@ func TestReaderSuccessReturnsCorrectBodyWithOneFunction(t *testing.T) {
}, },
}, },
} }
marshalled, _ := json.Marshal(functions) marshalled, _ := json.Marshal(functions)
expected := string(marshalled) expected := string(marshalled)
if w.Body.String() != expected { if w.Body.String() != expected {

View File

@ -15,14 +15,14 @@ import (
func TestBuildEncodedAuthConfig(t *testing.T) { func TestBuildEncodedAuthConfig(t *testing.T) {
// custom repository with valid data // custom repository with valid data
assertValidEncodedAuthConfig(t, "user", "password", "my.repository.com/user/imagename", "my.repository.com") testValidEncodedAuthConfig(t, "user", "password", "my.repository.com/user/imagename", "my.repository.com")
assertValidEncodedAuthConfig(t, "user", "weird:password:", "my.repository.com/user/imagename", "my.repository.com") testValidEncodedAuthConfig(t, "user", "weird:password:", "my.repository.com/user/imagename", "my.repository.com")
assertValidEncodedAuthConfig(t, "userWithNoPassword", "", "my.repository.com/user/imagename", "my.repository.com") testValidEncodedAuthConfig(t, "userWithNoPassword", "", "my.repository.com/user/imagename", "my.repository.com")
assertValidEncodedAuthConfig(t, "", "", "my.repository.com/user/imagename", "my.repository.com") testValidEncodedAuthConfig(t, "", "", "my.repository.com/user/imagename", "my.repository.com")
// docker hub default repository // docker hub default repository
assertValidEncodedAuthConfig(t, "user", "password", "user/imagename", "docker.io") testValidEncodedAuthConfig(t, "user", "password", "user/imagename", "docker.io")
assertValidEncodedAuthConfig(t, "", "", "user/imagename", "docker.io") testValidEncodedAuthConfig(t, "", "", "user/imagename", "docker.io")
// invalid base64 basic auth // invalid base64 basic auth
assertEncodedAuthError(t, "invalidBasicAuth", "my.repository.com/user/imagename") assertEncodedAuthError(t, "invalidBasicAuth", "my.repository.com/user/imagename")
@ -32,7 +32,7 @@ func TestBuildEncodedAuthConfig(t *testing.T) {
assertEncodedAuthError(t, b64BasicAuth("user", "password"), "invalid name") assertEncodedAuthError(t, b64BasicAuth("user", "password"), "invalid name")
} }
func assertValidEncodedAuthConfig(t *testing.T, user, password, imageName, expectedRegistryHost string) { func testValidEncodedAuthConfig(t *testing.T, user, password, imageName, expectedRegistryHost string) {
encodedAuthConfig, err := handlers.BuildEncodedAuthConfig(b64BasicAuth(user, password), imageName) encodedAuthConfig, err := handlers.BuildEncodedAuthConfig(b64BasicAuth(user, password), imageName)
if err != nil { if err != nil {
t.Log("Unexpected error while building auth config with correct values") t.Log("Unexpected error while building auth config with correct values")
@ -50,10 +50,12 @@ func assertValidEncodedAuthConfig(t *testing.T, user, password, imageName, expec
t.Log("Auth config username mismatch", user, authConfig.Username) t.Log("Auth config username mismatch", user, authConfig.Username)
t.Fail() t.Fail()
} }
if password != authConfig.Password { if password != authConfig.Password {
t.Log("Auth config password mismatch", password, authConfig.Password) t.Log("Auth config password mismatch", password, authConfig.Password)
t.Fail() t.Fail()
} }
if expectedRegistryHost != authConfig.ServerAddress { if expectedRegistryHost != authConfig.ServerAddress {
t.Log("Auth config registry server address mismatch", expectedRegistryHost, authConfig.ServerAddress) t.Log("Auth config registry server address mismatch", expectedRegistryHost, authConfig.ServerAddress)
t.Fail() t.Fail()

View File

@ -6,6 +6,8 @@ import (
"github.com/openfaas/faas/gateway/handlers" "github.com/openfaas/faas/gateway/handlers"
) )
// Test_ParseMemory exploratory testing to document how to convert
// from Docker limits notation to bytes value.
func Test_ParseMemory(t *testing.T) { func Test_ParseMemory(t *testing.T) {
value := "512 m" value := "512 m"

View File

@ -4,11 +4,9 @@
package tests package tests
import ( import (
"testing"
"io/ioutil"
"encoding/json" "encoding/json"
"io/ioutil"
"testing"
"github.com/openfaas/faas/gateway/requests" "github.com/openfaas/faas/gateway/requests"
) )
@ -19,21 +17,27 @@ func TestUnmarshallAlert(t *testing.T) {
var alert requests.PrometheusAlert var alert requests.PrometheusAlert
err := json.Unmarshal(file, &alert) err := json.Unmarshal(file, &alert)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if (len(alert.Status)) == 0 { if (len(alert.Status)) == 0 {
t.Fatal("No status read") t.Fatal("No status read")
} }
if (len(alert.Receiver)) == 0 { if (len(alert.Receiver)) == 0 {
t.Fatal("No status read") t.Fatal("No status read")
} }
if (len(alert.Alerts)) == 0 { if (len(alert.Alerts)) == 0 {
t.Fatal("No alerts read") t.Fatal("No alerts read")
} }
if (len(alert.Alerts[0].Labels.AlertName)) == 0 { if (len(alert.Alerts[0].Labels.AlertName)) == 0 {
t.Fatal("No alerts name") t.Fatal("No alerts name")
} }
if (len(alert.Alerts[0].Labels.FunctionName)) == 0 { if (len(alert.Alerts[0].Labels.FunctionName)) == 0 {
t.Fatal("No function name read") t.Fatal("No function name read")
} }

View File

@ -4,22 +4,24 @@
package types package types
import ( import (
"fmt" "log"
"net/http" "net/http"
) )
// WriteAdapter adapts a ResponseWriter // WriteAdapter adapts a ResponseWriter
type WriteAdapter struct { type WriteAdapter struct {
Writer http.ResponseWriter Writer http.ResponseWriter
HttpResult *HttpResult HTTPResult *HTTPResult
} }
type HttpResult struct {
HeaderCode int // HTTPResult captures data from forwarded HTTP call
type HTTPResult struct {
HeaderCode int // HeaderCode is the result of WriteHeader(int)
} }
//NewWriteAdapter create a new NewWriteAdapter //NewWriteAdapter create a new NewWriteAdapter
func NewWriteAdapter(w http.ResponseWriter) WriteAdapter { func NewWriteAdapter(w http.ResponseWriter) WriteAdapter {
return WriteAdapter{Writer: w, HttpResult: &HttpResult{}} return WriteAdapter{Writer: w, HTTPResult: &HTTPResult{}}
} }
//Header adapts Header //Header adapts Header
@ -27,19 +29,20 @@ func (w WriteAdapter) Header() http.Header {
return w.Writer.Header() return w.Writer.Header()
} }
// Write adapts Write // Write adapts Write for a straight pass-through
func (w WriteAdapter) Write(data []byte) (int, error) { func (w WriteAdapter) Write(data []byte) (int, error) {
return w.Writer.Write(data) return w.Writer.Write(data)
} }
// WriteHeader adapts WriteHeader // WriteHeader adapts WriteHeader
func (w WriteAdapter) WriteHeader(i int) { func (w WriteAdapter) WriteHeader(statusCode int) {
w.Writer.WriteHeader(i) w.Writer.WriteHeader(statusCode)
w.HttpResult.HeaderCode = i w.HTTPResult.HeaderCode = statusCode
fmt.Println("GetHeaderCode before", w.HttpResult.HeaderCode)
log.Printf("GetHeaderCode %d", w.HTTPResult.HeaderCode)
} }
// GetHeaderCode result from WriteHeader // GetHeaderCode result from WriteHeader
func (w *WriteAdapter) GetHeaderCode() int { func (w *WriteAdapter) GetHeaderCode() int {
return w.HttpResult.HeaderCode return w.HTTPResult.HeaderCode
} }