diff --git a/gateway/handlers/external_auth.go b/gateway/handlers/external_auth.go index 5ac8282f..70f8c5aa 100644 --- a/gateway/handlers/external_auth.go +++ b/gateway/handlers/external_auth.go @@ -1 +1,37 @@ package handlers + +import ( + "context" + "net/http" + "time" +) + +// MakeExternalAuthHandler make an authentication proxy handler +func MakeExternalAuthHandler(next http.HandlerFunc, upstreamTimeout time.Duration, upstreamURL string, passBody bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + req, _ := http.NewRequest(http.MethodGet, upstreamURL, nil) + + deadlineContext, cancel := context.WithDeadline( + context.Background(), + time.Now().Add(upstreamTimeout)) + + defer cancel() + + res, err := http.DefaultClient.Do(req.WithContext(deadlineContext)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if res.Body != nil { + defer res.Body.Close() + } + + if res.StatusCode == http.StatusOK { + next.ServeHTTP(w, r) + return + } + + w.WriteHeader(res.StatusCode) + } +} diff --git a/gateway/handlers/external_auth_test.go b/gateway/handlers/external_auth_test.go index 294aaa86..205ce438 100644 --- a/gateway/handlers/external_auth_test.go +++ b/gateway/handlers/external_auth_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" ) func Test_External_Auth_Wrapper_FailsInvalidAuth(t *testing.T) { @@ -18,7 +19,7 @@ func Test_External_Auth_Wrapper_FailsInvalidAuth(t *testing.T) { } passBody := false - handler := MakeExternalAuthHandler(next, s.URL, passBody) + handler := MakeExternalAuthHandler(next, time.Second*5, s.URL, passBody) req := httptest.NewRequest(http.MethodGet, s.URL, nil) rr := httptest.NewRecorder() @@ -41,7 +42,7 @@ func Test_External_Auth_Wrapper_PassesValidAuth(t *testing.T) { } passBody := false - handler := MakeExternalAuthHandler(next, s.URL, passBody) + handler := MakeExternalAuthHandler(next, time.Second*5, s.URL, passBody) req := httptest.NewRequest(http.MethodGet, s.URL, nil) rr := httptest.NewRecorder() @@ -52,21 +53,27 @@ func Test_External_Auth_Wrapper_PassesValidAuth(t *testing.T) { } } -func MakeExternalAuthHandler(next http.HandlerFunc, upstreamURL string, passBody bool) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - req, _ := http.NewRequest(http.MethodGet, upstreamURL, nil) +func Test_External_Auth_Wrapper_TimeoutGivesInternalServerError(t *testing.T) { - res, err := http.DefaultClient.Do(req) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - } - if res.Body != nil { - defer res.Body.Close() - } + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer s.Close() - if res.StatusCode == http.StatusOK { - next.ServeHTTP(w, r) - } - w.WriteHeader(res.StatusCode) + next := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + } + + passBody := false + handler := MakeExternalAuthHandler(next, time.Millisecond*10, s.URL, passBody) + + req := httptest.NewRequest(http.MethodGet, s.URL, nil) + rr := httptest.NewRecorder() + handler(rr, req) + + want := http.StatusInternalServerError + if rr.Code != want { + t.Errorf("Status incorrect, want: %d, but got %d", want, rr.Code) } }