Merge master into breakout_swarm

Signed-off-by: Alex Ellis <alexellis2@gmail.com>
This commit is contained in:
Alex Ellis
2018-02-01 09:25:39 +00:00
parent afeb7bbce4
commit f954bf0733
1953 changed files with 614131 additions and 175582 deletions

View File

@ -0,0 +1,185 @@
package errcode
import (
"encoding/json"
"net/http"
"reflect"
"strings"
"testing"
)
// TestErrorsManagement does a quick check of the Errors type to ensure that
// members are properly pushed and marshaled.
var ErrorCodeTest1 = Register("test.errors", ErrorDescriptor{
Value: "TEST1",
Message: "test error 1",
Description: `Just a test message #1.`,
HTTPStatusCode: http.StatusInternalServerError,
})
var ErrorCodeTest2 = Register("test.errors", ErrorDescriptor{
Value: "TEST2",
Message: "test error 2",
Description: `Just a test message #2.`,
HTTPStatusCode: http.StatusNotFound,
})
var ErrorCodeTest3 = Register("test.errors", ErrorDescriptor{
Value: "TEST3",
Message: "Sorry %q isn't valid",
Description: `Just a test message #3.`,
HTTPStatusCode: http.StatusNotFound,
})
// TestErrorCodes ensures that error code format, mappings and
// marshaling/unmarshaling. round trips are stable.
func TestErrorCodes(t *testing.T) {
if len(errorCodeToDescriptors) == 0 {
t.Fatal("errors aren't loaded!")
}
for ec, desc := range errorCodeToDescriptors {
if ec != desc.Code {
t.Fatalf("error code in descriptor isn't correct, %q != %q", ec, desc.Code)
}
if idToDescriptors[desc.Value].Code != ec {
t.Fatalf("error code in idToDesc isn't correct, %q != %q", idToDescriptors[desc.Value].Code, ec)
}
if ec.Message() != desc.Message {
t.Fatalf("ec.Message doesn't mtach desc.Message: %q != %q", ec.Message(), desc.Message)
}
// Test (de)serializing the ErrorCode
p, err := json.Marshal(ec)
if err != nil {
t.Fatalf("couldn't marshal ec %v: %v", ec, err)
}
if len(p) <= 0 {
t.Fatalf("expected content in marshaled before for error code %v", ec)
}
// First, unmarshal to interface and ensure we have a string.
var ecUnspecified interface{}
if err := json.Unmarshal(p, &ecUnspecified); err != nil {
t.Fatalf("error unmarshaling error code %v: %v", ec, err)
}
if _, ok := ecUnspecified.(string); !ok {
t.Fatalf("expected a string for error code %v on unmarshal got a %T", ec, ecUnspecified)
}
// Now, unmarshal with the error code type and ensure they are equal
var ecUnmarshaled ErrorCode
if err := json.Unmarshal(p, &ecUnmarshaled); err != nil {
t.Fatalf("error unmarshaling error code %v: %v", ec, err)
}
if ecUnmarshaled != ec {
t.Fatalf("unexpected error code during error code marshal/unmarshal: %v != %v", ecUnmarshaled, ec)
}
expectedErrorString := strings.ToLower(strings.Replace(ec.Descriptor().Value, "_", " ", -1))
if ec.Error() != expectedErrorString {
t.Fatalf("unexpected return from %v.Error(): %q != %q", ec, ec.Error(), expectedErrorString)
}
}
}
func TestErrorsManagement(t *testing.T) {
var errs Errors
errs = append(errs, ErrorCodeTest1)
errs = append(errs, ErrorCodeTest2.WithDetail(
map[string]interface{}{"digest": "sometestblobsumdoesntmatter"}))
errs = append(errs, ErrorCodeTest3.WithArgs("BOOGIE"))
errs = append(errs, ErrorCodeTest3.WithArgs("BOOGIE").WithDetail("data"))
p, err := json.Marshal(errs)
if err != nil {
t.Fatalf("error marashaling errors: %v", err)
}
expectedJSON := `{"errors":[` +
`{"code":"TEST1","message":"test error 1"},` +
`{"code":"TEST2","message":"test error 2","detail":{"digest":"sometestblobsumdoesntmatter"}},` +
`{"code":"TEST3","message":"Sorry \"BOOGIE\" isn't valid"},` +
`{"code":"TEST3","message":"Sorry \"BOOGIE\" isn't valid","detail":"data"}` +
`]}`
if string(p) != expectedJSON {
t.Fatalf("unexpected json:\ngot:\n%q\n\nexpected:\n%q", string(p), expectedJSON)
}
// Now test the reverse
var unmarshaled Errors
if err := json.Unmarshal(p, &unmarshaled); err != nil {
t.Fatalf("unexpected error unmarshaling error envelope: %v", err)
}
if !reflect.DeepEqual(unmarshaled, errs) {
t.Fatalf("errors not equal after round trip:\nunmarshaled:\n%#v\n\nerrs:\n%#v", unmarshaled, errs)
}
// Test the arg substitution stuff
e1 := unmarshaled[3].(Error)
exp1 := `Sorry "BOOGIE" isn't valid`
if e1.Message != exp1 {
t.Fatalf("Wrong msg, got:\n%q\n\nexpected:\n%q", e1.Message, exp1)
}
exp1 = "test3: " + exp1
if e1.Error() != exp1 {
t.Fatalf("Error() didn't return the right string, got:%s\nexpected:%s", e1.Error(), exp1)
}
// Test again with a single value this time
errs = Errors{ErrorCodeUnknown}
expectedJSON = "{\"errors\":[{\"code\":\"UNKNOWN\",\"message\":\"unknown error\"}]}"
p, err = json.Marshal(errs)
if err != nil {
t.Fatalf("error marashaling errors: %v", err)
}
if string(p) != expectedJSON {
t.Fatalf("unexpected json: %q != %q", string(p), expectedJSON)
}
// Now test the reverse
unmarshaled = nil
if err := json.Unmarshal(p, &unmarshaled); err != nil {
t.Fatalf("unexpected error unmarshaling error envelope: %v", err)
}
if !reflect.DeepEqual(unmarshaled, errs) {
t.Fatalf("errors not equal after round trip:\nunmarshaled:\n%#v\n\nerrs:\n%#v", unmarshaled, errs)
}
// Verify that calling WithArgs() more than once does the right thing.
// Meaning creates a new Error and uses the ErrorCode Message
e1 = ErrorCodeTest3.WithArgs("test1")
e2 := e1.WithArgs("test2")
if &e1 == &e2 {
t.Fatalf("args: e2 and e1 should not be the same, but they are")
}
if e2.Message != `Sorry "test2" isn't valid` {
t.Fatalf("e2 had wrong message: %q", e2.Message)
}
// Verify that calling WithDetail() more than once does the right thing.
// Meaning creates a new Error and overwrites the old detail field
e1 = ErrorCodeTest3.WithDetail("stuff1")
e2 = e1.WithDetail("stuff2")
if &e1 == &e2 {
t.Fatalf("detail: e2 and e1 should not be the same, but they are")
}
if e2.Detail != `stuff2` {
t.Fatalf("e2 had wrong detail: %q", e2.Detail)
}
}

View File

@ -4,9 +4,9 @@ import (
"net/http"
"regexp"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/errcode"
"github.com/opencontainers/go-digest"
)
var (

View File

@ -0,0 +1,161 @@
package v2
import (
"testing"
)
func TestParseForwardedHeader(t *testing.T) {
for _, tc := range []struct {
name string
raw string
expected map[string]string
expectedRest string
expectedError bool
}{
{
name: "empty",
raw: "",
},
{
name: "one pair",
raw: " key = value ",
expected: map[string]string{"key": "value"},
},
{
name: "two pairs",
raw: " key1 = value1; key2=value2",
expected: map[string]string{"key1": "value1", "key2": "value2"},
},
{
name: "uppercase parameter",
raw: "KeY=VaL",
expected: map[string]string{"key": "VaL"},
},
{
name: "missing key=value pair - be tolerant",
raw: "key=val;",
expected: map[string]string{"key": "val"},
},
{
name: "quoted values",
raw: `key="val";param = "[[ $((1 + 1)) == 3 ]] && echo panic!;" ; p=" abcd "`,
expected: map[string]string{"key": "val", "param": "[[ $((1 + 1)) == 3 ]] && echo panic!;", "p": " abcd "},
},
{
name: "empty quoted value",
raw: `key=""`,
expected: map[string]string{"key": ""},
},
{
name: "quoted double quotes",
raw: `key="\"value\""`,
expected: map[string]string{"key": `"value"`},
},
{
name: "quoted backslash",
raw: `key="\"\\\""`,
expected: map[string]string{"key": `"\"`},
},
{
name: "ignore subsequent elements",
raw: "key=a, param= b",
expected: map[string]string{"key": "a"},
expectedRest: " param= b",
},
{
name: "empty element - be tolerant",
raw: " , key=val",
expectedRest: " key=val",
},
{
name: "obscure key",
raw: `ob₷C&r€ = value`,
expected: map[string]string{`ob₷c&r€`: "value"},
},
{
name: "duplicate parameter",
raw: "key=a; p=b; key=c",
expectedError: true,
},
{
name: "empty parameter",
raw: "=value",
expectedError: true,
},
{
name: "empty value",
raw: "key= ",
expectedError: true,
},
{
name: "empty value before a new element ",
raw: "key=,",
expectedError: true,
},
{
name: "empty value before a new pair",
raw: "key=;",
expectedError: true,
},
{
name: "just parameter",
raw: "key",
expectedError: true,
},
{
name: "missing key-value",
raw: "a=b;;",
expectedError: true,
},
{
name: "unclosed quoted value",
raw: `key="value`,
expectedError: true,
},
{
name: "escaped terminating dquote",
raw: `key="value\"`,
expectedError: true,
},
{
name: "just a quoted value",
raw: `"key=val"`,
expectedError: true,
},
{
name: "quoted key",
raw: `"key"=val`,
expectedError: true,
},
} {
parsed, rest, err := parseForwardedHeader(tc.raw)
if err != nil && !tc.expectedError {
t.Errorf("[%s] got unexpected error: %v", tc.name, err)
}
if err == nil && tc.expectedError {
t.Errorf("[%s] got unexpected non-error", tc.name)
}
if err != nil || tc.expectedError {
continue
}
for key, value := range tc.expected {
v, exists := parsed[key]
if !exists {
t.Errorf("[%s] missing expected parameter %q", tc.name, key)
continue
}
if v != value {
t.Errorf("[%s] got unexpected value for parameter %q: %q != %q", tc.name, key, v, value)
}
}
for key, value := range parsed {
if _, exists := tc.expected[key]; !exists {
t.Errorf("[%s] got unexpected key/value pair: %q=%q", tc.name, key, value)
}
}
if rest != tc.expectedRest {
t.Errorf("[%s] got unexpected unparsed string: %q != %q", tc.name, rest, tc.expectedRest)
}
}
}

View File

@ -0,0 +1,355 @@
package v2
import (
"encoding/json"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gorilla/mux"
)
type routeTestCase struct {
RequestURI string
ExpectedURI string
Vars map[string]string
RouteName string
StatusCode int
}
// TestRouter registers a test handler with all the routes and ensures that
// each route returns the expected path variables. Not method verification is
// present. This not meant to be exhaustive but as check to ensure that the
// expected variables are extracted.
//
// This may go away as the application structure comes together.
func TestRouter(t *testing.T) {
testCases := []routeTestCase{
{
RouteName: RouteNameBase,
RequestURI: "/v2/",
Vars: map[string]string{},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/manifests/bar",
Vars: map[string]string{
"name": "foo",
"reference": "bar",
},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/bar/manifests/tag",
Vars: map[string]string{
"name": "foo/bar",
"reference": "tag",
},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/bar/manifests/sha256:abcdef01234567890",
Vars: map[string]string{
"name": "foo/bar",
"reference": "sha256:abcdef01234567890",
},
},
{
RouteName: RouteNameTags,
RequestURI: "/v2/foo/bar/tags/list",
Vars: map[string]string{
"name": "foo/bar",
},
},
{
RouteName: RouteNameTags,
RequestURI: "/v2/docker.com/foo/tags/list",
Vars: map[string]string{
"name": "docker.com/foo",
},
},
{
RouteName: RouteNameTags,
RequestURI: "/v2/docker.com/foo/bar/tags/list",
Vars: map[string]string{
"name": "docker.com/foo/bar",
},
},
{
RouteName: RouteNameTags,
RequestURI: "/v2/docker.com/foo/bar/baz/tags/list",
Vars: map[string]string{
"name": "docker.com/foo/bar/baz",
},
},
{
RouteName: RouteNameBlob,
RequestURI: "/v2/foo/bar/blobs/sha256:abcdef0919234",
Vars: map[string]string{
"name": "foo/bar",
"digest": "sha256:abcdef0919234",
},
},
{
RouteName: RouteNameBlobUpload,
RequestURI: "/v2/foo/bar/blobs/uploads/",
Vars: map[string]string{
"name": "foo/bar",
},
},
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/uuid",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "uuid",
},
},
{
// support uuid proper
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
},
},
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/RDk1MzA2RkEtRkFEMy00RTM2LThENDEtQ0YxQzkzRUY4Mjg2IA==",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "RDk1MzA2RkEtRkFEMy00RTM2LThENDEtQ0YxQzkzRUY4Mjg2IA==",
},
},
{
// supports urlsafe base64
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/RDk1MzA2RkEtRkFEMy00RTM2LThENDEtQ0YxQzkzRUY4Mjg2IA_-==",
Vars: map[string]string{
"name": "foo/bar",
"uuid": "RDk1MzA2RkEtRkFEMy00RTM2LThENDEtQ0YxQzkzRUY4Mjg2IA_-==",
},
},
{
// does not match
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/bar/blobs/uploads/totalandcompletejunk++$$-==",
StatusCode: http.StatusNotFound,
},
{
// Check ambiguity: ensure we can distinguish between tags for
// "foo/bar/image/image" and image for "foo/bar/image" with tag
// "tags"
RouteName: RouteNameManifest,
RequestURI: "/v2/foo/bar/manifests/manifests/tags",
Vars: map[string]string{
"name": "foo/bar/manifests",
"reference": "tags",
},
},
{
// This case presents an ambiguity between foo/bar with tag="tags"
// and list tags for "foo/bar/manifest"
RouteName: RouteNameTags,
RequestURI: "/v2/foo/bar/manifests/tags/list",
Vars: map[string]string{
"name": "foo/bar/manifests",
},
},
{
RouteName: RouteNameManifest,
RequestURI: "/v2/locahost:8080/foo/bar/baz/manifests/tag",
Vars: map[string]string{
"name": "locahost:8080/foo/bar/baz",
"reference": "tag",
},
},
}
checkTestRouter(t, testCases, "", true)
checkTestRouter(t, testCases, "/prefix/", true)
}
func TestRouterWithPathTraversals(t *testing.T) {
testCases := []routeTestCase{
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/../../blob/uploads/D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
ExpectedURI: "/blob/uploads/D95306FA-FAD3-4E36-8D41-CF1C93EF8286",
StatusCode: http.StatusNotFound,
},
{
// Testing for path traversal attack handling
RouteName: RouteNameTags,
RequestURI: "/v2/foo/../bar/baz/tags/list",
ExpectedURI: "/v2/bar/baz/tags/list",
Vars: map[string]string{
"name": "bar/baz",
},
},
}
checkTestRouter(t, testCases, "", false)
}
func TestRouterWithBadCharacters(t *testing.T) {
if testing.Short() {
testCases := []routeTestCase{
{
RouteName: RouteNameBlobUploadChunk,
RequestURI: "/v2/foo/blob/uploads/不95306FA-FAD3-4E36-8D41-CF1C93EF8286",
StatusCode: http.StatusNotFound,
},
{
// Testing for path traversal attack handling
RouteName: RouteNameTags,
RequestURI: "/v2/foo/不bar/tags/list",
StatusCode: http.StatusNotFound,
},
}
checkTestRouter(t, testCases, "", true)
} else {
// in the long version we're going to fuzz the router
// with random UTF8 characters not in the 128 bit ASCII range.
// These are not valid characters for the router and we expect
// 404s on every test.
rand.Seed(time.Now().UTC().UnixNano())
testCases := make([]routeTestCase, 1000)
for idx := range testCases {
testCases[idx] = routeTestCase{
RouteName: RouteNameTags,
RequestURI: fmt.Sprintf("/v2/%v/%v/tags/list", randomString(10), randomString(10)),
StatusCode: http.StatusNotFound,
}
}
checkTestRouter(t, testCases, "", true)
}
}
func checkTestRouter(t *testing.T, testCases []routeTestCase, prefix string, deeplyEqual bool) {
router := RouterWithPrefix(prefix)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testCase := routeTestCase{
RequestURI: r.RequestURI,
Vars: mux.Vars(r),
RouteName: mux.CurrentRoute(r).GetName(),
}
enc := json.NewEncoder(w)
if err := enc.Encode(testCase); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
// Startup test server
server := httptest.NewServer(router)
for _, testcase := range testCases {
testcase.RequestURI = strings.TrimSuffix(prefix, "/") + testcase.RequestURI
// Register the endpoint
route := router.GetRoute(testcase.RouteName)
if route == nil {
t.Fatalf("route for name %q not found", testcase.RouteName)
}
route.Handler(testHandler)
u := server.URL + testcase.RequestURI
resp, err := http.Get(u)
if err != nil {
t.Fatalf("error issuing get request: %v", err)
}
if testcase.StatusCode == 0 {
// Override default, zero-value
testcase.StatusCode = http.StatusOK
}
if testcase.ExpectedURI == "" {
// Override default, zero-value
testcase.ExpectedURI = testcase.RequestURI
}
if resp.StatusCode != testcase.StatusCode {
t.Fatalf("unexpected status for %s: %v %v", u, resp.Status, resp.StatusCode)
}
if testcase.StatusCode != http.StatusOK {
resp.Body.Close()
// We don't care about json response.
continue
}
dec := json.NewDecoder(resp.Body)
var actualRouteInfo routeTestCase
if err := dec.Decode(&actualRouteInfo); err != nil {
t.Fatalf("error reading json response: %v", err)
}
// Needs to be set out of band
actualRouteInfo.StatusCode = resp.StatusCode
if actualRouteInfo.RequestURI != testcase.ExpectedURI {
t.Fatalf("URI %v incorrectly parsed, expected %v", actualRouteInfo.RequestURI, testcase.ExpectedURI)
}
if actualRouteInfo.RouteName != testcase.RouteName {
t.Fatalf("incorrect route %q matched, expected %q", actualRouteInfo.RouteName, testcase.RouteName)
}
// when testing deep equality, the actualRouteInfo has an empty ExpectedURI, we don't want
// that to make the comparison fail. We're otherwise done with the testcase so empty the
// testcase.ExpectedURI
testcase.ExpectedURI = ""
if deeplyEqual && !reflect.DeepEqual(actualRouteInfo, testcase) {
t.Fatalf("actual does not equal expected: %#v != %#v", actualRouteInfo, testcase)
}
resp.Body.Close()
}
}
// -------------- START LICENSED CODE --------------
// The following code is derivative of https://github.com/google/gofuzz
// gofuzz is licensed under the Apache License, Version 2.0, January 2004,
// a copy of which can be found in the LICENSE file at the root of this
// repository.
// These functions allow us to generate strings containing only multibyte
// characters that are invalid in our URLs. They are used above for fuzzing
// to ensure we always get 404s on these invalid strings
type charRange struct {
first, last rune
}
// choose returns a random unicode character from the given range, using the
// given randomness source.
func (r *charRange) choose() rune {
count := int64(r.last - r.first)
return r.first + rune(rand.Int63n(count))
}
var unicodeRanges = []charRange{
{'\u00a0', '\u02af'}, // Multi-byte encoded characters
{'\u4e00', '\u9fff'}, // Common CJK (even longer encodings)
}
func randomString(length int) string {
runes := make([]rune, length)
for i := range runes {
runes[i] = unicodeRanges[rand.Intn(len(unicodeRanges))].choose()
}
return string(runes)
}
// -------------- END LICENSED CODE --------------

View File

@ -1,7 +1,6 @@
package v2
import (
"fmt"
"net/http"
"net/url"
"strings"
@ -150,8 +149,6 @@ func (ub *URLBuilder) BuildManifestURL(ref reference.Named) (string, error) {
tagOrDigest = v.Tag()
case reference.Digested:
tagOrDigest = v.Digest().String()
default:
return "", fmt.Errorf("reference must have a tag or digest")
}
manifestURL, err := route.URL("name", ref.Name(), "reference", tagOrDigest)

View File

@ -0,0 +1,484 @@
package v2
import (
"net/http"
"net/url"
"testing"
"github.com/docker/distribution/reference"
)
type urlBuilderTestCase struct {
description string
expectedPath string
build func() (string, error)
}
func makeURLBuilderTestCases(urlBuilder *URLBuilder) []urlBuilderTestCase {
fooBarRef, _ := reference.ParseNamed("foo/bar")
return []urlBuilderTestCase{
{
description: "test base url",
expectedPath: "/v2/",
build: urlBuilder.BuildBaseURL,
},
{
description: "test tags url",
expectedPath: "/v2/foo/bar/tags/list",
build: func() (string, error) {
return urlBuilder.BuildTagsURL(fooBarRef)
},
},
{
description: "test manifest url",
expectedPath: "/v2/foo/bar/manifests/tag",
build: func() (string, error) {
ref, _ := reference.WithTag(fooBarRef, "tag")
return urlBuilder.BuildManifestURL(ref)
},
},
{
description: "build blob url",
expectedPath: "/v2/foo/bar/blobs/sha256:3b3692957d439ac1928219a83fac91e7bf96c153725526874673ae1f2023f8d5",
build: func() (string, error) {
ref, _ := reference.WithDigest(fooBarRef, "sha256:3b3692957d439ac1928219a83fac91e7bf96c153725526874673ae1f2023f8d5")
return urlBuilder.BuildBlobURL(ref)
},
},
{
description: "build blob upload url",
expectedPath: "/v2/foo/bar/blobs/uploads/",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadURL(fooBarRef)
},
},
{
description: "build blob upload url with digest and size",
expectedPath: "/v2/foo/bar/blobs/uploads/?digest=sha256%3A3b3692957d439ac1928219a83fac91e7bf96c153725526874673ae1f2023f8d5&size=10000",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadURL(fooBarRef, url.Values{
"size": []string{"10000"},
"digest": []string{"sha256:3b3692957d439ac1928219a83fac91e7bf96c153725526874673ae1f2023f8d5"},
})
},
},
{
description: "build blob upload chunk url",
expectedPath: "/v2/foo/bar/blobs/uploads/uuid-part",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadChunkURL(fooBarRef, "uuid-part")
},
},
{
description: "build blob upload chunk url with digest and size",
expectedPath: "/v2/foo/bar/blobs/uploads/uuid-part?digest=sha256%3A3b3692957d439ac1928219a83fac91e7bf96c153725526874673ae1f2023f8d5&size=10000",
build: func() (string, error) {
return urlBuilder.BuildBlobUploadChunkURL(fooBarRef, "uuid-part", url.Values{
"size": []string{"10000"},
"digest": []string{"sha256:3b3692957d439ac1928219a83fac91e7bf96c153725526874673ae1f2023f8d5"},
})
},
},
}
}
// TestURLBuilder tests the various url building functions, ensuring they are
// returning the expected values.
func TestURLBuilder(t *testing.T) {
roots := []string{
"http://example.com",
"https://example.com",
"http://localhost:5000",
"https://localhost:5443",
}
doTest := func(relative bool) {
for _, root := range roots {
urlBuilder, err := NewURLBuilderFromString(root, relative)
if err != nil {
t.Fatalf("unexpected error creating urlbuilder: %v", err)
}
for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := testCase.expectedPath
if !relative {
expectedURL = root + expectedURL
}
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
}
}
}
doTest(true)
doTest(false)
}
func TestURLBuilderWithPrefix(t *testing.T) {
roots := []string{
"http://example.com/prefix/",
"https://example.com/prefix/",
"http://localhost:5000/prefix/",
"https://localhost:5443/prefix/",
}
doTest := func(relative bool) {
for _, root := range roots {
urlBuilder, err := NewURLBuilderFromString(root, relative)
if err != nil {
t.Fatalf("unexpected error creating urlbuilder: %v", err)
}
for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := testCase.expectedPath
if !relative {
expectedURL = root[0:len(root)-1] + expectedURL
}
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
}
}
}
doTest(true)
doTest(false)
}
type builderFromRequestTestCase struct {
request *http.Request
base string
}
func TestBuilderFromRequest(t *testing.T) {
u, err := url.Parse("http://example.com")
if err != nil {
t.Fatal(err)
}
testRequests := []struct {
name string
request *http.Request
base string
configHost url.URL
}{
{
name: "no forwarded header",
request: &http.Request{URL: u, Host: u.Host},
base: "http://example.com",
},
{
name: "https protocol forwarded with a non-standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Custom-Forwarded-Proto": []string{"https"},
}},
base: "http://example.com",
},
{
name: "forwarded protocol is the same",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{"https"},
}},
base: "https://example.com",
},
{
name: "forwarded host with a non-standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"first.example.com"},
}},
base: "http://first.example.com",
},
{
name: "forwarded multiple hosts a with non-standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"first.example.com, proxy1.example.com"},
}},
base: "http://first.example.com",
},
{
name: "host configured in config file takes priority",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"first.example.com, proxy1.example.com"},
}},
base: "https://third.example.com:5000",
configHost: url.URL{
Scheme: "https",
Host: "third.example.com:5000",
},
},
{
name: "forwarded host and port with just one non-standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"first.example.com:443"},
}},
base: "http://first.example.com:443",
},
{
name: "forwarded port with a non-standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"example.com:5000"},
"X-Forwarded-Port": []string{"5000"},
}},
base: "http://example.com:5000",
},
{
name: "forwarded multiple ports with a non-standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Port": []string{"443 , 5001"},
}},
base: "http://example.com",
},
{
name: "forwarded standard port with non-standard headers",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{"https"},
"X-Forwarded-Host": []string{"example.com"},
"X-Forwarded-Port": []string{"443"},
}},
base: "https://example.com",
},
{
name: "forwarded standard port with non-standard headers and explicit port",
request: &http.Request{URL: u, Host: u.Host + ":443", Header: http.Header{
"X-Forwarded-Proto": []string{"https"},
"X-Forwarded-Host": []string{u.Host + ":443"},
"X-Forwarded-Port": []string{"443"},
}},
base: "https://example.com:443",
},
{
name: "several non-standard headers",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{"https"},
"X-Forwarded-Host": []string{" first.example.com:12345 "},
}},
base: "https://first.example.com:12345",
},
{
name: "forwarded host with port supplied takes priority",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"first.example.com:5000"},
"X-Forwarded-Port": []string{"80"},
}},
base: "http://first.example.com:5000",
},
{
name: "malformed forwarded port",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Host": []string{"first.example.com"},
"X-Forwarded-Port": []string{"abcd"},
}},
base: "http://first.example.com",
},
{
name: "forwarded protocol and addr using standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`proto=https;host="192.168.22.30:80"`},
}},
base: "https://192.168.22.30:80",
},
{
name: "forwarded host takes priority over for",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`host="reg.example.com:5000";for="192.168.22.30"`},
}},
base: "http://reg.example.com:5000",
},
{
name: "forwarded host and protocol using standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`host=reg.example.com;proto=https`},
}},
base: "https://reg.example.com",
},
{
name: "process just the first standard forwarded header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`host="reg.example.com:88";proto=http`, `host=reg.example.com;proto=https`},
}},
base: "http://reg.example.com:88",
},
{
name: "process just the first list element of standard header",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`host="reg.example.com:443";proto=https, host="reg.example.com:80";proto=http`},
}},
base: "https://reg.example.com:443",
},
{
name: "IPv6 address use host",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`for="2607:f0d0:1002:51::4";host="[2607:f0d0:1002:51::4]:5001"`},
"X-Forwarded-Port": []string{"5002"},
}},
base: "http://[2607:f0d0:1002:51::4]:5001",
},
{
name: "IPv6 address with port",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"Forwarded": []string{`host="[2607:f0d0:1002:51::4]:4000"`},
"X-Forwarded-Port": []string{"5001"},
}},
base: "http://[2607:f0d0:1002:51::4]:4000",
},
{
name: "non-standard and standard forward headers",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{`https`},
"X-Forwarded-Host": []string{`first.example.com`},
"X-Forwarded-Port": []string{``},
"Forwarded": []string{`host=first.example.com; proto=https`},
}},
base: "https://first.example.com",
},
{
name: "standard header takes precedence over non-standard headers",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{`http`},
"Forwarded": []string{`host=second.example.com; proto=https`},
"X-Forwarded-Host": []string{`first.example.com`},
"X-Forwarded-Port": []string{`4000`},
}},
base: "https://second.example.com",
},
{
name: "incomplete standard header uses default",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{`https`},
"Forwarded": []string{`for=127.0.0.1`},
"X-Forwarded-Host": []string{`first.example.com`},
"X-Forwarded-Port": []string{`4000`},
}},
base: "http://" + u.Host,
},
{
name: "standard with just proto",
request: &http.Request{URL: u, Host: u.Host, Header: http.Header{
"X-Forwarded-Proto": []string{`https`},
"Forwarded": []string{`proto=https`},
"X-Forwarded-Host": []string{`first.example.com`},
"X-Forwarded-Port": []string{`4000`},
}},
base: "https://" + u.Host,
},
}
doTest := func(relative bool) {
for _, tr := range testRequests {
var builder *URLBuilder
if tr.configHost.Scheme != "" && tr.configHost.Host != "" {
builder = NewURLBuilder(&tr.configHost, relative)
} else {
builder = NewURLBuilderFromRequest(tr.request, relative)
}
for _, testCase := range makeURLBuilderTestCases(builder) {
buildURL, err := testCase.build()
if err != nil {
t.Fatalf("[relative=%t, request=%q, case=%q]: error building url: %v", relative, tr.name, testCase.description, err)
}
expectedURL := testCase.expectedPath
if !relative {
expectedURL = tr.base + expectedURL
}
if buildURL != expectedURL {
t.Errorf("[relative=%t, request=%q, case=%q]: %q != %q", relative, tr.name, testCase.description, buildURL, expectedURL)
}
}
}
}
doTest(true)
doTest(false)
}
func TestBuilderFromRequestWithPrefix(t *testing.T) {
u, err := url.Parse("http://example.com/prefix/v2/")
if err != nil {
t.Fatal(err)
}
forwardedProtoHeader := make(http.Header, 1)
forwardedProtoHeader.Set("X-Forwarded-Proto", "https")
testRequests := []struct {
request *http.Request
base string
configHost url.URL
}{
{
request: &http.Request{URL: u, Host: u.Host},
base: "http://example.com/prefix/",
},
{
request: &http.Request{URL: u, Host: u.Host, Header: forwardedProtoHeader},
base: "http://example.com/prefix/",
},
{
request: &http.Request{URL: u, Host: u.Host, Header: forwardedProtoHeader},
base: "https://example.com/prefix/",
},
{
request: &http.Request{URL: u, Host: u.Host, Header: forwardedProtoHeader},
base: "https://subdomain.example.com/prefix/",
configHost: url.URL{
Scheme: "https",
Host: "subdomain.example.com",
Path: "/prefix/",
},
},
}
var relative bool
for _, tr := range testRequests {
var builder *URLBuilder
if tr.configHost.Scheme != "" && tr.configHost.Host != "" {
builder = NewURLBuilder(&tr.configHost, false)
} else {
builder = NewURLBuilderFromRequest(tr.request, false)
}
for _, testCase := range makeURLBuilderTestCases(builder) {
buildURL, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
var expectedURL string
proto, ok := tr.request.Header["X-Forwarded-Proto"]
if !ok {
expectedURL = testCase.expectedPath
if !relative {
expectedURL = tr.base[0:len(tr.base)-1] + expectedURL
}
} else {
urlBase, err := url.Parse(tr.base)
if err != nil {
t.Fatal(err)
}
urlBase.Scheme = proto[0]
expectedURL = testCase.expectedPath
if !relative {
expectedURL = urlBase.String()[0:len(urlBase.String())-1] + expectedURL
}
}
if buildURL != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, buildURL, expectedURL)
}
}
}
}

View File

@ -0,0 +1,202 @@
// Package auth defines a standard interface for request access controllers.
//
// An access controller has a simple interface with a single `Authorized`
// method which checks that a given request is authorized to perform one or
// more actions on one or more resources. This method should return a non-nil
// error if the request is not authorized.
//
// An implementation registers its access controller by name with a constructor
// which accepts an options map for configuring the access controller.
//
// options := map[string]interface{}{"sillySecret": "whysosilly?"}
// accessController, _ := auth.GetAccessController("silly", options)
//
// This `accessController` can then be used in a request handler like so:
//
// func updateOrder(w http.ResponseWriter, r *http.Request) {
// orderNumber := r.FormValue("orderNumber")
// resource := auth.Resource{Type: "customerOrder", Name: orderNumber}
// access := auth.Access{Resource: resource, Action: "update"}
//
// if ctx, err := accessController.Authorized(ctx, access); err != nil {
// if challenge, ok := err.(auth.Challenge) {
// // Let the challenge write the response.
// challenge.SetHeaders(w)
// w.WriteHeader(http.StatusUnauthorized)
// return
// } else {
// // Some other error.
// }
// }
// }
//
package auth
import (
"errors"
"fmt"
"net/http"
"github.com/docker/distribution/context"
)
const (
// UserKey is used to get the user object from
// a user context
UserKey = "auth.user"
// UserNameKey is used to get the user name from
// a user context
UserNameKey = "auth.user.name"
)
var (
// ErrInvalidCredential is returned when the auth token does not authenticate correctly.
ErrInvalidCredential = errors.New("invalid authorization credential")
// ErrAuthenticationFailure returned when authentication fails.
ErrAuthenticationFailure = errors.New("authentication failure")
)
// UserInfo carries information about
// an autenticated/authorized client.
type UserInfo struct {
Name string
}
// Resource describes a resource by type and name.
type Resource struct {
Type string
Class string
Name string
}
// Access describes a specific action that is
// requested or allowed for a given resource.
type Access struct {
Resource
Action string
}
// Challenge is a special error type which is used for HTTP 401 Unauthorized
// responses and is able to write the response with WWW-Authenticate challenge
// header values based on the error.
type Challenge interface {
error
// SetHeaders prepares the request to conduct a challenge response by
// adding the an HTTP challenge header on the response message. Callers
// are expected to set the appropriate HTTP status code (e.g. 401)
// themselves.
SetHeaders(w http.ResponseWriter)
}
// AccessController controls access to registry resources based on a request
// and required access levels for a request. Implementations can support both
// complete denial and http authorization challenges.
type AccessController interface {
// Authorized returns a non-nil error if the context is granted access and
// returns a new authorized context. If one or more Access structs are
// provided, the requested access will be compared with what is available
// to the context. The given context will contain a "http.request" key with
// a `*http.Request` value. If the error is non-nil, access should always
// be denied. The error may be of type Challenge, in which case the caller
// may have the Challenge handle the request or choose what action to take
// based on the Challenge header or response status. The returned context
// object should have a "auth.user" value set to a UserInfo struct.
Authorized(ctx context.Context, access ...Access) (context.Context, error)
}
// CredentialAuthenticator is an object which is able to authenticate credentials
type CredentialAuthenticator interface {
AuthenticateUser(username, password string) error
}
// WithUser returns a context with the authorized user info.
func WithUser(ctx context.Context, user UserInfo) context.Context {
return userInfoContext{
Context: ctx,
user: user,
}
}
type userInfoContext struct {
context.Context
user UserInfo
}
func (uic userInfoContext) Value(key interface{}) interface{} {
switch key {
case UserKey:
return uic.user
case UserNameKey:
return uic.user.Name
}
return uic.Context.Value(key)
}
// WithResources returns a context with the authorized resources.
func WithResources(ctx context.Context, resources []Resource) context.Context {
return resourceContext{
Context: ctx,
resources: resources,
}
}
type resourceContext struct {
context.Context
resources []Resource
}
type resourceKey struct{}
func (rc resourceContext) Value(key interface{}) interface{} {
if key == (resourceKey{}) {
return rc.resources
}
return rc.Context.Value(key)
}
// AuthorizedResources returns the list of resources which have
// been authorized for this request.
func AuthorizedResources(ctx context.Context) []Resource {
if resources, ok := ctx.Value(resourceKey{}).([]Resource); ok {
return resources
}
return nil
}
// InitFunc is the type of an AccessController factory function and is used
// to register the constructor for different AccesController backends.
type InitFunc func(options map[string]interface{}) (AccessController, error)
var accessControllers map[string]InitFunc
func init() {
accessControllers = make(map[string]InitFunc)
}
// Register is used to register an InitFunc for
// an AccessController backend with the given name.
func Register(name string, initFunc InitFunc) error {
if _, exists := accessControllers[name]; exists {
return fmt.Errorf("name already registered: %s", name)
}
accessControllers[name] = initFunc
return nil
}
// GetAccessController constructs an AccessController
// with the given options using the named backend.
func GetAccessController(name string, options map[string]interface{}) (AccessController, error) {
if initFunc, exists := accessControllers[name]; exists {
return initFunc(options)
}
return nil, fmt.Errorf("no access controller registered with name: %s", name)
}

View File

@ -0,0 +1,115 @@
// Package htpasswd provides a simple authentication scheme that checks for the
// user credential hash in an htpasswd formatted file in a configuration-determined
// location.
//
// This authentication method MUST be used under TLS, as simple token-replay attack is possible.
package htpasswd
import (
"fmt"
"net/http"
"os"
"sync"
"time"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
)
type accessController struct {
realm string
path string
modtime time.Time
mu sync.Mutex
htpasswd *htpasswd
}
var _ auth.AccessController = &accessController{}
func newAccessController(options map[string]interface{}) (auth.AccessController, error) {
realm, present := options["realm"]
if _, ok := realm.(string); !present || !ok {
return nil, fmt.Errorf(`"realm" must be set for htpasswd access controller`)
}
path, present := options["path"]
if _, ok := path.(string); !present || !ok {
return nil, fmt.Errorf(`"path" must be set for htpasswd access controller`)
}
return &accessController{realm: realm.(string), path: path.(string)}, nil
}
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := context.GetRequest(ctx)
if err != nil {
return nil, err
}
username, password, ok := req.BasicAuth()
if !ok {
return nil, &challenge{
realm: ac.realm,
err: auth.ErrInvalidCredential,
}
}
// Dynamically parsing the latest account list
fstat, err := os.Stat(ac.path)
if err != nil {
return nil, err
}
lastModified := fstat.ModTime()
ac.mu.Lock()
if ac.htpasswd == nil || !ac.modtime.Equal(lastModified) {
ac.modtime = lastModified
f, err := os.Open(ac.path)
if err != nil {
ac.mu.Unlock()
return nil, err
}
defer f.Close()
h, err := newHTPasswd(f)
if err != nil {
ac.mu.Unlock()
return nil, err
}
ac.htpasswd = h
}
localHTPasswd := ac.htpasswd
ac.mu.Unlock()
if err := localHTPasswd.authenticateUser(username, password); err != nil {
context.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{
realm: ac.realm,
err: auth.ErrAuthenticationFailure,
}
}
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil
}
// challenge implements the auth.Challenge interface.
type challenge struct {
realm string
err error
}
var _ auth.Challenge = challenge{}
// SetHeaders sets the basic challenge header on the response.
func (ch challenge) SetHeaders(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", ch.realm))
}
func (ch challenge) Error() string {
return fmt.Sprintf("basic authentication challenge for realm %q: %s", ch.realm, ch.err)
}
func init() {
auth.Register("htpasswd", auth.InitFunc(newAccessController))
}

View File

@ -0,0 +1,122 @@
package htpasswd
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
)
func TestBasicAccessController(t *testing.T) {
testRealm := "The-Shire"
testUsers := []string{"bilbo", "frodo", "MiShil", "DeokMan"}
testPasswords := []string{"baggins", "baggins", "새주", "공주님"}
testHtpasswdContent := `bilbo:{SHA}5siv5c0SHx681xU6GiSx9ZQryqs=
frodo:$2y$05$926C3y10Quzn/LnqQH86VOEVh/18T6RnLaS.khre96jLNL/7e.K5W
MiShil:$2y$05$0oHgwMehvoe8iAWS8I.7l.KoECXrwVaC16RPfaSCU5eVTFrATuMI2
DeokMan:공주님`
tempFile, err := ioutil.TempFile("", "htpasswd-test")
if err != nil {
t.Fatal("could not create temporary htpasswd file")
}
if _, err = tempFile.WriteString(testHtpasswdContent); err != nil {
t.Fatal("could not write temporary htpasswd file")
}
options := map[string]interface{}{
"realm": testRealm,
"path": tempFile.Name(),
}
ctx := context.Background()
accessController, err := newAccessController(options)
if err != nil {
t.Fatal("error creating access controller")
}
tempFile.Close()
var userNumber = 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithRequest(ctx, r)
authCtx, err := accessController.Authorized(ctx)
if err != nil {
switch err := err.(type) {
case auth.Challenge:
err.SetHeaders(w)
w.WriteHeader(http.StatusUnauthorized)
return
default:
t.Fatalf("unexpected error authorizing request: %v", err)
}
}
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
if !ok {
t.Fatal("basic accessController did not set auth.user context")
}
if userInfo.Name != testUsers[userNumber] {
t.Fatalf("expected user name %q, got %q", testUsers[userNumber], userInfo.Name)
}
w.WriteHeader(http.StatusNoContent)
}))
client := &http.Client{
CheckRedirect: nil,
}
req, _ := http.NewRequest("GET", server.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("unexpected error during GET: %v", err)
}
defer resp.Body.Close()
// Request should not be authorized
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("unexpected non-fail response status: %v != %v", resp.StatusCode, http.StatusUnauthorized)
}
nonbcrypt := map[string]struct{}{
"bilbo": {},
"DeokMan": {},
}
for i := 0; i < len(testUsers); i++ {
userNumber = i
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("error allocating new request: %v", err)
}
req.SetBasicAuth(testUsers[i], testPasswords[i])
resp, err = client.Do(req)
if err != nil {
t.Fatalf("unexpected error during GET: %v", err)
}
defer resp.Body.Close()
if _, ok := nonbcrypt[testUsers[i]]; ok {
// these are not allowed.
// Request should be authorized
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("unexpected non-success response status: %v != %v for %s %s", resp.StatusCode, http.StatusUnauthorized, testUsers[i], testPasswords[i])
}
} else {
// Request should be authorized
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("unexpected non-success response status: %v != %v for %s %s", resp.StatusCode, http.StatusNoContent, testUsers[i], testPasswords[i])
}
}
}
}

View File

@ -0,0 +1,82 @@
package htpasswd
import (
"bufio"
"fmt"
"io"
"strings"
"github.com/docker/distribution/registry/auth"
"golang.org/x/crypto/bcrypt"
)
// htpasswd holds a path to a system .htpasswd file and the machinery to parse
// it. Only bcrypt hash entries are supported.
type htpasswd struct {
entries map[string][]byte // maps username to password byte slice.
}
// newHTPasswd parses the reader and returns an htpasswd or an error.
func newHTPasswd(rd io.Reader) (*htpasswd, error) {
entries, err := parseHTPasswd(rd)
if err != nil {
return nil, err
}
return &htpasswd{entries: entries}, nil
}
// AuthenticateUser checks a given user:password credential against the
// receiving HTPasswd's file. If the check passes, nil is returned.
func (htpasswd *htpasswd) authenticateUser(username string, password string) error {
credentials, ok := htpasswd.entries[username]
if !ok {
// timing attack paranoia
bcrypt.CompareHashAndPassword([]byte{}, []byte(password))
return auth.ErrAuthenticationFailure
}
err := bcrypt.CompareHashAndPassword([]byte(credentials), []byte(password))
if err != nil {
return auth.ErrAuthenticationFailure
}
return nil
}
// parseHTPasswd parses the contents of htpasswd. This will read all the
// entries in the file, whether or not they are needed. An error is returned
// if a syntax errors are encountered or if the reader fails.
func parseHTPasswd(rd io.Reader) (map[string][]byte, error) {
entries := map[string][]byte{}
scanner := bufio.NewScanner(rd)
var line int
for scanner.Scan() {
line++ // 1-based line numbering
t := strings.TrimSpace(scanner.Text())
if len(t) < 1 {
continue
}
// lines that *begin* with a '#' are considered comments
if t[0] == '#' {
continue
}
i := strings.Index(t, ":")
if i < 0 || i >= len(t) {
return nil, fmt.Errorf("htpasswd: invalid entry at line %d: %q", line, scanner.Text())
}
entries[t[:i]] = []byte(t[i+1:])
}
if err := scanner.Err(); err != nil {
return nil, err
}
return entries, nil
}

View File

@ -0,0 +1,85 @@
package htpasswd
import (
"fmt"
"reflect"
"strings"
"testing"
)
func TestParseHTPasswd(t *testing.T) {
for _, tc := range []struct {
desc string
input string
err error
entries map[string][]byte
}{
{
desc: "basic example",
input: `
# This is a comment in a basic example.
bilbo:{SHA}5siv5c0SHx681xU6GiSx9ZQryqs=
frodo:$2y$05$926C3y10Quzn/LnqQH86VOEVh/18T6RnLaS.khre96jLNL/7e.K5W
MiShil:$2y$05$0oHgwMehvoe8iAWS8I.7l.KoECXrwVaC16RPfaSCU5eVTFrATuMI2
DeokMan:공주님
`,
entries: map[string][]byte{
"bilbo": []byte("{SHA}5siv5c0SHx681xU6GiSx9ZQryqs="),
"frodo": []byte("$2y$05$926C3y10Quzn/LnqQH86VOEVh/18T6RnLaS.khre96jLNL/7e.K5W"),
"MiShil": []byte("$2y$05$0oHgwMehvoe8iAWS8I.7l.KoECXrwVaC16RPfaSCU5eVTFrATuMI2"),
"DeokMan": []byte("공주님"),
},
},
{
desc: "ensures comments are filtered",
input: `
# asdf:asdf
`,
},
{
desc: "ensure midline hash is not comment",
input: `
asdf:as#df
`,
entries: map[string][]byte{
"asdf": []byte("as#df"),
},
},
{
desc: "ensure midline hash is not comment",
input: `
# A valid comment
valid:entry
asdf
`,
err: fmt.Errorf(`htpasswd: invalid entry at line 4: "asdf"`),
},
} {
entries, err := parseHTPasswd(strings.NewReader(tc.input))
if err != tc.err {
if tc.err == nil {
t.Fatalf("%s: unexpected error: %v", tc.desc, err)
} else {
if err.Error() != tc.err.Error() { // use string equality here.
t.Fatalf("%s: expected error not returned: %v != %v", tc.desc, err, tc.err)
}
}
}
if tc.err != nil {
continue // don't test output
}
// allow empty and nil to be equal
if tc.entries == nil {
tc.entries = map[string][]byte{}
}
if !reflect.DeepEqual(entries, tc.entries) {
t.Fatalf("%s: entries not parsed correctly: %v != %v", tc.desc, entries, tc.entries)
}
}
}

View File

@ -0,0 +1,97 @@
// Package silly provides a simple authentication scheme that checks for the
// existence of an Authorization header and issues access if is present and
// non-empty.
//
// This package is present as an example implementation of a minimal
// auth.AccessController and for testing. This is not suitable for any kind of
// production security.
package silly
import (
"fmt"
"net/http"
"strings"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
)
// accessController provides a simple implementation of auth.AccessController
// that simply checks for a non-empty Authorization header. It is useful for
// demonstration and testing.
type accessController struct {
realm string
service string
}
var _ auth.AccessController = &accessController{}
func newAccessController(options map[string]interface{}) (auth.AccessController, error) {
realm, present := options["realm"]
if _, ok := realm.(string); !present || !ok {
return nil, fmt.Errorf(`"realm" must be set for silly access controller`)
}
service, present := options["service"]
if _, ok := service.(string); !present || !ok {
return nil, fmt.Errorf(`"service" must be set for silly access controller`)
}
return &accessController{realm: realm.(string), service: service.(string)}, nil
}
// Authorized simply checks for the existence of the authorization header,
// responding with a bearer challenge if it doesn't exist.
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := context.GetRequest(ctx)
if err != nil {
return nil, err
}
if req.Header.Get("Authorization") == "" {
challenge := challenge{
realm: ac.realm,
service: ac.service,
}
if len(accessRecords) > 0 {
var scopes []string
for _, access := range accessRecords {
scopes = append(scopes, fmt.Sprintf("%s:%s:%s", access.Type, access.Resource.Name, access.Action))
}
challenge.scope = strings.Join(scopes, " ")
}
return nil, &challenge
}
return auth.WithUser(ctx, auth.UserInfo{Name: "silly"}), nil
}
type challenge struct {
realm string
service string
scope string
}
var _ auth.Challenge = challenge{}
// SetHeaders sets a simple bearer challenge on the response.
func (ch challenge) SetHeaders(w http.ResponseWriter) {
header := fmt.Sprintf("Bearer realm=%q,service=%q", ch.realm, ch.service)
if ch.scope != "" {
header = fmt.Sprintf("%s,scope=%q", header, ch.scope)
}
w.Header().Set("WWW-Authenticate", header)
}
func (ch challenge) Error() string {
return fmt.Sprintf("silly authentication challenge: %#v", ch)
}
// init registers the silly auth backend.
func init() {
auth.Register("silly", auth.InitFunc(newAccessController))
}

View File

@ -0,0 +1,71 @@
package silly
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
)
func TestSillyAccessController(t *testing.T) {
ac := &accessController{
realm: "test-realm",
service: "test-service",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithRequest(context.Background(), r)
authCtx, err := ac.Authorized(ctx)
if err != nil {
switch err := err.(type) {
case auth.Challenge:
err.SetHeaders(w)
w.WriteHeader(http.StatusUnauthorized)
return
default:
t.Fatalf("unexpected error authorizing request: %v", err)
}
}
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
if !ok {
t.Fatal("silly accessController did not set auth.user context")
}
if userInfo.Name != "silly" {
t.Fatalf("expected user name %q, got %q", "silly", userInfo.Name)
}
w.WriteHeader(http.StatusNoContent)
}))
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("unexpected error during GET: %v", err)
}
defer resp.Body.Close()
// Request should not be authorized
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("unexpected response status: %v != %v", resp.StatusCode, http.StatusUnauthorized)
}
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatalf("unexpected error creating new request: %v", err)
}
req.Header.Set("Authorization", "seriously, anything")
resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unexpected error during GET: %v", err)
}
defer resp.Body.Close()
// Request should not be authorized
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("unexpected response status: %v != %v", resp.StatusCode, http.StatusNoContent)
}
}

View File

@ -0,0 +1,272 @@
package token
import (
"crypto"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
"github.com/docker/libtrust"
)
// accessSet maps a typed, named resource to
// a set of actions requested or authorized.
type accessSet map[auth.Resource]actionSet
// newAccessSet constructs an accessSet from
// a variable number of auth.Access items.
func newAccessSet(accessItems ...auth.Access) accessSet {
accessSet := make(accessSet, len(accessItems))
for _, access := range accessItems {
resource := auth.Resource{
Type: access.Type,
Name: access.Name,
}
set, exists := accessSet[resource]
if !exists {
set = newActionSet()
accessSet[resource] = set
}
set.add(access.Action)
}
return accessSet
}
// contains returns whether or not the given access is in this accessSet.
func (s accessSet) contains(access auth.Access) bool {
actionSet, ok := s[access.Resource]
if ok {
return actionSet.contains(access.Action)
}
return false
}
// scopeParam returns a collection of scopes which can
// be used for a WWW-Authenticate challenge parameter.
// See https://tools.ietf.org/html/rfc6750#section-3
func (s accessSet) scopeParam() string {
scopes := make([]string, 0, len(s))
for resource, actionSet := range s {
actions := strings.Join(actionSet.keys(), ",")
scopes = append(scopes, fmt.Sprintf("%s:%s:%s", resource.Type, resource.Name, actions))
}
return strings.Join(scopes, " ")
}
// Errors used and exported by this package.
var (
ErrInsufficientScope = errors.New("insufficient scope")
ErrTokenRequired = errors.New("authorization token required")
)
// authChallenge implements the auth.Challenge interface.
type authChallenge struct {
err error
realm string
service string
accessSet accessSet
}
var _ auth.Challenge = authChallenge{}
// Error returns the internal error string for this authChallenge.
func (ac authChallenge) Error() string {
return ac.err.Error()
}
// Status returns the HTTP Response Status Code for this authChallenge.
func (ac authChallenge) Status() int {
return http.StatusUnauthorized
}
// challengeParams constructs the value to be used in
// the WWW-Authenticate response challenge header.
// See https://tools.ietf.org/html/rfc6750#section-3
func (ac authChallenge) challengeParams() string {
str := fmt.Sprintf("Bearer realm=%q,service=%q", ac.realm, ac.service)
if scope := ac.accessSet.scopeParam(); scope != "" {
str = fmt.Sprintf("%s,scope=%q", str, scope)
}
if ac.err == ErrInvalidToken || ac.err == ErrMalformedToken {
str = fmt.Sprintf("%s,error=%q", str, "invalid_token")
} else if ac.err == ErrInsufficientScope {
str = fmt.Sprintf("%s,error=%q", str, "insufficient_scope")
}
return str
}
// SetChallenge sets the WWW-Authenticate value for the response.
func (ac authChallenge) SetHeaders(w http.ResponseWriter) {
w.Header().Add("WWW-Authenticate", ac.challengeParams())
}
// accessController implements the auth.AccessController interface.
type accessController struct {
realm string
issuer string
service string
rootCerts *x509.CertPool
trustedKeys map[string]libtrust.PublicKey
}
// tokenAccessOptions is a convenience type for handling
// options to the contstructor of an accessController.
type tokenAccessOptions struct {
realm string
issuer string
service string
rootCertBundle string
}
// checkOptions gathers the necessary options
// for an accessController from the given map.
func checkOptions(options map[string]interface{}) (tokenAccessOptions, error) {
var opts tokenAccessOptions
keys := []string{"realm", "issuer", "service", "rootcertbundle"}
vals := make([]string, 0, len(keys))
for _, key := range keys {
val, ok := options[key].(string)
if !ok {
return opts, fmt.Errorf("token auth requires a valid option string: %q", key)
}
vals = append(vals, val)
}
opts.realm, opts.issuer, opts.service, opts.rootCertBundle = vals[0], vals[1], vals[2], vals[3]
return opts, nil
}
// newAccessController creates an accessController using the given options.
func newAccessController(options map[string]interface{}) (auth.AccessController, error) {
config, err := checkOptions(options)
if err != nil {
return nil, err
}
fp, err := os.Open(config.rootCertBundle)
if err != nil {
return nil, fmt.Errorf("unable to open token auth root certificate bundle file %q: %s", config.rootCertBundle, err)
}
defer fp.Close()
rawCertBundle, err := ioutil.ReadAll(fp)
if err != nil {
return nil, fmt.Errorf("unable to read token auth root certificate bundle file %q: %s", config.rootCertBundle, err)
}
var rootCerts []*x509.Certificate
pemBlock, rawCertBundle := pem.Decode(rawCertBundle)
for pemBlock != nil {
if pemBlock.Type == "CERTIFICATE" {
cert, err := x509.ParseCertificate(pemBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse token auth root certificate: %s", err)
}
rootCerts = append(rootCerts, cert)
}
pemBlock, rawCertBundle = pem.Decode(rawCertBundle)
}
if len(rootCerts) == 0 {
return nil, errors.New("token auth requires at least one token signing root certificate")
}
rootPool := x509.NewCertPool()
trustedKeys := make(map[string]libtrust.PublicKey, len(rootCerts))
for _, rootCert := range rootCerts {
rootPool.AddCert(rootCert)
pubKey, err := libtrust.FromCryptoPublicKey(crypto.PublicKey(rootCert.PublicKey))
if err != nil {
return nil, fmt.Errorf("unable to get public key from token auth root certificate: %s", err)
}
trustedKeys[pubKey.KeyID()] = pubKey
}
return &accessController{
realm: config.realm,
issuer: config.issuer,
service: config.service,
rootCerts: rootPool,
trustedKeys: trustedKeys,
}, nil
}
// Authorized handles checking whether the given request is authorized
// for actions on resources described by the given access items.
func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) {
challenge := &authChallenge{
realm: ac.realm,
service: ac.service,
accessSet: newAccessSet(accessItems...),
}
req, err := context.GetRequest(ctx)
if err != nil {
return nil, err
}
parts := strings.Split(req.Header.Get("Authorization"), " ")
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
challenge.err = ErrTokenRequired
return nil, challenge
}
rawToken := parts[1]
token, err := NewToken(rawToken)
if err != nil {
challenge.err = err
return nil, challenge
}
verifyOpts := VerifyOptions{
TrustedIssuers: []string{ac.issuer},
AcceptedAudiences: []string{ac.service},
Roots: ac.rootCerts,
TrustedKeys: ac.trustedKeys,
}
if err = token.Verify(verifyOpts); err != nil {
challenge.err = err
return nil, challenge
}
accessSet := token.accessSet()
for _, access := range accessItems {
if !accessSet.contains(access) {
challenge.err = ErrInsufficientScope
return nil, challenge
}
}
ctx = auth.WithResources(ctx, token.resources())
return auth.WithUser(ctx, auth.UserInfo{Name: token.Claims.Subject}), nil
}
// init handles registering the token auth backend.
func init() {
auth.Register("token", auth.InitFunc(newAccessController))
}

View File

@ -0,0 +1,35 @@
package token
// StringSet is a useful type for looking up strings.
type stringSet map[string]struct{}
// NewStringSet creates a new StringSet with the given strings.
func newStringSet(keys ...string) stringSet {
ss := make(stringSet, len(keys))
ss.add(keys...)
return ss
}
// Add inserts the given keys into this StringSet.
func (ss stringSet) add(keys ...string) {
for _, key := range keys {
ss[key] = struct{}{}
}
}
// Contains returns whether the given key is in this StringSet.
func (ss stringSet) contains(key string) bool {
_, ok := ss[key]
return ok
}
// Keys returns a slice of all keys in this StringSet.
func (ss stringSet) keys() []string {
keys := make([]string, 0, len(ss))
for key := range ss {
keys = append(keys, key)
}
return keys
}

View File

@ -0,0 +1,378 @@
package token
import (
"crypto"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/docker/libtrust"
"github.com/docker/distribution/registry/auth"
)
const (
// TokenSeparator is the value which separates the header, claims, and
// signature in the compact serialization of a JSON Web Token.
TokenSeparator = "."
// Leeway is the Duration that will be added to NBF and EXP claim
// checks to account for clock skew as per https://tools.ietf.org/html/rfc7519#section-4.1.5
Leeway = 60 * time.Second
)
// Errors used by token parsing and verification.
var (
ErrMalformedToken = errors.New("malformed token")
ErrInvalidToken = errors.New("invalid token")
)
// ResourceActions stores allowed actions on a named and typed resource.
type ResourceActions struct {
Type string `json:"type"`
Class string `json:"class,omitempty"`
Name string `json:"name"`
Actions []string `json:"actions"`
}
// ClaimSet describes the main section of a JSON Web Token.
type ClaimSet struct {
// Public claims
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience string `json:"aud"`
Expiration int64 `json:"exp"`
NotBefore int64 `json:"nbf"`
IssuedAt int64 `json:"iat"`
JWTID string `json:"jti"`
// Private claims
Access []*ResourceActions `json:"access"`
}
// Header describes the header section of a JSON Web Token.
type Header struct {
Type string `json:"typ"`
SigningAlg string `json:"alg"`
KeyID string `json:"kid,omitempty"`
X5c []string `json:"x5c,omitempty"`
RawJWK *json.RawMessage `json:"jwk,omitempty"`
}
// Token describes a JSON Web Token.
type Token struct {
Raw string
Header *Header
Claims *ClaimSet
Signature []byte
}
// VerifyOptions is used to specify
// options when verifying a JSON Web Token.
type VerifyOptions struct {
TrustedIssuers []string
AcceptedAudiences []string
Roots *x509.CertPool
TrustedKeys map[string]libtrust.PublicKey
}
// NewToken parses the given raw token string
// and constructs an unverified JSON Web Token.
func NewToken(rawToken string) (*Token, error) {
parts := strings.Split(rawToken, TokenSeparator)
if len(parts) != 3 {
return nil, ErrMalformedToken
}
var (
rawHeader, rawClaims = parts[0], parts[1]
headerJSON, claimsJSON []byte
err error
)
defer func() {
if err != nil {
log.Infof("error while unmarshalling raw token: %s", err)
}
}()
if headerJSON, err = joseBase64UrlDecode(rawHeader); err != nil {
err = fmt.Errorf("unable to decode header: %s", err)
return nil, ErrMalformedToken
}
if claimsJSON, err = joseBase64UrlDecode(rawClaims); err != nil {
err = fmt.Errorf("unable to decode claims: %s", err)
return nil, ErrMalformedToken
}
token := new(Token)
token.Header = new(Header)
token.Claims = new(ClaimSet)
token.Raw = strings.Join(parts[:2], TokenSeparator)
if token.Signature, err = joseBase64UrlDecode(parts[2]); err != nil {
err = fmt.Errorf("unable to decode signature: %s", err)
return nil, ErrMalformedToken
}
if err = json.Unmarshal(headerJSON, token.Header); err != nil {
return nil, ErrMalformedToken
}
if err = json.Unmarshal(claimsJSON, token.Claims); err != nil {
return nil, ErrMalformedToken
}
return token, nil
}
// Verify attempts to verify this token using the given options.
// Returns a nil error if the token is valid.
func (t *Token) Verify(verifyOpts VerifyOptions) error {
// Verify that the Issuer claim is a trusted authority.
if !contains(verifyOpts.TrustedIssuers, t.Claims.Issuer) {
log.Infof("token from untrusted issuer: %q", t.Claims.Issuer)
return ErrInvalidToken
}
// Verify that the Audience claim is allowed.
if !contains(verifyOpts.AcceptedAudiences, t.Claims.Audience) {
log.Infof("token intended for another audience: %q", t.Claims.Audience)
return ErrInvalidToken
}
// Verify that the token is currently usable and not expired.
currentTime := time.Now()
ExpWithLeeway := time.Unix(t.Claims.Expiration, 0).Add(Leeway)
if currentTime.After(ExpWithLeeway) {
log.Infof("token not to be used after %s - currently %s", ExpWithLeeway, currentTime)
return ErrInvalidToken
}
NotBeforeWithLeeway := time.Unix(t.Claims.NotBefore, 0).Add(-Leeway)
if currentTime.Before(NotBeforeWithLeeway) {
log.Infof("token not to be used before %s - currently %s", NotBeforeWithLeeway, currentTime)
return ErrInvalidToken
}
// Verify the token signature.
if len(t.Signature) == 0 {
log.Info("token has no signature")
return ErrInvalidToken
}
// Verify that the signing key is trusted.
signingKey, err := t.VerifySigningKey(verifyOpts)
if err != nil {
log.Info(err)
return ErrInvalidToken
}
// Finally, verify the signature of the token using the key which signed it.
if err := signingKey.Verify(strings.NewReader(t.Raw), t.Header.SigningAlg, t.Signature); err != nil {
log.Infof("unable to verify token signature: %s", err)
return ErrInvalidToken
}
return nil
}
// VerifySigningKey attempts to get the key which was used to sign this token.
// The token header should contain either of these 3 fields:
// `x5c` - The x509 certificate chain for the signing key. Needs to be
// verified.
// `jwk` - The JSON Web Key representation of the signing key.
// May contain its own `x5c` field which needs to be verified.
// `kid` - The unique identifier for the key. This library interprets it
// as a libtrust fingerprint. The key itself can be looked up in
// the trustedKeys field of the given verify options.
// Each of these methods are tried in that order of preference until the
// signing key is found or an error is returned.
func (t *Token) VerifySigningKey(verifyOpts VerifyOptions) (signingKey libtrust.PublicKey, err error) {
// First attempt to get an x509 certificate chain from the header.
var (
x5c = t.Header.X5c
rawJWK = t.Header.RawJWK
keyID = t.Header.KeyID
)
switch {
case len(x5c) > 0:
signingKey, err = parseAndVerifyCertChain(x5c, verifyOpts.Roots)
case rawJWK != nil:
signingKey, err = parseAndVerifyRawJWK(rawJWK, verifyOpts)
case len(keyID) > 0:
signingKey = verifyOpts.TrustedKeys[keyID]
if signingKey == nil {
err = fmt.Errorf("token signed by untrusted key with ID: %q", keyID)
}
default:
err = errors.New("unable to get token signing key")
}
return
}
func parseAndVerifyCertChain(x5c []string, roots *x509.CertPool) (leafKey libtrust.PublicKey, err error) {
if len(x5c) == 0 {
return nil, errors.New("empty x509 certificate chain")
}
// Ensure the first element is encoded correctly.
leafCertDer, err := base64.StdEncoding.DecodeString(x5c[0])
if err != nil {
return nil, fmt.Errorf("unable to decode leaf certificate: %s", err)
}
// And that it is a valid x509 certificate.
leafCert, err := x509.ParseCertificate(leafCertDer)
if err != nil {
return nil, fmt.Errorf("unable to parse leaf certificate: %s", err)
}
// The rest of the certificate chain are intermediate certificates.
intermediates := x509.NewCertPool()
for i := 1; i < len(x5c); i++ {
intermediateCertDer, err := base64.StdEncoding.DecodeString(x5c[i])
if err != nil {
return nil, fmt.Errorf("unable to decode intermediate certificate: %s", err)
}
intermediateCert, err := x509.ParseCertificate(intermediateCertDer)
if err != nil {
return nil, fmt.Errorf("unable to parse intermediate certificate: %s", err)
}
intermediates.AddCert(intermediateCert)
}
verifyOpts := x509.VerifyOptions{
Intermediates: intermediates,
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
}
// TODO: this call returns certificate chains which we ignore for now, but
// we should check them for revocations if we have the ability later.
if _, err = leafCert.Verify(verifyOpts); err != nil {
return nil, fmt.Errorf("unable to verify certificate chain: %s", err)
}
// Get the public key from the leaf certificate.
leafCryptoKey, ok := leafCert.PublicKey.(crypto.PublicKey)
if !ok {
return nil, errors.New("unable to get leaf cert public key value")
}
leafKey, err = libtrust.FromCryptoPublicKey(leafCryptoKey)
if err != nil {
return nil, fmt.Errorf("unable to make libtrust public key from leaf certificate: %s", err)
}
return
}
func parseAndVerifyRawJWK(rawJWK *json.RawMessage, verifyOpts VerifyOptions) (pubKey libtrust.PublicKey, err error) {
pubKey, err = libtrust.UnmarshalPublicKeyJWK([]byte(*rawJWK))
if err != nil {
return nil, fmt.Errorf("unable to decode raw JWK value: %s", err)
}
// Check to see if the key includes a certificate chain.
x5cVal, ok := pubKey.GetExtendedField("x5c").([]interface{})
if !ok {
// The JWK should be one of the trusted root keys.
if _, trusted := verifyOpts.TrustedKeys[pubKey.KeyID()]; !trusted {
return nil, errors.New("untrusted JWK with no certificate chain")
}
// The JWK is one of the trusted keys.
return
}
// Ensure each item in the chain is of the correct type.
x5c := make([]string, len(x5cVal))
for i, val := range x5cVal {
certString, ok := val.(string)
if !ok || len(certString) == 0 {
return nil, errors.New("malformed certificate chain")
}
x5c[i] = certString
}
// Ensure that the x509 certificate chain can
// be verified up to one of our trusted roots.
leafKey, err := parseAndVerifyCertChain(x5c, verifyOpts.Roots)
if err != nil {
return nil, fmt.Errorf("could not verify JWK certificate chain: %s", err)
}
// Verify that the public key in the leaf cert *is* the signing key.
if pubKey.KeyID() != leafKey.KeyID() {
return nil, errors.New("leaf certificate public key ID does not match JWK key ID")
}
return
}
// accessSet returns a set of actions available for the resource
// actions listed in the `access` section of this token.
func (t *Token) accessSet() accessSet {
if t.Claims == nil {
return nil
}
accessSet := make(accessSet, len(t.Claims.Access))
for _, resourceActions := range t.Claims.Access {
resource := auth.Resource{
Type: resourceActions.Type,
Name: resourceActions.Name,
}
set, exists := accessSet[resource]
if !exists {
set = newActionSet()
accessSet[resource] = set
}
for _, action := range resourceActions.Actions {
set.add(action)
}
}
return accessSet
}
func (t *Token) resources() []auth.Resource {
if t.Claims == nil {
return nil
}
resourceSet := map[auth.Resource]struct{}{}
for _, resourceActions := range t.Claims.Access {
resource := auth.Resource{
Type: resourceActions.Type,
Class: resourceActions.Class,
Name: resourceActions.Name,
}
resourceSet[resource] = struct{}{}
}
resources := make([]auth.Resource, 0, len(resourceSet))
for resource := range resourceSet {
resources = append(resources, resource)
}
return resources
}
func (t *Token) compactRaw() string {
return fmt.Sprintf("%s.%s", t.Raw, joseBase64UrlEncode(t.Signature))
}

View File

@ -0,0 +1,510 @@
package token
import (
"crypto"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"testing"
"time"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/auth"
"github.com/docker/libtrust"
)
func makeRootKeys(numKeys int) ([]libtrust.PrivateKey, error) {
keys := make([]libtrust.PrivateKey, 0, numKeys)
for i := 0; i < numKeys; i++ {
key, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, nil
}
func makeSigningKeyWithChain(rootKey libtrust.PrivateKey, depth int) (libtrust.PrivateKey, error) {
if depth == 0 {
// Don't need to build a chain.
return rootKey, nil
}
var (
x5c = make([]string, depth)
parentKey = rootKey
key libtrust.PrivateKey
cert *x509.Certificate
err error
)
for depth > 0 {
if key, err = libtrust.GenerateECP256PrivateKey(); err != nil {
return nil, err
}
if cert, err = libtrust.GenerateCACert(parentKey, key); err != nil {
return nil, err
}
depth--
x5c[depth] = base64.StdEncoding.EncodeToString(cert.Raw)
parentKey = key
}
key.AddExtendedField("x5c", x5c)
return key, nil
}
func makeRootCerts(rootKeys []libtrust.PrivateKey) ([]*x509.Certificate, error) {
certs := make([]*x509.Certificate, 0, len(rootKeys))
for _, key := range rootKeys {
cert, err := libtrust.GenerateCACert(key, key)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
return certs, nil
}
func makeTrustedKeyMap(rootKeys []libtrust.PrivateKey) map[string]libtrust.PublicKey {
trustedKeys := make(map[string]libtrust.PublicKey, len(rootKeys))
for _, key := range rootKeys {
trustedKeys[key.KeyID()] = key.PublicKey()
}
return trustedKeys
}
func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey libtrust.PrivateKey, depth int, now time.Time, exp time.Time) (*Token, error) {
signingKey, err := makeSigningKeyWithChain(rootKey, depth)
if err != nil {
return nil, fmt.Errorf("unable to make signing key with chain: %s", err)
}
var rawJWK json.RawMessage
rawJWK, err = signingKey.PublicKey().MarshalJSON()
if err != nil {
return nil, fmt.Errorf("unable to marshal signing key to JSON: %s", err)
}
joseHeader := &Header{
Type: "JWT",
SigningAlg: "ES256",
RawJWK: &rawJWK,
}
randomBytes := make([]byte, 15)
if _, err = rand.Read(randomBytes); err != nil {
return nil, fmt.Errorf("unable to read random bytes for jwt id: %s", err)
}
claimSet := &ClaimSet{
Issuer: issuer,
Subject: "foo",
Audience: audience,
Expiration: exp.Unix(),
NotBefore: now.Unix(),
IssuedAt: now.Unix(),
JWTID: base64.URLEncoding.EncodeToString(randomBytes),
Access: access,
}
var joseHeaderBytes, claimSetBytes []byte
if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil {
return nil, fmt.Errorf("unable to marshal jose header: %s", err)
}
if claimSetBytes, err = json.Marshal(claimSet); err != nil {
return nil, fmt.Errorf("unable to marshal claim set: %s", err)
}
encodedJoseHeader := joseBase64UrlEncode(joseHeaderBytes)
encodedClaimSet := joseBase64UrlEncode(claimSetBytes)
encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet)
var signatureBytes []byte
if signatureBytes, _, err = signingKey.Sign(strings.NewReader(encodingToSign), crypto.SHA256); err != nil {
return nil, fmt.Errorf("unable to sign jwt payload: %s", err)
}
signature := joseBase64UrlEncode(signatureBytes)
tokenString := fmt.Sprintf("%s.%s", encodingToSign, signature)
return NewToken(tokenString)
}
// This test makes 4 tokens with a varying number of intermediate
// certificates ranging from no intermediate chain to a length of 3
// intermediates.
func TestTokenVerify(t *testing.T) {
var (
numTokens = 4
issuer = "test-issuer"
audience = "test-audience"
access = []*ResourceActions{
{
Type: "repository",
Name: "foo/bar",
Actions: []string{"pull", "push"},
},
}
)
rootKeys, err := makeRootKeys(numTokens)
if err != nil {
t.Fatal(err)
}
rootCerts, err := makeRootCerts(rootKeys)
if err != nil {
t.Fatal(err)
}
rootPool := x509.NewCertPool()
for _, rootCert := range rootCerts {
rootPool.AddCert(rootCert)
}
trustedKeys := makeTrustedKeyMap(rootKeys)
tokens := make([]*Token, 0, numTokens)
for i := 0; i < numTokens; i++ {
token, err := makeTestToken(issuer, audience, access, rootKeys[i], i, time.Now(), time.Now().Add(5*time.Minute))
if err != nil {
t.Fatal(err)
}
tokens = append(tokens, token)
}
verifyOps := VerifyOptions{
TrustedIssuers: []string{issuer},
AcceptedAudiences: []string{audience},
Roots: rootPool,
TrustedKeys: trustedKeys,
}
for _, token := range tokens {
if err := token.Verify(verifyOps); err != nil {
t.Fatal(err)
}
}
}
// This tests that we don't fail tokens with nbf within
// the defined leeway in seconds
func TestLeeway(t *testing.T) {
var (
issuer = "test-issuer"
audience = "test-audience"
access = []*ResourceActions{
{
Type: "repository",
Name: "foo/bar",
Actions: []string{"pull", "push"},
},
}
)
rootKeys, err := makeRootKeys(1)
if err != nil {
t.Fatal(err)
}
trustedKeys := makeTrustedKeyMap(rootKeys)
verifyOps := VerifyOptions{
TrustedIssuers: []string{issuer},
AcceptedAudiences: []string{audience},
Roots: nil,
TrustedKeys: trustedKeys,
}
// nbf verification should pass within leeway
futureNow := time.Now().Add(time.Duration(5) * time.Second)
token, err := makeTestToken(issuer, audience, access, rootKeys[0], 0, futureNow, futureNow.Add(5*time.Minute))
if err != nil {
t.Fatal(err)
}
if err := token.Verify(verifyOps); err != nil {
t.Fatal(err)
}
// nbf verification should fail with a skew larger than leeway
futureNow = time.Now().Add(time.Duration(61) * time.Second)
token, err = makeTestToken(issuer, audience, access, rootKeys[0], 0, futureNow, futureNow.Add(5*time.Minute))
if err != nil {
t.Fatal(err)
}
if err = token.Verify(verifyOps); err == nil {
t.Fatal("Verification should fail for token with nbf in the future outside leeway")
}
// exp verification should pass within leeway
token, err = makeTestToken(issuer, audience, access, rootKeys[0], 0, time.Now(), time.Now().Add(-59*time.Second))
if err != nil {
t.Fatal(err)
}
if err = token.Verify(verifyOps); err != nil {
t.Fatal(err)
}
// exp verification should fail with a skew larger than leeway
token, err = makeTestToken(issuer, audience, access, rootKeys[0], 0, time.Now(), time.Now().Add(-60*time.Second))
if err != nil {
t.Fatal(err)
}
if err = token.Verify(verifyOps); err == nil {
t.Fatal("Verification should fail for token with exp in the future outside leeway")
}
}
func writeTempRootCerts(rootKeys []libtrust.PrivateKey) (filename string, err error) {
rootCerts, err := makeRootCerts(rootKeys)
if err != nil {
return "", err
}
tempFile, err := ioutil.TempFile("", "rootCertBundle")
if err != nil {
return "", err
}
defer tempFile.Close()
for _, cert := range rootCerts {
if err = pem.Encode(tempFile, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}); err != nil {
os.Remove(tempFile.Name())
return "", err
}
}
return tempFile.Name(), nil
}
// TestAccessController tests complete integration of the token auth package.
// It starts by mocking the options for a token auth accessController which
// it creates. It then tries a few mock requests:
// - don't supply a token; should error with challenge
// - supply an invalid token; should error with challenge
// - supply a token with insufficient access; should error with challenge
// - supply a valid token; should not error
func TestAccessController(t *testing.T) {
// Make 2 keys; only the first is to be a trusted root key.
rootKeys, err := makeRootKeys(2)
if err != nil {
t.Fatal(err)
}
rootCertBundleFilename, err := writeTempRootCerts(rootKeys[:1])
if err != nil {
t.Fatal(err)
}
defer os.Remove(rootCertBundleFilename)
realm := "https://auth.example.com/token/"
issuer := "test-issuer.example.com"
service := "test-service.example.com"
options := map[string]interface{}{
"realm": realm,
"issuer": issuer,
"service": service,
"rootcertbundle": rootCertBundleFilename,
}
accessController, err := newAccessController(options)
if err != nil {
t.Fatal(err)
}
// 1. Make a mock http.Request with no token.
req, err := http.NewRequest("GET", "http://example.com/foo", nil)
if err != nil {
t.Fatal(err)
}
testAccess := auth.Access{
Resource: auth.Resource{
Type: "foo",
Name: "bar",
},
Action: "baz",
}
ctx := context.WithRequest(context.Background(), req)
authCtx, err := accessController.Authorized(ctx, testAccess)
challenge, ok := err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
}
if challenge.Error() != ErrTokenRequired.Error() {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
}
if authCtx != nil {
t.Fatalf("expected nil auth context but got %s", authCtx)
}
// 2. Supply an invalid token.
token, err := makeTestToken(
issuer, service,
[]*ResourceActions{{
Type: testAccess.Type,
Name: testAccess.Name,
Actions: []string{testAccess.Action},
}},
rootKeys[1], 1, time.Now(), time.Now().Add(5*time.Minute), // Everything is valid except the key which signed it.
)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
authCtx, err = accessController.Authorized(ctx, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
}
if challenge.Error() != ErrInvalidToken.Error() {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
}
if authCtx != nil {
t.Fatalf("expected nil auth context but got %s", authCtx)
}
// 3. Supply a token with insufficient access.
token, err = makeTestToken(
issuer, service,
[]*ResourceActions{}, // No access specified.
rootKeys[0], 1, time.Now(), time.Now().Add(5*time.Minute),
)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
authCtx, err = accessController.Authorized(ctx, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
}
if challenge.Error() != ErrInsufficientScope.Error() {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope)
}
if authCtx != nil {
t.Fatalf("expected nil auth context but got %s", authCtx)
}
// 4. Supply the token we need, or deserve, or whatever.
token, err = makeTestToken(
issuer, service,
[]*ResourceActions{{
Type: testAccess.Type,
Name: testAccess.Name,
Actions: []string{testAccess.Action},
}},
rootKeys[0], 1, time.Now(), time.Now().Add(5*time.Minute),
)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
authCtx, err = accessController.Authorized(ctx, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
if !ok {
t.Fatal("token accessController did not set auth.user context")
}
if userInfo.Name != "foo" {
t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name)
}
}
// This tests that newAccessController can handle PEM blocks in the certificate
// file other than certificates, for example a private key.
func TestNewAccessControllerPemBlock(t *testing.T) {
rootKeys, err := makeRootKeys(2)
if err != nil {
t.Fatal(err)
}
rootCertBundleFilename, err := writeTempRootCerts(rootKeys)
if err != nil {
t.Fatal(err)
}
defer os.Remove(rootCertBundleFilename)
// Add something other than a certificate to the rootcertbundle
file, err := os.OpenFile(rootCertBundleFilename, os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
t.Fatal(err)
}
keyBlock, err := rootKeys[0].PEMBlock()
if err != nil {
t.Fatal(err)
}
err = pem.Encode(file, keyBlock)
if err != nil {
t.Fatal(err)
}
err = file.Close()
if err != nil {
t.Fatal(err)
}
realm := "https://auth.example.com/token/"
issuer := "test-issuer.example.com"
service := "test-service.example.com"
options := map[string]interface{}{
"realm": realm,
"issuer": issuer,
"service": service,
"rootcertbundle": rootCertBundleFilename,
}
ac, err := newAccessController(options)
if err != nil {
t.Fatal(err)
}
if len(ac.(*accessController).rootCerts.Subjects()) != 2 {
t.Fatal("accessController has the wrong number of certificates")
}
}

View File

@ -0,0 +1,58 @@
package token
import (
"encoding/base64"
"errors"
"strings"
)
// joseBase64UrlEncode encodes the given data using the standard base64 url
// encoding format but with all trailing '=' characters omitted in accordance
// with the jose specification.
// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
func joseBase64UrlEncode(b []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
}
// joseBase64UrlDecode decodes the given string using the standard base64 url
// decoder but first adds the appropriate number of trailing '=' characters in
// accordance with the jose specification.
// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
func joseBase64UrlDecode(s string) ([]byte, error) {
switch len(s) % 4 {
case 0:
case 2:
s += "=="
case 3:
s += "="
default:
return nil, errors.New("illegal base64url string")
}
return base64.URLEncoding.DecodeString(s)
}
// actionSet is a special type of stringSet.
type actionSet struct {
stringSet
}
func newActionSet(actions ...string) actionSet {
return actionSet{newStringSet(actions...)}
}
// Contains calls StringSet.Contains() for
// either "*" or the given action string.
func (s actionSet) contains(action string) bool {
return s.stringSet.contains("*") || s.stringSet.contains(action)
}
// contains returns true if q is found in ss.
func contains(ss []string, q string) bool {
for _, s := range ss {
if s == q {
return true
}
}
return false
}

View File

@ -0,0 +1,127 @@
package challenge
import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"testing"
)
func TestAuthChallengeParse(t *testing.T) {
header := http.Header{}
header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`)
challenges := parseAuthHeader(header)
if len(challenges) != 1 {
t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges))
}
challenge := challenges[0]
if expected := "bearer"; challenge.Scheme != expected {
t.Fatalf("Unexpected scheme: %s, expected: %s", challenge.Scheme, expected)
}
if expected := "https://auth.example.com/token"; challenge.Parameters["realm"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["realm"], expected)
}
if expected := "registry.example.com"; challenge.Parameters["service"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["service"], expected)
}
if expected := "fun"; challenge.Parameters["other"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["other"], expected)
}
if expected := "he\"llo"; challenge.Parameters["slashed"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["slashed"], expected)
}
}
func TestAuthChallengeNormalization(t *testing.T) {
testAuthChallengeNormalization(t, "reg.EXAMPLE.com")
testAuthChallengeNormalization(t, "bɿɒʜɔiɿ-ɿɘƚƨim-ƚol-ɒ-ƨʞnɒʜƚ.com")
testAuthChallengeNormalization(t, "reg.example.com:80")
testAuthChallengeConcurrent(t, "reg.EXAMPLE.com")
}
func testAuthChallengeNormalization(t *testing.T, host string) {
scm := NewSimpleManager()
url, err := url.Parse(fmt.Sprintf("http://%s/v2/", host))
if err != nil {
t.Fatal(err)
}
resp := &http.Response{
Request: &http.Request{
URL: url,
},
Header: make(http.Header),
StatusCode: http.StatusUnauthorized,
}
resp.Header.Add("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"https://%s/token\",service=\"registry.example.com\"", host))
err = scm.AddResponse(resp)
if err != nil {
t.Fatal(err)
}
lowered := *url
lowered.Host = strings.ToLower(lowered.Host)
lowered.Host = canonicalAddr(&lowered)
c, err := scm.GetChallenges(lowered)
if err != nil {
t.Fatal(err)
}
if len(c) == 0 {
t.Fatal("Expected challenge for lower-cased-host URL")
}
}
func testAuthChallengeConcurrent(t *testing.T, host string) {
scm := NewSimpleManager()
url, err := url.Parse(fmt.Sprintf("http://%s/v2/", host))
if err != nil {
t.Fatal(err)
}
resp := &http.Response{
Request: &http.Request{
URL: url,
},
Header: make(http.Header),
StatusCode: http.StatusUnauthorized,
}
resp.Header.Add("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"https://%s/token\",service=\"registry.example.com\"", host))
var s sync.WaitGroup
s.Add(2)
go func() {
defer s.Done()
for i := 0; i < 200; i++ {
err = scm.AddResponse(resp)
if err != nil {
t.Error(err)
}
}
}()
go func() {
defer s.Done()
lowered := *url
lowered.Host = strings.ToLower(lowered.Host)
for k := 0; k < 200; k++ {
_, err := scm.GetChallenges(lowered)
if err != nil {
t.Error(err)
}
}
}()
s.Wait()
}

View File

@ -155,9 +155,7 @@ type RepositoryScope struct {
// using the scope grammar
func (rs RepositoryScope) String() string {
repoType := "repository"
// Keep existing format for image class to maintain backwards compatibility
// with authorization servers which do not support the expanded grammar.
if rs.Class != "" && rs.Class != "image" {
if rs.Class != "" {
repoType = fmt.Sprintf("%s(%s)", repoType, rs.Class)
}
return fmt.Sprintf("%s:%s:%s", repoType, rs.Repository, strings.Join(rs.Actions, ","))

View File

@ -0,0 +1,866 @@
package auth
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/docker/distribution/registry/client/auth/challenge"
"github.com/docker/distribution/registry/client/transport"
"github.com/docker/distribution/testutil"
)
// An implementation of clock for providing fake time data.
type fakeClock struct {
current time.Time
}
// Now implements clock
func (fc *fakeClock) Now() time.Time { return fc.current }
func testServer(rrm testutil.RequestResponseMap) (string, func()) {
h := testutil.NewHandler(rrm)
s := httptest.NewServer(h)
return s.URL, s.Close
}
type testAuthenticationWrapper struct {
headers http.Header
authCheck func(string) bool
next http.Handler
}
func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth == "" || !w.authCheck(auth) {
h := rw.Header()
for k, values := range w.headers {
h[k] = values
}
rw.WriteHeader(http.StatusUnauthorized)
return
}
w.next.ServeHTTP(rw, r)
}
func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (string, func()) {
h := testutil.NewHandler(rrm)
wrapper := &testAuthenticationWrapper{
headers: http.Header(map[string][]string{
"X-API-Version": {"registry/2.0"},
"X-Multi-API-Version": {"registry/2.0", "registry/2.1", "trust/1.0"},
"WWW-Authenticate": {authenticate},
}),
authCheck: authCheck,
next: h,
}
s := httptest.NewServer(wrapper)
return s.URL, s.Close
}
// ping pings the provided endpoint to determine its required authorization challenges.
// If a version header is provided, the versions will be returned.
func ping(manager challenge.Manager, endpoint, versionHeader string) ([]APIVersion, error) {
resp, err := http.Get(endpoint)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if err := manager.AddResponse(resp); err != nil {
return nil, err
}
return APIVersions(resp, versionHeader), err
}
type testCredentialStore struct {
username string
password string
refreshTokens map[string]string
}
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
return tcs.username, tcs.password
}
func (tcs *testCredentialStore) RefreshToken(u *url.URL, service string) string {
return tcs.refreshTokens[service]
}
func (tcs *testCredentialStore) SetRefreshToken(u *url.URL, service string, token string) {
if tcs.refreshTokens != nil {
tcs.refreshTokens[service] = token
}
}
func TestEndpointAuthorizeToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
repo2 := "other/registry"
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken"}`),
},
},
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"badtoken"}`),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
challengeManager1 := challenge.NewSimpleManager()
versions, err := ping(challengeManager1, e+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandler(nil, nil, repo1, "pull", "push")))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
e2, c2 := testServerWithAuth(m, authenicate, validCheck)
defer c2()
challengeManager2 := challenge.NewSimpleManager()
versions, err = ping(challengeManager2, e2+"/v2/", "x-multi-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 3 {
t.Fatalf("Unexpected version count: %d, expected 3", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
if check := (APIVersion{Type: "registry", Version: "2.1"}); versions[1] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[1], check)
}
if check := (APIVersion{Type: "trust", Version: "1.0"}); versions[2] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[2], check)
}
transport2 := transport.NewTransport(nil, NewAuthorizer(challengeManager2, NewTokenHandler(nil, nil, repo2, "pull", "push")))
client2 := &http.Client{Transport: transport2}
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
resp, err = client2.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func TestEndpointAuthorizeRefreshToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
repo2 := "other/registry"
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
refreshToken1 := "0123456790abcdef"
refreshToken2 := "0123456790fedcba"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope1), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken1)),
},
},
{
// In the future this test may fail and require using basic auth to get a different refresh token
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken2)),
},
},
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken2, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"badtoken","refresh_token":"%s"}`),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
challengeManager1 := challenge.NewSimpleManager()
versions, err := ping(challengeManager1, e+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
creds := &testCredentialStore{
refreshTokens: map[string]string{
service: refreshToken1,
},
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandler(nil, creds, repo1, "pull", "push")))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
// Try with refresh token setting
e2, c2 := testServerWithAuth(m, authenicate, validCheck)
defer c2()
challengeManager2 := challenge.NewSimpleManager()
versions, err = ping(challengeManager2, e2+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport2 := transport.NewTransport(nil, NewAuthorizer(challengeManager2, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client2 := &http.Client{Transport: transport2}
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
resp, err = client2.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
if creds.refreshTokens[service] != refreshToken2 {
t.Fatalf("Refresh token not set after change")
}
// Try with bad token
e3, c3 := testServerWithAuth(m, authenicate, validCheck)
defer c3()
challengeManager3 := challenge.NewSimpleManager()
versions, err = ping(challengeManager3, e3+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport3 := transport.NewTransport(nil, NewAuthorizer(challengeManager3, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client3 := &http.Client{Transport: transport3}
req, _ = http.NewRequest("GET", e3+"/v2/hello", nil)
resp, err = client3.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func TestEndpointAuthorizeV2RefreshToken(t *testing.T) {
service := "localhost.localdomain"
scope1 := "registry:catalog:search"
refreshToken1 := "0123456790abcdef"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope1), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken1)),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v1/search",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
challengeManager1 := challenge.NewSimpleManager()
versions, err := ping(challengeManager1, e+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
tho := TokenHandlerOptions{
Credentials: &testCredentialStore{
refreshTokens: map[string]string{
service: refreshToken1,
},
},
Scopes: []Scope{
RegistryScope{
Name: "catalog",
Actions: []string{"search"},
},
},
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandlerWithOptions(tho)))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v1/search", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
}
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
func TestEndpointAuthorizeTokenBasic(t *testing.T) {
service := "localhost.localdomain"
repo := "some/fun/registry"
scope := fmt.Sprintf("repository:%s:pull,push", repo)
username := "tokenuser"
password := "superSecretPa$$word"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"statictoken"}`),
},
},
})
authenicate1 := fmt.Sprintf("Basic realm=localhost")
basicCheck := func(a string) bool {
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
bearerCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
defer c()
creds := &testCredentialStore{
username: username,
password: password,
}
challengeManager := challenge.NewSimpleManager()
_, err := ping(challengeManager, e+"/v2/", "")
if err != nil {
t.Fatal(err)
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, NewTokenHandler(nil, creds, repo, "pull", "push"), NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
}
func TestEndpointAuthorizeTokenBasicWithExpiresIn(t *testing.T) {
service := "localhost.localdomain"
repo := "some/fun/registry"
scope := fmt.Sprintf("repository:%s:pull,push", repo)
username := "tokenuser"
password := "superSecretPa$$word"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken", "expires_in": 3001}`),
},
},
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"statictoken", "expires_in": 3001}`),
},
},
})
authenicate1 := fmt.Sprintf("Basic realm=localhost")
tokenExchanges := 0
basicCheck := func(a string) bool {
tokenExchanges = tokenExchanges + 1
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
bearerCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
defer c()
creds := &testCredentialStore{
username: username,
password: password,
}
challengeManager := challenge.NewSimpleManager()
_, err := ping(challengeManager, e+"/v2/", "")
if err != nil {
t.Fatal(err)
}
clock := &fakeClock{current: time.Now()}
options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
// First call should result in a token exchange
// Subsequent calls should recycle the token from the first request, until the expiration has lapsed.
timeIncrement := 1000 * time.Second
for i := 0; i < 4; i++ {
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 1 {
t.Fatalf("Unexpected number of token exchanges, want: 1, got %d (iteration: %d)", tokenExchanges, i)
}
clock.current = clock.current.Add(timeIncrement)
}
// After we've exceeded the expiration, we should see a second token exchange.
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 2 {
t.Fatalf("Unexpected number of token exchanges, want: 2, got %d", tokenExchanges)
}
}
func TestEndpointAuthorizeTokenBasicWithExpiresInAndIssuedAt(t *testing.T) {
service := "localhost.localdomain"
repo := "some/fun/registry"
scope := fmt.Sprintf("repository:%s:pull,push", repo)
username := "tokenuser"
password := "superSecretPa$$word"
// This test sets things up such that the token was issued one increment
// earlier than its sibling in TestEndpointAuthorizeTokenBasicWithExpiresIn.
// This will mean that the token expires after 3 increments instead of 4.
clock := &fakeClock{current: time.Now()}
timeIncrement := 1000 * time.Second
firstIssuedAt := clock.Now()
clock.current = clock.current.Add(timeIncrement)
secondIssuedAt := clock.current.Add(2 * timeIncrement)
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken", "issued_at": "` + firstIssuedAt.Format(time.RFC3339Nano) + `", "expires_in": 3001}`),
},
},
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"statictoken", "issued_at": "` + secondIssuedAt.Format(time.RFC3339Nano) + `", "expires_in": 3001}`),
},
},
})
authenicate1 := fmt.Sprintf("Basic realm=localhost")
tokenExchanges := 0
basicCheck := func(a string) bool {
tokenExchanges = tokenExchanges + 1
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
bearerCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
defer c()
creds := &testCredentialStore{
username: username,
password: password,
}
challengeManager := challenge.NewSimpleManager()
_, err := ping(challengeManager, e+"/v2/", "")
if err != nil {
t.Fatal(err)
}
options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
// First call should result in a token exchange
// Subsequent calls should recycle the token from the first request, until the expiration has lapsed.
// We shaved one increment off of the equivalent logic in TestEndpointAuthorizeTokenBasicWithExpiresIn
// so this loop should have one fewer iteration.
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 1 {
t.Fatalf("Unexpected number of token exchanges, want: 1, got %d (iteration: %d)", tokenExchanges, i)
}
clock.current = clock.current.Add(timeIncrement)
}
// After we've exceeded the expiration, we should see a second token exchange.
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 2 {
t.Fatalf("Unexpected number of token exchanges, want: 2, got %d", tokenExchanges)
}
}
func TestEndpointAuthorizeBasic(t *testing.T) {
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
username := "user1"
password := "funSecretPa$$word"
authenicate := fmt.Sprintf("Basic realm=localhost")
validCheck := func(a string) bool {
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
creds := &testCredentialStore{
username: username,
password: password,
}
challengeManager := challenge.NewSimpleManager()
_, err := ping(challengeManager, e+"/v2/", "")
if err != nil {
t.Fatal(err)
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
}

View File

@ -0,0 +1,211 @@
package client
import (
"bytes"
"fmt"
"net/http"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/testutil"
)
// Test implements distribution.BlobWriter
var _ distribution.BlobWriter = &httpBlobUpload{}
func TestUploadReadFrom(t *testing.T) {
_, b := newRandomBlob(64)
repo := "test/upload/readfrom"
locationPath := fmt.Sprintf("/v2/%s/uploads/testid", repo)
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/",
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Headers: http.Header(map[string][]string{
"Docker-Distribution-API-Version": {"registry/2.0"},
}),
},
},
// Test Valid case
{
Request: testutil.Request{
Method: "PATCH",
Route: locationPath,
Body: b,
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
Headers: http.Header(map[string][]string{
"Docker-Upload-UUID": {"46603072-7a1b-4b41-98f9-fd8a7da89f9b"},
"Location": {locationPath},
"Range": {"0-63"},
}),
},
},
// Test invalid range
{
Request: testutil.Request{
Method: "PATCH",
Route: locationPath,
Body: b,
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
Headers: http.Header(map[string][]string{
"Docker-Upload-UUID": {"46603072-7a1b-4b41-98f9-fd8a7da89f9b"},
"Location": {locationPath},
"Range": {""},
}),
},
},
// Test 404
{
Request: testutil.Request{
Method: "PATCH",
Route: locationPath,
Body: b,
},
Response: testutil.Response{
StatusCode: http.StatusNotFound,
},
},
// Test 400 valid json
{
Request: testutil.Request{
Method: "PATCH",
Route: locationPath,
Body: b,
},
Response: testutil.Response{
StatusCode: http.StatusBadRequest,
Body: []byte(`
{ "errors":
[
{
"code": "BLOB_UPLOAD_INVALID",
"message": "blob upload invalid",
"detail": "more detail"
}
]
} `),
},
},
// Test 400 invalid json
{
Request: testutil.Request{
Method: "PATCH",
Route: locationPath,
Body: b,
},
Response: testutil.Response{
StatusCode: http.StatusBadRequest,
Body: []byte("something bad happened"),
},
},
// Test 500
{
Request: testutil.Request{
Method: "PATCH",
Route: locationPath,
Body: b,
},
Response: testutil.Response{
StatusCode: http.StatusInternalServerError,
},
},
})
e, c := testServer(m)
defer c()
blobUpload := &httpBlobUpload{
client: &http.Client{},
}
// Valid case
blobUpload.location = e + locationPath
n, err := blobUpload.ReadFrom(bytes.NewReader(b))
if err != nil {
t.Fatalf("Error calling ReadFrom: %s", err)
}
if n != 64 {
t.Fatalf("Wrong length returned from ReadFrom: %d, expected 64", n)
}
// Bad range
blobUpload.location = e + locationPath
_, err = blobUpload.ReadFrom(bytes.NewReader(b))
if err == nil {
t.Fatalf("Expected error when bad range received")
}
// 404
blobUpload.location = e + locationPath
_, err = blobUpload.ReadFrom(bytes.NewReader(b))
if err == nil {
t.Fatalf("Expected error when not found")
}
if err != distribution.ErrBlobUploadUnknown {
t.Fatalf("Wrong error thrown: %s, expected %s", err, distribution.ErrBlobUploadUnknown)
}
// 400 valid json
blobUpload.location = e + locationPath
_, err = blobUpload.ReadFrom(bytes.NewReader(b))
if err == nil {
t.Fatalf("Expected error when not found")
}
if uploadErr, ok := err.(errcode.Errors); !ok {
t.Fatalf("Wrong error type %T: %s", err, err)
} else if len(uploadErr) != 1 {
t.Fatalf("Unexpected number of errors: %d, expected 1", len(uploadErr))
} else {
v2Err, ok := uploadErr[0].(errcode.Error)
if !ok {
t.Fatalf("Not an 'Error' type: %#v", uploadErr[0])
}
if v2Err.Code != v2.ErrorCodeBlobUploadInvalid {
t.Fatalf("Unexpected error code: %s, expected %d", v2Err.Code.String(), v2.ErrorCodeBlobUploadInvalid)
}
if expected := "blob upload invalid"; v2Err.Message != expected {
t.Fatalf("Unexpected error message: %q, expected %q", v2Err.Message, expected)
}
if expected := "more detail"; v2Err.Detail.(string) != expected {
t.Fatalf("Unexpected error message: %q, expected %q", v2Err.Detail.(string), expected)
}
}
// 400 invalid json
blobUpload.location = e + locationPath
_, err = blobUpload.ReadFrom(bytes.NewReader(b))
if err == nil {
t.Fatalf("Expected error when not found")
}
if uploadErr, ok := err.(*UnexpectedHTTPResponseError); !ok {
t.Fatalf("Wrong error type %T: %s", err, err)
} else {
respStr := string(uploadErr.Response)
if expected := "something bad happened"; respStr != expected {
t.Fatalf("Unexpected response string: %s, expected: %s", respStr, expected)
}
}
// 500
blobUpload.location = e + locationPath
_, err = blobUpload.ReadFrom(bytes.NewReader(b))
if err == nil {
t.Fatalf("Expected error when not found")
}
if uploadErr, ok := err.(*UnexpectedHTTPStatusError); !ok {
t.Fatalf("Wrong error type %T: %s", err, err)
} else if expected := "500 " + http.StatusText(http.StatusInternalServerError); uploadErr.Status != expected {
t.Fatalf("Unexpected response status: %s, expected %s", uploadErr.Status, expected)
}
}

View File

@ -0,0 +1,104 @@
package client
import (
"bytes"
"io"
"net/http"
"strings"
"testing"
)
type nopCloser struct {
io.Reader
}
func (nopCloser) Close() error { return nil }
func TestHandleErrorResponse401ValidBody(t *testing.T) {
json := "{\"errors\":[{\"code\":\"UNAUTHORIZED\",\"message\":\"action requires authentication\"}]}"
response := &http.Response{
Status: "401 Unauthorized",
StatusCode: 401,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := "unauthorized: action requires authentication"
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponse401WithInvalidBody(t *testing.T) {
json := "{invalid json}"
response := &http.Response{
Status: "401 Unauthorized",
StatusCode: 401,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := "unauthorized: authentication required"
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponseExpectedStatusCode400ValidBody(t *testing.T) {
json := "{\"errors\":[{\"code\":\"DIGEST_INVALID\",\"message\":\"provided digest does not match\"}]}"
response := &http.Response{
Status: "400 Bad Request",
StatusCode: 400,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := "digest invalid: provided digest does not match"
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponseExpectedStatusCode404EmptyErrorSlice(t *testing.T) {
json := `{"randomkey": "randomvalue"}`
response := &http.Response{
Status: "404 Not Found",
StatusCode: 404,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := `error parsing HTTP 404 response body: no error details found in HTTP response body: "{\"randomkey\": \"randomvalue\"}"`
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) {
json := "{invalid json}"
response := &http.Response{
Status: "404 Not Found",
StatusCode: 404,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := "error parsing HTTP 404 response body: invalid character 'i' looking for beginning of object key string: \"{invalid json}\""
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponseUnexpectedStatusCode501(t *testing.T) {
response := &http.Response{
Status: "501 Not Implemented",
StatusCode: 501,
Body: nopCloser{bytes.NewBufferString("{\"Error Encountered\" : \"Function not implemented.\"}")},
}
err := HandleErrorResponse(response)
expectedMsg := "received unexpected HTTP status: 501 Not Implemented"
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}

View File

@ -15,12 +15,12 @@ import (
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/client/transport"
"github.com/docker/distribution/registry/storage/cache"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/opencontainers/go-digest"
)
// Registry provides an interface for calling Repositories, which returns a catalog of repositories.
@ -268,7 +268,7 @@ func descriptorFromResponse(response *http.Response) (distribution.Descriptor, e
return desc, nil
}
dgst, err := digest.Parse(digestHeader)
dgst, err := digest.ParseDigest(digestHeader)
if err != nil {
return distribution.Descriptor{}, err
}
@ -475,7 +475,7 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
return nil, distribution.ErrManifestNotModified
} else if SuccessStatus(resp.StatusCode) {
if contentDgst != nil {
dgst, err := digest.Parse(resp.Header.Get("Docker-Content-Digest"))
dgst, err := digest.ParseDigest(resp.Header.Get("Docker-Content-Digest"))
if err == nil {
*contentDgst = dgst
}
@ -553,7 +553,7 @@ func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options .
if SuccessStatus(resp.StatusCode) {
dgstHeader := resp.Header.Get("Docker-Content-Digest")
dgst, err := digest.Parse(dgstHeader)
dgst, err := digest.ParseDigest(dgstHeader)
if err != nil {
return "", err
}
@ -661,7 +661,7 @@ func (bs *blobs) Put(ctx context.Context, mediaType string, p []byte) (distribut
if err != nil {
return distribution.Descriptor{}, err
}
dgstr := digest.Canonical.Digester()
dgstr := digest.Canonical.New()
n, err := io.Copy(writer, io.TeeReader(bytes.NewReader(p), dgstr.Hash()))
if err != nil {
return distribution.Descriptor{}, err

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,2 @@
// Package registry provides the main entrypoints for running a registry.
package registry

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,279 @@
package handlers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/auth"
_ "github.com/docker/distribution/registry/auth/silly"
"github.com/docker/distribution/registry/storage"
memorycache "github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver/testdriver"
)
// TestAppDispatcher builds an application with a test dispatcher and ensures
// that requests are properly dispatched and the handlers are constructed.
// This only tests the dispatch mechanism. The underlying dispatchers must be
// tested individually.
func TestAppDispatcher(t *testing.T) {
driver := testdriver.New()
ctx := context.Background()
registry, err := storage.NewRegistry(ctx, driver, storage.BlobDescriptorCacheProvider(memorycache.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableDelete, storage.EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
app := &App{
Config: &configuration.Configuration{},
Context: ctx,
router: v2.Router(),
driver: driver,
registry: registry,
}
server := httptest.NewServer(app)
defer server.Close()
router := v2.Router()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("error parsing server url: %v", err)
}
varCheckingDispatcher := func(expectedVars map[string]string) dispatchFunc {
return func(ctx *Context, r *http.Request) http.Handler {
// Always checks the same name context
if ctx.Repository.Named().Name() != getName(ctx) {
t.Fatalf("unexpected name: %q != %q", ctx.Repository.Named().Name(), "foo/bar")
}
// Check that we have all that is expected
for expectedK, expectedV := range expectedVars {
if ctx.Value(expectedK) != expectedV {
t.Fatalf("unexpected %s in context vars: %q != %q", expectedK, ctx.Value(expectedK), expectedV)
}
}
// Check that we only have variables that are expected
for k, v := range ctx.Value("vars").(map[string]string) {
_, ok := expectedVars[k]
if !ok { // name is checked on context
// We have an unexpected key, fail
t.Fatalf("unexpected key %q in vars with value %q", k, v)
}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
}
}
// unflatten a list of variables, suitable for gorilla/mux, to a map[string]string
unflatten := func(vars []string) map[string]string {
m := make(map[string]string)
for i := 0; i < len(vars)-1; i = i + 2 {
m[vars[i]] = vars[i+1]
}
return m
}
for _, testcase := range []struct {
endpoint string
vars []string
}{
{
endpoint: v2.RouteNameManifest,
vars: []string{
"name", "foo/bar",
"reference", "sometag",
},
},
{
endpoint: v2.RouteNameTags,
vars: []string{
"name", "foo/bar",
},
},
{
endpoint: v2.RouteNameBlobUpload,
vars: []string{
"name", "foo/bar",
},
},
{
endpoint: v2.RouteNameBlobUploadChunk,
vars: []string{
"name", "foo/bar",
"uuid", "theuuid",
},
},
} {
app.register(testcase.endpoint, varCheckingDispatcher(unflatten(testcase.vars)))
route := router.GetRoute(testcase.endpoint).Host(serverURL.Host)
u, err := route.URL(testcase.vars...)
if err != nil {
t.Fatal(err)
}
resp, err := http.Get(u.String())
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %v != %v", resp.StatusCode, http.StatusOK)
}
}
}
// TestNewApp covers the creation of an application via NewApp with a
// configuration.
func TestNewApp(t *testing.T) {
ctx := context.Background()
config := configuration.Configuration{
Storage: configuration.Storage{
"testdriver": nil,
"maintenance": configuration.Parameters{"uploadpurging": map[interface{}]interface{}{
"enabled": false,
}},
},
Auth: configuration.Auth{
// For now, we simply test that new auth results in a viable
// application.
"silly": {
"realm": "realm-test",
"service": "service-test",
},
},
}
// Mostly, with this test, given a sane configuration, we are simply
// ensuring that NewApp doesn't panic. We might want to tweak this
// behavior.
app := NewApp(ctx, &config)
server := httptest.NewServer(app)
defer server.Close()
builder, err := v2.NewURLBuilderFromString(server.URL, false)
if err != nil {
t.Fatalf("error creating urlbuilder: %v", err)
}
baseURL, err := builder.BuildBaseURL()
if err != nil {
t.Fatalf("error creating baseURL: %v", err)
}
// TODO(stevvooe): The rest of this test might belong in the API tests.
// Just hit the app and make sure we get a 401 Unauthorized error.
req, err := http.Get(baseURL)
if err != nil {
t.Fatalf("unexpected error during GET: %v", err)
}
defer req.Body.Close()
if req.StatusCode != http.StatusUnauthorized {
t.Fatalf("unexpected status code during request: %v", err)
}
if req.Header.Get("Content-Type") != "application/json; charset=utf-8" {
t.Fatalf("unexpected content-type: %v != %v", req.Header.Get("Content-Type"), "application/json; charset=utf-8")
}
expectedAuthHeader := "Bearer realm=\"realm-test\",service=\"service-test\""
if e, a := expectedAuthHeader, req.Header.Get("WWW-Authenticate"); e != a {
t.Fatalf("unexpected WWW-Authenticate header: %q != %q", e, a)
}
var errs errcode.Errors
dec := json.NewDecoder(req.Body)
if err := dec.Decode(&errs); err != nil {
t.Fatalf("error decoding error response: %v", err)
}
err2, ok := errs[0].(errcode.ErrorCoder)
if !ok {
t.Fatalf("not an ErrorCoder: %#v", errs[0])
}
if err2.ErrorCode() != errcode.ErrorCodeUnauthorized {
t.Fatalf("unexpected error code: %v != %v", err2.ErrorCode(), errcode.ErrorCodeUnauthorized)
}
}
// Test the access record accumulator
func TestAppendAccessRecords(t *testing.T) {
repo := "testRepo"
expectedResource := auth.Resource{
Type: "repository",
Name: repo,
}
expectedPullRecord := auth.Access{
Resource: expectedResource,
Action: "pull",
}
expectedPushRecord := auth.Access{
Resource: expectedResource,
Action: "push",
}
expectedAllRecord := auth.Access{
Resource: expectedResource,
Action: "*",
}
records := []auth.Access{}
result := appendAccessRecords(records, "GET", repo)
expectedResult := []auth.Access{expectedPullRecord}
if ok := reflect.DeepEqual(result, expectedResult); !ok {
t.Fatalf("Actual access record differs from expected")
}
records = []auth.Access{}
result = appendAccessRecords(records, "HEAD", repo)
expectedResult = []auth.Access{expectedPullRecord}
if ok := reflect.DeepEqual(result, expectedResult); !ok {
t.Fatalf("Actual access record differs from expected")
}
records = []auth.Access{}
result = appendAccessRecords(records, "POST", repo)
expectedResult = []auth.Access{expectedPullRecord, expectedPushRecord}
if ok := reflect.DeepEqual(result, expectedResult); !ok {
t.Fatalf("Actual access record differs from expected")
}
records = []auth.Access{}
result = appendAccessRecords(records, "PUT", repo)
expectedResult = []auth.Access{expectedPullRecord, expectedPushRecord}
if ok := reflect.DeepEqual(result, expectedResult); !ok {
t.Fatalf("Actual access record differs from expected")
}
records = []auth.Access{}
result = appendAccessRecords(records, "PATCH", repo)
expectedResult = []auth.Access{expectedPullRecord, expectedPushRecord}
if ok := reflect.DeepEqual(result, expectedResult); !ok {
t.Fatalf("Actual access record differs from expected")
}
records = []auth.Access{}
result = appendAccessRecords(records, "DELETE", repo)
expectedResult = []auth.Access{expectedAllRecord}
if ok := reflect.DeepEqual(result, expectedResult); !ok {
t.Fatalf("Actual access record differs from expected")
}
}

View File

@ -0,0 +1,11 @@
// +build go1.4
package handlers
import (
"net/http"
)
func basicAuth(r *http.Request) (username, password string, ok bool) {
return r.BasicAuth()
}

View File

@ -0,0 +1,41 @@
// +build !go1.4
package handlers
import (
"encoding/base64"
"net/http"
"strings"
)
// NOTE(stevvooe): This is basic auth support from go1.4 present to ensure we
// can compile on go1.3 and earlier.
// BasicAuth returns the username and password provided in the request's
// Authorization header, if the request uses HTTP Basic Authentication.
// See RFC 2617, Section 2.
func basicAuth(r *http.Request) (username, password string, ok bool) {
auth := r.Header.Get("Authorization")
if auth == "" {
return
}
return parseBasicAuth(auth)
}
// parseBasicAuth parses an HTTP Basic Authentication string.
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
func parseBasicAuth(auth string) (username, password string, ok bool) {
if !strings.HasPrefix(auth, "Basic ") {
return
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}

View File

@ -0,0 +1,99 @@
package handlers
import (
"net/http"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/gorilla/handlers"
)
// blobDispatcher uses the request context to build a blobHandler.
func blobDispatcher(ctx *Context, r *http.Request) http.Handler {
dgst, err := getDigest(ctx)
if err != nil {
if err == errDigestNotAvailable {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx.Errors = append(ctx.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err))
})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx.Errors = append(ctx.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err))
})
}
blobHandler := &blobHandler{
Context: ctx,
Digest: dgst,
}
mhandler := handlers.MethodHandler{
"GET": http.HandlerFunc(blobHandler.GetBlob),
"HEAD": http.HandlerFunc(blobHandler.GetBlob),
}
if !ctx.readOnly {
mhandler["DELETE"] = http.HandlerFunc(blobHandler.DeleteBlob)
}
return mhandler
}
// blobHandler serves http blob requests.
type blobHandler struct {
*Context
Digest digest.Digest
}
// GetBlob fetches the binary data from backend storage returns it in the
// response.
func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
context.GetLogger(bh).Debug("GetBlob")
blobs := bh.Repository.Blobs(bh)
desc, err := blobs.Stat(bh, bh.Digest)
if err != nil {
if err == distribution.ErrBlobUnknown {
bh.Errors = append(bh.Errors, v2.ErrorCodeBlobUnknown.WithDetail(bh.Digest))
} else {
bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
return
}
if err := blobs.ServeBlob(bh, w, r, desc.Digest); err != nil {
context.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err)
bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
}
// DeleteBlob deletes a layer blob
func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) {
context.GetLogger(bh).Debug("DeleteBlob")
blobs := bh.Repository.Blobs(bh)
err := blobs.Delete(bh, bh.Digest)
if err != nil {
switch err {
case distribution.ErrUnsupported:
bh.Errors = append(bh.Errors, errcode.ErrorCodeUnsupported)
return
case distribution.ErrBlobUnknown:
bh.Errors = append(bh.Errors, v2.ErrorCodeBlobUnknown)
return
default:
bh.Errors = append(bh.Errors, err)
context.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error())
return
}
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusAccepted)
}

View File

@ -0,0 +1,368 @@
package handlers
import (
"fmt"
"net/http"
"net/url"
"github.com/docker/distribution"
ctxu "github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/storage"
"github.com/gorilla/handlers"
)
// blobUploadDispatcher constructs and returns the blob upload handler for the
// given request context.
func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
buh := &blobUploadHandler{
Context: ctx,
UUID: getUploadUUID(ctx),
}
handler := handlers.MethodHandler{
"GET": http.HandlerFunc(buh.GetUploadStatus),
"HEAD": http.HandlerFunc(buh.GetUploadStatus),
}
if !ctx.readOnly {
handler["POST"] = http.HandlerFunc(buh.StartBlobUpload)
handler["PATCH"] = http.HandlerFunc(buh.PatchBlobData)
handler["PUT"] = http.HandlerFunc(buh.PutBlobUploadComplete)
handler["DELETE"] = http.HandlerFunc(buh.CancelBlobUpload)
}
if buh.UUID != "" {
state, err := hmacKey(ctx.Config.HTTP.Secret).unpackUploadState(r.FormValue("_state"))
if err != nil {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(ctx).Infof("error resolving upload: %v", err)
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
})
}
buh.State = state
if state.Name != ctx.Repository.Named().Name() {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(ctx).Infof("mismatched repository name in upload state: %q != %q", state.Name, buh.Repository.Named().Name())
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
})
}
if state.UUID != buh.UUID {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(ctx).Infof("mismatched uuid in upload state: %q != %q", state.UUID, buh.UUID)
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
})
}
blobs := ctx.Repository.Blobs(buh)
upload, err := blobs.Resume(buh, buh.UUID)
if err != nil {
ctxu.GetLogger(ctx).Errorf("error resolving upload: %v", err)
if err == distribution.ErrBlobUploadUnknown {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadUnknown.WithDetail(err))
})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
})
}
buh.Upload = upload
if size := upload.Size(); size != buh.State.Offset {
defer upload.Close()
ctxu.GetLogger(ctx).Errorf("upload resumed at wrong offest: %d != %d", size, buh.State.Offset)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
upload.Cancel(buh)
})
}
return closeResources(handler, buh.Upload)
}
return handler
}
// blobUploadHandler handles the http blob upload process.
type blobUploadHandler struct {
*Context
// UUID identifies the upload instance for the current request. Using UUID
// to key blob writers since this implementation uses UUIDs.
UUID string
Upload distribution.BlobWriter
State blobUploadState
}
// StartBlobUpload begins the blob upload process and allocates a server-side
// blob writer session, optionally mounting the blob from a separate repository.
func (buh *blobUploadHandler) StartBlobUpload(w http.ResponseWriter, r *http.Request) {
var options []distribution.BlobCreateOption
fromRepo := r.FormValue("from")
mountDigest := r.FormValue("mount")
if mountDigest != "" && fromRepo != "" {
opt, err := buh.createBlobMountOption(fromRepo, mountDigest)
if opt != nil && err == nil {
options = append(options, opt)
}
}
blobs := buh.Repository.Blobs(buh)
upload, err := blobs.Create(buh, options...)
if err != nil {
if ebm, ok := err.(distribution.ErrBlobMounted); ok {
if err := buh.writeBlobCreatedHeaders(w, ebm.Descriptor); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
} else if err == distribution.ErrUnsupported {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnsupported)
} else {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
return
}
buh.Upload = upload
if err := buh.blobUploadResponse(w, r, true); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
w.Header().Set("Docker-Upload-UUID", buh.Upload.ID())
w.WriteHeader(http.StatusAccepted)
}
// GetUploadStatus returns the status of a given upload, identified by id.
func (buh *blobUploadHandler) GetUploadStatus(w http.ResponseWriter, r *http.Request) {
if buh.Upload == nil {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadUnknown)
return
}
// TODO(dmcgowan): Set last argument to false in blobUploadResponse when
// resumable upload is supported. This will enable returning a non-zero
// range for clients to begin uploading at an offset.
if err := buh.blobUploadResponse(w, r, true); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
w.Header().Set("Docker-Upload-UUID", buh.UUID)
w.WriteHeader(http.StatusNoContent)
}
// PatchBlobData writes data to an upload.
func (buh *blobUploadHandler) PatchBlobData(w http.ResponseWriter, r *http.Request) {
if buh.Upload == nil {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadUnknown)
return
}
ct := r.Header.Get("Content-Type")
if ct != "" && ct != "application/octet-stream" {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(fmt.Errorf("Bad Content-Type")))
// TODO(dmcgowan): encode error
return
}
// TODO(dmcgowan): support Content-Range header to seek and write range
if err := copyFullPayload(w, r, buh.Upload, -1, buh, "blob PATCH"); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err.Error()))
return
}
if err := buh.blobUploadResponse(w, r, false); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
w.WriteHeader(http.StatusAccepted)
}
// PutBlobUploadComplete takes the final request of a blob upload. The
// request may include all the blob data or no blob data. Any data
// provided is received and verified. If successful, the blob is linked
// into the blob store and 201 Created is returned with the canonical
// url of the blob.
func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *http.Request) {
if buh.Upload == nil {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadUnknown)
return
}
dgstStr := r.FormValue("digest") // TODO(stevvooe): Support multiple digest parameters!
if dgstStr == "" {
// no digest? return error, but allow retry.
buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail("digest missing"))
return
}
dgst, err := digest.ParseDigest(dgstStr)
if err != nil {
// no digest? return error, but allow retry.
buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail("digest parsing failed"))
return
}
if err := copyFullPayload(w, r, buh.Upload, -1, buh, "blob PUT"); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err.Error()))
return
}
desc, err := buh.Upload.Commit(buh, distribution.Descriptor{
Digest: dgst,
// TODO(stevvooe): This isn't wildly important yet, but we should
// really set the mediatype. For now, we can let the backend take care
// of this.
})
if err != nil {
switch err := err.(type) {
case distribution.ErrBlobInvalidDigest:
buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err))
case errcode.Error:
buh.Errors = append(buh.Errors, err)
default:
switch err {
case distribution.ErrAccessDenied:
buh.Errors = append(buh.Errors, errcode.ErrorCodeDenied)
case distribution.ErrUnsupported:
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnsupported)
case distribution.ErrBlobInvalidLength, distribution.ErrBlobDigestUnsupported:
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
default:
ctxu.GetLogger(buh).Errorf("unknown error completing upload: %v", err)
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
}
// Clean up the backend blob data if there was an error.
if err := buh.Upload.Cancel(buh); err != nil {
// If the cleanup fails, all we can do is observe and report.
ctxu.GetLogger(buh).Errorf("error canceling upload after error: %v", err)
}
return
}
if err := buh.writeBlobCreatedHeaders(w, desc); err != nil {
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
}
// CancelBlobUpload cancels an in-progress upload of a blob.
func (buh *blobUploadHandler) CancelBlobUpload(w http.ResponseWriter, r *http.Request) {
if buh.Upload == nil {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadUnknown)
return
}
w.Header().Set("Docker-Upload-UUID", buh.UUID)
if err := buh.Upload.Cancel(buh); err != nil {
ctxu.GetLogger(buh).Errorf("error encountered canceling upload: %v", err)
buh.Errors = append(buh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
w.WriteHeader(http.StatusNoContent)
}
// blobUploadResponse provides a standard request for uploading blobs and
// chunk responses. This sets the correct headers but the response status is
// left to the caller. The fresh argument is used to ensure that new blob
// uploads always start at a 0 offset. This allows disabling resumable push by
// always returning a 0 offset on check status.
func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.Request, fresh bool) error {
// TODO(stevvooe): Need a better way to manage the upload state automatically.
buh.State.Name = buh.Repository.Named().Name()
buh.State.UUID = buh.Upload.ID()
buh.Upload.Close()
buh.State.Offset = buh.Upload.Size()
buh.State.StartedAt = buh.Upload.StartedAt()
token, err := hmacKey(buh.Config.HTTP.Secret).packUploadState(buh.State)
if err != nil {
ctxu.GetLogger(buh).Infof("error building upload state token: %s", err)
return err
}
uploadURL, err := buh.urlBuilder.BuildBlobUploadChunkURL(
buh.Repository.Named(), buh.Upload.ID(),
url.Values{
"_state": []string{token},
})
if err != nil {
ctxu.GetLogger(buh).Infof("error building upload url: %s", err)
return err
}
endRange := buh.Upload.Size()
if endRange > 0 {
endRange = endRange - 1
}
w.Header().Set("Docker-Upload-UUID", buh.UUID)
w.Header().Set("Location", uploadURL)
w.Header().Set("Content-Length", "0")
w.Header().Set("Range", fmt.Sprintf("0-%d", endRange))
return nil
}
// mountBlob attempts to mount a blob from another repository by its digest. If
// successful, the blob is linked into the blob store and 201 Created is
// returned with the canonical url of the blob.
func (buh *blobUploadHandler) createBlobMountOption(fromRepo, mountDigest string) (distribution.BlobCreateOption, error) {
dgst, err := digest.ParseDigest(mountDigest)
if err != nil {
return nil, err
}
ref, err := reference.ParseNamed(fromRepo)
if err != nil {
return nil, err
}
canonical, err := reference.WithDigest(ref, dgst)
if err != nil {
return nil, err
}
return storage.WithMountFrom(canonical), nil
}
// writeBlobCreatedHeaders writes the standard headers describing a newly
// created blob. A 201 Created is written as well as the canonical URL and
// blob digest.
func (buh *blobUploadHandler) writeBlobCreatedHeaders(w http.ResponseWriter, desc distribution.Descriptor) error {
ref, err := reference.WithDigest(buh.Repository.Named(), desc.Digest)
if err != nil {
return err
}
blobURL, err := buh.urlBuilder.BuildBlobURL(ref)
if err != nil {
return err
}
w.Header().Set("Location", blobURL)
w.Header().Set("Content-Length", "0")
w.Header().Set("Docker-Content-Digest", desc.Digest.String())
w.WriteHeader(http.StatusCreated)
return nil
}

View File

@ -0,0 +1,98 @@
package handlers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/storage/driver"
"github.com/gorilla/handlers"
)
const maximumReturnedEntries = 100
func catalogDispatcher(ctx *Context, r *http.Request) http.Handler {
catalogHandler := &catalogHandler{
Context: ctx,
}
return handlers.MethodHandler{
"GET": http.HandlerFunc(catalogHandler.GetCatalog),
}
}
type catalogHandler struct {
*Context
}
type catalogAPIResponse struct {
Repositories []string `json:"repositories"`
}
func (ch *catalogHandler) GetCatalog(w http.ResponseWriter, r *http.Request) {
var moreEntries = true
q := r.URL.Query()
lastEntry := q.Get("last")
maxEntries, err := strconv.Atoi(q.Get("n"))
if err != nil || maxEntries < 0 {
maxEntries = maximumReturnedEntries
}
repos := make([]string, maxEntries)
filled, err := ch.App.registry.Repositories(ch.Context, repos, lastEntry)
_, pathNotFound := err.(driver.PathNotFoundError)
if err == io.EOF || pathNotFound {
moreEntries = false
} else if err != nil {
ch.Errors = append(ch.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
// Add a link header if there are more entries to retrieve
if moreEntries {
lastEntry = repos[len(repos)-1]
urlStr, err := createLinkEntry(r.URL.String(), maxEntries, lastEntry)
if err != nil {
ch.Errors = append(ch.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
w.Header().Set("Link", urlStr)
}
enc := json.NewEncoder(w)
if err := enc.Encode(catalogAPIResponse{
Repositories: repos[0:filled],
}); err != nil {
ch.Errors = append(ch.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
}
// Use the original URL from the request to create a new URL for
// the link header
func createLinkEntry(origURL string, maxEntries int, lastEntry string) (string, error) {
calledURL, err := url.Parse(origURL)
if err != nil {
return "", err
}
v := url.Values{}
v.Add("n", strconv.Itoa(maxEntries))
v.Add("last", lastEntry)
calledURL.RawQuery = v.Encode()
calledURL.Fragment = ""
urlStr := fmt.Sprintf("<%s>; rel=\"next\"", calledURL.String())
return urlStr, nil
}

View File

@ -0,0 +1,152 @@
package handlers
import (
"fmt"
"net/http"
"sync"
"github.com/docker/distribution"
ctxu "github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/auth"
"golang.org/x/net/context"
)
// Context should contain the request specific context for use in across
// handlers. Resources that don't need to be shared across handlers should not
// be on this object.
type Context struct {
// App points to the application structure that created this context.
*App
context.Context
// Repository is the repository for the current request. All requests
// should be scoped to a single repository. This field may be nil.
Repository distribution.Repository
// Errors is a collection of errors encountered during the request to be
// returned to the client API. If errors are added to the collection, the
// handler *must not* start the response via http.ResponseWriter.
Errors errcode.Errors
urlBuilder *v2.URLBuilder
// TODO(stevvooe): The goal is too completely factor this context and
// dispatching out of the web application. Ideally, we should lean on
// context.Context for injection of these resources.
}
// Value overrides context.Context.Value to ensure that calls are routed to
// correct context.
func (ctx *Context) Value(key interface{}) interface{} {
return ctx.Context.Value(key)
}
func getName(ctx context.Context) (name string) {
return ctxu.GetStringValue(ctx, "vars.name")
}
func getReference(ctx context.Context) (reference string) {
return ctxu.GetStringValue(ctx, "vars.reference")
}
var errDigestNotAvailable = fmt.Errorf("digest not available in context")
func getDigest(ctx context.Context) (dgst digest.Digest, err error) {
dgstStr := ctxu.GetStringValue(ctx, "vars.digest")
if dgstStr == "" {
ctxu.GetLogger(ctx).Errorf("digest not available")
return "", errDigestNotAvailable
}
d, err := digest.ParseDigest(dgstStr)
if err != nil {
ctxu.GetLogger(ctx).Errorf("error parsing digest=%q: %v", dgstStr, err)
return "", err
}
return d, nil
}
func getUploadUUID(ctx context.Context) (uuid string) {
return ctxu.GetStringValue(ctx, "vars.uuid")
}
// getUserName attempts to resolve a username from the context and request. If
// a username cannot be resolved, the empty string is returned.
func getUserName(ctx context.Context, r *http.Request) string {
username := ctxu.GetStringValue(ctx, auth.UserNameKey)
// Fallback to request user with basic auth
if username == "" {
var ok bool
uname, _, ok := basicAuth(r)
if ok {
username = uname
}
}
return username
}
// contextManager allows us to associate net/context.Context instances with a
// request, based on the memory identity of http.Request. This prepares http-
// level context, which is not application specific. If this is called,
// (*contextManager).release must be called on the context when the request is
// completed.
//
// Providing this circumvents a lot of necessity for dispatchers with the
// benefit of instantiating the request context much earlier.
//
// TODO(stevvooe): Consider making this facility a part of the context package.
type contextManager struct {
contexts map[*http.Request]context.Context
mu sync.Mutex
}
// defaultContextManager is just a global instance to register request contexts.
var defaultContextManager = newContextManager()
func newContextManager() *contextManager {
return &contextManager{
contexts: make(map[*http.Request]context.Context),
}
}
// context either returns a new context or looks it up in the manager.
func (cm *contextManager) context(parent context.Context, w http.ResponseWriter, r *http.Request) context.Context {
cm.mu.Lock()
defer cm.mu.Unlock()
ctx, ok := cm.contexts[r]
if ok {
return ctx
}
if parent == nil {
parent = ctxu.Background()
}
ctx = ctxu.WithRequest(parent, r)
ctx, w = ctxu.WithResponseWriter(ctx, w)
ctx = ctxu.WithLogger(ctx, ctxu.GetRequestLogger(ctx))
cm.contexts[r] = ctx
return ctx
}
// releases frees any associated with resources from request.
func (cm *contextManager) release(ctx context.Context) {
cm.mu.Lock()
defer cm.mu.Unlock()
r, err := ctxu.GetRequest(ctx)
if err != nil {
ctxu.GetLogger(ctx).Errorf("no request found in context during release")
return
}
delete(cm.contexts, r)
}

View File

@ -0,0 +1,210 @@
package handlers
import (
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
"github.com/docker/distribution/health"
)
func TestFileHealthCheck(t *testing.T) {
interval := time.Second
tmpfile, err := ioutil.TempFile(os.TempDir(), "healthcheck")
if err != nil {
t.Fatalf("could not create temporary file: %v", err)
}
defer tmpfile.Close()
config := &configuration.Configuration{
Storage: configuration.Storage{
"inmemory": configuration.Parameters{},
"maintenance": configuration.Parameters{"uploadpurging": map[interface{}]interface{}{
"enabled": false,
}},
},
Health: configuration.Health{
FileCheckers: []configuration.FileChecker{
{
Interval: interval,
File: tmpfile.Name(),
},
},
},
}
ctx := context.Background()
app := NewApp(ctx, config)
healthRegistry := health.NewRegistry()
app.RegisterHealthChecks(healthRegistry)
// Wait for health check to happen
<-time.After(2 * interval)
status := healthRegistry.CheckStatus()
if len(status) != 1 {
t.Fatal("expected 1 item in health check results")
}
if status[tmpfile.Name()] != "file exists" {
t.Fatal(`did not get "file exists" result for health check`)
}
os.Remove(tmpfile.Name())
<-time.After(2 * interval)
if len(healthRegistry.CheckStatus()) != 0 {
t.Fatal("expected 0 items in health check results")
}
}
func TestTCPHealthCheck(t *testing.T) {
interval := time.Second
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("could not create listener: %v", err)
}
addrStr := ln.Addr().String()
// Start accepting
go func() {
for {
conn, err := ln.Accept()
if err != nil {
// listener was closed
return
}
defer conn.Close()
}
}()
config := &configuration.Configuration{
Storage: configuration.Storage{
"inmemory": configuration.Parameters{},
"maintenance": configuration.Parameters{"uploadpurging": map[interface{}]interface{}{
"enabled": false,
}},
},
Health: configuration.Health{
TCPCheckers: []configuration.TCPChecker{
{
Interval: interval,
Addr: addrStr,
Timeout: 500 * time.Millisecond,
},
},
},
}
ctx := context.Background()
app := NewApp(ctx, config)
healthRegistry := health.NewRegistry()
app.RegisterHealthChecks(healthRegistry)
// Wait for health check to happen
<-time.After(2 * interval)
if len(healthRegistry.CheckStatus()) != 0 {
t.Fatal("expected 0 items in health check results")
}
ln.Close()
<-time.After(2 * interval)
// Health check should now fail
status := healthRegistry.CheckStatus()
if len(status) != 1 {
t.Fatal("expected 1 item in health check results")
}
if status[addrStr] != "connection to "+addrStr+" failed" {
t.Fatal(`did not get "connection failed" result for health check`)
}
}
func TestHTTPHealthCheck(t *testing.T) {
interval := time.Second
threshold := 3
stopFailing := make(chan struct{})
checkedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "HEAD" {
t.Fatalf("expected HEAD request, got %s", r.Method)
}
select {
case <-stopFailing:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusInternalServerError)
}
}))
config := &configuration.Configuration{
Storage: configuration.Storage{
"inmemory": configuration.Parameters{},
"maintenance": configuration.Parameters{"uploadpurging": map[interface{}]interface{}{
"enabled": false,
}},
},
Health: configuration.Health{
HTTPCheckers: []configuration.HTTPChecker{
{
Interval: interval,
URI: checkedServer.URL,
Threshold: threshold,
},
},
},
}
ctx := context.Background()
app := NewApp(ctx, config)
healthRegistry := health.NewRegistry()
app.RegisterHealthChecks(healthRegistry)
for i := 0; ; i++ {
<-time.After(interval)
status := healthRegistry.CheckStatus()
if i < threshold-1 {
// definitely shouldn't have hit the threshold yet
if len(status) != 0 {
t.Fatal("expected 1 item in health check results")
}
continue
}
if i < threshold+1 {
// right on the threshold - don't expect a failure yet
continue
}
if len(status) != 1 {
t.Fatal("expected 1 item in health check results")
}
if status[checkedServer.URL] != "downstream service returned unexpected status: 500" {
t.Fatal("did not get expected result for health check")
}
break
}
// Signal HTTP handler to start returning 200
close(stopFailing)
<-time.After(2 * interval)
if len(healthRegistry.CheckStatus()) != 0 {
t.Fatal("expected 0 items in health check results")
}
}

View File

@ -0,0 +1,71 @@
package handlers
import (
"errors"
"io"
"net/http"
ctxu "github.com/docker/distribution/context"
)
// closeResources closes all the provided resources after running the target
// handler.
func closeResources(handler http.Handler, closers ...io.Closer) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, closer := range closers {
defer closer.Close()
}
handler.ServeHTTP(w, r)
})
}
// copyFullPayload copies the payload of an HTTP request to destWriter. If it
// receives less content than expected, and the client disconnected during the
// upload, it avoids sending a 400 error to keep the logs cleaner.
//
// The copy will be limited to `limit` bytes, if limit is greater than zero.
func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, limit int64, context ctxu.Context, action string) error {
// Get a channel that tells us if the client disconnects
var clientClosed <-chan bool
if notifier, ok := responseWriter.(http.CloseNotifier); ok {
clientClosed = notifier.CloseNotify()
} else {
ctxu.GetLogger(context).Warnf("the ResponseWriter does not implement CloseNotifier (type: %T)", responseWriter)
}
var body = r.Body
if limit > 0 {
body = http.MaxBytesReader(responseWriter, body, limit)
}
// Read in the data, if any.
copied, err := io.Copy(destWriter, body)
if clientClosed != nil && (err != nil || (r.ContentLength > 0 && copied < r.ContentLength)) {
// Didn't receive as much content as expected. Did the client
// disconnect during the request? If so, avoid returning a 400
// error to keep the logs cleaner.
select {
case <-clientClosed:
// Set the response code to "499 Client Closed Request"
// Even though the connection has already been closed,
// this causes the logger to pick up a 499 error
// instead of showing 0 for the HTTP status.
responseWriter.WriteHeader(499)
ctxu.GetLoggerWithFields(context, map[interface{}]interface{}{
"error": err,
"copied": copied,
"contentLength": r.ContentLength,
}, "error", "copied", "contentLength").Error("client disconnected during " + action)
return errors.New("client disconnected")
default:
}
}
if err != nil {
ctxu.GetLogger(context).Errorf("unknown error reading request payload: %v", err)
return err
}
return nil
}

View File

@ -0,0 +1,74 @@
package handlers
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"time"
)
// blobUploadState captures the state serializable state of the blob upload.
type blobUploadState struct {
// name is the primary repository under which the blob will be linked.
Name string
// UUID identifies the upload.
UUID string
// offset contains the current progress of the upload.
Offset int64
// StartedAt is the original start time of the upload.
StartedAt time.Time
}
type hmacKey string
var errInvalidSecret = fmt.Errorf("invalid secret")
// unpackUploadState unpacks and validates the blob upload state from the
// token, using the hmacKey secret.
func (secret hmacKey) unpackUploadState(token string) (blobUploadState, error) {
var state blobUploadState
tokenBytes, err := base64.URLEncoding.DecodeString(token)
if err != nil {
return state, err
}
mac := hmac.New(sha256.New, []byte(secret))
if len(tokenBytes) < mac.Size() {
return state, errInvalidSecret
}
macBytes := tokenBytes[:mac.Size()]
messageBytes := tokenBytes[mac.Size():]
mac.Write(messageBytes)
if !hmac.Equal(mac.Sum(nil), macBytes) {
return state, errInvalidSecret
}
if err := json.Unmarshal(messageBytes, &state); err != nil {
return state, err
}
return state, nil
}
// packUploadState packs the upload state signed with and hmac digest using
// the hmacKey secret, encoding to url safe base64. The resulting token can be
// used to share data with minimized risk of external tampering.
func (secret hmacKey) packUploadState(lus blobUploadState) (string, error) {
mac := hmac.New(sha256.New, []byte(secret))
p, err := json.Marshal(lus)
if err != nil {
return "", err
}
mac.Write(p)
return base64.URLEncoding.EncodeToString(append(mac.Sum(nil), p...)), nil
}

View File

@ -0,0 +1,117 @@
package handlers
import "testing"
var blobUploadStates = []blobUploadState{
{
Name: "hello",
UUID: "abcd-1234-qwer-0987",
Offset: 0,
},
{
Name: "hello-world",
UUID: "abcd-1234-qwer-0987",
Offset: 0,
},
{
Name: "h3ll0_w0rld",
UUID: "abcd-1234-qwer-0987",
Offset: 1337,
},
{
Name: "ABCDEFG",
UUID: "ABCD-1234-QWER-0987",
Offset: 1234567890,
},
{
Name: "this-is-A-sort-of-Long-name-for-Testing",
UUID: "dead-1234-beef-0987",
Offset: 8675309,
},
}
var secrets = []string{
"supersecret",
"12345",
"a",
"SuperSecret",
"Sup3r... S3cr3t!",
"This is a reasonably long secret key that is used for the purpose of testing.",
"\u2603+\u2744", // snowman+snowflake
}
// TestLayerUploadTokens constructs stateTokens from LayerUploadStates and
// validates that the tokens can be used to reconstruct the proper upload state.
func TestLayerUploadTokens(t *testing.T) {
secret := hmacKey("supersecret")
for _, testcase := range blobUploadStates {
token, err := secret.packUploadState(testcase)
if err != nil {
t.Fatal(err)
}
lus, err := secret.unpackUploadState(token)
if err != nil {
t.Fatal(err)
}
assertBlobUploadStateEquals(t, testcase, lus)
}
}
// TestHMACValidate ensures that any HMAC token providers are compatible if and
// only if they share the same secret.
func TestHMACValidation(t *testing.T) {
for _, secret := range secrets {
secret1 := hmacKey(secret)
secret2 := hmacKey(secret)
badSecret := hmacKey("DifferentSecret")
for _, testcase := range blobUploadStates {
token, err := secret1.packUploadState(testcase)
if err != nil {
t.Fatal(err)
}
lus, err := secret2.unpackUploadState(token)
if err != nil {
t.Fatal(err)
}
assertBlobUploadStateEquals(t, testcase, lus)
_, err = badSecret.unpackUploadState(token)
if err == nil {
t.Fatalf("Expected token provider to fail at retrieving state from token: %s", token)
}
badToken, err := badSecret.packUploadState(lus)
if err != nil {
t.Fatal(err)
}
_, err = secret1.unpackUploadState(badToken)
if err == nil {
t.Fatalf("Expected token provider to fail at retrieving state from token: %s", badToken)
}
_, err = secret2.unpackUploadState(badToken)
if err == nil {
t.Fatalf("Expected token provider to fail at retrieving state from token: %s", badToken)
}
}
}
}
func assertBlobUploadStateEquals(t *testing.T, expected blobUploadState, received blobUploadState) {
if expected.Name != received.Name {
t.Fatalf("Expected Name=%q, Received Name=%q", expected.Name, received.Name)
}
if expected.UUID != received.UUID {
t.Fatalf("Expected UUID=%q, Received UUID=%q", expected.UUID, received.UUID)
}
if expected.Offset != received.Offset {
t.Fatalf("Expected Offset=%d, Received Offset=%d", expected.Offset, received.Offset)
}
}

View File

@ -0,0 +1,53 @@
package handlers
import (
"bytes"
"errors"
"fmt"
"strings"
"text/template"
"github.com/Sirupsen/logrus"
)
// logHook is for hooking Panic in web application
type logHook struct {
LevelsParam []string
Mail *mailer
}
// Fire forwards an error to LogHook
func (hook *logHook) Fire(entry *logrus.Entry) error {
addr := strings.Split(hook.Mail.Addr, ":")
if len(addr) != 2 {
return errors.New("Invalid Mail Address")
}
host := addr[0]
subject := fmt.Sprintf("[%s] %s: %s", entry.Level, host, entry.Message)
html := `
{{.Message}}
{{range $key, $value := .Data}}
{{$key}}: {{$value}}
{{end}}
`
b := bytes.NewBuffer(make([]byte, 0))
t := template.Must(template.New("mail body").Parse(html))
if err := t.Execute(b, entry); err != nil {
return err
}
body := fmt.Sprintf("%s", b)
return hook.Mail.sendMail(subject, body)
}
// Levels contains hook levels to be catched
func (hook *logHook) Levels() []logrus.Level {
levels := []logrus.Level{}
for _, v := range hook.LevelsParam {
lv, _ := logrus.ParseLevel(v)
levels = append(levels, lv)
}
return levels
}

View File

@ -0,0 +1,462 @@
package handlers
import (
"bytes"
"fmt"
"net/http"
"strings"
"github.com/docker/distribution"
ctxu "github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/distribution/registry/auth"
"github.com/gorilla/handlers"
)
// These constants determine which architecture and OS to choose from a
// manifest list when downconverting it to a schema1 manifest.
const (
defaultArch = "amd64"
defaultOS = "linux"
maxManifestBodySize = 4 << 20
)
// imageManifestDispatcher takes the request context and builds the
// appropriate handler for handling image manifest requests.
func imageManifestDispatcher(ctx *Context, r *http.Request) http.Handler {
imageManifestHandler := &imageManifestHandler{
Context: ctx,
}
reference := getReference(ctx)
dgst, err := digest.ParseDigest(reference)
if err != nil {
// We just have a tag
imageManifestHandler.Tag = reference
} else {
imageManifestHandler.Digest = dgst
}
mhandler := handlers.MethodHandler{
"GET": http.HandlerFunc(imageManifestHandler.GetImageManifest),
"HEAD": http.HandlerFunc(imageManifestHandler.GetImageManifest),
}
if !ctx.readOnly {
mhandler["PUT"] = http.HandlerFunc(imageManifestHandler.PutImageManifest)
mhandler["DELETE"] = http.HandlerFunc(imageManifestHandler.DeleteImageManifest)
}
return mhandler
}
// imageManifestHandler handles http operations on image manifests.
type imageManifestHandler struct {
*Context
// One of tag or digest gets set, depending on what is present in context.
Tag string
Digest digest.Digest
}
// GetImageManifest fetches the image manifest from the storage backend, if it exists.
func (imh *imageManifestHandler) GetImageManifest(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(imh).Debug("GetImageManifest")
manifests, err := imh.Repository.Manifests(imh)
if err != nil {
imh.Errors = append(imh.Errors, err)
return
}
var manifest distribution.Manifest
if imh.Tag != "" {
tags := imh.Repository.Tags(imh)
desc, err := tags.Get(imh, imh.Tag)
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err))
return
}
imh.Digest = desc.Digest
}
if etagMatch(r, imh.Digest.String()) {
w.WriteHeader(http.StatusNotModified)
return
}
var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
manifest, err = manifests.Get(imh, imh.Digest, options...)
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err))
return
}
supportsSchema2 := false
supportsManifestList := false
// this parsing of Accept headers is not quite as full-featured as godoc.org's parser, but we don't care about "q=" values
// https://github.com/golang/gddo/blob/e91d4165076d7474d20abda83f92d15c7ebc3e81/httputil/header/header.go#L165-L202
for _, acceptHeader := range r.Header["Accept"] {
// r.Header[...] is a slice in case the request contains the same header more than once
// if the header isn't set, we'll get the zero value, which "range" will handle gracefully
// we need to split each header value on "," to get the full list of "Accept" values (per RFC 2616)
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.1
for _, mediaType := range strings.Split(acceptHeader, ",") {
// remove "; q=..." if present
if i := strings.Index(mediaType, ";"); i >= 0 {
mediaType = mediaType[:i]
}
// it's common (but not required) for Accept values to be space separated ("a/b, c/d, e/f")
mediaType = strings.TrimSpace(mediaType)
if mediaType == schema2.MediaTypeManifest {
supportsSchema2 = true
}
if mediaType == manifestlist.MediaTypeManifestList {
supportsManifestList = true
}
}
}
schema2Manifest, isSchema2 := manifest.(*schema2.DeserializedManifest)
manifestList, isManifestList := manifest.(*manifestlist.DeserializedManifestList)
// Only rewrite schema2 manifests when they are being fetched by tag.
// If they are being fetched by digest, we can't return something not
// matching the digest.
if imh.Tag != "" && isSchema2 && !supportsSchema2 {
// Rewrite manifest in schema1 format
ctxu.GetLogger(imh).Infof("rewriting manifest %s in schema1 format to support old client", imh.Digest.String())
manifest, err = imh.convertSchema2Manifest(schema2Manifest)
if err != nil {
return
}
} else if imh.Tag != "" && isManifestList && !supportsManifestList {
// Rewrite manifest in schema1 format
ctxu.GetLogger(imh).Infof("rewriting manifest list %s in schema1 format to support old client", imh.Digest.String())
// Find the image manifest corresponding to the default
// platform
var manifestDigest digest.Digest
for _, manifestDescriptor := range manifestList.Manifests {
if manifestDescriptor.Platform.Architecture == defaultArch && manifestDescriptor.Platform.OS == defaultOS {
manifestDigest = manifestDescriptor.Digest
break
}
}
if manifestDigest == "" {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown)
return
}
manifest, err = manifests.Get(imh, manifestDigest)
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err))
return
}
// If necessary, convert the image manifest
if schema2Manifest, isSchema2 := manifest.(*schema2.DeserializedManifest); isSchema2 && !supportsSchema2 {
manifest, err = imh.convertSchema2Manifest(schema2Manifest)
if err != nil {
return
}
}
}
ct, p, err := manifest.Payload()
if err != nil {
return
}
w.Header().Set("Content-Type", ct)
w.Header().Set("Content-Length", fmt.Sprint(len(p)))
w.Header().Set("Docker-Content-Digest", imh.Digest.String())
w.Header().Set("Etag", fmt.Sprintf(`"%s"`, imh.Digest))
w.Write(p)
}
func (imh *imageManifestHandler) convertSchema2Manifest(schema2Manifest *schema2.DeserializedManifest) (distribution.Manifest, error) {
targetDescriptor := schema2Manifest.Target()
blobs := imh.Repository.Blobs(imh)
configJSON, err := blobs.Get(imh, targetDescriptor.Digest)
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err))
return nil, err
}
ref := imh.Repository.Named()
if imh.Tag != "" {
ref, err = reference.WithTag(ref, imh.Tag)
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeTagInvalid.WithDetail(err))
return nil, err
}
}
builder := schema1.NewConfigManifestBuilder(imh.Repository.Blobs(imh), imh.Context.App.trustKey, ref, configJSON)
for _, d := range schema2Manifest.Layers {
if err := builder.AppendReference(d); err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err))
return nil, err
}
}
manifest, err := builder.Build(imh)
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err))
return nil, err
}
imh.Digest = digest.FromBytes(manifest.(*schema1.SignedManifest).Canonical)
return manifest, nil
}
func etagMatch(r *http.Request, etag string) bool {
for _, headerVal := range r.Header["If-None-Match"] {
if headerVal == etag || headerVal == fmt.Sprintf(`"%s"`, etag) { // allow quoted or unquoted
return true
}
}
return false
}
// PutImageManifest validates and stores an image in the registry.
func (imh *imageManifestHandler) PutImageManifest(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(imh).Debug("PutImageManifest")
manifests, err := imh.Repository.Manifests(imh)
if err != nil {
imh.Errors = append(imh.Errors, err)
return
}
var jsonBuf bytes.Buffer
if err := copyFullPayload(w, r, &jsonBuf, maxManifestBodySize, imh, "image manifest PUT"); err != nil {
// copyFullPayload reports the error if necessary
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err.Error()))
return
}
mediaType := r.Header.Get("Content-Type")
manifest, desc, err := distribution.UnmarshalManifest(mediaType, jsonBuf.Bytes())
if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestInvalid.WithDetail(err))
return
}
if imh.Digest != "" {
if desc.Digest != imh.Digest {
ctxu.GetLogger(imh).Errorf("payload digest does match: %q != %q", desc.Digest, imh.Digest)
imh.Errors = append(imh.Errors, v2.ErrorCodeDigestInvalid)
return
}
} else if imh.Tag != "" {
imh.Digest = desc.Digest
} else {
imh.Errors = append(imh.Errors, v2.ErrorCodeTagInvalid.WithDetail("no tag or digest specified"))
return
}
var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
if err := imh.applyResourcePolicy(manifest); err != nil {
imh.Errors = append(imh.Errors, err)
return
}
_, err = manifests.Put(imh, manifest, options...)
if err != nil {
// TODO(stevvooe): These error handling switches really need to be
// handled by an app global mapper.
if err == distribution.ErrUnsupported {
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnsupported)
return
}
if err == distribution.ErrAccessDenied {
imh.Errors = append(imh.Errors, errcode.ErrorCodeDenied)
return
}
switch err := err.(type) {
case distribution.ErrManifestVerification:
for _, verificationError := range err {
switch verificationError := verificationError.(type) {
case distribution.ErrManifestBlobUnknown:
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestBlobUnknown.WithDetail(verificationError.Digest))
case distribution.ErrManifestNameInvalid:
imh.Errors = append(imh.Errors, v2.ErrorCodeNameInvalid.WithDetail(err))
case distribution.ErrManifestUnverified:
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnverified)
default:
if verificationError == digest.ErrDigestInvalidFormat {
imh.Errors = append(imh.Errors, v2.ErrorCodeDigestInvalid)
} else {
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown, verificationError)
}
}
}
case errcode.Error:
imh.Errors = append(imh.Errors, err)
default:
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
return
}
// Tag this manifest
if imh.Tag != "" {
tags := imh.Repository.Tags(imh)
err = tags.Tag(imh, imh.Tag, desc)
if err != nil {
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
}
// Construct a canonical url for the uploaded manifest.
ref, err := reference.WithDigest(imh.Repository.Named(), imh.Digest)
if err != nil {
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
location, err := imh.urlBuilder.BuildManifestURL(ref)
if err != nil {
// NOTE(stevvooe): Given the behavior above, this absurdly unlikely to
// happen. We'll log the error here but proceed as if it worked. Worst
// case, we set an empty location header.
ctxu.GetLogger(imh).Errorf("error building manifest url from digest: %v", err)
}
w.Header().Set("Location", location)
w.Header().Set("Docker-Content-Digest", imh.Digest.String())
w.WriteHeader(http.StatusCreated)
}
// applyResourcePolicy checks whether the resource class matches what has
// been authorized and allowed by the policy configuration.
func (imh *imageManifestHandler) applyResourcePolicy(manifest distribution.Manifest) error {
allowedClasses := imh.App.Config.Policy.Repository.Classes
if len(allowedClasses) == 0 {
return nil
}
var class string
switch m := manifest.(type) {
case *schema1.SignedManifest:
class = "image"
case *schema2.DeserializedManifest:
switch m.Config.MediaType {
case schema2.MediaTypeConfig:
class = "image"
case schema2.MediaTypePluginConfig:
class = "plugin"
default:
message := fmt.Sprintf("unknown manifest class for %s", m.Config.MediaType)
return errcode.ErrorCodeDenied.WithMessage(message)
}
}
if class == "" {
return nil
}
// Check to see if class is allowed in registry
var allowedClass bool
for _, c := range allowedClasses {
if class == c {
allowedClass = true
break
}
}
if !allowedClass {
message := fmt.Sprintf("registry does not allow %s manifest", class)
return errcode.ErrorCodeDenied.WithMessage(message)
}
resources := auth.AuthorizedResources(imh)
n := imh.Repository.Named().Name()
var foundResource bool
for _, r := range resources {
if r.Name == n {
if r.Class == "" {
r.Class = "image"
}
if r.Class == class {
return nil
}
foundResource = true
}
}
// resource was found but no matching class was found
if foundResource {
message := fmt.Sprintf("repository not authorized for %s manifest", class)
return errcode.ErrorCodeDenied.WithMessage(message)
}
return nil
}
// DeleteImageManifest removes the manifest with the given digest from the registry.
func (imh *imageManifestHandler) DeleteImageManifest(w http.ResponseWriter, r *http.Request) {
ctxu.GetLogger(imh).Debug("DeleteImageManifest")
manifests, err := imh.Repository.Manifests(imh)
if err != nil {
imh.Errors = append(imh.Errors, err)
return
}
err = manifests.Delete(imh, imh.Digest)
if err != nil {
switch err {
case digest.ErrDigestUnsupported:
case digest.ErrDigestInvalidFormat:
imh.Errors = append(imh.Errors, v2.ErrorCodeDigestInvalid)
return
case distribution.ErrBlobUnknown:
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown)
return
case distribution.ErrUnsupported:
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnsupported)
return
default:
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown)
return
}
}
tagService := imh.Repository.Tags(imh)
referencedTags, err := tagService.Lookup(imh, distribution.Descriptor{Digest: imh.Digest})
if err != nil {
imh.Errors = append(imh.Errors, err)
return
}
for _, tag := range referencedTags {
if err := tagService.Untag(imh, tag); err != nil {
imh.Errors = append(imh.Errors, err)
return
}
}
w.WriteHeader(http.StatusAccepted)
}

View File

@ -0,0 +1,45 @@
package handlers
import (
"errors"
"net/smtp"
"strings"
)
// mailer provides fields of email configuration for sending.
type mailer struct {
Addr, Username, Password, From string
Insecure bool
To []string
}
// sendMail allows users to send email, only if mail parameters is configured correctly.
func (mail *mailer) sendMail(subject, message string) error {
addr := strings.Split(mail.Addr, ":")
if len(addr) != 2 {
return errors.New("Invalid Mail Address")
}
host := addr[0]
msg := []byte("To:" + strings.Join(mail.To, ";") +
"\r\nFrom: " + mail.From +
"\r\nSubject: " + subject +
"\r\nContent-Type: text/plain\r\n\r\n" +
message)
auth := smtp.PlainAuth(
"",
mail.Username,
mail.Password,
host,
)
err := smtp.SendMail(
mail.Addr,
auth,
mail.From,
mail.To,
[]byte(msg),
)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,62 @@
package handlers
import (
"encoding/json"
"net/http"
"github.com/docker/distribution"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/api/v2"
"github.com/gorilla/handlers"
)
// tagsDispatcher constructs the tags handler api endpoint.
func tagsDispatcher(ctx *Context, r *http.Request) http.Handler {
tagsHandler := &tagsHandler{
Context: ctx,
}
return handlers.MethodHandler{
"GET": http.HandlerFunc(tagsHandler.GetTags),
}
}
// tagsHandler handles requests for lists of tags under a repository name.
type tagsHandler struct {
*Context
}
type tagsAPIResponse struct {
Name string `json:"name"`
Tags []string `json:"tags"`
}
// GetTags returns a json list of tags for a specific image name.
func (th *tagsHandler) GetTags(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
tagService := th.Repository.Tags(th)
tags, err := tagService.All(th)
if err != nil {
switch err := err.(type) {
case distribution.ErrRepositoryUnknown:
th.Errors = append(th.Errors, v2.ErrorCodeNameUnknown.WithDetail(map[string]string{"name": th.Repository.Named().Name()}))
case errcode.Error:
th.Errors = append(th.Errors, err)
default:
th.Errors = append(th.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
}
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
enc := json.NewEncoder(w)
if err := enc.Encode(tagsAPIResponse{
Name: th.Repository.Named().Name(),
Tags: tags,
}); err != nil {
th.Errors = append(th.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
}

View File

@ -0,0 +1,74 @@
package listener
import (
"fmt"
"net"
"os"
"time"
)
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
// it is a plain copy-paste from net/http/server.go
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// NewListener announces on laddr and net. Accepted values of the net are
// 'unix' and 'tcp'
func NewListener(net, laddr string) (net.Listener, error) {
switch net {
case "unix":
return newUnixListener(laddr)
case "tcp", "": // an empty net means tcp
return newTCPListener(laddr)
default:
return nil, fmt.Errorf("unknown address type %s", net)
}
}
func newUnixListener(laddr string) (net.Listener, error) {
fi, err := os.Stat(laddr)
if err == nil {
// the file exists.
// try to remove it if it's a socket
if !isSocket(fi.Mode()) {
return nil, fmt.Errorf("file %s exists and is not a socket", laddr)
}
if err := os.Remove(laddr); err != nil {
return nil, err
}
} else if !os.IsNotExist(err) {
// we can't do stat on the file.
// it means we can not remove it
return nil, err
}
return net.Listen("unix", laddr)
}
func isSocket(m os.FileMode) bool {
return m&os.ModeSocket != 0
}
func newTCPListener(laddr string) (net.Listener, error) {
ln, err := net.Listen("tcp", laddr)
if err != nil {
return nil, err
}
return tcpKeepAliveListener{ln.(*net.TCPListener)}, nil
}

View File

@ -0,0 +1,54 @@
package middleware
import (
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
)
// InitFunc is the type of a RegistryMiddleware factory function and is
// used to register the constructor for different RegistryMiddleware backends.
type InitFunc func(ctx context.Context, registry distribution.Namespace, options map[string]interface{}) (distribution.Namespace, error)
var middlewares map[string]InitFunc
var registryoptions []storage.RegistryOption
// Register is used to register an InitFunc for
// a RegistryMiddleware backend with the given name.
func Register(name string, initFunc InitFunc) error {
if middlewares == nil {
middlewares = make(map[string]InitFunc)
}
if _, exists := middlewares[name]; exists {
return fmt.Errorf("name already registered: %s", name)
}
middlewares[name] = initFunc
return nil
}
// Get constructs a RegistryMiddleware with the given options using the named backend.
func Get(ctx context.Context, name string, options map[string]interface{}, registry distribution.Namespace) (distribution.Namespace, error) {
if middlewares != nil {
if initFunc, exists := middlewares[name]; exists {
return initFunc(ctx, registry, options)
}
}
return nil, fmt.Errorf("no registry middleware registered with name: %s", name)
}
// RegisterOptions adds more options to RegistryOption list. Options get applied before
// any other configuration-based options.
func RegisterOptions(options ...storage.RegistryOption) error {
registryoptions = append(registryoptions, options...)
return nil
}
// GetRegistryOptions returns list of RegistryOption.
func GetRegistryOptions() []storage.RegistryOption {
return registryoptions
}

View File

@ -0,0 +1,40 @@
package middleware
import (
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
// InitFunc is the type of a RepositoryMiddleware factory function and is
// used to register the constructor for different RepositoryMiddleware backends.
type InitFunc func(ctx context.Context, repository distribution.Repository, options map[string]interface{}) (distribution.Repository, error)
var middlewares map[string]InitFunc
// Register is used to register an InitFunc for
// a RepositoryMiddleware backend with the given name.
func Register(name string, initFunc InitFunc) error {
if middlewares == nil {
middlewares = make(map[string]InitFunc)
}
if _, exists := middlewares[name]; exists {
return fmt.Errorf("name already registered: %s", name)
}
middlewares[name] = initFunc
return nil
}
// Get constructs a RepositoryMiddleware with the given options using the named backend.
func Get(ctx context.Context, name string, options map[string]interface{}, repository distribution.Repository) (distribution.Repository, error) {
if middlewares != nil {
if initFunc, exists := middlewares[name]; exists {
return initFunc(ctx, repository, options)
}
}
return nil, fmt.Errorf("no repository middleware registered with name: %s", name)
}

View File

@ -0,0 +1,87 @@
package proxy
import (
"net/http"
"net/url"
"strings"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/client/auth"
"github.com/docker/distribution/registry/client/auth/challenge"
)
const challengeHeader = "Docker-Distribution-Api-Version"
type userpass struct {
username string
password string
}
type credentials struct {
creds map[string]userpass
}
func (c credentials) Basic(u *url.URL) (string, string) {
up := c.creds[u.String()]
return up.username, up.password
}
func (c credentials) RefreshToken(u *url.URL, service string) string {
return ""
}
func (c credentials) SetRefreshToken(u *url.URL, service, token string) {
}
// configureAuth stores credentials for challenge responses
func configureAuth(username, password, remoteURL string) (auth.CredentialStore, error) {
creds := map[string]userpass{}
authURLs, err := getAuthURLs(remoteURL)
if err != nil {
return nil, err
}
for _, url := range authURLs {
context.GetLogger(context.Background()).Infof("Discovered token authentication URL: %s", url)
creds[url] = userpass{
username: username,
password: password,
}
}
return credentials{creds: creds}, nil
}
func getAuthURLs(remoteURL string) ([]string, error) {
authURLs := []string{}
resp, err := http.Get(remoteURL + "/v2/")
if err != nil {
return nil, err
}
defer resp.Body.Close()
for _, c := range challenge.ResponseChallenges(resp) {
if strings.EqualFold(c.Scheme, "bearer") {
authURLs = append(authURLs, c.Parameters["realm"])
}
}
return authURLs, nil
}
func ping(manager challenge.Manager, endpoint, versionHeader string) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
}
defer resp.Body.Close()
if err := manager.AddResponse(resp); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,224 @@
package proxy
import (
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/proxy/scheduler"
)
// todo(richardscothern): from cache control header or config file
const blobTTL = time.Duration(24 * 7 * time.Hour)
type proxyBlobStore struct {
localStore distribution.BlobStore
remoteStore distribution.BlobService
scheduler *scheduler.TTLExpirationScheduler
repositoryName reference.Named
authChallenger authChallenger
}
var _ distribution.BlobStore = &proxyBlobStore{}
// inflight tracks currently downloading blobs
var inflight = make(map[digest.Digest]struct{})
// mu protects inflight
var mu sync.Mutex
func setResponseHeaders(w http.ResponseWriter, length int64, mediaType string, digest digest.Digest) {
w.Header().Set("Content-Length", strconv.FormatInt(length, 10))
w.Header().Set("Content-Type", mediaType)
w.Header().Set("Docker-Content-Digest", digest.String())
w.Header().Set("Etag", digest.String())
}
func (pbs *proxyBlobStore) copyContent(ctx context.Context, dgst digest.Digest, writer io.Writer) (distribution.Descriptor, error) {
desc, err := pbs.remoteStore.Stat(ctx, dgst)
if err != nil {
return distribution.Descriptor{}, err
}
if w, ok := writer.(http.ResponseWriter); ok {
setResponseHeaders(w, desc.Size, desc.MediaType, dgst)
}
remoteReader, err := pbs.remoteStore.Open(ctx, dgst)
if err != nil {
return distribution.Descriptor{}, err
}
defer remoteReader.Close()
_, err = io.CopyN(writer, remoteReader, desc.Size)
if err != nil {
return distribution.Descriptor{}, err
}
proxyMetrics.BlobPush(uint64(desc.Size))
return desc, nil
}
func (pbs *proxyBlobStore) serveLocal(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) (bool, error) {
localDesc, err := pbs.localStore.Stat(ctx, dgst)
if err != nil {
// Stat can report a zero sized file here if it's checked between creation
// and population. Return nil error, and continue
return false, nil
}
if err == nil {
proxyMetrics.BlobPush(uint64(localDesc.Size))
return true, pbs.localStore.ServeBlob(ctx, w, r, dgst)
}
return false, nil
}
func (pbs *proxyBlobStore) storeLocal(ctx context.Context, dgst digest.Digest) error {
defer func() {
mu.Lock()
delete(inflight, dgst)
mu.Unlock()
}()
var desc distribution.Descriptor
var err error
var bw distribution.BlobWriter
bw, err = pbs.localStore.Create(ctx)
if err != nil {
return err
}
desc, err = pbs.copyContent(ctx, dgst, bw)
if err != nil {
return err
}
_, err = bw.Commit(ctx, desc)
if err != nil {
return err
}
return nil
}
func (pbs *proxyBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
served, err := pbs.serveLocal(ctx, w, r, dgst)
if err != nil {
context.GetLogger(ctx).Errorf("Error serving blob from local storage: %s", err.Error())
return err
}
if served {
return nil
}
if err := pbs.authChallenger.tryEstablishChallenges(ctx); err != nil {
return err
}
mu.Lock()
_, ok := inflight[dgst]
if ok {
mu.Unlock()
_, err := pbs.copyContent(ctx, dgst, w)
return err
}
inflight[dgst] = struct{}{}
mu.Unlock()
go func(dgst digest.Digest) {
if err := pbs.storeLocal(ctx, dgst); err != nil {
context.GetLogger(ctx).Errorf("Error committing to storage: %s", err.Error())
}
blobRef, err := reference.WithDigest(pbs.repositoryName, dgst)
if err != nil {
context.GetLogger(ctx).Errorf("Error creating reference: %s", err)
return
}
pbs.scheduler.AddBlob(blobRef, repositoryTTL)
}(dgst)
_, err = pbs.copyContent(ctx, dgst, w)
if err != nil {
return err
}
return nil
}
func (pbs *proxyBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
desc, err := pbs.localStore.Stat(ctx, dgst)
if err == nil {
return desc, err
}
if err != distribution.ErrBlobUnknown {
return distribution.Descriptor{}, err
}
if err := pbs.authChallenger.tryEstablishChallenges(ctx); err != nil {
return distribution.Descriptor{}, err
}
return pbs.remoteStore.Stat(ctx, dgst)
}
func (pbs *proxyBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) {
blob, err := pbs.localStore.Get(ctx, dgst)
if err == nil {
return blob, nil
}
if err := pbs.authChallenger.tryEstablishChallenges(ctx); err != nil {
return []byte{}, err
}
blob, err = pbs.remoteStore.Get(ctx, dgst)
if err != nil {
return []byte{}, err
}
_, err = pbs.localStore.Put(ctx, "", blob)
if err != nil {
return []byte{}, err
}
return blob, nil
}
// Unsupported functions
func (pbs *proxyBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) {
return distribution.Descriptor{}, distribution.ErrUnsupported
}
func (pbs *proxyBlobStore) Create(ctx context.Context, options ...distribution.BlobCreateOption) (distribution.BlobWriter, error) {
return nil, distribution.ErrUnsupported
}
func (pbs *proxyBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
return nil, distribution.ErrUnsupported
}
func (pbs *proxyBlobStore) Mount(ctx context.Context, sourceRepo reference.Named, dgst digest.Digest) (distribution.Descriptor, error) {
return distribution.Descriptor{}, distribution.ErrUnsupported
}
func (pbs *proxyBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) {
return nil, distribution.ErrUnsupported
}
func (pbs *proxyBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
return distribution.ErrUnsupported
}

View File

@ -0,0 +1,416 @@
package proxy
import (
"io/ioutil"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/proxy/scheduler"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver/filesystem"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
var sbsMu sync.Mutex
type statsBlobStore struct {
stats map[string]int
blobs distribution.BlobStore
}
func (sbs statsBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) {
sbsMu.Lock()
sbs.stats["put"]++
sbsMu.Unlock()
return sbs.blobs.Put(ctx, mediaType, p)
}
func (sbs statsBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) {
sbsMu.Lock()
sbs.stats["get"]++
sbsMu.Unlock()
return sbs.blobs.Get(ctx, dgst)
}
func (sbs statsBlobStore) Create(ctx context.Context, options ...distribution.BlobCreateOption) (distribution.BlobWriter, error) {
sbsMu.Lock()
sbs.stats["create"]++
sbsMu.Unlock()
return sbs.blobs.Create(ctx, options...)
}
func (sbs statsBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
sbsMu.Lock()
sbs.stats["resume"]++
sbsMu.Unlock()
return sbs.blobs.Resume(ctx, id)
}
func (sbs statsBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) {
sbsMu.Lock()
sbs.stats["open"]++
sbsMu.Unlock()
return sbs.blobs.Open(ctx, dgst)
}
func (sbs statsBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
sbsMu.Lock()
sbs.stats["serveblob"]++
sbsMu.Unlock()
return sbs.blobs.ServeBlob(ctx, w, r, dgst)
}
func (sbs statsBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
sbsMu.Lock()
sbs.stats["stat"]++
sbsMu.Unlock()
return sbs.blobs.Stat(ctx, dgst)
}
func (sbs statsBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
sbsMu.Lock()
sbs.stats["delete"]++
sbsMu.Unlock()
return sbs.blobs.Delete(ctx, dgst)
}
type testEnv struct {
numUnique int
inRemote []distribution.Descriptor
store proxyBlobStore
ctx context.Context
}
func (te *testEnv) LocalStats() *map[string]int {
sbsMu.Lock()
ls := te.store.localStore.(statsBlobStore).stats
sbsMu.Unlock()
return &ls
}
func (te *testEnv) RemoteStats() *map[string]int {
sbsMu.Lock()
rs := te.store.remoteStore.(statsBlobStore).stats
sbsMu.Unlock()
return &rs
}
// Populate remote store and record the digests
func makeTestEnv(t *testing.T, name string) *testEnv {
nameRef, err := reference.ParseNamed(name)
if err != nil {
t.Fatalf("unable to parse reference: %s", err)
}
ctx := context.Background()
truthDir, err := ioutil.TempDir("", "truth")
if err != nil {
t.Fatalf("unable to create tempdir: %s", err)
}
cacheDir, err := ioutil.TempDir("", "cache")
if err != nil {
t.Fatalf("unable to create tempdir: %s", err)
}
localDriver, err := filesystem.FromParameters(map[string]interface{}{
"rootdirectory": truthDir,
})
if err != nil {
t.Fatalf("unable to create filesystem driver: %s", err)
}
// todo: create a tempfile area here
localRegistry, err := storage.NewRegistry(ctx, localDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
localRepo, err := localRegistry.Repository(ctx, nameRef)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
cacheDriver, err := filesystem.FromParameters(map[string]interface{}{
"rootdirectory": cacheDir,
})
if err != nil {
t.Fatalf("unable to create filesystem driver: %s", err)
}
truthRegistry, err := storage.NewRegistry(ctx, cacheDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()))
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
truthRepo, err := truthRegistry.Repository(ctx, nameRef)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
truthBlobs := statsBlobStore{
stats: make(map[string]int),
blobs: truthRepo.Blobs(ctx),
}
localBlobs := statsBlobStore{
stats: make(map[string]int),
blobs: localRepo.Blobs(ctx),
}
s := scheduler.New(ctx, inmemory.New(), "/scheduler-state.json")
proxyBlobStore := proxyBlobStore{
repositoryName: nameRef,
remoteStore: truthBlobs,
localStore: localBlobs,
scheduler: s,
authChallenger: &mockChallenger{},
}
te := &testEnv{
store: proxyBlobStore,
ctx: ctx,
}
return te
}
func makeBlob(size int) []byte {
blob := make([]byte, size, size)
for i := 0; i < size; i++ {
blob[i] = byte('A' + rand.Int()%48)
}
return blob
}
func init() {
rand.Seed(42)
}
func perm(m []distribution.Descriptor) []distribution.Descriptor {
for i := 0; i < len(m); i++ {
j := rand.Intn(i + 1)
tmp := m[i]
m[i] = m[j]
m[j] = tmp
}
return m
}
func populate(t *testing.T, te *testEnv, blobCount, size, numUnique int) {
var inRemote []distribution.Descriptor
for i := 0; i < numUnique; i++ {
bytes := makeBlob(size)
for j := 0; j < blobCount/numUnique; j++ {
desc, err := te.store.remoteStore.Put(te.ctx, "", bytes)
if err != nil {
t.Fatalf("Put in store")
}
inRemote = append(inRemote, desc)
}
}
te.inRemote = inRemote
te.numUnique = numUnique
}
func TestProxyStoreGet(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
localStats := te.LocalStats()
remoteStats := te.RemoteStats()
populate(t, te, 1, 10, 1)
_, err := te.store.Get(te.ctx, te.inRemote[0].Digest)
if err != nil {
t.Fatal(err)
}
if (*localStats)["get"] != 1 && (*localStats)["put"] != 1 {
t.Errorf("Unexpected local counts")
}
if (*remoteStats)["get"] != 1 {
t.Errorf("Unexpected remote get count")
}
_, err = te.store.Get(te.ctx, te.inRemote[0].Digest)
if err != nil {
t.Fatal(err)
}
if (*localStats)["get"] != 2 && (*localStats)["put"] != 1 {
t.Errorf("Unexpected local counts")
}
if (*remoteStats)["get"] != 1 {
t.Errorf("Unexpected remote get count")
}
}
func TestProxyStoreStat(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
remoteBlobCount := 1
populate(t, te, remoteBlobCount, 10, 1)
localStats := te.LocalStats()
remoteStats := te.RemoteStats()
// Stat - touches both stores
for _, d := range te.inRemote {
_, err := te.store.Stat(te.ctx, d.Digest)
if err != nil {
t.Fatalf("Error stating proxy store")
}
}
if (*localStats)["stat"] != remoteBlobCount {
t.Errorf("Unexpected local stat count")
}
if (*remoteStats)["stat"] != remoteBlobCount {
t.Errorf("Unexpected remote stat count")
}
if te.store.authChallenger.(*mockChallenger).count != len(te.inRemote) {
t.Fatalf("Unexpected auth challenge count, got %#v", te.store.authChallenger)
}
}
func TestProxyStoreServeHighConcurrency(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
blobSize := 200
blobCount := 10
numUnique := 1
populate(t, te, blobCount, blobSize, numUnique)
numClients := 16
testProxyStoreServe(t, te, numClients)
}
func TestProxyStoreServeMany(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
blobSize := 200
blobCount := 10
numUnique := 4
populate(t, te, blobCount, blobSize, numUnique)
numClients := 4
testProxyStoreServe(t, te, numClients)
}
// todo(richardscothern): blobCount must be smaller than num clients
func TestProxyStoreServeBig(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
blobSize := 2 << 20
blobCount := 4
numUnique := 2
populate(t, te, blobCount, blobSize, numUnique)
numClients := 4
testProxyStoreServe(t, te, numClients)
}
// testProxyStoreServe will create clients to consume all blobs
// populated in the truth store
func testProxyStoreServe(t *testing.T, te *testEnv, numClients int) {
localStats := te.LocalStats()
remoteStats := te.RemoteStats()
var wg sync.WaitGroup
for i := 0; i < numClients; i++ {
// Serveblob - pulls through blobs
wg.Add(1)
go func() {
defer wg.Done()
for _, remoteBlob := range te.inRemote {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
}
err = te.store.ServeBlob(te.ctx, w, r, remoteBlob.Digest)
if err != nil {
t.Fatalf(err.Error())
}
bodyBytes := w.Body.Bytes()
localDigest := digest.FromBytes(bodyBytes)
if localDigest != remoteBlob.Digest {
t.Fatalf("Mismatching blob fetch from proxy")
}
}
}()
}
wg.Wait()
remoteBlobCount := len(te.inRemote)
sbsMu.Lock()
if (*localStats)["stat"] != remoteBlobCount*numClients && (*localStats)["create"] != te.numUnique {
sbsMu.Unlock()
t.Fatal("Expected: stat:", remoteBlobCount*numClients, "create:", remoteBlobCount)
}
sbsMu.Unlock()
// Wait for any async storage goroutines to finish
time.Sleep(3 * time.Second)
sbsMu.Lock()
remoteStatCount := (*remoteStats)["stat"]
remoteOpenCount := (*remoteStats)["open"]
sbsMu.Unlock()
// Serveblob - blobs come from local
for _, dr := range te.inRemote {
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
}
err = te.store.ServeBlob(te.ctx, w, r, dr.Digest)
if err != nil {
t.Fatalf(err.Error())
}
dl := digest.FromBytes(w.Body.Bytes())
if dl != dr.Digest {
t.Errorf("Mismatching blob fetch from proxy")
}
}
localStats = te.LocalStats()
remoteStats = te.RemoteStats()
// Ensure remote unchanged
sbsMu.Lock()
defer sbsMu.Unlock()
if (*remoteStats)["stat"] != remoteStatCount && (*remoteStats)["open"] != remoteOpenCount {
t.Fatalf("unexpected remote stats: %#v", remoteStats)
}
}

View File

@ -0,0 +1,95 @@
package proxy
import (
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/proxy/scheduler"
)
// todo(richardscothern): from cache control header or config
const repositoryTTL = time.Duration(24 * 7 * time.Hour)
type proxyManifestStore struct {
ctx context.Context
localManifests distribution.ManifestService
remoteManifests distribution.ManifestService
repositoryName reference.Named
scheduler *scheduler.TTLExpirationScheduler
authChallenger authChallenger
}
var _ distribution.ManifestService = &proxyManifestStore{}
func (pms proxyManifestStore) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
exists, err := pms.localManifests.Exists(ctx, dgst)
if err != nil {
return false, err
}
if exists {
return true, nil
}
if err := pms.authChallenger.tryEstablishChallenges(ctx); err != nil {
return false, err
}
return pms.remoteManifests.Exists(ctx, dgst)
}
func (pms proxyManifestStore) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) {
// At this point `dgst` was either specified explicitly, or returned by the
// tagstore with the most recent association.
var fromRemote bool
manifest, err := pms.localManifests.Get(ctx, dgst, options...)
if err != nil {
if err := pms.authChallenger.tryEstablishChallenges(ctx); err != nil {
return nil, err
}
manifest, err = pms.remoteManifests.Get(ctx, dgst, options...)
if err != nil {
return nil, err
}
fromRemote = true
}
_, payload, err := manifest.Payload()
if err != nil {
return nil, err
}
proxyMetrics.ManifestPush(uint64(len(payload)))
if fromRemote {
proxyMetrics.ManifestPull(uint64(len(payload)))
_, err = pms.localManifests.Put(ctx, manifest)
if err != nil {
return nil, err
}
// Schedule the manifest blob for removal
repoBlob, err := reference.WithDigest(pms.repositoryName, dgst)
if err != nil {
context.GetLogger(ctx).Errorf("Error creating reference: %s", err)
return nil, err
}
pms.scheduler.AddManifest(repoBlob, repositoryTTL)
// Ensure the manifest blob is cleaned up
//pms.scheduler.AddBlob(blobRef, repositoryTTL)
}
return manifest, err
}
func (pms proxyManifestStore) Put(ctx context.Context, manifest distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) {
var d digest.Digest
return d, distribution.ErrUnsupported
}
func (pms proxyManifestStore) Delete(ctx context.Context, dgst digest.Digest) error {
return distribution.ErrUnsupported
}

View File

@ -0,0 +1,275 @@
package proxy
import (
"io"
"sync"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/client/auth"
"github.com/docker/distribution/registry/client/auth/challenge"
"github.com/docker/distribution/registry/proxy/scheduler"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver/inmemory"
"github.com/docker/distribution/testutil"
"github.com/docker/libtrust"
)
type statsManifest struct {
manifests distribution.ManifestService
stats map[string]int
}
type manifestStoreTestEnv struct {
manifestDigest digest.Digest // digest of the signed manifest in the local storage
manifests proxyManifestStore
}
func (te manifestStoreTestEnv) LocalStats() *map[string]int {
ls := te.manifests.localManifests.(statsManifest).stats
return &ls
}
func (te manifestStoreTestEnv) RemoteStats() *map[string]int {
rs := te.manifests.remoteManifests.(statsManifest).stats
return &rs
}
func (sm statsManifest) Delete(ctx context.Context, dgst digest.Digest) error {
sm.stats["delete"]++
return sm.manifests.Delete(ctx, dgst)
}
func (sm statsManifest) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
sm.stats["exists"]++
return sm.manifests.Exists(ctx, dgst)
}
func (sm statsManifest) Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error) {
sm.stats["get"]++
return sm.manifests.Get(ctx, dgst)
}
func (sm statsManifest) Put(ctx context.Context, manifest distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) {
sm.stats["put"]++
return sm.manifests.Put(ctx, manifest)
}
type mockChallenger struct {
sync.Mutex
count int
}
// Called for remote operations only
func (m *mockChallenger) tryEstablishChallenges(context.Context) error {
m.Lock()
defer m.Unlock()
m.count++
return nil
}
func (m *mockChallenger) credentialStore() auth.CredentialStore {
return nil
}
func (m *mockChallenger) challengeManager() challenge.Manager {
return nil
}
func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestEnv {
nameRef, err := reference.ParseNamed(name)
if err != nil {
t.Fatalf("unable to parse reference: %s", err)
}
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(),
storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()),
storage.Schema1SigningKey(k))
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
truthRepo, err := truthRegistry.Repository(ctx, nameRef)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
tr, err := truthRepo.Manifests(ctx)
if err != nil {
t.Fatal(err.Error())
}
truthManifests := statsManifest{
manifests: tr,
stats: make(map[string]int),
}
manifestDigest, err := populateRepo(ctx, t, truthRepo, name, tag)
if err != nil {
t.Fatalf(err.Error())
}
localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption, storage.Schema1SigningKey(k))
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
localRepo, err := localRegistry.Repository(ctx, nameRef)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
lr, err := localRepo.Manifests(ctx)
if err != nil {
t.Fatal(err.Error())
}
localManifests := statsManifest{
manifests: lr,
stats: make(map[string]int),
}
s := scheduler.New(ctx, inmemory.New(), "/scheduler-state.json")
return &manifestStoreTestEnv{
manifestDigest: manifestDigest,
manifests: proxyManifestStore{
ctx: ctx,
localManifests: localManifests,
remoteManifests: truthManifests,
scheduler: s,
repositoryName: nameRef,
authChallenger: &mockChallenger{},
},
}
}
func populateRepo(ctx context.Context, t *testing.T, repository distribution.Repository, name, tag string) (digest.Digest, error) {
m := schema1.Manifest{
Versioned: manifest.Versioned{
SchemaVersion: 1,
},
Name: name,
Tag: tag,
}
for i := 0; i < 2; i++ {
wr, err := repository.Blobs(ctx).Create(ctx)
if err != nil {
t.Fatalf("unexpected error creating test upload: %v", err)
}
rs, ts, err := testutil.CreateRandomTarFile()
if err != nil {
t.Fatalf("unexpected error generating test layer file")
}
dgst := digest.Digest(ts)
if _, err := io.Copy(wr, rs); err != nil {
t.Fatalf("unexpected error copying to upload: %v", err)
}
if _, err := wr.Commit(ctx, distribution.Descriptor{Digest: dgst}); err != nil {
t.Fatalf("unexpected error finishing upload: %v", err)
}
}
pk, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("unexpected error generating private key: %v", err)
}
sm, err := schema1.Sign(&m, pk)
if err != nil {
t.Fatalf("error signing manifest: %v", err)
}
ms, err := repository.Manifests(ctx)
if err != nil {
t.Fatalf(err.Error())
}
dgst, err := ms.Put(ctx, sm)
if err != nil {
t.Fatalf("unexpected errors putting manifest: %v", err)
}
return dgst, nil
}
// TestProxyManifests contains basic acceptance tests
// for the pull-through behavior
func TestProxyManifests(t *testing.T) {
name := "foo/bar"
env := newManifestStoreTestEnv(t, name, "latest")
localStats := env.LocalStats()
remoteStats := env.RemoteStats()
ctx := context.Background()
// Stat - must check local and remote
exists, err := env.manifests.Exists(ctx, env.manifestDigest)
if err != nil {
t.Fatalf("Error checking existence")
}
if !exists {
t.Errorf("Unexpected non-existant manifest")
}
if (*localStats)["exists"] != 1 && (*remoteStats)["exists"] != 1 {
t.Errorf("Unexpected exists count : \n%v \n%v", localStats, remoteStats)
}
if env.manifests.authChallenger.(*mockChallenger).count != 1 {
t.Fatalf("Expected 1 auth challenge, got %#v", env.manifests.authChallenger)
}
// Get - should succeed and pull manifest into local
_, err = env.manifests.Get(ctx, env.manifestDigest)
if err != nil {
t.Fatal(err)
}
if (*localStats)["get"] != 1 && (*remoteStats)["get"] != 1 {
t.Errorf("Unexpected get count")
}
if (*localStats)["put"] != 1 {
t.Errorf("Expected local put")
}
if env.manifests.authChallenger.(*mockChallenger).count != 2 {
t.Fatalf("Expected 2 auth challenges, got %#v", env.manifests.authChallenger)
}
// Stat - should only go to local
exists, err = env.manifests.Exists(ctx, env.manifestDigest)
if err != nil {
t.Fatal(err)
}
if !exists {
t.Errorf("Unexpected non-existant manifest")
}
if (*localStats)["exists"] != 2 && (*remoteStats)["exists"] != 1 {
t.Errorf("Unexpected exists count")
}
if env.manifests.authChallenger.(*mockChallenger).count != 2 {
t.Fatalf("Expected 2 auth challenges, got %#v", env.manifests.authChallenger)
}
// Get proxied - won't require another authchallenge
_, err = env.manifests.Get(ctx, env.manifestDigest)
if err != nil {
t.Fatal(err)
}
if env.manifests.authChallenger.(*mockChallenger).count != 2 {
t.Fatalf("Expected 2 auth challenges, got %#v", env.manifests.authChallenger)
}
}

View File

@ -0,0 +1,74 @@
package proxy
import (
"expvar"
"sync/atomic"
)
// Metrics is used to hold metric counters
// related to the proxy
type Metrics struct {
Requests uint64
Hits uint64
Misses uint64
BytesPulled uint64
BytesPushed uint64
}
type proxyMetricsCollector struct {
blobMetrics Metrics
manifestMetrics Metrics
}
// BlobPull tracks metrics about blobs pulled into the cache
func (pmc *proxyMetricsCollector) BlobPull(bytesPulled uint64) {
atomic.AddUint64(&pmc.blobMetrics.Misses, 1)
atomic.AddUint64(&pmc.blobMetrics.BytesPulled, bytesPulled)
}
// BlobPush tracks metrics about blobs pushed to clients
func (pmc *proxyMetricsCollector) BlobPush(bytesPushed uint64) {
atomic.AddUint64(&pmc.blobMetrics.Requests, 1)
atomic.AddUint64(&pmc.blobMetrics.Hits, 1)
atomic.AddUint64(&pmc.blobMetrics.BytesPushed, bytesPushed)
}
// ManifestPull tracks metrics related to Manifests pulled into the cache
func (pmc *proxyMetricsCollector) ManifestPull(bytesPulled uint64) {
atomic.AddUint64(&pmc.manifestMetrics.Misses, 1)
atomic.AddUint64(&pmc.manifestMetrics.BytesPulled, bytesPulled)
}
// ManifestPush tracks metrics about manifests pushed to clients
func (pmc *proxyMetricsCollector) ManifestPush(bytesPushed uint64) {
atomic.AddUint64(&pmc.manifestMetrics.Requests, 1)
atomic.AddUint64(&pmc.manifestMetrics.Hits, 1)
atomic.AddUint64(&pmc.manifestMetrics.BytesPushed, bytesPushed)
}
// proxyMetrics tracks metrics about the proxy cache. This is
// kept globally and made available via expvar.
var proxyMetrics = &proxyMetricsCollector{}
func init() {
registry := expvar.Get("registry")
if registry == nil {
registry = expvar.NewMap("registry")
}
pm := registry.(*expvar.Map).Get("proxy")
if pm == nil {
pm = &expvar.Map{}
pm.(*expvar.Map).Init()
registry.(*expvar.Map).Set("proxy", pm)
}
pm.(*expvar.Map).Set("blobs", expvar.Func(func() interface{} {
return proxyMetrics.blobMetrics
}))
pm.(*expvar.Map).Set("manifests", expvar.Func(func() interface{} {
return proxyMetrics.manifestMetrics
}))
}

View File

@ -0,0 +1,249 @@
package proxy
import (
"fmt"
"net/http"
"net/url"
"sync"
"github.com/docker/distribution"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/client"
"github.com/docker/distribution/registry/client/auth"
"github.com/docker/distribution/registry/client/auth/challenge"
"github.com/docker/distribution/registry/client/transport"
"github.com/docker/distribution/registry/proxy/scheduler"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver"
)
// proxyingRegistry fetches content from a remote registry and caches it locally
type proxyingRegistry struct {
embedded distribution.Namespace // provides local registry functionality
scheduler *scheduler.TTLExpirationScheduler
remoteURL url.URL
authChallenger authChallenger
}
// NewRegistryPullThroughCache creates a registry acting as a pull through cache
func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Namespace, driver driver.StorageDriver, config configuration.Proxy) (distribution.Namespace, error) {
remoteURL, err := url.Parse(config.RemoteURL)
if err != nil {
return nil, err
}
v := storage.NewVacuum(ctx, driver)
s := scheduler.New(ctx, driver, "/scheduler-state.json")
s.OnBlobExpire(func(ref reference.Reference) error {
var r reference.Canonical
var ok bool
if r, ok = ref.(reference.Canonical); !ok {
return fmt.Errorf("unexpected reference type : %T", ref)
}
repo, err := registry.Repository(ctx, r)
if err != nil {
return err
}
blobs := repo.Blobs(ctx)
// Clear the repository reference and descriptor caches
err = blobs.Delete(ctx, r.Digest())
if err != nil {
return err
}
err = v.RemoveBlob(r.Digest().String())
if err != nil {
return err
}
return nil
})
s.OnManifestExpire(func(ref reference.Reference) error {
var r reference.Canonical
var ok bool
if r, ok = ref.(reference.Canonical); !ok {
return fmt.Errorf("unexpected reference type : %T", ref)
}
repo, err := registry.Repository(ctx, r)
if err != nil {
return err
}
manifests, err := repo.Manifests(ctx)
if err != nil {
return err
}
err = manifests.Delete(ctx, r.Digest())
if err != nil {
return err
}
return nil
})
err = s.Start()
if err != nil {
return nil, err
}
cs, err := configureAuth(config.Username, config.Password, config.RemoteURL)
if err != nil {
return nil, err
}
return &proxyingRegistry{
embedded: registry,
scheduler: s,
remoteURL: *remoteURL,
authChallenger: &remoteAuthChallenger{
remoteURL: *remoteURL,
cm: challenge.NewSimpleManager(),
cs: cs,
},
}, nil
}
func (pr *proxyingRegistry) Scope() distribution.Scope {
return distribution.GlobalScope
}
func (pr *proxyingRegistry) Repositories(ctx context.Context, repos []string, last string) (n int, err error) {
return pr.embedded.Repositories(ctx, repos, last)
}
func (pr *proxyingRegistry) Repository(ctx context.Context, name reference.Named) (distribution.Repository, error) {
c := pr.authChallenger
tr := transport.NewTransport(http.DefaultTransport,
auth.NewAuthorizer(c.challengeManager(), auth.NewTokenHandler(http.DefaultTransport, c.credentialStore(), name.Name(), "pull")))
localRepo, err := pr.embedded.Repository(ctx, name)
if err != nil {
return nil, err
}
localManifests, err := localRepo.Manifests(ctx, storage.SkipLayerVerification())
if err != nil {
return nil, err
}
remoteRepo, err := client.NewRepository(ctx, name, pr.remoteURL.String(), tr)
if err != nil {
return nil, err
}
remoteManifests, err := remoteRepo.Manifests(ctx)
if err != nil {
return nil, err
}
return &proxiedRepository{
blobStore: &proxyBlobStore{
localStore: localRepo.Blobs(ctx),
remoteStore: remoteRepo.Blobs(ctx),
scheduler: pr.scheduler,
repositoryName: name,
authChallenger: pr.authChallenger,
},
manifests: &proxyManifestStore{
repositoryName: name,
localManifests: localManifests, // Options?
remoteManifests: remoteManifests,
ctx: ctx,
scheduler: pr.scheduler,
authChallenger: pr.authChallenger,
},
name: name,
tags: &proxyTagService{
localTags: localRepo.Tags(ctx),
remoteTags: remoteRepo.Tags(ctx),
authChallenger: pr.authChallenger,
},
}, nil
}
func (pr *proxyingRegistry) Blobs() distribution.BlobEnumerator {
return pr.embedded.Blobs()
}
func (pr *proxyingRegistry) BlobStatter() distribution.BlobStatter {
return pr.embedded.BlobStatter()
}
// authChallenger encapsulates a request to the upstream to establish credential challenges
type authChallenger interface {
tryEstablishChallenges(context.Context) error
challengeManager() challenge.Manager
credentialStore() auth.CredentialStore
}
type remoteAuthChallenger struct {
remoteURL url.URL
sync.Mutex
cm challenge.Manager
cs auth.CredentialStore
}
func (r *remoteAuthChallenger) credentialStore() auth.CredentialStore {
return r.cs
}
func (r *remoteAuthChallenger) challengeManager() challenge.Manager {
return r.cm
}
// tryEstablishChallenges will attempt to get a challenge type for the upstream if none currently exist
func (r *remoteAuthChallenger) tryEstablishChallenges(ctx context.Context) error {
r.Lock()
defer r.Unlock()
remoteURL := r.remoteURL
remoteURL.Path = "/v2/"
challenges, err := r.cm.GetChallenges(remoteURL)
if err != nil {
return err
}
if len(challenges) > 0 {
return nil
}
// establish challenge type with upstream
if err := ping(r.cm, remoteURL.String(), challengeHeader); err != nil {
return err
}
context.GetLogger(ctx).Infof("Challenge established with upstream : %s %s", remoteURL, r.cm)
return nil
}
// proxiedRepository uses proxying blob and manifest services to serve content
// locally, or pulling it through from a remote and caching it locally if it doesn't
// already exist
type proxiedRepository struct {
blobStore distribution.BlobStore
manifests distribution.ManifestService
name reference.Named
tags distribution.TagService
}
func (pr *proxiedRepository) Manifests(ctx context.Context, options ...distribution.ManifestServiceOption) (distribution.ManifestService, error) {
return pr.manifests, nil
}
func (pr *proxiedRepository) Blobs(ctx context.Context) distribution.BlobStore {
return pr.blobStore
}
func (pr *proxiedRepository) Named() reference.Named {
return pr.name
}
func (pr *proxiedRepository) Tags(ctx context.Context) distribution.TagService {
return pr.tags
}

View File

@ -0,0 +1,65 @@
package proxy
import (
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
// proxyTagService supports local and remote lookup of tags.
type proxyTagService struct {
localTags distribution.TagService
remoteTags distribution.TagService
authChallenger authChallenger
}
var _ distribution.TagService = proxyTagService{}
// Get attempts to get the most recent digest for the tag by checking the remote
// tag service first and then caching it locally. If the remote is unavailable
// the local association is returned
func (pt proxyTagService) Get(ctx context.Context, tag string) (distribution.Descriptor, error) {
err := pt.authChallenger.tryEstablishChallenges(ctx)
if err == nil {
desc, err := pt.remoteTags.Get(ctx, tag)
if err == nil {
err := pt.localTags.Tag(ctx, tag, desc)
if err != nil {
return distribution.Descriptor{}, err
}
return desc, nil
}
}
desc, err := pt.localTags.Get(ctx, tag)
if err != nil {
return distribution.Descriptor{}, err
}
return desc, nil
}
func (pt proxyTagService) Tag(ctx context.Context, tag string, desc distribution.Descriptor) error {
return distribution.ErrUnsupported
}
func (pt proxyTagService) Untag(ctx context.Context, tag string) error {
err := pt.localTags.Untag(ctx, tag)
if err != nil {
return err
}
return nil
}
func (pt proxyTagService) All(ctx context.Context) ([]string, error) {
err := pt.authChallenger.tryEstablishChallenges(ctx)
if err == nil {
tags, err := pt.remoteTags.All(ctx)
if err == nil {
return tags, err
}
}
return pt.localTags.All(ctx)
}
func (pt proxyTagService) Lookup(ctx context.Context, digest distribution.Descriptor) ([]string, error) {
return []string{}, distribution.ErrUnsupported
}

View File

@ -0,0 +1,182 @@
package proxy
import (
"reflect"
"sort"
"sync"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
)
type mockTagStore struct {
mapping map[string]distribution.Descriptor
sync.Mutex
}
var _ distribution.TagService = &mockTagStore{}
func (m *mockTagStore) Get(ctx context.Context, tag string) (distribution.Descriptor, error) {
m.Lock()
defer m.Unlock()
if d, ok := m.mapping[tag]; ok {
return d, nil
}
return distribution.Descriptor{}, distribution.ErrTagUnknown{}
}
func (m *mockTagStore) Tag(ctx context.Context, tag string, desc distribution.Descriptor) error {
m.Lock()
defer m.Unlock()
m.mapping[tag] = desc
return nil
}
func (m *mockTagStore) Untag(ctx context.Context, tag string) error {
m.Lock()
defer m.Unlock()
if _, ok := m.mapping[tag]; ok {
delete(m.mapping, tag)
return nil
}
return distribution.ErrTagUnknown{}
}
func (m *mockTagStore) All(ctx context.Context) ([]string, error) {
m.Lock()
defer m.Unlock()
var tags []string
for tag := range m.mapping {
tags = append(tags, tag)
}
return tags, nil
}
func (m *mockTagStore) Lookup(ctx context.Context, digest distribution.Descriptor) ([]string, error) {
panic("not implemented")
}
func testProxyTagService(local, remote map[string]distribution.Descriptor) *proxyTagService {
if local == nil {
local = make(map[string]distribution.Descriptor)
}
if remote == nil {
remote = make(map[string]distribution.Descriptor)
}
return &proxyTagService{
localTags: &mockTagStore{mapping: local},
remoteTags: &mockTagStore{mapping: remote},
authChallenger: &mockChallenger{},
}
}
func TestGet(t *testing.T) {
remoteDesc := distribution.Descriptor{Size: 42}
remoteTag := "remote"
proxyTags := testProxyTagService(map[string]distribution.Descriptor{remoteTag: remoteDesc}, nil)
ctx := context.Background()
// Get pre-loaded tag
d, err := proxyTags.Get(ctx, remoteTag)
if err != nil {
t.Fatal(err)
}
if proxyTags.authChallenger.(*mockChallenger).count != 1 {
t.Fatalf("Expected 1 auth challenge call, got %#v", proxyTags.authChallenger)
}
if !reflect.DeepEqual(d, remoteDesc) {
t.Fatal("unable to get put tag")
}
local, err := proxyTags.localTags.Get(ctx, remoteTag)
if err != nil {
t.Fatal("remote tag not pulled into store")
}
if !reflect.DeepEqual(local, remoteDesc) {
t.Fatalf("unexpected descriptor pulled through")
}
// Manually overwrite remote tag
newRemoteDesc := distribution.Descriptor{Size: 43}
err = proxyTags.remoteTags.Tag(ctx, remoteTag, newRemoteDesc)
if err != nil {
t.Fatal(err)
}
d, err = proxyTags.Get(ctx, remoteTag)
if err != nil {
t.Fatal(err)
}
if proxyTags.authChallenger.(*mockChallenger).count != 2 {
t.Fatalf("Expected 2 auth challenge calls, got %#v", proxyTags.authChallenger)
}
if !reflect.DeepEqual(d, newRemoteDesc) {
t.Fatal("unable to get put tag")
}
_, err = proxyTags.localTags.Get(ctx, remoteTag)
if err != nil {
t.Fatal("remote tag not pulled into store")
}
// untag, ensure it's removed locally, but present in remote
err = proxyTags.Untag(ctx, remoteTag)
if err != nil {
t.Fatal(err)
}
_, err = proxyTags.localTags.Get(ctx, remoteTag)
if err == nil {
t.Fatalf("Expected error getting Untag'd tag")
}
_, err = proxyTags.remoteTags.Get(ctx, remoteTag)
if err != nil {
t.Fatalf("remote tag should not be untagged with proxyTag.Untag")
}
_, err = proxyTags.Get(ctx, remoteTag)
if err != nil {
t.Fatal("untagged tag should be pulled through")
}
if proxyTags.authChallenger.(*mockChallenger).count != 3 {
t.Fatalf("Expected 3 auth challenge calls, got %#v", proxyTags.authChallenger)
}
// Add another tag. Ensure both tags appear in 'All'
err = proxyTags.remoteTags.Tag(ctx, "funtag", distribution.Descriptor{Size: 42})
if err != nil {
t.Fatal(err)
}
all, err := proxyTags.All(ctx)
if err != nil {
t.Fatal(err)
}
if len(all) != 2 {
t.Fatalf("Unexpected tag length returned from All() : %d ", len(all))
}
sort.Strings(all)
if all[0] != "funtag" && all[1] != "remote" {
t.Fatalf("Unexpected tags returned from All() : %v ", all)
}
if proxyTags.authChallenger.(*mockChallenger).count != 4 {
t.Fatalf("Expected 4 auth challenge calls, got %#v", proxyTags.authChallenger)
}
}

View File

@ -0,0 +1,259 @@
package scheduler
import (
"encoding/json"
"fmt"
"sync"
"time"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver"
)
// onTTLExpiryFunc is called when a repository's TTL expires
type expiryFunc func(reference.Reference) error
const (
entryTypeBlob = iota
entryTypeManifest
indexSaveFrequency = 5 * time.Second
)
// schedulerEntry represents an entry in the scheduler
// fields are exported for serialization
type schedulerEntry struct {
Key string `json:"Key"`
Expiry time.Time `json:"ExpiryData"`
EntryType int `json:"EntryType"`
timer *time.Timer
}
// New returns a new instance of the scheduler
func New(ctx context.Context, driver driver.StorageDriver, path string) *TTLExpirationScheduler {
return &TTLExpirationScheduler{
entries: make(map[string]*schedulerEntry),
driver: driver,
pathToStateFile: path,
ctx: ctx,
stopped: true,
doneChan: make(chan struct{}),
saveTimer: time.NewTicker(indexSaveFrequency),
}
}
// TTLExpirationScheduler is a scheduler used to perform actions
// when TTLs expire
type TTLExpirationScheduler struct {
sync.Mutex
entries map[string]*schedulerEntry
driver driver.StorageDriver
ctx context.Context
pathToStateFile string
stopped bool
onBlobExpire expiryFunc
onManifestExpire expiryFunc
indexDirty bool
saveTimer *time.Ticker
doneChan chan struct{}
}
// OnBlobExpire is called when a scheduled blob's TTL expires
func (ttles *TTLExpirationScheduler) OnBlobExpire(f expiryFunc) {
ttles.Lock()
defer ttles.Unlock()
ttles.onBlobExpire = f
}
// OnManifestExpire is called when a scheduled manifest's TTL expires
func (ttles *TTLExpirationScheduler) OnManifestExpire(f expiryFunc) {
ttles.Lock()
defer ttles.Unlock()
ttles.onManifestExpire = f
}
// AddBlob schedules a blob cleanup after ttl expires
func (ttles *TTLExpirationScheduler) AddBlob(blobRef reference.Canonical, ttl time.Duration) error {
ttles.Lock()
defer ttles.Unlock()
if ttles.stopped {
return fmt.Errorf("scheduler not started")
}
ttles.add(blobRef, ttl, entryTypeBlob)
return nil
}
// AddManifest schedules a manifest cleanup after ttl expires
func (ttles *TTLExpirationScheduler) AddManifest(manifestRef reference.Canonical, ttl time.Duration) error {
ttles.Lock()
defer ttles.Unlock()
if ttles.stopped {
return fmt.Errorf("scheduler not started")
}
ttles.add(manifestRef, ttl, entryTypeManifest)
return nil
}
// Start starts the scheduler
func (ttles *TTLExpirationScheduler) Start() error {
ttles.Lock()
defer ttles.Unlock()
err := ttles.readState()
if err != nil {
return err
}
if !ttles.stopped {
return fmt.Errorf("Scheduler already started")
}
context.GetLogger(ttles.ctx).Infof("Starting cached object TTL expiration scheduler...")
ttles.stopped = false
// Start timer for each deserialized entry
for _, entry := range ttles.entries {
entry.timer = ttles.startTimer(entry, entry.Expiry.Sub(time.Now()))
}
// Start a ticker to periodically save the entries index
go func() {
for {
select {
case <-ttles.saveTimer.C:
ttles.Lock()
if !ttles.indexDirty {
ttles.Unlock()
continue
}
err := ttles.writeState()
if err != nil {
context.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
} else {
ttles.indexDirty = false
}
ttles.Unlock()
case <-ttles.doneChan:
return
}
}
}()
return nil
}
func (ttles *TTLExpirationScheduler) add(r reference.Reference, ttl time.Duration, eType int) {
entry := &schedulerEntry{
Key: r.String(),
Expiry: time.Now().Add(ttl),
EntryType: eType,
}
context.GetLogger(ttles.ctx).Infof("Adding new scheduler entry for %s with ttl=%s", entry.Key, entry.Expiry.Sub(time.Now()))
if oldEntry, present := ttles.entries[entry.Key]; present && oldEntry.timer != nil {
oldEntry.timer.Stop()
}
ttles.entries[entry.Key] = entry
entry.timer = ttles.startTimer(entry, ttl)
ttles.indexDirty = true
}
func (ttles *TTLExpirationScheduler) startTimer(entry *schedulerEntry, ttl time.Duration) *time.Timer {
return time.AfterFunc(ttl, func() {
ttles.Lock()
defer ttles.Unlock()
var f expiryFunc
switch entry.EntryType {
case entryTypeBlob:
f = ttles.onBlobExpire
case entryTypeManifest:
f = ttles.onManifestExpire
default:
f = func(reference.Reference) error {
return fmt.Errorf("scheduler entry type")
}
}
ref, err := reference.Parse(entry.Key)
if err == nil {
if err := f(ref); err != nil {
context.GetLogger(ttles.ctx).Errorf("Scheduler error returned from OnExpire(%s): %s", entry.Key, err)
}
} else {
context.GetLogger(ttles.ctx).Errorf("Error unpacking reference: %s", err)
}
delete(ttles.entries, entry.Key)
ttles.indexDirty = true
})
}
// Stop stops the scheduler.
func (ttles *TTLExpirationScheduler) Stop() {
ttles.Lock()
defer ttles.Unlock()
if err := ttles.writeState(); err != nil {
context.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err)
}
for _, entry := range ttles.entries {
entry.timer.Stop()
}
close(ttles.doneChan)
ttles.saveTimer.Stop()
ttles.stopped = true
}
func (ttles *TTLExpirationScheduler) writeState() error {
jsonBytes, err := json.Marshal(ttles.entries)
if err != nil {
return err
}
err = ttles.driver.PutContent(ttles.ctx, ttles.pathToStateFile, jsonBytes)
if err != nil {
return err
}
return nil
}
func (ttles *TTLExpirationScheduler) readState() error {
if _, err := ttles.driver.Stat(ttles.ctx, ttles.pathToStateFile); err != nil {
switch err := err.(type) {
case driver.PathNotFoundError:
return nil
default:
return err
}
}
bytes, err := ttles.driver.GetContent(ttles.ctx, ttles.pathToStateFile)
if err != nil {
return err
}
err = json.Unmarshal(bytes, &ttles.entries)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,211 @@
package scheduler
import (
"encoding/json"
"sync"
"testing"
"time"
"github.com/docker/distribution/context"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
func testRefs(t *testing.T) (reference.Reference, reference.Reference, reference.Reference) {
ref1, err := reference.Parse("testrepo@sha256:aaaaeaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
if err != nil {
t.Fatalf("could not parse reference: %v", err)
}
ref2, err := reference.Parse("testrepo@sha256:bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
if err != nil {
t.Fatalf("could not parse reference: %v", err)
}
ref3, err := reference.Parse("testrepo@sha256:cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc")
if err != nil {
t.Fatalf("could not parse reference: %v", err)
}
return ref1, ref2, ref3
}
func TestSchedule(t *testing.T) {
ref1, ref2, ref3 := testRefs(t)
timeUnit := time.Millisecond
remainingRepos := map[string]bool{
ref1.String(): true,
ref2.String(): true,
ref3.String(): true,
}
var mu sync.Mutex
s := New(context.Background(), inmemory.New(), "/ttl")
deleteFunc := func(repoName reference.Reference) error {
if len(remainingRepos) == 0 {
t.Fatalf("Incorrect expiry count")
}
_, ok := remainingRepos[repoName.String()]
if !ok {
t.Fatalf("Trying to remove nonexistent repo: %s", repoName)
}
t.Log("removing", repoName)
mu.Lock()
delete(remainingRepos, repoName.String())
mu.Unlock()
return nil
}
s.onBlobExpire = deleteFunc
err := s.Start()
if err != nil {
t.Fatalf("Error starting ttlExpirationScheduler: %s", err)
}
s.add(ref1, 3*timeUnit, entryTypeBlob)
s.add(ref2, 1*timeUnit, entryTypeBlob)
func() {
s.Lock()
s.add(ref3, 1*timeUnit, entryTypeBlob)
s.Unlock()
}()
// Ensure all repos are deleted
<-time.After(50 * timeUnit)
mu.Lock()
defer mu.Unlock()
if len(remainingRepos) != 0 {
t.Fatalf("Repositories remaining: %#v", remainingRepos)
}
}
func TestRestoreOld(t *testing.T) {
ref1, ref2, _ := testRefs(t)
remainingRepos := map[string]bool{
ref1.String(): true,
ref2.String(): true,
}
var wg sync.WaitGroup
wg.Add(len(remainingRepos))
var mu sync.Mutex
deleteFunc := func(r reference.Reference) error {
mu.Lock()
defer mu.Unlock()
if r.String() == ref1.String() && len(remainingRepos) == 2 {
t.Errorf("ref1 should not be removed first")
}
_, ok := remainingRepos[r.String()]
if !ok {
t.Fatalf("Trying to remove nonexistent repo: %s", r)
}
delete(remainingRepos, r.String())
wg.Done()
return nil
}
timeUnit := time.Millisecond
serialized, err := json.Marshal(&map[string]schedulerEntry{
ref1.String(): {
Expiry: time.Now().Add(10 * timeUnit),
Key: ref1.String(),
EntryType: 0,
},
ref2.String(): {
Expiry: time.Now().Add(-3 * timeUnit), // TTL passed, should be removed first
Key: ref2.String(),
EntryType: 0,
},
})
if err != nil {
t.Fatalf("Error serializing test data: %s", err.Error())
}
ctx := context.Background()
pathToStatFile := "/ttl"
fs := inmemory.New()
err = fs.PutContent(ctx, pathToStatFile, serialized)
if err != nil {
t.Fatal("Unable to write serialized data to fs")
}
s := New(context.Background(), fs, "/ttl")
s.OnBlobExpire(deleteFunc)
err = s.Start()
if err != nil {
t.Fatalf("Error starting ttlExpirationScheduler: %s", err)
}
defer s.Stop()
wg.Wait()
mu.Lock()
defer mu.Unlock()
if len(remainingRepos) != 0 {
t.Fatalf("Repositories remaining: %#v", remainingRepos)
}
}
func TestStopRestore(t *testing.T) {
ref1, ref2, _ := testRefs(t)
timeUnit := time.Millisecond
remainingRepos := map[string]bool{
ref1.String(): true,
ref2.String(): true,
}
var mu sync.Mutex
deleteFunc := func(r reference.Reference) error {
mu.Lock()
delete(remainingRepos, r.String())
mu.Unlock()
return nil
}
fs := inmemory.New()
pathToStateFile := "/ttl"
s := New(context.Background(), fs, pathToStateFile)
s.onBlobExpire = deleteFunc
err := s.Start()
if err != nil {
t.Fatalf(err.Error())
}
s.add(ref1, 300*timeUnit, entryTypeBlob)
s.add(ref2, 100*timeUnit, entryTypeBlob)
// Start and stop before all operations complete
// state will be written to fs
s.Stop()
time.Sleep(10 * time.Millisecond)
// v2 will restore state from fs
s2 := New(context.Background(), fs, pathToStateFile)
s2.onBlobExpire = deleteFunc
err = s2.Start()
if err != nil {
t.Fatalf("Error starting v2: %s", err.Error())
}
<-time.After(500 * timeUnit)
mu.Lock()
defer mu.Unlock()
if len(remainingRepos) != 0 {
t.Fatalf("Repositories remaining: %#v", remainingRepos)
}
}
func TestDoubleStart(t *testing.T) {
s := New(context.Background(), inmemory.New(), "/ttl")
err := s.Start()
if err != nil {
t.Fatalf("Unable to start scheduler")
}
err = s.Start()
if err == nil {
t.Fatalf("Scheduler started twice without error")
}
}

View File

@ -0,0 +1,356 @@
package registry
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"os"
"time"
"rsc.io/letsencrypt"
log "github.com/Sirupsen/logrus"
"github.com/Sirupsen/logrus/formatters/logstash"
"github.com/bugsnag/bugsnag-go"
"github.com/docker/distribution/configuration"
"github.com/docker/distribution/context"
"github.com/docker/distribution/health"
"github.com/docker/distribution/registry/handlers"
"github.com/docker/distribution/registry/listener"
"github.com/docker/distribution/uuid"
"github.com/docker/distribution/version"
gorhandlers "github.com/gorilla/handlers"
"github.com/spf13/cobra"
"github.com/yvasiyarov/gorelic"
)
// ServeCmd is a cobra command for running the registry.
var ServeCmd = &cobra.Command{
Use: "serve <config>",
Short: "`serve` stores and distributes Docker images",
Long: "`serve` stores and distributes Docker images.",
Run: func(cmd *cobra.Command, args []string) {
// setup context
ctx := context.WithVersion(context.Background(), version.Version)
config, err := resolveConfiguration(args)
if err != nil {
fmt.Fprintf(os.Stderr, "configuration error: %v\n", err)
cmd.Usage()
os.Exit(1)
}
if config.HTTP.Debug.Addr != "" {
go func(addr string) {
log.Infof("debug server listening %v", addr)
if err := http.ListenAndServe(addr, nil); err != nil {
log.Fatalf("error listening on debug interface: %v", err)
}
}(config.HTTP.Debug.Addr)
}
registry, err := NewRegistry(ctx, config)
if err != nil {
log.Fatalln(err)
}
if err = registry.ListenAndServe(); err != nil {
log.Fatalln(err)
}
},
}
// A Registry represents a complete instance of the registry.
// TODO(aaronl): It might make sense for Registry to become an interface.
type Registry struct {
config *configuration.Configuration
app *handlers.App
server *http.Server
}
// NewRegistry creates a new registry from a context and configuration struct.
func NewRegistry(ctx context.Context, config *configuration.Configuration) (*Registry, error) {
var err error
ctx, err = configureLogging(ctx, config)
if err != nil {
return nil, fmt.Errorf("error configuring logger: %v", err)
}
// inject a logger into the uuid library. warns us if there is a problem
// with uuid generation under low entropy.
uuid.Loggerf = context.GetLogger(ctx).Warnf
app := handlers.NewApp(ctx, config)
// TODO(aaronl): The global scope of the health checks means NewRegistry
// can only be called once per process.
app.RegisterHealthChecks()
handler := configureReporting(app)
handler = alive("/", handler)
handler = health.Handler(handler)
handler = panicHandler(handler)
if !config.Log.AccessLog.Disabled {
handler = gorhandlers.CombinedLoggingHandler(os.Stdout, handler)
}
server := &http.Server{
Handler: handler,
}
return &Registry{
app: app,
config: config,
server: server,
}, nil
}
// ListenAndServe runs the registry's HTTP server.
func (registry *Registry) ListenAndServe() error {
config := registry.config
ln, err := listener.NewListener(config.HTTP.Net, config.HTTP.Addr)
if err != nil {
return err
}
if config.HTTP.TLS.Certificate != "" || config.HTTP.TLS.LetsEncrypt.CacheFile != "" {
tlsConf := &tls.Config{
ClientAuth: tls.NoClientCert,
NextProtos: nextProtos(config),
MinVersion: tls.VersionTLS10,
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
},
}
if config.HTTP.TLS.LetsEncrypt.CacheFile != "" {
if config.HTTP.TLS.Certificate != "" {
return fmt.Errorf("cannot specify both certificate and Let's Encrypt")
}
var m letsencrypt.Manager
if err := m.CacheFile(config.HTTP.TLS.LetsEncrypt.CacheFile); err != nil {
return err
}
if !m.Registered() {
if err := m.Register(config.HTTP.TLS.LetsEncrypt.Email, nil); err != nil {
return err
}
}
tlsConf.GetCertificate = m.GetCertificate
} else {
tlsConf.Certificates = make([]tls.Certificate, 1)
tlsConf.Certificates[0], err = tls.LoadX509KeyPair(config.HTTP.TLS.Certificate, config.HTTP.TLS.Key)
if err != nil {
return err
}
}
if len(config.HTTP.TLS.ClientCAs) != 0 {
pool := x509.NewCertPool()
for _, ca := range config.HTTP.TLS.ClientCAs {
caPem, err := ioutil.ReadFile(ca)
if err != nil {
return err
}
if ok := pool.AppendCertsFromPEM(caPem); !ok {
return fmt.Errorf("Could not add CA to pool")
}
}
for _, subj := range pool.Subjects() {
context.GetLogger(registry.app).Debugf("CA Subject: %s", string(subj))
}
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
tlsConf.ClientCAs = pool
}
ln = tls.NewListener(ln, tlsConf)
context.GetLogger(registry.app).Infof("listening on %v, tls", ln.Addr())
} else {
context.GetLogger(registry.app).Infof("listening on %v", ln.Addr())
}
return registry.server.Serve(ln)
}
func configureReporting(app *handlers.App) http.Handler {
var handler http.Handler = app
if app.Config.Reporting.Bugsnag.APIKey != "" {
bugsnagConfig := bugsnag.Configuration{
APIKey: app.Config.Reporting.Bugsnag.APIKey,
// TODO(brianbland): provide the registry version here
// AppVersion: "2.0",
}
if app.Config.Reporting.Bugsnag.ReleaseStage != "" {
bugsnagConfig.ReleaseStage = app.Config.Reporting.Bugsnag.ReleaseStage
}
if app.Config.Reporting.Bugsnag.Endpoint != "" {
bugsnagConfig.Endpoint = app.Config.Reporting.Bugsnag.Endpoint
}
bugsnag.Configure(bugsnagConfig)
handler = bugsnag.Handler(handler)
}
if app.Config.Reporting.NewRelic.LicenseKey != "" {
agent := gorelic.NewAgent()
agent.NewrelicLicense = app.Config.Reporting.NewRelic.LicenseKey
if app.Config.Reporting.NewRelic.Name != "" {
agent.NewrelicName = app.Config.Reporting.NewRelic.Name
}
agent.CollectHTTPStat = true
agent.Verbose = app.Config.Reporting.NewRelic.Verbose
agent.Run()
handler = agent.WrapHTTPHandler(handler)
}
return handler
}
// configureLogging prepares the context with a logger using the
// configuration.
func configureLogging(ctx context.Context, config *configuration.Configuration) (context.Context, error) {
if config.Log.Level == "" && config.Log.Formatter == "" {
// If no config for logging is set, fallback to deprecated "Loglevel".
log.SetLevel(logLevel(config.Loglevel))
ctx = context.WithLogger(ctx, context.GetLogger(ctx))
return ctx, nil
}
log.SetLevel(logLevel(config.Log.Level))
formatter := config.Log.Formatter
if formatter == "" {
formatter = "text" // default formatter
}
switch formatter {
case "json":
log.SetFormatter(&log.JSONFormatter{
TimestampFormat: time.RFC3339Nano,
})
case "text":
log.SetFormatter(&log.TextFormatter{
TimestampFormat: time.RFC3339Nano,
})
case "logstash":
log.SetFormatter(&logstash.LogstashFormatter{
TimestampFormat: time.RFC3339Nano,
})
default:
// just let the library use default on empty string.
if config.Log.Formatter != "" {
return ctx, fmt.Errorf("unsupported logging formatter: %q", config.Log.Formatter)
}
}
if config.Log.Formatter != "" {
log.Debugf("using %q logging formatter", config.Log.Formatter)
}
if len(config.Log.Fields) > 0 {
// build up the static fields, if present.
var fields []interface{}
for k := range config.Log.Fields {
fields = append(fields, k)
}
ctx = context.WithValues(ctx, config.Log.Fields)
ctx = context.WithLogger(ctx, context.GetLogger(ctx, fields...))
}
return ctx, nil
}
func logLevel(level configuration.Loglevel) log.Level {
l, err := log.ParseLevel(string(level))
if err != nil {
l = log.InfoLevel
log.Warnf("error parsing level %q: %v, using %q ", level, err, l)
}
return l
}
// panicHandler add an HTTP handler to web app. The handler recover the happening
// panic. logrus.Panic transmits panic message to pre-config log hooks, which is
// defined in config.yml.
func panicHandler(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Panic(fmt.Sprintf("%v", err))
}
}()
handler.ServeHTTP(w, r)
})
}
// alive simply wraps the handler with a route that always returns an http 200
// response when the path is matched. If the path is not matched, the request
// is passed to the provided handler. There is no guarantee of anything but
// that the server is up. Wrap with other handlers (such as health.Handler)
// for greater affect.
func alive(path string, handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == path {
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(http.StatusOK)
return
}
handler.ServeHTTP(w, r)
})
}
func resolveConfiguration(args []string) (*configuration.Configuration, error) {
var configurationPath string
if len(args) > 0 {
configurationPath = args[0]
} else if os.Getenv("REGISTRY_CONFIGURATION_PATH") != "" {
configurationPath = os.Getenv("REGISTRY_CONFIGURATION_PATH")
}
if configurationPath == "" {
return nil, fmt.Errorf("configuration path unspecified")
}
fp, err := os.Open(configurationPath)
if err != nil {
return nil, err
}
defer fp.Close()
config, err := configuration.Parse(fp)
if err != nil {
return nil, fmt.Errorf("error parsing %s: %v", configurationPath, err)
}
return config, nil
}
func nextProtos(config *configuration.Configuration) []string {
switch config.HTTP.HTTP2.Disabled {
case true:
return []string{"http/1.1"}
default:
return []string{"h2", "http/1.1"}
}
}

View File

@ -0,0 +1,30 @@
package registry
import (
"reflect"
"testing"
"github.com/docker/distribution/configuration"
)
// Tests to ensure nextProtos returns the correct protocols when:
// * config.HTTP.HTTP2.Disabled is not explicitly set => [h2 http/1.1]
// * config.HTTP.HTTP2.Disabled is explicitly set to false [h2 http/1.1]
// * config.HTTP.HTTP2.Disabled is explicitly set to true [http/1.1]
func TestNextProtos(t *testing.T) {
config := &configuration.Configuration{}
protos := nextProtos(config)
if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
}
config.HTTP.HTTP2.Disabled = false
protos = nextProtos(config)
if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
}
config.HTTP.HTTP2.Disabled = true
protos = nextProtos(config)
if !reflect.DeepEqual(protos, []string{"http/1.1"}) {
t.Fatalf("expected protos to equal [http/1.1], got %s", protos)
}
}

View File

@ -0,0 +1,84 @@
package registry
import (
"fmt"
"os"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/version"
"github.com/docker/libtrust"
"github.com/spf13/cobra"
)
var showVersion bool
func init() {
RootCmd.AddCommand(ServeCmd)
RootCmd.AddCommand(GCCmd)
GCCmd.Flags().BoolVarP(&dryRun, "dry-run", "d", false, "do everything except remove the blobs")
RootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "show the version and exit")
}
// RootCmd is the main command for the 'registry' binary.
var RootCmd = &cobra.Command{
Use: "registry",
Short: "`registry`",
Long: "`registry`",
Run: func(cmd *cobra.Command, args []string) {
if showVersion {
version.PrintVersion()
return
}
cmd.Usage()
},
}
var dryRun bool
// GCCmd is the cobra command that corresponds to the garbage-collect subcommand
var GCCmd = &cobra.Command{
Use: "garbage-collect <config>",
Short: "`garbage-collect` deletes layers not referenced by any manifests",
Long: "`garbage-collect` deletes layers not referenced by any manifests",
Run: func(cmd *cobra.Command, args []string) {
config, err := resolveConfiguration(args)
if err != nil {
fmt.Fprintf(os.Stderr, "configuration error: %v\n", err)
cmd.Usage()
os.Exit(1)
}
driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
ctx := context.Background()
ctx, err = configureLogging(ctx, config)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to configure logging with config: %s", err)
os.Exit(1)
}
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
registry, err := storage.NewRegistry(ctx, driver, storage.Schema1SigningKey(k))
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err)
os.Exit(1)
}
err = storage.MarkAndSweep(ctx, driver, registry, dryRun)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to garbage collect: %v", err)
os.Exit(1)
}
},
}

View File

@ -0,0 +1,614 @@
package storage
import (
"bytes"
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"reflect"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver/testdriver"
"github.com/docker/distribution/testutil"
)
// TestWriteSeek tests that the current file size can be
// obtained using Seek
func TestWriteSeek(t *testing.T) {
ctx := context.Background()
imageName, _ := reference.ParseNamed("foo/bar")
driver := testdriver.New()
registry, err := NewRegistry(ctx, driver, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repository, err := registry.Repository(ctx, imageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
bs := repository.Blobs(ctx)
blobUpload, err := bs.Create(ctx)
if err != nil {
t.Fatalf("unexpected error starting layer upload: %s", err)
}
contents := []byte{1, 2, 3}
blobUpload.Write(contents)
blobUpload.Close()
offset := blobUpload.Size()
if offset != int64(len(contents)) {
t.Fatalf("unexpected value for blobUpload offset: %v != %v", offset, len(contents))
}
}
// TestSimpleBlobUpload covers the blob upload process, exercising common
// error paths that might be seen during an upload.
func TestSimpleBlobUpload(t *testing.T) {
randomDataReader, dgst, err := testutil.CreateRandomTarFile()
if err != nil {
t.Fatalf("error creating random reader: %v", err)
}
ctx := context.Background()
imageName, _ := reference.ParseNamed("foo/bar")
driver := testdriver.New()
registry, err := NewRegistry(ctx, driver, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repository, err := registry.Repository(ctx, imageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
bs := repository.Blobs(ctx)
h := sha256.New()
rd := io.TeeReader(randomDataReader, h)
blobUpload, err := bs.Create(ctx)
if err != nil {
t.Fatalf("unexpected error starting layer upload: %s", err)
}
// Cancel the upload then restart it
if err := blobUpload.Cancel(ctx); err != nil {
t.Fatalf("unexpected error during upload cancellation: %v", err)
}
// get the enclosing directory
uploadPath := path.Dir(blobUpload.(*blobWriter).path)
// ensure state was cleaned up
_, err = driver.List(ctx, uploadPath)
if err == nil {
t.Fatal("files in upload path after cleanup")
}
// Do a resume, get unknown upload
blobUpload, err = bs.Resume(ctx, blobUpload.ID())
if err != distribution.ErrBlobUploadUnknown {
t.Fatalf("unexpected error resuming upload, should be unknown: %v", err)
}
// Restart!
blobUpload, err = bs.Create(ctx)
if err != nil {
t.Fatalf("unexpected error starting layer upload: %s", err)
}
// Get the size of our random tarfile
randomDataSize, err := seekerSize(randomDataReader)
if err != nil {
t.Fatalf("error getting seeker size of random data: %v", err)
}
nn, err := io.Copy(blobUpload, rd)
if err != nil {
t.Fatalf("unexpected error uploading layer data: %v", err)
}
if nn != randomDataSize {
t.Fatalf("layer data write incomplete")
}
blobUpload.Close()
offset := blobUpload.Size()
if offset != nn {
t.Fatalf("blobUpload not updated with correct offset: %v != %v", offset, nn)
}
// Do a resume, for good fun
blobUpload, err = bs.Resume(ctx, blobUpload.ID())
if err != nil {
t.Fatalf("unexpected error resuming upload: %v", err)
}
sha256Digest := digest.NewDigest("sha256", h)
desc, err := blobUpload.Commit(ctx, distribution.Descriptor{Digest: dgst})
if err != nil {
t.Fatalf("unexpected error finishing layer upload: %v", err)
}
// ensure state was cleaned up
uploadPath = path.Dir(blobUpload.(*blobWriter).path)
_, err = driver.List(ctx, uploadPath)
if err == nil {
t.Fatal("files in upload path after commit")
}
// After finishing an upload, it should no longer exist.
if _, err := bs.Resume(ctx, blobUpload.ID()); err != distribution.ErrBlobUploadUnknown {
t.Fatalf("expected layer upload to be unknown, got %v", err)
}
// Test for existence.
statDesc, err := bs.Stat(ctx, desc.Digest)
if err != nil {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs)
}
if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
}
rc, err := bs.Open(ctx, desc.Digest)
if err != nil {
t.Fatalf("unexpected error opening blob for read: %v", err)
}
defer rc.Close()
h.Reset()
nn, err = io.Copy(h, rc)
if err != nil {
t.Fatalf("error reading layer: %v", err)
}
if nn != randomDataSize {
t.Fatalf("incorrect read length")
}
if digest.NewDigest("sha256", h) != sha256Digest {
t.Fatalf("unexpected digest from uploaded layer: %q != %q", digest.NewDigest("sha256", h), sha256Digest)
}
// Delete a blob
err = bs.Delete(ctx, desc.Digest)
if err != nil {
t.Fatalf("Unexpected error deleting blob")
}
d, err := bs.Stat(ctx, desc.Digest)
if err == nil {
t.Fatalf("unexpected non-error stating deleted blob: %v", d)
}
switch err {
case distribution.ErrBlobUnknown:
break
default:
t.Errorf("Unexpected error type stat-ing deleted manifest: %#v", err)
}
_, err = bs.Open(ctx, desc.Digest)
if err == nil {
t.Fatalf("unexpected success opening deleted blob for read")
}
switch err {
case distribution.ErrBlobUnknown:
break
default:
t.Errorf("Unexpected error type getting deleted manifest: %#v", err)
}
// Re-upload the blob
randomBlob, err := ioutil.ReadAll(randomDataReader)
if err != nil {
t.Fatalf("Error reading all of blob %s", err.Error())
}
expectedDigest := digest.FromBytes(randomBlob)
simpleUpload(t, bs, randomBlob, expectedDigest)
d, err = bs.Stat(ctx, expectedDigest)
if err != nil {
t.Errorf("unexpected error stat-ing blob")
}
if d.Digest != expectedDigest {
t.Errorf("Mismatching digest with restored blob")
}
_, err = bs.Open(ctx, expectedDigest)
if err != nil {
t.Errorf("Unexpected error opening blob")
}
// Reuse state to test delete with a delete-disabled registry
registry, err = NewRegistry(ctx, driver, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repository, err = registry.Repository(ctx, imageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
bs = repository.Blobs(ctx)
err = bs.Delete(ctx, desc.Digest)
if err == nil {
t.Errorf("Unexpected success deleting while disabled")
}
}
// TestSimpleBlobRead just creates a simple blob file and ensures that basic
// open, read, seek, read works. More specific edge cases should be covered in
// other tests.
func TestSimpleBlobRead(t *testing.T) {
ctx := context.Background()
imageName, _ := reference.ParseNamed("foo/bar")
driver := testdriver.New()
registry, err := NewRegistry(ctx, driver, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repository, err := registry.Repository(ctx, imageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
bs := repository.Blobs(ctx)
randomLayerReader, dgst, err := testutil.CreateRandomTarFile() // TODO(stevvooe): Consider using just a random string.
if err != nil {
t.Fatalf("error creating random data: %v", err)
}
// Test for existence.
desc, err := bs.Stat(ctx, dgst)
if err != distribution.ErrBlobUnknown {
t.Fatalf("expected not found error when testing for existence: %v", err)
}
rc, err := bs.Open(ctx, dgst)
if err != distribution.ErrBlobUnknown {
t.Fatalf("expected not found error when opening non-existent blob: %v", err)
}
randomLayerSize, err := seekerSize(randomLayerReader)
if err != nil {
t.Fatalf("error getting seeker size for random layer: %v", err)
}
descBefore := distribution.Descriptor{Digest: dgst, MediaType: "application/octet-stream", Size: randomLayerSize}
t.Logf("desc: %v", descBefore)
desc, err = addBlob(ctx, bs, descBefore, randomLayerReader)
if err != nil {
t.Fatalf("error adding blob to blobservice: %v", err)
}
if desc.Size != randomLayerSize {
t.Fatalf("committed blob has incorrect length: %v != %v", desc.Size, randomLayerSize)
}
rc, err = bs.Open(ctx, desc.Digest) // note that we are opening with original digest.
if err != nil {
t.Fatalf("error opening blob with %v: %v", dgst, err)
}
defer rc.Close()
// Now check the sha digest and ensure its the same
h := sha256.New()
nn, err := io.Copy(h, rc)
if err != nil {
t.Fatalf("unexpected error copying to hash: %v", err)
}
if nn != randomLayerSize {
t.Fatalf("stored incorrect number of bytes in blob: %d != %d", nn, randomLayerSize)
}
sha256Digest := digest.NewDigest("sha256", h)
if sha256Digest != desc.Digest {
t.Fatalf("fetched digest does not match: %q != %q", sha256Digest, desc.Digest)
}
// Now seek back the blob, read the whole thing and check against randomLayerData
offset, err := rc.Seek(0, os.SEEK_SET)
if err != nil {
t.Fatalf("error seeking blob: %v", err)
}
if offset != 0 {
t.Fatalf("seek failed: expected 0 offset, got %d", offset)
}
p, err := ioutil.ReadAll(rc)
if err != nil {
t.Fatalf("error reading all of blob: %v", err)
}
if len(p) != int(randomLayerSize) {
t.Fatalf("blob data read has different length: %v != %v", len(p), randomLayerSize)
}
// Reset the randomLayerReader and read back the buffer
_, err = randomLayerReader.Seek(0, os.SEEK_SET)
if err != nil {
t.Fatalf("error resetting layer reader: %v", err)
}
randomLayerData, err := ioutil.ReadAll(randomLayerReader)
if err != nil {
t.Fatalf("random layer read failed: %v", err)
}
if !bytes.Equal(p, randomLayerData) {
t.Fatalf("layer data not equal")
}
}
// TestBlobMount covers the blob mount process, exercising common
// error paths that might be seen during a mount.
func TestBlobMount(t *testing.T) {
randomDataReader, dgst, err := testutil.CreateRandomTarFile()
if err != nil {
t.Fatalf("error creating random reader: %v", err)
}
ctx := context.Background()
imageName, _ := reference.ParseNamed("foo/bar")
sourceImageName, _ := reference.ParseNamed("foo/source")
driver := testdriver.New()
registry, err := NewRegistry(ctx, driver, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repository, err := registry.Repository(ctx, imageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
sourceRepository, err := registry.Repository(ctx, sourceImageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
sbs := sourceRepository.Blobs(ctx)
blobUpload, err := sbs.Create(ctx)
if err != nil {
t.Fatalf("unexpected error starting layer upload: %s", err)
}
// Get the size of our random tarfile
randomDataSize, err := seekerSize(randomDataReader)
if err != nil {
t.Fatalf("error getting seeker size of random data: %v", err)
}
nn, err := io.Copy(blobUpload, randomDataReader)
if err != nil {
t.Fatalf("unexpected error uploading layer data: %v", err)
}
desc, err := blobUpload.Commit(ctx, distribution.Descriptor{Digest: dgst})
if err != nil {
t.Fatalf("unexpected error finishing layer upload: %v", err)
}
// Test for existence.
statDesc, err := sbs.Stat(ctx, desc.Digest)
if err != nil {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, sbs)
}
if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
}
bs := repository.Blobs(ctx)
// Test destination for existence.
statDesc, err = bs.Stat(ctx, desc.Digest)
if err == nil {
t.Fatalf("unexpected non-error stating unmounted blob: %v", desc)
}
canonicalRef, err := reference.WithDigest(sourceRepository.Named(), desc.Digest)
if err != nil {
t.Fatal(err)
}
bw, err := bs.Create(ctx, WithMountFrom(canonicalRef))
if bw != nil {
t.Fatal("unexpected blobwriter returned from Create call, should mount instead")
}
ebm, ok := err.(distribution.ErrBlobMounted)
if !ok {
t.Fatalf("unexpected error mounting layer: %v", err)
}
if !reflect.DeepEqual(ebm.Descriptor, desc) {
t.Fatalf("descriptors not equal: %v != %v", ebm.Descriptor, desc)
}
// Test for existence.
statDesc, err = bs.Stat(ctx, desc.Digest)
if err != nil {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs)
}
if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
}
rc, err := bs.Open(ctx, desc.Digest)
if err != nil {
t.Fatalf("unexpected error opening blob for read: %v", err)
}
defer rc.Close()
h := sha256.New()
nn, err = io.Copy(h, rc)
if err != nil {
t.Fatalf("error reading layer: %v", err)
}
if nn != randomDataSize {
t.Fatalf("incorrect read length")
}
if digest.NewDigest("sha256", h) != dgst {
t.Fatalf("unexpected digest from uploaded layer: %q != %q", digest.NewDigest("sha256", h), dgst)
}
// Delete the blob from the source repo
err = sbs.Delete(ctx, desc.Digest)
if err != nil {
t.Fatalf("Unexpected error deleting blob")
}
d, err := bs.Stat(ctx, desc.Digest)
if err != nil {
t.Fatalf("unexpected error stating blob deleted from source repository: %v", err)
}
d, err = sbs.Stat(ctx, desc.Digest)
if err == nil {
t.Fatalf("unexpected non-error stating deleted blob: %v", d)
}
switch err {
case distribution.ErrBlobUnknown:
break
default:
t.Errorf("Unexpected error type stat-ing deleted manifest: %#v", err)
}
// Delete the blob from the dest repo
err = bs.Delete(ctx, desc.Digest)
if err != nil {
t.Fatalf("Unexpected error deleting blob")
}
d, err = bs.Stat(ctx, desc.Digest)
if err == nil {
t.Fatalf("unexpected non-error stating deleted blob: %v", d)
}
switch err {
case distribution.ErrBlobUnknown:
break
default:
t.Errorf("Unexpected error type stat-ing deleted manifest: %#v", err)
}
}
// TestLayerUploadZeroLength uploads zero-length
func TestLayerUploadZeroLength(t *testing.T) {
ctx := context.Background()
imageName, _ := reference.ParseNamed("foo/bar")
driver := testdriver.New()
registry, err := NewRegistry(ctx, driver, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repository, err := registry.Repository(ctx, imageName)
if err != nil {
t.Fatalf("unexpected error getting repo: %v", err)
}
bs := repository.Blobs(ctx)
simpleUpload(t, bs, []byte{}, digest.DigestSha256EmptyTar)
}
func simpleUpload(t *testing.T, bs distribution.BlobIngester, blob []byte, expectedDigest digest.Digest) {
ctx := context.Background()
wr, err := bs.Create(ctx)
if err != nil {
t.Fatalf("unexpected error starting upload: %v", err)
}
nn, err := io.Copy(wr, bytes.NewReader(blob))
if err != nil {
t.Fatalf("error copying into blob writer: %v", err)
}
if nn != 0 {
t.Fatalf("unexpected number of bytes copied: %v > 0", nn)
}
dgst, err := digest.FromReader(bytes.NewReader(blob))
if err != nil {
t.Fatalf("error getting digest: %v", err)
}
if dgst != expectedDigest {
// sanity check on zero digest
t.Fatalf("digest not as expected: %v != %v", dgst, expectedDigest)
}
desc, err := wr.Commit(ctx, distribution.Descriptor{Digest: dgst})
if err != nil {
t.Fatalf("unexpected error committing write: %v", err)
}
if desc.Digest != dgst {
t.Fatalf("unexpected digest: %v != %v", desc.Digest, dgst)
}
}
// seekerSize seeks to the end of seeker, checks the size and returns it to
// the original state, returning the size. The state of the seeker should be
// treated as unknown if an error is returned.
func seekerSize(seeker io.ReadSeeker) (int64, error) {
current, err := seeker.Seek(0, os.SEEK_CUR)
if err != nil {
return 0, err
}
end, err := seeker.Seek(0, os.SEEK_END)
if err != nil {
return 0, err
}
resumed, err := seeker.Seek(current, os.SEEK_SET)
if err != nil {
return 0, err
}
if resumed != current {
return 0, fmt.Errorf("error returning seeker to original state, could not seek back to original location")
}
return end, nil
}
// addBlob simply consumes the reader and inserts into the blob service,
// returning a descriptor on success.
func addBlob(ctx context.Context, bs distribution.BlobIngester, desc distribution.Descriptor, rd io.Reader) (distribution.Descriptor, error) {
wr, err := bs.Create(ctx)
if err != nil {
return distribution.Descriptor{}, err
}
defer wr.Cancel(ctx)
if nn, err := io.Copy(wr, rd); err != nil {
return distribution.Descriptor{}, err
} else if nn != desc.Size {
return distribution.Descriptor{}, fmt.Errorf("incorrect number of bytes copied: %v != %v", nn, desc.Size)
}
return wr.Commit(ctx, desc)
}

View File

@ -0,0 +1,60 @@
package storage
import (
"expvar"
"sync/atomic"
"github.com/docker/distribution/registry/storage/cache"
)
type blobStatCollector struct {
metrics cache.Metrics
}
func (bsc *blobStatCollector) Hit() {
atomic.AddUint64(&bsc.metrics.Requests, 1)
atomic.AddUint64(&bsc.metrics.Hits, 1)
}
func (bsc *blobStatCollector) Miss() {
atomic.AddUint64(&bsc.metrics.Requests, 1)
atomic.AddUint64(&bsc.metrics.Misses, 1)
}
func (bsc *blobStatCollector) Metrics() cache.Metrics {
return bsc.metrics
}
// blobStatterCacheMetrics keeps track of cache metrics for blob descriptor
// cache requests. Note this is kept globally and made available via expvar.
// For more detailed metrics, its recommend to instrument a particular cache
// implementation.
var blobStatterCacheMetrics cache.MetricsTracker = &blobStatCollector{}
func init() {
registry := expvar.Get("registry")
if registry == nil {
registry = expvar.NewMap("registry")
}
cache := registry.(*expvar.Map).Get("cache")
if cache == nil {
cache = &expvar.Map{}
cache.(*expvar.Map).Init()
registry.(*expvar.Map).Set("cache", cache)
}
storage := cache.(*expvar.Map).Get("storage")
if storage == nil {
storage = &expvar.Map{}
storage.(*expvar.Map).Init()
cache.(*expvar.Map).Set("storage", storage)
}
storage.(*expvar.Map).Set("blobdescriptor", expvar.Func(func() interface{} {
// no need for synchronous access: the increments are atomic and
// during reading, we don't care if the data is up to date. The
// numbers will always *eventually* be reported correctly.
return blobStatterCacheMetrics
}))
}

View File

@ -0,0 +1,78 @@
package storage
import (
"fmt"
"net/http"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/registry/storage/driver"
)
// TODO(stevvooe): This should configurable in the future.
const blobCacheControlMaxAge = 365 * 24 * time.Hour
// blobServer simply serves blobs from a driver instance using a path function
// to identify paths and a descriptor service to fill in metadata.
type blobServer struct {
driver driver.StorageDriver
statter distribution.BlobStatter
pathFn func(dgst digest.Digest) (string, error)
redirect bool // allows disabling URLFor redirects
}
func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
desc, err := bs.statter.Stat(ctx, dgst)
if err != nil {
return err
}
path, err := bs.pathFn(desc.Digest)
if err != nil {
return err
}
if bs.redirect {
redirectURL, err := bs.driver.URLFor(ctx, path, map[string]interface{}{"method": r.Method})
switch err.(type) {
case nil:
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return err
case driver.ErrUnsupportedMethod:
// Fallback to serving the content directly.
default:
// Some unexpected error.
return err
}
}
br, err := newFileReader(ctx, bs.driver, path, desc.Size)
if err != nil {
return err
}
defer br.Close()
w.Header().Set("ETag", fmt.Sprintf(`"%s"`, desc.Digest)) // If-None-Match handled by ServeContent
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%.f", blobCacheControlMaxAge.Seconds()))
if w.Header().Get("Docker-Content-Digest") == "" {
w.Header().Set("Docker-Content-Digest", desc.Digest.String())
}
if w.Header().Get("Content-Type") == "" {
// Set the content type if not already set.
w.Header().Set("Content-Type", desc.MediaType)
}
if w.Header().Get("Content-Length") == "" {
// Set the content length if not already set.
w.Header().Set("Content-Length", fmt.Sprint(desc.Size))
}
http.ServeContent(w, r, desc.Digest.String(), time.Time{}, br)
return nil
}

View File

@ -0,0 +1,223 @@
package storage
import (
"path"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/registry/storage/driver"
)
// blobStore implements the read side of the blob store interface over a
// driver without enforcing per-repository membership. This object is
// intentionally a leaky abstraction, providing utility methods that support
// creating and traversing backend links.
type blobStore struct {
driver driver.StorageDriver
statter distribution.BlobStatter
}
var _ distribution.BlobProvider = &blobStore{}
// Get implements the BlobReadService.Get call.
func (bs *blobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) {
bp, err := bs.path(dgst)
if err != nil {
return nil, err
}
p, err := getContent(ctx, bs.driver, bp)
if err != nil {
switch err.(type) {
case driver.PathNotFoundError:
return nil, distribution.ErrBlobUnknown
}
return nil, err
}
return p, nil
}
func (bs *blobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) {
desc, err := bs.statter.Stat(ctx, dgst)
if err != nil {
return nil, err
}
path, err := bs.path(desc.Digest)
if err != nil {
return nil, err
}
return newFileReader(ctx, bs.driver, path, desc.Size)
}
// Put stores the content p in the blob store, calculating the digest. If the
// content is already present, only the digest will be returned. This should
// only be used for small objects, such as manifests. This implemented as a convenience for other Put implementations
func (bs *blobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) {
dgst := digest.FromBytes(p)
desc, err := bs.statter.Stat(ctx, dgst)
if err == nil {
// content already present
return desc, nil
} else if err != distribution.ErrBlobUnknown {
context.GetLogger(ctx).Errorf("blobStore: error stating content (%v): %v", dgst, err)
// real error, return it
return distribution.Descriptor{}, err
}
bp, err := bs.path(dgst)
if err != nil {
return distribution.Descriptor{}, err
}
// TODO(stevvooe): Write out mediatype here, as well.
return distribution.Descriptor{
Size: int64(len(p)),
// NOTE(stevvooe): The central blob store firewalls media types from
// other users. The caller should look this up and override the value
// for the specific repository.
MediaType: "application/octet-stream",
Digest: dgst,
}, bs.driver.PutContent(ctx, bp, p)
}
func (bs *blobStore) Enumerate(ctx context.Context, ingester func(dgst digest.Digest) error) error {
specPath, err := pathFor(blobsPathSpec{})
if err != nil {
return err
}
err = Walk(ctx, bs.driver, specPath, func(fileInfo driver.FileInfo) error {
// skip directories
if fileInfo.IsDir() {
return nil
}
currentPath := fileInfo.Path()
// we only want to parse paths that end with /data
_, fileName := path.Split(currentPath)
if fileName != "data" {
return nil
}
digest, err := digestFromPath(currentPath)
if err != nil {
return err
}
return ingester(digest)
})
return err
}
// path returns the canonical path for the blob identified by digest. The blob
// may or may not exist.
func (bs *blobStore) path(dgst digest.Digest) (string, error) {
bp, err := pathFor(blobDataPathSpec{
digest: dgst,
})
if err != nil {
return "", err
}
return bp, nil
}
// link links the path to the provided digest by writing the digest into the
// target file. Caller must ensure that the blob actually exists.
func (bs *blobStore) link(ctx context.Context, path string, dgst digest.Digest) error {
// The contents of the "link" file are the exact string contents of the
// digest, which is specified in that package.
return bs.driver.PutContent(ctx, path, []byte(dgst))
}
// readlink returns the linked digest at path.
func (bs *blobStore) readlink(ctx context.Context, path string) (digest.Digest, error) {
content, err := bs.driver.GetContent(ctx, path)
if err != nil {
return "", err
}
linked, err := digest.ParseDigest(string(content))
if err != nil {
return "", err
}
return linked, nil
}
// resolve reads the digest link at path and returns the blob store path.
func (bs *blobStore) resolve(ctx context.Context, path string) (string, error) {
dgst, err := bs.readlink(ctx, path)
if err != nil {
return "", err
}
return bs.path(dgst)
}
type blobStatter struct {
driver driver.StorageDriver
}
var _ distribution.BlobDescriptorService = &blobStatter{}
// Stat implements BlobStatter.Stat by returning the descriptor for the blob
// in the main blob store. If this method returns successfully, there is
// strong guarantee that the blob exists and is available.
func (bs *blobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
path, err := pathFor(blobDataPathSpec{
digest: dgst,
})
if err != nil {
return distribution.Descriptor{}, err
}
fi, err := bs.driver.Stat(ctx, path)
if err != nil {
switch err := err.(type) {
case driver.PathNotFoundError:
return distribution.Descriptor{}, distribution.ErrBlobUnknown
default:
return distribution.Descriptor{}, err
}
}
if fi.IsDir() {
// NOTE(stevvooe): This represents a corruption situation. Somehow, we
// calculated a blob path and then detected a directory. We log the
// error and then error on the side of not knowing about the blob.
context.GetLogger(ctx).Warnf("blob path should not be a directory: %q", path)
return distribution.Descriptor{}, distribution.ErrBlobUnknown
}
// TODO(stevvooe): Add method to resolve the mediatype. We can store and
// cache a "global" media type for the blob, even if a specific repo has a
// mediatype that overrides the main one.
return distribution.Descriptor{
Size: fi.Size(),
// NOTE(stevvooe): The central blob store firewalls media types from
// other users. The caller should look this up and override the value
// for the specific repository.
MediaType: "application/octet-stream",
Digest: dgst,
}, nil
}
func (bs *blobStatter) Clear(ctx context.Context, dgst digest.Digest) error {
return distribution.ErrUnsupported
}
func (bs *blobStatter) SetDescriptor(ctx context.Context, dgst digest.Digest, desc distribution.Descriptor) error {
return distribution.ErrUnsupported
}

View File

@ -0,0 +1,399 @@
package storage
import (
"errors"
"fmt"
"io"
"path"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
var (
errResumableDigestNotAvailable = errors.New("resumable digest not available")
)
// blobWriter is used to control the various aspects of resumable
// blob upload.
type blobWriter struct {
ctx context.Context
blobStore *linkedBlobStore
id string
startedAt time.Time
digester digest.Digester
written int64 // track the contiguous write
fileWriter storagedriver.FileWriter
driver storagedriver.StorageDriver
path string
resumableDigestEnabled bool
committed bool
}
var _ distribution.BlobWriter = &blobWriter{}
// ID returns the identifier for this upload.
func (bw *blobWriter) ID() string {
return bw.id
}
func (bw *blobWriter) StartedAt() time.Time {
return bw.startedAt
}
// Commit marks the upload as completed, returning a valid descriptor. The
// final size and digest are checked against the first descriptor provided.
func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
context.GetLogger(ctx).Debug("(*blobWriter).Commit")
if err := bw.fileWriter.Commit(); err != nil {
return distribution.Descriptor{}, err
}
bw.Close()
desc.Size = bw.Size()
canonical, err := bw.validateBlob(ctx, desc)
if err != nil {
return distribution.Descriptor{}, err
}
if err := bw.moveBlob(ctx, canonical); err != nil {
return distribution.Descriptor{}, err
}
if err := bw.blobStore.linkBlob(ctx, canonical, desc.Digest); err != nil {
return distribution.Descriptor{}, err
}
if err := bw.removeResources(ctx); err != nil {
return distribution.Descriptor{}, err
}
err = bw.blobStore.blobAccessController.SetDescriptor(ctx, canonical.Digest, canonical)
if err != nil {
return distribution.Descriptor{}, err
}
bw.committed = true
return canonical, nil
}
// Cancel the blob upload process, releasing any resources associated with
// the writer and canceling the operation.
func (bw *blobWriter) Cancel(ctx context.Context) error {
context.GetLogger(ctx).Debug("(*blobWriter).Cancel")
if err := bw.fileWriter.Cancel(); err != nil {
return err
}
if err := bw.Close(); err != nil {
context.GetLogger(ctx).Errorf("error closing blobwriter: %s", err)
}
if err := bw.removeResources(ctx); err != nil {
return err
}
return nil
}
func (bw *blobWriter) Size() int64 {
return bw.fileWriter.Size()
}
func (bw *blobWriter) Write(p []byte) (int, error) {
// Ensure that the current write offset matches how many bytes have been
// written to the digester. If not, we need to update the digest state to
// match the current write position.
if err := bw.resumeDigest(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return 0, err
}
n, err := io.MultiWriter(bw.fileWriter, bw.digester.Hash()).Write(p)
bw.written += int64(n)
return n, err
}
func (bw *blobWriter) ReadFrom(r io.Reader) (n int64, err error) {
// Ensure that the current write offset matches how many bytes have been
// written to the digester. If not, we need to update the digest state to
// match the current write position.
if err := bw.resumeDigest(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return 0, err
}
nn, err := io.Copy(io.MultiWriter(bw.fileWriter, bw.digester.Hash()), r)
bw.written += nn
return nn, err
}
func (bw *blobWriter) Close() error {
if bw.committed {
return errors.New("blobwriter close after commit")
}
if err := bw.storeHashState(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return err
}
return bw.fileWriter.Close()
}
// validateBlob checks the data against the digest, returning an error if it
// does not match. The canonical descriptor is returned.
func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
var (
verified, fullHash bool
canonical digest.Digest
)
if desc.Digest == "" {
// if no descriptors are provided, we have nothing to validate
// against. We don't really want to support this for the registry.
return distribution.Descriptor{}, distribution.ErrBlobInvalidDigest{
Reason: fmt.Errorf("cannot validate against empty digest"),
}
}
var size int64
// Stat the on disk file
if fi, err := bw.driver.Stat(ctx, bw.path); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
// NOTE(stevvooe): We really don't care if the file is
// not actually present for the reader. We now assume
// that the desc length is zero.
desc.Size = 0
default:
// Any other error we want propagated up the stack.
return distribution.Descriptor{}, err
}
} else {
if fi.IsDir() {
return distribution.Descriptor{}, fmt.Errorf("unexpected directory at upload location %q", bw.path)
}
size = fi.Size()
}
if desc.Size > 0 {
if desc.Size != size {
return distribution.Descriptor{}, distribution.ErrBlobInvalidLength
}
} else {
// if provided 0 or negative length, we can assume caller doesn't know or
// care about length.
desc.Size = size
}
// TODO(stevvooe): This section is very meandering. Need to be broken down
// to be a lot more clear.
if err := bw.resumeDigest(ctx); err == nil {
canonical = bw.digester.Digest()
if canonical.Algorithm() == desc.Digest.Algorithm() {
// Common case: client and server prefer the same canonical digest
// algorithm - currently SHA256.
verified = desc.Digest == canonical
} else {
// The client wants to use a different digest algorithm. They'll just
// have to be patient and wait for us to download and re-hash the
// uploaded content using that digest algorithm.
fullHash = true
}
} else if err == errResumableDigestNotAvailable {
// Not using resumable digests, so we need to hash the entire layer.
fullHash = true
} else {
return distribution.Descriptor{}, err
}
if fullHash {
// a fantastic optimization: if the the written data and the size are
// the same, we don't need to read the data from the backend. This is
// because we've written the entire file in the lifecycle of the
// current instance.
if bw.written == size && digest.Canonical == desc.Digest.Algorithm() {
canonical = bw.digester.Digest()
verified = desc.Digest == canonical
}
// If the check based on size fails, we fall back to the slowest of
// paths. We may be able to make the size-based check a stronger
// guarantee, so this may be defensive.
if !verified {
digester := digest.Canonical.New()
digestVerifier, err := digest.NewDigestVerifier(desc.Digest)
if err != nil {
return distribution.Descriptor{}, err
}
// Read the file from the backend driver and validate it.
fr, err := newFileReader(ctx, bw.driver, bw.path, desc.Size)
if err != nil {
return distribution.Descriptor{}, err
}
defer fr.Close()
tr := io.TeeReader(fr, digester.Hash())
if _, err := io.Copy(digestVerifier, tr); err != nil {
return distribution.Descriptor{}, err
}
canonical = digester.Digest()
verified = digestVerifier.Verified()
}
}
if !verified {
context.GetLoggerWithFields(ctx,
map[interface{}]interface{}{
"canonical": canonical,
"provided": desc.Digest,
}, "canonical", "provided").
Errorf("canonical digest does match provided digest")
return distribution.Descriptor{}, distribution.ErrBlobInvalidDigest{
Digest: desc.Digest,
Reason: fmt.Errorf("content does not match digest"),
}
}
// update desc with canonical hash
desc.Digest = canonical
if desc.MediaType == "" {
desc.MediaType = "application/octet-stream"
}
return desc, nil
}
// moveBlob moves the data into its final, hash-qualified destination,
// identified by dgst. The layer should be validated before commencing the
// move.
func (bw *blobWriter) moveBlob(ctx context.Context, desc distribution.Descriptor) error {
blobPath, err := pathFor(blobDataPathSpec{
digest: desc.Digest,
})
if err != nil {
return err
}
// Check for existence
if _, err := bw.blobStore.driver.Stat(ctx, blobPath); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
break // ensure that it doesn't exist.
default:
return err
}
} else {
// If the path exists, we can assume that the content has already
// been uploaded, since the blob storage is content-addressable.
// While it may be corrupted, detection of such corruption belongs
// elsewhere.
return nil
}
// If no data was received, we may not actually have a file on disk. Check
// the size here and write a zero-length file to blobPath if this is the
// case. For the most part, this should only ever happen with zero-length
// tars.
if _, err := bw.blobStore.driver.Stat(ctx, bw.path); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
// HACK(stevvooe): This is slightly dangerous: if we verify above,
// get a hash, then the underlying file is deleted, we risk moving
// a zero-length blob into a nonzero-length blob location. To
// prevent this horrid thing, we employ the hack of only allowing
// to this happen for the digest of an empty tar.
if desc.Digest == digest.DigestSha256EmptyTar {
return bw.blobStore.driver.PutContent(ctx, blobPath, []byte{})
}
// We let this fail during the move below.
logrus.
WithField("upload.id", bw.ID()).
WithField("digest", desc.Digest).Warnf("attempted to move zero-length content with non-zero digest")
default:
return err // unrelated error
}
}
// TODO(stevvooe): We should also write the mediatype when executing this move.
return bw.blobStore.driver.Move(ctx, bw.path, blobPath)
}
// removeResources should clean up all resources associated with the upload
// instance. An error will be returned if the clean up cannot proceed. If the
// resources are already not present, no error will be returned.
func (bw *blobWriter) removeResources(ctx context.Context) error {
dataPath, err := pathFor(uploadDataPathSpec{
name: bw.blobStore.repository.Named().Name(),
id: bw.id,
})
if err != nil {
return err
}
// Resolve and delete the containing directory, which should include any
// upload related files.
dirPath := path.Dir(dataPath)
if err := bw.blobStore.driver.Delete(ctx, dirPath); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
break // already gone!
default:
// This should be uncommon enough such that returning an error
// should be okay. At this point, the upload should be mostly
// complete, but perhaps the backend became unaccessible.
context.GetLogger(ctx).Errorf("unable to delete layer upload resources %q: %v", dirPath, err)
return err
}
}
return nil
}
func (bw *blobWriter) Reader() (io.ReadCloser, error) {
// todo(richardscothern): Change to exponential backoff, i=0.5, e=2, n=4
try := 1
for try <= 5 {
_, err := bw.driver.Stat(bw.ctx, bw.path)
if err == nil {
break
}
switch err.(type) {
case storagedriver.PathNotFoundError:
context.GetLogger(bw.ctx).Debugf("Nothing found on try %d, sleeping...", try)
time.Sleep(1 * time.Second)
try++
default:
return nil, err
}
}
readCloser, err := bw.driver.Reader(bw.ctx, bw.path, 0)
if err != nil {
return nil, err
}
return readCloser, nil
}

View File

@ -0,0 +1,17 @@
// +build noresumabledigest
package storage
import (
"github.com/docker/distribution/context"
)
// resumeHashAt is a noop when resumable digest support is disabled.
func (bw *blobWriter) resumeDigest(ctx context.Context) error {
return errResumableDigestNotAvailable
}
// storeHashState is a noop when resumable digest support is disabled.
func (bw *blobWriter) storeHashState(ctx context.Context) error {
return errResumableDigestNotAvailable
}

View File

@ -0,0 +1,145 @@
// +build !noresumabledigest
package storage
import (
"fmt"
"path"
"strconv"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/stevvooe/resumable"
// register resumable hashes with import
_ "github.com/stevvooe/resumable/sha256"
_ "github.com/stevvooe/resumable/sha512"
)
// resumeDigest attempts to restore the state of the internal hash function
// by loading the most recent saved hash state equal to the current size of the blob.
func (bw *blobWriter) resumeDigest(ctx context.Context) error {
if !bw.resumableDigestEnabled {
return errResumableDigestNotAvailable
}
h, ok := bw.digester.Hash().(resumable.Hash)
if !ok {
return errResumableDigestNotAvailable
}
offset := bw.fileWriter.Size()
if offset == int64(h.Len()) {
// State of digester is already at the requested offset.
return nil
}
// List hash states from storage backend.
var hashStateMatch hashStateEntry
hashStates, err := bw.getStoredHashStates(ctx)
if err != nil {
return fmt.Errorf("unable to get stored hash states with offset %d: %s", offset, err)
}
// Find the highest stored hashState with offset equal to
// the requested offset.
for _, hashState := range hashStates {
if hashState.offset == offset {
hashStateMatch = hashState
break // Found an exact offset match.
}
}
if hashStateMatch.offset == 0 {
// No need to load any state, just reset the hasher.
h.Reset()
} else {
storedState, err := bw.driver.GetContent(ctx, hashStateMatch.path)
if err != nil {
return err
}
if err = h.Restore(storedState); err != nil {
return err
}
}
// Mind the gap.
if gapLen := offset - int64(h.Len()); gapLen > 0 {
return errResumableDigestNotAvailable
}
return nil
}
type hashStateEntry struct {
offset int64
path string
}
// getStoredHashStates returns a slice of hashStateEntries for this upload.
func (bw *blobWriter) getStoredHashStates(ctx context.Context) ([]hashStateEntry, error) {
uploadHashStatePathPrefix, err := pathFor(uploadHashStatePathSpec{
name: bw.blobStore.repository.Named().String(),
id: bw.id,
alg: bw.digester.Digest().Algorithm(),
list: true,
})
if err != nil {
return nil, err
}
paths, err := bw.blobStore.driver.List(ctx, uploadHashStatePathPrefix)
if err != nil {
if _, ok := err.(storagedriver.PathNotFoundError); !ok {
return nil, err
}
// Treat PathNotFoundError as no entries.
paths = nil
}
hashStateEntries := make([]hashStateEntry, 0, len(paths))
for _, p := range paths {
pathSuffix := path.Base(p)
// The suffix should be the offset.
offset, err := strconv.ParseInt(pathSuffix, 0, 64)
if err != nil {
logrus.Errorf("unable to parse offset from upload state path %q: %s", p, err)
}
hashStateEntries = append(hashStateEntries, hashStateEntry{offset: offset, path: p})
}
return hashStateEntries, nil
}
func (bw *blobWriter) storeHashState(ctx context.Context) error {
if !bw.resumableDigestEnabled {
return errResumableDigestNotAvailable
}
h, ok := bw.digester.Hash().(resumable.Hash)
if !ok {
return errResumableDigestNotAvailable
}
uploadHashStatePath, err := pathFor(uploadHashStatePathSpec{
name: bw.blobStore.repository.Named().String(),
id: bw.id,
alg: bw.digester.Digest().Algorithm(),
offset: int64(h.Len()),
})
if err != nil {
return err
}
hashState, err := h.State()
if err != nil {
return err
}
return bw.driver.PutContent(ctx, uploadHashStatePath, hashState)
}

View File

@ -0,0 +1,180 @@
package cachecheck
import (
"reflect"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/registry/storage/cache"
)
// CheckBlobDescriptorCache takes a cache implementation through a common set
// of operations. If adding new tests, please add them here so new
// implementations get the benefit. This should be used for unit tests.
func CheckBlobDescriptorCache(t *testing.T, provider cache.BlobDescriptorCacheProvider) {
ctx := context.Background()
checkBlobDescriptorCacheEmptyRepository(ctx, t, provider)
checkBlobDescriptorCacheSetAndRead(ctx, t, provider)
checkBlobDescriptorCacheClear(ctx, t, provider)
}
func checkBlobDescriptorCacheEmptyRepository(ctx context.Context, t *testing.T, provider cache.BlobDescriptorCacheProvider) {
if _, err := provider.Stat(ctx, "sha384:abc111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"); err != distribution.ErrBlobUnknown {
t.Fatalf("expected unknown blob error with empty store: %v", err)
}
cache, err := provider.RepositoryScoped("")
if err == nil {
t.Fatalf("expected an error when asking for invalid repo")
}
cache, err = provider.RepositoryScoped("foo/bar")
if err != nil {
t.Fatalf("unexpected error getting repository: %v", err)
}
if err := cache.SetDescriptor(ctx, "", distribution.Descriptor{
Digest: "sha384:abc",
Size: 10,
MediaType: "application/octet-stream"}); err != digest.ErrDigestInvalidFormat {
t.Fatalf("expected error with invalid digest: %v", err)
}
if err := cache.SetDescriptor(ctx, "sha384:abc111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111", distribution.Descriptor{
Digest: "",
Size: 10,
MediaType: "application/octet-stream"}); err == nil {
t.Fatalf("expected error setting value on invalid descriptor")
}
if _, err := cache.Stat(ctx, ""); err != digest.ErrDigestInvalidFormat {
t.Fatalf("expected error checking for cache item with empty digest: %v", err)
}
if _, err := cache.Stat(ctx, "sha384:abc111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"); err != distribution.ErrBlobUnknown {
t.Fatalf("expected unknown blob error with empty repo: %v", err)
}
}
func checkBlobDescriptorCacheSetAndRead(ctx context.Context, t *testing.T, provider cache.BlobDescriptorCacheProvider) {
localDigest := digest.Digest("sha384:abc111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111")
expected := distribution.Descriptor{
Digest: "sha256:abc1111111111111111111111111111111111111111111111111111111111111",
Size: 10,
MediaType: "application/octet-stream"}
cache, err := provider.RepositoryScoped("foo/bar")
if err != nil {
t.Fatalf("unexpected error getting scoped cache: %v", err)
}
if err := cache.SetDescriptor(ctx, localDigest, expected); err != nil {
t.Fatalf("error setting descriptor: %v", err)
}
desc, err := cache.Stat(ctx, localDigest)
if err != nil {
t.Fatalf("unexpected error statting fake2:abc: %v", err)
}
if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
}
// also check that we set the canonical key ("fake:abc")
desc, err = cache.Stat(ctx, localDigest)
if err != nil {
t.Fatalf("descriptor not returned for canonical key: %v", err)
}
if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
}
// ensure that global gets extra descriptor mapping
desc, err = provider.Stat(ctx, localDigest)
if err != nil {
t.Fatalf("expected blob unknown in global cache: %v, %v", err, desc)
}
if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
}
// get at it through canonical descriptor
desc, err = provider.Stat(ctx, expected.Digest)
if err != nil {
t.Fatalf("unexpected error checking glboal descriptor: %v", err)
}
if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
}
// now, we set the repo local mediatype to something else and ensure it
// doesn't get changed in the provider cache.
expected.MediaType = "application/json"
if err := cache.SetDescriptor(ctx, localDigest, expected); err != nil {
t.Fatalf("unexpected error setting descriptor: %v", err)
}
desc, err = cache.Stat(ctx, localDigest)
if err != nil {
t.Fatalf("unexpected error getting descriptor: %v", err)
}
if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected)
}
desc, err = provider.Stat(ctx, localDigest)
if err != nil {
t.Fatalf("unexpected error getting global descriptor: %v", err)
}
expected.MediaType = "application/octet-stream" // expect original mediatype in global
if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected)
}
}
func checkBlobDescriptorCacheClear(ctx context.Context, t *testing.T, provider cache.BlobDescriptorCacheProvider) {
localDigest := digest.Digest("sha384:def111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111")
expected := distribution.Descriptor{
Digest: "sha256:def1111111111111111111111111111111111111111111111111111111111111",
Size: 10,
MediaType: "application/octet-stream"}
cache, err := provider.RepositoryScoped("foo/bar")
if err != nil {
t.Fatalf("unexpected error getting scoped cache: %v", err)
}
if err := cache.SetDescriptor(ctx, localDigest, expected); err != nil {
t.Fatalf("error setting descriptor: %v", err)
}
desc, err := cache.Stat(ctx, localDigest)
if err != nil {
t.Fatalf("unexpected error statting fake2:abc: %v", err)
}
if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
}
err = cache.Clear(ctx, localDigest)
if err != nil {
t.Error(err)
}
desc, err = cache.Stat(ctx, localDigest)
if err == nil {
t.Fatalf("expected error statting deleted blob: %v", err)
}
}

View File

@ -2,7 +2,7 @@ package cache
import (
"github.com/docker/distribution/context"
"github.com/opencontainers/go-digest"
"github.com/docker/distribution/digest"
"github.com/docker/distribution"
)

View File

@ -5,9 +5,9 @@ import (
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache"
"github.com/opencontainers/go-digest"
)
type inMemoryBlobDescriptorCacheProvider struct {
@ -26,7 +26,7 @@ func NewInMemoryBlobDescriptorCacheProvider() cache.BlobDescriptorCacheProvider
}
func (imbdcp *inMemoryBlobDescriptorCacheProvider) RepositoryScoped(repo string) (distribution.BlobDescriptorService, error) {
if _, err := reference.ParseNormalizedNamed(repo); err != nil {
if _, err := reference.ParseNamed(repo); err != nil {
return nil, err
}

View File

@ -0,0 +1,13 @@
package memory
import (
"testing"
"github.com/docker/distribution/registry/storage/cache/cachecheck"
)
// TestInMemoryBlobInfoCache checks the in memory implementation is working
// correctly.
func TestInMemoryBlobInfoCache(t *testing.T) {
cachecheck.CheckBlobDescriptorCache(t, NewInMemoryBlobDescriptorCacheProvider())
}

View File

@ -0,0 +1,268 @@
package redis
import (
"fmt"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache"
"github.com/garyburd/redigo/redis"
)
// redisBlobStatService provides an implementation of
// BlobDescriptorCacheProvider based on redis. Blob descriptors are stored in
// two parts. The first provide fast access to repository membership through a
// redis set for each repo. The second is a redis hash keyed by the digest of
// the layer, providing path, length and mediatype information. There is also
// a per-repository redis hash of the blob descriptor, allowing override of
// data. This is currently used to override the mediatype on a per-repository
// basis.
//
// Note that there is no implied relationship between these two caches. The
// layer may exist in one, both or none and the code must be written this way.
type redisBlobDescriptorService struct {
pool *redis.Pool
// TODO(stevvooe): We use a pool because we don't have great control over
// the cache lifecycle to manage connections. A new connection if fetched
// for each operation. Once we have better lifecycle management of the
// request objects, we can change this to a connection.
}
// NewRedisBlobDescriptorCacheProvider returns a new redis-based
// BlobDescriptorCacheProvider using the provided redis connection pool.
func NewRedisBlobDescriptorCacheProvider(pool *redis.Pool) cache.BlobDescriptorCacheProvider {
return &redisBlobDescriptorService{
pool: pool,
}
}
// RepositoryScoped returns the scoped cache.
func (rbds *redisBlobDescriptorService) RepositoryScoped(repo string) (distribution.BlobDescriptorService, error) {
if _, err := reference.ParseNamed(repo); err != nil {
return nil, err
}
return &repositoryScopedRedisBlobDescriptorService{
repo: repo,
upstream: rbds,
}, nil
}
// Stat retrieves the descriptor data from the redis hash entry.
func (rbds *redisBlobDescriptorService) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
if err := dgst.Validate(); err != nil {
return distribution.Descriptor{}, err
}
conn := rbds.pool.Get()
defer conn.Close()
return rbds.stat(ctx, conn, dgst)
}
func (rbds *redisBlobDescriptorService) Clear(ctx context.Context, dgst digest.Digest) error {
if err := dgst.Validate(); err != nil {
return err
}
conn := rbds.pool.Get()
defer conn.Close()
// Not atomic in redis <= 2.3
reply, err := conn.Do("HDEL", rbds.blobDescriptorHashKey(dgst), "digest", "length", "mediatype")
if err != nil {
return err
}
if reply == 0 {
return distribution.ErrBlobUnknown
}
return nil
}
// stat provides an internal stat call that takes a connection parameter. This
// allows some internal management of the connection scope.
func (rbds *redisBlobDescriptorService) stat(ctx context.Context, conn redis.Conn, dgst digest.Digest) (distribution.Descriptor, error) {
reply, err := redis.Values(conn.Do("HMGET", rbds.blobDescriptorHashKey(dgst), "digest", "size", "mediatype"))
if err != nil {
return distribution.Descriptor{}, err
}
// NOTE(stevvooe): The "size" field used to be "length". We treat a
// missing "size" field here as an unknown blob, which causes a cache
// miss, effectively migrating the field.
if len(reply) < 3 || reply[0] == nil || reply[1] == nil { // don't care if mediatype is nil
return distribution.Descriptor{}, distribution.ErrBlobUnknown
}
var desc distribution.Descriptor
if _, err := redis.Scan(reply, &desc.Digest, &desc.Size, &desc.MediaType); err != nil {
return distribution.Descriptor{}, err
}
return desc, nil
}
// SetDescriptor sets the descriptor data for the given digest using a redis
// hash. A hash is used here since we may store unrelated fields about a layer
// in the future.
func (rbds *redisBlobDescriptorService) SetDescriptor(ctx context.Context, dgst digest.Digest, desc distribution.Descriptor) error {
if err := dgst.Validate(); err != nil {
return err
}
if err := cache.ValidateDescriptor(desc); err != nil {
return err
}
conn := rbds.pool.Get()
defer conn.Close()
return rbds.setDescriptor(ctx, conn, dgst, desc)
}
func (rbds *redisBlobDescriptorService) setDescriptor(ctx context.Context, conn redis.Conn, dgst digest.Digest, desc distribution.Descriptor) error {
if _, err := conn.Do("HMSET", rbds.blobDescriptorHashKey(dgst),
"digest", desc.Digest,
"size", desc.Size); err != nil {
return err
}
// Only set mediatype if not already set.
if _, err := conn.Do("HSETNX", rbds.blobDescriptorHashKey(dgst),
"mediatype", desc.MediaType); err != nil {
return err
}
return nil
}
func (rbds *redisBlobDescriptorService) blobDescriptorHashKey(dgst digest.Digest) string {
return "blobs::" + dgst.String()
}
type repositoryScopedRedisBlobDescriptorService struct {
repo string
upstream *redisBlobDescriptorService
}
var _ distribution.BlobDescriptorService = &repositoryScopedRedisBlobDescriptorService{}
// Stat ensures that the digest is a member of the specified repository and
// forwards the descriptor request to the global blob store. If the media type
// differs for the repository, we override it.
func (rsrbds *repositoryScopedRedisBlobDescriptorService) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
if err := dgst.Validate(); err != nil {
return distribution.Descriptor{}, err
}
conn := rsrbds.upstream.pool.Get()
defer conn.Close()
// Check membership to repository first
member, err := redis.Bool(conn.Do("SISMEMBER", rsrbds.repositoryBlobSetKey(rsrbds.repo), dgst))
if err != nil {
return distribution.Descriptor{}, err
}
if !member {
return distribution.Descriptor{}, distribution.ErrBlobUnknown
}
upstream, err := rsrbds.upstream.stat(ctx, conn, dgst)
if err != nil {
return distribution.Descriptor{}, err
}
// We allow a per repository mediatype, let's look it up here.
mediatype, err := redis.String(conn.Do("HGET", rsrbds.blobDescriptorHashKey(dgst), "mediatype"))
if err != nil {
return distribution.Descriptor{}, err
}
if mediatype != "" {
upstream.MediaType = mediatype
}
return upstream, nil
}
// Clear removes the descriptor from the cache and forwards to the upstream descriptor store
func (rsrbds *repositoryScopedRedisBlobDescriptorService) Clear(ctx context.Context, dgst digest.Digest) error {
if err := dgst.Validate(); err != nil {
return err
}
conn := rsrbds.upstream.pool.Get()
defer conn.Close()
// Check membership to repository first
member, err := redis.Bool(conn.Do("SISMEMBER", rsrbds.repositoryBlobSetKey(rsrbds.repo), dgst))
if err != nil {
return err
}
if !member {
return distribution.ErrBlobUnknown
}
return rsrbds.upstream.Clear(ctx, dgst)
}
func (rsrbds *repositoryScopedRedisBlobDescriptorService) SetDescriptor(ctx context.Context, dgst digest.Digest, desc distribution.Descriptor) error {
if err := dgst.Validate(); err != nil {
return err
}
if err := cache.ValidateDescriptor(desc); err != nil {
return err
}
if dgst != desc.Digest {
if dgst.Algorithm() == desc.Digest.Algorithm() {
return fmt.Errorf("redis cache: digest for descriptors differ but algorthim does not: %q != %q", dgst, desc.Digest)
}
}
conn := rsrbds.upstream.pool.Get()
defer conn.Close()
return rsrbds.setDescriptor(ctx, conn, dgst, desc)
}
func (rsrbds *repositoryScopedRedisBlobDescriptorService) setDescriptor(ctx context.Context, conn redis.Conn, dgst digest.Digest, desc distribution.Descriptor) error {
if _, err := conn.Do("SADD", rsrbds.repositoryBlobSetKey(rsrbds.repo), dgst); err != nil {
return err
}
if err := rsrbds.upstream.setDescriptor(ctx, conn, dgst, desc); err != nil {
return err
}
// Override repository mediatype.
if _, err := conn.Do("HSET", rsrbds.blobDescriptorHashKey(dgst), "mediatype", desc.MediaType); err != nil {
return err
}
// Also set the values for the primary descriptor, if they differ by
// algorithm (ie sha256 vs sha512).
if desc.Digest != "" && dgst != desc.Digest && dgst.Algorithm() != desc.Digest.Algorithm() {
if err := rsrbds.setDescriptor(ctx, conn, desc.Digest, desc); err != nil {
return err
}
}
return nil
}
func (rsrbds *repositoryScopedRedisBlobDescriptorService) blobDescriptorHashKey(dgst digest.Digest) string {
return "repository::" + rsrbds.repo + "::blobs::" + dgst.String()
}
func (rsrbds *repositoryScopedRedisBlobDescriptorService) repositoryBlobSetKey(repo string) string {
return "repository::" + rsrbds.repo + "::blobs"
}

View File

@ -0,0 +1,53 @@
package redis
import (
"flag"
"os"
"testing"
"time"
"github.com/docker/distribution/registry/storage/cache/cachecheck"
"github.com/garyburd/redigo/redis"
)
var redisAddr string
func init() {
flag.StringVar(&redisAddr, "test.registry.storage.cache.redis.addr", "", "configure the address of a test instance of redis")
}
// TestRedisLayerInfoCache exercises a live redis instance using the cache
// implementation.
func TestRedisBlobDescriptorCacheProvider(t *testing.T) {
if redisAddr == "" {
// fallback to an environement variable
redisAddr = os.Getenv("TEST_REGISTRY_STORAGE_CACHE_REDIS_ADDR")
}
if redisAddr == "" {
// skip if still not set
t.Skip("please set -test.registry.storage.cache.redis.addr to test layer info cache against redis")
}
pool := &redis.Pool{
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", redisAddr)
},
MaxIdle: 1,
MaxActive: 2,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
Wait: false, // if a connection is not avialable, proceed without cache.
}
// Clear the database
conn := pool.Get()
if _, err := conn.Do("FLUSHDB"); err != nil {
t.Fatalf("unexpected error flushing redis db: %v", err)
}
conn.Close()
cachecheck.CheckBlobDescriptorCache(t, NewRedisBlobDescriptorCacheProvider(pool))
}

View File

@ -0,0 +1,153 @@
package storage
import (
"errors"
"io"
"path"
"strings"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage/driver"
)
// errFinishedWalk signals an early exit to the walk when the current query
// is satisfied.
var errFinishedWalk = errors.New("finished walk")
// Returns a list, or partial list, of repositories in the registry.
// Because it's a quite expensive operation, it should only be used when building up
// an initial set of repositories.
func (reg *registry) Repositories(ctx context.Context, repos []string, last string) (n int, err error) {
var foundRepos []string
if len(repos) == 0 {
return 0, errors.New("no space in slice")
}
root, err := pathFor(repositoriesRootPathSpec{})
if err != nil {
return 0, err
}
err = Walk(ctx, reg.blobStore.driver, root, func(fileInfo driver.FileInfo) error {
err := handleRepository(fileInfo, root, last, func(repoPath string) error {
foundRepos = append(foundRepos, repoPath)
return nil
})
if err != nil {
return err
}
// if we've filled our array, no need to walk any further
if len(foundRepos) == len(repos) {
return errFinishedWalk
}
return nil
})
n = copy(repos, foundRepos)
switch err {
case nil:
// nil means that we completed walk and didn't fill buffer. No more
// records are available.
err = io.EOF
case errFinishedWalk:
// more records are available.
err = nil
}
return n, err
}
// Enumerate applies ingester to each repository
func (reg *registry) Enumerate(ctx context.Context, ingester func(string) error) error {
root, err := pathFor(repositoriesRootPathSpec{})
if err != nil {
return err
}
err = Walk(ctx, reg.blobStore.driver, root, func(fileInfo driver.FileInfo) error {
return handleRepository(fileInfo, root, "", ingester)
})
return err
}
// lessPath returns true if one path a is less than path b.
//
// A component-wise comparison is done, rather than the lexical comparison of
// strings.
func lessPath(a, b string) bool {
// we provide this behavior by making separator always sort first.
return compareReplaceInline(a, b, '/', '\x00') < 0
}
// compareReplaceInline modifies runtime.cmpstring to replace old with new
// during a byte-wise comparison.
func compareReplaceInline(s1, s2 string, old, new byte) int {
// TODO(stevvooe): We are missing an optimization when the s1 and s2 have
// the exact same slice header. It will make the code unsafe but can
// provide some extra performance.
l := len(s1)
if len(s2) < l {
l = len(s2)
}
for i := 0; i < l; i++ {
c1, c2 := s1[i], s2[i]
if c1 == old {
c1 = new
}
if c2 == old {
c2 = new
}
if c1 < c2 {
return -1
}
if c1 > c2 {
return +1
}
}
if len(s1) < len(s2) {
return -1
}
if len(s1) > len(s2) {
return +1
}
return 0
}
// handleRepository calls function fn with a repository path if fileInfo
// has a path of a repository under root and that it is lexographically
// after last. Otherwise, it will return ErrSkipDir. This should be used
// with Walk to do handling with repositories in a storage.
func handleRepository(fileInfo driver.FileInfo, root, last string, fn func(repoPath string) error) error {
filePath := fileInfo.Path()
// lop the base path off
repo := filePath[len(root)+1:]
_, file := path.Split(repo)
if file == "_layers" {
repo = strings.TrimSuffix(repo, "/_layers")
if lessPath(last, repo) {
if err := fn(repo); err != nil {
return err
}
}
return ErrSkipDir
} else if strings.HasPrefix(file, "_") {
return ErrSkipDir
}
return nil
}

View File

@ -0,0 +1,324 @@
package storage
import (
"fmt"
"io"
"math/rand"
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory"
"github.com/docker/distribution/testutil"
)
type setupEnv struct {
ctx context.Context
driver driver.StorageDriver
expected []string
registry distribution.Namespace
}
func setupFS(t *testing.T) *setupEnv {
d := inmemory.New()
ctx := context.Background()
registry, err := NewRegistry(ctx, d, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
repos := []string{
"foo/a",
"foo/b",
"foo-bar/a",
"bar/c",
"bar/d",
"bar/e",
"foo/d/in",
"foo-bar/b",
"test",
}
for _, repo := range repos {
makeRepo(ctx, t, repo, registry)
}
expected := []string{
"bar/c",
"bar/d",
"bar/e",
"foo/a",
"foo/b",
"foo/d/in",
"foo-bar/a",
"foo-bar/b",
"test",
}
return &setupEnv{
ctx: ctx,
driver: d,
expected: expected,
registry: registry,
}
}
func makeRepo(ctx context.Context, t *testing.T, name string, reg distribution.Namespace) {
named, err := reference.ParseNamed(name)
if err != nil {
t.Fatal(err)
}
repo, _ := reg.Repository(ctx, named)
manifests, _ := repo.Manifests(ctx)
layers, err := testutil.CreateRandomLayers(1)
if err != nil {
t.Fatal(err)
}
err = testutil.UploadBlobs(repo, layers)
if err != nil {
t.Fatalf("failed to upload layers: %v", err)
}
getKeys := func(digests map[digest.Digest]io.ReadSeeker) (ds []digest.Digest) {
for d := range digests {
ds = append(ds, d)
}
return
}
manifest, err := testutil.MakeSchema1Manifest(getKeys(layers))
if err != nil {
t.Fatal(err)
}
_, err = manifests.Put(ctx, manifest)
if err != nil {
t.Fatalf("manifest upload failed: %v", err)
}
}
func TestCatalog(t *testing.T) {
env := setupFS(t)
p := make([]string, 50)
numFilled, err := env.registry.Repositories(env.ctx, p, "")
if numFilled != len(env.expected) {
t.Errorf("missing items in catalog")
}
if !testEq(p, env.expected, len(env.expected)) {
t.Errorf("Expected catalog repos err")
}
if err != io.EOF {
t.Errorf("Catalog has more values which we aren't expecting")
}
}
func TestCatalogInParts(t *testing.T) {
env := setupFS(t)
chunkLen := 3
p := make([]string, chunkLen)
numFilled, err := env.registry.Repositories(env.ctx, p, "")
if err == io.EOF || numFilled != len(p) {
t.Errorf("Expected more values in catalog")
}
if !testEq(p, env.expected[0:chunkLen], numFilled) {
t.Errorf("Expected catalog first chunk err")
}
lastRepo := p[len(p)-1]
numFilled, err = env.registry.Repositories(env.ctx, p, lastRepo)
if err == io.EOF || numFilled != len(p) {
t.Errorf("Expected more values in catalog")
}
if !testEq(p, env.expected[chunkLen:chunkLen*2], numFilled) {
t.Errorf("Expected catalog second chunk err")
}
lastRepo = p[len(p)-1]
numFilled, err = env.registry.Repositories(env.ctx, p, lastRepo)
if err != io.EOF || numFilled != len(p) {
t.Errorf("Expected end of catalog")
}
if !testEq(p, env.expected[chunkLen*2:chunkLen*3], numFilled) {
t.Errorf("Expected catalog third chunk err")
}
lastRepo = p[len(p)-1]
numFilled, err = env.registry.Repositories(env.ctx, p, lastRepo)
if err != io.EOF {
t.Errorf("Catalog has more values which we aren't expecting")
}
if numFilled != 0 {
t.Errorf("Expected catalog fourth chunk err")
}
}
func TestCatalogEnumerate(t *testing.T) {
env := setupFS(t)
var repos []string
repositoryEnumerator := env.registry.(distribution.RepositoryEnumerator)
err := repositoryEnumerator.Enumerate(env.ctx, func(repoName string) error {
repos = append(repos, repoName)
return nil
})
if err != nil {
t.Errorf("Expected catalog enumerate err")
}
if len(repos) != len(env.expected) {
t.Errorf("Expected catalog enumerate doesn't have correct number of values")
}
if !testEq(repos, env.expected, len(env.expected)) {
t.Errorf("Expected catalog enumerate not over all values")
}
}
func testEq(a, b []string, size int) bool {
for cnt := 0; cnt < size-1; cnt++ {
if a[cnt] != b[cnt] {
return false
}
}
return true
}
func setupBadWalkEnv(t *testing.T) *setupEnv {
d := newBadListDriver()
ctx := context.Background()
registry, err := NewRegistry(ctx, d, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
}
return &setupEnv{
ctx: ctx,
driver: d,
registry: registry,
}
}
type badListDriver struct {
driver.StorageDriver
}
var _ driver.StorageDriver = &badListDriver{}
func newBadListDriver() *badListDriver {
return &badListDriver{StorageDriver: inmemory.New()}
}
func (d *badListDriver) List(ctx context.Context, path string) ([]string, error) {
return nil, fmt.Errorf("List error")
}
func TestCatalogWalkError(t *testing.T) {
env := setupBadWalkEnv(t)
p := make([]string, 1)
_, err := env.registry.Repositories(env.ctx, p, "")
if err == io.EOF {
t.Errorf("Expected catalog driver list error")
}
}
func BenchmarkPathCompareEqual(B *testing.B) {
B.StopTimer()
pp := randomPath(100)
// make a real copy
ppb := append([]byte{}, []byte(pp)...)
a, b := pp, string(ppb)
B.StartTimer()
for i := 0; i < B.N; i++ {
lessPath(a, b)
}
}
func BenchmarkPathCompareNotEqual(B *testing.B) {
B.StopTimer()
a, b := randomPath(100), randomPath(100)
B.StartTimer()
for i := 0; i < B.N; i++ {
lessPath(a, b)
}
}
func BenchmarkPathCompareNative(B *testing.B) {
B.StopTimer()
a, b := randomPath(100), randomPath(100)
B.StartTimer()
for i := 0; i < B.N; i++ {
c := a < b
c = c && false
}
}
func BenchmarkPathCompareNativeEqual(B *testing.B) {
B.StopTimer()
pp := randomPath(100)
a, b := pp, pp
B.StartTimer()
for i := 0; i < B.N; i++ {
c := a < b
c = c && false
}
}
var filenameChars = []byte("abcdefghijklmnopqrstuvwxyz0123456789")
var separatorChars = []byte("._-")
func randomPath(length int64) string {
path := "/"
for int64(len(path)) < length {
chunkLength := rand.Int63n(length-int64(len(path))) + 1
chunk := randomFilename(chunkLength)
path += chunk
remaining := length - int64(len(path))
if remaining == 1 {
path += randomFilename(1)
} else if remaining > 1 {
path += "/"
}
}
return path
}
func randomFilename(length int64) string {
b := make([]byte, length)
wasSeparator := true
for i := range b {
if !wasSeparator && i < len(b)-1 && rand.Intn(4) == 0 {
b[i] = separatorChars[rand.Intn(len(separatorChars))]
wasSeparator = true
} else {
b[i] = filenameChars[rand.Intn(len(filenameChars))]
wasSeparator = false
}
}
return string(b)
}

View File

@ -0,0 +1,3 @@
// Package storage contains storage services for use in the registry
// application. It should be considered an internal package, as of Go 1.4.
package storage

View File

@ -0,0 +1,483 @@
// Package azure provides a storagedriver.StorageDriver implementation to
// store blobs in Microsoft Azure Blob Storage Service.
package azure
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
const driverName = "azure"
const (
paramAccountName = "accountname"
paramAccountKey = "accountkey"
paramContainer = "container"
paramRealm = "realm"
maxChunkSize = 4 * 1024 * 1024
)
type driver struct {
client azure.BlobStorageClient
container string
}
type baseEmbed struct{ base.Base }
// Driver is a storagedriver.StorageDriver implementation backed by
// Microsoft Azure Blob Storage Service.
type Driver struct{ baseEmbed }
func init() {
factory.Register(driverName, &azureDriverFactory{})
}
type azureDriverFactory struct{}
func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
// FromParameters constructs a new Driver with a given parameters map.
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
accountName, ok := parameters[paramAccountName]
if !ok || fmt.Sprint(accountName) == "" {
return nil, fmt.Errorf("No %s parameter provided", paramAccountName)
}
accountKey, ok := parameters[paramAccountKey]
if !ok || fmt.Sprint(accountKey) == "" {
return nil, fmt.Errorf("No %s parameter provided", paramAccountKey)
}
container, ok := parameters[paramContainer]
if !ok || fmt.Sprint(container) == "" {
return nil, fmt.Errorf("No %s parameter provided", paramContainer)
}
realm, ok := parameters[paramRealm]
if !ok || fmt.Sprint(realm) == "" {
realm = azure.DefaultBaseURL
}
return New(fmt.Sprint(accountName), fmt.Sprint(accountKey), fmt.Sprint(container), fmt.Sprint(realm))
}
// New constructs a new Driver with the given Azure Storage Account credentials
func New(accountName, accountKey, container, realm string) (*Driver, error) {
api, err := azure.NewClient(accountName, accountKey, realm, azure.DefaultAPIVersion, true)
if err != nil {
return nil, err
}
blobClient := api.GetBlobService()
// Create registry container
if _, err = blobClient.CreateContainerIfNotExists(container, azure.ContainerAccessTypePrivate); err != nil {
return nil, err
}
d := &driver{
client: blobClient,
container: container}
return &Driver{baseEmbed: baseEmbed{Base: base.Base{StorageDriver: d}}}, nil
}
// Implement the storagedriver.StorageDriver interface.
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
blob, err := d.client.GetBlob(d.container, path)
if err != nil {
if is404(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
defer blob.Close()
return ioutil.ReadAll(blob)
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
if _, err := d.client.DeleteBlobIfExists(d.container, path, nil); err != nil {
return err
}
writer, err := d.Writer(ctx, path, false)
if err != nil {
return err
}
defer writer.Close()
_, err = writer.Write(contents)
if err != nil {
return err
}
return writer.Commit()
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
if ok, err := d.client.BlobExists(d.container, path); err != nil {
return nil, err
} else if !ok {
return nil, storagedriver.PathNotFoundError{Path: path}
}
info, err := d.client.GetBlobProperties(d.container, path)
if err != nil {
return nil, err
}
size := int64(info.ContentLength)
if offset >= size {
return ioutil.NopCloser(bytes.NewReader(nil)), nil
}
bytesRange := fmt.Sprintf("%v-", offset)
resp, err := d.client.GetBlobRange(d.container, path, bytesRange, nil)
if err != nil {
return nil, err
}
return resp, nil
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
blobExists, err := d.client.BlobExists(d.container, path)
if err != nil {
return nil, err
}
var size int64
if blobExists {
if append {
blobProperties, err := d.client.GetBlobProperties(d.container, path)
if err != nil {
return nil, err
}
size = blobProperties.ContentLength
} else {
err := d.client.DeleteBlob(d.container, path, nil)
if err != nil {
return nil, err
}
}
} else {
if append {
return nil, storagedriver.PathNotFoundError{Path: path}
}
err := d.client.PutAppendBlob(d.container, path, nil)
if err != nil {
return nil, err
}
}
return d.newWriter(path, size), nil
}
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
// Check if the path is a blob
if ok, err := d.client.BlobExists(d.container, path); err != nil {
return nil, err
} else if ok {
blob, err := d.client.GetBlobProperties(d.container, path)
if err != nil {
return nil, err
}
mtim, err := time.Parse(http.TimeFormat, blob.LastModified)
if err != nil {
return nil, err
}
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: int64(blob.ContentLength),
ModTime: mtim,
IsDir: false,
}}, nil
}
// Check if path is a virtual container
virtContainerPath := path
if !strings.HasSuffix(virtContainerPath, "/") {
virtContainerPath += "/"
}
blobs, err := d.client.ListBlobs(d.container, azure.ListBlobsParameters{
Prefix: virtContainerPath,
MaxResults: 1,
})
if err != nil {
return nil, err
}
if len(blobs.Blobs) > 0 {
// path is a virtual container
return storagedriver.FileInfoInternal{FileInfoFields: storagedriver.FileInfoFields{
Path: path,
IsDir: true,
}}, nil
}
// path is not a blob or virtual container
return nil, storagedriver.PathNotFoundError{Path: path}
}
// List returns a list of the objects that are direct descendants of the given
// path.
func (d *driver) List(ctx context.Context, path string) ([]string, error) {
if path == "/" {
path = ""
}
blobs, err := d.listBlobs(d.container, path)
if err != nil {
return blobs, err
}
list := directDescendants(blobs, path)
if path != "" && len(list) == 0 {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return list, nil
}
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
sourceBlobURL := d.client.GetBlobURL(d.container, sourcePath)
err := d.client.CopyBlob(d.container, destPath, sourceBlobURL)
if err != nil {
if is404(err) {
return storagedriver.PathNotFoundError{Path: sourcePath}
}
return err
}
return d.client.DeleteBlob(d.container, sourcePath, nil)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
ok, err := d.client.DeleteBlobIfExists(d.container, path, nil)
if err != nil {
return err
}
if ok {
return nil // was a blob and deleted, return
}
// Not a blob, see if path is a virtual container with blobs
blobs, err := d.listBlobs(d.container, path)
if err != nil {
return err
}
for _, b := range blobs {
if err = d.client.DeleteBlob(d.container, b, nil); err != nil {
return err
}
}
if len(blobs) == 0 {
return storagedriver.PathNotFoundError{Path: path}
}
return nil
}
// URLFor returns a publicly accessible URL for the blob stored at given path
// for specified duration by making use of Azure Storage Shared Access Signatures (SAS).
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration
expires, ok := options["expiry"]
if ok {
t, ok := expires.(time.Time)
if ok {
expiresTime = t
}
}
return d.client.GetBlobSASURI(d.container, path, expiresTime, "r")
}
// directDescendants will find direct descendants (blobs or virtual containers)
// of from list of blob paths and will return their full paths. Elements in blobs
// list must be prefixed with a "/" and
//
// Example: direct descendants of "/" in {"/foo", "/bar/1", "/bar/2"} is
// {"/foo", "/bar"} and direct descendants of "bar" is {"/bar/1", "/bar/2"}
func directDescendants(blobs []string, prefix string) []string {
if !strings.HasPrefix(prefix, "/") { // add trailing '/'
prefix = "/" + prefix
}
if !strings.HasSuffix(prefix, "/") { // containerify the path
prefix += "/"
}
out := make(map[string]bool)
for _, b := range blobs {
if strings.HasPrefix(b, prefix) {
rel := b[len(prefix):]
c := strings.Count(rel, "/")
if c == 0 {
out[b] = true
} else {
out[prefix+rel[:strings.Index(rel, "/")]] = true
}
}
}
var keys []string
for k := range out {
keys = append(keys, k)
}
return keys
}
func (d *driver) listBlobs(container, virtPath string) ([]string, error) {
if virtPath != "" && !strings.HasSuffix(virtPath, "/") { // containerify the path
virtPath += "/"
}
out := []string{}
marker := ""
for {
resp, err := d.client.ListBlobs(d.container, azure.ListBlobsParameters{
Marker: marker,
Prefix: virtPath,
})
if err != nil {
return out, err
}
for _, b := range resp.Blobs {
out = append(out, b.Name)
}
if len(resp.Blobs) == 0 || resp.NextMarker == "" {
break
}
marker = resp.NextMarker
}
return out, nil
}
func is404(err error) bool {
statusCodeErr, ok := err.(azure.AzureStorageServiceError)
return ok && statusCodeErr.StatusCode == http.StatusNotFound
}
type writer struct {
driver *driver
path string
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(path string, size int64) storagedriver.FileWriter {
return &writer{
driver: d,
path: path,
size: size,
bw: bufio.NewWriterSize(&blockWriter{
client: d.client,
container: d.container,
path: path,
}, maxChunkSize),
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := w.bw.Write(p)
w.size += int64(n)
return n, err
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.bw.Flush()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.client.DeleteBlob(w.driver.container, w.path, nil)
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
w.committed = true
return w.bw.Flush()
}
type blockWriter struct {
client azure.BlobStorageClient
container string
path string
}
func (bw *blockWriter) Write(p []byte) (int, error) {
n := 0
for offset := 0; offset < len(p); offset += maxChunkSize {
chunkSize := maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
err := bw.client.AppendBlock(bw.container, bw.path, p[offset:offset+chunkSize], nil)
if err != nil {
return n, err
}
n += chunkSize
}
return n, nil
}

View File

@ -0,0 +1,63 @@
package azure
import (
"fmt"
"os"
"strings"
"testing"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
. "gopkg.in/check.v1"
)
const (
envAccountName = "AZURE_STORAGE_ACCOUNT_NAME"
envAccountKey = "AZURE_STORAGE_ACCOUNT_KEY"
envContainer = "AZURE_STORAGE_CONTAINER"
envRealm = "AZURE_STORAGE_REALM"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { TestingT(t) }
func init() {
var (
accountName string
accountKey string
container string
realm string
)
config := []struct {
env string
value *string
}{
{envAccountName, &accountName},
{envAccountKey, &accountKey},
{envContainer, &container},
{envRealm, &realm},
}
missing := []string{}
for _, v := range config {
*v.value = os.Getenv(v.env)
if *v.value == "" {
missing = append(missing, v.env)
}
}
azureDriverConstructor := func() (storagedriver.StorageDriver, error) {
return New(accountName, accountKey, container, realm)
}
// Skip Azure storage driver tests if environment variable parameters are not provided
skipCheck := func() string {
if len(missing) > 0 {
return fmt.Sprintf("Must set %s environment variables to run Azure tests", strings.Join(missing, ", "))
}
return ""
}
testsuites.RegisterSuite(azureDriverConstructor, skipCheck)
}

View File

@ -0,0 +1,198 @@
// Package base provides a base implementation of the storage driver that can
// be used to implement common checks. The goal is to increase the amount of
// code sharing.
//
// The canonical approach to use this class is to embed in the exported driver
// struct such that calls are proxied through this implementation. First,
// declare the internal driver, as follows:
//
// type driver struct { ... internal ...}
//
// The resulting type should implement StorageDriver such that it can be the
// target of a Base struct. The exported type can then be declared as follows:
//
// type Driver struct {
// Base
// }
//
// Because Driver embeds Base, it effectively implements Base. If the driver
// needs to intercept a call, before going to base, Driver should implement
// that method. Effectively, Driver can intercept calls before coming in and
// driver implements the actual logic.
//
// To further shield the embed from other packages, it is recommended to
// employ a private embed struct:
//
// type baseEmbed struct {
// base.Base
// }
//
// Then, declare driver to embed baseEmbed, rather than Base directly:
//
// type Driver struct {
// baseEmbed
// }
//
// The type now implements StorageDriver, proxying through Base, without
// exporting an unnecessary field.
package base
import (
"io"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
// Base provides a wrapper around a storagedriver implementation that provides
// common path and bounds checking.
type Base struct {
storagedriver.StorageDriver
}
// Format errors received from the storage driver
func (base *Base) setDriverName(e error) error {
switch actual := e.(type) {
case nil:
return nil
case storagedriver.ErrUnsupportedMethod:
actual.DriverName = base.StorageDriver.Name()
return actual
case storagedriver.PathNotFoundError:
actual.DriverName = base.StorageDriver.Name()
return actual
case storagedriver.InvalidPathError:
actual.DriverName = base.StorageDriver.Name()
return actual
case storagedriver.InvalidOffsetError:
actual.DriverName = base.StorageDriver.Name()
return actual
default:
storageError := storagedriver.Error{
DriverName: base.StorageDriver.Name(),
Enclosed: e,
}
return storageError
}
}
// GetContent wraps GetContent of underlying storage driver.
func (base *Base) GetContent(ctx context.Context, path string) ([]byte, error) {
ctx, done := context.WithTrace(ctx)
defer done("%s.GetContent(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
b, e := base.StorageDriver.GetContent(ctx, path)
return b, base.setDriverName(e)
}
// PutContent wraps PutContent of underlying storage driver.
func (base *Base) PutContent(ctx context.Context, path string, content []byte) error {
ctx, done := context.WithTrace(ctx)
defer done("%s.PutContent(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.PutContent(ctx, path, content))
}
// Reader wraps Reader of underlying storage driver.
func (base *Base) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
ctx, done := context.WithTrace(ctx)
defer done("%s.Reader(%q, %d)", base.Name(), path, offset)
if offset < 0 {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()}
}
if !storagedriver.PathRegexp.MatchString(path) {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
rc, e := base.StorageDriver.Reader(ctx, path, offset)
return rc, base.setDriverName(e)
}
// Writer wraps Writer of underlying storage driver.
func (base *Base) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
ctx, done := context.WithTrace(ctx)
defer done("%s.Writer(%q, %v)", base.Name(), path, append)
if !storagedriver.PathRegexp.MatchString(path) {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
writer, e := base.StorageDriver.Writer(ctx, path, append)
return writer, base.setDriverName(e)
}
// Stat wraps Stat of underlying storage driver.
func (base *Base) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
ctx, done := context.WithTrace(ctx)
defer done("%s.Stat(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) && path != "/" {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
fi, e := base.StorageDriver.Stat(ctx, path)
return fi, base.setDriverName(e)
}
// List wraps List of underlying storage driver.
func (base *Base) List(ctx context.Context, path string) ([]string, error) {
ctx, done := context.WithTrace(ctx)
defer done("%s.List(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) && path != "/" {
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
str, e := base.StorageDriver.List(ctx, path)
return str, base.setDriverName(e)
}
// Move wraps Move of underlying storage driver.
func (base *Base) Move(ctx context.Context, sourcePath string, destPath string) error {
ctx, done := context.WithTrace(ctx)
defer done("%s.Move(%q, %q", base.Name(), sourcePath, destPath)
if !storagedriver.PathRegexp.MatchString(sourcePath) {
return storagedriver.InvalidPathError{Path: sourcePath, DriverName: base.StorageDriver.Name()}
} else if !storagedriver.PathRegexp.MatchString(destPath) {
return storagedriver.InvalidPathError{Path: destPath, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.Move(ctx, sourcePath, destPath))
}
// Delete wraps Delete of underlying storage driver.
func (base *Base) Delete(ctx context.Context, path string) error {
ctx, done := context.WithTrace(ctx)
defer done("%s.Delete(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
return base.setDriverName(base.StorageDriver.Delete(ctx, path))
}
// URLFor wraps URLFor of underlying storage driver.
func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
ctx, done := context.WithTrace(ctx)
defer done("%s.URLFor(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
str, e := base.StorageDriver.URLFor(ctx, path, options)
return str, base.setDriverName(e)
}

View File

@ -0,0 +1,145 @@
package base
import (
"io"
"sync"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
type regulator struct {
storagedriver.StorageDriver
*sync.Cond
available uint64
}
// NewRegulator wraps the given driver and is used to regulate concurrent calls
// to the given storage driver to a maximum of the given limit. This is useful
// for storage drivers that would otherwise create an unbounded number of OS
// threads if allowed to be called unregulated.
func NewRegulator(driver storagedriver.StorageDriver, limit uint64) storagedriver.StorageDriver {
return &regulator{
StorageDriver: driver,
Cond: sync.NewCond(&sync.Mutex{}),
available: limit,
}
}
func (r *regulator) enter() {
r.L.Lock()
for r.available == 0 {
r.Wait()
}
r.available--
r.L.Unlock()
}
func (r *regulator) exit() {
r.L.Lock()
// We only need to signal to a waiting FS operation if we're already at the
// limit of threads used
if r.available == 0 {
r.Signal()
}
r.available++
r.L.Unlock()
}
// Name returns the human-readable "name" of the driver, useful in error
// messages and logging. By convention, this will just be the registration
// name, but drivers may provide other information here.
func (r *regulator) Name() string {
r.enter()
defer r.exit()
return r.StorageDriver.Name()
}
// GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects.
func (r *regulator) GetContent(ctx context.Context, path string) ([]byte, error) {
r.enter()
defer r.exit()
return r.StorageDriver.GetContent(ctx, path)
}
// PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects.
func (r *regulator) PutContent(ctx context.Context, path string, content []byte) error {
r.enter()
defer r.exit()
return r.StorageDriver.PutContent(ctx, path, content)
}
// Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset.
func (r *regulator) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Reader(ctx, path, offset)
}
// Writer stores the contents of the provided io.ReadCloser at a
// location designated by the given path.
// May be used to resume writing a stream by providing a nonzero offset.
// The offset must be no larger than the CurrentSize for this path.
func (r *regulator) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Writer(ctx, path, append)
}
// Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time.
func (r *regulator) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Stat(ctx, path)
}
// List returns a list of the objects that are direct descendants of the
//given path.
func (r *regulator) List(ctx context.Context, path string) ([]string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.List(ctx, path)
}
// Move moves an object stored at sourcePath to destPath, removing the
// original object.
// Note: This may be no more efficient than a copy followed by a delete for
// many implementations.
func (r *regulator) Move(ctx context.Context, sourcePath string, destPath string) error {
r.enter()
defer r.exit()
return r.StorageDriver.Move(ctx, sourcePath, destPath)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (r *regulator) Delete(ctx context.Context, path string) error {
r.enter()
defer r.exit()
return r.StorageDriver.Delete(ctx, path)
}
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.URLFor(ctx, path, options)
}

View File

@ -0,0 +1,64 @@
package factory
import (
"fmt"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
// driverFactories stores an internal mapping between storage driver names and their respective
// factories
var driverFactories = make(map[string]StorageDriverFactory)
// StorageDriverFactory is a factory interface for creating storagedriver.StorageDriver interfaces
// Storage drivers should call Register() with a factory to make the driver available by name.
// Individual StorageDriver implementations generally register with the factory via the Register
// func (below) in their init() funcs, and as such they should be imported anonymously before use.
// See below for an example of how to register and get a StorageDriver for S3
//
// import _ "github.com/docker/distribution/registry/storage/driver/s3-aws"
// s3Driver, err = factory.Create("s3", storageParams)
// // assuming no error, s3Driver is the StorageDriver that communicates with S3 according to storageParams
type StorageDriverFactory interface {
// Create returns a new storagedriver.StorageDriver with the given parameters
// Parameters will vary by driver and may be ignored
// Each parameter key must only consist of lowercase letters and numbers
Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error)
}
// Register makes a storage driver available by the provided name.
// If Register is called twice with the same name or if driver factory is nil, it panics.
// Additionally, it is not concurrency safe. Most Storage Drivers call this function
// in their init() functions. See the documentation for StorageDriverFactory for more.
func Register(name string, factory StorageDriverFactory) {
if factory == nil {
panic("Must not provide nil StorageDriverFactory")
}
_, registered := driverFactories[name]
if registered {
panic(fmt.Sprintf("StorageDriverFactory named %s already registered", name))
}
driverFactories[name] = factory
}
// Create a new storagedriver.StorageDriver with the given name and
// parameters. To use a driver, the StorageDriverFactory must first be
// registered with the given name. If no drivers are found, an
// InvalidStorageDriverError is returned
func Create(name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
driverFactory, ok := driverFactories[name]
if !ok {
return nil, InvalidStorageDriverError{name}
}
return driverFactory.Create(parameters)
}
// InvalidStorageDriverError records an attempt to construct an unregistered storage driver
type InvalidStorageDriverError struct {
Name string
}
func (err InvalidStorageDriverError) Error() string {
return fmt.Sprintf("StorageDriver not registered: %s", err.Name)
}

View File

@ -0,0 +1,79 @@
package driver
import "time"
// FileInfo returns information about a given path. Inspired by os.FileInfo,
// it elides the base name method for a full path instead.
type FileInfo interface {
// Path provides the full path of the target of this file info.
Path() string
// Size returns current length in bytes of the file. The return value can
// be used to write to the end of the file at path. The value is
// meaningless if IsDir returns true.
Size() int64
// ModTime returns the modification time for the file. For backends that
// don't have a modification time, the creation time should be returned.
ModTime() time.Time
// IsDir returns true if the path is a directory.
IsDir() bool
}
// NOTE(stevvooe): The next two types, FileInfoFields and FileInfoInternal
// should only be used by storagedriver implementations. They should moved to
// a "driver" package, similar to database/sql.
// FileInfoFields provides the exported fields for implementing FileInfo
// interface in storagedriver implementations. It should be used with
// InternalFileInfo.
type FileInfoFields struct {
// Path provides the full path of the target of this file info.
Path string
// Size is current length in bytes of the file. The value of this field
// can be used to write to the end of the file at path. The value is
// meaningless if IsDir is set to true.
Size int64
// ModTime returns the modification time for the file. For backends that
// don't have a modification time, the creation time should be returned.
ModTime time.Time
// IsDir returns true if the path is a directory.
IsDir bool
}
// FileInfoInternal implements the FileInfo interface. This should only be
// used by storagedriver implementations that don't have a specialized
// FileInfo type.
type FileInfoInternal struct {
FileInfoFields
}
var _ FileInfo = FileInfoInternal{}
var _ FileInfo = &FileInfoInternal{}
// Path provides the full path of the target of this file info.
func (fi FileInfoInternal) Path() string {
return fi.FileInfoFields.Path
}
// Size returns current length in bytes of the file. The return value can
// be used to write to the end of the file at path. The value is
// meaningless if IsDir returns true.
func (fi FileInfoInternal) Size() int64 {
return fi.FileInfoFields.Size
}
// ModTime returns the modification time for the file. For backends that
// don't have a modification time, the creation time should be returned.
func (fi FileInfoInternal) ModTime() time.Time {
return fi.FileInfoFields.ModTime
}
// IsDir returns true if the path is a directory.
func (fi FileInfoInternal) IsDir() bool {
return fi.FileInfoFields.IsDir
}

View File

@ -0,0 +1,440 @@
package filesystem
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"reflect"
"strconv"
"time"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
)
const (
driverName = "filesystem"
defaultRootDirectory = "/var/lib/registry"
defaultMaxThreads = uint64(100)
// minThreads is the minimum value for the maxthreads configuration
// parameter. If the driver's parameters are less than this we set
// the parameters to minThreads
minThreads = uint64(25)
)
// DriverParameters represents all configuration options available for the
// filesystem driver
type DriverParameters struct {
RootDirectory string
MaxThreads uint64
}
func init() {
factory.Register(driverName, &filesystemDriverFactory{})
}
// filesystemDriverFactory implements the factory.StorageDriverFactory interface
type filesystemDriverFactory struct{}
func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
type driver struct {
rootDirectory string
}
type baseEmbed struct {
base.Base
}
// Driver is a storagedriver.StorageDriver implementation backed by a local
// filesystem. All provided paths will be subpaths of the RootDirectory.
type Driver struct {
baseEmbed
}
// FromParameters constructs a new Driver with a given parameters map
// Optional Parameters:
// - rootdirectory
// - maxthreads
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
params, err := fromParametersImpl(parameters)
if err != nil || params == nil {
return nil, err
}
return New(*params), nil
}
func fromParametersImpl(parameters map[string]interface{}) (*DriverParameters, error) {
var (
err error
maxThreads = defaultMaxThreads
rootDirectory = defaultRootDirectory
)
if parameters != nil {
if rootDir, ok := parameters["rootdirectory"]; ok {
rootDirectory = fmt.Sprint(rootDir)
}
// Get maximum number of threads for blocking filesystem operations,
// if specified
threads := parameters["maxthreads"]
switch v := threads.(type) {
case string:
if maxThreads, err = strconv.ParseUint(v, 0, 64); err != nil {
return nil, fmt.Errorf("maxthreads parameter must be an integer, %v invalid", threads)
}
case uint64:
maxThreads = v
case int, int32, int64:
val := reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Int()
// If threads is negative casting to uint64 will wrap around and
// give you the hugest thread limit ever. Let's be sensible, here
if val > 0 {
maxThreads = uint64(val)
}
case uint, uint32:
maxThreads = reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Uint()
case nil:
// do nothing
default:
return nil, fmt.Errorf("invalid value for maxthreads: %#v", threads)
}
if maxThreads < minThreads {
maxThreads = minThreads
}
}
params := &DriverParameters{
RootDirectory: rootDirectory,
MaxThreads: maxThreads,
}
return params, nil
}
// New constructs a new Driver with a given rootDirectory
func New(params DriverParameters) *Driver {
fsDriver := &driver{rootDirectory: params.RootDirectory}
return &Driver{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: base.NewRegulator(fsDriver, params.MaxThreads),
},
},
}
}
// Implement the storagedriver.StorageDriver interface
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
rc, err := d.Reader(ctx, path, 0)
if err != nil {
return nil, err
}
defer rc.Close()
p, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
return p, nil
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte) error {
writer, err := d.Writer(ctx, subPath, false)
if err != nil {
return err
}
defer writer.Close()
_, err = io.Copy(writer, bytes.NewReader(contents))
if err != nil {
writer.Cancel()
return err
}
return writer.Commit()
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
file, err := os.OpenFile(d.fullPath(path), os.O_RDONLY, 0644)
if err != nil {
if os.IsNotExist(err) {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
seekPos, err := file.Seek(int64(offset), os.SEEK_SET)
if err != nil {
file.Close()
return nil, err
} else if seekPos < int64(offset) {
file.Close()
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
return file, nil
}
func (d *driver) Writer(ctx context.Context, subPath string, append bool) (storagedriver.FileWriter, error) {
fullPath := d.fullPath(subPath)
parentDir := path.Dir(fullPath)
if err := os.MkdirAll(parentDir, 0777); err != nil {
return nil, err
}
fp, err := os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE, 0666)
if err != nil {
return nil, err
}
var offset int64
if !append {
err := fp.Truncate(0)
if err != nil {
fp.Close()
return nil, err
}
} else {
n, err := fp.Seek(0, os.SEEK_END)
if err != nil {
fp.Close()
return nil, err
}
offset = int64(n)
}
return newFileWriter(fp, offset), nil
}
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, subPath string) (storagedriver.FileInfo, error) {
fullPath := d.fullPath(subPath)
fi, err := os.Stat(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, storagedriver.PathNotFoundError{Path: subPath}
}
return nil, err
}
return fileInfo{
path: subPath,
FileInfo: fi,
}, nil
}
// List returns a list of the objects that are direct descendants of the given
// path.
func (d *driver) List(ctx context.Context, subPath string) ([]string, error) {
fullPath := d.fullPath(subPath)
dir, err := os.Open(fullPath)
if err != nil {
if os.IsNotExist(err) {
return nil, storagedriver.PathNotFoundError{Path: subPath}
}
return nil, err
}
defer dir.Close()
fileNames, err := dir.Readdirnames(0)
if err != nil {
return nil, err
}
keys := make([]string, 0, len(fileNames))
for _, fileName := range fileNames {
keys = append(keys, path.Join(subPath, fileName))
}
return keys, nil
}
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
source := d.fullPath(sourcePath)
dest := d.fullPath(destPath)
if _, err := os.Stat(source); os.IsNotExist(err) {
return storagedriver.PathNotFoundError{Path: sourcePath}
}
if err := os.MkdirAll(path.Dir(dest), 0755); err != nil {
return err
}
err := os.Rename(source, dest)
return err
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, subPath string) error {
fullPath := d.fullPath(subPath)
_, err := os.Stat(fullPath)
if err != nil && !os.IsNotExist(err) {
return err
} else if err != nil {
return storagedriver.PathNotFoundError{Path: subPath}
}
err = os.RemoveAll(fullPath)
return err
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
}
// fullPath returns the absolute path of a key within the Driver's storage.
func (d *driver) fullPath(subPath string) string {
return path.Join(d.rootDirectory, subPath)
}
type fileInfo struct {
os.FileInfo
path string
}
var _ storagedriver.FileInfo = fileInfo{}
// Path provides the full path of the target of this file info.
func (fi fileInfo) Path() string {
return fi.path
}
// Size returns current length in bytes of the file. The return value can
// be used to write to the end of the file at path. The value is
// meaningless if IsDir returns true.
func (fi fileInfo) Size() int64 {
if fi.IsDir() {
return 0
}
return fi.FileInfo.Size()
}
// ModTime returns the modification time for the file. For backends that
// don't have a modification time, the creation time should be returned.
func (fi fileInfo) ModTime() time.Time {
return fi.FileInfo.ModTime()
}
// IsDir returns true if the path is a directory.
func (fi fileInfo) IsDir() bool {
return fi.FileInfo.IsDir()
}
type fileWriter struct {
file *os.File
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func newFileWriter(file *os.File, size int64) *fileWriter {
return &fileWriter{
file: file,
size: size,
bw: bufio.NewWriter(file),
}
}
func (fw *fileWriter) Write(p []byte) (int, error) {
if fw.closed {
return 0, fmt.Errorf("already closed")
} else if fw.committed {
return 0, fmt.Errorf("already committed")
} else if fw.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := fw.bw.Write(p)
fw.size += int64(n)
return n, err
}
func (fw *fileWriter) Size() int64 {
return fw.size
}
func (fw *fileWriter) Close() error {
if fw.closed {
return fmt.Errorf("already closed")
}
if err := fw.bw.Flush(); err != nil {
return err
}
if err := fw.file.Sync(); err != nil {
return err
}
if err := fw.file.Close(); err != nil {
return err
}
fw.closed = true
return nil
}
func (fw *fileWriter) Cancel() error {
if fw.closed {
return fmt.Errorf("already closed")
}
fw.cancelled = true
fw.file.Close()
return os.Remove(fw.file.Name())
}
func (fw *fileWriter) Commit() error {
if fw.closed {
return fmt.Errorf("already closed")
} else if fw.committed {
return fmt.Errorf("already committed")
} else if fw.cancelled {
return fmt.Errorf("already cancelled")
}
if err := fw.bw.Flush(); err != nil {
return err
}
if err := fw.file.Sync(); err != nil {
return err
}
fw.committed = true
return nil
}

View File

@ -0,0 +1,113 @@
package filesystem
import (
"io/ioutil"
"os"
"reflect"
"testing"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
. "gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { TestingT(t) }
func init() {
root, err := ioutil.TempDir("", "driver-")
if err != nil {
panic(err)
}
defer os.Remove(root)
driver, err := FromParameters(map[string]interface{}{
"rootdirectory": root,
})
if err != nil {
panic(err)
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return driver, nil
}, testsuites.NeverSkip)
}
func TestFromParametersImpl(t *testing.T) {
tests := []struct {
params map[string]interface{} // techincally the yaml can contain anything
expected DriverParameters
pass bool
}{
// check we use default threads and root dirs
{
params: map[string]interface{}{},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: defaultMaxThreads,
},
pass: true,
},
// Testing initiation with a string maxThreads which can't be parsed
{
params: map[string]interface{}{
"maxthreads": "fail",
},
expected: DriverParameters{},
pass: false,
},
{
params: map[string]interface{}{
"maxthreads": "100",
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: uint64(100),
},
pass: true,
},
{
params: map[string]interface{}{
"maxthreads": 100,
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: uint64(100),
},
pass: true,
},
// check that we use minimum thread counts
{
params: map[string]interface{}{
"maxthreads": 1,
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: minThreads,
},
pass: true,
},
}
for _, item := range tests {
params, err := fromParametersImpl(item.params)
if !item.pass {
// We only need to assert that expected failures have an error
if err == nil {
t.Fatalf("expected error configuring filesystem driver with invalid param: %+v", item.params)
}
continue
}
if err != nil {
t.Fatalf("unexpected error creating filesystem driver: %s", err)
}
// Note that we get a pointer to params back
if !reflect.DeepEqual(*params, item.expected) {
t.Fatalf("unexpected params from filesystem driver. expected %+v, got %+v", item.expected, params)
}
}
}

View File

@ -0,0 +1,3 @@
// Package gcs implements the Google Cloud Storage driver backend. Support can be
// enabled by including the "include_gcs" build tag.
package gcs

View File

@ -0,0 +1,873 @@
// Package gcs provides a storagedriver.StorageDriver implementation to
// store blobs in Google cloud storage.
//
// This package leverages the google.golang.org/cloud/storage client library
//for interfacing with gcs.
//
// Because gcs is a key, value store the Stat call does not support last modification
// time for directories (directories are an abstraction for key, value stores)
//
// Note that the contents of incomplete uploads are not accessible even though
// Stat returns their length
//
// +build include_gcs
package gcs
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"time"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt"
"google.golang.org/api/googleapi"
"google.golang.org/cloud"
"google.golang.org/cloud/storage"
"github.com/Sirupsen/logrus"
ctx "github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
)
const (
driverName = "gcs"
dummyProjectID = "<unknown>"
uploadSessionContentType = "application/x-docker-upload-session"
minChunkSize = 256 * 1024
defaultChunkSize = 20 * minChunkSize
maxTries = 5
)
var rangeHeader = regexp.MustCompile(`^bytes=([0-9])+-([0-9]+)$`)
// driverParameters is a struct that encapsulates all of the driver parameters after all values have been set
type driverParameters struct {
bucket string
config *jwt.Config
email string
privateKey []byte
client *http.Client
rootDirectory string
chunkSize int
}
func init() {
factory.Register(driverName, &gcsDriverFactory{})
}
// gcsDriverFactory implements the factory.StorageDriverFactory interface
type gcsDriverFactory struct{}
// Create StorageDriver from parameters
func (factory *gcsDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
// driver is a storagedriver.StorageDriver implementation backed by GCS
// Objects are stored at absolute keys in the provided bucket.
type driver struct {
client *http.Client
bucket string
email string
privateKey []byte
rootDirectory string
chunkSize int
}
// FromParameters constructs a new Driver with a given parameters map
// Required parameters:
// - bucket
func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
bucket, ok := parameters["bucket"]
if !ok || fmt.Sprint(bucket) == "" {
return nil, fmt.Errorf("No bucket parameter provided")
}
rootDirectory, ok := parameters["rootdirectory"]
if !ok {
rootDirectory = ""
}
chunkSize := defaultChunkSize
chunkSizeParam, ok := parameters["chunksize"]
if ok {
switch v := chunkSizeParam.(type) {
case string:
vv, err := strconv.Atoi(v)
if err != nil {
return nil, fmt.Errorf("chunksize parameter must be an integer, %v invalid", chunkSizeParam)
}
chunkSize = vv
case int, uint, int32, uint32, uint64, int64:
chunkSize = int(reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int())
default:
return nil, fmt.Errorf("invalid valud for chunksize: %#v", chunkSizeParam)
}
if chunkSize < minChunkSize {
return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize)
}
if chunkSize%minChunkSize != 0 {
return nil, fmt.Errorf("chunksize should be a multiple of %d", minChunkSize)
}
}
var ts oauth2.TokenSource
jwtConf := new(jwt.Config)
if keyfile, ok := parameters["keyfile"]; ok {
jsonKey, err := ioutil.ReadFile(fmt.Sprint(keyfile))
if err != nil {
return nil, err
}
jwtConf, err = google.JWTConfigFromJSON(jsonKey, storage.ScopeFullControl)
if err != nil {
return nil, err
}
ts = jwtConf.TokenSource(context.Background())
} else {
var err error
ts, err = google.DefaultTokenSource(context.Background(), storage.ScopeFullControl)
if err != nil {
return nil, err
}
}
params := driverParameters{
bucket: fmt.Sprint(bucket),
rootDirectory: fmt.Sprint(rootDirectory),
email: jwtConf.Email,
privateKey: jwtConf.PrivateKey,
client: oauth2.NewClient(context.Background(), ts),
chunkSize: chunkSize,
}
return New(params)
}
// New constructs a new driver
func New(params driverParameters) (storagedriver.StorageDriver, error) {
rootDirectory := strings.Trim(params.rootDirectory, "/")
if rootDirectory != "" {
rootDirectory += "/"
}
if params.chunkSize <= 0 || params.chunkSize%minChunkSize != 0 {
return nil, fmt.Errorf("Invalid chunksize: %d is not a positive multiple of %d", params.chunkSize, minChunkSize)
}
d := &driver{
bucket: params.bucket,
rootDirectory: rootDirectory,
email: params.email,
privateKey: params.privateKey,
client: params.client,
chunkSize: params.chunkSize,
}
return &base.Base{
StorageDriver: d,
}, nil
}
// Implement the storagedriver.StorageDriver interface
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects.
func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
gcsContext := d.context(context)
name := d.pathToKey(path)
var rc io.ReadCloser
err := retry(func() error {
var err error
rc, err = storage.NewReader(gcsContext, d.bucket, name)
return err
})
if err == storage.ErrObjectNotExist {
return nil, storagedriver.PathNotFoundError{Path: path}
}
if err != nil {
return nil, err
}
defer rc.Close()
p, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
return p, nil
}
// PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects.
func (d *driver) PutContent(context ctx.Context, path string, contents []byte) error {
return retry(func() error {
wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path))
wc.ContentType = "application/octet-stream"
return putContentsClose(wc, contents)
})
}
// Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset.
func (d *driver) Reader(context ctx.Context, path string, offset int64) (io.ReadCloser, error) {
res, err := getObject(d.client, d.bucket, d.pathToKey(path), offset)
if err != nil {
if res != nil {
if res.StatusCode == http.StatusNotFound {
res.Body.Close()
return nil, storagedriver.PathNotFoundError{Path: path}
}
if res.StatusCode == http.StatusRequestedRangeNotSatisfiable {
res.Body.Close()
obj, err := storageStatObject(d.context(context), d.bucket, d.pathToKey(path))
if err != nil {
return nil, err
}
if offset == int64(obj.Size) {
return ioutil.NopCloser(bytes.NewReader([]byte{})), nil
}
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
}
return nil, err
}
if res.Header.Get("Content-Type") == uploadSessionContentType {
defer res.Body.Close()
return nil, storagedriver.PathNotFoundError{Path: path}
}
return res.Body, nil
}
func getObject(client *http.Client, bucket string, name string, offset int64) (*http.Response, error) {
// copied from google.golang.org/cloud/storage#NewReader :
// to set the additional "Range" header
u := &url.URL{
Scheme: "https",
Host: "storage.googleapis.com",
Path: fmt.Sprintf("/%s/%s", bucket, name),
}
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, err
}
if offset > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%v-", offset))
}
var res *http.Response
err = retry(func() error {
var err error
res, err = client.Do(req)
return err
})
if err != nil {
return nil, err
}
return res, googleapi.CheckMediaResponse(res)
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(context ctx.Context, path string, append bool) (storagedriver.FileWriter, error) {
writer := &writer{
client: d.client,
bucket: d.bucket,
name: d.pathToKey(path),
buffer: make([]byte, d.chunkSize),
}
if append {
err := writer.init(path)
if err != nil {
return nil, err
}
}
return writer, nil
}
type writer struct {
client *http.Client
bucket string
name string
size int64
offset int64
closed bool
sessionURI string
buffer []byte
buffSize int
}
// Cancel removes any written content from this FileWriter.
func (w *writer) Cancel() error {
w.closed = true
err := storageDeleteObject(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
if err != nil {
if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound {
err = nil
}
}
}
return err
}
func (w *writer) Close() error {
if w.closed {
return nil
}
w.closed = true
err := w.writeChunk()
if err != nil {
return err
}
// Copy the remaining bytes from the buffer to the upload session
// Normally buffSize will be smaller than minChunkSize. However, in the
// unlikely event that the upload session failed to start, this number could be higher.
// In this case we can safely clip the remaining bytes to the minChunkSize
if w.buffSize > minChunkSize {
w.buffSize = minChunkSize
}
// commit the writes by updating the upload session
err = retry(func() error {
wc := storage.NewWriter(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
wc.ContentType = uploadSessionContentType
wc.Metadata = map[string]string{
"Session-URI": w.sessionURI,
"Offset": strconv.FormatInt(w.offset, 10),
}
return putContentsClose(wc, w.buffer[0:w.buffSize])
})
if err != nil {
return err
}
w.size = w.offset + int64(w.buffSize)
w.buffSize = 0
return nil
}
func putContentsClose(wc *storage.Writer, contents []byte) error {
size := len(contents)
var nn int
var err error
for nn < size {
n, err := wc.Write(contents[nn:size])
nn += n
if err != nil {
break
}
}
if err != nil {
wc.CloseWithError(err)
return err
}
return wc.Close()
}
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
func (w *writer) Commit() error {
if err := w.checkClosed(); err != nil {
return err
}
w.closed = true
// no session started yet just perform a simple upload
if w.sessionURI == "" {
err := retry(func() error {
wc := storage.NewWriter(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
wc.ContentType = "application/octet-stream"
return putContentsClose(wc, w.buffer[0:w.buffSize])
})
if err != nil {
return err
}
w.size = w.offset + int64(w.buffSize)
w.buffSize = 0
return nil
}
size := w.offset + int64(w.buffSize)
var nn int
// loop must be performed at least once to ensure the file is committed even when
// the buffer is empty
for {
n, err := putChunk(w.client, w.sessionURI, w.buffer[nn:w.buffSize], w.offset, size)
nn += int(n)
w.offset += n
w.size = w.offset
if err != nil {
w.buffSize = copy(w.buffer, w.buffer[nn:w.buffSize])
return err
}
if nn == w.buffSize {
break
}
}
w.buffSize = 0
return nil
}
func (w *writer) checkClosed() error {
if w.closed {
return fmt.Errorf("Writer already closed")
}
return nil
}
func (w *writer) writeChunk() error {
var err error
// chunks can be uploaded only in multiples of minChunkSize
// chunkSize is a multiple of minChunkSize less than or equal to buffSize
chunkSize := w.buffSize - (w.buffSize % minChunkSize)
if chunkSize == 0 {
return nil
}
// if their is no sessionURI yet, obtain one by starting the session
if w.sessionURI == "" {
w.sessionURI, err = startSession(w.client, w.bucket, w.name)
}
if err != nil {
return err
}
nn, err := putChunk(w.client, w.sessionURI, w.buffer[0:chunkSize], w.offset, -1)
w.offset += nn
if w.offset > w.size {
w.size = w.offset
}
// shift the remaining bytes to the start of the buffer
w.buffSize = copy(w.buffer, w.buffer[int(nn):w.buffSize])
return err
}
func (w *writer) Write(p []byte) (int, error) {
err := w.checkClosed()
if err != nil {
return 0, err
}
var nn int
for nn < len(p) {
n := copy(w.buffer[w.buffSize:], p[nn:])
w.buffSize += n
if w.buffSize == cap(w.buffer) {
err = w.writeChunk()
if err != nil {
break
}
}
nn += n
}
return nn, err
}
// Size returns the number of bytes written to this FileWriter.
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) init(path string) error {
res, err := getObject(w.client, w.bucket, w.name, 0)
if err != nil {
return err
}
defer res.Body.Close()
if res.Header.Get("Content-Type") != uploadSessionContentType {
return storagedriver.PathNotFoundError{Path: path}
}
offset, err := strconv.ParseInt(res.Header.Get("X-Goog-Meta-Offset"), 10, 64)
if err != nil {
return err
}
buffer, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
w.sessionURI = res.Header.Get("X-Goog-Meta-Session-URI")
w.buffSize = copy(w.buffer, buffer)
w.offset = offset
w.size = offset + int64(w.buffSize)
return nil
}
type request func() error
func retry(req request) error {
backoff := time.Second
var err error
for i := 0; i < maxTries; i++ {
err = req()
if err == nil {
return nil
}
status, ok := err.(*googleapi.Error)
if !ok || (status.Code != 429 && status.Code < http.StatusInternalServerError) {
return err
}
time.Sleep(backoff - time.Second + (time.Duration(rand.Int31n(1000)) * time.Millisecond))
if i <= 4 {
backoff = backoff * 2
}
}
return err
}
// Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time.
func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo, error) {
var fi storagedriver.FileInfoFields
//try to get as file
gcsContext := d.context(context)
obj, err := storageStatObject(gcsContext, d.bucket, d.pathToKey(path))
if err == nil {
if obj.ContentType == uploadSessionContentType {
return nil, storagedriver.PathNotFoundError{Path: path}
}
fi = storagedriver.FileInfoFields{
Path: path,
Size: obj.Size,
ModTime: obj.Updated,
IsDir: false,
}
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
}
//try to get as folder
dirpath := d.pathToDirKey(path)
var query *storage.Query
query = &storage.Query{}
query.Prefix = dirpath
query.MaxResults = 1
objects, err := storageListObjects(gcsContext, d.bucket, query)
if err != nil {
return nil, err
}
if len(objects.Results) < 1 {
return nil, storagedriver.PathNotFoundError{Path: path}
}
fi = storagedriver.FileInfoFields{
Path: path,
IsDir: true,
}
obj = objects.Results[0]
if obj.Name == dirpath {
fi.Size = obj.Size
fi.ModTime = obj.Updated
}
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
}
// List returns a list of the objects that are direct descendants of the
//given path.
func (d *driver) List(context ctx.Context, path string) ([]string, error) {
var query *storage.Query
query = &storage.Query{}
query.Delimiter = "/"
query.Prefix = d.pathToDirKey(path)
list := make([]string, 0, 64)
for {
objects, err := storageListObjects(d.context(context), d.bucket, query)
if err != nil {
return nil, err
}
for _, object := range objects.Results {
// GCS does not guarantee strong consistency between
// DELETE and LIST operations. Check that the object is not deleted,
// and filter out any objects with a non-zero time-deleted
if object.Deleted.IsZero() && object.ContentType != uploadSessionContentType {
list = append(list, d.keyToPath(object.Name))
}
}
for _, subpath := range objects.Prefixes {
subpath = d.keyToPath(subpath)
list = append(list, subpath)
}
query = objects.Next
if query == nil {
break
}
}
if path != "/" && len(list) == 0 {
// Treat empty response as missing directory, since we don't actually
// have directories in Google Cloud Storage.
return nil, storagedriver.PathNotFoundError{Path: path}
}
return list, nil
}
// Move moves an object stored at sourcePath to destPath, removing the
// original object.
func (d *driver) Move(context ctx.Context, sourcePath string, destPath string) error {
gcsContext := d.context(context)
_, err := storageCopyObject(gcsContext, d.bucket, d.pathToKey(sourcePath), d.bucket, d.pathToKey(destPath), nil)
if err != nil {
if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound {
return storagedriver.PathNotFoundError{Path: sourcePath}
}
}
return err
}
err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(sourcePath))
// if deleting the file fails, log the error, but do not fail; the file was successfully copied,
// and the original should eventually be cleaned when purging the uploads folder.
if err != nil {
logrus.Infof("error deleting file: %v due to %v", sourcePath, err)
}
return nil
}
// listAll recursively lists all names of objects stored at "prefix" and its subpaths.
func (d *driver) listAll(context context.Context, prefix string) ([]string, error) {
list := make([]string, 0, 64)
query := &storage.Query{}
query.Prefix = prefix
query.Versions = false
for {
objects, err := storageListObjects(d.context(context), d.bucket, query)
if err != nil {
return nil, err
}
for _, obj := range objects.Results {
// GCS does not guarantee strong consistency between
// DELETE and LIST operations. Check that the object is not deleted,
// and filter out any objects with a non-zero time-deleted
if obj.Deleted.IsZero() {
list = append(list, obj.Name)
}
}
query = objects.Next
if query == nil {
break
}
}
return list, nil
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(context ctx.Context, path string) error {
prefix := d.pathToDirKey(path)
gcsContext := d.context(context)
keys, err := d.listAll(gcsContext, prefix)
if err != nil {
return err
}
if len(keys) > 0 {
sort.Sort(sort.Reverse(sort.StringSlice(keys)))
for _, key := range keys {
err := storageDeleteObject(gcsContext, d.bucket, key)
// GCS only guarantees eventual consistency, so listAll might return
// paths that no longer exist. If this happens, just ignore any not
// found error
if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound {
err = nil
}
}
if err != nil {
return err
}
}
return nil
}
err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(path))
if err != nil {
if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound {
return storagedriver.PathNotFoundError{Path: path}
}
}
}
return err
}
func storageDeleteObject(context context.Context, bucket string, name string) error {
return retry(func() error {
return storage.DeleteObject(context, bucket, name)
})
}
func storageStatObject(context context.Context, bucket string, name string) (*storage.Object, error) {
var obj *storage.Object
err := retry(func() error {
var err error
obj, err = storage.StatObject(context, bucket, name)
return err
})
return obj, err
}
func storageListObjects(context context.Context, bucket string, q *storage.Query) (*storage.Objects, error) {
var objs *storage.Objects
err := retry(func() error {
var err error
objs, err = storage.ListObjects(context, bucket, q)
return err
})
return objs, err
}
func storageCopyObject(context context.Context, srcBucket, srcName string, destBucket, destName string, attrs *storage.ObjectAttrs) (*storage.Object, error) {
var obj *storage.Object
err := retry(func() error {
var err error
obj, err = storage.CopyObject(context, srcBucket, srcName, destBucket, destName, attrs)
return err
})
return obj, err
}
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// Returns ErrUnsupportedMethod if this driver has no privateKey
func (d *driver) URLFor(context ctx.Context, path string, options map[string]interface{}) (string, error) {
if d.privateKey == nil {
return "", storagedriver.ErrUnsupportedMethod{}
}
name := d.pathToKey(path)
methodString := "GET"
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != "GET" && methodString != "HEAD") {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresTime := time.Now().Add(20 * time.Minute)
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresTime = et
}
}
opts := &storage.SignedURLOptions{
GoogleAccessID: d.email,
PrivateKey: d.privateKey,
Method: methodString,
Expires: expiresTime,
}
return storage.SignedURL(d.bucket, name, opts)
}
func startSession(client *http.Client, bucket string, name string) (uri string, err error) {
u := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: fmt.Sprintf("/upload/storage/v1/b/%v/o", bucket),
RawQuery: fmt.Sprintf("uploadType=resumable&name=%v", name),
}
err = retry(func() error {
req, err := http.NewRequest("POST", u.String(), nil)
if err != nil {
return err
}
req.Header.Set("X-Upload-Content-Type", "application/octet-stream")
req.Header.Set("Content-Length", "0")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
err = googleapi.CheckMediaResponse(resp)
if err != nil {
return err
}
uri = resp.Header.Get("Location")
return nil
})
return uri, err
}
func putChunk(client *http.Client, sessionURI string, chunk []byte, from int64, totalSize int64) (int64, error) {
bytesPut := int64(0)
err := retry(func() error {
req, err := http.NewRequest("PUT", sessionURI, bytes.NewReader(chunk))
if err != nil {
return err
}
length := int64(len(chunk))
to := from + length - 1
size := "*"
if totalSize >= 0 {
size = strconv.FormatInt(totalSize, 10)
}
req.Header.Set("Content-Type", "application/octet-stream")
if from == to+1 {
req.Header.Set("Content-Range", fmt.Sprintf("bytes */%v", size))
} else {
req.Header.Set("Content-Range", fmt.Sprintf("bytes %v-%v/%v", from, to, size))
}
req.Header.Set("Content-Length", strconv.FormatInt(length, 10))
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if totalSize < 0 && resp.StatusCode == 308 {
groups := rangeHeader.FindStringSubmatch(resp.Header.Get("Range"))
end, err := strconv.ParseInt(groups[2], 10, 64)
if err != nil {
return err
}
bytesPut = end - from + 1
return nil
}
err = googleapi.CheckMediaResponse(resp)
if err != nil {
return err
}
bytesPut = to - from + 1
return nil
})
return bytesPut, err
}
func (d *driver) context(context ctx.Context) context.Context {
return cloud.WithContext(context, dummyProjectID, d.client)
}
func (d *driver) pathToKey(path string) string {
return strings.TrimRight(d.rootDirectory+strings.TrimLeft(path, "/"), "/")
}
func (d *driver) pathToDirKey(path string) string {
return d.pathToKey(path) + "/"
}
func (d *driver) keyToPath(key string) string {
return "/" + strings.Trim(strings.TrimPrefix(key, d.rootDirectory), "/")
}

View File

@ -0,0 +1,311 @@
// +build include_gcs
package gcs
import (
"io/ioutil"
"os"
"testing"
"fmt"
ctx "github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/googleapi"
"google.golang.org/cloud/storage"
"gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
var gcsDriverConstructor func(rootDirectory string) (storagedriver.StorageDriver, error)
var skipGCS func() string
func init() {
bucket := os.Getenv("REGISTRY_STORAGE_GCS_BUCKET")
credentials := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS")
// Skip GCS storage driver tests if environment variable parameters are not provided
skipGCS = func() string {
if bucket == "" || credentials == "" {
return "The following environment variables must be set to enable these tests: REGISTRY_STORAGE_GCS_BUCKET, GOOGLE_APPLICATION_CREDENTIALS"
}
return ""
}
if skipGCS() != "" {
return
}
root, err := ioutil.TempDir("", "driver-")
if err != nil {
panic(err)
}
defer os.Remove(root)
var ts oauth2.TokenSource
var email string
var privateKey []byte
ts, err = google.DefaultTokenSource(ctx.Background(), storage.ScopeFullControl)
if err != nil {
// Assume that the file contents are within the environment variable since it exists
// but does not contain a valid file path
jwtConfig, err := google.JWTConfigFromJSON([]byte(credentials), storage.ScopeFullControl)
if err != nil {
panic(fmt.Sprintf("Error reading JWT config : %s", err))
}
email = jwtConfig.Email
privateKey = []byte(jwtConfig.PrivateKey)
if len(privateKey) == 0 {
panic("Error reading JWT config : missing private_key property")
}
if email == "" {
panic("Error reading JWT config : missing client_email property")
}
ts = jwtConfig.TokenSource(ctx.Background())
}
gcsDriverConstructor = func(rootDirectory string) (storagedriver.StorageDriver, error) {
parameters := driverParameters{
bucket: bucket,
rootDirectory: root,
email: email,
privateKey: privateKey,
client: oauth2.NewClient(ctx.Background(), ts),
chunkSize: defaultChunkSize,
}
return New(parameters)
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return gcsDriverConstructor(root)
}, skipGCS)
}
// Test Committing a FileWriter without having called Write
func TestCommitEmpty(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
filename := "/test"
ctx := ctx.Background()
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
err = writer.Commit()
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("writer.Close: unexpected error: %v", err)
}
if writer.Size() != 0 {
t.Fatalf("writer.Size: %d != 0", writer.Size())
}
readContents, err := driver.GetContent(ctx, filename)
if err != nil {
t.Fatalf("driver.GetContent: unexpected error: %v", err)
}
if len(readContents) != 0 {
t.Fatalf("len(driver.GetContent(..)): %d != 0", len(readContents))
}
}
// Test Committing a FileWriter after having written exactly
// defaultChunksize bytes.
func TestCommit(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
filename := "/test"
ctx := ctx.Background()
contents := make([]byte, defaultChunkSize)
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
_, err = writer.Write(contents)
if err != nil {
t.Fatalf("writer.Write: unexpected error: %v", err)
}
err = writer.Commit()
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("writer.Close: unexpected error: %v", err)
}
if writer.Size() != int64(len(contents)) {
t.Fatalf("writer.Size: %d != %d", writer.Size(), len(contents))
}
readContents, err := driver.GetContent(ctx, filename)
if err != nil {
t.Fatalf("driver.GetContent: unexpected error: %v", err)
}
if len(readContents) != len(contents) {
t.Fatalf("len(driver.GetContent(..)): %d != %d", len(readContents), len(contents))
}
}
func TestRetry(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
assertError := func(expected string, observed error) {
observedMsg := "<nil>"
if observed != nil {
observedMsg = observed.Error()
}
if observedMsg != expected {
t.Fatalf("expected %v, observed %v\n", expected, observedMsg)
}
}
err := retry(func() error {
return &googleapi.Error{
Code: 503,
Message: "google api error",
}
})
assertError("googleapi: Error 503: google api error", err)
err = retry(func() error {
return &googleapi.Error{
Code: 404,
Message: "google api error",
}
})
assertError("googleapi: Error 404: google api error", err)
err = retry(func() error {
return fmt.Errorf("error")
})
assertError("error", err)
}
func TestEmptyRootList(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
rootedDriver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
emptyRootDriver, err := gcsDriverConstructor("")
if err != nil {
t.Fatalf("unexpected error creating empty root driver: %v", err)
}
slashRootDriver, err := gcsDriverConstructor("/")
if err != nil {
t.Fatalf("unexpected error creating slash root driver: %v", err)
}
filename := "/test"
contents := []byte("contents")
ctx := ctx.Background()
err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer func() {
err := rootedDriver.Delete(ctx, filename)
if err != nil {
t.Fatalf("failed to remove %v due to %v\n", filename, err)
}
}()
keys, err := emptyRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
keys, err = slashRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
}
// TestMoveDirectory checks that moving a directory returns an error.
func TestMoveDirectory(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
ctx := ctx.Background()
contents := []byte("contents")
// Create a regular file.
err = driver.PutContent(ctx, "/parent/dir/foo", contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer func() {
err := driver.Delete(ctx, "/parent")
if err != nil {
t.Fatalf("failed to remove /parent due to %v\n", err)
}
}()
err = driver.Move(ctx, "/parent/dir", "/parent/other")
if err == nil {
t.Fatalf("Moving directory /parent/dir /parent/other should have return a non-nil error\n")
}
}

View File

@ -0,0 +1,312 @@
package inmemory
import (
"fmt"
"io"
"io/ioutil"
"sync"
"time"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
)
const driverName = "inmemory"
func init() {
factory.Register(driverName, &inMemoryDriverFactory{})
}
// inMemoryDriverFacotry implements the factory.StorageDriverFactory interface.
type inMemoryDriverFactory struct{}
func (factory *inMemoryDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return New(), nil
}
type driver struct {
root *dir
mutex sync.RWMutex
}
// baseEmbed allows us to hide the Base embed.
type baseEmbed struct {
base.Base
}
// Driver is a storagedriver.StorageDriver implementation backed by a local map.
// Intended solely for example and testing purposes.
type Driver struct {
baseEmbed // embedded, hidden base driver.
}
var _ storagedriver.StorageDriver = &Driver{}
// New constructs a new Driver.
func New() *Driver {
return &Driver{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: &driver{
root: &dir{
common: common{
p: "/",
mod: time.Now(),
},
},
},
},
},
}
}
// Implement the storagedriver.StorageDriver interface.
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
rc, err := d.Reader(ctx, path, 0)
if err != nil {
return nil, err
}
defer rc.Close()
return ioutil.ReadAll(rc)
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, p string, contents []byte) error {
d.mutex.Lock()
defer d.mutex.Unlock()
normalized := normalize(p)
f, err := d.root.mkfile(normalized)
if err != nil {
// TODO(stevvooe): Again, we need to clarify when this is not a
// directory in StorageDriver API.
return fmt.Errorf("not a file")
}
f.truncate()
f.WriteAt(contents, 0)
return nil
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
if offset < 0 {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
normalized := normalize(path)
found := d.root.find(normalized)
if found.path() != normalized {
return nil, storagedriver.PathNotFoundError{Path: path}
}
if found.isdir() {
return nil, fmt.Errorf("%q is a directory", path)
}
return ioutil.NopCloser(found.(*file).sectionReader(offset)), nil
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
d.mutex.Lock()
defer d.mutex.Unlock()
normalized := normalize(path)
f, err := d.root.mkfile(normalized)
if err != nil {
return nil, fmt.Errorf("not a file")
}
if !append {
f.truncate()
}
return d.newWriter(f), nil
}
// Stat returns info about the provided path.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
normalized := normalize(path)
found := d.root.find(normalized)
if found.path() != normalized {
return nil, storagedriver.PathNotFoundError{Path: path}
}
fi := storagedriver.FileInfoFields{
Path: path,
IsDir: found.isdir(),
ModTime: found.modtime(),
}
if !fi.IsDir {
fi.Size = int64(len(found.(*file).data))
}
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
}
// List returns a list of the objects that are direct descendants of the given
// path.
func (d *driver) List(ctx context.Context, path string) ([]string, error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
normalized := normalize(path)
found := d.root.find(normalized)
if !found.isdir() {
return nil, fmt.Errorf("not a directory") // TODO(stevvooe): Need error type for this...
}
entries, err := found.(*dir).list(normalized)
if err != nil {
switch err {
case errNotExists:
return nil, storagedriver.PathNotFoundError{Path: path}
case errIsNotDir:
return nil, fmt.Errorf("not a directory")
default:
return nil, err
}
}
return entries, nil
}
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
d.mutex.Lock()
defer d.mutex.Unlock()
normalizedSrc, normalizedDst := normalize(sourcePath), normalize(destPath)
err := d.root.move(normalizedSrc, normalizedDst)
switch err {
case errNotExists:
return storagedriver.PathNotFoundError{Path: destPath}
default:
return err
}
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
d.mutex.Lock()
defer d.mutex.Unlock()
normalized := normalize(path)
err := d.root.delete(normalized)
switch err {
case errNotExists:
return storagedriver.PathNotFoundError{Path: path}
default:
return err
}
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
}
type writer struct {
d *driver
f *file
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(f *file) storagedriver.FileWriter {
return &writer{
d: d,
f: f,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
w.d.mutex.Lock()
defer w.d.mutex.Unlock()
return w.f.WriteAt(p, int64(len(w.f.data)))
}
func (w *writer) Size() int64 {
w.d.mutex.RLock()
defer w.d.mutex.RUnlock()
return int64(len(w.f.data))
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return nil
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
w.d.mutex.Lock()
defer w.d.mutex.Unlock()
return w.d.root.delete(w.f.path())
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
w.committed = true
return nil
}

View File

@ -0,0 +1,19 @@
package inmemory
import (
"testing"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
"gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
func init() {
inmemoryDriverConstructor := func() (storagedriver.StorageDriver, error) {
return New(), nil
}
testsuites.RegisterSuite(inmemoryDriverConstructor, testsuites.NeverSkip)
}

View File

@ -0,0 +1,338 @@
package inmemory
import (
"fmt"
"io"
"path"
"sort"
"strings"
"time"
)
var (
errExists = fmt.Errorf("exists")
errNotExists = fmt.Errorf("notexists")
errIsNotDir = fmt.Errorf("notdir")
errIsDir = fmt.Errorf("isdir")
)
type node interface {
name() string
path() string
isdir() bool
modtime() time.Time
}
// dir is the central type for the memory-based storagedriver. All operations
// are dispatched from a root dir.
type dir struct {
common
// TODO(stevvooe): Use sorted slice + search.
children map[string]node
}
var _ node = &dir{}
func (d *dir) isdir() bool {
return true
}
// add places the node n into dir d.
func (d *dir) add(n node) {
if d.children == nil {
d.children = make(map[string]node)
}
d.children[n.name()] = n
d.mod = time.Now()
}
// find searches for the node, given path q in dir. If the node is found, it
// will be returned. If the node is not found, the closet existing parent. If
// the node is found, the returned (node).path() will match q.
func (d *dir) find(q string) node {
q = strings.Trim(q, "/")
i := strings.Index(q, "/")
if q == "" {
return d
}
if i == 0 {
panic("shouldn't happen, no root paths")
}
var component string
if i < 0 {
// No more path components
component = q
} else {
component = q[:i]
}
child, ok := d.children[component]
if !ok {
// Node was not found. Return p and the current node.
return d
}
if child.isdir() {
// traverse down!
q = q[i+1:]
return child.(*dir).find(q)
}
return child
}
func (d *dir) list(p string) ([]string, error) {
n := d.find(p)
if n.path() != p {
return nil, errNotExists
}
if !n.isdir() {
return nil, errIsNotDir
}
var children []string
for _, child := range n.(*dir).children {
children = append(children, child.path())
}
sort.Strings(children)
return children, nil
}
// mkfile or return the existing one. returns an error if it exists and is a
// directory. Essentially, this is open or create.
func (d *dir) mkfile(p string) (*file, error) {
n := d.find(p)
if n.path() == p {
if n.isdir() {
return nil, errIsDir
}
return n.(*file), nil
}
dirpath, filename := path.Split(p)
// Make any non-existent directories
n, err := d.mkdirs(dirpath)
if err != nil {
return nil, err
}
dd := n.(*dir)
n = &file{
common: common{
p: path.Join(dd.path(), filename),
mod: time.Now(),
},
}
dd.add(n)
return n.(*file), nil
}
// mkdirs creates any missing directory entries in p and returns the result.
func (d *dir) mkdirs(p string) (*dir, error) {
p = normalize(p)
n := d.find(p)
if !n.isdir() {
// Found something there
return nil, errIsNotDir
}
if n.path() == p {
return n.(*dir), nil
}
dd := n.(*dir)
relative := strings.Trim(strings.TrimPrefix(p, n.path()), "/")
if relative == "" {
return dd, nil
}
components := strings.Split(relative, "/")
for _, component := range components {
d, err := dd.mkdir(component)
if err != nil {
// This should actually never happen, since there are no children.
return nil, err
}
dd = d
}
return dd, nil
}
// mkdir creates a child directory under d with the given name.
func (d *dir) mkdir(name string) (*dir, error) {
if name == "" {
return nil, fmt.Errorf("invalid dirname")
}
_, ok := d.children[name]
if ok {
return nil, errExists
}
child := &dir{
common: common{
p: path.Join(d.path(), name),
mod: time.Now(),
},
}
d.add(child)
d.mod = time.Now()
return child, nil
}
func (d *dir) move(src, dst string) error {
dstDirname, _ := path.Split(dst)
dp, err := d.mkdirs(dstDirname)
if err != nil {
return err
}
srcDirname, srcFilename := path.Split(src)
sp := d.find(srcDirname)
if normalize(srcDirname) != normalize(sp.path()) {
return errNotExists
}
spd, ok := sp.(*dir)
if !ok {
return errIsNotDir // paranoid.
}
s, ok := spd.children[srcFilename]
if !ok {
return errNotExists
}
delete(spd.children, srcFilename)
switch n := s.(type) {
case *dir:
n.p = dst
case *file:
n.p = dst
}
dp.add(s)
return nil
}
func (d *dir) delete(p string) error {
dirname, filename := path.Split(p)
parent := d.find(dirname)
if normalize(dirname) != normalize(parent.path()) {
return errNotExists
}
if _, ok := parent.(*dir).children[filename]; !ok {
return errNotExists
}
delete(parent.(*dir).children, filename)
return nil
}
// dump outputs a primitive directory structure to stdout.
func (d *dir) dump(indent string) {
fmt.Println(indent, d.name()+"/")
for _, child := range d.children {
if child.isdir() {
child.(*dir).dump(indent + "\t")
} else {
fmt.Println(indent, child.name())
}
}
}
func (d *dir) String() string {
return fmt.Sprintf("&dir{path: %v, children: %v}", d.p, d.children)
}
// file stores actual data in the fs tree. It acts like an open, seekable file
// where operations are conducted through ReadAt and WriteAt. Use it with
// SectionReader for the best effect.
type file struct {
common
data []byte
}
var _ node = &file{}
func (f *file) isdir() bool {
return false
}
func (f *file) truncate() {
f.data = f.data[:0]
}
func (f *file) sectionReader(offset int64) io.Reader {
return io.NewSectionReader(f, offset, int64(len(f.data))-offset)
}
func (f *file) ReadAt(p []byte, offset int64) (n int, err error) {
return copy(p, f.data[offset:]), nil
}
func (f *file) WriteAt(p []byte, offset int64) (n int, err error) {
off := int(offset)
if cap(f.data) < off+len(p) {
data := make([]byte, len(f.data), off+len(p))
copy(data, f.data)
f.data = data
}
f.mod = time.Now()
f.data = f.data[:off+len(p)]
return copy(f.data[off:off+len(p)], p), nil
}
func (f *file) String() string {
return fmt.Sprintf("&file{path: %q}", f.p)
}
// common provides shared fields and methods for node implementations.
type common struct {
p string
mod time.Time
}
func (c *common) name() string {
_, name := path.Split(c.p)
return name
}
func (c *common) path() string {
return c.p
}
func (c *common) modtime() time.Time {
return c.mod
}
func normalize(p string) string {
return "/" + strings.Trim(p, "/")
}

View File

@ -0,0 +1,136 @@
// Package middleware - cloudfront wrapper for storage libs
// N.B. currently only works with S3, not arbitrary sites
//
package middleware
import (
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net/url"
"strings"
"time"
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
)
// cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
// constructs temporary signed CloudFront URLs from the storagedriver layer URL,
// then issues HTTP Temporary Redirects to this CloudFront content URL.
type cloudFrontStorageMiddleware struct {
storagedriver.StorageDriver
urlSigner *sign.URLSigner
baseURL string
duration time.Duration
}
var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
// newCloudFrontLayerHandler constructs and returns a new CloudFront
// LayerHandler implementation.
// Required options: baseurl, privatekey, keypairid
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
base, ok := options["baseurl"]
if !ok {
return nil, fmt.Errorf("no baseurl provided")
}
baseURL, ok := base.(string)
if !ok {
return nil, fmt.Errorf("baseurl must be a string")
}
if !strings.Contains(baseURL, "://") {
baseURL = "https://" + baseURL
}
if !strings.HasSuffix(baseURL, "/") {
baseURL += "/"
}
if _, err := url.Parse(baseURL); err != nil {
return nil, fmt.Errorf("invalid baseurl: %v", err)
}
pk, ok := options["privatekey"]
if !ok {
return nil, fmt.Errorf("no privatekey provided")
}
pkPath, ok := pk.(string)
if !ok {
return nil, fmt.Errorf("privatekey must be a string")
}
kpid, ok := options["keypairid"]
if !ok {
return nil, fmt.Errorf("no keypairid provided")
}
keypairID, ok := kpid.(string)
if !ok {
return nil, fmt.Errorf("keypairid must be a string")
}
pkBytes, err := ioutil.ReadFile(pkPath)
if err != nil {
return nil, fmt.Errorf("failed to read privatekey file: %s", err)
}
block, _ := pem.Decode([]byte(pkBytes))
if block == nil {
return nil, fmt.Errorf("failed to decode private key as an rsa private key")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
urlSigner := sign.NewURLSigner(keypairID, privateKey)
duration := 20 * time.Minute
d, ok := options["duration"]
if ok {
switch d := d.(type) {
case time.Duration:
duration = d
case string:
dur, err := time.ParseDuration(d)
if err != nil {
return nil, fmt.Errorf("invalid duration: %s", err)
}
duration = dur
}
}
return &cloudFrontStorageMiddleware{
StorageDriver: storageDriver,
urlSigner: urlSigner,
baseURL: baseURL,
duration: duration,
}, nil
}
// S3BucketKeyer is any type that is capable of returning the S3 bucket key
// which should be cached by AWS CloudFront.
type S3BucketKeyer interface {
S3BucketKey(path string) string
}
// Resolve returns an http.Handler which can serve the contents of the given
// Layer, or an error if not supported by the storagedriver.
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
// TODO(endophage): currently only supports S3
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
if !ok {
context.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.URLFor(ctx, path, options)
}
cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
if err != nil {
return "", err
}
return cfURL, nil
}
// init registers the cloudfront layerHandler backend.
func init() {
storagemiddleware.Register("cloudfront", storagemiddleware.InitFunc(newCloudFrontStorageMiddleware))
}

View File

@ -0,0 +1,50 @@
package middleware
import (
"fmt"
"net/url"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
)
type redirectStorageMiddleware struct {
storagedriver.StorageDriver
scheme string
host string
}
var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
o, ok := options["baseurl"]
if !ok {
return nil, fmt.Errorf("no baseurl provided")
}
b, ok := o.(string)
if !ok {
return nil, fmt.Errorf("baseurl must be a string")
}
u, err := url.Parse(b)
if err != nil {
return nil, fmt.Errorf("unable to parse redirect baseurl: %s", b)
}
if u.Scheme == "" {
return nil, fmt.Errorf("no scheme specified for redirect baseurl")
}
if u.Host == "" {
return nil, fmt.Errorf("no host specified for redirect baseurl")
}
return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host}, nil
}
func (r *redirectStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
u := &url.URL{Scheme: r.scheme, Host: r.host, Path: path}
return u.String(), nil
}
func init() {
storagemiddleware.Register("redirect", storagemiddleware.InitFunc(newRedirectStorageMiddleware))
}

View File

@ -0,0 +1,58 @@
package middleware
import (
"testing"
check "gopkg.in/check.v1"
)
func Test(t *testing.T) { check.TestingT(t) }
type MiddlewareSuite struct{}
var _ = check.Suite(&MiddlewareSuite{})
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
options := make(map[string]interface{})
_, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.ErrorMatches, "no baseurl provided")
}
func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "example.com"
_, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
}
func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "https://example.com:5443"
middleware, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware)
c.Assert(ok, check.Equals, true)
c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(nil, "/rick/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
}
func (s *MiddlewareSuite) TestHTTP(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "http://example.com"
middleware, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware)
c.Assert(ok, check.Equals, true)
c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(nil, "morty/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data")
}

View File

@ -0,0 +1,39 @@
package storagemiddleware
import (
"fmt"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
// InitFunc is the type of a StorageMiddleware factory function and is
// used to register the constructor for different StorageMiddleware backends.
type InitFunc func(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error)
var storageMiddlewares map[string]InitFunc
// Register is used to register an InitFunc for
// a StorageMiddleware backend with the given name.
func Register(name string, initFunc InitFunc) error {
if storageMiddlewares == nil {
storageMiddlewares = make(map[string]InitFunc)
}
if _, exists := storageMiddlewares[name]; exists {
return fmt.Errorf("name already registered: %s", name)
}
storageMiddlewares[name] = initFunc
return nil
}
// Get constructs a StorageMiddleware with the given options using the named backend.
func Get(name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) {
if storageMiddlewares != nil {
if initFunc, exists := storageMiddlewares[name]; exists {
return initFunc(storageDriver, options)
}
}
return nil, fmt.Errorf("no storage middleware registered with name: %s", name)
}

View File

@ -0,0 +1,3 @@
// Package oss implements the Aliyun OSS Storage driver backend. Support can be
// enabled by including the "include_oss" build tag.
package oss

View File

@ -0,0 +1,683 @@
// Package oss provides a storagedriver.StorageDriver implementation to
// store blobs in Aliyun OSS cloud storage.
//
// This package leverages the denverdino/aliyungo client library for interfacing with
// oss.
//
// Because OSS is a key, value store the Stat call does not support last modification
// time for directories (directories are an abstraction for key, value stores)
//
// +build include_oss
package oss
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/docker/distribution/context"
"github.com/Sirupsen/logrus"
"github.com/denverdino/aliyungo/oss"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
)
const driverName = "oss"
// minChunkSize defines the minimum multipart upload chunk size
// OSS API requires multipart upload chunks to be at least 5MB
const minChunkSize = 5 << 20
const defaultChunkSize = 2 * minChunkSize
const defaultTimeout = 2 * time.Minute // 2 minute timeout per chunk
// listMax is the largest amount of objects you can request from OSS in a list call
const listMax = 1000
//DriverParameters A struct that encapsulates all of the driver parameters after all values have been set
type DriverParameters struct {
AccessKeyID string
AccessKeySecret string
Bucket string
Region oss.Region
Internal bool
Encrypt bool
Secure bool
ChunkSize int64
RootDirectory string
Endpoint string
}
func init() {
factory.Register(driverName, &ossDriverFactory{})
}
// ossDriverFactory implements the factory.StorageDriverFactory interface
type ossDriverFactory struct{}
func (factory *ossDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
type driver struct {
Client *oss.Client
Bucket *oss.Bucket
ChunkSize int64
Encrypt bool
RootDirectory string
}
type baseEmbed struct {
base.Base
}
// Driver is a storagedriver.StorageDriver implementation backed by Aliyun OSS
// Objects are stored at absolute keys in the provided bucket.
type Driver struct {
baseEmbed
}
// FromParameters constructs a new Driver with a given parameters map
// Required parameters:
// - accesskey
// - secretkey
// - region
// - bucket
// - encrypt
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
// Providing no values for these is valid in case the user is authenticating
accessKey, ok := parameters["accesskeyid"]
if !ok {
return nil, fmt.Errorf("No accesskeyid parameter provided")
}
secretKey, ok := parameters["accesskeysecret"]
if !ok {
return nil, fmt.Errorf("No accesskeysecret parameter provided")
}
regionName, ok := parameters["region"]
if !ok || fmt.Sprint(regionName) == "" {
return nil, fmt.Errorf("No region parameter provided")
}
bucket, ok := parameters["bucket"]
if !ok || fmt.Sprint(bucket) == "" {
return nil, fmt.Errorf("No bucket parameter provided")
}
internalBool := false
internal, ok := parameters["internal"]
if ok {
internalBool, ok = internal.(bool)
if !ok {
return nil, fmt.Errorf("The internal parameter should be a boolean")
}
}
encryptBool := false
encrypt, ok := parameters["encrypt"]
if ok {
encryptBool, ok = encrypt.(bool)
if !ok {
return nil, fmt.Errorf("The encrypt parameter should be a boolean")
}
}
secureBool := true
secure, ok := parameters["secure"]
if ok {
secureBool, ok = secure.(bool)
if !ok {
return nil, fmt.Errorf("The secure parameter should be a boolean")
}
}
chunkSize := int64(defaultChunkSize)
chunkSizeParam, ok := parameters["chunksize"]
if ok {
switch v := chunkSizeParam.(type) {
case string:
vv, err := strconv.ParseInt(v, 0, 64)
if err != nil {
return nil, fmt.Errorf("chunksize parameter must be an integer, %v invalid", chunkSizeParam)
}
chunkSize = vv
case int64:
chunkSize = v
case int, uint, int32, uint32, uint64:
chunkSize = reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int()
default:
return nil, fmt.Errorf("invalid valud for chunksize: %#v", chunkSizeParam)
}
if chunkSize < minChunkSize {
return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize)
}
}
rootDirectory, ok := parameters["rootdirectory"]
if !ok {
rootDirectory = ""
}
endpoint, ok := parameters["endpoint"]
if !ok {
endpoint = ""
}
params := DriverParameters{
AccessKeyID: fmt.Sprint(accessKey),
AccessKeySecret: fmt.Sprint(secretKey),
Bucket: fmt.Sprint(bucket),
Region: oss.Region(fmt.Sprint(regionName)),
ChunkSize: chunkSize,
RootDirectory: fmt.Sprint(rootDirectory),
Encrypt: encryptBool,
Secure: secureBool,
Internal: internalBool,
Endpoint: fmt.Sprint(endpoint),
}
return New(params)
}
// New constructs a new Driver with the given Aliyun credentials, region, encryption flag, and
// bucketName
func New(params DriverParameters) (*Driver, error) {
client := oss.NewOSSClient(params.Region, params.Internal, params.AccessKeyID, params.AccessKeySecret, params.Secure)
client.SetEndpoint(params.Endpoint)
bucket := client.Bucket(params.Bucket)
client.SetDebug(false)
// Validate that the given credentials have at least read permissions in the
// given bucket scope.
if _, err := bucket.List(strings.TrimRight(params.RootDirectory, "/"), "", "", 1); err != nil {
return nil, err
}
// TODO(tg123): Currently multipart uploads have no timestamps, so this would be unwise
// if you initiated a new OSS client while another one is running on the same bucket.
d := &driver{
Client: client,
Bucket: bucket,
ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt,
RootDirectory: params.RootDirectory,
}
return &Driver{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: d,
},
},
}, nil
}
// Implement the storagedriver.StorageDriver interface
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
content, err := d.Bucket.Get(d.ossPath(path))
if err != nil {
return nil, parseError(path, err)
}
return content, nil
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
return parseError(path, d.Bucket.Put(d.ossPath(path), contents, d.getContentType(), getPermissions(), d.getOptions()))
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(http.Header)
headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-")
resp, err := d.Bucket.GetResponseWithHeaders(d.ossPath(path), headers)
if err != nil {
return nil, parseError(path, err)
}
// Due to Aliyun OSS API, status 200 and whole object will be return instead of an
// InvalidRange error when range is invalid.
//
// OSS sever will always return http.StatusPartialContent if range is acceptable.
if resp.StatusCode != http.StatusPartialContent {
resp.Body.Close()
return ioutil.NopCloser(bytes.NewReader(nil)), nil
}
return resp.Body, nil
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
key := d.ossPath(path)
if !append {
// TODO (brianbland): cancel other uploads at this path
multi, err := d.Bucket.InitMulti(key, d.getContentType(), getPermissions(), d.getOptions())
if err != nil {
return nil, err
}
return d.newWriter(key, multi, nil), nil
}
multis, _, err := d.Bucket.ListMulti(key, "")
if err != nil {
return nil, parseError(path, err)
}
for _, multi := range multis {
if key != multi.Key {
continue
}
parts, err := multi.ListParts()
if err != nil {
return nil, parseError(path, err)
}
var multiSize int64
for _, part := range parts {
multiSize += part.Size
}
return d.newWriter(key, multi, parts), nil
}
return nil, storagedriver.PathNotFoundError{Path: path}
}
// Stat retrieves the FileInfo for the given path, including the current size
// in bytes and the creation time.
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
listResponse, err := d.Bucket.List(d.ossPath(path), "", "", 1)
if err != nil {
return nil, err
}
fi := storagedriver.FileInfoFields{
Path: path,
}
if len(listResponse.Contents) == 1 {
if listResponse.Contents[0].Key != d.ossPath(path) {
fi.IsDir = true
} else {
fi.IsDir = false
fi.Size = listResponse.Contents[0].Size
timestamp, err := time.Parse(time.RFC3339Nano, listResponse.Contents[0].LastModified)
if err != nil {
return nil, err
}
fi.ModTime = timestamp
}
} else if len(listResponse.CommonPrefixes) == 1 {
fi.IsDir = true
} else {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
}
// List returns a list of the objects that are direct descendants of the given path.
func (d *driver) List(ctx context.Context, opath string) ([]string, error) {
path := opath
if path != "/" && opath[len(path)-1] != '/' {
path = path + "/"
}
// This is to cover for the cases when the rootDirectory of the driver is either "" or "/".
// In those cases, there is no root prefix to replace and we must actually add a "/" to all
// results in order to keep them as valid paths as recognized by storagedriver.PathRegexp
prefix := ""
if d.ossPath("") == "" {
prefix = "/"
}
listResponse, err := d.Bucket.List(d.ossPath(path), "/", "", listMax)
if err != nil {
return nil, parseError(opath, err)
}
files := []string{}
directories := []string{}
for {
for _, key := range listResponse.Contents {
files = append(files, strings.Replace(key.Key, d.ossPath(""), prefix, 1))
}
for _, commonPrefix := range listResponse.CommonPrefixes {
directories = append(directories, strings.Replace(commonPrefix[0:len(commonPrefix)-1], d.ossPath(""), prefix, 1))
}
if listResponse.IsTruncated {
listResponse, err = d.Bucket.List(d.ossPath(path), "/", listResponse.NextMarker, listMax)
if err != nil {
return nil, err
}
} else {
break
}
}
if opath != "/" {
if len(files) == 0 && len(directories) == 0 {
// Treat empty response as missing directory, since we don't actually
// have directories in s3.
return nil, storagedriver.PathNotFoundError{Path: opath}
}
}
return append(files, directories...), nil
}
const maxConcurrency = 10
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
logrus.Infof("Move from %s to %s", d.ossPath(sourcePath), d.ossPath(destPath))
err := d.Bucket.CopyLargeFileInParallel(d.ossPath(sourcePath), d.ossPath(destPath),
d.getContentType(),
getPermissions(),
oss.Options{},
maxConcurrency)
if err != nil {
logrus.Errorf("Failed for move from %s to %s: %v", d.ossPath(sourcePath), d.ossPath(destPath), err)
return parseError(sourcePath, err)
}
return d.Delete(ctx, sourcePath)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, path string) error {
ossPath := d.ossPath(path)
listResponse, err := d.Bucket.List(ossPath, "", "", listMax)
if err != nil || len(listResponse.Contents) == 0 {
return storagedriver.PathNotFoundError{Path: path}
}
ossObjects := make([]oss.Object, listMax)
for len(listResponse.Contents) > 0 {
numOssObjects := len(listResponse.Contents)
for index, key := range listResponse.Contents {
// Stop if we encounter a key that is not a subpath (so that deleting "/a" does not delete "/ab").
if len(key.Key) > len(ossPath) && (key.Key)[len(ossPath)] != '/' {
numOssObjects = index
break
}
ossObjects[index].Key = key.Key
}
err := d.Bucket.DelMulti(oss.Delete{Quiet: false, Objects: ossObjects[0:numOssObjects]})
if err != nil {
return nil
}
if numOssObjects < len(listResponse.Contents) {
return nil
}
listResponse, err = d.Bucket.List(d.ossPath(path), "", "", listMax)
if err != nil {
return err
}
}
return nil
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
methodString := "GET"
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != "GET") {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresTime := time.Now().Add(20 * time.Minute)
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresTime = et
}
}
logrus.Infof("methodString: %s, expiresTime: %v", methodString, expiresTime)
signedURL := d.Bucket.SignedURLWithMethod(methodString, d.ossPath(path), expiresTime, nil, nil)
logrus.Infof("signed URL: %s", signedURL)
return signedURL, nil
}
func (d *driver) ossPath(path string) string {
return strings.TrimLeft(strings.TrimRight(d.RootDirectory, "/")+path, "/")
}
func parseError(path string, err error) error {
if ossErr, ok := err.(*oss.Error); ok && ossErr.StatusCode == http.StatusNotFound && (ossErr.Code == "NoSuchKey" || ossErr.Code == "") {
return storagedriver.PathNotFoundError{Path: path}
}
return err
}
func hasCode(err error, code string) bool {
ossErr, ok := err.(*oss.Error)
return ok && ossErr.Code == code
}
func (d *driver) getOptions() oss.Options {
return oss.Options{ServerSideEncryption: d.Encrypt}
}
func getPermissions() oss.ACL {
return oss.Private
}
func (d *driver) getContentType() string {
return "application/octet-stream"
}
// writer attempts to upload parts to S3 in a buffered fashion where the last
// part is at least as large as the chunksize, so the multipart upload could be
// cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
multi *oss.Multi
parts []oss.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(key string, multi *oss.Multi, parts []oss.Part) storagedriver.FileWriter {
var size int64
for _, part := range parts {
size += part.Size
}
return &writer{
driver: d,
key: key,
multi: multi,
parts: parts,
size: size,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(w.parts[len(w.parts)-1].Size) < minChunkSize {
err := w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return 0, err
}
multi, err := w.driver.Bucket.InitMulti(w.key, w.driver.getContentType(), getPermissions(), w.driver.getOptions())
if err != nil {
return 0, err
}
w.multi = multi
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
contents, err := w.driver.Bucket.Get(w.key)
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart = contents
} else {
// Otherwise we can use the old file as the new first part
_, part, err := multi.PutPartCopy(1, oss.CopyOptions{}, w.driver.Bucket.Name+"/"+w.key)
if err != nil {
return 0, err
}
w.parts = []oss.Part{part}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
err := w.multi.Abort()
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
err = w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
part, err := w.multi.PutPart(len(w.parts)+1, bytes.NewReader(w.readyPart))
if err != nil {
return err
}
w.parts = append(w.parts, part)
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
}

View File

@ -0,0 +1,144 @@
// +build include_oss
package oss
import (
"io/ioutil"
alioss "github.com/denverdino/aliyungo/oss"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
//"log"
"os"
"strconv"
"testing"
"gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
var ossDriverConstructor func(rootDirectory string) (*Driver, error)
var skipCheck func() string
func init() {
accessKey := os.Getenv("ALIYUN_ACCESS_KEY_ID")
secretKey := os.Getenv("ALIYUN_ACCESS_KEY_SECRET")
bucket := os.Getenv("OSS_BUCKET")
region := os.Getenv("OSS_REGION")
internal := os.Getenv("OSS_INTERNAL")
encrypt := os.Getenv("OSS_ENCRYPT")
secure := os.Getenv("OSS_SECURE")
endpoint := os.Getenv("OSS_ENDPOINT")
root, err := ioutil.TempDir("", "driver-")
if err != nil {
panic(err)
}
defer os.Remove(root)
ossDriverConstructor = func(rootDirectory string) (*Driver, error) {
encryptBool := false
if encrypt != "" {
encryptBool, err = strconv.ParseBool(encrypt)
if err != nil {
return nil, err
}
}
secureBool := false
if secure != "" {
secureBool, err = strconv.ParseBool(secure)
if err != nil {
return nil, err
}
}
internalBool := false
if internal != "" {
internalBool, err = strconv.ParseBool(internal)
if err != nil {
return nil, err
}
}
parameters := DriverParameters{
AccessKeyID: accessKey,
AccessKeySecret: secretKey,
Bucket: bucket,
Region: alioss.Region(region),
Internal: internalBool,
ChunkSize: minChunkSize,
RootDirectory: rootDirectory,
Encrypt: encryptBool,
Secure: secureBool,
Endpoint: endpoint,
}
return New(parameters)
}
// Skip OSS storage driver tests if environment variable parameters are not provided
skipCheck = func() string {
if accessKey == "" || secretKey == "" || region == "" || bucket == "" || encrypt == "" {
return "Must set ALIYUN_ACCESS_KEY_ID, ALIYUN_ACCESS_KEY_SECRET, OSS_REGION, OSS_BUCKET, and OSS_ENCRYPT to run OSS tests"
}
return ""
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return ossDriverConstructor(root)
}, skipCheck)
}
func TestEmptyRootList(t *testing.T) {
if skipCheck() != "" {
t.Skip(skipCheck())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
rootedDriver, err := ossDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
emptyRootDriver, err := ossDriverConstructor("")
if err != nil {
t.Fatalf("unexpected error creating empty root driver: %v", err)
}
slashRootDriver, err := ossDriverConstructor("/")
if err != nil {
t.Fatalf("unexpected error creating slash root driver: %v", err)
}
filename := "/test"
contents := []byte("contents")
ctx := context.Background()
err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer rootedDriver.Delete(ctx, filename)
keys, err := emptyRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
keys, err = slashRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,313 @@
package s3
import (
"bytes"
"io/ioutil"
"math/rand"
"os"
"strconv"
"testing"
"gopkg.in/check.v1"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
var s3DriverConstructor func(rootDirectory, storageClass string) (*Driver, error)
var skipS3 func() string
func init() {
accessKey := os.Getenv("AWS_ACCESS_KEY")
secretKey := os.Getenv("AWS_SECRET_KEY")
bucket := os.Getenv("S3_BUCKET")
encrypt := os.Getenv("S3_ENCRYPT")
keyID := os.Getenv("S3_KEY_ID")
secure := os.Getenv("S3_SECURE")
v4Auth := os.Getenv("S3_V4_AUTH")
region := os.Getenv("AWS_REGION")
objectACL := os.Getenv("S3_OBJECT_ACL")
root, err := ioutil.TempDir("", "driver-")
regionEndpoint := os.Getenv("REGION_ENDPOINT")
if err != nil {
panic(err)
}
defer os.Remove(root)
s3DriverConstructor = func(rootDirectory, storageClass string) (*Driver, error) {
encryptBool := false
if encrypt != "" {
encryptBool, err = strconv.ParseBool(encrypt)
if err != nil {
return nil, err
}
}
secureBool := true
if secure != "" {
secureBool, err = strconv.ParseBool(secure)
if err != nil {
return nil, err
}
}
v4Bool := true
if v4Auth != "" {
v4Bool, err = strconv.ParseBool(v4Auth)
if err != nil {
return nil, err
}
}
parameters := DriverParameters{
accessKey,
secretKey,
bucket,
region,
regionEndpoint,
encryptBool,
keyID,
secureBool,
v4Bool,
minChunkSize,
defaultMultipartCopyChunkSize,
defaultMultipartCopyMaxConcurrency,
defaultMultipartCopyThresholdSize,
rootDirectory,
storageClass,
driverName + "-test",
objectACL,
}
return New(parameters)
}
// Skip S3 storage driver tests if environment variable parameters are not provided
skipS3 = func() string {
if accessKey == "" || secretKey == "" || region == "" || bucket == "" || encrypt == "" {
return "Must set AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION, S3_BUCKET, and S3_ENCRYPT to run S3 tests"
}
return ""
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return s3DriverConstructor(root, s3.StorageClassStandard)
}, skipS3)
}
func TestEmptyRootList(t *testing.T) {
if skipS3() != "" {
t.Skip(skipS3())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
rootedDriver, err := s3DriverConstructor(validRoot, s3.StorageClassStandard)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
emptyRootDriver, err := s3DriverConstructor("", s3.StorageClassStandard)
if err != nil {
t.Fatalf("unexpected error creating empty root driver: %v", err)
}
slashRootDriver, err := s3DriverConstructor("/", s3.StorageClassStandard)
if err != nil {
t.Fatalf("unexpected error creating slash root driver: %v", err)
}
filename := "/test"
contents := []byte("contents")
ctx := context.Background()
err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer rootedDriver.Delete(ctx, filename)
keys, err := emptyRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
keys, err = slashRootDriver.List(ctx, "/")
for _, path := range keys {
if !storagedriver.PathRegexp.MatchString(path) {
t.Fatalf("unexpected string in path: %q != %q", path, storagedriver.PathRegexp)
}
}
}
func TestStorageClass(t *testing.T) {
if skipS3() != "" {
t.Skip(skipS3())
}
rootDir, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(rootDir)
standardDriver, err := s3DriverConstructor(rootDir, s3.StorageClassStandard)
if err != nil {
t.Fatalf("unexpected error creating driver with standard storage: %v", err)
}
rrDriver, err := s3DriverConstructor(rootDir, s3.StorageClassReducedRedundancy)
if err != nil {
t.Fatalf("unexpected error creating driver with reduced redundancy storage: %v", err)
}
if _, err = s3DriverConstructor(rootDir, noStorageClass); err != nil {
t.Fatalf("unexpected error creating driver without storage class: %v", err)
}
standardFilename := "/test-standard"
rrFilename := "/test-rr"
contents := []byte("contents")
ctx := context.Background()
err = standardDriver.PutContent(ctx, standardFilename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer standardDriver.Delete(ctx, standardFilename)
err = rrDriver.PutContent(ctx, rrFilename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
defer rrDriver.Delete(ctx, rrFilename)
standardDriverUnwrapped := standardDriver.Base.StorageDriver.(*driver)
resp, err := standardDriverUnwrapped.S3.GetObject(&s3.GetObjectInput{
Bucket: aws.String(standardDriverUnwrapped.Bucket),
Key: aws.String(standardDriverUnwrapped.s3Path(standardFilename)),
})
if err != nil {
t.Fatalf("unexpected error retrieving standard storage file: %v", err)
}
defer resp.Body.Close()
// Amazon only populates this header value for non-standard storage classes
if resp.StorageClass != nil {
t.Fatalf("unexpected storage class for standard file: %v", resp.StorageClass)
}
rrDriverUnwrapped := rrDriver.Base.StorageDriver.(*driver)
resp, err = rrDriverUnwrapped.S3.GetObject(&s3.GetObjectInput{
Bucket: aws.String(rrDriverUnwrapped.Bucket),
Key: aws.String(rrDriverUnwrapped.s3Path(rrFilename)),
})
if err != nil {
t.Fatalf("unexpected error retrieving reduced-redundancy storage file: %v", err)
}
defer resp.Body.Close()
if resp.StorageClass == nil {
t.Fatalf("unexpected storage class for reduced-redundancy file: %v", s3.StorageClassStandard)
} else if *resp.StorageClass != s3.StorageClassReducedRedundancy {
t.Fatalf("unexpected storage class for reduced-redundancy file: %v", *resp.StorageClass)
}
}
func TestOverThousandBlobs(t *testing.T) {
if skipS3() != "" {
t.Skip(skipS3())
}
rootDir, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(rootDir)
standardDriver, err := s3DriverConstructor(rootDir, s3.StorageClassStandard)
if err != nil {
t.Fatalf("unexpected error creating driver with standard storage: %v", err)
}
ctx := context.Background()
for i := 0; i < 1005; i++ {
filename := "/thousandfiletest/file" + strconv.Itoa(i)
contents := []byte("contents")
err = standardDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
}
// cant actually verify deletion because read-after-delete is inconsistent, but can ensure no errors
err = standardDriver.Delete(ctx, "/thousandfiletest")
if err != nil {
t.Fatalf("unexpected error deleting thousand files: %v", err)
}
}
func TestMoveWithMultipartCopy(t *testing.T) {
if skipS3() != "" {
t.Skip(skipS3())
}
rootDir, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(rootDir)
d, err := s3DriverConstructor(rootDir, s3.StorageClassStandard)
if err != nil {
t.Fatalf("unexpected error creating driver: %v", err)
}
ctx := context.Background()
sourcePath := "/source"
destPath := "/dest"
defer d.Delete(ctx, sourcePath)
defer d.Delete(ctx, destPath)
// An object larger than d's MultipartCopyThresholdSize will cause d.Move() to perform a multipart copy.
multipartCopyThresholdSize := d.baseEmbed.Base.StorageDriver.(*driver).MultipartCopyThresholdSize
contents := make([]byte, 2*multipartCopyThresholdSize)
rand.Read(contents)
err = d.PutContent(ctx, sourcePath, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
}
err = d.Move(ctx, sourcePath, destPath)
if err != nil {
t.Fatalf("unexpected error moving file: %v", err)
}
received, err := d.GetContent(ctx, destPath)
if err != nil {
t.Fatalf("unexpected error getting content: %v", err)
}
if !bytes.Equal(contents, received) {
t.Fatal("content differs")
}
_, err = d.GetContent(ctx, sourcePath)
switch err.(type) {
case storagedriver.PathNotFoundError:
default:
t.Fatalf("unexpected error getting content: %v", err)
}
}

View File

@ -0,0 +1,219 @@
package s3
// Source: https://github.com/pivotal-golang/s3cli
// Copyright (c) 2013 Damien Le Berrigaud and Nick Wade
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"net/http"
"net/url"
"sort"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
const (
signatureVersion = "2"
signatureMethod = "HmacSHA1"
timeFormat = "2006-01-02T15:04:05Z"
)
type signer struct {
// Values that must be populated from the request
Request *http.Request
Time time.Time
Credentials *credentials.Credentials
Query url.Values
stringToSign string
signature string
}
var s3ParamsToSign = map[string]bool{
"acl": true,
"location": true,
"logging": true,
"notification": true,
"partNumber": true,
"policy": true,
"requestPayment": true,
"torrent": true,
"uploadId": true,
"uploads": true,
"versionId": true,
"versioning": true,
"versions": true,
"response-content-type": true,
"response-content-language": true,
"response-expires": true,
"response-cache-control": true,
"response-content-disposition": true,
"response-content-encoding": true,
"website": true,
"delete": true,
}
// setv2Handlers will setup v2 signature signing on the S3 driver
func setv2Handlers(svc *s3.S3) {
svc.Handlers.Build.PushBack(func(r *request.Request) {
parsedURL, err := url.Parse(r.HTTPRequest.URL.String())
if err != nil {
log.Fatalf("Failed to parse URL: %v", err)
}
r.HTTPRequest.URL.Opaque = parsedURL.Path
})
svc.Handlers.Sign.Clear()
svc.Handlers.Sign.PushBack(Sign)
svc.Handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
}
// Sign requests with signature version 2.
//
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *request.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Config.Credentials == credentials.AnonymousCredentials {
return
}
v2 := signer{
Request: req.HTTPRequest,
Time: req.Time,
Credentials: req.Config.Credentials,
}
v2.Sign()
}
func (v2 *signer) Sign() error {
credValue, err := v2.Credentials.Get()
if err != nil {
return err
}
accessKey := credValue.AccessKeyID
var (
md5, ctype, date, xamz string
xamzDate bool
sarray []string
smap map[string]string
sharray []string
)
headers := v2.Request.Header
params := v2.Request.URL.Query()
parsedURL, err := url.Parse(v2.Request.URL.String())
if err != nil {
return err
}
host, canonicalPath := parsedURL.Host, parsedURL.Path
v2.Request.Header["Host"] = []string{host}
v2.Request.Header["date"] = []string{v2.Time.In(time.UTC).Format(time.RFC1123)}
smap = make(map[string]string)
for k, v := range headers {
k = strings.ToLower(k)
switch k {
case "content-md5":
md5 = v[0]
case "content-type":
ctype = v[0]
case "date":
if !xamzDate {
date = v[0]
}
default:
if strings.HasPrefix(k, "x-amz-") {
vall := strings.Join(v, ",")
smap[k] = k + ":" + vall
if k == "x-amz-date" {
xamzDate = true
date = ""
}
sharray = append(sharray, k)
}
}
}
if len(sharray) > 0 {
sort.StringSlice(sharray).Sort()
for _, h := range sharray {
sarray = append(sarray, smap[h])
}
xamz = strings.Join(sarray, "\n") + "\n"
}
expires := false
if v, ok := params["Expires"]; ok {
expires = true
date = v[0]
params["AWSAccessKeyId"] = []string{accessKey}
}
sarray = sarray[0:0]
for k, v := range params {
if s3ParamsToSign[k] {
for _, vi := range v {
if vi == "" {
sarray = append(sarray, k)
} else {
sarray = append(sarray, k+"="+vi)
}
}
}
}
if len(sarray) > 0 {
sort.StringSlice(sarray).Sort()
canonicalPath = canonicalPath + "?" + strings.Join(sarray, "&")
}
v2.stringToSign = strings.Join([]string{
v2.Request.Method,
md5,
ctype,
date,
xamz + canonicalPath,
}, "\n")
hash := hmac.New(sha1.New, []byte(credValue.SecretAccessKey))
hash.Write([]byte(v2.stringToSign))
v2.signature = base64.StdEncoding.EncodeToString(hash.Sum(nil))
if expires {
params["Signature"] = []string{string(v2.signature)}
} else {
headers["Authorization"] = []string{"AWS " + accessKey + ":" + string(v2.signature)}
}
log.WithFields(log.Fields{
"string-to-sign": v2.stringToSign,
"signature": v2.signature,
}).Debugln("request signature")
return nil
}

Some files were not shown because too many files have changed in this diff Show More