diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e0d3cfb..82ee83c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -38,7 +38,7 @@ jobs: images: ghcr.io/${{ github.repository }} - name: Build and push Docker image - uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f + uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf with: context: . push: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f787509..c23de56 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: fetch-depth: 0 persist-credentials: false - - uses: sigstore/cosign-installer@cad07c2e89fa2edd6e2d7bab4c1aa38e53f76003 # v4.1.1 + - uses: sigstore/cosign-installer@6f9f17788090df1f26f669e9d70d6ae9567deba6 # v4.1.2 - name: Set up Go uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0 @@ -27,7 +27,7 @@ jobs: go-version-file: go.mod cache: false - - uses: goreleaser/goreleaser-action@ec59f474b9834571250b370d4735c50f8e2d1e29 # v7.0.0 + - uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2 with: version: "~> v2" args: release --clean diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index 4acd19b..f621254 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -26,4 +26,4 @@ jobs: persist-credentials: false - name: Run zizmor - uses: zizmorcore/zizmor-action@b1d7e1fb5de872772f31590499237e7cce841e8e # v0.5.3 + uses: zizmorcore/zizmor-action@5f14fd08f7cf1cb1609c1e344975f152c7ee938d # v0.5.6 diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..9d4b957 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,28 @@ +version: "2" + +linters: + enable: + - gocritic + - gocognit + - gocyclo + - maintidx + - dupl + - mnd + - unparam + - ireturn + - goconst + - errcheck + settings: + goconst: + min-len: 4 + min-occurrences: 5 + ignore-tests: true + ignore-string-values: + - "^[a-z]+$" + exclusions: + rules: + - path: _test\.go + linters: + - goconst + - dupl + - mnd diff --git a/.goreleaser.yaml b/.goreleaser.yaml index b6256de..e5881db 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -36,11 +36,10 @@ checksum: signs: - cmd: cosign - certificate: "${artifact}.pem" + signature: "${artifact}.cosign.bundle" args: - sign-blob - - "--output-certificate=${certificate}" - - "--output-signature=${signature}" + - "--bundle=${signature}" - "${artifact}" - "--yes" artifacts: checksum diff --git a/Dockerfile b/Dockerfile index 71a9fcc..5124b1d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25-alpine AS builder +FROM golang:1.26.3-alpine AS builder WORKDIR /src @@ -15,7 +15,7 @@ COPY . . # Build the binary RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /proxy ./cmd/proxy -FROM alpine:3.21 +FROM alpine:3.23.4 RUN apk add --no-cache ca-certificates diff --git a/README.md b/README.md index 411f25c..abf3e11 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,13 @@ Resolution order: package override, then ecosystem override, then global default | pub.dev | Dart | Yes | ✓ | | PyPI | Python | Yes | ✓ | | Maven | Java | | ✓ | +| Gradle Build Cache | Java/Kotlin | | ✓ | | NuGet | .NET | Yes | ✓ | | Composer | PHP | Yes | ✓ | | Conan | C/C++ | | ✓ | | Conda | Python/R | Yes | ✓ | | CRAN | R | | ✓ | +| Julia | Julia | | ✓ | | Container | Docker/OCI | | ✓ | | Debian | Debian/Ubuntu | | ✓ | | RPM | RHEL/Fedora | | ✓ | @@ -208,6 +210,34 @@ Add to your `~/.m2/settings.xml`: ``` +The `/maven/` endpoint uses Maven Central as primary upstream and falls back to the Gradle Plugin Portal for Gradle plugin marker metadata and related artifacts when the primary upstream returns not found. + +For Gradle plugin resolution via the same proxy endpoint: + +```kotlin +pluginManagement { + repositories { + maven(url = "http://localhost:8080/maven/") + } +} +``` + +### Gradle HTTP Build Cache + +Configure in `settings.gradle(.kts)`: + +```kotlin +buildCache { + local { + enabled = false + } + remote { + url = uri("http://localhost:8080/gradle/") + push = true + } +} +``` + ### NuGet Configure in `nuget.config`: @@ -295,6 +325,21 @@ local({ }) ``` +### Julia + +Set the Pkg server before starting Julia: + +```bash +export JULIA_PKG_SERVER=http://localhost:8080/julia +``` + +Or inside a running session: + +```julia +ENV["JULIA_PKG_SERVER"] = "http://localhost:8080/julia" +using Pkg; Pkg.update() +``` + ### Docker / Container Registry Configure Docker to use the proxy as a registry mirror in `/etc/docker/daemon.json`: @@ -353,6 +398,7 @@ sudo dnf update ## Configuration The proxy can be configured via: + 1. Command line flags (highest priority) 2. Environment variables 3. Configuration file (YAML or JSON) @@ -560,7 +606,7 @@ Recently cached: | Endpoint | Description | |----------|-------------| | `GET /` | Dashboard (web UI) | -| `GET /health` | Health check (returns "ok" if healthy) | +| `GET /health` | Health check (JSON; HTTP 200 healthy, 503 unhealthy) | | `GET /stats` | Cache statistics (JSON) | | `GET /metrics` | Prometheus metrics | | `GET /npm/*` | npm registry protocol | @@ -576,6 +622,7 @@ Recently cached: | `GET /conan/*` | Conan C/C++ protocol | | `GET /conda/*` | Conda/Anaconda protocol | | `GET /cran/*` | CRAN (R) protocol | +| `GET /julia/*` | Julia Pkg server protocol | | `GET /v2/*` | OCI/Docker registry protocol | | `GET /debian/*` | Debian/APT repository protocol | | `GET /rpm/*` | RPM/Yum repository protocol | @@ -798,9 +845,28 @@ The proxy exposes Prometheus metrics at `GET /metrics`. All metric names are pre | `proxy_storage_operation_duration_seconds` | histogram | `operation` | Storage read/write latency | | `proxy_storage_errors_total` | counter | `operation` | Storage read/write failures | | `proxy_active_requests` | gauge | | In-flight requests | +| `proxy_health_probe_failures_total` | counter | `step` | Storage health probe failures by failing step (`write`, `size`, `read`, `verify`, `delete`). | Cache size and artifact count are refreshed every 60 seconds. The remaining metrics update on each request. +### Health Check + +`/health` returns a structured JSON report of subsystem health. HTTP 200 if all checks pass; 503 if any fail. + +```json +{ + "status": "ok", + "checks": { + "database": {"status": "ok"}, + "storage": {"status": "ok"} + } +} +``` + +Failing checks include an `"error"` field. Storage failures also include a `"step"` field identifying which probe step failed (`write`, `size`, `read`, `verify`, `delete`). When the database check fails, the storage entry reports `{"status": "skipped"}` so the response always carries the same key set. + +Storage probe results are cached for `health.storage_probe_interval` (default 30s) to bound the cost of probing remote backends. A probe holds an internal mutex for up to 10 seconds (the hardcoded per-probe timeout), so `/health` is intended as a Kubernetes **readiness** probe rather than a liveness probe — a slow S3 round-trip should pull the pod from rotation, not restart it. + Scrape config for Prometheus: ```yaml @@ -924,6 +990,7 @@ The proxy will recreate the database on next start. ## Building from Source Requirements: + - Go 1.25 or later ```bash diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 0268e9e..15a71c0 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -72,6 +72,14 @@ // PROXY_DATABASE_URL - PostgreSQL connection URL // PROXY_LOG_LEVEL - Log level // PROXY_LOG_FORMAT - Log format +// PROXY_UPSTREAM_MAVEN - Maven repository upstream URL +// PROXY_UPSTREAM_GRADLE_PLUGIN_PORTAL - Gradle Plugin Portal upstream URL +// PROXY_GRADLE_BUILD_CACHE_READ_ONLY - Disable Gradle PUT uploads +// PROXY_GRADLE_BUILD_CACHE_MAX_UPLOAD_SIZE - Max Gradle PUT request body size +// PROXY_GRADLE_BUILD_CACHE_MAX_AGE - Gradle cache max age eviction +// PROXY_GRADLE_BUILD_CACHE_MAX_SIZE - Gradle cache max total size +// PROXY_GRADLE_BUILD_CACHE_SWEEP_INTERVAL - Gradle cache eviction sweep interval +// PROXY_HEALTH_STORAGE_PROBE_INTERVAL - Storage health probe cache interval (default "30s") // // Example: // @@ -193,6 +201,14 @@ func runServe() { fmt.Fprintf(os.Stderr, " PROXY_DATABASE_URL PostgreSQL connection URL\n") fmt.Fprintf(os.Stderr, " PROXY_LOG_LEVEL Log level\n") fmt.Fprintf(os.Stderr, " PROXY_LOG_FORMAT Log format\n") + fmt.Fprintf(os.Stderr, " PROXY_UPSTREAM_MAVEN Maven repository upstream URL\n") + fmt.Fprintf(os.Stderr, " PROXY_UPSTREAM_GRADLE_PLUGIN_PORTAL Gradle Plugin Portal upstream URL\n") + fmt.Fprintf(os.Stderr, " PROXY_GRADLE_BUILD_CACHE_READ_ONLY Disable Gradle PUT uploads\n") + fmt.Fprintf(os.Stderr, " PROXY_GRADLE_BUILD_CACHE_MAX_UPLOAD_SIZE Max Gradle PUT request body size\n") + fmt.Fprintf(os.Stderr, " PROXY_GRADLE_BUILD_CACHE_MAX_AGE Gradle cache max age eviction\n") + fmt.Fprintf(os.Stderr, " PROXY_GRADLE_BUILD_CACHE_MAX_SIZE Gradle cache max total size\n") + fmt.Fprintf(os.Stderr, " PROXY_GRADLE_BUILD_CACHE_SWEEP_INTERVAL Gradle cache eviction sweep interval\n") + fmt.Fprintf(os.Stderr, " PROXY_HEALTH_STORAGE_PROBE_INTERVAL Storage health probe cache interval\n") } _ = fs.Parse(os.Args[1:]) @@ -454,6 +470,7 @@ func runMirror() { proxy := handler.NewProxy(db, store, fetcher, resolver, logger) proxy.CacheMetadata = true // mirror always caches metadata proxy.MetadataTTL = cfg.ParseMetadataTTL() + proxy.MetadataMaxSize = cfg.ParseMetadataMaxSize() m := mirror.New(proxy, db, store, logger, *concurrency) diff --git a/config.example.yaml b/config.example.yaml index ea17d15..11c751c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -29,6 +29,23 @@ storage: # Empty or "0" means unlimited max_size: "" + # Redirect cached artifact downloads to presigned storage URLs (HTTP 302) + # instead of streaming through the proxy. Only effective for S3 and Azure. + # Leave disabled if clients reach the proxy through an authenticating gateway, + # since presigned URLs bypass it. + direct_serve: false + + # How long presigned URLs remain valid (e.g. "5m", "1h"). Default: "15m". + direct_serve_ttl: "15m" + + # Public base URL to substitute into presigned URLs. Set this when the + # proxy reaches storage at an internal address (127.0.0.1, a Docker + # service name) but clients must use a public hostname. Only scheme and + # host are used; the signed path and query are preserved. For S3/MinIO + # the reverse proxy at this address must forward requests with the + # internal Host header or the SigV4 signature will not validate. + # direct_serve_base_url: "https://minio.example.com" + # Database configuration database: # Database driver: "sqlite" (default) or "postgres" @@ -54,6 +71,12 @@ upstream: # npm registry URL npm: "https://registry.npmjs.org" + # Maven repository URL (used by /maven endpoint) + maven: "https://repo1.maven.org/maven2" + + # Gradle Plugin Portal Maven URL (fallback for plugin marker artifacts) + gradle_plugin_portal: "https://plugins.gradle.org/m2" + # Cargo sparse index URL cargo: "https://index.crates.io" @@ -91,6 +114,35 @@ upstream: # header_name: "X-Auth-Token" # header_value: "${MAVEN_TOKEN}" +# Gradle HttpBuildCache configuration +gradle: + build_cache: + # Set to true to disable PUT uploads (read-only cache mode) + read_only: false + + # Maximum accepted Gradle cache upload body size + # Required and must be > 0 + max_upload_size: "100MB" + + # Evict entries older than this age (set to "0" to disable age-based eviction) + max_age: "168h" + + # Cap total Gradle cache size; oldest entries are deleted first + # ("0" disables size-based eviction) + # max_size: "20GB" + + # How often eviction runs when max_age or max_size is set + sweep_interval: "10m" + +# Health endpoint configuration. +health: + # Minimum time between storage backend probes. + # The /health endpoint runs a write/read/verify/delete round-trip + # against the configured storage backend and caches the result for + # this interval. Set to "0" to probe on every request. + # Default: "30s". + storage_probe_interval: "30s" + # Version cooldown configuration # Hides package versions published too recently, giving the community time # to spot malicious releases before they're pulled into projects. diff --git a/docs/architecture.md b/docs/architecture.md index 81c41cf..85e5aaf 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -14,7 +14,7 @@ The proxy is a caching HTTP server that sits between package manager clients and │ │ /npm/* -> NPMHandler /health -> healthHandler │ │ │ │ /cargo/* -> CargoHandler /stats -> statsHandler │ │ │ │ /gem/* -> GemHandler /metrics -> prometheus │ │ -│ │ ...16 ecosystems /api/* -> APIHandler │ │ +│ │ ...17 ecosystems /api/* -> APIHandler │ │ │ │ / -> Web UI │ │ │ └──────────────────────────────────────────────────────────┘ │ │ │ │ │ │ @@ -277,15 +277,15 @@ HTTP server setup, web UI, and API handlers. - Web UI: dashboard, package browser, source browser, version comparison - Templates are embedded in the binary via `//go:embed` - Enrichment API for package metadata, vulnerability scanning, and outdated detection -- Health, stats, and Prometheus metrics endpoints +- Health, stats, and Prometheus metrics endpoints. `/health` runs an active write → size-check → read → verify → delete probe against the storage backend and returns a structured JSON response (`HealthResponse`) with `"ok"` / `"error"` status per subsystem. Probe results are cached (default 30 s, configurable via `health.storage_probe_interval`) to avoid overwhelming remote backends. ### `internal/metrics` Prometheus metrics for cache performance, upstream latency, storage operations, and active requests. See the Monitoring section of the README for the full metric list. -### `internal/cooldown` +### Cooldown -Version age filtering for supply chain attack mitigation. Configurable at global, ecosystem, and per-package levels. Supported by npm, PyPI, pub.dev, and Composer handlers. +Version age filtering for supply chain attack mitigation, provided by [github.com/git-pkgs/cooldown](https://github.com/git-pkgs/cooldown). Configurable at global, ecosystem, and per-package levels. Supported by npm, PyPI, pub.dev, and Composer handlers. ### `internal/enrichment` diff --git a/docs/configuration.md b/docs/configuration.md index be196de..1310bd0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -114,6 +114,8 @@ Override default upstream registry URLs: ```yaml upstream: npm: "https://registry.npmjs.org" + maven: "https://repo1.maven.org/maven2" + gradle_plugin_portal: "https://plugins.gradle.org/m2" cargo: "https://index.crates.io" cargo_download: "https://static.crates.io/crates" ``` @@ -184,6 +186,30 @@ upstream: token: "${PRIVATE_TOKEN}" ``` +## Gradle Build Cache + +The `/gradle` endpoint supports optional safeguards for upload control and cache retention. + +```yaml +gradle: + build_cache: + read_only: false + max_upload_size: "100MB" + max_age: "168h" + max_size: "20GB" + sweep_interval: "10m" +``` + +| Config | Environment | Description | +|--------|-------------|-------------| +| `gradle.build_cache.read_only` | `PROXY_GRADLE_BUILD_CACHE_READ_ONLY` | Disable PUT uploads and keep GET/HEAD read-only | +| `gradle.build_cache.max_upload_size` | `PROXY_GRADLE_BUILD_CACHE_MAX_UPLOAD_SIZE` | Maximum accepted PUT body size (must be > 0) | +| `gradle.build_cache.max_age` | `PROXY_GRADLE_BUILD_CACHE_MAX_AGE` | Delete entries older than this duration (default `168h`, set `0` to disable) | +| `gradle.build_cache.max_size` | `PROXY_GRADLE_BUILD_CACHE_MAX_SIZE` | Total size cap for `_gradle/http-build-cache`, deleting oldest first (`0` disables) | +| `gradle.build_cache.sweep_interval` | `PROXY_GRADLE_BUILD_CACHE_SWEEP_INTERVAL` | Frequency for background eviction sweeps | + +`max_age` and `max_size` are independent and can be combined. When both are set, age-based eviction runs first, then size-based eviction trims remaining entries oldest-first. + ## Cooldown The cooldown feature hides package versions published too recently, giving the community time to spot malicious releases before they reach your projects. When a version is within its cooldown period, it's stripped from metadata responses so package managers won't install it. @@ -239,6 +265,16 @@ Set to `"0"` to always revalidate with upstream (ETag-based conditional requests When upstream is unreachable and the cached entry is past its TTL, the proxy serves the stale cached copy with a `Warning: 110 - "Response is Stale"` header so clients can tell the data may be outdated. +### Metadata size limit + +Upstream metadata responses are buffered in memory before being rewritten and served. `metadata_max_size` caps that buffer to protect against OOM from a misbehaving upstream. Some npm packages with thousands of versions (for example `renovate`) exceed the 100 MB default, so raise this if you see `metadata response exceeds size limit` in the logs. + +```yaml +metadata_max_size: "100MB" # default +``` + +Or via environment variable: `PROXY_METADATA_MAX_SIZE=250MB`. + ## Mirror API The `/api/mirror` endpoints are disabled by default. Enable them to allow starting mirror jobs via HTTP: diff --git a/docs/swagger/docs.go b/docs/swagger/docs.go index fedf889..23ff54a 100644 --- a/docs/swagger/docs.go +++ b/docs/swagger/docs.go @@ -64,13 +64,13 @@ const docTemplate = `{ "404": { "description": "Not Found", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -126,19 +126,19 @@ const docTemplate = `{ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "404": { "description": "Not Found", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -177,13 +177,13 @@ const docTemplate = `{ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -240,13 +240,13 @@ const docTemplate = `{ "404": { "description": "Not Found", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -285,13 +285,13 @@ const docTemplate = `{ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -338,13 +338,13 @@ const docTemplate = `{ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -384,13 +384,13 @@ const docTemplate = `{ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -399,7 +399,7 @@ const docTemplate = `{ "/health": { "get": { "produces": [ - "text/plain" + "application/json" ], "tags": [ "meta" @@ -409,13 +409,13 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "type": "string" + "$ref": "#/definitions/server.HealthResponse" } }, "503": { "description": "Service Unavailable", "schema": { - "type": "string" + "$ref": "#/definitions/server.HealthResponse" } } } @@ -440,7 +440,7 @@ const docTemplate = `{ "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -504,6 +504,45 @@ const docTemplate = `{ } } }, + "server.ErrorResponse": { + "type": "object", + "properties": { + "code": { + "type": "string" + }, + "message": { + "type": "string" + } + } + }, + "server.HealthCheck": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "status": { + "type": "string" + }, + "step": { + "type": "string" + } + } + }, + "server.HealthResponse": { + "type": "object", + "properties": { + "checks": { + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/server.HealthCheck" + } + }, + "status": { + "type": "string" + } + } + }, "server.OutdatedPackage": { "type": "object", "properties": { diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index 88df1e9..c2b4dfc 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -57,13 +57,13 @@ "404": { "description": "Not Found", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -119,19 +119,19 @@ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "404": { "description": "Not Found", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -170,13 +170,13 @@ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -233,13 +233,13 @@ "404": { "description": "Not Found", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -278,13 +278,13 @@ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -331,13 +331,13 @@ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -377,13 +377,13 @@ "400": { "description": "Bad Request", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } }, "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -392,7 +392,7 @@ "/health": { "get": { "produces": [ - "text/plain" + "application/json" ], "tags": [ "meta" @@ -402,13 +402,13 @@ "200": { "description": "OK", "schema": { - "type": "string" + "$ref": "#/definitions/server.HealthResponse" } }, "503": { "description": "Service Unavailable", "schema": { - "type": "string" + "$ref": "#/definitions/server.HealthResponse" } } } @@ -433,7 +433,7 @@ "500": { "description": "Internal Server Error", "schema": { - "type": "string" + "$ref": "#/definitions/server.ErrorResponse" } } } @@ -497,6 +497,45 @@ } } }, + "server.ErrorResponse": { + "type": "object", + "properties": { + "code": { + "type": "string" + }, + "message": { + "type": "string" + } + } + }, + "server.HealthCheck": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "status": { + "type": "string" + }, + "step": { + "type": "string" + } + } + }, + "server.HealthResponse": { + "type": "object", + "properties": { + "checks": { + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/server.HealthCheck" + } + }, + "status": { + "type": "string" + } + } + }, "server.OutdatedPackage": { "type": "object", "properties": { diff --git a/go.mod b/go.mod index 1067a3a..199c8c8 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,16 @@ module github.com/git-pkgs/proxy go 1.25.6 require ( - github.com/CycloneDX/cyclonedx-go v0.10.0 - github.com/git-pkgs/archives v0.2.2 - github.com/git-pkgs/enrichment v0.2.2 - github.com/git-pkgs/purl v0.1.10 - github.com/git-pkgs/registries v0.4.1 - github.com/git-pkgs/spdx v0.1.2 - github.com/git-pkgs/vers v0.2.4 - github.com/git-pkgs/vulns v0.1.4 + github.com/BurntSushi/toml v1.6.0 + github.com/CycloneDX/cyclonedx-go v0.11.0 + github.com/git-pkgs/archives v0.3.0 + github.com/git-pkgs/cooldown v0.1.1 + github.com/git-pkgs/enrichment v0.2.3 + github.com/git-pkgs/purl v0.1.12 + github.com/git-pkgs/registries v0.6.1 + github.com/git-pkgs/spdx v0.1.4 + github.com/git-pkgs/vers v0.2.6 + github.com/git-pkgs/vulns v0.1.5 github.com/go-chi/chi/v5 v5.2.5 github.com/jmoiron/sqlx v1.4.0 github.com/lib/pq v1.12.3 @@ -22,12 +24,14 @@ require ( golang.org/x/sync v0.20.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 - modernc.org/sqlite v1.48.2 + modernc.org/sqlite v1.50.1 ) require ( 4d63.com/gocheckcompilerdirectives v1.3.0 // indirect 4d63.com/gochecknoglobals v0.2.2 // indirect + cloud.google.com/go/auth v0.18.2 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect codeberg.org/chavacava/garif v0.2.0 // indirect codeberg.org/polyfloyd/go-errorlint v1.9.0 // indirect @@ -40,7 +44,13 @@ require ( github.com/Antonboom/errname v1.1.1 // indirect github.com/Antonboom/nilnil v1.1.1 // indirect github.com/Antonboom/testifylint v1.6.4 // indirect - github.com/BurntSushi/toml v1.6.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 // indirect + github.com/Azure/go-autorest v14.2.0+incompatible // indirect + github.com/Azure/go-autorest/autorest/to v0.4.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect github.com/Djarvur/go-err113 v0.1.1 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/Masterminds/semver/v3 v3.4.0 // indirect @@ -118,7 +128,8 @@ require ( github.com/fzipp/gocyclo v0.6.0 // indirect github.com/ghostiam/protogetter v0.3.20 // indirect github.com/git-pkgs/packageurl-go v0.3.1 // indirect - github.com/github/go-spdx/v2 v2.4.0 // indirect + github.com/git-pkgs/pom v0.1.4 // indirect + github.com/github/go-spdx/v2 v2.7.0 // indirect github.com/go-critic/go-critic v0.14.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -138,6 +149,7 @@ require ( github.com/gobwas/glob v0.2.3 // indirect github.com/godoc-lint/godoc-lint v0.11.2 // indirect github.com/gofrs/flock v0.13.0 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golangci/asciicheck v0.5.0 // indirect github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect github.com/golangci/go-printf-func-name v0.1.1 // indirect @@ -150,8 +162,10 @@ require ( github.com/golangci/swaggoswag v0.0.0-20250504205917-77f2aca3143e // indirect github.com/golangci/unconvert v0.0.0-20250410112200-a129a6e6413e // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/google/wire v0.7.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.12 // indirect github.com/googleapis/gax-go/v2 v2.17.0 // indirect github.com/gordonklaus/ineffassign v0.2.0 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect @@ -174,6 +188,7 @@ require ( github.com/kkHAIKE/contextcheck v1.1.6 // indirect github.com/kulti/thelper v0.7.1 // indirect github.com/kunwardeep/paralleltest v1.0.15 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/lasiar/canonicalheader v1.1.2 // indirect github.com/ldez/exptostd v0.4.5 // indirect github.com/ldez/gomoddirectives v0.8.0 // indirect @@ -206,9 +221,11 @@ require ( github.com/nishanths/predeclared v0.2.2 // indirect github.com/nunnatsa/ginkgolinter v0.23.0 // indirect github.com/oapi-codegen/runtime v1.2.0 // indirect + github.com/package-url/packageurl-go v0.1.6 // indirect github.com/pandatix/go-cvss v0.6.2 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.20.1 // indirect @@ -277,10 +294,12 @@ require ( go.uber.org/zap v1.27.1 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/crypto v0.48.0 // indirect golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/tools v0.42.0 // indirect @@ -291,7 +310,7 @@ require ( gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect honnef.co/go/tools v0.7.0 // indirect - modernc.org/libc v1.70.0 // indirect + modernc.org/libc v1.72.3 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect mvdan.cc/gofumpt v0.9.2 // indirect diff --git a/go.sum b/go.sum index 49b11f9..23c3df7 100644 --- a/go.sum +++ b/go.sum @@ -43,11 +43,31 @@ github.com/Antonboom/nilnil v1.1.1 h1:9Mdr6BYd8WHCDngQnNVV0b554xyisFioEKi30sksuf github.com/Antonboom/nilnil v1.1.1/go.mod h1:yCyAmSw3doopbOWhJlVci+HuyNRuHJKIv6V2oYQa8II= github.com/Antonboom/testifylint v1.6.4 h1:gs9fUEy+egzxkEbq9P4cpcMB6/G0DYdMeiFS87UiqmQ= github.com/Antonboom/testifylint v1.6.4/go.mod h1:YO33FROXX2OoUfwjz8g+gUxQXio5i9qpVy7nXGbxDD4= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 h1:ZJJNFaQ86GVKQ9ehwqyAFE6pIfyicpuJ8IkVaPBc6/4= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3/go.mod h1:URuDvhmATVKqHBH9/0nOiNKk0+YcwfQ3WkK5PqHKxc8= +github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= +github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= +github.com/Azure/go-autorest/autorest/to v0.4.1 h1:CxNHBqdzTr7rLtdrtb5CMjJcDut+WNGCVv7OmS5+lTc= +github.com/Azure/go-autorest/autorest/to v0.4.1/go.mod h1:EtaofgU4zmtvn1zT2ARsjRFdq9vXx0YWtmElwL+GZ9M= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/CycloneDX/cyclonedx-go v0.10.0 h1:7xyklU7YD+CUyGzSFIARG18NYLsKVn4QFg04qSsu+7Y= -github.com/CycloneDX/cyclonedx-go v0.10.0/go.mod h1:vUvbCXQsEm48OI6oOlanxstwNByXjCZ2wuleUlwGEO8= +github.com/CycloneDX/cyclonedx-go v0.11.0 h1:GokP8FiRC+foiuwWhSSLpSD5H4hSWtGnR3wo7apkBFI= +github.com/CycloneDX/cyclonedx-go v0.11.0/go.mod h1:vUvbCXQsEm48OI6oOlanxstwNByXjCZ2wuleUlwGEO8= github.com/Djarvur/go-err113 v0.1.1 h1:eHfopDqXRwAi+YmCUas75ZE0+hoBHJ2GQNLYRSxao4g= github.com/Djarvur/go-err113 v0.1.1/go.mod h1:IaWJdYFLg76t2ihfflPZnM1LIQszWOsFDh2hhhAVF6k= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= @@ -230,24 +250,28 @@ github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= github.com/ghostiam/protogetter v0.3.20 h1:oW7OPFit2FxZOpmMRPP9FffU4uUpfeE/rEdE1f+MzD0= github.com/ghostiam/protogetter v0.3.20/go.mod h1:FjIu5Yfs6FT391m+Fjp3fbAYJ6rkL/J6ySpZBfnODuI= -github.com/git-pkgs/archives v0.2.2 h1:RxOjrV8RzKicbMVdf2GDKOqIOHZNVjrLY/Pc7KSE/WQ= -github.com/git-pkgs/archives v0.2.2/go.mod h1:LTJ1iQVFA7otizWMOyiI82NYVmyBWAPRzwu/e30rcXU= -github.com/git-pkgs/enrichment v0.2.2 h1:vaQu5vs3tjQB5JI0gzBrUCynUc9z3l5byPhgKFaNZrc= -github.com/git-pkgs/enrichment v0.2.2/go.mod h1:5JWGmlHWcv5HQHUrctcpnRUNpEF5VAixD2z4zvqKejs= +github.com/git-pkgs/archives v0.3.0 h1:iXKyO83jEFub1PGEDlHmk2tQ7XeV5LySTc0sEkH3x78= +github.com/git-pkgs/archives v0.3.0/go.mod h1:LTJ1iQVFA7otizWMOyiI82NYVmyBWAPRzwu/e30rcXU= +github.com/git-pkgs/cooldown v0.1.1 h1:9OqqzCB8gANz/y44SmqGD0Jp8Qtu81D1sCbKl6Ehg7w= +github.com/git-pkgs/cooldown v0.1.1/go.mod h1:v7APuK/UouTiu8mWQZbdDmj7DfxxkGUeuhjaRB5gv9E= +github.com/git-pkgs/enrichment v0.2.3 h1:42mqoUhQZNGhlEO671pboI/Cu6F+DoffJoFbVhb2jlw= +github.com/git-pkgs/enrichment v0.2.3/go.mod h1:MBv5nhHzjwLxeSgx2+7waCcpReUjhCD+9B0bvufpMO0= github.com/git-pkgs/packageurl-go v0.3.1 h1:WM3RBABQZLaRBxgKyYughc3cVBE8KyQxbSC6Jt5ak7M= github.com/git-pkgs/packageurl-go v0.3.1/go.mod h1:rcIxiG37BlQLB6FZfgdj9Fm7yjhRQd3l+5o7J0QPAk4= -github.com/git-pkgs/purl v0.1.10 h1:NMjeF10nzFn3tdQlz6rbmHB+i+YkyrFQxho3e33ePTQ= -github.com/git-pkgs/purl v0.1.10/go.mod h1:C5Vp/kyZ/wGckCLexx4wPVfUxEiToRkdsOPh5Z7ig/I= -github.com/git-pkgs/registries v0.4.1 h1:4qlKVNhC/6x6Bt87t3wrGJtF3EFrUpHQt9/zKsa5IvU= -github.com/git-pkgs/registries v0.4.1/go.mod h1:49UCPFWQmwNV7rBEr9TrTDWKR7vYxFcxp3VfdkeFbdE= -github.com/git-pkgs/spdx v0.1.2 h1:wHSK+CqFsO5N7yDTPvxDmer5LgNEa7vAsiZhi5Aci0A= -github.com/git-pkgs/spdx v0.1.2/go.mod h1:V98MgZapNgYw54/pdGR82d7RU93qzJoybahbpZqTfw8= -github.com/git-pkgs/vers v0.2.4 h1:Zr3jR/Xf1i/6cvBaJKPxhCwjzqz7uvYHE0Fhid/GPBk= -github.com/git-pkgs/vers v0.2.4/go.mod h1:biTbSQK1qdbrsxDEKnqe3Jzclxz8vW6uDcwKjfUGcOo= -github.com/git-pkgs/vulns v0.1.4 h1:SlnGWHNmtdQgABjfrX/I/pVe+DWLbZ5Yi9xg+/De5r8= -github.com/git-pkgs/vulns v0.1.4/go.mod h1:34xkR7QncIVfxoi78k3YT6Y9DfTEaL7j6PzCqjsRP9U= -github.com/github/go-spdx/v2 v2.4.0 h1:+4IwVwJJbm3rzvrQ6P1nI9BDMcy3la4RchRy5uehV/M= -github.com/github/go-spdx/v2 v2.4.0/go.mod h1:/5rwgS0txhGtRdUZwc02bTglzg6HK3FfuEbECKlK2Sg= +github.com/git-pkgs/pom v0.1.4 h1:C6st+XSbF75eKuwfdkDZZtYHoTcaWRIEQYar5VtszUo= +github.com/git-pkgs/pom v0.1.4/go.mod h1:ufdMBe1lKzqOeP9IUb9NPZ458xKV8E8NvuyBMxOfwIk= +github.com/git-pkgs/purl v0.1.12 h1:qCskrEU1LWQhCkIVZd992W5++Bsxazvx2Cx1/65qCvU= +github.com/git-pkgs/purl v0.1.12/go.mod h1:ofp4mHsR0cUeVONQaf33n6Wxg2QTEvtUdRfCedI8ouA= +github.com/git-pkgs/registries v0.6.1 h1:xZfVZQmffIfdeJthn5o2EozbVJ6gBeImYwKQnfdKUfU= +github.com/git-pkgs/registries v0.6.1/go.mod h1:a3BP/56VW3O/CFRqiJCtSy+OqRrSH25wF1PWHP76ka0= +github.com/git-pkgs/spdx v0.1.4 h1:eQ0waEV3uUeItpWAOvdN1K1rL9hTgsU7fF74r1mDXMs= +github.com/git-pkgs/spdx v0.1.4/go.mod h1:cqRoZcvl530s/W+oGNvwjt4ODN8T1W6D/20MUZEFdto= +github.com/git-pkgs/vers v0.2.6 h1:IelZd7BP/JhzTloUTDY67nehUgoYva3g9viqAMCHJg8= +github.com/git-pkgs/vers v0.2.6/go.mod h1:biTbSQK1qdbrsxDEKnqe3Jzclxz8vW6uDcwKjfUGcOo= +github.com/git-pkgs/vulns v0.1.5 h1:mtX88/27toFl+B95kaH5QbAdOCQ3YIDGjJrlrrnqQTE= +github.com/git-pkgs/vulns v0.1.5/go.mod h1:bZFikfrR/5gC0ZMwXh7qcEu2gpKfXMBhVsy4kF12Ae0= +github.com/github/go-spdx/v2 v2.7.0 h1:GzfXx4wFdlilARxmFRXW/mgUy3A4vSqZocCMFV6XFdQ= +github.com/github/go-spdx/v2 v2.7.0/go.mod h1:Ftc45YYG1WzpzwEPKRVm9Jv8vDqOrN4gWoCkK+bHer0= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-critic/go-critic v0.14.3 h1:5R1qH2iFeo4I/RJU8vTezdqs08Egi4u5p6vOESA0pog= @@ -305,6 +329,8 @@ github.com/godoc-lint/godoc-lint v0.11.2 h1:Bp0FkJWoSdNsBikdNgIcgtaoo+xz6I/Y9s5W github.com/godoc-lint/godoc-lint v0.11.2/go.mod h1:iVpGdL1JCikNH2gGeAn3Hh+AgN5Gx/I/cxV+91L41jo= github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golangci/asciicheck v0.5.0 h1:jczN/BorERZwK8oiFBOGvlGPknhvq0bjnysTj4nUfo0= @@ -396,6 +422,8 @@ github.com/julz/importas v0.2.0 h1:y+MJN/UdL63QbFJHws9BVC5RpA2iq0kpjrFajTGivjQ= github.com/julz/importas v0.2.0/go.mod h1:pThlt589EnCYtMnmhmRYY/qn9lCf/frPOK+WMx3xiJY= github.com/karamaru-alpha/copyloopvar v1.2.2 h1:yfNQvP9YaGQR7VaWLYcfZUlRP2eo2vhExWKxD/fP6q0= github.com/karamaru-alpha/copyloopvar v1.2.2/go.mod h1:oY4rGZqZ879JkJMtX3RRkcXRkmUvH0x35ykgaKgsgJY= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/kisielk/errcheck v1.9.0 h1:9xt1zI9EBfcYBvdU1nVrzMzzUPUtPKs9bVSIM3TAb3M= github.com/kisielk/errcheck v1.9.0/go.mod h1:kQxWMMVZgIkDq7U8xtG/n2juOjbLgZtedi0D+/VL/i8= github.com/kkHAIKE/contextcheck v1.1.6 h1:7HIyRcnyzxL9Lz06NGhiKvenXq7Zw6Q0UQu/ttjfJCE= @@ -501,6 +529,8 @@ github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJ github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= +github.com/package-url/packageurl-go v0.1.6 h1:YO3p6u1XmCUliivUg/qWphaY8vI6hxSnnPv7Bfg3m5M= +github.com/package-url/packageurl-go v0.1.6/go.mod h1:nKAWB8E6uk1MHqiS/lQb9pYBGH2+mdJ2PJc2s50dQY0= github.com/pandatix/go-cvss v0.6.2 h1:TFiHlzUkT67s6UkelHmK6s1INKVUG7nlKYiWWDTITGI= github.com/pandatix/go-cvss v0.6.2/go.mod h1:jDXYlQBZrc8nvrMUVVvTG8PhmuShOnKrxP53nOFkt8Q= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= @@ -509,6 +539,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/peterbourgon/g2s v0.0.0-20170223122336-d4e7ad98afea h1:sKwxy1H95npauwu8vtF95vG/syrL0p8fSZo/XlDg5gk= github.com/peterbourgon/g2s v0.0.0-20170223122336-d4e7ad98afea/go.mod h1:1VcHEd3ro4QMoHfiNl/j7Jkln9+KQuorp0PItHMJYNg= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -771,6 +803,7 @@ golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -851,10 +884,10 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU= honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc= -modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= -modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= -modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= -modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= +modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= +modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= +modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ= +modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A= modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= @@ -863,18 +896,18 @@ modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= -modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= -modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU= +modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= -modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg= +modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= -modernc.org/sqlite v1.48.2 h1:5CnW4uP8joZtA0LedVqLbZV5GD7F/0x91AXeSyjoh5c= -modernc.org/sqlite v1.48.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w= +modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= diff --git a/internal/config/config.go b/internal/config/config.go index 2984d5c..0e8405d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -51,6 +51,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/url" "os" "path/filepath" "strconv" @@ -95,9 +96,20 @@ type Config struct { // Default: "5m". Set to "0" to always revalidate. MetadataTTL string `json:"metadata_ttl" yaml:"metadata_ttl"` + // MetadataMaxSize is the maximum size of an upstream metadata response + // the proxy will buffer (e.g. "100MB", "250MB"). Responses over this + // size return ErrMetadataTooLarge. Default: "100MB". + MetadataMaxSize string `json:"metadata_max_size" yaml:"metadata_max_size"` + // MirrorAPI enables the /api/mirror endpoints for starting mirror jobs via HTTP. // Disabled by default to prevent unauthenticated users from triggering downloads. MirrorAPI bool `json:"mirror_api" yaml:"mirror_api"` + + // Gradle configures Gradle HttpBuildCache behavior. + Gradle GradleConfig `json:"gradle" yaml:"gradle"` + + // Health configures the /health endpoint behavior. + Health HealthConfig `json:"health" yaml:"health"` } // CooldownConfig configures version cooldown periods. @@ -133,6 +145,57 @@ type StorageConfig struct { // When exceeded, least recently used artifacts are evicted. // Empty or "0" means unlimited. MaxSize string `json:"max_size" yaml:"max_size"` + + // DirectServe enables redirecting cached artifact downloads to presigned + // storage URLs (HTTP 302) instead of streaming bytes through the proxy. + // Only effective for backends that support URL signing (S3, Azure). + DirectServe bool `json:"direct_serve" yaml:"direct_serve"` + + // DirectServeTTL is how long presigned URLs remain valid. + // Uses Go duration syntax (e.g. "5m", "1h"). Default: "15m". + DirectServeTTL string `json:"direct_serve_ttl" yaml:"direct_serve_ttl"` + + // DirectServeBaseURL overrides the scheme and host of presigned URLs + // before returning them to clients. Useful when the proxy reaches + // storage at an internal address (e.g. 127.0.0.1 or a Docker hostname) + // but clients must use a public one. + DirectServeBaseURL string `json:"direct_serve_base_url" yaml:"direct_serve_base_url"` +} + +// GradleConfig configures Gradle-specific features. +type GradleConfig struct { + // BuildCache configures the /gradle HttpBuildCache endpoint. + BuildCache GradleBuildCacheConfig `json:"build_cache" yaml:"build_cache"` +} + +// GradleBuildCacheConfig configures Gradle HttpBuildCache safeguards. +type GradleBuildCacheConfig struct { + // ReadOnly disables PUT uploads and keeps cache reads (GET/HEAD) enabled. + ReadOnly bool `json:"read_only" yaml:"read_only"` + + // MaxUploadSize caps a single PUT body size (e.g., "100MB"). Must be > 0. + // Default: "100MB". + MaxUploadSize string `json:"max_upload_size" yaml:"max_upload_size"` + + // MaxAge evicts entries older than this duration (e.g., "24h", "7d"). + // Empty or "0" disables age-based eviction. + MaxAge string `json:"max_age" yaml:"max_age"` + + // MaxSize evicts oldest entries until total Gradle cache size is <= MaxSize. + // Empty or "0" disables size-based eviction. + MaxSize string `json:"max_size" yaml:"max_size"` + + // SweepInterval controls periodic eviction frequency. + // Default: "10m". + SweepInterval string `json:"sweep_interval" yaml:"sweep_interval"` +} + +// HealthConfig configures the /health endpoint. +type HealthConfig struct { + // StorageProbeInterval is the minimum time between storage backend probes. + // Uses Go duration syntax (e.g. "30s", "1m"). Default: "30s". + // Set to "0" to probe on every /health request (useful for low-traffic deployments). + StorageProbeInterval string `json:"storage_probe_interval" yaml:"storage_probe_interval"` } // DatabaseConfig configures the cache database. @@ -163,6 +226,15 @@ type UpstreamConfig struct { // Default: https://registry.npmjs.org NPM string `json:"npm" yaml:"npm"` + // Maven is the upstream Maven repository URL. + // Default: https://repo1.maven.org/maven2 + Maven string `json:"maven" yaml:"maven"` + + // GradlePluginPortal is the upstream Gradle Plugin Portal Maven URL. + // Used to resolve Gradle plugin marker artifacts. + // Default: https://plugins.gradle.org/m2 + GradlePluginPortal string `json:"gradle_plugin_portal" yaml:"gradle_plugin_portal"` + // Cargo is the upstream cargo index URL. // Default: https://index.crates.io Cargo string `json:"cargo" yaml:"cargo"` @@ -240,9 +312,20 @@ func Default() *Config { Format: "text", }, Upstream: UpstreamConfig{ - NPM: "https://registry.npmjs.org", - Cargo: "https://index.crates.io", - CargoDownload: "https://static.crates.io/crates", + NPM: "https://registry.npmjs.org", + Maven: "https://repo1.maven.org/maven2", + GradlePluginPortal: "https://plugins.gradle.org/m2", + Cargo: "https://index.crates.io", + CargoDownload: "https://static.crates.io/crates", + }, + Gradle: GradleConfig{ + BuildCache: GradleBuildCacheConfig{ + ReadOnly: false, + MaxUploadSize: defaultGradleMaxUploadSizeStr, + MaxAge: "168h", + MaxSize: "", + SweepInterval: defaultGradleSweepIntervalStr, + }, }, } } @@ -287,6 +370,7 @@ func Load(path string) (*Config, error) { // - PROXY_DATABASE_PATH // - PROXY_LOG_LEVEL // - PROXY_LOG_FORMAT +// - PROXY_HEALTH_STORAGE_PROBE_INTERVAL func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_LISTEN"); v != "" { c.Listen = v @@ -303,6 +387,15 @@ func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_STORAGE_MAX_SIZE"); v != "" { c.Storage.MaxSize = v } + if v := os.Getenv("PROXY_STORAGE_DIRECT_SERVE"); v != "" { + c.Storage.DirectServe = envBool(v) + } + if v := os.Getenv("PROXY_STORAGE_DIRECT_SERVE_TTL"); v != "" { + c.Storage.DirectServeTTL = v + } + if v := os.Getenv("PROXY_STORAGE_DIRECT_SERVE_BASE_URL"); v != "" { + c.Storage.DirectServeBaseURL = v + } if v := os.Getenv("PROXY_DATABASE_DRIVER"); v != "" { c.Database.Driver = v } @@ -318,18 +411,45 @@ func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_LOG_FORMAT"); v != "" { c.Log.Format = v } + if v := os.Getenv("PROXY_UPSTREAM_MAVEN"); v != "" { + c.Upstream.Maven = v + } + if v := os.Getenv("PROXY_UPSTREAM_GRADLE_PLUGIN_PORTAL"); v != "" { + c.Upstream.GradlePluginPortal = v + } if v := os.Getenv("PROXY_COOLDOWN_DEFAULT"); v != "" { c.Cooldown.Default = v } if v := os.Getenv("PROXY_CACHE_METADATA"); v != "" { - c.CacheMetadata = v == "true" || v == "1" + c.CacheMetadata = envBool(v) } if v := os.Getenv("PROXY_MIRROR_API"); v != "" { - c.MirrorAPI = v == "true" || v == "1" + c.MirrorAPI = envBool(v) } if v := os.Getenv("PROXY_METADATA_TTL"); v != "" { c.MetadataTTL = v } + if v := os.Getenv("PROXY_METADATA_MAX_SIZE"); v != "" { + c.MetadataMaxSize = v + } + if v := os.Getenv("PROXY_GRADLE_BUILD_CACHE_READ_ONLY"); v != "" { + c.Gradle.BuildCache.ReadOnly = v == "true" || v == "1" + } + if v := os.Getenv("PROXY_GRADLE_BUILD_CACHE_MAX_UPLOAD_SIZE"); v != "" { + c.Gradle.BuildCache.MaxUploadSize = v + } + if v := os.Getenv("PROXY_GRADLE_BUILD_CACHE_MAX_AGE"); v != "" { + c.Gradle.BuildCache.MaxAge = v + } + if v := os.Getenv("PROXY_GRADLE_BUILD_CACHE_MAX_SIZE"); v != "" { + c.Gradle.BuildCache.MaxSize = v + } + if v := os.Getenv("PROXY_GRADLE_BUILD_CACHE_SWEEP_INTERVAL"); v != "" { + c.Gradle.BuildCache.SweepInterval = v + } + if v := os.Getenv("PROXY_HEALTH_STORAGE_PROBE_INTERVAL"); v != "" { + c.Health.StorageProbeInterval = v + } } // Validate checks the configuration for errors. @@ -379,6 +499,21 @@ func (c *Config) Validate() error { } } + // Validate direct serve TTL if specified + if c.Storage.DirectServeTTL != "" { + if _, err := time.ParseDuration(c.Storage.DirectServeTTL); err != nil { + return fmt.Errorf("invalid storage.direct_serve_ttl %q: %w", c.Storage.DirectServeTTL, err) + } + } + + // Validate direct serve base URL if specified + if c.Storage.DirectServeBaseURL != "" { + u, err := url.Parse(c.Storage.DirectServeBaseURL) + if err != nil || u.Scheme == "" || u.Host == "" { + return fmt.Errorf("invalid storage.direct_serve_base_url %q: must be an absolute URL", c.Storage.DirectServeBaseURL) + } + } + // Validate metadata TTL if specified if c.MetadataTTL != "" && c.MetadataTTL != "0" { if _, err := time.ParseDuration(c.MetadataTTL); err != nil { @@ -386,10 +521,85 @@ func (c *Config) Validate() error { } } + if err := validateMetadataMaxSize(c.MetadataMaxSize); err != nil { + return err + } + + if err := c.Health.Validate(); err != nil { + return err + } + + if err := c.Gradle.BuildCache.Validate(); err != nil { + return err + } + return nil } -const defaultMetadataTTL = 5 * time.Minute //nolint:mnd // sensible default +// Validate checks the /health configuration. An unset interval is allowed +// (the cache uses its default); explicit values must parse and be non-negative. +func (h *HealthConfig) Validate() error { + if h.StorageProbeInterval == "" || h.StorageProbeInterval == "0" { + return nil + } + d, err := time.ParseDuration(h.StorageProbeInterval) + if err != nil { + return fmt.Errorf("invalid health.storage_probe_interval %q: %w", h.StorageProbeInterval, err) + } + if d < 0 { + return fmt.Errorf("invalid health.storage_probe_interval %q: must be non-negative", h.StorageProbeInterval) + } + return nil +} + +// Validate checks Gradle build cache settings, applying the default upload +// size if unset. +func (g *GradleBuildCacheConfig) Validate() error { + if g.MaxUploadSize == "" { + g.MaxUploadSize = defaultGradleMaxUploadSizeStr + } + uploadSize, err := ParseSize(g.MaxUploadSize) + if err != nil { + return fmt.Errorf("invalid gradle.build_cache.max_upload_size: %w", err) + } + if uploadSize <= 0 { + return fmt.Errorf("invalid gradle.build_cache.max_upload_size %q: must be > 0", g.MaxUploadSize) + } + + if g.MaxAge != "" && g.MaxAge != "0" { + if _, err := time.ParseDuration(g.MaxAge); err != nil { + return fmt.Errorf("invalid gradle.build_cache.max_age %q: %w", g.MaxAge, err) + } + } + + if g.MaxSize != "" { + if _, err := ParseSize(g.MaxSize); err != nil { + return fmt.Errorf("invalid gradle.build_cache.max_size: %w", err) + } + } + + if g.SweepInterval != "" { + d, err := time.ParseDuration(g.SweepInterval) + if err != nil { + return fmt.Errorf("invalid gradle.build_cache.sweep_interval %q: %w", g.SweepInterval, err) + } + if d <= 0 { + return fmt.Errorf("invalid gradle.build_cache.sweep_interval %q: must be > 0", g.SweepInterval) + } + } + + return nil +} + +const ( + defaultMetadataTTL = 5 * time.Minute //nolint:mnd // sensible default + defaultDirectServeTTL = 15 * time.Minute //nolint:mnd // sensible default + defaultMetadataMaxSize = 100 << 20 + defaultGradleBuildCacheMaxUploadSize = 100 << 20 + defaultGradleBuildCacheSweepInterval = 10 * time.Minute + defaultGradleMaxUploadSizeStr = "100MB" + defaultGradleSweepIntervalStr = "10m" +) // ParseMaxSize returns the maximum cache size in bytes. // Returns 0 if unset or explicitly disabled (meaning unlimited). @@ -404,6 +614,33 @@ func (c *Config) ParseMaxSize() int64 { return size } +func validateMetadataMaxSize(s string) error { + if s == "" { + return nil + } + size, err := ParseSize(s) + if err != nil { + return fmt.Errorf("invalid metadata_max_size: %w", err) + } + if size <= 0 { + return fmt.Errorf("invalid metadata_max_size %q: must be positive", s) + } + return nil +} + +// ParseMetadataMaxSize returns the maximum metadata response size in bytes. +// Returns 100MB if unset or invalid. +func (c *Config) ParseMetadataMaxSize() int64 { + if c.MetadataMaxSize == "" { + return defaultMetadataMaxSize + } + size, err := ParseSize(c.MetadataMaxSize) + if err != nil || size <= 0 { + return defaultMetadataMaxSize + } + return size +} + // ParseMetadataTTL returns the metadata TTL duration. // Returns 5 minutes if unset, 0 if explicitly disabled. func (c *Config) ParseMetadataTTL() time.Duration { @@ -420,6 +657,71 @@ func (c *Config) ParseMetadataTTL() time.Duration { return d } +// ParseGradleBuildCacheMaxUploadSize returns the max accepted PUT body size. +// Defaults to 100MB if unset or invalid. +func (c *Config) ParseGradleBuildCacheMaxUploadSize() int64 { + if c.Gradle.BuildCache.MaxUploadSize == "" { + return defaultGradleBuildCacheMaxUploadSize + } + size, err := ParseSize(c.Gradle.BuildCache.MaxUploadSize) + if err != nil || size <= 0 { + return defaultGradleBuildCacheMaxUploadSize + } + return size +} + +// ParseGradleBuildCacheMaxAge returns age-based eviction threshold. +// Returns 0 when disabled or invalid. +func (c *Config) ParseGradleBuildCacheMaxAge() time.Duration { + if c.Gradle.BuildCache.MaxAge == "" || c.Gradle.BuildCache.MaxAge == "0" { + return 0 + } + d, err := time.ParseDuration(c.Gradle.BuildCache.MaxAge) + if err != nil || d <= 0 { + return 0 + } + return d +} + +// ParseGradleBuildCacheMaxSize returns total-size cap in bytes. +// Returns 0 when disabled or invalid. +func (c *Config) ParseGradleBuildCacheMaxSize() int64 { + if c.Gradle.BuildCache.MaxSize == "" || c.Gradle.BuildCache.MaxSize == "0" { + return 0 + } + size, err := ParseSize(c.Gradle.BuildCache.MaxSize) + if err != nil || size <= 0 { + return 0 + } + return size +} + +// ParseGradleBuildCacheSweepInterval returns eviction sweep cadence. +// Defaults to 10m if unset or invalid. +func (c *Config) ParseGradleBuildCacheSweepInterval() time.Duration { + if c.Gradle.BuildCache.SweepInterval == "" { + return defaultGradleBuildCacheSweepInterval + } + d, err := time.ParseDuration(c.Gradle.BuildCache.SweepInterval) + if err != nil || d <= 0 { + return defaultGradleBuildCacheSweepInterval + } + return d +} + +// ParseDirectServeTTL returns the presigned URL expiry duration. +// Returns 15 minutes if unset. +func (c *Config) ParseDirectServeTTL() time.Duration { + if c.Storage.DirectServeTTL == "" { + return defaultDirectServeTTL + } + d, err := time.ParseDuration(c.Storage.DirectServeTTL) + if err != nil { + return defaultDirectServeTTL + } + return d +} + // ParseSize parses a human-readable size string (e.g., "10GB", "500MB"). // Returns the size in bytes. func ParseSize(s string) (int64, error) { @@ -500,3 +802,7 @@ func (a *AuthConfig) Header() (name, value string) { func expandEnv(s string) string { return os.Expand(s, os.Getenv) } + +func envBool(v string) bool { + return v == "true" || v == "1" +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 25f3488..d633c25 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -25,6 +25,18 @@ func TestDefault(t *testing.T) { if cfg.Database.Path == "" { t.Error("Database.Path should not be empty") } + if cfg.Gradle.BuildCache.MaxUploadSize != "100MB" { + t.Errorf("Gradle.BuildCache.MaxUploadSize = %q, want %q", cfg.Gradle.BuildCache.MaxUploadSize, "100MB") + } + if cfg.Gradle.BuildCache.MaxAge != "168h" { + t.Errorf("Gradle.BuildCache.MaxAge = %q, want %q", cfg.Gradle.BuildCache.MaxAge, "168h") + } + if cfg.Upstream.Maven != "https://repo1.maven.org/maven2" { + t.Errorf("Upstream.Maven = %q, want %q", cfg.Upstream.Maven, "https://repo1.maven.org/maven2") + } + if cfg.Upstream.GradlePluginPortal != "https://plugins.gradle.org/m2" { + t.Errorf("Upstream.GradlePluginPortal = %q, want %q", cfg.Upstream.GradlePluginPortal, "https://plugins.gradle.org/m2") + } } func TestValidate(t *testing.T) { @@ -98,6 +110,41 @@ func TestValidate(t *testing.T) { modify: func(c *Config) { c.Storage.MaxSize = "10GB" }, wantErr: false, }, + { + name: "invalid gradle upload size", + modify: func(c *Config) { c.Gradle.BuildCache.MaxUploadSize = testInvalid }, + wantErr: true, + }, + { + name: "zero gradle upload size", + modify: func(c *Config) { c.Gradle.BuildCache.MaxUploadSize = "0" }, + wantErr: true, + }, + { + name: "invalid gradle max age", + modify: func(c *Config) { c.Gradle.BuildCache.MaxAge = testInvalid }, + wantErr: true, + }, + { + name: "valid gradle max age", + modify: func(c *Config) { c.Gradle.BuildCache.MaxAge = "24h" }, + wantErr: false, + }, + { + name: "invalid gradle max size", + modify: func(c *Config) { c.Gradle.BuildCache.MaxSize = testInvalid }, + wantErr: true, + }, + { + name: "invalid gradle sweep interval", + modify: func(c *Config) { c.Gradle.BuildCache.SweepInterval = "0" }, + wantErr: true, + }, + { + name: "valid gradle sweep interval", + modify: func(c *Config) { c.Gradle.BuildCache.SweepInterval = "30m" }, + wantErr: false, + }, } for _, tt := range tests { @@ -223,6 +270,13 @@ func TestLoadFromEnv(t *testing.T) { t.Setenv("PROXY_BASE_URL", "https://env.example.com") t.Setenv("PROXY_STORAGE_PATH", "/env/cache") t.Setenv("PROXY_LOG_LEVEL", testLevelDebug) + t.Setenv("PROXY_UPSTREAM_MAVEN", "https://maven.example.com/repository/maven-public") + t.Setenv("PROXY_UPSTREAM_GRADLE_PLUGIN_PORTAL", "https://plugins.example.com/m2") + t.Setenv("PROXY_GRADLE_BUILD_CACHE_READ_ONLY", "true") + t.Setenv("PROXY_GRADLE_BUILD_CACHE_MAX_UPLOAD_SIZE", "32MB") + t.Setenv("PROXY_GRADLE_BUILD_CACHE_MAX_AGE", "12h") + t.Setenv("PROXY_GRADLE_BUILD_CACHE_MAX_SIZE", "10GB") + t.Setenv("PROXY_GRADLE_BUILD_CACHE_SWEEP_INTERVAL", "15m") cfg.LoadFromEnv() @@ -238,6 +292,27 @@ func TestLoadFromEnv(t *testing.T) { if cfg.Log.Level != testLevelDebug { t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, testLevelDebug) } + if cfg.Upstream.Maven != "https://maven.example.com/repository/maven-public" { + t.Errorf("Upstream.Maven = %q, want %q", cfg.Upstream.Maven, "https://maven.example.com/repository/maven-public") + } + if cfg.Upstream.GradlePluginPortal != "https://plugins.example.com/m2" { + t.Errorf("Upstream.GradlePluginPortal = %q, want %q", cfg.Upstream.GradlePluginPortal, "https://plugins.example.com/m2") + } + if !cfg.Gradle.BuildCache.ReadOnly { + t.Error("Gradle.BuildCache.ReadOnly = false, want true") + } + if cfg.Gradle.BuildCache.MaxUploadSize != "32MB" { + t.Errorf("Gradle.BuildCache.MaxUploadSize = %q, want %q", cfg.Gradle.BuildCache.MaxUploadSize, "32MB") + } + if cfg.Gradle.BuildCache.MaxAge != "12h" { + t.Errorf("Gradle.BuildCache.MaxAge = %q, want %q", cfg.Gradle.BuildCache.MaxAge, "12h") + } + if cfg.Gradle.BuildCache.MaxSize != "10GB" { + t.Errorf("Gradle.BuildCache.MaxSize = %q, want %q", cfg.Gradle.BuildCache.MaxSize, "10GB") + } + if cfg.Gradle.BuildCache.SweepInterval != "15m" { + t.Errorf("Gradle.BuildCache.SweepInterval = %q, want %q", cfg.Gradle.BuildCache.SweepInterval, "15m") + } } func TestLoadCooldownConfig(t *testing.T) { @@ -353,6 +428,52 @@ func TestParseMetadataTTL(t *testing.T) { } } +func TestParseMetadataMaxSize(t *testing.T) { + tests := []struct { + name string + size string + want int64 + }{ + {"unset uses default", "", defaultMetadataMaxSize}, + {"explicit value", "250MB", 250 << 20}, + {"bytes", "1024", 1024}, + {"invalid uses default", "lots", defaultMetadataMaxSize}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Default() + cfg.MetadataMaxSize = tt.size + got := cfg.ParseMetadataMaxSize() + if got != tt.want { + t.Errorf("ParseMetadataMaxSize() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestValidateMetadataMaxSize(t *testing.T) { + cfg := Default() + cfg.MetadataMaxSize = "not-a-size" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for invalid metadata_max_size") + } + + cfg.MetadataMaxSize = "0" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for zero metadata_max_size") + } + + cfg.MetadataMaxSize = "250MB" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for valid metadata_max_size: %v", err) + } + + cfg.MetadataMaxSize = "" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for unset metadata_max_size: %v", err) + } +} + func TestValidateMetadataTTL(t *testing.T) { cfg := Default() cfg.MetadataTTL = "invalid" @@ -371,6 +492,34 @@ func TestValidateMetadataTTL(t *testing.T) { } } +func TestValidateHealthStorageProbeInterval(t *testing.T) { + cfg := Default() + cfg.Health.StorageProbeInterval = "not-a-duration" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for invalid health.storage_probe_interval") + } + + cfg.Health.StorageProbeInterval = "30s" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for valid health.storage_probe_interval: %v", err) + } + + cfg.Health.StorageProbeInterval = "0" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for zero health.storage_probe_interval: %v", err) + } + + cfg.Health.StorageProbeInterval = "" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for empty health.storage_probe_interval: %v", err) + } + + cfg.Health.StorageProbeInterval = "-5s" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for negative health.storage_probe_interval") + } +} + func TestLoadMetadataTTLFromEnv(t *testing.T) { cfg := Default() t.Setenv("PROXY_METADATA_TTL", "10m") @@ -380,3 +529,112 @@ func TestLoadMetadataTTLFromEnv(t *testing.T) { t.Errorf("MetadataTTL = %q, want %q", cfg.MetadataTTL, "10m") } } + +func TestParseGradleBuildCacheConfig(t *testing.T) { + cfg := Default() + + if got := cfg.ParseGradleBuildCacheMaxUploadSize(); got != 100*1024*1024 { + t.Errorf("ParseGradleBuildCacheMaxUploadSize() = %d, want %d", got, 100*1024*1024) + } + if got := cfg.ParseGradleBuildCacheMaxAge(); got != 168*time.Hour { + t.Errorf("ParseGradleBuildCacheMaxAge() = %v, want %v", got, 168*time.Hour) + } + if got := cfg.ParseGradleBuildCacheMaxSize(); got != 0 { + t.Errorf("ParseGradleBuildCacheMaxSize() = %d, want 0", got) + } + if got := cfg.ParseGradleBuildCacheSweepInterval(); got != 10*time.Minute { + t.Errorf("ParseGradleBuildCacheSweepInterval() = %v, want %v", got, 10*time.Minute) + } + + cfg.Gradle.BuildCache.MaxUploadSize = "64MB" + cfg.Gradle.BuildCache.MaxAge = "48h" + cfg.Gradle.BuildCache.MaxSize = "2GB" + cfg.Gradle.BuildCache.SweepInterval = "20m" + + if got := cfg.ParseGradleBuildCacheMaxUploadSize(); got != 64*1024*1024 { + t.Errorf("ParseGradleBuildCacheMaxUploadSize() = %d, want %d", got, 64*1024*1024) + } + if got := cfg.ParseGradleBuildCacheMaxAge(); got != 48*time.Hour { + t.Errorf("ParseGradleBuildCacheMaxAge() = %v, want %v", got, 48*time.Hour) + } + if got := cfg.ParseGradleBuildCacheMaxSize(); got != 2*1024*1024*1024 { + t.Errorf("ParseGradleBuildCacheMaxSize() = %d, want %d", got, 2*1024*1024*1024) + } + if got := cfg.ParseGradleBuildCacheSweepInterval(); got != 20*time.Minute { + t.Errorf("ParseGradleBuildCacheSweepInterval() = %v, want %v", got, 20*time.Minute) + } +} + +func TestParseDirectServeTTL(t *testing.T) { + tests := []struct { + name string + ttl string + want time.Duration + }{ + {"empty defaults to 15m", "", 15 * time.Minute}, + {"5 minutes", "5m", 5 * time.Minute}, + {"1 hour", "1h", 1 * time.Hour}, + {"invalid defaults to 15m", "not-a-duration", 15 * time.Minute}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Default() + cfg.Storage.DirectServeTTL = tt.ttl + got := cfg.ParseDirectServeTTL() + if got != tt.want { + t.Errorf("ParseDirectServeTTL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidateDirectServeTTL(t *testing.T) { + cfg := Default() + cfg.Storage.DirectServeTTL = "invalid" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for invalid storage.direct_serve_ttl") + } + + cfg.Storage.DirectServeTTL = "5m" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for valid storage.direct_serve_ttl: %v", err) + } +} + +func TestLoadDirectServeFromEnv(t *testing.T) { + cfg := Default() + t.Setenv("PROXY_STORAGE_DIRECT_SERVE", "true") + t.Setenv("PROXY_STORAGE_DIRECT_SERVE_TTL", "30m") + t.Setenv("PROXY_STORAGE_DIRECT_SERVE_BASE_URL", "https://cdn.example.com") + cfg.LoadFromEnv() + + if !cfg.Storage.DirectServe { + t.Error("Storage.DirectServe should be true") + } + if cfg.Storage.DirectServeTTL != "30m" { + t.Errorf("Storage.DirectServeTTL = %q, want %q", cfg.Storage.DirectServeTTL, "30m") + } + if cfg.Storage.DirectServeBaseURL != "https://cdn.example.com" { + t.Errorf("Storage.DirectServeBaseURL = %q, want %q", cfg.Storage.DirectServeBaseURL, "https://cdn.example.com") + } +} + +func TestValidateDirectServeBaseURL(t *testing.T) { + cfg := Default() + + cfg.Storage.DirectServeBaseURL = "not a url" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for relative direct_serve_base_url") + } + + cfg.Storage.DirectServeBaseURL = "://bad" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for unparseable direct_serve_base_url") + } + + cfg.Storage.DirectServeBaseURL = "https://cdn.example.com" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for valid direct_serve_base_url: %v", err) + } +} diff --git a/internal/cooldown/cooldown.go b/internal/cooldown/cooldown.go deleted file mode 100644 index f37a2b9..0000000 --- a/internal/cooldown/cooldown.go +++ /dev/null @@ -1,125 +0,0 @@ -package cooldown - -import ( - "fmt" - "strconv" - "strings" - "time" -) - -const hoursPerDay = 24 - -// Config holds cooldown settings for version filtering. -// Cooldown hides package versions published too recently, giving the community -// time to spot malicious releases before they're pulled into projects. -type Config struct { - // Default is the global default cooldown duration (e.g., "3d", "48h"). - Default string `json:"default" yaml:"default"` - - // Ecosystems overrides the default for specific ecosystems. - // Keys are ecosystem names (e.g., "npm", "pypi"). - Ecosystems map[string]string `json:"ecosystems" yaml:"ecosystems"` - - // Packages overrides the cooldown for specific packages. - // Keys are PURLs (e.g., "pkg:npm/lodash", "pkg:npm/@babel/core"). - Packages map[string]string `json:"packages" yaml:"packages"` - - defaultDuration time.Duration - ecosystemDurations map[string]time.Duration - packageDurations map[string]time.Duration - parsed bool -} - -// parse resolves all string durations into time.Duration values. -// Called lazily on first use. -func (c *Config) parse() { - if c.parsed { - return - } - c.parsed = true - - c.defaultDuration, _ = ParseDuration(c.Default) - - c.ecosystemDurations = make(map[string]time.Duration, len(c.Ecosystems)) - for k, v := range c.Ecosystems { - d, _ := ParseDuration(v) - c.ecosystemDurations[k] = d - } - - c.packageDurations = make(map[string]time.Duration, len(c.Packages)) - for k, v := range c.Packages { - d, _ := ParseDuration(v) - c.packageDurations[k] = d - } -} - -// For returns the effective cooldown duration for a given ecosystem and package PURL. -// Resolution order: package override > ecosystem override > global default. -func (c *Config) For(ecosystem, packagePURL string) time.Duration { - c.parse() - - if d, ok := c.packageDurations[packagePURL]; ok { - return d - } - if d, ok := c.ecosystemDurations[ecosystem]; ok { - return d - } - return c.defaultDuration -} - -// IsAllowed returns true if a version with the given publish time has passed -// the cooldown period for this ecosystem/package. -func (c *Config) IsAllowed(ecosystem, packagePURL string, publishedAt time.Time) bool { - d := c.For(ecosystem, packagePURL) - if d == 0 { - return true - } - if publishedAt.IsZero() { - return true - } - return time.Since(publishedAt) >= d -} - -// Enabled returns true if any cooldown is configured. -func (c *Config) Enabled() bool { - c.parse() - if c.defaultDuration > 0 { - return true - } - for _, d := range c.ecosystemDurations { - if d > 0 { - return true - } - } - for _, d := range c.packageDurations { - if d > 0 { - return true - } - } - return false -} - -// ParseDuration parses a duration string supporting days (e.g., "3d"), -// in addition to Go's standard time.ParseDuration formats ("48h", "30m"). -// "0" means disabled (returns 0). -func ParseDuration(s string) (time.Duration, error) { - s = strings.TrimSpace(s) - if s == "" || s == "0" { - return 0, nil - } - - // Handle day suffix - if numStr, ok := strings.CutSuffix(s, "d"); ok { - days, err := strconv.ParseFloat(numStr, 64) - if err != nil { - return 0, fmt.Errorf("invalid duration %q: %w", s, err) - } - return time.Duration(days * float64(hoursPerDay*time.Hour)), nil - } - - d, err := time.ParseDuration(s) - if err != nil { - return 0, fmt.Errorf("invalid duration %q: %w", s, err) - } - return d, nil -} diff --git a/internal/cooldown/cooldown_test.go b/internal/cooldown/cooldown_test.go deleted file mode 100644 index c366077..0000000 --- a/internal/cooldown/cooldown_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package cooldown - -import ( - "testing" - "time" -) - -func TestParseDuration(t *testing.T) { - tests := []struct { - input string - want time.Duration - wantErr bool - }{ - {"", 0, false}, - {"0", 0, false}, - {"3d", 3 * 24 * time.Hour, false}, - {"7d", 7 * 24 * time.Hour, false}, - {"14d", 14 * 24 * time.Hour, false}, - {"1.5d", 36 * time.Hour, false}, - {"48h", 48 * time.Hour, false}, - {"30m", 30 * time.Minute, false}, - {"1h30m", 90 * time.Minute, false}, - {"invalid", 0, true}, - {"d", 0, true}, - {"xd", 0, true}, - } - - for _, tt := range tests { - got, err := ParseDuration(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("ParseDuration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) - continue - } - if got != tt.want { - t.Errorf("ParseDuration(%q) = %v, want %v", tt.input, got, tt.want) - } - } -} - -func TestConfigFor(t *testing.T) { - c := &Config{ - Default: "3d", - Ecosystems: map[string]string{ - "npm": "7d", - "cargo": "0", - }, - Packages: map[string]string{ - "pkg:npm/lodash": "0", - "pkg:npm/@babel/core": "14d", - }, - } - - tests := []struct { - ecosystem string - packagePURL string - want time.Duration - }{ - // Package override takes priority - {"npm", "pkg:npm/lodash", 0}, - {"npm", "pkg:npm/@babel/core", 14 * 24 * time.Hour}, - // Ecosystem override - {"npm", "pkg:npm/express", 7 * 24 * time.Hour}, - {"cargo", "pkg:cargo/serde", 0}, - // Global default - {"pypi", "pkg:pypi/requests", 3 * 24 * time.Hour}, - {"pub", "pkg:pub/flutter", 3 * 24 * time.Hour}, - } - - for _, tt := range tests { - got := c.For(tt.ecosystem, tt.packagePURL) - if got != tt.want { - t.Errorf("For(%q, %q) = %v, want %v", tt.ecosystem, tt.packagePURL, got, tt.want) - } - } -} - -func TestConfigIsAllowed(t *testing.T) { - c := &Config{ - Default: "3d", - Packages: map[string]string{ - "pkg:npm/lodash": "0", - }, - } - - now := time.Now() - - tests := []struct { - name string - ecosystem string - packagePURL string - publishedAt time.Time - want bool - }{ - {"old enough", "npm", "pkg:npm/express", now.Add(-4 * 24 * time.Hour), true}, - {"too recent", "npm", "pkg:npm/express", now.Add(-1 * 24 * time.Hour), false}, - {"exactly at boundary", "npm", "pkg:npm/express", now.Add(-3 * 24 * time.Hour), true}, - {"exempt package", "npm", "pkg:npm/lodash", now.Add(-1 * time.Minute), true}, - {"zero time", "npm", "pkg:npm/express", time.Time{}, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := c.IsAllowed(tt.ecosystem, tt.packagePURL, tt.publishedAt) - if got != tt.want { - t.Errorf("IsAllowed(%q, %q, %v) = %v, want %v", - tt.ecosystem, tt.packagePURL, tt.publishedAt, got, tt.want) - } - }) - } -} - -func TestConfigEnabled(t *testing.T) { - tests := []struct { - name string - cfg Config - want bool - }{ - {"empty config", Config{}, false}, - {"default only", Config{Default: "3d"}, true}, - {"ecosystem only", Config{Ecosystems: map[string]string{"npm": "7d"}}, true}, - {"package only", Config{Packages: map[string]string{"pkg:npm/x": "1d"}}, true}, - {"all zero", Config{Default: "0", Ecosystems: map[string]string{"npm": "0"}}, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.cfg.Enabled() - if got != tt.want { - t.Errorf("Enabled() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/database/schema.go b/internal/database/schema.go index e6f284f..c8d8d1e 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -6,7 +6,11 @@ import ( "time" ) -const postgresTimestamp = "TIMESTAMP" +const ( + postgresTimestamp = "TIMESTAMP" + sqliteDatetime = "DATETIME" + colTypeText = "TEXT" +) // Schema for proxy-specific tables. The packages and versions tables // are compatible with git-pkgs, allowing the proxy to use an existing @@ -369,9 +373,9 @@ func isTableNotFound(err error) bool { func (db *DB) createMigrationsTable() error { var ts string if db.dialect == DialectPostgres { - ts = "TIMESTAMP" + ts = postgresTimestamp } else { - ts = "DATETIME" + ts = sqliteDatetime } query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS migrations ( @@ -457,12 +461,12 @@ func (db *DB) MigrateSchema() error { func migrateAddPackagesEnrichmentColumns(db *DB) error { columns := map[string]string{ - "registry_url": "TEXT", - "supplier_name": "TEXT", - "supplier_type": "TEXT", - "source": "TEXT", - "enriched_at": "DATETIME", - "vulns_synced_at": "DATETIME", + "registry_url": colTypeText, + "supplier_name": colTypeText, + "supplier_type": colTypeText, + "source": colTypeText, + "enriched_at": sqliteDatetime, + "vulns_synced_at": sqliteDatetime, } if db.dialect == DialectPostgres { @@ -487,10 +491,10 @@ func migrateAddPackagesEnrichmentColumns(db *DB) error { func migrateAddVersionsEnrichmentColumns(db *DB) error { columns := map[string]string{ - "integrity": "TEXT", + "integrity": colTypeText, "yanked": "INTEGER DEFAULT 0", - "source": "TEXT", - "enriched_at": "DATETIME", + "source": colTypeText, + "enriched_at": sqliteDatetime, } if db.dialect == DialectPostgres { diff --git a/internal/handler/cargo_test.go b/internal/handler/cargo_test.go index 5ce81b6..10d3faf 100644 --- a/internal/handler/cargo_test.go +++ b/internal/handler/cargo_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) func cargoTestProxy() *Proxy { diff --git a/internal/handler/composer.go b/internal/handler/composer.go index 0933ece..065ddf9 100644 --- a/internal/handler/composer.go +++ b/internal/handler/composer.go @@ -16,6 +16,7 @@ import ( const ( composerUpstream = "https://packagist.org" composerRepo = "https://repo.packagist.org" + composerUnset = "__unset" vendorPackageParts = 2 ) @@ -150,7 +151,8 @@ func (h *ComposerHandler) rewriteMetadata(body []byte) ([]byte, error) { // expandMinifiedVersions expands the Composer v2 minified format where each // version entry only contains fields that differ from the previous entry. -// The "~dev" sentinel string resets the inheritance chain. +// The "~dev" sentinel string resets the inheritance chain, and the "__unset" +// value removes a field from the inherited state. func expandMinifiedVersions(versionList []any) []any { expanded := make([]any, 0, len(versionList)) inherited := map[string]any{} @@ -174,6 +176,10 @@ func expandMinifiedVersions(versionList []any) []any { merged[k] = deepCopyValue(val) } for k, val := range vmap { + if val == composerUnset { + delete(merged, k) + continue + } merged[k] = val } diff --git a/internal/handler/composer_test.go b/internal/handler/composer_test.go index 94ff8cb..baf13b6 100644 --- a/internal/handler/composer_test.go +++ b/internal/handler/composer_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) func TestComposerRewriteMetadata(t *testing.T) { @@ -177,6 +177,80 @@ func TestComposerRewriteMetadataMinifiedDevReset(t *testing.T) { } } +func TestComposerRewriteMetadataUnset(t *testing.T) { + h := &ComposerHandler{ + proxy: &Proxy{Logger: slog.Default()}, + proxyURL: "http://localhost:8080", + } + + // In the minified format, "__unset" removes a field from the inherited + // state. v1.29.0 has require-dev, v1.28.0 unsets it, v1.27.0 inherits the + // unset state. Composer rejects metadata where require-dev (or any link + // field) is the literal string "__unset" rather than an object. + input := `{ + "minified": "composer/2.0", + "packages": { + "venturecraft/revisionable": [ + { + "name": "venturecraft/revisionable", + "version": "1.29.0", + "require": {"php": ">=5.4"}, + "require-dev": {"orchestra/testbench": "~3.0"}, + "dist": {"url": "https://example.com/a.zip", "type": "zip"} + }, + { + "version": "1.28.0", + "require-dev": "__unset" + }, + { + "version": "1.27.0" + }, + { + "version": "1.26.0", + "require-dev": {"foo/bar": "1.0"} + } + ] + } + }` + + output, err := h.rewriteMetadata([]byte(input)) + if err != nil { + t.Fatalf("rewriteMetadata failed: %v", err) + } + + var result map[string]any + if err := json.Unmarshal(output, &result); err != nil { + t.Fatalf("failed to parse output: %v", err) + } + + versions := result["packages"].(map[string]any)["venturecraft/revisionable"].([]any) + if len(versions) != 4 { + t.Fatalf("expected 4 versions, got %d", len(versions)) + } + + byVersion := map[string]map[string]any{} + for _, v := range versions { + vmap := v.(map[string]any) + byVersion[vmap["version"].(string)] = vmap + } + + if _, ok := byVersion["1.29.0"]["require-dev"].(map[string]any); !ok { + t.Errorf("1.29.0 require-dev should be an object, got %T", byVersion["1.29.0"]["require-dev"]) + } + if rd, ok := byVersion["1.28.0"]["require-dev"]; ok { + t.Errorf("1.28.0 require-dev should be absent, got %v", rd) + } + if rd, ok := byVersion["1.27.0"]["require-dev"]; ok { + t.Errorf("1.27.0 require-dev should be absent (inherited unset), got %v", rd) + } + if _, ok := byVersion["1.26.0"]["require-dev"].(map[string]any); !ok { + t.Errorf("1.26.0 require-dev should be an object, got %T", byVersion["1.26.0"]["require-dev"]) + } + if _, ok := byVersion["1.27.0"]["require"].(map[string]any); !ok { + t.Error("1.27.0 should still inherit require from 1.29.0") + } +} + func TestComposerRewriteMetadataCooldownPreservesNames(t *testing.T) { now := time.Now() old := now.Add(-10 * 24 * time.Hour).Format(time.RFC3339) diff --git a/internal/handler/conda.go b/internal/handler/conda.go index a986f01..cfa20c8 100644 --- a/internal/handler/conda.go +++ b/internal/handler/conda.go @@ -140,7 +140,7 @@ func (h *CondaHandler) handleRepodata(w http.ResponseWriter, r *http.Request) { http.Error(w, "failed to create request", http.StatusInternalServerError) return } - req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set(headerAcceptEncoding, "gzip") resp, err := h.proxy.HTTPClient.Do(req) if err != nil { @@ -161,7 +161,7 @@ func (h *CondaHandler) handleRepodata(w http.ResponseWriter, r *http.Request) { return } - body, err := ReadMetadata(resp.Body) + body, err := h.proxy.ReadMetadata(resp.Body) if err != nil { http.Error(w, "failed to read response", http.StatusInternalServerError) return @@ -241,5 +241,5 @@ func (h *CondaHandler) proxyCached(w http.ResponseWriter, r *http.Request) { // proxyUpstream forwards a request to Anaconda without caching. func (h *CondaHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{headerAcceptEncoding}) } diff --git a/internal/handler/conda_test.go b/internal/handler/conda_test.go index 24b0236..1b57039 100644 --- a/internal/handler/conda_test.go +++ b/internal/handler/conda_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) func TestCondaParseFilename(t *testing.T) { diff --git a/internal/handler/cran.go b/internal/handler/cran.go index 246fcaa..0ecd2a3 100644 --- a/internal/handler/cran.go +++ b/internal/handler/cran.go @@ -159,5 +159,5 @@ func (h *CRANHandler) proxyCached(w http.ResponseWriter, r *http.Request) { // proxyUpstream forwards a request to CRAN without caching. func (h *CRANHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{headerAcceptEncoding}) } diff --git a/internal/handler/download_test.go b/internal/handler/download_test.go index 639e976..980e234 100644 --- a/internal/handler/download_test.go +++ b/internal/handler/download_test.go @@ -673,7 +673,7 @@ func TestMavenHandler_DownloadCacheHit(t *testing.T) { proxy, db, store, _ := setupTestProxy(t) seedPackageWithPURL(t, db, store, "maven", "com.google.guava:guava", "32.1.3-jre", "guava-32.1.3-jre.jar", "jar content") - h := NewMavenHandler(proxy, "http://localhost") + h := NewMavenHandler(proxy, "http://localhost", "", "") srv := httptest.NewServer(h.Routes()) defer srv.Close() @@ -730,7 +730,7 @@ func TestMavenHandler_MetadataProxied(t *testing.T) { func TestMavenHandler_EmptyPathNotFound(t *testing.T) { proxy, _, _, _ := setupTestProxy(t) - h := NewMavenHandler(proxy, "http://localhost") + h := NewMavenHandler(proxy, "http://localhost", "", "") srv := httptest.NewServer(h.Routes()) defer srv.Close() @@ -748,7 +748,7 @@ func TestMavenHandler_EmptyPathNotFound(t *testing.T) { func TestMavenHandler_ArtifactExtensions(t *testing.T) { proxy, _, _, fetcher := setupTestProxy(t) - extensions := []string{".jar", ".war", ".ear", ".pom", ".aar", ".klib"} + extensions := []string{".jar", ".war", ".ear", ".pom", ".aar", ".klib", ".module"} for _, ext := range extensions { fetcher.artifact = &fetch.Artifact{ Body: io.NopCloser(strings.NewReader("artifact")), @@ -756,7 +756,7 @@ func TestMavenHandler_ArtifactExtensions(t *testing.T) { } fetcher.fetchCalled = false - h := NewMavenHandler(proxy, "http://localhost") + h := NewMavenHandler(proxy, "http://localhost", "", "") upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Errorf("should not proxy artifact file %s to upstream", ext) @@ -789,7 +789,7 @@ func TestMavenHandler_CacheMiss(t *testing.T) { ContentType: "application/java-archive", } - h := NewMavenHandler(proxy, "http://localhost") + h := NewMavenHandler(proxy, "http://localhost", "", "") srv := httptest.NewServer(h.Routes()) defer srv.Close() @@ -809,6 +809,274 @@ func TestMavenHandler_CacheMiss(t *testing.T) { } } +func TestMavenHandler_GradlePluginMarkerFallbackAndCache(t *testing.T) { + tests := []struct { + name string + markerPath string + }{ + { + name: "Spotless", + markerPath: "/com/diffplug/spotless/com.diffplug.spotless.gradle.plugin/8.4.0/com.diffplug.spotless.gradle.plugin-8.4.0.pom", + }, + { + name: "BenManes", + markerPath: "/com/github/ben-manes/versions/com.github.ben-manes.versions.gradle.plugin/0.54.0/com.github.ben-manes.versions.gradle.plugin-0.54.0.pom", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + + primaryUpstream := "https://repo1.maven.org/maven2" + pluginPortalUpstream := "https://plugins.gradle.org/m2" + primaryURL := primaryUpstream + tt.markerPath + + fetcher.fetchErrByURL = map[string]error{ + primaryURL: ErrUpstreamNotFound, + } + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("")), + ContentType: "application/xml", + } + + h := NewMavenHandler(proxy, "http://localhost", primaryUpstream, pluginPortalUpstream) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + tt.markerPath) + if err != nil { + t.Fatalf("request failed: %v", err) + } + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if string(body) != "" { + t.Fatalf("body = %q, want %q", body, "") + } + + wantFallbackURL := pluginPortalUpstream + tt.markerPath + if fetcher.fetchedURL != wantFallbackURL { + t.Fatalf("fallback URL = %q, want %q", fetcher.fetchedURL, wantFallbackURL) + } + + fetcher.fetchCalled = false + resp, err = http.Get(srv.URL + tt.markerPath) + if err != nil { + t.Fatalf("second request failed: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("second status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if fetcher.fetchCalled { + t.Fatal("expected plugin marker POM to be served from cache on second request") + } + }) + } +} + +func TestMavenHandler_GradlePluginMarkerMetadataFallback(t *testing.T) { + paths := map[string]string{ + "/com/diffplug/spotless/com.diffplug.spotless.gradle.plugin/8.4.0/com.diffplug.spotless.gradle.plugin-8.4.0.pom.sha1": "sha1", + "/com/diffplug/spotless/com.diffplug.spotless.gradle.plugin/8.4.0/com.diffplug.spotless.gradle.plugin-8.4.0.pom.sha256": "sha256", + "/com/diffplug/spotless/com.diffplug.spotless.gradle.plugin/8.4.0/com.diffplug.spotless.gradle.plugin-8.4.0.pom.md5": "md5", + "/com/diffplug/spotless/com.diffplug.spotless.gradle.plugin/maven-metadata.xml": "", + } + + primaryHits := map[string]int{} + pluginHits := map[string]int{} + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + primaryHits[r.URL.Path]++ + if _, ok := paths[r.URL.Path]; ok { + http.NotFound(w, r) + return + } + t.Fatalf("unexpected path to primary upstream: %s", r.URL.Path) + })) + defer primary.Close() + + pluginPortal := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pluginHits[r.URL.Path]++ + body, ok := paths[r.URL.Path] + if !ok { + http.NotFound(w, r) + return + } + _, _ = io.WriteString(w, body) + })) + defer pluginPortal.Close() + + proxy, _, _, _ := setupTestProxy(t) + proxy.HTTPClient = primary.Client() + + h := NewMavenHandler(proxy, "http://localhost", primary.URL, pluginPortal.URL) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + for reqPath, wantBody := range paths { + resp, err := http.Get(srv.URL + reqPath) + if err != nil { + t.Fatalf("GET %s failed: %v", reqPath, err) + } + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("GET %s: status = %d, want %d", reqPath, resp.StatusCode, http.StatusOK) + } + if string(body) != wantBody { + t.Fatalf("GET %s: body = %q, want %q", reqPath, body, wantBody) + } + + if primaryHits[reqPath] == 0 { + t.Fatalf("GET %s did not hit primary upstream", reqPath) + } + if pluginHits[reqPath] == 0 { + t.Fatalf("GET %s did not hit plugin portal fallback", reqPath) + } + } +} + +func TestMavenHandler_GradlePluginImplementationMetadataFallback(t *testing.T) { + paths := map[string]string{ + "/com/diffplug/spotless/spotless-plugin-gradle/8.4.0/spotless-plugin-gradle-8.4.0.jar.sha1": "impl-sha1", + "/com/diffplug/spotless/spotless-plugin-gradle/8.4.0/spotless-plugin-gradle-8.4.0.jar.sha256": "impl-sha256", + } + + primaryHits := map[string]int{} + pluginHits := map[string]int{} + + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + primaryHits[r.URL.Path]++ + if _, ok := paths[r.URL.Path]; ok { + http.NotFound(w, r) + return + } + t.Fatalf("unexpected path to primary upstream: %s", r.URL.Path) + })) + defer primary.Close() + + pluginPortal := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + pluginHits[r.URL.Path]++ + body, ok := paths[r.URL.Path] + if !ok { + http.NotFound(w, r) + return + } + _, _ = io.WriteString(w, body) + })) + defer pluginPortal.Close() + + proxy, _, _, _ := setupTestProxy(t) + proxy.HTTPClient = primary.Client() + + h := NewMavenHandler(proxy, "http://localhost", primary.URL, pluginPortal.URL) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + for reqPath, wantBody := range paths { + resp, err := http.Get(srv.URL + reqPath) + if err != nil { + t.Fatalf("GET %s failed: %v", reqPath, err) + } + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("GET %s: status = %d, want %d", reqPath, resp.StatusCode, http.StatusOK) + } + if string(body) != wantBody { + t.Fatalf("GET %s: body = %q, want %q", reqPath, body, wantBody) + } + + if primaryHits[reqPath] == 0 { + t.Fatalf("GET %s did not hit primary upstream", reqPath) + } + if pluginHits[reqPath] == 0 { + t.Fatalf("GET %s did not hit plugin portal fallback", reqPath) + } + } +} + +func TestMavenHandler_GradlePluginImplementation_FallbackToPluginPortal(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + + primaryUpstream := "https://repo1.maven.org/maven2" + pluginPortalUpstream := "https://plugins.gradle.org/m2" + implPath := "/com/diffplug/spotless/spotless-plugin-gradle/8.4.0/spotless-plugin-gradle-8.4.0.jar" + primaryURL := primaryUpstream + implPath + pluginPortalURL := pluginPortalUpstream + implPath + + fetcher.fetchErrByURL = map[string]error{ + primaryURL: ErrUpstreamNotFound, + } + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("plugin impl jar")), + ContentType: "application/java-archive", + } + + h := NewMavenHandler(proxy, "http://localhost", primaryUpstream, pluginPortalUpstream) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + implPath) + if err != nil { + t.Fatalf("request failed: %v", err) + } + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if string(body) != "plugin impl jar" { + t.Fatalf("body = %q, want %q", body, "plugin impl jar") + } + + if fetcher.fetchedURL != pluginPortalURL { + t.Fatalf("implementation artifact should fallback to plugin portal; fetched URL = %q, want %q", fetcher.fetchedURL, pluginPortalURL) + } +} + +func TestMavenHandler_GradlePluginImplementation_NotFoundInBothUpstreams(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + + primaryUpstream := "https://repo1.maven.org/maven2" + pluginPortalUpstream := "https://plugins.gradle.org/m2" + implPath := "/com/diffplug/spotless/spotless-plugin-gradle/8.4.0/spotless-plugin-gradle-8.4.0.jar" + primaryURL := primaryUpstream + implPath + pluginPortalURL := pluginPortalUpstream + implPath + + fetcher.fetchErrByURL = map[string]error{ + primaryURL: ErrUpstreamNotFound, + pluginPortalURL: ErrUpstreamNotFound, + } + + h := NewMavenHandler(proxy, "http://localhost", primaryUpstream, pluginPortalUpstream) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + implPath) + if err != nil { + t.Fatalf("request failed: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } + + if fetcher.fetchedURL != pluginPortalURL { + t.Fatalf("expected fallback attempt to plugin portal; fetched URL = %q, want %q", fetcher.fetchedURL, pluginPortalURL) + } +} + func TestNuGetHandler_DownloadCacheMiss(t *testing.T) { proxy, _, _, fetcher := setupTestProxy(t) fetcher.artifact = &fetch.Artifact{ diff --git a/internal/handler/gem.go b/internal/handler/gem.go index bdb4bb9..9ec57e3 100644 --- a/internal/handler/gem.go +++ b/internal/handler/gem.go @@ -182,7 +182,7 @@ func (h *GemHandler) fetchCompactIndex(r *http.Request, name string) (*http.Resp if err != nil { return nil, err } - for _, hdr := range []string{"Accept", "Accept-Encoding", "If-None-Match", "If-Modified-Since"} { + for _, hdr := range []string{"Accept", headerAcceptEncoding, "If-None-Match", "If-Modified-Since"} { if v := r.Header.Get(hdr); v != "" { req.Header.Set(hdr, v) } @@ -311,7 +311,7 @@ func (h *GemHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { } // Copy relevant headers - for _, h := range []string{"Accept", "Accept-Encoding", "If-None-Match", "If-Modified-Since"} { + for _, h := range []string{"Accept", headerAcceptEncoding, "If-None-Match", "If-Modified-Since"} { if v := r.Header.Get(h); v != "" { req.Header.Set(h, v) } diff --git a/internal/handler/gem_test.go b/internal/handler/gem_test.go index 6dce324..7d90946 100644 --- a/internal/handler/gem_test.go +++ b/internal/handler/gem_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) func TestGemParseFilename(t *testing.T) { diff --git a/internal/handler/gradle.go b/internal/handler/gradle.go new file mode 100644 index 0000000..3b703be --- /dev/null +++ b/internal/handler/gradle.go @@ -0,0 +1,178 @@ +package handler + +import ( + "errors" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "time" + + "github.com/git-pkgs/proxy/internal/metrics" + "github.com/git-pkgs/proxy/internal/storage" +) + +const ( + gradleBuildCacheContentType = "application/vnd.gradle.build-cache-artifact.v2" + gradleBuildCacheStorageRoot = "_gradle/http-build-cache" + defaultGradleMaxUploadSize = 100 << 20 +) + +var gradleBuildCacheKeyPattern = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9._-]*$`) + +// GradleBuildCacheHandler handles Gradle HttpBuildCache GET/HEAD/PUT requests. +// +// This handler accepts /{key} when mounted under a base URL. +type GradleBuildCacheHandler struct { + proxy *Proxy +} + +// NewGradleBuildCacheHandler creates a Gradle HttpBuildCache handler. +func NewGradleBuildCacheHandler(proxy *Proxy) *GradleBuildCacheHandler { + return &GradleBuildCacheHandler{proxy: proxy} +} + +// Routes returns the HTTP handler for Gradle HttpBuildCache requests. +func (h *GradleBuildCacheHandler) Routes() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodPut: + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + key, statusCode := h.parseCacheKey(r.URL.Path) + if statusCode != http.StatusOK { + if statusCode == http.StatusNotFound { + http.NotFound(w, r) + return + } + http.Error(w, "invalid cache key", statusCode) + return + } + + if r.Method == http.MethodPut { + if h.proxy.GradleReadOnly { + http.Error(w, "gradle build cache is read-only", http.StatusMethodNotAllowed) + return + } + h.handlePut(w, r, key) + return + } + + h.handleGetOrHead(w, r, key) + }) +} + +func (h *GradleBuildCacheHandler) parseCacheKey(urlPath string) (string, int) { + keyPath := strings.TrimPrefix(urlPath, "/") + if keyPath == "" { + return "", http.StatusNotFound + } + + if containsPathTraversal(keyPath) { + return "", http.StatusBadRequest + } + + if strings.Contains(keyPath, "/") { + return "", http.StatusNotFound + } + + if !gradleBuildCacheKeyPattern.MatchString(keyPath) { + return "", http.StatusBadRequest + } + + return keyPath, http.StatusOK +} + +func (h *GradleBuildCacheHandler) cacheStoragePath(key string) string { + return gradleBuildCacheStorageRoot + "/" + key +} + +func (h *GradleBuildCacheHandler) handleGetOrHead(w http.ResponseWriter, r *http.Request, key string) { + storagePath := h.cacheStoragePath(key) + w.Header().Set("Content-Type", gradleBuildCacheContentType) + + if r.Method == http.MethodHead { + existsStart := time.Now() + exists, err := h.proxy.Storage.Exists(r.Context(), storagePath) + metrics.RecordStorageOperation("read", time.Since(existsStart)) + if err != nil { + metrics.RecordStorageError("read") + h.proxy.Logger.Error("failed to check gradle build cache entry", "key", key, "error", err) + http.Error(w, "failed to read cache entry", http.StatusInternalServerError) + return + } + if !exists { + metrics.RecordCacheMiss("gradle") + http.NotFound(w, r) + return + } + metrics.RecordCacheHit("gradle") + + sizeStart := time.Now() + size, err := h.proxy.Storage.Size(r.Context(), storagePath) + metrics.RecordStorageOperation("read", time.Since(sizeStart)) + if err != nil { + metrics.RecordStorageError("read") + } else if size >= 0 { + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + } + + w.WriteHeader(http.StatusOK) + return + } + + readStart := time.Now() + reader, err := h.proxy.Storage.Open(r.Context(), storagePath) + metrics.RecordStorageOperation("read", time.Since(readStart)) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + metrics.RecordCacheMiss("gradle") + http.NotFound(w, r) + return + } + metrics.RecordStorageError("read") + h.proxy.Logger.Error("failed to open gradle build cache entry", "key", key, "error", err) + http.Error(w, "failed to read cache entry", http.StatusInternalServerError) + return + } + defer func() { _ = reader.Close() }() + metrics.RecordCacheHit("gradle") + + w.WriteHeader(http.StatusOK) + _, _ = io.Copy(w, reader) +} + +func (h *GradleBuildCacheHandler) handlePut(w http.ResponseWriter, r *http.Request, key string) { + storagePath := h.cacheStoragePath(key) + maxUploadSize := h.proxy.GradleMaxUploadSize + if maxUploadSize <= 0 { + maxUploadSize = defaultGradleMaxUploadSize + } + + r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize) + + storeStart := time.Now() + _, hash, err := h.proxy.Storage.Store(r.Context(), storagePath, r.Body) + metrics.RecordStorageOperation("write", time.Since(storeStart)) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + http.Error(w, "cache entry too large", http.StatusRequestEntityTooLarge) + return + } + + metrics.RecordStorageError("write") + h.proxy.Logger.Error("failed to store gradle build cache entry", "key", key, "error", err) + http.Error(w, "failed to write cache entry", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Length", "0") + w.Header().Set("ETag", `"`+hash+`"`) + + w.WriteHeader(http.StatusCreated) +} diff --git a/internal/handler/gradle_test.go b/internal/handler/gradle_test.go new file mode 100644 index 0000000..a05d07e --- /dev/null +++ b/internal/handler/gradle_test.go @@ -0,0 +1,285 @@ +package handler + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/git-pkgs/proxy/internal/metrics" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +func TestGradleBuildCacheHandler_PutGetHead(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + key := "a1b2c3d4e5f6" + payload := "cache entry content" + + putReq, err := http.NewRequest(http.MethodPut, srv.URL+"/"+key, strings.NewReader(payload)) + if err != nil { + t.Fatalf("failed to create PUT request: %v", err) + } + putResp, err := http.DefaultClient.Do(putReq) + if err != nil { + t.Fatalf("PUT request failed: %v", err) + } + _ = putResp.Body.Close() + + if putResp.StatusCode != http.StatusCreated { + t.Fatalf("PUT status = %d, want %d", putResp.StatusCode, http.StatusCreated) + } + + getResp, err := http.Get(srv.URL + "/" + key) + if err != nil { + t.Fatalf("GET request failed: %v", err) + } + defer func() { _ = getResp.Body.Close() }() + + if getResp.StatusCode != http.StatusOK { + t.Fatalf("GET status = %d, want %d", getResp.StatusCode, http.StatusOK) + } + if getResp.Header.Get("Content-Type") != gradleBuildCacheContentType { + t.Fatalf("GET Content-Type = %q, want %q", getResp.Header.Get("Content-Type"), gradleBuildCacheContentType) + } + + body, _ := io.ReadAll(getResp.Body) + if string(body) != payload { + t.Fatalf("GET body = %q, want %q", body, payload) + } + + headReq, err := http.NewRequest(http.MethodHead, srv.URL+"/"+key, nil) + if err != nil { + t.Fatalf("failed to create HEAD request: %v", err) + } + headResp, err := http.DefaultClient.Do(headReq) + if err != nil { + t.Fatalf("HEAD request failed: %v", err) + } + defer func() { _ = headResp.Body.Close() }() + + if headResp.StatusCode != http.StatusOK { + t.Fatalf("HEAD status = %d, want %d", headResp.StatusCode, http.StatusOK) + } + body, _ = io.ReadAll(headResp.Body) + if len(body) != 0 { + t.Fatalf("HEAD body length = %d, want 0", len(body)) + } +} + +func TestGradleBuildCacheHandler_RootKeyPath(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + key := "rootpathkey" + putReq, err := http.NewRequest(http.MethodPut, srv.URL+"/"+key, strings.NewReader("root")) + if err != nil { + t.Fatalf("failed to create PUT request: %v", err) + } + putResp, err := http.DefaultClient.Do(putReq) + if err != nil { + t.Fatalf("PUT request failed: %v", err) + } + _ = putResp.Body.Close() + + if putResp.StatusCode != http.StatusCreated { + t.Fatalf("PUT status = %d, want %d", putResp.StatusCode, http.StatusCreated) + } + + getResp, err := http.Get(srv.URL + "/" + key) + if err != nil { + t.Fatalf("GET request failed: %v", err) + } + defer func() { _ = getResp.Body.Close() }() + + if getResp.StatusCode != http.StatusOK { + t.Fatalf("GET status = %d, want %d", getResp.StatusCode, http.StatusOK) + } +} + +func TestGradleBuildCacheHandler_GetMiss(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/missing-key") + if err != nil { + t.Fatalf("GET request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestGradleBuildCacheHandler_MethodNotAllowed(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + + req := httptest.NewRequest(http.MethodPost, "/key", nil) + w := httptest.NewRecorder() + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } +} + +func TestGradleBuildCacheHandler_PathTraversalRejected(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + + req := httptest.NewRequest(http.MethodGet, "/../secret", nil) + w := httptest.NewRecorder() + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestGradleBuildCacheHandler_CachePrefixRejected(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + + req := httptest.NewRequest(http.MethodGet, "/cache/key", nil) + w := httptest.NewRecorder() + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestGradleBuildCacheHandler_PutOverwriteReturnsCreated(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + key := "overwrite-key" + + for i, payload := range []string{"first", "second"} { + req, err := http.NewRequest(http.MethodPut, srv.URL+"/"+key, strings.NewReader(payload)) + if err != nil { + t.Fatalf("failed to create PUT request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("PUT request failed: %v", err) + } + _ = resp.Body.Close() + + want := http.StatusCreated + if resp.StatusCode != want { + t.Fatalf("PUT #%d status = %d, want %d", i+1, resp.StatusCode, want) + } + } +} + +func TestGradleBuildCacheHandler_PutReadOnly(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + proxy.GradleReadOnly = true + + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + req, err := http.NewRequest(http.MethodPut, srv.URL+"/readonly-key", strings.NewReader("payload")) + if err != nil { + t.Fatalf("failed to create PUT request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("PUT request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("PUT status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } +} + +func TestGradleBuildCacheHandler_PutTooLarge(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + proxy.GradleMaxUploadSize = 4 + + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + req, err := http.NewRequest(http.MethodPut, srv.URL+"/oversized-key", strings.NewReader("12345")) + if err != nil { + t.Fatalf("failed to create PUT request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("PUT request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusRequestEntityTooLarge { + t.Fatalf("PUT status = %d, want %d", resp.StatusCode, http.StatusRequestEntityTooLarge) + } +} + +func TestGradleBuildCacheHandler_RecordsMetrics(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGradleBuildCacheHandler(proxy) + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + hitsBefore := testutil.ToFloat64(metrics.CacheHits.WithLabelValues("gradle")) + missesBefore := testutil.ToFloat64(metrics.CacheMisses.WithLabelValues("gradle")) + + key := "metrics-key" + putReq, err := http.NewRequest(http.MethodPut, srv.URL+"/"+key, strings.NewReader("payload")) + if err != nil { + t.Fatalf("failed to create PUT request: %v", err) + } + putResp, err := http.DefaultClient.Do(putReq) + if err != nil { + t.Fatalf("PUT request failed: %v", err) + } + _ = putResp.Body.Close() + + getResp, err := http.Get(srv.URL + "/" + key) + if err != nil { + t.Fatalf("GET request failed: %v", err) + } + _ = getResp.Body.Close() + + headReq, err := http.NewRequest(http.MethodHead, srv.URL+"/"+key, nil) + if err != nil { + t.Fatalf("failed to create HEAD request: %v", err) + } + headResp, err := http.DefaultClient.Do(headReq) + if err != nil { + t.Fatalf("HEAD request failed: %v", err) + } + _ = headResp.Body.Close() + + missResp, err := http.Get(srv.URL + "/missing-key") + if err != nil { + t.Fatalf("GET miss request failed: %v", err) + } + _ = missResp.Body.Close() + + hitsAfter := testutil.ToFloat64(metrics.CacheHits.WithLabelValues("gradle")) + missesAfter := testutil.ToFloat64(metrics.CacheMisses.WithLabelValues("gradle")) + + if diff := hitsAfter - hitsBefore; diff != 2 { + t.Fatalf("cache hits delta = %.0f, want 2", diff) + } + if diff := missesAfter - missesBefore; diff != 1 { + t.Fatalf("cache misses delta = %.0f, want 1", diff) + } +} diff --git a/internal/handler/handler.go b/internal/handler/handler.go index d7d79c9..d06ca83 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -10,11 +10,12 @@ import ( "io" "log/slog" "net/http" + "net/url" "strconv" "strings" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" "github.com/git-pkgs/proxy/internal/database" "github.com/git-pkgs/proxy/internal/metrics" "github.com/git-pkgs/proxy/internal/storage" @@ -23,9 +24,21 @@ import ( ) // containsPathTraversal returns true if the path contains ".." segments -// that could be used to escape the intended directory. +// that could be used to escape the intended directory. It checks the path +// as given and after URL-decoding, and treats backslashes as separators. func containsPathTraversal(path string) bool { - for _, segment := range strings.Split(path, "/") { + if hasDotDotSegment(path) { + return true + } + if decoded, err := url.PathUnescape(path); err == nil && decoded != path { + return hasDotDotSegment(decoded) + } + return false +} + +func hasDotDotSegment(path string) bool { + path = strings.ReplaceAll(path, "\\", "/") + for segment := range strings.SplitSeq(path, "/") { if segment == ".." { return true } @@ -37,23 +50,27 @@ const defaultHTTPTimeout = 30 * time.Second const contentTypeJSON = "application/json" -// maxMetadataSize is the maximum size of upstream metadata responses (100 MB). -// Package metadata (e.g. npm with many versions) can be large, but unbounded -// reads risk OOM if an upstream misbehaves. -const maxMetadataSize = 100 << 20 +const headerAcceptEncoding = "Accept-Encoding" -// ErrMetadataTooLarge is returned when upstream metadata exceeds maxMetadataSize. +// defaultMetadataMaxSize is used when Proxy.MetadataMaxSize is unset. +const defaultMetadataMaxSize = 100 << 20 + +// ErrMetadataTooLarge is returned when upstream metadata exceeds the configured limit. var ErrMetadataTooLarge = errors.New("metadata response exceeds size limit") // ReadMetadata reads an upstream response body with a size limit to prevent OOM // from unexpectedly large responses. Returns ErrMetadataTooLarge if the response // is truncated by the limit. -func ReadMetadata(r io.Reader) ([]byte, error) { - data, err := io.ReadAll(io.LimitReader(r, maxMetadataSize+1)) +func (p *Proxy) ReadMetadata(r io.Reader) ([]byte, error) { + limit := p.MetadataMaxSize + if limit <= 0 { + limit = defaultMetadataMaxSize + } + data, err := io.ReadAll(io.LimitReader(r, limit+1)) if err != nil { return nil, err } - if int64(len(data)) > maxMetadataSize { + if int64(len(data)) > limit { return nil, ErrMetadataTooLarge } return data, nil @@ -61,15 +78,24 @@ func ReadMetadata(r io.Reader) ([]byte, error) { // Proxy provides shared functionality for protocol handlers. type Proxy struct { - DB *database.DB - Storage storage.Storage - Fetcher fetch.FetcherInterface - Resolver *fetch.Resolver - Logger *slog.Logger - Cooldown *cooldown.Config - CacheMetadata bool - MetadataTTL time.Duration - HTTPClient *http.Client + DB *database.DB + Storage storage.Storage + Fetcher fetch.FetcherInterface + Resolver *fetch.Resolver + Logger *slog.Logger + Cooldown *cooldown.Config + CacheMetadata bool + MetadataTTL time.Duration + MetadataMaxSize int64 + GradleReadOnly bool + GradleMaxUploadSize int64 + DirectServe bool + DirectServeTTL time.Duration + // DirectServeBaseURL, if set, replaces the scheme and host of presigned + // URLs so clients receive a public address even when the proxy reaches + // storage at an internal one. + DirectServeBaseURL string + HTTPClient *http.Client } // NewProxy creates a new Proxy with the given dependencies. @@ -92,6 +118,7 @@ func NewProxy(db *database.DB, store storage.Storage, fetcher fetch.FetcherInter // CacheResult contains information about a cached or fetched artifact. type CacheResult struct { Reader io.ReadCloser + RedirectURL string Size int64 ContentType string Hash string @@ -138,6 +165,26 @@ func (p *Proxy) checkCache(ctx context.Context, pkgPURL, versionPURL, filename s return nil, nil } + result := &CacheResult{ + Size: artifact.Size.Int64, + ContentType: artifact.ContentType.String, + Hash: artifact.ContentHash.String, + Cached: true, + } + + if p.DirectServe { + signed, err := p.Storage.SignedURL(ctx, artifact.StoragePath.String, p.DirectServeTTL) + if err == nil { + result.RedirectURL = rewriteSignedURLHost(signed, p.DirectServeBaseURL) + p.recordCacheHit(pkgPURL, versionPURL, filename) + return result, nil + } + if !errors.Is(err, storage.ErrSignedURLUnsupported) { + p.Logger.Warn("failed to sign storage URL, falling back to streaming", + "path", artifact.StoragePath.String, "error", err) + } + } + start := time.Now() reader, err := p.Storage.Open(ctx, artifact.StoragePath.String) metrics.RecordStorageOperation("read", time.Since(start)) @@ -148,20 +195,45 @@ func (p *Proxy) checkCache(ctx context.Context, pkgPURL, versionPURL, filename s return nil, nil } - _ = p.DB.RecordArtifactHit(versionPURL, filename) + result.Reader = newVerifyingReader(reader, artifact.ContentHash.String, ver.Integrity.String, + func(reason string) { + p.Logger.Error("cached artifact failed integrity check", + "purl", versionPURL, "filename", filename, + "path", artifact.StoragePath.String, "reason", reason) + metrics.RecordIntegrityFailure(pkg.Ecosystem) + if err := p.DB.ClearArtifactCache(versionPURL, filename); err != nil { + p.Logger.Warn("failed to clear corrupt artifact from cache", "error", err) + } + }) + p.recordCacheHit(pkgPURL, versionPURL, filename) + return result, nil +} - // Extract ecosystem from pkgPURL for metrics - if p, err := purl.Parse(pkgPURL); err == nil { - metrics.RecordCacheHit(purl.PURLTypeToEcosystem(p.Type)) +// rewriteSignedURLHost replaces the scheme and host of a signed URL with those +// from baseURL, preserving the path and query (which carry the signature). +// Returns signed unchanged if baseURL is empty or either URL fails to parse. +func rewriteSignedURLHost(signed, baseURL string) string { + if baseURL == "" { + return signed } + s, err := url.Parse(signed) + if err != nil { + return signed + } + b, err := url.Parse(baseURL) + if err != nil || b.Scheme == "" || b.Host == "" { + return signed + } + s.Scheme = b.Scheme + s.Host = b.Host + return s.String() +} - return &CacheResult{ - Reader: reader, - Size: artifact.Size.Int64, - ContentType: artifact.ContentType.String, - Hash: artifact.ContentHash.String, - Cached: true, - }, nil +func (p *Proxy) recordCacheHit(pkgPURL, versionPURL, filename string) { + _ = p.DB.RecordArtifactHit(versionPURL, filename) + if parsed, err := purl.Parse(pkgPURL); err == nil { + metrics.RecordCacheHit(purl.PURLTypeToEcosystem(parsed.Type)) + } } func (p *Proxy) fetchAndCache(ctx context.Context, ecosystem, name, version, filename, pkgPURL, versionPURL string) (*CacheResult, error) { @@ -276,6 +348,15 @@ func (p *Proxy) updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, u // ServeArtifact writes a CacheResult to an HTTP response. func ServeArtifact(w http.ResponseWriter, result *CacheResult) { + if result.RedirectURL != "" { + if result.Hash != "" { + w.Header().Set("ETag", fmt.Sprintf(`"%s"`, result.Hash)) + } + w.Header().Set("Location", result.RedirectURL) + w.WriteHeader(http.StatusFound) + return + } + defer func() { _ = result.Reader.Close() }() if result.ContentType != "" { @@ -396,7 +477,7 @@ func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, u cached, readErr := p.Storage.Open(ctx, entry.StoragePath) if readErr == nil { defer func() { _ = cached.Close() }() - data, readErr := ReadMetadata(cached) + data, readErr := p.ReadMetadata(cached) if readErr == nil { ct := contentTypeJSON if entry.ContentType.Valid { @@ -441,7 +522,7 @@ func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, u } defer func() { _ = cached.Close() }() - data, readErr := ReadMetadata(cached) + data, readErr := p.ReadMetadata(cached) if readErr != nil { return nil, "", fmt.Errorf("upstream failed and cached read error: %w", err) } @@ -483,7 +564,7 @@ func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, e return nil, "", "", zeroTime, errStale304 } defer func() { _ = cached.Close() }() - data, readErr := ReadMetadata(cached) + data, readErr := p.ReadMetadata(cached) if readErr != nil { return nil, "", "", zeroTime, errStale304 } @@ -505,7 +586,7 @@ func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, e return nil, "", "", zeroTime, fmt.Errorf("upstream returned %d", resp.StatusCode) } - body, err := ReadMetadata(resp.Body) + body, err := p.ReadMetadata(resp.Body) if err != nil { return nil, "", "", zeroTime, fmt.Errorf("reading response: %w", err) } @@ -602,9 +683,14 @@ func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, return } + p.writeMetadataCachedResponse(w, r, ecosystem, cacheKey, body, contentType) +} + +// writeMetadataCachedResponse writes a cached metadata response and handles +// conditional request headers using metadata cache validators. +func (p *Proxy) writeMetadataCachedResponse(w http.ResponseWriter, r *http.Request, ecosystem, cacheKey string, body []byte, contentType string) { cm := p.lookupCachedMeta(ecosystem, cacheKey) - // Honor client conditional request headers if cm.etag != "" { if match := r.Header.Get("If-None-Match"); match != "" && match == cm.etag { w.WriteHeader(http.StatusNotModified) @@ -650,7 +736,7 @@ func (p *Proxy) proxyMetadataStream(w http.ResponseWriter, r *http.Request, upst } req.Header.Set("Accept", accept) - for _, header := range []string{"Accept-Encoding", "If-Modified-Since", "If-None-Match"} { + for _, header := range []string{headerAcceptEncoding, "If-Modified-Since", "If-None-Match"} { if v := r.Header.Get(header); v != "" { req.Header.Set(header, v) } diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index 78ed415..bbcab72 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -21,9 +21,11 @@ import ( // mockStorage implements storage.Storage for testing. type mockStorage struct { - files map[string][]byte - storeErr error - openErr error + files map[string][]byte + storeErr error + openErr error + signedURL string + signErr error } func newMockStorage() *mockStorage { @@ -79,16 +81,27 @@ func (s *mockStorage) UsedSpace(_ context.Context) (int64, error) { return total, nil } +func (s *mockStorage) SignedURL(_ context.Context, _ string, _ time.Duration) (string, error) { + if s.signErr != nil { + return "", s.signErr + } + if s.signedURL == "" { + return "", storage.ErrSignedURLUnsupported + } + return s.signedURL, nil +} + func (s *mockStorage) URL() string { return "mem://" } func (s *mockStorage) Close() error { return nil } // mockFetcher implements fetch.FetcherInterface for testing. type mockFetcher struct { - artifact *fetch.Artifact - fetchErr error - fetchCalled bool - fetchedURL string + artifact *fetch.Artifact + fetchErr error + fetchErrByURL map[string]error + fetchCalled bool + fetchedURL string } func (f *mockFetcher) Fetch(ctx context.Context, url string) (*fetch.Artifact, error) { @@ -98,6 +111,11 @@ func (f *mockFetcher) Fetch(ctx context.Context, url string) (*fetch.Artifact, e func (f *mockFetcher) FetchWithHeaders(_ context.Context, url string, _ http.Header) (*fetch.Artifact, error) { f.fetchCalled = true f.fetchedURL = url + if f.fetchErrByURL != nil { + if err, ok := f.fetchErrByURL[url]; ok { + return nil, err + } + } if f.fetchErr != nil { return nil, f.fetchErr } @@ -311,6 +329,213 @@ func TestGetOrFetchArtifactFromURL_CacheMiss_StorageMissing(t *testing.T) { } } +func TestGetOrFetchArtifact_DirectServe_Redirect(t *testing.T) { + proxy, db, store, fetcher := setupTestProxy(t) + seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content") + + proxy.DirectServe = true + proxy.DirectServeTTL = 15 * time.Minute + store.signedURL = "https://bucket.s3.amazonaws.com/npm/lodash?X-Amz-Signature=abc" + + result, err := proxy.GetOrFetchArtifact(context.Background(), "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !result.Cached { + t.Error("expected result to be cached") + } + if result.RedirectURL != store.signedURL { + t.Errorf("RedirectURL = %q, want %q", result.RedirectURL, store.signedURL) + } + if result.Reader != nil { + t.Error("Reader should be nil when redirecting") + } + if fetcher.fetchCalled { + t.Error("fetcher should not be called on cache hit") + } + + // Hit count should still be recorded on the redirect path. + art, _ := db.GetArtifact("pkg:npm/lodash@4.17.21", "lodash-4.17.21.tgz") + if art == nil || art.HitCount != 1 { + t.Errorf("artifact hit count not recorded: %+v", art) + } +} + +func TestGetOrFetchArtifact_DirectServe_BaseURLRewrite(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content") + + proxy.DirectServe = true + proxy.DirectServeBaseURL = "https://cdn.example.com" + store.signedURL = "http://127.0.0.1:9000/bucket/npm/lodash?X-Amz-Signature=abc&X-Amz-Expires=900" + + result, err := proxy.GetOrFetchArtifact(context.Background(), "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + want := "https://cdn.example.com/bucket/npm/lodash?X-Amz-Signature=abc&X-Amz-Expires=900" + if result.RedirectURL != want { + t.Errorf("RedirectURL = %q, want %q", result.RedirectURL, want) + } +} + +func TestRewriteSignedURLHost(t *testing.T) { + tests := []struct { + name string + signed string + baseURL string + want string + }{ + { + "empty base url is no-op", + "http://127.0.0.1:9000/bucket/key?sig=abc", + "", + "http://127.0.0.1:9000/bucket/key?sig=abc", + }, + { + "replaces scheme and host", + "http://127.0.0.1:9000/bucket/key?sig=abc", + "https://cdn.example.com", + "https://cdn.example.com/bucket/key?sig=abc", + }, + { + "preserves path and query", + "http://minio:9000/bucket/npm/lodash/4.17.21/lodash.tgz?X-Amz-Signature=abc&X-Amz-Date=20260101", + "https://files.example.com", + "https://files.example.com/bucket/npm/lodash/4.17.21/lodash.tgz?X-Amz-Signature=abc&X-Amz-Date=20260101", + }, + { + "ignores base url path", + "http://127.0.0.1:9000/bucket/key?sig=abc", + "https://cdn.example.com/ignored", + "https://cdn.example.com/bucket/key?sig=abc", + }, + { + "invalid base url is no-op", + "http://127.0.0.1:9000/bucket/key?sig=abc", + "://bad", + "http://127.0.0.1:9000/bucket/key?sig=abc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := rewriteSignedURLHost(tt.signed, tt.baseURL) + if got != tt.want { + t.Errorf("rewriteSignedURLHost(%q, %q) = %q, want %q", tt.signed, tt.baseURL, got, tt.want) + } + }) + } +} + +func TestGetOrFetchArtifact_DirectServe_FallbackOnUnsupported(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content") + + proxy.DirectServe = true + // store.signedURL is empty so SignedURL returns ErrSignedURLUnsupported. + + result, err := proxy.GetOrFetchArtifact(context.Background(), "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer func() { _ = result.Reader.Close() }() + + if result.RedirectURL != "" { + t.Errorf("RedirectURL should be empty, got %q", result.RedirectURL) + } + if result.Reader == nil { + t.Fatal("Reader should be set when signing is unsupported") + } + body, _ := io.ReadAll(result.Reader) + if string(body) != "cached content" { + t.Errorf("got body %q, want %q", body, "cached content") + } +} + +func TestGetOrFetchArtifact_DirectServe_FallbackOnError(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content") + + proxy.DirectServe = true + store.signErr = errors.New("signing failed") + + result, err := proxy.GetOrFetchArtifact(context.Background(), "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer func() { _ = result.Reader.Close() }() + + if result.RedirectURL != "" { + t.Errorf("RedirectURL should be empty on signing error, got %q", result.RedirectURL) + } + if result.Reader == nil { + t.Fatal("Reader should be set when signing fails") + } +} + +func TestGetOrFetchArtifact_DirectServe_DisabledIgnoresSigning(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content") + + proxy.DirectServe = false + store.signedURL = "https://bucket.example/should-not-be-used" + + result, err := proxy.GetOrFetchArtifact(context.Background(), "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer func() { _ = result.Reader.Close() }() + + if result.RedirectURL != "" { + t.Errorf("RedirectURL should be empty when DirectServe is off, got %q", result.RedirectURL) + } +} + +func TestServeArtifact_Redirect(t *testing.T) { + w := httptest.NewRecorder() + ServeArtifact(w, &CacheResult{ + RedirectURL: "https://bucket.s3.amazonaws.com/file?sig=abc", + Hash: "abc123", + Cached: true, + }) + + if w.Code != http.StatusFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusFound) + } + if loc := w.Header().Get("Location"); loc != "https://bucket.s3.amazonaws.com/file?sig=abc" { + t.Errorf("Location = %q", loc) + } + if etag := w.Header().Get("ETag"); etag != `"abc123"` { + t.Errorf("ETag = %q, want %q", etag, `"abc123"`) + } + if cl := w.Header().Get("Content-Length"); cl != "" { + t.Errorf("Content-Length should not be set on redirect, got %q", cl) + } +} + +func TestServeArtifact_Stream(t *testing.T) { + w := httptest.NewRecorder() + ServeArtifact(w, &CacheResult{ + Reader: io.NopCloser(strings.NewReader("payload")), + Size: 7, + ContentType: "application/octet-stream", + Hash: "abc123", + }) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "payload" { + t.Errorf("body = %q, want %q", w.Body.String(), "payload") + } + if ct := w.Header().Get("Content-Type"); ct != "application/octet-stream" { + t.Errorf("Content-Type = %q", ct) + } +} + func TestGetOrFetchArtifactFromURL_CacheHit(t *testing.T) { proxy, db, store, fetcher := setupTestProxy(t) seedPackage(t, db, store, "pypi", "requests", "2.28.0", "requests-2.28.0.tar.gz", "pypi content") diff --git a/internal/handler/hex_test.go b/internal/handler/hex_test.go index 19d34b4..b02540a 100644 --- a/internal/handler/hex_test.go +++ b/internal/handler/hex_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" "google.golang.org/protobuf/encoding/protowire" ) diff --git a/internal/handler/integrity.go b/internal/handler/integrity.go new file mode 100644 index 0000000..bb29a21 --- /dev/null +++ b/internal/handler/integrity.go @@ -0,0 +1,140 @@ +package handler + +import ( + "crypto/sha256" + "crypto/sha512" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "io" + "strings" +) + +// parseSRI parses a Subresource Integrity string (e.g. "sha512-abc==") into +// an algorithm name and raw digest bytes. Returns ok=false for empty, +// malformed, or unsupported entries. Only the first hash in a multi-hash +// SRI string is considered. +func parseSRI(s string) (algo string, digest []byte, ok bool) { + s = strings.TrimSpace(s) + if s == "" { + return "", nil, false + } + if i := strings.IndexByte(s, ' '); i >= 0 { + s = s[:i] + } + algo, b64, found := strings.Cut(s, "-") + if !found { + return "", nil, false + } + d, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return "", nil, false + } + switch algo { + case "sha256", "sha384", "sha512": + return algo, d, true + default: + return "", nil, false + } +} + +func newSRIHash(algo string) hash.Hash { + switch algo { + case "sha256": + return sha256.New() + case "sha384": + return sha512.New384() + case "sha512": + return sha512.New() + } + return nil +} + +// verifyingReader wraps an io.ReadCloser and computes SHA256 (and optionally +// a second SRI hash) as bytes are read. When the underlying reader reaches +// EOF it compares the digests against the expected values and calls +// onMismatch for each failure. Verification is skipped if the stream was +// not fully consumed (e.g. client disconnect) to avoid false positives. +type verifyingReader struct { + r io.ReadCloser + sha256 hash.Hash + wantSHA256 string + sri hash.Hash + sriAlgo string + wantSRI []byte + onMismatch func(reason string) + eof bool + verified bool +} + +func newVerifyingReader(r io.ReadCloser, contentHash, sri string, onMismatch func(string)) io.ReadCloser { + if contentHash == "" && sri == "" { + return r + } + v := &verifyingReader{ + r: r, + onMismatch: onMismatch, + } + if contentHash != "" { + v.sha256 = sha256.New() + v.wantSHA256 = contentHash + } + if algo, digest, ok := parseSRI(sri); ok { + v.sri = newSRIHash(algo) + v.sriAlgo = algo + v.wantSRI = digest + } + if v.sha256 == nil && v.sri == nil { + return r + } + return v +} + +func (v *verifyingReader) Read(p []byte) (int, error) { + n, err := v.r.Read(p) + if n > 0 { + if v.sha256 != nil { + v.sha256.Write(p[:n]) + } + if v.sri != nil { + v.sri.Write(p[:n]) + } + } + if err == io.EOF { + v.eof = true + v.verify() + } + return n, err +} + +func (v *verifyingReader) Close() error { + if v.eof { + v.verify() + } + return v.r.Close() +} + +func (v *verifyingReader) verify() { + if v.verified { + return + } + v.verified = true + + if v.sha256 != nil { + got := hex.EncodeToString(v.sha256.Sum(nil)) + if subtle.ConstantTimeCompare([]byte(got), []byte(v.wantSHA256)) != 1 { + v.onMismatch(fmt.Sprintf("content_hash mismatch: stored=%s computed=%s", v.wantSHA256, got)) + } + } + if v.sri != nil { + got := v.sri.Sum(nil) + if subtle.ConstantTimeCompare(got, v.wantSRI) != 1 { + v.onMismatch(fmt.Sprintf("integrity mismatch: %s expected=%s computed=%s", + v.sriAlgo, + base64.StdEncoding.EncodeToString(v.wantSRI), + base64.StdEncoding.EncodeToString(got))) + } + } +} diff --git a/internal/handler/integrity_test.go b/internal/handler/integrity_test.go new file mode 100644 index 0000000..93c448c --- /dev/null +++ b/internal/handler/integrity_test.go @@ -0,0 +1,136 @@ +package handler + +import ( + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/hex" + "io" + "strings" + "testing" +) + +func sha256Hex(data string) string { + sum := sha256.Sum256([]byte(data)) + return hex.EncodeToString(sum[:]) +} + +func sha512SRI(data string) string { + sum := sha512.Sum512([]byte(data)) + return "sha512-" + base64.StdEncoding.EncodeToString(sum[:]) +} + +func TestParseSRI(t *testing.T) { + tests := []struct { + name string + input string + algo string + ok bool + }{ + {"sha512", sha512SRI("hello"), "sha512", true}, + {"sha256", "sha256-" + base64.StdEncoding.EncodeToString([]byte("0123456789012345678901234567890123456789")), "sha256", true}, + {"empty", "", "", false}, + {"no dash", "sha512abc", "", false}, + {"bad base64", "sha512-not!base64", "", false}, + {"unsupported algo", "md5-" + base64.StdEncoding.EncodeToString([]byte("x")), "", false}, + {"multi hash takes first", sha512SRI("a") + " " + sha512SRI("b"), "sha512", true}, + {"whitespace", " " + sha512SRI("x") + " ", "sha512", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + algo, digest, ok := parseSRI(tt.input) + if ok != tt.ok { + t.Fatalf("ok = %v, want %v", ok, tt.ok) + } + if !tt.ok { + return + } + if algo != tt.algo { + t.Errorf("algo = %q, want %q", algo, tt.algo) + } + if len(digest) == 0 { + t.Error("digest is empty") + } + }) + } +} + +func TestVerifyingReader(t *testing.T) { + const data = "hello world" + goodSHA := sha256Hex(data) + goodSRI := sha512SRI(data) + + tests := []struct { + name string + hash string + sri string + wantCalls int + }{ + {"both match", goodSHA, goodSRI, 0}, + {"sha256 only match", goodSHA, "", 0}, + {"sri only match", "", goodSRI, 0}, + {"sha256 mismatch", sha256Hex("other"), "", 1}, + {"sri mismatch", "", sha512SRI("other"), 1}, + {"both mismatch", sha256Hex("other"), sha512SRI("other"), 2}, + {"no checks", "", "", 0}, + {"unparseable sri ignored", goodSHA, "garbage", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var calls []string + r := newVerifyingReader(io.NopCloser(strings.NewReader(data)), tt.hash, tt.sri, + func(reason string) { calls = append(calls, reason) }) + + got, err := io.ReadAll(r) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(got) != data { + t.Errorf("data corrupted: got %q", got) + } + if err := r.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + if len(calls) != tt.wantCalls { + t.Errorf("onMismatch called %d times, want %d: %v", len(calls), tt.wantCalls, calls) + } + }) + } +} + +func TestVerifyingReaderPassthrough(t *testing.T) { + src := io.NopCloser(strings.NewReader("x")) + r := newVerifyingReader(src, "", "", func(string) { t.Fatal("should not be called") }) + if r != src { + t.Error("expected passthrough when no hashes provided") + } +} + +func TestVerifyingReaderPartialRead(t *testing.T) { + var calls int + r := newVerifyingReader(io.NopCloser(strings.NewReader("hello world")), + sha256Hex("hello world"), "", func(string) { calls++ }) + + buf := make([]byte, 5) + _, _ = r.Read(buf) + _ = r.Close() + + if calls != 0 { + t.Errorf("onMismatch called %d times for partial read, want 0", calls) + } +} + +func TestVerifyingReaderVerifyOnce(t *testing.T) { + var calls int + r := newVerifyingReader(io.NopCloser(strings.NewReader("x")), sha256Hex("y"), "", + func(string) { calls++ }) + _, _ = io.ReadAll(r) + _ = r.Close() + _ = r.Close() + if calls != 1 { + t.Errorf("onMismatch called %d times, want 1", calls) + } +} diff --git a/internal/handler/julia.go b/internal/handler/julia.go new file mode 100644 index 0000000..08b1fdf --- /dev/null +++ b/internal/handler/julia.go @@ -0,0 +1,347 @@ +package handler + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/gzip" + "context" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "sync" + + "github.com/BurntSushi/toml" +) + +const ( + juliaUpstream = "https://pkg.julialang.org" + juliaGeneralRegistryUUID = "23338594-aafe-5451-b93e-139f81909106" + juliaArtifactName = "_artifact" + juliaRegistryName = "_registry" +) + +var ( + juliaHexPattern = regexp.MustCompile(`^[0-9a-f]{40,64}$`) + juliaUUIDPattern = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) +) + +// JuliaHandler handles Julia Pkg server protocol requests. +// +// See https://pkgdocs.julialang.org/v1/registries/ and the PkgServer.jl +// reference implementation. The protocol is content-addressed: registry, +// package and artifact resources are all identified by git tree hashes +// and are immutable once published. +type JuliaHandler struct { + proxy *Proxy + upstreamURL string + + mu sync.RWMutex + names map[string]string + namesHash string + loadMu sync.Mutex +} + +// NewJuliaHandler creates a new Julia Pkg server handler. +func NewJuliaHandler(proxy *Proxy, _ string) *JuliaHandler { + return &JuliaHandler{ + proxy: proxy, + upstreamURL: juliaUpstream, + names: make(map[string]string), + } +} + +// Routes returns the HTTP handler for Julia requests. +func (h *JuliaHandler) Routes() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("GET /registries", h.handleRegistries) + mux.HandleFunc("GET /registries.eager", h.handleRegistries) + mux.HandleFunc("GET /registries.conservative", h.handleRegistries) + mux.HandleFunc("GET /registry/{uuid}/{hash}", h.handleRegistry) + mux.HandleFunc("GET /package/{uuid}/{hash}", h.handlePackage) + mux.HandleFunc("GET /artifact/{hash}", h.handleArtifact) + mux.HandleFunc("GET /meta", h.proxyUpstream) + + return mux +} + +// handleRegistries serves the list of available registries. This is the only +// mutable endpoint in the protocol so it goes through the metadata cache. +func (h *JuliaHandler) handleRegistries(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "julia", cacheKey, "*/*") +} + +// handleRegistry serves an immutable registry tarball and refreshes the +// UUID→name map from its Registry.toml. +func (h *JuliaHandler) handleRegistry(w http.ResponseWriter, r *http.Request) { + uuid := r.PathValue("uuid") + hash := r.PathValue("hash") + if !validJuliaUUID(uuid) || !juliaHexPattern.MatchString(hash) { + http.Error(w, "invalid registry reference", http.StatusBadRequest) + return + } + + h.proxy.Logger.Info("julia registry request", "uuid", uuid, "hash", hash) + + upstreamURL := h.upstreamURL + r.URL.Path + result, err := h.proxy.GetOrFetchArtifactFromURL(r.Context(), "julia", juliaRegistryName, hash, hash+".tar.gz", upstreamURL) + if err != nil { + h.proxy.Logger.Error("failed to get registry", "error", err) + http.Error(w, "failed to fetch registry", http.StatusBadGateway) + return + } + + go h.refreshNamesFromRegistry(uuid, hash) + + ServeArtifact(w, result) +} + +// handlePackage serves an immutable package source tarball. +func (h *JuliaHandler) handlePackage(w http.ResponseWriter, r *http.Request) { + uuid := r.PathValue("uuid") + hash := r.PathValue("hash") + if !validJuliaUUID(uuid) || !juliaHexPattern.MatchString(hash) { + http.Error(w, "invalid package reference", http.StatusBadRequest) + return + } + + if err := h.ensureNames(r.Context()); err != nil { + h.proxy.Logger.Warn("julia name map unavailable, using uuid", "error", err) + } + name := h.resolveName(uuid) + + h.proxy.Logger.Info("julia package request", "name", name, "uuid", uuid, "hash", hash) + + upstreamURL := h.upstreamURL + r.URL.Path + result, err := h.proxy.GetOrFetchArtifactFromURL(r.Context(), "julia", name, hash, hash+".tar.gz", upstreamURL) + if err != nil { + h.proxy.Logger.Error("failed to get package", "error", err) + http.Error(w, "failed to fetch package", http.StatusBadGateway) + return + } + + ServeArtifact(w, result) +} + +// handleArtifact serves an immutable binary artifact tarball. Artifacts are +// anonymous content-addressed blobs with no associated package name. +func (h *JuliaHandler) handleArtifact(w http.ResponseWriter, r *http.Request) { + hash := r.PathValue("hash") + if !juliaHexPattern.MatchString(hash) { + http.Error(w, "invalid artifact hash", http.StatusBadRequest) + return + } + + h.proxy.Logger.Info("julia artifact request", "hash", hash) + + upstreamURL := h.upstreamURL + r.URL.Path + result, err := h.proxy.GetOrFetchArtifactFromURL(r.Context(), "julia", juliaArtifactName, hash, hash+".tar.gz", upstreamURL) + if err != nil { + h.proxy.Logger.Error("failed to get artifact", "error", err) + http.Error(w, "failed to fetch artifact", http.StatusBadGateway) + return + } + + ServeArtifact(w, result) +} + +// proxyUpstream forwards a request to the upstream Pkg server without caching. +func (h *JuliaHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil) +} + +// resolveName returns the human-readable package name for a UUID, falling +// back to the UUID itself if it is not present in the loaded registry. +func (h *JuliaHandler) resolveName(uuid string) string { + h.mu.RLock() + defer h.mu.RUnlock() + if name, ok := h.names[uuid]; ok { + return name + } + return uuid +} + +// ensureNames lazily populates the UUID→name map from the General registry. +// Returns immediately if the map is already populated; otherwise blocks until +// a single in-flight load completes. Failed loads are retried on the next call. +func (h *JuliaHandler) ensureNames(ctx context.Context) error { + if h.namesLoaded() { + return nil + } + + h.loadMu.Lock() + defer h.loadMu.Unlock() + + if h.namesLoaded() { + return nil + } + return h.loadNamesFromUpstream(ctx) +} + +func (h *JuliaHandler) namesLoaded() bool { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.names) > 0 +} + +// loadNamesFromUpstream fetches the current /registries listing, downloads the +// General registry tarball at its current hash, and parses Registry.toml. +func (h *JuliaHandler) loadNamesFromUpstream(ctx context.Context) error { + hash, err := h.fetchGeneralRegistryHash(ctx) + if err != nil { + return err + } + return h.loadRegistryTarball(ctx, juliaGeneralRegistryUUID, hash) +} + +// fetchGeneralRegistryHash reads /registries and returns the current tree hash +// for the General registry. +func (h *JuliaHandler) fetchGeneralRegistryHash(ctx context.Context) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, h.upstreamURL+"/registries", nil) + if err != nil { + return "", err + } + resp, err := h.proxy.HTTPClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("upstream /registries returned %d", resp.StatusCode) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + uuid, hash, ok := parseRegistryLine(scanner.Text()) + if ok && uuid == juliaGeneralRegistryUUID { + return hash, nil + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("general registry not listed in /registries") +} + +// refreshNamesFromRegistry reloads the UUID→name map from a registry tarball +// that has just been cached. Errors are logged but do not affect the response. +func (h *JuliaHandler) refreshNamesFromRegistry(uuid, hash string) { + if uuid != juliaGeneralRegistryUUID { + return + } + h.mu.RLock() + current := h.namesHash + h.mu.RUnlock() + if current == hash { + return + } + if err := h.loadRegistryTarball(context.Background(), uuid, hash); err != nil { + h.proxy.Logger.Warn("failed to refresh julia name map", "error", err) + } +} + +// loadRegistryTarball downloads a registry tarball and replaces the name map +// with the contents of its Registry.toml. +func (h *JuliaHandler) loadRegistryTarball(ctx context.Context, uuid, hash string) error { + url := fmt.Sprintf("%s/registry/%s/%s", h.upstreamURL, uuid, hash) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + resp, err := h.proxy.HTTPClient.Do(req) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("upstream registry returned %d", resp.StatusCode) + } + + names, err := extractRegistryNames(resp.Body) + if err != nil { + return err + } + + h.mu.Lock() + h.names = names + h.namesHash = hash + h.mu.Unlock() + + h.proxy.Logger.Info("loaded julia registry name map", "packages", len(names), "hash", hash) + return nil +} + +// extractRegistryNames reads a gzipped registry tarball, finds Registry.toml +// at the root, and returns its [packages] table as a UUID→name map. +func extractRegistryNames(r io.Reader) (map[string]string, error) { + gz, err := gzip.NewReader(r) + if err != nil { + return nil, fmt.Errorf("opening gzip stream: %w", err) + } + defer func() { _ = gz.Close() }() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err == io.EOF { + return nil, fmt.Errorf("no Registry.toml in tarball") + } + if err != nil { + return nil, err + } + if strings.TrimPrefix(hdr.Name, "./") != "Registry.toml" { + continue + } + + data, err := io.ReadAll(tr) + if err != nil { + return nil, err + } + return parseRegistryToml(data) + } +} + +type juliaRegistryFile struct { + Packages map[string]struct { + Name string `toml:"name"` + } `toml:"packages"` +} + +// parseRegistryToml decodes the [packages] table of a Registry.toml file. +func parseRegistryToml(data []byte) (map[string]string, error) { + var reg juliaRegistryFile + if _, err := toml.NewDecoder(bytes.NewReader(data)).Decode(®); err != nil { + return nil, fmt.Errorf("parsing Registry.toml: %w", err) + } + + names := make(map[string]string, len(reg.Packages)) + for uuid, pkg := range reg.Packages { + if pkg.Name != "" { + names[uuid] = pkg.Name + } + } + return names, nil +} + +// parseRegistryLine parses a single line from /registries of the form +// "/registry/{uuid}/{hash}" and returns the uuid and hash. +func parseRegistryLine(line string) (uuid, hash string, ok bool) { + line = strings.TrimSpace(line) + line = strings.TrimPrefix(line, "/registry/") + uuid, hash, found := strings.Cut(line, "/") + if !found || !validJuliaUUID(uuid) || !juliaHexPattern.MatchString(hash) { + return "", "", false + } + return uuid, hash, true +} + +// validJuliaUUID reports whether s looks like a lowercase RFC 4122 UUID. +func validJuliaUUID(s string) bool { + return juliaUUIDPattern.MatchString(s) +} diff --git a/internal/handler/julia_test.go b/internal/handler/julia_test.go new file mode 100644 index 0000000..68fb975 --- /dev/null +++ b/internal/handler/julia_test.go @@ -0,0 +1,167 @@ +package handler + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "log/slog" + "net/http" + "net/http/httptest" + "testing" +) + +func TestJuliaParseRegistryLine(t *testing.T) { + tests := []struct { + line string + wantUUID string + wantHash string + wantOK bool + }{ + { + "/registry/23338594-aafe-5451-b93e-139f81909106/342327538ed6c1ec54c69fa145e7b6bf5934201e", + "23338594-aafe-5451-b93e-139f81909106", + "342327538ed6c1ec54c69fa145e7b6bf5934201e", + true, + }, + { + " /registry/23338594-aafe-5451-b93e-139f81909106/342327538ed6c1ec54c69fa145e7b6bf5934201e\n", + "23338594-aafe-5451-b93e-139f81909106", + "342327538ed6c1ec54c69fa145e7b6bf5934201e", + true, + }, + {"/registry/not-a-uuid/0000", "", "", false}, + {"junk", "", "", false}, + {"", "", "", false}, + } + + for _, tt := range tests { + uuid, hash, ok := parseRegistryLine(tt.line) + if uuid != tt.wantUUID || hash != tt.wantHash || ok != tt.wantOK { + t.Errorf("parseRegistryLine(%q) = (%q, %q, %v), want (%q, %q, %v)", + tt.line, uuid, hash, ok, tt.wantUUID, tt.wantHash, tt.wantOK) + } + } +} + +func TestJuliaValidUUID(t *testing.T) { + tests := []struct { + s string + want bool + }{ + {"23338594-aafe-5451-b93e-139f81909106", true}, + {"295af30f-e4ad-537b-8983-00126c2a3abe", true}, + {"23338594-AAFE-5451-b93e-139f81909106", false}, + {"23338594aafe5451b93e139f81909106", false}, + {"23338594-aafe-5451-b93e-139f8190910", false}, + {"23338594-aafe-5451-b93e-139f81909106-", false}, + {"23338594-gafe-5451-b93e-139f81909106", false}, + {"", false}, + } + + for _, tt := range tests { + if got := validJuliaUUID(tt.s); got != tt.want { + t.Errorf("validJuliaUUID(%q) = %v, want %v", tt.s, got, tt.want) + } + } +} + +func TestJuliaParseRegistryToml(t *testing.T) { + data := []byte(`name = "General" +uuid = "23338594-aafe-5451-b93e-139f81909106" + +[packages] +295af30f-e4ad-537b-8983-00126c2a3abe = { name = "Revise", path = "R/Revise" } +91a5bcdd-55d7-5caf-9e0b-520d859cae80 = { name = "Plots", path = "P/Plots" } +`) + + names, err := parseRegistryToml(data) + if err != nil { + t.Fatalf("parseRegistryToml: %v", err) + } + if got := names["295af30f-e4ad-537b-8983-00126c2a3abe"]; got != "Revise" { + t.Errorf("names[Revise uuid] = %q, want Revise", got) + } + if got := names["91a5bcdd-55d7-5caf-9e0b-520d859cae80"]; got != "Plots" { + t.Errorf("names[Plots uuid] = %q, want Plots", got) + } + if len(names) != 2 { + t.Errorf("len(names) = %d, want 2", len(names)) + } +} + +func TestJuliaExtractRegistryNames(t *testing.T) { + registryToml := `name = "General" +[packages] +295af30f-e4ad-537b-8983-00126c2a3abe = { name = "Revise", path = "R/Revise" } +` + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + for _, f := range []struct{ name, body string }{ + {"R/Revise/Package.toml", "name = \"Revise\"\n"}, + {"Registry.toml", registryToml}, + } { + if err := tw.WriteHeader(&tar.Header{Name: f.name, Mode: 0o644, Size: int64(len(f.body))}); err != nil { + t.Fatalf("WriteHeader: %v", err) + } + if _, err := tw.Write([]byte(f.body)); err != nil { + t.Fatalf("Write: %v", err) + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar Close: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip Close: %v", err) + } + + names, err := extractRegistryNames(bytes.NewReader(buf.Bytes())) + if err != nil { + t.Fatalf("extractRegistryNames: %v", err) + } + if got := names["295af30f-e4ad-537b-8983-00126c2a3abe"]; got != "Revise" { + t.Errorf("names[Revise uuid] = %q, want Revise", got) + } +} + +func TestJuliaResolveName(t *testing.T) { + h := &JuliaHandler{ + proxy: &Proxy{Logger: slog.Default()}, + names: map[string]string{ + "295af30f-e4ad-537b-8983-00126c2a3abe": "Revise", + }, + } + + if got := h.resolveName("295af30f-e4ad-537b-8983-00126c2a3abe"); got != "Revise" { + t.Errorf("resolveName(known) = %q, want Revise", got) + } + if got := h.resolveName("00000000-0000-0000-0000-000000000000"); got != "00000000-0000-0000-0000-000000000000" { + t.Errorf("resolveName(unknown) = %q, want uuid fallback", got) + } +} + +func TestJuliaRoutesValidation(t *testing.T) { + h := NewJuliaHandler(&Proxy{Logger: slog.Default()}, "") + routes := h.Routes() + + tests := []struct { + path string + want int + }{ + {"/package/not-a-uuid/342327538ed6c1ec54c69fa145e7b6bf5934201e", http.StatusBadRequest}, + {"/package/295af30f-e4ad-537b-8983-00126c2a3abe/short", http.StatusBadRequest}, + {"/registry/295af30f-e4ad-537b-8983-00126c2a3abe/zzzz", http.StatusBadRequest}, + {"/artifact/nothex", http.StatusBadRequest}, + {"/nope", http.StatusNotFound}, + } + + for _, tt := range tests { + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + rr := httptest.NewRecorder() + routes.ServeHTTP(rr, req) + if rr.Code != tt.want { + t.Errorf("GET %s = %d, want %d", tt.path, rr.Code, tt.want) + } + } +} diff --git a/internal/handler/maven.go b/internal/handler/maven.go index 86664a2..c423645 100644 --- a/internal/handler/maven.go +++ b/internal/handler/maven.go @@ -1,6 +1,7 @@ package handler import ( + "errors" "fmt" "net/http" "path" @@ -8,23 +9,33 @@ import ( ) const ( - mavenUpstream = "https://repo1.maven.org/maven2" - minMavenParts = 4 // group path segments + artifact + version + filename + mavenCentralUpstream = "https://repo1.maven.org/maven2" + gradlePluginPortalUpstream = "https://plugins.gradle.org/m2" + minMavenParts = 4 // group path segments + artifact + version + filename ) // MavenHandler handles Maven repository protocol requests. type MavenHandler struct { - proxy *Proxy - upstreamURL string - proxyURL string + proxy *Proxy + upstreamURL string + pluginPortalUpstreamURL string + proxyURL string } // NewMavenHandler creates a new Maven repository handler. -func NewMavenHandler(proxy *Proxy, proxyURL string) *MavenHandler { +func NewMavenHandler(proxy *Proxy, proxyURL, upstreamURL, pluginPortalUpstreamURL string) *MavenHandler { + if strings.TrimSpace(upstreamURL) == "" { + upstreamURL = mavenCentralUpstream + } + if strings.TrimSpace(pluginPortalUpstreamURL) == "" { + pluginPortalUpstreamURL = gradlePluginPortalUpstream + } + return &MavenHandler{ - proxy: proxy, - upstreamURL: mavenUpstream, - proxyURL: strings.TrimSuffix(proxyURL, "/"), + proxy: proxy, + upstreamURL: strings.TrimSuffix(upstreamURL, "/"), + pluginPortalUpstreamURL: strings.TrimSuffix(pluginPortalUpstreamURL, "/"), + proxyURL: strings.TrimSuffix(proxyURL, "/"), } } @@ -51,8 +62,7 @@ func (h *MavenHandler) handleRequest(w http.ResponseWriter, r *http.Request) { filename := path.Base(urlPath) if h.isMetadataFile(filename) { - cacheKey := strings.ReplaceAll(urlPath, "/", "_") - h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "maven", cacheKey, "*/*") + h.handleMetadata(w, r, urlPath) return } @@ -66,6 +76,32 @@ func (h *MavenHandler) handleRequest(w http.ResponseWriter, r *http.Request) { h.proxyUpstream(w, r) } +func (h *MavenHandler) handleMetadata(w http.ResponseWriter, r *http.Request, urlPath string) { + cacheKey := strings.ReplaceAll(urlPath, "/", "_") + upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, urlPath) + + body, contentType, err := h.proxy.FetchOrCacheMetadata(r.Context(), "maven", cacheKey, upstreamURL, "*/*") + if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + pluginPortalURL := fmt.Sprintf("%s/%s", h.pluginPortalUpstreamURL, urlPath) + h.proxy.Logger.Info("maven metadata unavailable in primary upstream, trying Gradle Plugin Portal", + "path", urlPath) + body, contentType, err = h.proxy.FetchOrCacheMetadata(r.Context(), "maven", cacheKey, pluginPortalURL, "*/*") + } + } + if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } + h.proxy.Logger.Error("metadata fetch failed", "error", err) + http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) + return + } + + h.proxy.writeMetadataCachedResponse(w, r, "maven", cacheKey, body, contentType) +} + // handleDownload serves an artifact file, fetching and caching from upstream if needed. func (h *MavenHandler) handleDownload(w http.ResponseWriter, r *http.Request, urlPath string) { // Parse Maven path: group/artifact/version/filename @@ -86,6 +122,18 @@ func (h *MavenHandler) handleDownload(w http.ResponseWriter, r *http.Request, ur result, err := h.proxy.GetOrFetchArtifactFromURL(r.Context(), "maven", name, version, filename, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + pluginPortalURL := fmt.Sprintf("%s/%s", h.pluginPortalUpstreamURL, urlPath) + h.proxy.Logger.Info("maven artifact not found in primary upstream, trying Gradle Plugin Portal", + "group", group, "artifact", artifact, "version", version, "filename", filename) + result, err = h.proxy.GetOrFetchArtifactFromURL(r.Context(), "maven", name, version, filename, pluginPortalURL) + } + } + if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("failed to get artifact", "error", err) http.Error(w, "failed to fetch artifact", http.StatusBadGateway) return @@ -115,7 +163,7 @@ func (h *MavenHandler) parsePath(urlPath string) (group, artifact, version, file // isArtifactFile returns true if the filename looks like a Maven artifact. func (h *MavenHandler) isArtifactFile(filename string) bool { // Common artifact extensions - extensions := []string{".jar", ".war", ".ear", ".pom", ".aar", ".klib"} + extensions := []string{".jar", ".war", ".ear", ".pom", ".aar", ".klib", ".module"} for _, ext := range extensions { if strings.HasSuffix(filename, ext) { return true diff --git a/internal/handler/maven_test.go b/internal/handler/maven_test.go index df6917c..9ca5eb6 100644 --- a/internal/handler/maven_test.go +++ b/internal/handler/maven_test.go @@ -52,6 +52,7 @@ func TestMavenIsArtifactFile(t *testing.T) { }{ {"guava-32.1.3-jre.jar", true}, {"guava-32.1.3-jre.pom", true}, + {"guava-32.1.3-jre.module", true}, {"app-1.0.war", true}, {"lib-1.0.aar", true}, {"maven-metadata.xml", false}, @@ -65,3 +66,63 @@ func TestMavenIsArtifactFile(t *testing.T) { } } } + +func TestMavenIsMetadataFile(t *testing.T) { + h := &MavenHandler{} + + tests := []struct { + name string + filename string + want bool + }{ + { + name: "pom is artifact, not metadata", + filename: "com.diffplug.spotless.gradle.plugin-8.4.0.pom", + want: false, + }, + { + name: "pom checksum is metadata", + filename: "com.diffplug.spotless.gradle.plugin-8.4.0.pom.sha1", + want: true, + }, + { + name: "metadata file", + filename: "maven-metadata.xml", + want: true, + }, + { + name: "metadata checksum", + filename: "maven-metadata.xml.sha256", + want: true, + }, + { + name: "jar checksum is metadata", + filename: "guava-32.1.3-jre.jar.sha1", + want: true, + }, + { + name: "asc signature is metadata", + filename: "guava-32.1.3-jre.jar.asc", + want: true, + }, + { + name: "regular jar is not metadata", + filename: "guava-32.1.3-jre.jar", + want: false, + }, + { + name: "pom checksum is metadata", + filename: "guava-32.1.3-jre.pom.sha1", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := h.isMetadataFile(tt.filename) + if got != tt.want { + t.Errorf("isMetadataFile(%q) = %v, want %v", tt.filename, got, tt.want) + } + }) + } +} diff --git a/internal/handler/npm_test.go b/internal/handler/npm_test.go index 7db3539..bc1edde 100644 --- a/internal/handler/npm_test.go +++ b/internal/handler/npm_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) const testVersion100 = "1.0.0" diff --git a/internal/handler/nuget.go b/internal/handler/nuget.go index 1c022fb..40b8b5f 100644 --- a/internal/handler/nuget.go +++ b/internal/handler/nuget.go @@ -172,7 +172,7 @@ func (h *NuGetHandler) handleRegistration(w http.ResponseWriter, r *http.Request http.Error(w, "failed to create request", http.StatusInternalServerError) return } - req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set(headerAcceptEncoding, "gzip") resp, err := h.proxy.HTTPClient.Do(req) if err != nil { @@ -193,7 +193,7 @@ func (h *NuGetHandler) handleRegistration(w http.ResponseWriter, r *http.Request return } - body, err := ReadMetadata(resp.Body) + body, err := h.proxy.ReadMetadata(resp.Body) if err != nil { http.Error(w, "failed to read response", http.StatusInternalServerError) return @@ -338,8 +338,8 @@ func (h *NuGetHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { } // Copy accept-encoding for compression - if ae := r.Header.Get("Accept-Encoding"); ae != "" { - req.Header.Set("Accept-Encoding", ae) + if ae := r.Header.Get(headerAcceptEncoding); ae != "" { + req.Header.Set(headerAcceptEncoding, ae) } resp, err := h.proxy.HTTPClient.Do(req) diff --git a/internal/handler/nuget_test.go b/internal/handler/nuget_test.go index 68c9d22..b2164e5 100644 --- a/internal/handler/nuget_test.go +++ b/internal/handler/nuget_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) func nugetTestProxy() *Proxy { diff --git a/internal/handler/path_traversal_test.go b/internal/handler/path_traversal_test.go index 14d2218..5ad68a5 100644 --- a/internal/handler/path_traversal_test.go +++ b/internal/handler/path_traversal_test.go @@ -14,6 +14,15 @@ func TestContainsPathTraversal(t *testing.T) { {"pool/main/../../../etc/shadow", true}, {"pool/..hidden/file", false}, // ".." as a segment, not "..hidden" {"", false}, + {"%2e%2e/etc/passwd", true}, + {"%2e%2e%2fetc%2fpasswd", true}, + {"pool/%2e%2e/%2e%2e/etc/shadow", true}, + {"%2E%2E%2Fetc", true}, + {`..\\etc\\passwd`, true}, + {`pool\\..\\..\\etc`, true}, + {"%2e%2e%5cetc%5cpasswd", true}, + {"pool/%2e%2ehidden/file", false}, + {"pool/%zz/bad-encoding", false}, } for _, tt := range tests { diff --git a/internal/handler/pub_test.go b/internal/handler/pub_test.go index 2788714..8a4c098 100644 --- a/internal/handler/pub_test.go +++ b/internal/handler/pub_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" ) func TestPubRewriteMetadata(t *testing.T) { diff --git a/internal/handler/pypi_test.go b/internal/handler/pypi_test.go index 9e2ade0..2b58960 100644 --- a/internal/handler/pypi_test.go +++ b/internal/handler/pypi_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" "github.com/git-pkgs/registries/fetch" ) diff --git a/internal/handler/read_metadata_test.go b/internal/handler/read_metadata_test.go index 60c1cf2..b13bddb 100644 --- a/internal/handler/read_metadata_test.go +++ b/internal/handler/read_metadata_test.go @@ -7,9 +7,12 @@ import ( ) func TestReadMetadata(t *testing.T) { + const limit = 1024 + p := &Proxy{MetadataMaxSize: limit} + t.Run("small body", func(t *testing.T) { data := []byte("hello world") - got, err := ReadMetadata(bytes.NewReader(data)) + got, err := p.ReadMetadata(bytes.NewReader(data)) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -19,27 +22,39 @@ func TestReadMetadata(t *testing.T) { }) t.Run("exactly at limit", func(t *testing.T) { - data := make([]byte, maxMetadataSize) + data := make([]byte, limit) for i := range data { data[i] = 'x' } - got, err := ReadMetadata(bytes.NewReader(data)) + got, err := p.ReadMetadata(bytes.NewReader(data)) if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(got) != int(maxMetadataSize) { - t.Errorf("got length %d, want %d", len(got), maxMetadataSize) + if len(got) != limit { + t.Errorf("got length %d, want %d", len(got), limit) } }) t.Run("over limit returns error", func(t *testing.T) { - data := make([]byte, maxMetadataSize+100) + data := make([]byte, limit+100) for i := range data { data[i] = 'x' } - _, err := ReadMetadata(bytes.NewReader(data)) + _, err := p.ReadMetadata(bytes.NewReader(data)) if !errors.Is(err, ErrMetadataTooLarge) { t.Errorf("got error %v, want ErrMetadataTooLarge", err) } }) + + t.Run("zero limit uses default", func(t *testing.T) { + p := &Proxy{} + data := make([]byte, 1<<20) + got, err := p.ReadMetadata(bytes.NewReader(data)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != len(data) { + t.Errorf("got length %d, want %d", len(got), len(data)) + } + }) } diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index da8bde6..f23a5a9 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -120,6 +120,22 @@ var ( Help: "Number of currently active requests", }, ) + + IntegrityFailures = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "proxy_integrity_failures_total", + Help: "Cached artifacts that failed hash verification on read", + }, + []string{"ecosystem"}, + ) + + HealthProbeFailures = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "proxy_health_probe_failures_total", + Help: "Total number of storage health probe failures, by step (write|size|read|verify|delete).", + }, + []string{"step"}, + ) ) func init() { @@ -138,6 +154,8 @@ func init() { StorageOperationDuration, StorageErrors, ActiveRequests, + IntegrityFailures, + HealthProbeFailures, ) } @@ -178,6 +196,17 @@ func RecordStorageOperation(operation string, duration time.Duration) { StorageOperationDuration.WithLabelValues(operation).Observe(duration.Seconds()) } +// RecordIntegrityFailure increments the integrity failure counter. +func RecordIntegrityFailure(ecosystem string) { + IntegrityFailures.WithLabelValues(ecosystem).Inc() +} + +// RecordHealthProbeFailure increments the health probe failure counter. +// step is one of: "write", "size", "read", "verify", "delete". +func RecordHealthProbeFailure(step string) { + HealthProbeFailures.WithLabelValues(step).Inc() +} + // RecordStorageError increments storage error counter. func RecordStorageError(operation string) { StorageErrors.WithLabelValues(operation).Inc() diff --git a/internal/server/api.go b/internal/server/api.go index 903dc22..ddb9ca7 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -140,10 +140,14 @@ type BulkResponse struct { func (h *APIHandler) HandlePackagePath(w http.ResponseWriter, r *http.Request) { ecosystem := chi.URLParam(r, "ecosystem") wildcard := chi.URLParam(r, "*") + if err := validatePackagePath(wildcard); err != nil { + badRequest(w, err.Error()) + return + } segments := splitWildcardPath(wildcard) if ecosystem == "" || len(segments) == 0 { - http.Error(w, "ecosystem and name are required", http.StatusBadRequest) + badRequest(w, "ecosystem and name are required") return } @@ -190,12 +194,12 @@ func (h *APIHandler) HandlePackagePath(w http.ResponseWriter, r *http.Request) { func (h *APIHandler) getPackage(w http.ResponseWriter, r *http.Request, ecosystem, name string) { info, err := h.enrichment.EnrichPackage(r.Context(), ecosystem, name) if err != nil { - http.Error(w, "failed to enrich package", http.StatusInternalServerError) + writeError(w, http.StatusBadGateway, ErrCodeUpstream, "failed to enrich package") return } if info == nil { - http.Error(w, "package not found", http.StatusNotFound) + notFound(w, "package not found") return } @@ -217,7 +221,7 @@ func (h *APIHandler) getPackage(w http.ResponseWriter, r *http.Request, ecosyste func (h *APIHandler) getVersion(w http.ResponseWriter, r *http.Request, ecosystem, name, version string) { result, err := h.enrichment.EnrichFull(r.Context(), ecosystem, name, version) if err != nil { - http.Error(w, "failed to enrich version", http.StatusInternalServerError) + writeError(w, http.StatusBadGateway, ErrCodeUpstream, "failed to enrich version") return } @@ -274,10 +278,14 @@ func (h *APIHandler) getVersion(w http.ResponseWriter, r *http.Request, ecosyste func (h *APIHandler) HandleVulnsPath(w http.ResponseWriter, r *http.Request) { ecosystem := chi.URLParam(r, "ecosystem") wildcard := chi.URLParam(r, "*") + if err := validatePackagePath(wildcard); err != nil { + badRequest(w, err.Error()) + return + } segments := splitWildcardPath(wildcard) if ecosystem == "" || len(segments) == 0 { - http.Error(w, "ecosystem and name are required", http.StatusBadRequest) + badRequest(w, "ecosystem and name are required") return } @@ -298,7 +306,7 @@ func (h *APIHandler) HandleVulnsPath(w http.ResponseWriter, r *http.Request) { vulns, err := h.enrichment.CheckVulnerabilities(r.Context(), ecosystem, name, version) if err != nil { - http.Error(w, "failed to check vulnerabilities", http.StatusInternalServerError) + writeError(w, http.StatusBadGateway, ErrCodeUpstream, "failed to check vulnerabilities") return } @@ -330,19 +338,19 @@ func (h *APIHandler) HandleVulnsPath(w http.ResponseWriter, r *http.Request) { // @Produce json // @Param request body OutdatedRequest true "Packages to check" // @Success 200 {object} OutdatedResponse -// @Failure 400 {string} string -// @Failure 500 {string} string +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/outdated [post] func (h *APIHandler) HandleOutdated(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) var req OutdatedRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid request body", http.StatusBadRequest) + badRequest(w, "invalid request body") return } if len(req.Packages) == 0 { - http.Error(w, "packages list is required", http.StatusBadRequest) + badRequest(w, "packages list is required") return } @@ -376,19 +384,19 @@ func (h *APIHandler) HandleOutdated(w http.ResponseWriter, r *http.Request) { // @Produce json // @Param request body BulkRequest true "PURLs" // @Success 200 {object} BulkResponse -// @Failure 400 {string} string -// @Failure 500 {string} string +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/bulk [post] func (h *APIHandler) HandleBulkLookup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) var req BulkRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid request body", http.StatusBadRequest) + badRequest(w, "invalid request body") return } if len(req.PURLs) == 0 { - http.Error(w, "purls list is required", http.StatusBadRequest) + badRequest(w, "purls list is required") return } @@ -476,15 +484,15 @@ type SearchPackageResult struct { // @Param q query string true "Query" // @Param ecosystem query string false "Ecosystem" // @Success 200 {object} SearchResponse -// @Failure 400 {string} string -// @Failure 500 {string} string +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/search [get] func (h *APIHandler) HandleSearch(w http.ResponseWriter, r *http.Request) { query := r.URL.Query().Get("q") ecosystem := r.URL.Query().Get("ecosystem") if query == "" { - http.Error(w, "query parameter 'q' is required", http.StatusBadRequest) + badRequest(w, "query parameter 'q' is required") return } @@ -494,7 +502,7 @@ func (h *APIHandler) HandleSearch(w http.ResponseWriter, r *http.Request) { // Search in database results, err := h.db.SearchPackages(query, ecosystem, limit, (page-1)*limit) if err != nil { - http.Error(w, "search failed", http.StatusInternalServerError) + internalError(w, "search failed") return } @@ -538,7 +546,7 @@ func (h *APIHandler) HandleSearch(w http.ResponseWriter, r *http.Request) { func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(v); err != nil { - http.Error(w, "failed to encode response", http.StatusInternalServerError) + internalError(w, "failed to encode response") } } @@ -573,8 +581,8 @@ type PackageListResult struct { // @Param ecosystem query string false "Ecosystem" // @Param sort query string false "Sort" Enums(hits,name,size,cached_at,ecosystem,vulns) // @Success 200 {object} PackagesListResponse -// @Failure 400 {string} string -// @Failure 500 {string} string +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/packages [get] func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request) { ecosystem := r.URL.Query().Get("ecosystem") @@ -592,7 +600,7 @@ func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request) "vulns": true, } if !validSorts[sortBy] { - http.Error(w, "invalid sort parameter", http.StatusBadRequest) + badRequest(w, "invalid sort parameter") return } @@ -601,7 +609,7 @@ func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request) packages, err := h.db.ListCachedPackages(ecosystem, sortBy, limit, (page-1)*limit) if err != nil { - http.Error(w, "failed to list packages", http.StatusInternalServerError) + internalError(w, "failed to list packages") return } diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 548f324..0494b2f 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" "github.com/git-pkgs/proxy/internal/database" @@ -48,6 +49,35 @@ func TestHandlePackagePath_MissingParams(t *testing.T) { } } +func TestHandlePackagePath_InvalidName(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + svc := enrichment.New(logger) + h := NewAPIHandler(svc, nil) + + r := chi.NewRouter() + r.Get("/api/package/{ecosystem}/*", h.HandlePackagePath) + + tests := []struct { + name string + path string + }{ + {"null byte", "/api/package/npm/lodash%00"}, + {"too long", "/api/package/npm/" + strings.Repeat("a", maxPackagePathLen+1)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + }) + } +} + func TestHandleVulnsPath_MissingParams(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) svc := enrichment.New(logger) diff --git a/internal/server/browse.go b/internal/server/browse.go index 6f718cd..be2b04a 100644 --- a/internal/server/browse.go +++ b/internal/server/browse.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "encoding/json" "fmt" "io" @@ -18,6 +17,11 @@ import ( const contentTypePlainText = "text/plain; charset=utf-8" +// maxBrowseArchiveSize caps how much data openArchive will buffer for +// prefix detection. Artifacts larger than this are rejected to prevent +// memory exhaustion from a single request. +const maxBrowseArchiveSize = 512 << 20 // 512 MB + // archiveFilename returns a filename suitable for archive format detection. // Some ecosystems (e.g. composer) store artifacts with bare hash filenames // that have no extension. This adds .zip when the original has no extension @@ -64,26 +68,27 @@ func detectSingleRootDir(reader archives.Reader) string { func openArchive(filename string, content io.Reader, ecosystem string) (archives.Reader, error) { //nolint:ireturn // wraps multiple archive implementations fname := archiveFilename(filename) - // npm always uses package/ prefix - if ecosystem == "npm" { - return archives.OpenWithPrefix(fname, content, "package/") - } - - // Read content into memory so we can scan then wrap with prefix - data, err := io.ReadAll(content) + limited := io.LimitReader(content, maxBrowseArchiveSize+1) + data, err := io.ReadAll(limited) if err != nil { return nil, fmt.Errorf("reading artifact: %w", err) } + if int64(len(data)) > maxBrowseArchiveSize { + return nil, fmt.Errorf("artifact too large for browsing (%d bytes)", len(data)) + } - // Open once to detect root prefix - probe, err := archives.Open(fname, bytes.NewReader(data)) + if ecosystem == "npm" { + return archives.OpenBytesWithPrefix(fname, data, "package/") + } + + probe, err := archives.OpenBytes(fname, data) if err != nil { return nil, err } prefix := detectSingleRootDir(probe) _ = probe.Close() - return archives.OpenWithPrefix(fname, bytes.NewReader(data), prefix) + return archives.OpenBytesWithPrefix(fname, data, prefix) } // BrowseListResponse contains the file listing for a directory in an archives. @@ -112,8 +117,8 @@ type BrowseFileInfo struct { // @Param version path string true "Version" // @Param path query string false "Directory path inside the archive" // @Success 200 {object} BrowseListResponse -// @Failure 404 {string} string -// @Failure 500 {string} string +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/browse/{ecosystem}/{name}/{version} [get] // handleBrowsePath dispatches /api/browse/{ecosystem}/* to the appropriate browse handler. // It resolves namespaced package names by consulting the database. @@ -125,10 +130,14 @@ type BrowseFileInfo struct { func (s *Server) handleBrowsePath(w http.ResponseWriter, r *http.Request) { ecosystem := chi.URLParam(r, "ecosystem") wildcard := chi.URLParam(r, "*") + if err := validatePackagePath(wildcard); err != nil { + badRequest(w, err.Error()) + return + } segments := splitWildcardPath(wildcard) if ecosystem == "" || len(segments) < 2 { - http.Error(w, "ecosystem, name, and version required", http.StatusBadRequest) + badRequest(w, "ecosystem, name, and version required") return } @@ -152,7 +161,7 @@ func (s *Server) handleBrowsePath(w http.ResponseWriter, r *http.Request) { rest = nameVersionSegments[len(nameVersionSegments)-1:] } if len(rest) != 1 { - http.Error(w, "not found", http.StatusNotFound) + notFound(w, "not found") return } s.browseFile(w, r, ecosystem, name, rest[0], filePath) @@ -166,7 +175,7 @@ func (s *Server) handleBrowsePath(w http.ResponseWriter, r *http.Request) { rest = segments[len(segments)-1:] } if len(rest) != 1 { - http.Error(w, "not found", http.StatusNotFound) + notFound(w, "not found") return } s.browseList(w, r, ecosystem, name, rest[0]) @@ -177,10 +186,14 @@ func (s *Server) handleBrowsePath(w http.ResponseWriter, r *http.Request) { func (s *Server) handleComparePath(w http.ResponseWriter, r *http.Request) { ecosystem := chi.URLParam(r, "ecosystem") wildcard := chi.URLParam(r, "*") + if err := validatePackagePath(wildcard); err != nil { + badRequest(w, err.Error()) + return + } segments := splitWildcardPath(wildcard) if ecosystem == "" || len(segments) < 3 { - http.Error(w, "ecosystem, name, fromVersion, and toVersion required", http.StatusBadRequest) + badRequest(w, "ecosystem, name, fromVersion, and toVersion required") return } @@ -200,12 +213,12 @@ func (s *Server) browseList(w http.ResponseWriter, r *http.Request, ecosystem, n versionPURL := purl.MakePURLString(ecosystem, name, version) artifacts, err := s.db.GetArtifactsByVersionPURL(versionPURL) if err != nil { - http.Error(w, "version not found", http.StatusNotFound) + notFound(w, "version not found") return } if len(artifacts) == 0 { - http.Error(w, "no artifacts cached", http.StatusNotFound) + notFound(w, "no artifacts cached") return } @@ -219,7 +232,7 @@ func (s *Server) browseList(w http.ResponseWriter, r *http.Request, ecosystem, n } if cachedArtifact == nil { - http.Error(w, "artifact not cached", http.StatusNotFound) + notFound(w, "artifact not cached") return } @@ -227,7 +240,7 @@ func (s *Server) browseList(w http.ResponseWriter, r *http.Request, ecosystem, n artifactReader, err := s.storage.Open(r.Context(), cachedArtifact.StoragePath.String) if err != nil { s.logger.Error("failed to read artifact from storage", "error", err) - http.Error(w, "failed to read artifact", http.StatusInternalServerError) + internalError(w, "failed to read artifact") return } defer func() { _ = artifactReader.Close() }() @@ -236,7 +249,7 @@ func (s *Server) browseList(w http.ResponseWriter, r *http.Request, ecosystem, n archiveReader, err := openArchive(cachedArtifact.Filename, artifactReader, ecosystem) if err != nil { s.logger.Error("failed to open archive", "error", err, "filename", cachedArtifact.Filename) - http.Error(w, "failed to open archive", http.StatusInternalServerError) + internalError(w, "failed to open archive") return } defer func() { _ = archiveReader.Close() }() @@ -245,7 +258,7 @@ func (s *Server) browseList(w http.ResponseWriter, r *http.Request, ecosystem, n files, err := archiveReader.ListDir(dirPath) if err != nil { s.logger.Error("failed to list directory", "error", err, "path", dirPath) - http.Error(w, "failed to list directory", http.StatusInternalServerError) + internalError(w, "failed to list directory") return } @@ -280,13 +293,13 @@ func (s *Server) browseList(w http.ResponseWriter, r *http.Request, ecosystem, n // @Param version path string true "Version" // @Param filepath path string true "File path inside the archive" // @Success 200 {file} file -// @Failure 400 {string} string -// @Failure 404 {string} string -// @Failure 500 {string} string +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/browse/{ecosystem}/{name}/{version}/file/{filepath} [get] func (s *Server) browseFile(w http.ResponseWriter, r *http.Request, ecosystem, name, version, filePath string) { if filePath == "" { - http.Error(w, "file path required", http.StatusBadRequest) + badRequest(w, "file path required") return } @@ -294,12 +307,12 @@ func (s *Server) browseFile(w http.ResponseWriter, r *http.Request, ecosystem, n versionPURL := purl.MakePURLString(ecosystem, name, version) artifacts, err := s.db.GetArtifactsByVersionPURL(versionPURL) if err != nil { - http.Error(w, "version not found", http.StatusNotFound) + notFound(w, "version not found") return } if len(artifacts) == 0 { - http.Error(w, "no artifacts cached", http.StatusNotFound) + notFound(w, "no artifacts cached") return } @@ -313,7 +326,7 @@ func (s *Server) browseFile(w http.ResponseWriter, r *http.Request, ecosystem, n } if cachedArtifact == nil { - http.Error(w, "artifact not cached", http.StatusNotFound) + notFound(w, "artifact not cached") return } @@ -321,7 +334,7 @@ func (s *Server) browseFile(w http.ResponseWriter, r *http.Request, ecosystem, n artifactReader, err := s.storage.Open(r.Context(), cachedArtifact.StoragePath.String) if err != nil { s.logger.Error("failed to read artifact from storage", "error", err) - http.Error(w, "failed to read artifact", http.StatusInternalServerError) + internalError(w, "failed to read artifact") return } defer func() { _ = artifactReader.Close() }() @@ -330,7 +343,7 @@ func (s *Server) browseFile(w http.ResponseWriter, r *http.Request, ecosystem, n archiveReader, err := openArchive(cachedArtifact.Filename, artifactReader, ecosystem) if err != nil { s.logger.Error("failed to open archive", "error", err, "filename", cachedArtifact.Filename) - http.Error(w, "failed to open archive", http.StatusInternalServerError) + internalError(w, "failed to open archive") return } defer func() { _ = archiveReader.Close() }() @@ -339,20 +352,20 @@ func (s *Server) browseFile(w http.ResponseWriter, r *http.Request, ecosystem, n fileReader, err := archiveReader.Extract(filePath) if err != nil { if strings.Contains(err.Error(), "not found") { - http.Error(w, "file not found", http.StatusNotFound) + notFound(w, "file not found") return } s.logger.Error("failed to extract file", "error", err, "path", filePath) - http.Error(w, "failed to extract file", http.StatusInternalServerError) + internalError(w, "failed to extract file") return } defer func() { _ = fileReader.Close() }() - // Set content type based on file extension contentType := detectContentType(filePath) w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Security-Policy", "sandbox") + w.Header().Set("X-Content-Type-Options", "nosniff") - // Set filename for download _, filename := path.Split(filePath) w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", filename)) @@ -368,8 +381,8 @@ func detectContentType(filename string) string { // Text formats case ".txt", ".md", ".markdown": return contentTypePlainText - case ".html", ".htm": - return "text/html; charset=utf-8" + case ".html", ".htm", ".xhtml": + return contentTypePlainText case ".css": return "text/css; charset=utf-8" case ".js", ".mjs": @@ -423,7 +436,7 @@ func detectContentType(filename string) string { case ".gif": return "image/gif" case ".svg": - return "image/svg+xml" + return contentTypePlainText case ".ico": return "image/x-icon" @@ -483,8 +496,8 @@ type BrowseSourceData struct { // @Param fromVersion path string true "From version" // @Param toVersion path string true "To version" // @Success 200 {object} map[string]any -// @Failure 404 {string} string -// @Failure 500 {string} string +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /api/compare/{ecosystem}/{name}/{fromVersion}/{toVersion} [get] func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, name, fromVersion, toVersion string) { // Get artifacts for both versions @@ -493,13 +506,13 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, fromArtifacts, err := s.db.GetArtifactsByVersionPURL(fromPURL) if err != nil || len(fromArtifacts) == 0 { - http.Error(w, "from version not found or not cached", http.StatusNotFound) + notFound(w, "from version not found or not cached") return } toArtifacts, err := s.db.GetArtifactsByVersionPURL(toPURL) if err != nil || len(toArtifacts) == 0 { - http.Error(w, "to version not found or not cached", http.StatusNotFound) + notFound(w, "to version not found or not cached") return } @@ -519,7 +532,7 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, } if fromArtifact == nil || toArtifact == nil { - http.Error(w, "one or both versions not cached", http.StatusNotFound) + notFound(w, "one or both versions not cached") return } @@ -527,7 +540,7 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, fromReader, err := s.storage.Open(r.Context(), fromArtifact.StoragePath.String) if err != nil { s.logger.Error("failed to open from artifact", "error", err) - http.Error(w, "failed to read from version", http.StatusInternalServerError) + internalError(w, "failed to read from version") return } defer func() { _ = fromReader.Close() }() @@ -535,7 +548,7 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, toReader, err := s.storage.Open(r.Context(), toArtifact.StoragePath.String) if err != nil { s.logger.Error("failed to open to artifact", "error", err) - http.Error(w, "failed to read to version", http.StatusInternalServerError) + internalError(w, "failed to read to version") return } defer func() { _ = toReader.Close() }() @@ -543,7 +556,7 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, fromArchive, err := openArchive(fromArtifact.Filename, fromReader, ecosystem) if err != nil { s.logger.Error("failed to open from archive", "error", err) - http.Error(w, "failed to open from archive", http.StatusInternalServerError) + internalError(w, "failed to open from archive") return } defer func() { _ = fromArchive.Close() }() @@ -551,7 +564,7 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, toArchive, err := openArchive(toArtifact.Filename, toReader, ecosystem) if err != nil { s.logger.Error("failed to open to archive", "error", err) - http.Error(w, "failed to open to archive", http.StatusInternalServerError) + internalError(w, "failed to open to archive") return } defer func() { _ = toArchive.Close() }() @@ -560,7 +573,7 @@ func (s *Server) compareDiff(w http.ResponseWriter, r *http.Request, ecosystem, result, err := diff.Compare(fromArchive, toArchive) if err != nil { s.logger.Error("failed to generate diff", "error", err) - http.Error(w, "failed to generate diff", http.StatusInternalServerError) + internalError(w, "failed to generate diff") return } diff --git a/internal/server/browse_bench_test.go b/internal/server/browse_bench_test.go new file mode 100644 index 0000000..03f3f02 --- /dev/null +++ b/internal/server/browse_bench_test.go @@ -0,0 +1,57 @@ +package server + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "fmt" + "math/rand" + "testing" +) + +func createBenchTarGz(prefix string, fileCount, fileSize int) []byte { + rnd := rand.New(rand.NewSource(1)) //nolint:gosec + buf := new(bytes.Buffer) + gw := gzip.NewWriter(buf) + tw := tar.NewWriter(gw) + + payload := make([]byte, fileSize) + for i := range fileCount { + rnd.Read(payload) + _ = tw.WriteHeader(&tar.Header{ + Name: fmt.Sprintf("%sfile%04d.dat", prefix, i), + Size: int64(fileSize), + Mode: 0644, + }) + _, _ = tw.Write(payload) + } + _ = tw.Close() + _ = gw.Close() + return buf.Bytes() +} + +func BenchmarkOpenArchive(b *testing.B) { + cases := []struct { + name string + ecosystem string + filename string + data []byte + }{ + {"npm", "npm", "pkg.tgz", createBenchTarGz("package/", 64, 16*1024)}, + {"go", "go", "v1.2.3.tar.gz", createBenchTarGz("repo-abc123/", 64, 16*1024)}, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + b.SetBytes(int64(len(tc.data))) + b.ReportAllocs() + for b.Loop() { + r, err := openArchive(tc.filename, bytes.NewReader(tc.data), tc.ecosystem) + if err != nil { + b.Fatal(err) + } + _ = r.Close() + } + }) + } +} diff --git a/internal/server/browse_test.go b/internal/server/browse_test.go index 2706e90..28f08da 100644 --- a/internal/server/browse_test.go +++ b/internal/server/browse_test.go @@ -179,6 +179,10 @@ func TestDetectContentType(t *testing.T) { {"file.go", "text/x-go; charset=utf-8"}, {"file.py", "text/x-python; charset=utf-8"}, {"file.rs", "text/x-rust; charset=utf-8"}, + {"file.html", contentTypePlainText}, + {"file.htm", contentTypePlainText}, + {"file.xhtml", contentTypePlainText}, + {"file.svg", contentTypePlainText}, {"file.png", "image/png"}, {"file.jpg", "image/jpeg"}, {"README", contentTypePlainText}, @@ -198,6 +202,19 @@ func TestDetectContentType(t *testing.T) { } } +func TestOpenArchiveSizeLimit(t *testing.T) { + huge := bytes.Repeat([]byte("x"), int(maxBrowseArchiveSize)+1) + for _, eco := range []string{"npm", "go"} { + _, err := openArchive("test.tar.gz", bytes.NewReader(huge), eco) + if err == nil { + t.Fatalf("%s: expected error for oversized archive, got nil", eco) + } + if !strings.Contains(err.Error(), "too large") { + t.Fatalf("%s: expected 'too large' error, got: %v", eco, err) + } + } +} + func TestIsLikelyText(t *testing.T) { tests := []struct { filename string diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go index b935628..1de294c 100644 --- a/internal/server/dashboard.go +++ b/internal/server/dashboard.go @@ -127,6 +127,7 @@ func supportedEcosystems() []string { "gem", "golang", "hex", + "julia", "maven", "npm", "nuget", @@ -176,6 +177,8 @@ func ecosystemBadgeClasses(ecosystem string) string { return base + " bg-green-100 text-green-700 dark:bg-green-900/50 dark:text-green-300" case "cran": return base + " bg-slate-100 text-slate-700 dark:bg-slate-800 dark:text-slate-300" + case "julia": + return base + " bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300" case "oci": return base + " bg-sky-100 text-sky-700 dark:bg-sky-900/50 dark:text-sky-300" case "deb": @@ -286,6 +289,20 @@ index-url = ` + baseURL + `/pypi/simple/`), </mirror> </mirrors> </settings>`), + }, + { + ID: "gradle", + Name: "Gradle Build Cache", + Language: "Java/Kotlin", + Endpoint: "/gradle/", + Instructions: template.HTML(`

Configure Gradle to use the proxy for HttpBuildCache:

+
// In settings.gradle(.kts)
+buildCache {
+  remote<HttpBuildCache> {
+    url = uri("` + baseURL + `/gradle/")
+    push = true
+  }
+}
`), }, { ID: "nuget", @@ -363,6 +380,17 @@ local({ r["CRAN"] <- "` + baseURL + `/cran" options(repos = r) })`), + }, + { + ID: "julia", + Name: "Julia", + Language: "Julia", + Endpoint: "/julia/", + Instructions: template.HTML(`

Set the Pkg server before starting Julia:

+
export JULIA_PKG_SERVER=` + baseURL + `/julia
+

Or inside a running session:

+
ENV["JULIA_PKG_SERVER"] = "` + baseURL + `/julia"
+using Pkg; Pkg.update()
`), }, { ID: "oci", diff --git a/internal/server/errors.go b/internal/server/errors.go new file mode 100644 index 0000000..474ecd7 --- /dev/null +++ b/internal/server/errors.go @@ -0,0 +1,42 @@ +package server + +import ( + "encoding/json" + "net/http" +) + +// Error codes returned in API error responses. These are stable identifiers +// that clients can match on; the message text is for humans and may change. +const ( + ErrCodeBadRequest = "BAD_REQUEST" + ErrCodeNotFound = "NOT_FOUND" + ErrCodeUpstream = "UPSTREAM_ERROR" + ErrCodeInternal = "INTERNAL_ERROR" +) + +// ErrorResponse is the JSON body returned for API errors. +type ErrorResponse struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// writeError sends a JSON error response with the given status, code and +// user-facing message. Internal error details should be logged separately +// by the caller, never passed as the message. +func writeError(w http.ResponseWriter, status int, code, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(ErrorResponse{Code: code, Message: message}) +} + +func badRequest(w http.ResponseWriter, message string) { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, message) +} + +func notFound(w http.ResponseWriter, message string) { + writeError(w, http.StatusNotFound, ErrCodeNotFound, message) +} + +func internalError(w http.ResponseWriter, message string) { + writeError(w, http.StatusInternalServerError, ErrCodeInternal, message) +} diff --git a/internal/server/errors_test.go b/internal/server/errors_test.go new file mode 100644 index 0000000..c660ae2 --- /dev/null +++ b/internal/server/errors_test.go @@ -0,0 +1,93 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteError(t *testing.T) { + tests := []struct { + name string + fn func(w http.ResponseWriter) + status int + code string + message string + }{ + { + name: "badRequest", + fn: func(w http.ResponseWriter) { badRequest(w, "missing field") }, + status: http.StatusBadRequest, + code: ErrCodeBadRequest, + message: "missing field", + }, + { + name: "notFound", + fn: func(w http.ResponseWriter) { notFound(w, "package not found") }, + status: http.StatusNotFound, + code: ErrCodeNotFound, + message: "package not found", + }, + { + name: "internalError", + fn: func(w http.ResponseWriter) { internalError(w, "boom") }, + status: http.StatusInternalServerError, + code: ErrCodeInternal, + message: "boom", + }, + { + name: "upstream", + fn: func(w http.ResponseWriter) { + writeError(w, http.StatusBadGateway, ErrCodeUpstream, "registry unreachable") + }, + status: http.StatusBadGateway, + code: ErrCodeUpstream, + message: "registry unreachable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + tt.fn(w) + + if w.Code != tt.status { + t.Errorf("status = %d, want %d", w.Code, tt.status) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + + var resp ErrorResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("response body is not valid JSON: %v (body: %q)", err, w.Body.String()) + } + if resp.Code != tt.code { + t.Errorf("code = %q, want %q", resp.Code, tt.code) + } + if resp.Message != tt.message { + t.Errorf("message = %q, want %q", resp.Message, tt.message) + } + }) + } +} + +func TestAPIErrorResponseShape(t *testing.T) { + w := httptest.NewRecorder() + badRequest(w, "x") + + var raw map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &raw); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if _, ok := raw["code"]; !ok { + t.Error("response missing 'code' field") + } + if _, ok := raw["message"]; !ok { + t.Error("response missing 'message' field") + } + if len(raw) != 2 { + t.Errorf("response has unexpected fields: %v", raw) + } +} diff --git a/internal/server/gradle_cache_eviction.go b/internal/server/gradle_cache_eviction.go new file mode 100644 index 0000000..7f546d1 --- /dev/null +++ b/internal/server/gradle_cache_eviction.go @@ -0,0 +1,149 @@ +package server + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/git-pkgs/proxy/internal/storage" +) + +const gradleBuildCacheStoragePrefix = "_gradle/http-build-cache/" + +type gradleBuildCacheLister interface { + ListPrefix(ctx context.Context, prefix string) ([]storage.ObjectInfo, error) +} + +func (s *Server) startGradleBuildCacheEviction(ctx context.Context) { + maxAge := s.cfg.ParseGradleBuildCacheMaxAge() + maxSize := s.cfg.ParseGradleBuildCacheMaxSize() + if maxAge <= 0 && maxSize <= 0 { + return + } + + lister, ok := s.storage.(gradleBuildCacheLister) + if !ok { + s.logger.Warn("gradle cache eviction is enabled, but storage backend cannot list objects") + return + } + + interval := s.cfg.ParseGradleBuildCacheSweepInterval() + s.logger.Info("gradle cache eviction enabled", + "max_age", maxAge, + "max_size_bytes", maxSize, + "interval", interval) + + sweep := func() { + deletedCount, freedBytes, err := sweepGradleBuildCache(ctx, s.storage, lister, maxAge, maxSize, time.Now()) + if err != nil { + s.logger.Warn("gradle cache eviction sweep failed", "error", err) + return + } + if deletedCount > 0 { + s.logger.Info("gradle cache eviction sweep completed", + "deleted_entries", deletedCount, + "freed_bytes", freedBytes) + } + } + + sweep() + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sweep() + } + } + }() +} + +func sweepGradleBuildCache( + ctx context.Context, + store storage.Storage, + lister gradleBuildCacheLister, + maxAge time.Duration, + maxSize int64, + now time.Time, +) (int, int64, error) { + entries, err := lister.ListPrefix(ctx, gradleBuildCacheStoragePrefix) + if err != nil { + return 0, 0, fmt.Errorf("listing gradle cache entries: %w", err) + } + + if len(entries) == 0 { + return 0, 0, nil + } + + sortOldestFirst(entries) + + deletedCount := 0 + freedBytes := int64(0) + var firstDeleteErr error + + deleteEntry := func(entry storage.ObjectInfo) bool { + if err := store.Delete(ctx, entry.Path); err != nil { + if firstDeleteErr == nil { + firstDeleteErr = err + } + return false + } + deletedCount++ + freedBytes += entry.Size + return true + } + + remaining := entries + if maxAge > 0 { + cutoff := now.Add(-maxAge) + kept := make([]storage.ObjectInfo, 0, len(entries)) + + for _, entry := range entries { + if !entry.ModTime.IsZero() && entry.ModTime.Before(cutoff) { + if deleteEntry(entry) { + continue + } + } + kept = append(kept, entry) + } + + remaining = kept + } + + if maxSize > 0 { + totalSize := int64(0) + for _, entry := range remaining { + totalSize += entry.Size + } + + for _, entry := range remaining { + if totalSize <= maxSize { + break + } + if deleteEntry(entry) { + totalSize -= entry.Size + } + } + } + + if firstDeleteErr != nil { + return deletedCount, freedBytes, fmt.Errorf("deleting gradle cache entries: %w", firstDeleteErr) + } + + return deletedCount, freedBytes, nil +} + +func sortOldestFirst(entries []storage.ObjectInfo) { + sort.Slice(entries, func(i, j int) bool { + if entries[i].ModTime.Equal(entries[j].ModTime) { + return entries[i].Path < entries[j].Path + } + return entries[i].ModTime.Before(entries[j].ModTime) + }) +} diff --git a/internal/server/gradle_cache_eviction_test.go b/internal/server/gradle_cache_eviction_test.go new file mode 100644 index 0000000..4e97507 --- /dev/null +++ b/internal/server/gradle_cache_eviction_test.go @@ -0,0 +1,138 @@ +package server + +import ( + "bytes" + "context" + "io" + "strings" + "testing" + "time" + + "github.com/git-pkgs/proxy/internal/storage" +) + +type fakeGradleCacheStore struct { + objects map[string]storage.ObjectInfo +} + +func newFakeGradleCacheStore(objects []storage.ObjectInfo) *fakeGradleCacheStore { + m := make(map[string]storage.ObjectInfo, len(objects)) + for _, obj := range objects { + m[obj.Path] = obj + } + return &fakeGradleCacheStore{objects: m} +} + +func (s *fakeGradleCacheStore) Store(_ context.Context, path string, r io.Reader) (int64, string, error) { + data, _ := io.ReadAll(r) + s.objects[path] = storage.ObjectInfo{Path: path, Size: int64(len(data)), ModTime: time.Now()} + return int64(len(data)), "", nil +} + +func (s *fakeGradleCacheStore) Open(_ context.Context, path string) (io.ReadCloser, error) { + obj, ok := s.objects[path] + if !ok { + return nil, storage.ErrNotFound + } + return io.NopCloser(bytes.NewReader(make([]byte, obj.Size))), nil +} + +func (s *fakeGradleCacheStore) Exists(_ context.Context, path string) (bool, error) { + _, ok := s.objects[path] + return ok, nil +} + +func (s *fakeGradleCacheStore) Delete(_ context.Context, path string) error { + delete(s.objects, path) + return nil +} + +func (s *fakeGradleCacheStore) Size(_ context.Context, path string) (int64, error) { + obj, ok := s.objects[path] + if !ok { + return 0, storage.ErrNotFound + } + return obj.Size, nil +} + +func (s *fakeGradleCacheStore) SignedURL(_ context.Context, _ string, _ time.Duration) (string, error) { + return "", storage.ErrSignedURLUnsupported +} + +func (s *fakeGradleCacheStore) UsedSpace(_ context.Context) (int64, error) { + var total int64 + for _, obj := range s.objects { + total += obj.Size + } + return total, nil +} + +func (s *fakeGradleCacheStore) URL() string { return "mem://" } + +func (s *fakeGradleCacheStore) Close() error { return nil } + +func (s *fakeGradleCacheStore) ListPrefix(_ context.Context, prefix string) ([]storage.ObjectInfo, error) { + objects := make([]storage.ObjectInfo, 0) + for _, obj := range s.objects { + if strings.HasPrefix(obj.Path, prefix) { + objects = append(objects, obj) + } + } + return objects, nil +} + +func TestSweepGradleBuildCache_MaxAge(t *testing.T) { + now := time.Date(2026, 4, 27, 12, 0, 0, 0, time.UTC) + store := newFakeGradleCacheStore([]storage.ObjectInfo{ + {Path: "_gradle/http-build-cache/old", Size: 10, ModTime: now.Add(-48 * time.Hour)}, + {Path: "_gradle/http-build-cache/new", Size: 10, ModTime: now.Add(-2 * time.Hour)}, + }) + + deleted, freed, err := sweepGradleBuildCache(context.Background(), store, store, 24*time.Hour, 0, now) + if err != nil { + t.Fatalf("sweepGradleBuildCache() error = %v", err) + } + if deleted != 1 { + t.Fatalf("deleted entries = %d, want 1", deleted) + } + if freed != 10 { + t.Fatalf("freed bytes = %d, want 10", freed) + } + + if _, ok := store.objects["_gradle/http-build-cache/old"]; ok { + t.Fatal("old entry was not deleted") + } + if _, ok := store.objects["_gradle/http-build-cache/new"]; !ok { + t.Fatal("new entry should remain") + } +} + +func TestSweepGradleBuildCache_MaxSizeOldestFirst(t *testing.T) { + now := time.Date(2026, 4, 27, 12, 0, 0, 0, time.UTC) + store := newFakeGradleCacheStore([]storage.ObjectInfo{ + {Path: "_gradle/http-build-cache/a", Size: 5, ModTime: now.Add(-3 * time.Hour)}, + {Path: "_gradle/http-build-cache/b", Size: 5, ModTime: now.Add(-2 * time.Hour)}, + {Path: "_gradle/http-build-cache/c", Size: 5, ModTime: now.Add(-1 * time.Hour)}, + }) + + deleted, freed, err := sweepGradleBuildCache(context.Background(), store, store, 0, 10, now) + if err != nil { + t.Fatalf("sweepGradleBuildCache() error = %v", err) + } + if deleted != 1 { + t.Fatalf("deleted entries = %d, want 1", deleted) + } + if freed != 5 { + t.Fatalf("freed bytes = %d, want 5", freed) + } + + if _, ok := store.objects["_gradle/http-build-cache/a"]; ok { + t.Fatal("oldest entry was not deleted") + } + if _, ok := store.objects["_gradle/http-build-cache/b"]; !ok { + t.Fatal("middle entry should remain") + } + if _, ok := store.objects["_gradle/http-build-cache/c"]; !ok { + t.Fatal("newest entry should remain") + } +} diff --git a/internal/server/health.go b/internal/server/health.go new file mode 100644 index 0000000..f4e4847 --- /dev/null +++ b/internal/server/health.go @@ -0,0 +1,182 @@ +// Package server implements the proxy HTTP server. +package server + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "log/slog" + "strconv" + "sync" + "time" + + "github.com/git-pkgs/proxy/internal/metrics" + "github.com/git-pkgs/proxy/internal/storage" +) + +const ( + probePathPrefix = ".healthcheck/" + probeMarker = "proxy-healthcheck:" + probeSuffixBytes = 8 + defaultProbeTTL = 30 * time.Second + defaultProbeTimeout = 10 * time.Second +) + +// HealthResponse is the JSON payload returned by /health. +type HealthResponse struct { + Status string `json:"status"` + Checks map[string]HealthCheck `json:"checks"` +} + +// HealthCheck reports the status of a single subsystem check. +type HealthCheck struct { + Status string `json:"status"` + Error string `json:"error,omitempty"` + Step string `json:"step,omitempty"` +} + +// probeError tags a storage probe failure with the step that failed. +type probeError struct { + step string + err error +} + +func (e *probeError) Error() string { return e.step + ": " + e.err.Error() } +func (e *probeError) Unwrap() error { return e.err } + +// storageProbe runs a write → size-check → read → verify → delete round-trip +// against the storage backend. Returns nil on success or a *probeError on failure. +func storageProbe(ctx context.Context, s storage.Storage) (err error) { + suffix, suffixErr := randomSuffix() + if suffixErr != nil { + return &probeError{step: "write", err: fmt.Errorf("generating random suffix: %w", suffixErr)} + } + path := probePathPrefix + strconv.FormatInt(time.Now().UnixNano(), 10) + "-" + suffix + payload := []byte(probeMarker + suffix) + + // 1. Store + size, _, storeErr := s.Store(ctx, path, bytes.NewReader(payload)) + if storeErr != nil { + return &probeError{step: "write", err: storeErr} + } + // After Store succeeds, always attempt to delete on the way out so probe + // objects don't accumulate when a later step (size/open/read/verify) fails. + // Delete is reported as the primary error only if no earlier failure + // already set one. + defer func() { + if delErr := s.Delete(ctx, path); delErr != nil && err == nil { + err = &probeError{step: "delete", err: delErr} + } + }() + // 2. Size check + if size != int64(len(payload)) { + return &probeError{step: "size", err: fmt.Errorf("wrote %d bytes, expected %d", size, len(payload))} + } + // 3. Open + rc, openErr := s.Open(ctx, path) + if openErr != nil { + return &probeError{step: "read", err: openErr} + } + // 4. Read all (classify mid-stream errors as read, not verify). + // Close explicitly (not deferred) so the file handle is released before + // Delete — on Windows, an open handle prevents deletion. + data, readErr := io.ReadAll(rc) + _ = rc.Close() + if readErr != nil { + return &probeError{step: "read", err: readErr} + } + // 5. Verify + if !bytes.Equal(data, payload) { + return &probeError{step: "verify", err: fmt.Errorf("content mismatch")} + } + // 6. Delete is handled via the deferred cleanup above. + return nil +} + +// randomSuffix returns 8 cryptographically random bytes hex-encoded. +func randomSuffix() (string, error) { + b := make([]byte, probeSuffixBytes) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// healthCache memoizes the result of storageProbe for a configurable TTL. +// It is safe for concurrent use. +type healthCache struct { + storage storage.Storage + interval time.Duration + probeTimeout time.Duration + logger *slog.Logger + + mu sync.Mutex + lastAt time.Time + lastErr error +} + +// newHealthCache builds a cache, parsing the interval from a duration string. +// Empty interval string defaults to 30s. "0" or "0s" disables caching. +func newHealthCache(s storage.Storage, intervalStr string, logger *slog.Logger) (*healthCache, error) { + interval := defaultProbeTTL + if intervalStr != "" { + d, err := time.ParseDuration(intervalStr) + if err != nil { + return nil, fmt.Errorf("parsing storage_probe_interval %q: %w", intervalStr, err) + } + interval = d + } + return &healthCache{ + storage: s, + interval: interval, + probeTimeout: defaultProbeTimeout, + logger: logger, + }, nil +} + +// Check returns the cached probe result if still fresh, otherwise runs a fresh probe. +// The probe runs under a context derived from context.Background() with a fixed +// timeout so that caller cancellation (e.g. client disconnect) cannot poison the +// cache with context.Canceled. +func (c *healthCache) Check() error { + c.mu.Lock() + defer c.mu.Unlock() + + // Cache hit + if c.interval > 0 && !c.lastAt.IsZero() && time.Since(c.lastAt) < c.interval { + return c.lastErr + } + + // Fresh probe under a detached context + probeCtx, cancel := context.WithTimeout(context.Background(), c.probeTimeout) + defer cancel() + err := storageProbe(probeCtx, c.storage) + + // Transition logging and metric increment happen only on the fresh-probe path. + c.logTransition(c.lastErr, err) + if err != nil { + var pe *probeError + if errors.As(err, &pe) { + metrics.RecordHealthProbeFailure(pe.step) + } else { + metrics.RecordHealthProbeFailure("unknown") + } + } + + c.lastErr = err + c.lastAt = time.Now() + return err +} + +func (c *healthCache) logTransition(prev, curr error) { + switch { + case prev != nil && curr == nil: + c.logger.Info("storage probe recovered") + case prev == nil && curr != nil: + c.logger.Error("storage probe failed", "error", curr.Error()) + } +} diff --git a/internal/server/health_test.go b/internal/server/health_test.go new file mode 100644 index 0000000..c0f70c9 --- /dev/null +++ b/internal/server/health_test.go @@ -0,0 +1,448 @@ +package server + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/git-pkgs/proxy/internal/metrics" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +// fakeStorage is a minimal storage.Storage for probe tests with per-step failure injection. +type fakeStorage struct { + mu sync.Mutex + + storeCalls atomic.Int64 + openCalls atomic.Int64 + closeCalls atomic.Int64 + deleteCalls atomic.Int64 + + paths []string + payloads [][]byte + + // Failure injection. + storeErr error + openErr error + readErr error // returned by the io.ReadCloser.Read after partial bytes + deleteErr error + + // Misbehavior knobs. + sizeDelta int64 // added to the reported size from Store + readOverride []byte // if non-nil, Open returns a reader yielding these bytes instead of stored content + + // storeBlock, if non-nil, causes Store to block until the channel is closed or ctx is done. + storeBlock chan struct{} + + stored map[string][]byte +} + +func newFakeStorage() *fakeStorage { return &fakeStorage{stored: map[string][]byte{}} } + +func (f *fakeStorage) Store(ctx context.Context, path string, r io.Reader) (int64, string, error) { + f.storeCalls.Add(1) + if f.storeErr != nil { + return 0, "", f.storeErr + } + if f.storeBlock != nil { + select { + case <-f.storeBlock: + case <-ctx.Done(): + return 0, "", ctx.Err() + } + } + data, err := io.ReadAll(r) + if err != nil { + return 0, "", err + } + f.mu.Lock() + f.stored[path] = data + f.paths = append(f.paths, path) + f.payloads = append(f.payloads, data) + f.mu.Unlock() + return int64(len(data)) + f.sizeDelta, "fakehash", nil +} + +type fakeReadCloser struct { + data []byte + pos int + readErr error + closed *atomic.Int64 +} + +func (rc *fakeReadCloser) Read(p []byte) (int, error) { + if rc.pos >= len(rc.data) { + if rc.readErr != nil { + return 0, rc.readErr + } + return 0, io.EOF + } + n := copy(p, rc.data[rc.pos:]) + rc.pos += n + if rc.pos >= len(rc.data) && rc.readErr != nil { + return n, rc.readErr + } + return n, nil +} + +func (rc *fakeReadCloser) Close() error { rc.closed.Add(1); return nil } + +func (f *fakeStorage) Open(ctx context.Context, path string) (io.ReadCloser, error) { + f.openCalls.Add(1) + if f.openErr != nil { + return nil, f.openErr + } + f.mu.Lock() + data := f.stored[path] + f.mu.Unlock() + if f.readOverride != nil { + data = f.readOverride + } + return &fakeReadCloser{data: data, readErr: f.readErr, closed: &f.closeCalls}, nil +} + +func (f *fakeStorage) Exists(ctx context.Context, path string) (bool, error) { + f.mu.Lock() + defer f.mu.Unlock() + _, ok := f.stored[path] + return ok, nil +} + +func (f *fakeStorage) Delete(ctx context.Context, path string) error { + f.deleteCalls.Add(1) + if f.deleteErr != nil { + return f.deleteErr + } + f.mu.Lock() + delete(f.stored, path) + f.mu.Unlock() + return nil +} + +func (f *fakeStorage) Size(ctx context.Context, path string) (int64, error) { return 0, nil } +func (f *fakeStorage) SignedURL(ctx context.Context, path string, expiry time.Duration) (string, error) { + return "", storage.ErrSignedURLUnsupported +} +func (f *fakeStorage) UsedSpace(ctx context.Context) (int64, error) { return 0, nil } +func (f *fakeStorage) URL() string { return "fake://" } +func (f *fakeStorage) Close() error { return nil } + +// --- Tests follow. First test: happy path --- + +func TestStorageProbe_HappyPath(t *testing.T) { + fs := newFakeStorage() + if err := storageProbe(context.Background(), fs); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := fs.storeCalls.Load(); got != 1 { + t.Errorf("Store calls = %d, want 1", got) + } + if got := fs.openCalls.Load(); got != 1 { + t.Errorf("Open calls = %d, want 1", got) + } + if got := fs.closeCalls.Load(); got != 1 { + t.Errorf("Close calls = %d, want 1", got) + } + if got := fs.deleteCalls.Load(); got != 1 { + t.Errorf("Delete calls = %d, want 1", got) + } + if len(fs.paths) != 1 || !strings.HasPrefix(fs.paths[0], ".healthcheck/") { + t.Errorf("unexpected probe path: %v", fs.paths) + } +} + +func TestStorageProbe_WriteFails(t *testing.T) { + fs := newFakeStorage() + fs.storeErr = errors.New("disk full") + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) { + t.Fatalf("expected *probeError, got %T: %v", err, err) + } + if pe.step != "write" { + t.Errorf("step = %q, want write", pe.step) + } + if fs.openCalls.Load() != 0 { + t.Errorf("Open should not be called after write failure") + } +} + +func TestStorageProbe_SizeMismatch(t *testing.T) { + fs := newFakeStorage() + fs.sizeDelta = -1 // Report 1 byte fewer than actually written + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) || pe.step != "size" { + t.Fatalf("step = %v, want size; err = %v", pe, err) + } + if fs.openCalls.Load() != 0 { + t.Errorf("Open should not be called after size mismatch") + } +} + +func TestStorageProbe_OpenFails(t *testing.T) { + fs := newFakeStorage() + fs.openErr = errors.New("access denied") + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) || pe.step != "read" { + t.Fatalf("step = %v, want read; err = %v", pe, err) + } +} + +func TestStorageProbe_ReadMidStreamFails(t *testing.T) { + fs := newFakeStorage() + fs.readErr = errors.New("connection reset") + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) || pe.step != "read" { + t.Fatalf("step = %v, want read (NOT verify); err = %v", pe, err) + } +} + +func TestStorageProbe_ContentMismatch(t *testing.T) { + fs := newFakeStorage() + fs.readOverride = []byte("wrong content") + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) || pe.step != "verify" { + t.Fatalf("step = %v, want verify; err = %v", pe, err) + } +} + +func TestStorageProbe_DeleteFails(t *testing.T) { + fs := newFakeStorage() + fs.deleteErr = errors.New("permission denied") + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) || pe.step != "delete" { + t.Fatalf("step = %v, want delete; err = %v", pe, err) + } +} + +// TestStorageProbe_CleanupOnNonDeleteFailure asserts that the probe object is +// deleted even when a step after Store (size/open/read/verify) fails, so +// probe artifacts don't accumulate in the storage backend. +func TestStorageProbe_CleanupOnNonDeleteFailure(t *testing.T) { + cases := []struct { + name string + inject func(*fakeStorage) + wantErr string + }{ + {"size mismatch", func(fs *fakeStorage) { fs.sizeDelta = -1 }, "size"}, + {"open fails", func(fs *fakeStorage) { fs.openErr = errors.New("open boom") }, "read"}, + {"read mid-stream", func(fs *fakeStorage) { fs.readErr = errors.New("mid-stream boom") }, "read"}, + {"content mismatch", func(fs *fakeStorage) { fs.readOverride = []byte("wrong") }, "verify"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + fs := newFakeStorage() + tc.inject(fs) + err := storageProbe(context.Background(), fs) + var pe *probeError + if !errors.As(err, &pe) || pe.step != tc.wantErr { + t.Fatalf("step = %v, want %q; err = %v", pe, tc.wantErr, err) + } + if got := fs.deleteCalls.Load(); got != 1 { + t.Errorf("deleteCalls = %d, want 1 (cleanup should run on non-delete failures)", got) + } + }) + } +} + +func TestStorageProbe_ReaderClosedOnReadFailure(t *testing.T) { + fs := newFakeStorage() + fs.readErr = errors.New("read error") + _ = storageProbe(context.Background(), fs) + if got := fs.closeCalls.Load(); got != fs.openCalls.Load() { + t.Errorf("closeCalls = %d, openCalls = %d (should match)", got, fs.openCalls.Load()) + } +} + +func TestStorageProbe_PathUniqueness(t *testing.T) { + fs := newFakeStorage() + for i := 0; i < 100; i++ { + if err := storageProbe(context.Background(), fs); err != nil { + t.Fatalf("probe %d: %v", i, err) + } + } + seen := make(map[string]bool) + for _, p := range fs.paths { + if !strings.HasPrefix(p, ".healthcheck/") { + t.Errorf("path missing prefix: %q", p) + } + if seen[p] { + t.Errorf("duplicate path: %q", p) + } + seen[p] = true + } +} + +// helper: a healthCache wired to a fakeStorage and a discard logger. +func newTestCache(fs *fakeStorage, interval time.Duration) *healthCache { + return &healthCache{ + storage: fs, + interval: interval, + probeTimeout: 5 * time.Second, + logger: discardLogger(), + } +} + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestHealthCache_CacheHit(t *testing.T) { + fs := newFakeStorage() + c := newTestCache(fs, 30*time.Second) + if err := c.Check(); err != nil { + t.Fatalf("first check: %v", err) + } + if err := c.Check(); err != nil { + t.Fatalf("second check: %v", err) + } + if got := fs.storeCalls.Load(); got != 1 { + t.Errorf("storeCalls = %d, want 1 (second call should be cached)", got) + } +} + +func TestHealthCache_MissAfterTTL(t *testing.T) { + fs := newFakeStorage() + c := newTestCache(fs, 10*time.Millisecond) + _ = c.Check() + time.Sleep(20 * time.Millisecond) + _ = c.Check() + if got := fs.storeCalls.Load(); got != 2 { + t.Errorf("storeCalls = %d, want 2", got) + } +} + +func TestHealthCache_Disabled(t *testing.T) { + fs := newFakeStorage() + c := newTestCache(fs, 0) // interval = 0 means probe every call + _ = c.Check() + _ = c.Check() + if got := fs.storeCalls.Load(); got != 2 { + t.Errorf("storeCalls = %d, want 2", got) + } +} + +func TestHealthCache_LastAtNotAdvancedOnHit(t *testing.T) { + fs := newFakeStorage() + c := newTestCache(fs, 30*time.Second) + for i := 0; i < 100; i++ { + _ = c.Check() + } + if got := fs.storeCalls.Load(); got != 1 { + t.Errorf("storeCalls = %d, want 1 across 100 hits", got) + } +} + +func TestHealthCache_ConcurrentSingleFlight(t *testing.T) { + fs := newFakeStorage() + c := newTestCache(fs, 30*time.Second) + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { defer wg.Done(); _ = c.Check() }() + } + wg.Wait() + if got := fs.storeCalls.Load(); got != 1 { + t.Errorf("storeCalls = %d, want 1 with 20 concurrent callers", got) + } +} + +func TestHealthCache_FailureCounterIncrement(t *testing.T) { + fs := newFakeStorage() + fs.storeErr = errors.New("boom") + c := newTestCache(fs, 30*time.Second) + + before := testutil.ToFloat64(metrics.HealthProbeFailures.WithLabelValues("write")) + + // First call: fresh probe → counter +1 + _ = c.Check() + afterFirst := testutil.ToFloat64(metrics.HealthProbeFailures.WithLabelValues("write")) + if afterFirst-before != 1 { + t.Errorf("counter delta after first call = %v, want 1", afterFirst-before) + } + + // Second call: cache hit → counter NOT re-incremented + _ = c.Check() + afterSecond := testutil.ToFloat64(metrics.HealthProbeFailures.WithLabelValues("write")) + if afterSecond != afterFirst { + t.Errorf("counter changed on cache hit: %v → %v", afterFirst, afterSecond) + } +} + +func TestHealthCache_ProbeTimeout(t *testing.T) { + fs := newFakeStorage() + fs.storeBlock = make(chan struct{}) // Store will block until channel is closed (or never) + t.Cleanup(func() { close(fs.storeBlock) }) + + c := &healthCache{ + storage: fs, + interval: 30 * time.Second, + probeTimeout: 50 * time.Millisecond, + logger: discardLogger(), + } + start := time.Now() + err := c.Check() + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if elapsed > 500*time.Millisecond { + t.Errorf("probe took %v, expected ~50ms (timeout not respected)", elapsed) + } +} + +func TestHealthCache_TransitionLogging(t *testing.T) { + fs := newFakeStorage() + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})) + c := &healthCache{ + storage: fs, + interval: 0, // probe every call + probeTimeout: 5 * time.Second, + logger: logger, + } + + // Steady ok state — should not log + _ = c.Check() + _ = c.Check() + if got := strings.Count(buf.String(), "storage probe"); got != 0 { + t.Errorf("steady-state logs = %d, want 0; output: %s", got, buf.String()) + } + + // ok → err transition: exactly one Error log + buf.Reset() + fs.storeErr = errors.New("boom") + _ = c.Check() + if !strings.Contains(buf.String(), "storage probe failed") { + t.Errorf("missing failure log on transition; output: %s", buf.String()) + } + + // err steady state — should not log again + buf.Reset() + _ = c.Check() + if buf.Len() != 0 { + t.Errorf("steady-err logs = %q, want empty", buf.String()) + } + + // err → ok transition: exactly one Info log + buf.Reset() + fs.storeErr = nil + _ = c.Check() + if !strings.Contains(buf.String(), "storage probe recovered") { + t.Errorf("missing recovery log on transition; output: %s", buf.String()) + } +} diff --git a/internal/server/mirror_api.go b/internal/server/mirror_api.go index 6a6a6ca..028d4e0 100644 --- a/internal/server/mirror_api.go +++ b/internal/server/mirror_api.go @@ -20,19 +20,16 @@ func NewMirrorAPIHandler(jobs *mirror.JobStore) *MirrorAPIHandler { // HandleCreate starts a new mirror job. func (h *MirrorAPIHandler) HandleCreate(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) var req mirror.JobRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - writeJSON(w, map[string]string{"error": "invalid request body"}) + badRequest(w, "invalid request body") return } id, err := h.jobs.Create(req) if err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - writeJSON(w, map[string]string{"error": err.Error()}) + badRequest(w, "invalid mirror job request") return } @@ -46,13 +43,10 @@ func (h *MirrorAPIHandler) HandleGet(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") job := h.jobs.Get(id) if job == nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) - writeJSON(w, map[string]string{"error": "job not found"}) + notFound(w, "job not found") return } - w.Header().Set("Content-Type", "application/json") writeJSON(w, job) } @@ -60,11 +54,8 @@ func (h *MirrorAPIHandler) HandleGet(w http.ResponseWriter, r *http.Request) { func (h *MirrorAPIHandler) HandleCancel(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if h.jobs.Cancel(id) { - w.Header().Set("Content-Type", "application/json") writeJSON(w, map[string]string{"status": "canceled"}) } else { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) - writeJSON(w, map[string]string{"error": "job not found or not running"}) + notFound(w, "job not found or not running") } } diff --git a/internal/server/mirror_api_test.go b/internal/server/mirror_api_test.go index 0e84da1..73b8731 100644 --- a/internal/server/mirror_api_test.go +++ b/internal/server/mirror_api_test.go @@ -71,6 +71,19 @@ func TestMirrorAPICreateJob(t *testing.T) { } } +func TestMirrorAPICreateOversizedBody(t *testing.T) { + h := setupMirrorAPI(t) + + body := bytes.Repeat([]byte("x"), int(maxBodySize)+1) + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + func TestMirrorAPICreateInvalidBody(t *testing.T) { h := setupMirrorAPI(t) diff --git a/internal/server/resolve.go b/internal/server/resolve.go index 479ede6..51f203d 100644 --- a/internal/server/resolve.go +++ b/internal/server/resolve.go @@ -1,11 +1,39 @@ package server import ( + "fmt" "strings" + "unicode" "github.com/git-pkgs/proxy/internal/database" ) +// maxPackagePathLen bounds the wildcard portion of package routes (name plus +// version and any suffix). npm caps names at 214 and Maven coordinates can be +// longer, so 512 leaves room without admitting pathological inputs. +const maxPackagePathLen = 512 + +// validatePackagePath rejects wildcard package paths that cannot be valid in +// any supported ecosystem. It is a coarse filter applied before database or +// enrichment lookups; ecosystem-specific name rules are layered on top. +func validatePackagePath(path string) error { + if path == "" { + return fmt.Errorf("package name required") + } + if len(path) > maxPackagePathLen { + return fmt.Errorf("package path exceeds %d bytes", maxPackagePathLen) + } + for _, r := range path { + if r == 0 { + return fmt.Errorf("package path contains null byte") + } + if unicode.IsControl(r) { + return fmt.Errorf("package path contains control character %#U", r) + } + } + return nil +} + // resolvePackageName determines the package name from a wildcard path by // checking the database. This handles namespaced packages like Composer's // vendor/name format where the package name contains a slash. diff --git a/internal/server/resolve_test.go b/internal/server/resolve_test.go index 427c2cb..dd7d2dc 100644 --- a/internal/server/resolve_test.go +++ b/internal/server/resolve_test.go @@ -3,6 +3,7 @@ package server import ( "os" "path/filepath" + "strings" "testing" "github.com/git-pkgs/proxy/internal/database" @@ -118,3 +119,36 @@ func TestSplitWildcardPath(t *testing.T) { } } } + +func TestValidatePackagePath(t *testing.T) { + tests := []struct { + name string + path string + wantErr bool + }{ + {"simple", "lodash", false}, + {"with version", "lodash/4.17.21", false}, + {"npm scoped", "@babel/core/7.0.0", false}, + {"composer namespaced", "symfony/console/6.0.0", false}, + {"maven coordinates", "org.apache.commons/commons-lang3/3.12.0", false}, + {"unicode", "café/1.0.0", false}, + {"empty", "", true}, + {"null byte", "lodash\x00/4.17.21", true}, + {"null byte suffix", "lodash\x00", true}, + {"newline", "lodash\n4.17.21", true}, + {"carriage return", "lodash\r", true}, + {"escape", "lodash\x1b[31m", true}, + {"delete", "lodash\x7f", true}, + {"too long", strings.Repeat("a", maxPackagePathLen+1), true}, + {"at limit", strings.Repeat("a", maxPackagePathLen), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePackagePath(tt.path) + if (err != nil) != tt.wantErr { + t.Errorf("validatePackagePath(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr) + } + }) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 2ae0e69..7de5041 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,11 +9,13 @@ // - /pub/* - pub.dev registry protocol // - /pypi/* - PyPI registry protocol // - /maven/* - Maven repository protocol +// - /gradle/* - Gradle HttpBuildCache protocol // - /nuget/* - NuGet V3 API protocol // - /composer/* - Composer/Packagist protocol // - /conan/* - Conan C/C++ protocol // - /conda/* - Conda/Anaconda protocol // - /cran/* - CRAN (R) protocol +// - /julia/* - Julia Pkg server protocol // - /v2/* - OCI/Docker container registry protocol // - /debian/* - Debian/APT repository protocol // - /rpm/* - RPM/Yum repository protocol @@ -39,6 +41,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -48,7 +51,7 @@ import ( swaggerdoc "github.com/git-pkgs/proxy/docs/swagger" "github.com/git-pkgs/proxy/internal/config" - "github.com/git-pkgs/proxy/internal/cooldown" + "github.com/git-pkgs/cooldown" "github.com/git-pkgs/proxy/internal/database" "github.com/git-pkgs/proxy/internal/enrichment" "github.com/git-pkgs/proxy/internal/handler" @@ -78,7 +81,8 @@ type Server struct { logger *slog.Logger http *http.Server templates *Templates - cancel context.CancelFunc + cancel context.CancelFunc + healthCache *healthCache } // New creates a new Server with the given configuration. @@ -124,12 +128,20 @@ func New(cfg *config.Config, logger *slog.Logger) (*Server, error) { return nil, fmt.Errorf("verifying storage connectivity: %w", err) } + hc, err := newHealthCache(store, cfg.Health.StorageProbeInterval, logger) + if err != nil { + _ = store.Close() + _ = db.Close() + return nil, fmt.Errorf("initializing health cache: %w", err) + } + return &Server{ - cfg: cfg, - db: db, - storage: store, - logger: logger, - templates: &Templates{}, + cfg: cfg, + db: db, + storage: store, + logger: logger, + templates: &Templates{}, + healthCache: hc, }, nil } @@ -148,6 +160,12 @@ func (s *Server) Start() error { proxy.Cooldown = cd proxy.CacheMetadata = s.cfg.CacheMetadata proxy.MetadataTTL = s.cfg.ParseMetadataTTL() + proxy.MetadataMaxSize = s.cfg.ParseMetadataMaxSize() + proxy.GradleReadOnly = s.cfg.Gradle.BuildCache.ReadOnly + proxy.GradleMaxUploadSize = s.cfg.ParseGradleBuildCacheMaxUploadSize() + proxy.DirectServe = s.cfg.Storage.DirectServe + proxy.DirectServeTTL = s.cfg.ParseDirectServeTTL() + proxy.DirectServeBaseURL = s.cfg.Storage.DirectServeBaseURL // Create router with Chi r := chi.NewRouter() @@ -176,12 +194,19 @@ func (s *Server) Start() error { hexHandler := handler.NewHexHandler(proxy, s.cfg.BaseURL) pubHandler := handler.NewPubHandler(proxy, s.cfg.BaseURL) pypiHandler := handler.NewPyPIHandler(proxy, s.cfg.BaseURL) - mavenHandler := handler.NewMavenHandler(proxy, s.cfg.BaseURL) + mavenHandler := handler.NewMavenHandler( + proxy, + s.cfg.BaseURL, + s.cfg.Upstream.Maven, + s.cfg.Upstream.GradlePluginPortal, + ) + gradleHandler := handler.NewGradleBuildCacheHandler(proxy) nugetHandler := handler.NewNuGetHandler(proxy, s.cfg.BaseURL) composerHandler := handler.NewComposerHandler(proxy, s.cfg.BaseURL) conanHandler := handler.NewConanHandler(proxy, s.cfg.BaseURL) condaHandler := handler.NewCondaHandler(proxy, s.cfg.BaseURL) cranHandler := handler.NewCRANHandler(proxy, s.cfg.BaseURL) + juliaHandler := handler.NewJuliaHandler(proxy, s.cfg.BaseURL) containerHandler := handler.NewContainerHandler(proxy, s.cfg.BaseURL) debianHandler := handler.NewDebianHandler(proxy, s.cfg.BaseURL) rpmHandler := handler.NewRPMHandler(proxy, s.cfg.BaseURL) @@ -194,11 +219,13 @@ func (s *Server) Start() error { r.Mount("/pub", http.StripPrefix("/pub", pubHandler.Routes())) r.Mount("/pypi", http.StripPrefix("/pypi", pypiHandler.Routes())) r.Mount("/maven", http.StripPrefix("/maven", mavenHandler.Routes())) + r.Mount("/gradle", http.StripPrefix("/gradle", gradleHandler.Routes())) r.Mount("/nuget", http.StripPrefix("/nuget", nugetHandler.Routes())) r.Mount("/composer", http.StripPrefix("/composer", composerHandler.Routes())) r.Mount("/conan", http.StripPrefix("/conan", conanHandler.Routes())) r.Mount("/conda", http.StripPrefix("/conda", condaHandler.Routes())) r.Mount("/cran", http.StripPrefix("/cran", cranHandler.Routes())) + r.Mount("/julia", http.StripPrefix("/julia", juliaHandler.Routes())) r.Mount("/v2", http.StripPrefix("/v2", containerHandler.Routes())) r.Mount("/debian", http.StripPrefix("/debian", debianHandler.Routes())) r.Mount("/rpm", http.StripPrefix("/rpm", rpmHandler.Routes())) @@ -235,6 +262,7 @@ func (s *Server) Start() error { // Start background context (used by mirror jobs and cleanup) bgCtx, bgCancel := context.WithCancel(context.Background()) s.cancel = bgCancel + s.startGradleBuildCacheEviction(bgCtx) // Mirror API endpoints (opt-in via mirror_api config or PROXY_MIRROR_API env) if s.cfg.MirrorAPI { @@ -612,6 +640,10 @@ func (s *Server) handlePackagesList(w http.ResponseWriter, r *http.Request) { func (s *Server) handlePackagePath(w http.ResponseWriter, r *http.Request) { ecosystem := chi.URLParam(r, "ecosystem") wildcard := chi.URLParam(r, "*") + if err := validatePackagePath(wildcard); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } segments := splitWildcardPath(wildcard) if ecosystem == "" || len(segments) == 0 { @@ -786,23 +818,49 @@ func (s *Server) showComparePage(w http.ResponseWriter, ecosystem, name, version } } -// handleHealth responds with a simple health check. +// handleHealth responds with a structured JSON health report. +// // @Summary Health check // @Tags meta -// @Produce plain -// @Success 200 {string} string -// @Failure 503 {string} string +// @Produce json +// @Success 200 {object} HealthResponse +// @Failure 503 {object} HealthResponse // @Router /health [get] func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { - // Check database connectivity + w.Header().Set("Content-Type", "application/json") + + resp := HealthResponse{Status: "ok", Checks: map[string]HealthCheck{}} + + // Database check (short-circuit; do not waste a storage probe call when DB is down). + // On DB failure the storage entry reports "skipped" rather than being omitted so + // the response always carries the same key set for monitors that expect it. if _, err := s.db.SchemaVersion(); err != nil { + resp.Status = "error" + resp.Checks["database"] = HealthCheck{Status: "error", Error: err.Error()} + resp.Checks["storage"] = HealthCheck{Status: "skipped"} w.WriteHeader(http.StatusServiceUnavailable) - _, _ = fmt.Fprint(w, "database error") + _ = json.NewEncoder(w).Encode(resp) return } + resp.Checks["database"] = HealthCheck{Status: "ok"} + + // Storage probe (via cache). + if err := s.healthCache.Check(); err != nil { + resp.Status = "error" + sc := HealthCheck{Status: "error", Error: err.Error()} + var pe *probeError + if errors.As(err, &pe) { + sc.Step = pe.step + } + resp.Checks["storage"] = sc + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(resp) + return + } + resp.Checks["storage"] = HealthCheck{Status: "ok"} w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, "ok") + _ = json.NewEncoder(w).Encode(resp) } // StatsResponse contains cache statistics. @@ -819,20 +877,20 @@ type StatsResponse struct { // @Tags meta // @Produce json // @Success 200 {object} StatsResponse -// @Failure 500 {string} string +// @Failure 500 {object} ErrorResponse // @Router /stats [get] func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { ctx := r.Context() count, err := s.db.GetCachedArtifactCount() if err != nil { - http.Error(w, "failed to get artifact count", http.StatusInternalServerError) + internalError(w, "failed to get artifact count") return } size, err := s.db.GetTotalCacheSize() if err != nil { - http.Error(w, "failed to get cache size", http.StatusInternalServerError) + internalError(w, "failed to get cache size") return } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index be88bf6..e2dc1c2 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -72,20 +72,30 @@ func newTestServer(t *testing.T) *testServer { gemHandler := handler.NewGemHandler(proxy, cfg.BaseURL) goHandler := handler.NewGoHandler(proxy, cfg.BaseURL) pypiHandler := handler.NewPyPIHandler(proxy, cfg.BaseURL) + gradleHandler := handler.NewGradleBuildCacheHandler(proxy) r.Mount("/npm", http.StripPrefix("/npm", npmHandler.Routes())) r.Mount("/cargo", http.StripPrefix("/cargo", cargoHandler.Routes())) r.Mount("/gem", http.StripPrefix("/gem", gemHandler.Routes())) r.Mount("/go", http.StripPrefix("/go", goHandler.Routes())) r.Mount("/pypi", http.StripPrefix("/pypi", pypiHandler.Routes())) + r.Mount("/gradle", http.StripPrefix("/gradle", gradleHandler.Routes())) + + hc, err := newHealthCache(store, "30s", logger) + if err != nil { + _ = db.Close() + _ = os.RemoveAll(tempDir) + t.Fatalf("failed to create health cache: %v", err) + } // Create a minimal server struct for the handlers s := &Server{ - cfg: cfg, - db: db, - storage: store, - logger: logger, - templates: &Templates{}, + cfg: cfg, + db: db, + storage: store, + logger: logger, + templates: &Templates{}, + healthCache: hc, } r.Get("/health", s.handleHealth) @@ -177,12 +187,55 @@ func TestHealthEndpoint(t *testing.T) { ts.handler.ServeHTTP(w, req) if w.Code != http.StatusOK { - t.Errorf("expected status 200, got %d", w.Code) + t.Fatalf("status = %d, want 200; body: %s", w.Code, w.Body.String()) } + if got := w.Header().Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type = %q, want application/json", got) + } + var resp HealthResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decoding response: %v", err) + } + if resp.Status != "ok" { + t.Errorf("status = %q, want ok", resp.Status) + } + if resp.Checks["database"].Status != "ok" { + t.Errorf("database check = %+v, want ok", resp.Checks["database"]) + } + if resp.Checks["storage"].Status != "ok" { + t.Errorf("storage check = %+v, want ok", resp.Checks["storage"]) + } +} - body := w.Body.String() - if body != "ok" { - t.Errorf("expected body 'ok', got %q", body) +func TestHealthEndpoint_DBFailureShortCircuits(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + // Force DB failure by closing the connection. + _ = ts.db.Close() + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want 503; body: %s", w.Code, w.Body.String()) + } + var resp HealthResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decoding: %v", err) + } + if resp.Status != "error" { + t.Errorf("status = %q, want error", resp.Status) + } + if resp.Checks["database"].Status != "error" { + t.Errorf("database check = %+v, want error", resp.Checks["database"]) + } + storage, present := resp.Checks["storage"] + if !present { + t.Error("storage key should be present (with status=skipped) on DB short-circuit") + } else if storage.Status != "skipped" { + t.Errorf("storage check = %+v, want status=skipped", storage) } } @@ -344,6 +397,33 @@ func TestPyPISimple(t *testing.T) { } } +func TestGradleBuildCachePutGet(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + key := "abc123def456" + body := "build-cache-bytes" + + putReq := httptest.NewRequest(http.MethodPut, "/gradle/"+key, strings.NewReader(body)) + putW := httptest.NewRecorder() + ts.handler.ServeHTTP(putW, putReq) + + if putW.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d: %s", putW.Code, putW.Body.String()) + } + + getReq := httptest.NewRequest(http.MethodGet, "/gradle/"+key, nil) + getW := httptest.NewRecorder() + ts.handler.ServeHTTP(getW, getReq) + + if getW.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", getW.Code, getW.Body.String()) + } + if got := getW.Body.String(); got != body { + t.Fatalf("expected body %q, got %q", body, got) + } +} + func TestGemSpecs(t *testing.T) { ts := newTestServer(t) defer ts.close() diff --git a/internal/server/templates_test.go b/internal/server/templates_test.go index a8b67f8..c27363b 100644 --- a/internal/server/templates_test.go +++ b/internal/server/templates_test.go @@ -193,7 +193,7 @@ func TestInstallPage(t *testing.T) { body := w.Body.String() // Should contain instructions for all registries - registries := []string{"npm", "Cargo", "RubyGems", "Go Modules", "PyPI", "Maven", "NuGet", "Composer", "Conan", "Conda", "CRAN"} + registries := []string{"npm", "Cargo", "RubyGems", "Go Modules", "PyPI", "Maven", "Gradle Build Cache", "NuGet", "Composer", "Conan", "Conda", "CRAN"} for _, reg := range registries { if !strings.Contains(body, reg) { t.Errorf("install page should contain %s instructions", reg) diff --git a/internal/storage/blob.go b/internal/storage/blob.go index 2d6af46..67e91d0 100644 --- a/internal/storage/blob.go +++ b/internal/storage/blob.go @@ -6,12 +6,15 @@ import ( "encoding/hex" "fmt" "io" + "net/http" "os" "path/filepath" "runtime" "strings" + "time" "gocloud.dev/blob" + _ "gocloud.dev/blob/azureblob" _ "gocloud.dev/blob/fileblob" _ "gocloud.dev/blob/s3blob" "gocloud.dev/gcerrors" @@ -138,6 +141,20 @@ func (b *Blob) Delete(ctx context.Context, path string) error { return nil } +func (b *Blob) SignedURL(ctx context.Context, path string, expiry time.Duration) (string, error) { + url, err := b.bucket.SignedURL(ctx, path, &blob.SignedURLOptions{ + Method: http.MethodGet, + Expiry: expiry, + }) + if err != nil { + if gcerrors.Code(err) == gcerrors.Unimplemented { + return "", ErrSignedURLUnsupported + } + return "", fmt.Errorf("signing URL: %w", err) + } + return url, nil +} + func (b *Blob) Size(ctx context.Context, path string) (int64, error) { attrs, err := b.bucket.Attributes(ctx, path) if err != nil { @@ -167,6 +184,35 @@ func (b *Blob) UsedSpace(ctx context.Context) (int64, error) { return total, nil } +// ListPrefix returns object metadata for keys under a prefix. +func (b *Blob) ListPrefix(ctx context.Context, prefix string) ([]ObjectInfo, error) { + iter := b.bucket.List(&blob.ListOptions{Prefix: prefix}) + objects := make([]ObjectInfo, 0) + + for { + obj, err := iter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("listing objects: %w", err) + } + if obj.IsDir { + continue + } + + info := ObjectInfo{ + Path: obj.Key, + Size: obj.Size, + ModTime: obj.ModTime, + } + + objects = append(objects, info) + } + + return objects, nil +} + func (b *Blob) Close() error { return b.bucket.Close() } diff --git a/internal/storage/blob_test.go b/internal/storage/blob_test.go index bb2d089..d80290b 100644 --- a/internal/storage/blob_test.go +++ b/internal/storage/blob_test.go @@ -10,6 +10,7 @@ import ( "runtime" "strings" "testing" + "time" ) func TestOpenBucket(t *testing.T) { @@ -188,6 +189,18 @@ func TestBlobLargeFile(t *testing.T) { assertLargeFileRoundTrip(t, createTestBlob(t)) } +func TestBlobSignedURLUnsupported(t *testing.T) { + b := createTestBlob(t) + ctx := context.Background() + + // fileblob has no URL signer configured, so this must surface as + // ErrSignedURLUnsupported rather than a generic error. + _, err := b.SignedURL(ctx, "test/file.txt", time.Minute) + if !errors.Is(err, ErrSignedURLUnsupported) { + t.Errorf("SignedURL on fileblob = %v, want ErrSignedURLUnsupported", err) + } +} + func TestBlobOverwrite(t *testing.T) { b := createTestBlob(t) ctx := context.Background() diff --git a/internal/storage/filesystem.go b/internal/storage/filesystem.go index cf6a1fe..1e5a24f 100644 --- a/internal/storage/filesystem.go +++ b/internal/storage/filesystem.go @@ -6,8 +6,11 @@ import ( "encoding/hex" "fmt" "io" + fsys "io/fs" "os" "path/filepath" + "strings" + "time" ) // Filesystem implements Storage using the local filesystem. @@ -30,12 +33,19 @@ func NewFilesystem(root string) (*Filesystem, error) { return &Filesystem{root: absRoot}, nil } -func (fs *Filesystem) fullPath(path string) string { - return filepath.Join(fs.root, filepath.FromSlash(path)) +func (fs *Filesystem) fullPath(path string) (string, error) { + full := filepath.Clean(filepath.Join(fs.root, filepath.FromSlash(path))) + if full != fs.root && !strings.HasPrefix(full, fs.root+string(filepath.Separator)) { + return "", fmt.Errorf("%w: path escapes storage root", ErrNotFound) + } + return full, nil } func (fs *Filesystem) Store(ctx context.Context, path string, r io.Reader) (int64, string, error) { - fullPath := fs.fullPath(path) + fullPath, err := fs.fullPath(path) + if err != nil { + return 0, "", err + } dir := filepath.Dir(fullPath) if err := os.MkdirAll(dir, dirPermissions); err != nil { @@ -82,7 +92,10 @@ func (fs *Filesystem) Store(ctx context.Context, path string, r io.Reader) (int6 } func (fs *Filesystem) Open(ctx context.Context, path string) (io.ReadCloser, error) { - fullPath := fs.fullPath(path) + fullPath, err := fs.fullPath(path) + if err != nil { + return nil, err + } f, err := os.Open(fullPath) if err != nil { @@ -96,9 +109,12 @@ func (fs *Filesystem) Open(ctx context.Context, path string) (io.ReadCloser, err } func (fs *Filesystem) Exists(ctx context.Context, path string) (bool, error) { - fullPath := fs.fullPath(path) + fullPath, err := fs.fullPath(path) + if err != nil { + return false, err + } - _, err := os.Stat(fullPath) + _, err = os.Stat(fullPath) if err != nil { if os.IsNotExist(err) { return false, nil @@ -110,9 +126,12 @@ func (fs *Filesystem) Exists(ctx context.Context, path string) (bool, error) { } func (fs *Filesystem) Delete(ctx context.Context, path string) error { - fullPath := fs.fullPath(path) + fullPath, err := fs.fullPath(path) + if err != nil { + return err + } - err := os.Remove(fullPath) + err = os.Remove(fullPath) if err != nil && !os.IsNotExist(err) { return fmt.Errorf("removing file: %w", err) } @@ -129,8 +148,15 @@ func (fs *Filesystem) Delete(ctx context.Context, path string) error { return nil } +func (fs *Filesystem) SignedURL(_ context.Context, _ string, _ time.Duration) (string, error) { + return "", ErrSignedURLUnsupported +} + func (fs *Filesystem) Size(ctx context.Context, path string) (int64, error) { - fullPath := fs.fullPath(path) + fullPath, err := fs.fullPath(path) + if err != nil { + return 0, err + } info, err := os.Stat(fullPath) if err != nil { @@ -162,6 +188,54 @@ func (fs *Filesystem) UsedSpace(ctx context.Context) (int64, error) { return total, nil } +// ListPrefix returns object metadata for paths under a prefix. +func (fs *Filesystem) ListPrefix(ctx context.Context, prefix string) ([]ObjectInfo, error) { + searchRoot, err := fs.fullPath(prefix) + if err != nil { + return nil, err + } + + if _, err := os.Stat(searchRoot); err != nil { + if os.IsNotExist(err) { + return []ObjectInfo{}, nil + } + return nil, fmt.Errorf("stat prefix: %w", err) + } + + objects := make([]ObjectInfo, 0) + err = filepath.WalkDir(searchRoot, func(path string, entry fsys.DirEntry, err error) error { + if err != nil { + return err + } + if entry.IsDir() { + return nil + } + + info, err := entry.Info() + if err != nil { + return err + } + + relPath, err := filepath.Rel(fs.root, path) + if err != nil { + return err + } + + objects = append(objects, ObjectInfo{ + Path: filepath.ToSlash(relPath), + Size: info.Size(), + ModTime: info.ModTime(), + }) + + return nil + }) + if err != nil { + return nil, fmt.Errorf("walking prefix: %w", err) + } + + return objects, nil +} + // Root returns the root directory of the storage. func (fs *Filesystem) Root() string { return fs.root @@ -169,7 +243,8 @@ func (fs *Filesystem) Root() string { // FullPath returns the full filesystem path for a storage path. // Useful for serving files directly or debugging. -func (fs *Filesystem) FullPath(path string) string { +// Returns an error if the resulting path would escape the storage root. +func (fs *Filesystem) FullPath(path string) (string, error) { return fs.fullPath(path) } diff --git a/internal/storage/filesystem_test.go b/internal/storage/filesystem_test.go index 7b7828d..332dfbf 100644 --- a/internal/storage/filesystem_test.go +++ b/internal/storage/filesystem_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) func TestNewFilesystem(t *testing.T) { @@ -51,7 +52,10 @@ func TestFilesystemStore(t *testing.T) { } // Verify file exists on disk - fullPath := fs.FullPath("npm/lodash/4.17.21/lodash.tgz") + fullPath, err := fs.FullPath("npm/lodash/4.17.21/lodash.tgz") + if err != nil { + t.Fatalf("FullPath failed: %v", err) + } data, err := os.ReadFile(fullPath) if err != nil { t.Fatalf("reading stored file: %v", err) @@ -162,7 +166,10 @@ func TestFilesystemDelete(t *testing.T) { } // Empty parent directories should be cleaned up - nestedDir := fs.FullPath("test/delete/nested") + nestedDir, err := fs.FullPath("test/delete/nested") + if err != nil { + t.Fatalf("FullPath failed: %v", err) + } if _, err := os.Stat(nestedDir); !os.IsNotExist(err) { t.Error("empty nested directory not cleaned up") } @@ -236,6 +243,31 @@ func TestFilesystemLargeFile(t *testing.T) { assertLargeFileRoundTrip(t, createTestFilesystem(t)) } +func TestFilesystemRejectsTraversal(t *testing.T) { + tmp := t.TempDir() + fs, err := NewFilesystem(tmp) + if err != nil { + t.Fatal(err) + } + for _, p := range []string{"../etc/passwd", "../../etc/passwd", "a/../../etc/passwd"} { + if _, err := fs.Open(context.Background(), p); err == nil { + t.Errorf("Open(%q) should reject traversal", p) + } + if _, _, err := fs.Store(context.Background(), p, strings.NewReader("x")); err == nil { + t.Errorf("Store(%q) should reject traversal", p) + } + } +} + +func TestFilesystemSignedURLUnsupported(t *testing.T) { + fs := createTestFilesystem(t) + + _, err := fs.SignedURL(context.Background(), "test/file.txt", time.Minute) + if !errors.Is(err, ErrSignedURLUnsupported) { + t.Errorf("SignedURL = %v, want ErrSignedURLUnsupported", err) + } +} + func createTestFilesystem(t *testing.T) *Filesystem { t.Helper() dir := t.TempDir() diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 8a9026c..e11db53 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -15,14 +15,26 @@ import ( "encoding/hex" "errors" "io" + "time" ) const dirPermissions = 0755 var ( ErrNotFound = errors.New("artifact not found") + + // ErrSignedURLUnsupported is returned by SignedURL when the backend + // cannot generate presigned URLs (e.g. local filesystem). + ErrSignedURLUnsupported = errors.New("signed URLs not supported by storage backend") ) +// ObjectInfo contains metadata for a stored object. +type ObjectInfo struct { + Path string + Size int64 + ModTime time.Time +} + // Storage defines the interface for artifact storage backends. type Storage interface { // Store writes content from r to the given path. @@ -45,6 +57,10 @@ type Storage interface { // Returns ErrNotFound if the path does not exist. Size(ctx context.Context, path string) (int64, error) + // SignedURL returns a presigned URL granting time-limited GET access to path. + // Returns ErrSignedURLUnsupported if the backend cannot generate presigned URLs. + SignedURL(ctx context.Context, path string, expiry time.Duration) (string, error) + // UsedSpace returns the total bytes used by all stored content. UsedSpace(ctx context.Context) (int64, error)