mirror of
https://github.com/git-pkgs/proxy.git
synced 2026-06-02 16:48:16 -04:00
136 lines
3.4 KiB
Go
136 lines
3.4 KiB
Go
|
|
package handler
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/sha256"
|
||
|
|
"crypto/sha512"
|
||
|
|
"encoding/base64"
|
||
|
|
"encoding/hex"
|
||
|
|
"io"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func sha256Hex(data string) string {
|
||
|
|
sum := sha256.Sum256([]byte(data))
|
||
|
|
return hex.EncodeToString(sum[:])
|
||
|
|
}
|
||
|
|
|
||
|
|
func sha512SRI(data string) string {
|
||
|
|
sum := sha512.Sum512([]byte(data))
|
||
|
|
return "sha512-" + base64.StdEncoding.EncodeToString(sum[:])
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestParseSRI(t *testing.T) {
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
input string
|
||
|
|
algo string
|
||
|
|
ok bool
|
||
|
|
}{
|
||
|
|
{"sha512", sha512SRI("hello"), "sha512", true},
|
||
|
|
{"sha256", "sha256-" + base64.StdEncoding.EncodeToString([]byte("0123456789012345678901234567890123456789")), "sha256", true},
|
||
|
|
{"empty", "", "", false},
|
||
|
|
{"no dash", "sha512abc", "", false},
|
||
|
|
{"bad base64", "sha512-not!base64", "", false},
|
||
|
|
{"unsupported algo", "md5-" + base64.StdEncoding.EncodeToString([]byte("x")), "", false},
|
||
|
|
{"multi hash takes first", sha512SRI("a") + " " + sha512SRI("b"), "sha512", true},
|
||
|
|
{"whitespace", " " + sha512SRI("x") + " ", "sha512", true},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
algo, digest, ok := parseSRI(tt.input)
|
||
|
|
if ok != tt.ok {
|
||
|
|
t.Fatalf("ok = %v, want %v", ok, tt.ok)
|
||
|
|
}
|
||
|
|
if !tt.ok {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if algo != tt.algo {
|
||
|
|
t.Errorf("algo = %q, want %q", algo, tt.algo)
|
||
|
|
}
|
||
|
|
if len(digest) == 0 {
|
||
|
|
t.Error("digest is empty")
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestVerifyingReader(t *testing.T) {
|
||
|
|
const data = "hello world"
|
||
|
|
goodSHA := sha256Hex(data)
|
||
|
|
goodSRI := sha512SRI(data)
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
hash string
|
||
|
|
sri string
|
||
|
|
wantCalls int
|
||
|
|
}{
|
||
|
|
{"both match", goodSHA, goodSRI, 0},
|
||
|
|
{"sha256 only match", goodSHA, "", 0},
|
||
|
|
{"sri only match", "", goodSRI, 0},
|
||
|
|
{"sha256 mismatch", sha256Hex("other"), "", 1},
|
||
|
|
{"sri mismatch", "", sha512SRI("other"), 1},
|
||
|
|
{"both mismatch", sha256Hex("other"), sha512SRI("other"), 2},
|
||
|
|
{"no checks", "", "", 0},
|
||
|
|
{"unparseable sri ignored", goodSHA, "garbage", 0},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
var calls []string
|
||
|
|
r := newVerifyingReader(io.NopCloser(strings.NewReader(data)), tt.hash, tt.sri,
|
||
|
|
func(reason string) { calls = append(calls, reason) })
|
||
|
|
|
||
|
|
got, err := io.ReadAll(r)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("ReadAll: %v", err)
|
||
|
|
}
|
||
|
|
if string(got) != data {
|
||
|
|
t.Errorf("data corrupted: got %q", got)
|
||
|
|
}
|
||
|
|
if err := r.Close(); err != nil {
|
||
|
|
t.Fatalf("Close: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(calls) != tt.wantCalls {
|
||
|
|
t.Errorf("onMismatch called %d times, want %d: %v", len(calls), tt.wantCalls, calls)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestVerifyingReaderPassthrough(t *testing.T) {
|
||
|
|
src := io.NopCloser(strings.NewReader("x"))
|
||
|
|
r := newVerifyingReader(src, "", "", func(string) { t.Fatal("should not be called") })
|
||
|
|
if r != src {
|
||
|
|
t.Error("expected passthrough when no hashes provided")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestVerifyingReaderPartialRead(t *testing.T) {
|
||
|
|
var calls int
|
||
|
|
r := newVerifyingReader(io.NopCloser(strings.NewReader("hello world")),
|
||
|
|
sha256Hex("hello world"), "", func(string) { calls++ })
|
||
|
|
|
||
|
|
buf := make([]byte, 5)
|
||
|
|
_, _ = r.Read(buf)
|
||
|
|
_ = r.Close()
|
||
|
|
|
||
|
|
if calls != 0 {
|
||
|
|
t.Errorf("onMismatch called %d times for partial read, want 0", calls)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestVerifyingReaderVerifyOnce(t *testing.T) {
|
||
|
|
var calls int
|
||
|
|
r := newVerifyingReader(io.NopCloser(strings.NewReader("x")), sha256Hex("y"), "",
|
||
|
|
func(string) { calls++ })
|
||
|
|
_, _ = io.ReadAll(r)
|
||
|
|
_ = r.Close()
|
||
|
|
_ = r.Close()
|
||
|
|
if calls != 1 {
|
||
|
|
t.Errorf("onMismatch called %d times, want 1", calls)
|
||
|
|
}
|
||
|
|
}
|