From ea62c1b12dd1ae1de794e6dd260351cfbd1a6759 Mon Sep 17 00:00:00 2001 From: Lucas Roesler Date: Sun, 17 Oct 2021 17:59:51 +0200 Subject: [PATCH] feat: add support for raw secret values Load the secret value from the RawValue field, if it is empty, use the string value. Add unit tests for the creation handler. Refactor secret parser tests. Resolves #208 Signed-off-by: Lucas Roesler --- pkg/provider/handlers/secret.go | 29 +++- pkg/provider/handlers/secret_test.go | 196 ++++++++++++++++++++------- 2 files changed, 172 insertions(+), 53 deletions(-) diff --git a/pkg/provider/handlers/secret.go b/pkg/provider/handlers/secret.go index 94b6a76..009f22d 100644 --- a/pkg/provider/handlers/secret.go +++ b/pkg/provider/handlers/secret.go @@ -86,6 +86,14 @@ func createSecret(c *containerd.Client, w http.ResponseWriter, r *http.Request, return } + err = validateSecret(secret) + if err != nil { + log.Printf("[secret] error %s", err.Error()) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + log.Printf("[secret] is valid: %q", secret.Name) namespace := getRequestNamespace(secret.Namespace) mountPath = getNamespaceSecretMountPath(mountPath, namespace) @@ -96,7 +104,12 @@ func createSecret(c *containerd.Client, w http.ResponseWriter, r *http.Request, return } - err = ioutil.WriteFile(path.Join(mountPath, secret.Name), []byte(secret.Value), secretFilePermission) + data := secret.RawValue + if len(data) == 0 { + data = []byte(secret.Value) + } + + err = ioutil.WriteFile(path.Join(mountPath, secret.Name), data, secretFilePermission) if err != nil { log.Printf("[secret] error %s", err.Error()) @@ -137,10 +150,6 @@ func parseSecret(r *http.Request) (types.Secret, error) { return secret, err } - if isTraversal(secret.Name) { - return secret, fmt.Errorf(traverseErrorSt) - } - return secret, err } @@ -150,3 +159,13 @@ func isTraversal(name string) bool { return strings.Contains(name, fmt.Sprintf("%s", string(os.PathSeparator))) || strings.Contains(name, "..") } + +func validateSecret(secret types.Secret) error { + if strings.TrimSpace(secret.Name) == "" { + return fmt.Errorf("non-empty name is required") + } + if isTraversal(secret.Name) { + return fmt.Errorf(traverseErrorSt) + } + return nil +} diff --git a/pkg/provider/handlers/secret_test.go b/pkg/provider/handlers/secret_test.go index b4f2bfc..2b3d8d8 100644 --- a/pkg/provider/handlers/secret_test.go +++ b/pkg/provider/handlers/secret_test.go @@ -1,63 +1,163 @@ package handlers import ( - "bytes" - "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" + "reflect" + "strings" "testing" "github.com/openfaas/faas-provider/types" ) -func Test_parseSecretValidName(t *testing.T) { +func Test_parseSecret(t *testing.T) { + cases := []struct { + name string + payload string + expError string + expSecret types.Secret + }{ + { + name: "no error when name is valid without extention and with no traversal", + payload: `{"name": "authorized_keys", "value": "foo"}`, + expSecret: types.Secret{Name: "authorized_keys", Value: "foo"}, + }, + { + name: "no error when name is valid and parses RawValue correctly", + payload: `{"name": "authorized_keys", "rawValue": "YmFy"}`, + expSecret: types.Secret{Name: "authorized_keys", RawValue: []byte("bar")}, + }, + { + name: "no error when name is valid with dot and with no traversal", + payload: `{"name": "authorized.keys", "value": "foo"}`, + expSecret: types.Secret{Name: "authorized.keys", Value: "foo"}, + }, + } - s := types.Secret{Name: "authorized_keys"} - body, _ := json.Marshal(s) - reader := bytes.NewReader(body) - r := httptest.NewRequest(http.MethodPost, "/", reader) - _, err := parseSecret(r) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + reader := strings.NewReader(tc.payload) + r := httptest.NewRequest(http.MethodPost, "/", reader) + secret, err := parseSecret(r) + if err != nil && tc.expError == "" { + t.Fatalf("unexpected error: %s", err) + return + } + if tc.expError != "" { + if err == nil { + t.Fatalf("expected error: %s, got nil", tc.expError) + } + if err.Error() != tc.expError { + t.Fatalf("expected error: %s, got: %s", tc.expError, err) + } + + return + } + + if !reflect.DeepEqual(secret, tc.expSecret) { + t.Fatalf("expected secret: %+v, got: %+v", tc.expSecret, secret) + } + }) + } +} + +func TestSecretCreation(t *testing.T) { + mountPath, err := os.MkdirTemp("", "test_secret_creation") if err != nil { - t.Fatalf("secret name is valid with no traversal characters") - } -} - -func Test_parseSecretValidNameWithDot(t *testing.T) { - - s := types.Secret{Name: "authorized.keys"} - body, _ := json.Marshal(s) - reader := bytes.NewReader(body) - r := httptest.NewRequest(http.MethodPost, "/", reader) - _, err := parseSecret(r) - - if err != nil { - t.Fatalf("secret name is valid with no traversal characters") - } -} - -func Test_parseSecretWithTraversalWithSlash(t *testing.T) { - - s := types.Secret{Name: "/root/.ssh/authorized_keys"} - body, _ := json.Marshal(s) - reader := bytes.NewReader(body) - r := httptest.NewRequest(http.MethodPost, "/", reader) - _, err := parseSecret(r) - - if err == nil { - t.Fatalf("secret name should fail due to path traversal") - } -} - -func Test_parseSecretWithTraversalWithDoubleDot(t *testing.T) { - - s := types.Secret{Name: ".."} - body, _ := json.Marshal(s) - reader := bytes.NewReader(body) - r := httptest.NewRequest(http.MethodPost, "/", reader) - _, err := parseSecret(r) - - if err == nil { - t.Fatalf("secret name should fail due to path traversal") + t.Fatalf("unexpected error while creating temp directory: %s", err) + } + + defer os.RemoveAll(mountPath) + + handler := MakeSecretHandler(nil, mountPath) + + cases := []struct { + name string + verb string + payload string + status int + secretPath string + secret string + err string + }{ + { + name: "returns error when the name contains a traversal", + verb: http.MethodPost, + payload: `{"name": "/root/.ssh/authorized_keys", "value": "foo"}`, + status: http.StatusBadRequest, + err: "directory traversal found in name\n", + }, + { + name: "returns error when the name contains a traversal", + verb: http.MethodPost, + payload: `{"name": "..", "value": "foo"}`, + status: http.StatusBadRequest, + err: "directory traversal found in name\n", + }, + { + name: "empty request returns a validation error", + verb: http.MethodPost, + payload: `{}`, + status: http.StatusBadRequest, + err: "non-empty name is required\n", + }, + { + name: "can create secret from string", + verb: http.MethodPost, + payload: `{"name": "foo", "value": "bar"}`, + status: http.StatusOK, + secretPath: "/openfaas-fn/foo", + secret: "bar", + }, + { + name: "can create secret from raw value", + verb: http.MethodPost, + payload: `{"name": "foo", "rawValue": "YmFy"}`, + status: http.StatusOK, + secretPath: "/openfaas-fn/foo", + secret: "bar", + }, + { + name: "can create secret in non-default namespace from raw value", + verb: http.MethodPost, + payload: `{"name": "pity", "rawValue": "dGhlIGZvbw==", "namespace": "a-team"}`, + status: http.StatusOK, + secretPath: "/a-team/pity", + secret: "the foo", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.verb, "http://example.com/foo", strings.NewReader(tc.payload)) + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + if resp.StatusCode != tc.status { + t.Logf("response body: %s", w.Body.String()) + t.Fatalf("expected status: %d, got: %d", tc.status, resp.StatusCode) + } + + if resp.StatusCode != http.StatusOK && w.Body.String() != tc.err { + t.Fatalf("expected error message: %q, got %q", tc.err, w.Body.String()) + + } + + if tc.secretPath != "" { + data, err := os.ReadFile(filepath.Join(mountPath, tc.secretPath)) + if err != nil { + t.Fatalf("can not read the secret from disk: %s", err) + } + + if string(data) != tc.secret { + t.Fatalf("expected secret value: %s, got %s", tc.secret, string(data)) + } + } + }) } }