diff --git a/auth/basic-auth/main.go b/auth/basic-auth/main.go index 380aae8b..194a23db 100644 --- a/auth/basic-auth/main.go +++ b/auth/basic-auth/main.go @@ -60,9 +60,20 @@ func makeLogger(next http.Handler) func(w http.ResponseWriter, r *http.Request) next.ServeHTTP(rr, r) log.Printf("Validated request %d.\n", rr.Code) + resHeader := rr.Header() + copyHeaders(w.Header(), &resHeader) + w.WriteHeader(rr.Code) if rr.Body != nil { w.Write(rr.Body.Bytes()) } } } + +func copyHeaders(destination http.Header, source *http.Header) { + for k, v := range *source { + vClone := make([]string, len(v)) + copy(vClone, v) + (destination)[k] = vClone + } +} diff --git a/auth/basic-auth/main_test.go b/auth/basic-auth/main_test.go new file mode 100644 index 00000000..b1241c52 --- /dev/null +++ b/auth/basic-auth/main_test.go @@ -0,0 +1,29 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test_makeLogger(t *testing.T) { + handler := http.HandlerFunc(makeLogger(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Unit-Test", "true") + }))) + + s := httptest.NewServer(handler) + defer s.Close() + + req := httptest.NewRequest(http.MethodGet, s.URL, nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + got := rr.Header().Get("X-Unit-Test") + want := "true" + if want != got { + t.Errorf("Header X-Unit-Test, want: %s, got %s", want, got) + } + +}