From 098baba7ccb11986ef0d77b55c725cb8b1d18946 Mon Sep 17 00:00:00 2001 From: "Alex Ellis (OpenFaaS Ltd)" Date: Fri, 3 Jan 2020 12:06:53 +0000 Subject: [PATCH] Add unit test for proxy and shutdown channel * Proxy has initial unit test and more can be added * Shutdown channel and cancellation added for proper shutdown of the proxy Signed-off-by: Alex Ellis (OpenFaaS Ltd) --- cmd/up.go | 11 +++-- pkg/proxy.go | 123 +++++++++++++++++++++++++++------------------- pkg/proxy_test.go | 73 +++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 53 deletions(-) create mode 100644 pkg/proxy_test.go diff --git a/cmd/up.go b/cmd/up.go index 6cb446f..289fd31 100644 --- a/cmd/up.go +++ b/cmd/up.go @@ -77,6 +77,7 @@ func runUp(_ *cobra.Command, _ []string) error { shutdownTimeout := time.Second * 1 timeout := time.Second * 60 + proxyDoneCh := make(chan bool) wg := sync.WaitGroup{} wg.Add(1) @@ -92,14 +93,18 @@ func runUp(_ *cobra.Command, _ []string) error { if err != nil { fmt.Println(err) } + + // Close proxy + proxyDoneCh <- true time.AfterFunc(shutdownTimeout, func() { wg.Done() }) }() gatewayURLChan := make(chan string, 1) - proxy := pkg.NewProxy(timeout) - go proxy.Start(gatewayURLChan) + proxyPort := 8080 + proxy := pkg.NewProxy(proxyPort, timeout) + go proxy.Start(gatewayURLChan, proxyDoneCh) go func() { wd, _ := os.Getwd() @@ -119,7 +124,7 @@ func runUp(_ *cobra.Command, _ []string) error { } } log.Printf("[up] Sending %s to proxy\n", host) - gatewayURLChan <- host + gatewayURLChan <- host + ":8080" close(gatewayURLChan) }() diff --git a/pkg/proxy.go b/pkg/proxy.go index cf0ad9f..d951b83 100644 --- a/pkg/proxy.go +++ b/pkg/proxy.go @@ -1,6 +1,7 @@ package pkg import ( + "context" "fmt" "io" "io/ioutil" @@ -10,83 +11,58 @@ import ( "time" ) -func NewProxy(timeout time.Duration) *Proxy { +func NewProxy(port int, timeout time.Duration) *Proxy { return &Proxy{ + Port: port, Timeout: timeout, } } type Proxy struct { Timeout time.Duration + Port int } -func (p *Proxy) Start(gatewayChan chan string) error { - tcp := 8080 +func (p *Proxy) Start(gatewayChan chan string, done chan bool) error { + tcp := p.Port http.DefaultClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } - - data := struct{ host string }{ - host: "", + ps := proxyState{ + Host: "", } - data.host = <-gatewayChan + ps.Host = <-gatewayChan log.Printf("Starting faasd proxy on %d\n", tcp) - fmt.Printf("Gateway: %s\n", data.host) + fmt.Printf("Gateway: %s\n", ps.Host) s := &http.Server{ Addr: fmt.Sprintf(":%d", tcp), ReadTimeout: p.Timeout, WriteTimeout: p.Timeout, MaxHeaderBytes: 1 << 20, // Max header of 1MB - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - query := "" - if len(r.URL.RawQuery) > 0 { - query = "?" + r.URL.RawQuery - } - - upstream := fmt.Sprintf("http://%s:8080%s%s", data.host, r.URL.Path, query) - fmt.Printf("[faasd] proxy: %s\n", upstream) - - if r.Body != nil { - defer r.Body.Close() - } - - wrapper := ioutil.NopCloser(r.Body) - upReq, upErr := http.NewRequest(r.Method, upstream, wrapper) - - copyHeaders(upReq.Header, &r.Header) - - if upErr != nil { - log.Println(upErr) - - http.Error(w, upErr.Error(), http.StatusInternalServerError) - return - } - - upRes, upResErr := http.DefaultClient.Do(upReq) - - if upResErr != nil { - log.Println(upResErr) - - http.Error(w, upResErr.Error(), http.StatusInternalServerError) - return - } - - copyHeaders(w.Header(), &upRes.Header) - - w.WriteHeader(upRes.StatusCode) - io.Copy(w, upRes.Body) - - }), + Handler: http.HandlerFunc(makeProxy(&ps)), } - return s.ListenAndServe() + go func() { + log.Printf("[proxy] Begin listen on %d\n", p.Port) + if err := s.ListenAndServe(); err != http.ErrServerClosed { + log.Printf("Error ListenAndServe: %v", err) + } + }() + + log.Println("[proxy] Wait for done") + <-done + log.Println("[proxy] Done received") + if err := s.Shutdown(context.Background()); err != nil { + log.Printf("[proxy] Error in Shutdown: %v", err) + } + + return nil } // copyHeaders clones the header values from the source into the destination. @@ -97,3 +73,50 @@ func copyHeaders(destination http.Header, source *http.Header) { destination[k] = vClone } } + +type proxyState struct { + Host string +} + +func makeProxy(ps *proxyState) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + + query := "" + if len(r.URL.RawQuery) > 0 { + query = "?" + r.URL.RawQuery + } + + upstream := fmt.Sprintf("http://%s%s%s", ps.Host, r.URL.Path, query) + fmt.Printf("[faasd] proxy: %s\n", upstream) + + if r.Body != nil { + defer r.Body.Close() + } + + wrapper := ioutil.NopCloser(r.Body) + upReq, upErr := http.NewRequest(r.Method, upstream, wrapper) + + copyHeaders(upReq.Header, &r.Header) + + if upErr != nil { + log.Println(upErr) + + http.Error(w, upErr.Error(), http.StatusInternalServerError) + return + } + + upRes, upResErr := http.DefaultClient.Do(upReq) + + if upResErr != nil { + log.Println(upResErr) + + http.Error(w, upResErr.Error(), http.StatusInternalServerError) + return + } + + copyHeaders(w.Header(), &upRes.Header) + + w.WriteHeader(upRes.StatusCode) + io.Copy(w, upRes.Body) + } +} diff --git a/pkg/proxy_test.go b/pkg/proxy_test.go new file mode 100644 index 0000000..e7f7b45 --- /dev/null +++ b/pkg/proxy_test.go @@ -0,0 +1,73 @@ +package pkg + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" +) + +func Test_Proxy_ToPrivateServer(t *testing.T) { + + wantBodyText := "OK" + wantBody := []byte(wantBodyText) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if r.Body != nil { + defer r.Body.Close() + } + + w.WriteHeader(http.StatusOK) + w.Write(wantBody) + + })) + + defer upstream.Close() + port := 8080 + proxy := NewProxy(port, time.Second*1) + + gwChan := make(chan string, 1) + doneCh := make(chan bool) + + go proxy.Start(gwChan, doneCh) + + u, _ := url.Parse(upstream.URL) + log.Println("Host", u.Host) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + gwChan <- u.Host + wg.Done() + }() + wg.Wait() + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d", port), nil) + if err != nil { + t.Fatal(err) + } + + for i := 1; i < 11; i++ { + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Logf("Try %d, gave error: %s", i, err) + + time.Sleep(time.Millisecond * 100) + } else { + + resBody, _ := ioutil.ReadAll(res.Body) + if string(resBody) != string(wantBody) { + t.Errorf("want %s, but got %s in body", string(wantBody), string(resBody)) + } + break + } + } + + go func() { + doneCh <- true + }() +}