diff --git a/gateway/handlers/external_auth.go b/gateway/handlers/external_auth.go new file mode 100644 index 00000000..5ac8282f --- /dev/null +++ b/gateway/handlers/external_auth.go @@ -0,0 +1 @@ +package handlers diff --git a/gateway/handlers/external_auth_test.go b/gateway/handlers/external_auth_test.go new file mode 100644 index 00000000..294aaa86 --- /dev/null +++ b/gateway/handlers/external_auth_test.go @@ -0,0 +1,72 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test_External_Auth_Wrapper_FailsInvalidAuth(t *testing.T) { + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer s.Close() + + next := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + } + + passBody := false + handler := MakeExternalAuthHandler(next, s.URL, passBody) + + req := httptest.NewRequest(http.MethodGet, s.URL, nil) + rr := httptest.NewRecorder() + handler(rr, req) + + if rr.Code == http.StatusOK { + t.Errorf("Status incorrect, did not want: %d, but got %d", http.StatusOK, rr.Code) + } +} + +func Test_External_Auth_Wrapper_PassesValidAuth(t *testing.T) { + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + next := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + } + + passBody := false + handler := MakeExternalAuthHandler(next, s.URL, passBody) + + req := httptest.NewRequest(http.MethodGet, s.URL, nil) + rr := httptest.NewRecorder() + handler(rr, req) + want := http.StatusNotImplemented + if rr.Code != want { + t.Errorf("Status incorrect, want: %d, but got %d", want, rr.Code) + } +} + +func MakeExternalAuthHandler(next http.HandlerFunc, upstreamURL string, passBody bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + req, _ := http.NewRequest(http.MethodGet, upstreamURL, nil) + + res, err := http.DefaultClient.Do(req) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + if res.Body != nil { + defer res.Body.Close() + } + + if res.StatusCode == http.StatusOK { + next.ServeHTTP(w, r) + } + w.WriteHeader(res.StatusCode) + } +}