1
0
Fork 1
mirror of https://github.com/git-pkgs/proxy.git synced 2026-06-02 16:48:16 -04:00
pkg-proxy/internal/server/api_test.go
Andrew Nesbitt 522c6f233e
Validate package paths before database lookups (#109)
The wildcard package routes (/packages/{ecosystem}/*, /api/package/*,
/api/vulns/*, /api/browse/*, /api/compare/*) only checked for an empty
path before passing user input to GetPackageByEcosystemName and the
enrichment service.

Add validatePackagePath as a coarse first-line filter: reject null
bytes, other control characters, and paths over 512 bytes. Wired into
all five entry handlers immediately after the chi wildcard is read.

This is the generic layer; ecosystem-specific name format rules (npm
scoped name shape, Maven coordinate structure, etc.) can be added on
top per #75.

Fixes #75
2026-05-03 09:14:18 +01:00

491 lines
12 KiB
Go

package server
import (
"bytes"
"database/sql"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/git-pkgs/proxy/internal/database"
"github.com/git-pkgs/proxy/internal/enrichment"
"github.com/go-chi/chi/v5"
)
const testEcosystemNPM = "npm"
func TestNewAPIHandler(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
if h == nil {
t.Fatal("NewAPIHandler returned nil")
}
if h.enrichment == nil {
t.Error("enrichment service is nil")
}
}
func TestHandlePackagePath_MissingParams(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
r := chi.NewRouter()
r.Get("/api/package/{ecosystem}/*", h.HandlePackagePath)
req := httptest.NewRequest("GET", "/api/package//", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest && w.Code != http.StatusNotFound {
t.Errorf("expected status 400 or 404, got %d", w.Code)
}
}
func TestHandlePackagePath_InvalidName(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
r := chi.NewRouter()
r.Get("/api/package/{ecosystem}/*", h.HandlePackagePath)
tests := []struct {
name string
path string
}{
{"null byte", "/api/package/npm/lodash%00"},
{"too long", "/api/package/npm/" + strings.Repeat("a", maxPackagePathLen+1)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
})
}
}
func TestHandleVulnsPath_MissingParams(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
r := chi.NewRouter()
r.Get("/api/vulns/{ecosystem}/*", h.HandleVulnsPath)
req := httptest.NewRequest("GET", "/api/vulns//", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest && w.Code != http.StatusNotFound {
t.Errorf("expected status 400 or 404, got %d", w.Code)
}
}
func TestHandleOutdated_EmptyBody(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
req := httptest.NewRequest("POST", "/api/outdated", bytes.NewReader([]byte("{}")))
w := httptest.NewRecorder()
h.HandleOutdated(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestHandleOutdated_OversizedBody(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
// Send a body larger than 1 MB
body := make([]byte, 2<<20)
for i := range body {
body[i] = 'x'
}
req := httptest.NewRequest("POST", "/api/outdated", bytes.NewReader(body))
w := httptest.NewRecorder()
h.HandleOutdated(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d for oversized body, got %d", http.StatusBadRequest, w.Code)
}
}
func TestHandleOutdated_InvalidJSON(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
req := httptest.NewRequest("POST", "/api/outdated", bytes.NewReader([]byte("not json")))
w := httptest.NewRecorder()
h.HandleOutdated(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestHandleBulkLookup_EmptyBody(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
req := httptest.NewRequest("POST", "/api/bulk", bytes.NewReader([]byte("{}")))
w := httptest.NewRecorder()
h.HandleBulkLookup(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestHandleBulkLookup_InvalidJSON(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
req := httptest.NewRequest("POST", "/api/bulk", bytes.NewReader([]byte("not json")))
w := httptest.NewRecorder()
h.HandleBulkLookup(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestWriteJSON(t *testing.T) {
w := httptest.NewRecorder()
data := map[string]string{"foo": "bar"}
writeJSON(w, data)
if w.Header().Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", w.Header().Get("Content-Type"))
}
var result map[string]string
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if result["foo"] != "bar" {
t.Errorf("expected foo=bar, got foo=%s", result["foo"])
}
}
func TestPackageResponseJSON(t *testing.T) {
resp := &PackageResponse{
Ecosystem: testEcosystemNPM,
Name: "lodash",
LatestVersion: "4.17.21",
License: "MIT",
LicenseCategory: "permissive",
Description: "A utility library",
Homepage: "https://lodash.com",
Repository: "https://github.com/lodash/lodash",
RegistryURL: "https://registry.npmjs.org",
}
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var decoded PackageResponse
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if decoded.Ecosystem != testEcosystemNPM {
t.Errorf("expected ecosystem npm, got %s", decoded.Ecosystem)
}
if decoded.Name != "lodash" {
t.Errorf("expected name lodash, got %s", decoded.Name)
}
}
func TestVulnResponseJSON(t *testing.T) {
resp := &VulnResponse{
ID: "CVE-2021-1234",
Summary: "Test vulnerability",
Severity: "HIGH",
CVSSScore: 8.5,
FixedVersion: "1.2.3",
References: []string{"https://example.com/advisory"},
}
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var decoded VulnResponse
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if decoded.ID != "CVE-2021-1234" {
t.Errorf("expected ID CVE-2021-1234, got %s", decoded.ID)
}
if decoded.CVSSScore != 8.5 {
t.Errorf("expected CVSS 8.5, got %f", decoded.CVSSScore)
}
}
func TestOutdatedRequestJSON(t *testing.T) {
req := &OutdatedRequest{
Packages: []OutdatedPackage{
{Ecosystem: "npm", Name: "lodash", Version: "4.17.0"},
{Ecosystem: "pypi", Name: "requests", Version: "2.25.0"},
},
}
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var decoded OutdatedRequest
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(decoded.Packages) != 2 {
t.Errorf("expected 2 packages, got %d", len(decoded.Packages))
}
}
func TestBulkRequestJSON(t *testing.T) {
req := &BulkRequest{
PURLs: []string{
"pkg:npm/lodash@4.17.21",
"pkg:pypi/requests@2.28.0",
},
}
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var decoded BulkRequest
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(decoded.PURLs) != 2 {
t.Errorf("expected 2 purls, got %d", len(decoded.PURLs))
}
}
func TestHandleSearch_MissingQuery(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
h := NewAPIHandler(svc, nil)
req := httptest.NewRequest("GET", "/api/search", nil)
w := httptest.NewRecorder()
h.HandleSearch(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
}
func TestHandleSearch_WithNullValues(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
db, err := database.Create(dbPath)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
defer func() { _ = db.Close() }()
pkg := &database.Package{
PURL: "pkg:npm/api-test",
Ecosystem: testEcosystemNPM,
Name: "api-test",
}
if err := db.UpsertPackage(pkg); err != nil {
t.Fatalf("UpsertPackage failed: %v", err)
}
ver := &database.Version{
PURL: "pkg:npm/api-test@1.0.0",
PackagePURL: pkg.PURL,
}
if err := db.UpsertVersion(ver); err != nil {
t.Fatalf("UpsertVersion failed: %v", err)
}
artifact := &database.Artifact{
VersionPURL: ver.PURL,
Filename: "api-test-1.0.0.tgz",
UpstreamURL: "https://registry.npmjs.org/api-test/-/api-test-1.0.0.tgz",
StoragePath: sql.NullString{String: "./cache/test.tgz", Valid: true},
HitCount: 3,
}
if err := db.UpsertArtifact(artifact); err != nil {
t.Fatalf("UpsertArtifact failed: %v", err)
}
h := NewAPIHandler(svc, db)
req := httptest.NewRequest("GET", "/api/search?q=api-test", nil)
w := httptest.NewRecorder()
h.HandleSearch(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp SearchResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(resp.Results) != 1 {
t.Fatalf("expected 1 result, got %d", len(resp.Results))
}
result := resp.Results[0]
if result.Name != "api-test" {
t.Errorf("expected name api-test, got %s", result.Name)
}
if result.LatestVersion != "" {
t.Errorf("expected empty LatestVersion, got %s", result.LatestVersion)
}
if result.License != "" {
t.Errorf("expected empty License, got %s", result.License)
}
if result.Hits != 3 {
t.Errorf("expected 3 hits, got %d", result.Hits)
}
}
func TestHandlePackagesListAPI(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
db, err := database.Create(dbPath)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
defer func() { _ = db.Close() }()
// Seed two packages
for _, name := range []string{"api-list-one", "api-list-two"} {
pkg := &database.Package{
PURL: "pkg:npm/" + name,
Ecosystem: testEcosystemNPM,
Name: name,
}
if err := db.UpsertPackage(pkg); err != nil {
t.Fatalf("UpsertPackage failed: %v", err)
}
ver := &database.Version{
PURL: "pkg:npm/" + name + "@1.0.0",
PackagePURL: pkg.PURL,
}
if err := db.UpsertVersion(ver); err != nil {
t.Fatalf("UpsertVersion failed: %v", err)
}
art := &database.Artifact{
VersionPURL: ver.PURL,
Filename: name + "-1.0.0.tgz",
UpstreamURL: "https://registry.npmjs.org/" + name + "/-/" + name + "-1.0.0.tgz",
StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true},
}
if err := db.UpsertArtifact(art); err != nil {
t.Fatalf("UpsertArtifact failed: %v", err)
}
}
h := NewAPIHandler(svc, db)
r := chi.NewRouter()
r.Get("/api/packages", h.HandlePackagesList)
req := httptest.NewRequest("GET", "/api/packages", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp PackagesListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(resp.Results) < 2 {
t.Fatalf("expected at least 2 results, got %d", len(resp.Results))
}
if resp.SortBy != defaultSortBy {
t.Errorf("expected default sort by hits, got %q", resp.SortBy)
}
found := false
for _, pkg := range resp.Results {
if pkg.Name == "api-list-one" || pkg.Name == "api-list-two" {
found = true
break
}
}
if !found {
t.Error("expected seeded packages in results")
}
}
func TestHandlePackagesListAPI_InvalidSort(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
dir := t.TempDir()
dbPath := filepath.Join(dir, "test.db")
db, err := database.Create(dbPath)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
defer func() { _ = db.Close() }()
h := NewAPIHandler(svc, db)
r := chi.NewRouter()
r.Get("/api/packages", h.HandlePackagesList)
req := httptest.NewRequest("GET", "/api/packages?sort=invalid", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400 for invalid sort, got %d", w.Code)
}
}