diff --git a/.github/goreleaser.yml b/.github/goreleaser.yml index 05d95ad..6d16915 100644 --- a/.github/goreleaser.yml +++ b/.github/goreleaser.yml @@ -43,7 +43,7 @@ dockers: goarch: amd64 use: buildx extra_files: - - config.sample.yaml + - config.docker.yaml build_flag_templates: - "--pull" - "--label=org.opencontainers.image.title={{.ProjectName}}" @@ -62,7 +62,7 @@ dockers: goarch: arm64 use: buildx extra_files: - - config.sample.yaml + - config.docker.yaml build_flag_templates: - "--pull" - "--label=org.opencontainers.image.title={{.ProjectName}}" diff --git a/.github/workflows/changelog_reminder.yml b/.github/workflows/changelog_reminder.yml index 6988029..7d5be99 100644 --- a/.github/workflows/changelog_reminder.yml +++ b/.github/workflows/changelog_reminder.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Changelog Reminder uses: mskelton/changelog-reminder-action@v3 with: diff --git a/.github/workflows/release_build.yml b/.github/workflows/release_build.yml index f31d720..ea64c15 100644 --- a/.github/workflows/release_build.yml +++ b/.github/workflows/release_build.yml @@ -21,31 +21,31 @@ jobs: id: go - name: Check out code into the Go module directory - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 # See: https://goreleaser.com/ci/actions/ - name: Setup QEMU # Used for cross-compiling with goreleaser / docker - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@v4 - name: Setup Docker Buildx # Used for cross-compiling with goreleaser / docker - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@v4 - name: Login to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@v4 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Login to GitHub Container Registry - uses: docker/login-action@v3 + uses: docker/login-action@v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v5 + uses: goreleaser/goreleaser-action@v7 with: version: latest args: release --clean --config .github/goreleaser.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bd0c74..70a2465 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,28 +9,33 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Ability to store and resume processing of certs from where it left off after a restart - see sample config "recovery" (#49) - New CLI switch for creating an index file from a CT log (#49) +- Support for [Static CT](https://github.com/C2SP/C2SP/blob/main/static-ct-api.md) logs - Check for retired CT logs and prevent them from being watched / stop watching them (#77) - Accept websocket connections from all origins - Option to disable the default logs provided by Google - see sample config "disable_default_logs" +- Use of cobra for CLI argument parsing. New commands for displaying version and creating an index file ### Changed +- The configuration file for the docker container is now read from the /app/config/ directory (b9e5e6) ### Removed - Non-functional Dodo log from sample config (#78) ### Fixed - Properly remove stopped ct log workers (#74) - Added missing fields certificatePolicies and ctlPoisonByte (#85) -- Prevent race condition caused by simultaneous rw access to logmetrics +- Prevent race condition caused by simultaneous rw access to logmetrics (#91) +- Properly display metrics for all initially watched logs (#95) +- Properly add new metrics for all newly found logs (#96) ### Docs -## [v1.8.2] - 2025-11-22 +## [1.8.2] - 2025-11-22 ### Fixed - Added missing fields certificatePolicies and ctlPoisonByte (#85) -## [v1.8.1] - 2025-05-04 +## [1.8.1] - 2025-05-04 ### Fixed - No longer reject URLs with trailing slashes defined in the `additional_logs` config (#62) - When using `drop_old_logs` in the config, the server won't remove logs defined in `additional_logs` anymore (#64) -## [v1.8.0] - 2025-05-03 +## [1.8.0] - 2025-05-03 ### Security - Close several CVEs in x/crypto and x/net dependencies (#59) diff --git a/Dockerfile b/Dockerfile index faecb10..9d7797d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,11 +17,12 @@ RUN adduser \ # Copy our static executable. COPY certstream-server-go /app/certstream-server-go -COPY ./config.sample.yaml /app/config.yaml +COPY ./config.docker.yaml /app/config/config.yaml +RUN chown -R certstreamserver:certstreamserver /app/ # Use an unprivileged user. USER certstreamserver:certstreamserver EXPOSE 8080 -ENTRYPOINT ["/app/certstream-server-go"] \ No newline at end of file +ENTRYPOINT ["/app/certstream-server-go", "-config", "/app/config/config.yaml"] \ No newline at end of file diff --git a/Dockerfile_multistage b/Dockerfile_multistage index 854cb2e..d23b978 100644 --- a/Dockerfile_multistage +++ b/Dockerfile_multistage @@ -42,11 +42,11 @@ COPY --from=builder /etc/group /etc/group # Copy our static executable. COPY --from=builder /go/bin/certstream-server-go /app/certstream-server-go -COPY --chown=certstreamserver:certstreamserver ./config.sample.yaml /app/config.yaml +COPY --chown=certstreamserver:certstreamserver ./config.docker.yaml /app/config/config.yaml # Use an unprivileged user. USER certstreamserver:certstreamserver EXPOSE 8080 -ENTRYPOINT ["/app/certstream-server-go"] +ENTRYPOINT ["/app/certstream-server-go", "-config", "/app/config/config.yaml"] diff --git a/cmd/certstream-server-go/createIndex.go b/cmd/certstream-server-go/createIndex.go new file mode 100644 index 0000000..cda010f --- /dev/null +++ b/cmd/certstream-server-go/createIndex.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "log" + "os" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/d-Rickyy-b/certstream-server-go/internal/certstream" + "github.com/d-Rickyy-b/certstream-server-go/internal/config" +) + +// createIndexCmd represents the createIndex command +var createIndexCmd = &cobra.Command{ + Use: "create-index", + Short: "Create the ct_index.json based on current STHs/Checkpoints", + Long: `When using the recovery feature, certstream will store an index of the processed certificates for each CT log. +create-index will create and pre fill the ct-index.json file with the current values of the most recent certificate for each CT log.`, + + RunE: func(cmd *cobra.Command, args []string) error { + configPath, err := cmd.Flags().GetString("config") + if err != nil { + return err + } + + conf, readConfErr := config.ReadConfig(configPath) + if readConfErr != nil { + return readConfErr + } + cs := certstream.NewRawCertstream(conf) + + force, err := cmd.Flags().GetBool("force") + if err != nil { + return err + } + + outFilePath, err := cmd.Flags().GetString("out") + if err != nil { + return err + } + + // Check if outfile already exists + outFileAbsPath, err := filepath.Abs(outFilePath) + if err != nil { + return err + } + if _, statErr := os.Stat(outFileAbsPath); statErr == nil { + if !force { + fmt.Printf("Output file '%s' already exists. Use --force to override it.\n", outFileAbsPath) + os.Exit(1) + } + } + + createErr := cs.CreateIndexFile(outFilePath) + if createErr != nil { + log.Fatalf("Error while creating index file: %v", createErr) + } + + return nil + }, +} + +func init() { + rootCmd.AddCommand(createIndexCmd) + + createIndexCmd.Flags().StringP("out", "o", "ct_index.json", "Path to the index file to create") + createIndexCmd.Flags().BoolP("force", "f", false, "Whether to override the index file if it already exists") +} diff --git a/cmd/certstream-server-go/main.go b/cmd/certstream-server-go/main.go index a37d9e7..9cc7d9b 100644 --- a/cmd/certstream-server-go/main.go +++ b/cmd/certstream-server-go/main.go @@ -1,50 +1,11 @@ package main import ( - "flag" - "fmt" "log" - - "github.com/d-Rickyy-b/certstream-server-go/internal/certstream" - "github.com/d-Rickyy-b/certstream-server-go/internal/config" ) // main is the entry point for the application. func main() { - configFile := flag.String("config", "config.yml", "path to the config file") - versionFlag := flag.Bool("version", false, "Print the version and exit") - createIndexFile := flag.Bool("create-index-file", false, "Create the ct_index.json based on current STHs") - flag.Parse() - - if *versionFlag { - fmt.Printf("certstream-server-go v%s\n", config.Version) - return - } - log.SetFlags(log.LstdFlags | log.Lshortfile) - - // If the user only wants to create the index file, we don't need to start the server - if *createIndexFile { - conf, readConfErr := config.ReadConfig(*configFile) - if readConfErr != nil { - log.Fatalf("Error while reading config: %v", readConfErr) - } - cs := certstream.NewRawCertstream(conf) - - createErr := cs.CreateIndexFile() - if createErr != nil { - log.Fatalf("Error while creating index file: %v", createErr) - } - - return - } - - log.Printf("Starting certstream-server-go v%s\n", config.Version) - - cs, err := certstream.NewCertstreamFromConfigFile(*configFile) - if err != nil { - log.Fatalf("Error while creating certstream server: %v", err) - } - - cs.Start() + Execute() } diff --git a/cmd/certstream-server-go/root.go b/cmd/certstream-server-go/root.go new file mode 100644 index 0000000..b16f8b4 --- /dev/null +++ b/cmd/certstream-server-go/root.go @@ -0,0 +1,66 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/spf13/cobra" + + "github.com/d-Rickyy-b/certstream-server-go/internal/certstream" + "github.com/d-Rickyy-b/certstream-server-go/internal/config" +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: "certstream-server-go", + Short: "A drop-in replacement for the certstream server by Calidog", + Long: `This tool aggregates, parses, and streams certificate data from multiple +certificate transparency logs via websocket connections to connected clients.`, + + RunE: func(cmd *cobra.Command, args []string) error { + // Handle --version flag + versionBool, err := cmd.Flags().GetBool("version") + if err != nil { + return err + } + if versionBool { + fmt.Printf("certstream-server-go v%s\n", config.Version) + return nil + } + + // Handle --config flag + configPath, err := cmd.Flags().GetString("config") + if err != nil { + return err + } + // Check if path exists and is a file + _, statErr := os.Stat(configPath) + if os.IsNotExist(statErr) { + return fmt.Errorf("config file '%s' does not exist", configPath) + } + + cs, err := certstream.NewCertstreamFromConfigFile(configPath) + if err != nil { + log.Fatalf("Error while creating certstream server: %v", err) + } + + cs.Start() + + return nil + }, +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} + +func init() { + rootCmd.PersistentFlags().StringP("config", "c", "config.yml", "Path to the config file") + rootCmd.Flags().BoolP("version", "v", false, "Print the version and exit") +} diff --git a/cmd/certstream-server-go/validate.go b/cmd/certstream-server-go/validate.go new file mode 100644 index 0000000..86015ec --- /dev/null +++ b/cmd/certstream-server-go/validate.go @@ -0,0 +1,51 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/spf13/cobra" + + "github.com/d-Rickyy-b/certstream-server-go/internal/config" +) + +// validateCmd represents the validate command +var validateCmd = &cobra.Command{ + Use: "validate", + Short: "Tests whether the config file is valid", + Long: `Validates a configuration file, then exits. + +This command deserializes the config and checks for errors.`, + PreRunE: func(cmd *cobra.Command, args []string) error { + // Check if config file exists + configPath, err := cmd.Flags().GetString("config") + if err != nil { + return err + } + // Check if path exists and is a file + _, statErr := os.Stat(configPath) + if os.IsNotExist(statErr) { + return fmt.Errorf("config file '%s' does not exist", configPath) + } + + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + configPath, err := cmd.Flags().GetString("config") + if err != nil { + return err + } + + readConfErr := config.ValidateConfig(configPath) + if readConfErr != nil { + log.Fatalln(readConfErr) + } + log.Println("Config file is valid!") + return nil + }, +} + +func init() { + rootCmd.AddCommand(validateCmd) +} diff --git a/config.docker.yaml b/config.docker.yaml new file mode 100644 index 0000000..efbc743 --- /dev/null +++ b/config.docker.yaml @@ -0,0 +1,43 @@ +webserver: + # For IPv6, set the listen_addr to "::" + listen_addr: "0.0.0.0" + listen_port: 8080 + # If you want to use a reverse proxy in front of the server, set this to true + # It will use the X-Forwarded-For header to get the real IP of the client + real_ip: false + full_url: "/full-stream" + lite_url: "/" + domains_only_url: "/domains-only" + cert_path: "" + cert_key_path: "" + compression_enabled: false + +prometheus: + enabled: false + listen_addr: "0.0.0.0" + listen_port: 8080 + metrics_url: "/metrics" + expose_system_metrics: false + real_ip: false + whitelist: + - "127.0.0.1/8" + +general: + # DisableDefaultLogs indicates whether the default logs used in Google Chrome and provided by Google should be disabled. + disable_default_logs: false + + # Google regularly updates the log list. If this option is set to true, the server will remove all logs no longer listed in the Google log list. + # This option defaults to true. See https://github.com/d-Rickyy-b/certstream-server-go/issues/51 + drop_old_logs: true + + # Options for resuming certificate downloads after restart + recovery: + # If enabled, the server will resume downloading certificates from the last processed and stored index for each log. + # If there is no ct_index_file or for a specific log there is no index entry, the server will start from index 0. + # Be aware that this leads to a massive number of certificates being downloaded. + # Depending on your server's performance and network connection, this could be up to 10.000 certificates per second. + # Make sure your infrastructure can handle this! + enabled: false + # Path to the file where indices are stored. Be aware that a temp file in the same path with the same name and ".tmp" as suffix will be created. + # If there are no write permissions to the path, the server will not be able to store the indices. + ct_index_file: "/app/config/ct_index.json" diff --git a/config.sample.yaml b/config.sample.yaml index c442d9c..eae11ea 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -27,10 +27,15 @@ general: disable_default_logs: false # When you want to add logs that are not contained in the log list provided by # Google (https://www.gstatic.com/ct/log_list/v3/log_list.json), you can add them here. - additional_logs: - - url: https://ct.googleapis.com/logs/us1/mirrors/digicert_nessie2022 - operator: "DigiCert" - description: "DigiCert Nessie2022 log" + #additional_logs: + # - url: https://ct.googleapis.com/logs/us1/mirrors/digicert_nessie2022 + # operator: "DigiCert" + # description: "DigiCert Nessie2022 log" + + #additional_tiled_logs: + # - url: https://ct.cloudflare.com/logs/raio2025h2b/ + # operator: "Cloudflare" + # description: "Cloudflare 'Raio2025h2b'" # To optimize the performance of the server, you can overwrite the size of different buffers # For low CPU, low memory machines, you should reduce the buffer sizes to save memory in case the CPU is maxed. diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 4edc2fe..0e93092 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -6,11 +6,17 @@ services: restart: always # Configure the service to run as specific user # user: "1000:1000" + # environment: + # You can change config options via env vars. + # - CERTSTREAM_LISTEN_ADDR="0.0.0.0" + # - CERTSTREAM_LISTEN_PORT=8080 ports: - - 127.0.0.1:8080:80 + - "127.0.0.1:8080:80" # Don't forget to open the other port in case you run the Prometheus endpoint on another port than the websocket server. - # - 127.0.0.1:8081:81 + # - "127.0.0.1:8081:81" volumes: - - ./certstream/config.yml:/app/config.yml + # Starting with v1.9.0, the docker container expects the config file in /app/config/config.yml. + # Mounting a config directory is suggested. + - ./certstream/config:/app/config networks: - monitoring diff --git a/go.mod b/go.mod index ebad9e7..2f4bdd9 100644 --- a/go.mod +++ b/go.mod @@ -8,17 +8,30 @@ require ( github.com/VictoriaMetrics/metrics v1.40.2 github.com/go-chi/chi/v5 v5.2.3 github.com/google/certificate-transparency-go v1.3.2 + github.com/google/trillian v1.7.2 github.com/gorilla/websocket v1.5.3 - gopkg.in/yaml.v3 v3.0.1 + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 + golang.org/x/crypto v0.46.0 ) require ( + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect - github.com/google/trillian v1.7.2 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect github.com/valyala/fastrand v1.1.0 // indirect github.com/valyala/histogram v1.2.0 // indirect - golang.org/x/crypto v0.46.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect google.golang.org/grpc v1.78.0 // indirect google.golang.org/protobuf v1.36.11 // indirect diff --git a/go.sum b/go.sum index 0a89c3e..c326743 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,18 @@ github.com/VictoriaMetrics/metrics v1.40.2 h1:OVSjKcQEx6JAwGeu8/KQm9Su5qJ72TMEW4xYn5vw3Ac= github.com/VictoriaMetrics/metrics v1.40.2/go.mod h1:XE4uudAAIRaJE614Tl5HMrtoEU6+GDZO4QTnNSsZRuA= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/certificate-transparency-go v1.3.2 h1:9ahSNZF2o7SYMaKaXhAumVEzXB2QaayzII9C8rv7v+A= @@ -14,14 +23,48 @@ github.com/google/trillian v1.7.2 h1:EPBxc4YWY4Ak8tcuhyFleY+zYlbCDCa4Sn24e1Ka8Js github.com/google/trillian v1.7.2/go.mod h1:mfQJW4qRH6/ilABtPYNBerVJAJ/upxHLX81zxNQw05s= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G8= github.com/valyala/fastrand v1.1.0/go.mod h1:HWqCzkrkg6QXT8V2EXWvXCoow7vLwOFN002oeRzjapQ= github.com/valyala/histogram v1.2.0 h1:wyYGAZZt3CpwUiIb9AU/Zbllg1llXyrtApRS815OLoQ= github.com/valyala/histogram v1.2.0/go.mod h1:Hb4kBwb4UxsaNbbbh+RRz8ZR6pdodR57tzWUS3BUzXY= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= @@ -36,8 +79,9 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= diff --git a/internal/certificatetransparency/ct-parser.go b/internal/certificatetransparency/ct-parser.go index 694707b..6f62466 100644 --- a/internal/certificatetransparency/ct-parser.go +++ b/internal/certificatetransparency/ct-parser.go @@ -11,6 +11,7 @@ import ( "hash" "log" "math/big" + "slices" "strings" "time" @@ -24,6 +25,7 @@ import ( // parseData converts a *ct.RawLogEntry struct into a certstream.Data struct by copying some values and calculating others. func parseData(entry *ct.RawLogEntry, operatorName, logName, ctURL string) (models.Data, error) { certLink := fmt.Sprintf("%s/ct/v1/get-entries?start=%d&end=%d", ctURL, entry.Index, entry.Index) + // TODO implement tiled cert link // Create main data structure data := models.Data{ @@ -131,14 +133,7 @@ func leafCertFromX509cert(cert x509.Certificate) models.LeafCert { leafCert.Subject = buildSubject(cert.Subject) if *leafCert.Subject.CN != "" && !leafCert.IsCA { - domainAlreadyAdded := false - // TODO check if CN matches domain regex - for _, domain := range leafCert.AllDomains { - if domain == *leafCert.Subject.CN { - domainAlreadyAdded = true - break - } - } + domainAlreadyAdded := slices.Contains(leafCert.AllDomains, *leafCert.Subject.CN) if !domainAlreadyAdded { leafCert.AllDomains = append(leafCert.AllDomains, *leafCert.Subject.CN) diff --git a/internal/certificatetransparency/ct-tiled.go b/internal/certificatetransparency/ct-tiled.go new file mode 100644 index 0000000..2c27b32 --- /dev/null +++ b/internal/certificatetransparency/ct-tiled.go @@ -0,0 +1,238 @@ +package certificatetransparency + +import ( + "bufio" + "context" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + ct "github.com/google/certificate-transparency-go" + "golang.org/x/crypto/cryptobyte" +) + +const TileSize = 256 + +// TiledCheckpoint represents the checkpoint information from a tiled CT log +type TiledCheckpoint struct { + Origin string + Size uint64 + Hash string +} + +// TileLeaf represents a single entry in a tile +type TileLeaf struct { + Timestamp uint64 + EntryType uint16 + X509Entry []byte // For X.509 certificates + PrecertEntry []byte // For precertificates + Chain [][]byte + IssuerKeyHash [32]byte +} + +// EncodeTilePath encodes a tile index into the proper path format +func EncodeTilePath(index uint64) string { + if index == 0 { + return "000" + } + + // Collect 3-digit groups + var groups []uint64 + for n := index; n > 0; n /= 1000 { + groups = append(groups, n%1000) + } + + // Build path from groups in reverse + var b strings.Builder + for i := len(groups) - 1; i >= 0; i-- { + if i < len(groups)-1 { + b.WriteByte('/') + } + if i > 0 { + b.WriteByte('x') + } + fmt.Fprintf(&b, "%03d", groups[i]) + } + + return b.String() +} + +// FetchCheckpoint fetches the checkpoint from a tiled CT log using the provided client +func FetchCheckpoint(ctx context.Context, client *http.Client, baseURL string) (*TiledCheckpoint, error) { + baseURL = strings.TrimRight(baseURL, "/") + url := fmt.Sprintf("%s/checkpoint", baseURL) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching checkpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("checkpoint request failed with status: %d", resp.StatusCode) + } + + scanner := bufio.NewScanner(resp.Body) + lines := make([]string, 0, 3) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading checkpoint response: %w", err) + } + + if len(lines) < 3 { + return nil, fmt.Errorf("invalid checkpoint format: expected at least 3 lines, got %d", len(lines)) + } + + size, err := strconv.ParseUint(lines[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("parsing tree size: %w", err) + } + + return &TiledCheckpoint{ + Origin: lines[0], + Size: size, + Hash: lines[2], + }, nil +} + +// FetchTile fetches a tile from the tiled CT log using the provided client. +// If partialWidth > 0, fetches a partial tile with that width (1-255). +func FetchTile(ctx context.Context, client *http.Client, baseURL string, tileIndex uint64, partialWidth uint64) ([]TileLeaf, error) { + baseURL = strings.TrimRight(baseURL, "/") + tilePath := EncodeTilePath(tileIndex) + if partialWidth > 0 { + tilePath = fmt.Sprintf("%s.p/%d", tilePath, partialWidth) + } + url := fmt.Sprintf("%s/tile/data/%s", baseURL, tilePath) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching tile %d: %w", tileIndex, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("tile request failed with status: %d", resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading tile data: %w", err) + } + + return ParseTileData(data) +} + +// ParseTileData parses the binary tile data into TileLeaf entries using cryptobyte +func ParseTileData(data []byte) ([]TileLeaf, error) { + var leaves []TileLeaf + s := cryptobyte.String(data) + + for !s.Empty() { + var leaf TileLeaf + + if !s.ReadUint64(&leaf.Timestamp) || !s.ReadUint16(&leaf.EntryType) { + return nil, fmt.Errorf("invalid data tile header") + } + + switch leaf.EntryType { + case 0: // x509_entry + var cert cryptobyte.String + var extensions, fingerprints cryptobyte.String + if !s.ReadUint24LengthPrefixed(&cert) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.ReadUint16LengthPrefixed(&fingerprints) { + return nil, fmt.Errorf("invalid data tile x509_entry") + } + leaf.X509Entry = append([]byte(nil), cert...) + for !fingerprints.Empty() { + var fp [32]byte + if !fingerprints.CopyBytes(fp[:]) { + return nil, fmt.Errorf("invalid fingerprints: truncated") + } + leaf.Chain = append(leaf.Chain, fp[:]) + } + + case 1: // precert_entry + var issuerKeyHash [32]byte + var defangedCrt, extensions, entry, fingerprints cryptobyte.String + if !s.CopyBytes(issuerKeyHash[:]) || + !s.ReadUint24LengthPrefixed(&defangedCrt) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.ReadUint24LengthPrefixed(&entry) || + !s.ReadUint16LengthPrefixed(&fingerprints) { + return nil, fmt.Errorf("invalid data tile precert_entry") + } + leaf.PrecertEntry = append([]byte(nil), defangedCrt...) + leaf.IssuerKeyHash = issuerKeyHash + for !fingerprints.Empty() { + var fp [32]byte + if !fingerprints.CopyBytes(fp[:]) { + return nil, fmt.Errorf("invalid fingerprints: truncated") + } + leaf.Chain = append(leaf.Chain, fp[:]) + } + + default: + return nil, fmt.Errorf("unknown entry type: %d", leaf.EntryType) + } + + leaves = append(leaves, leaf) + } + return leaves, nil +} + +// ConvertTileLeafToRawLogEntry converts a TileLeaf to ct.RawLogEntry for compatibility +func ConvertTileLeafToRawLogEntry(leaf TileLeaf, index uint64) *ct.RawLogEntry { + rawEntry := &ct.RawLogEntry{ + Index: int64(index), + Leaf: ct.MerkleTreeLeaf{ + Version: ct.V1, + LeafType: ct.TimestampedEntryLeafType, + }, + } + + switch leaf.EntryType { + case 0: // x509_entry + // Use the DER certificate from X509Entry + certData := leaf.X509Entry + rawEntry.Leaf.TimestampedEntry = &ct.TimestampedEntry{ + Timestamp: leaf.Timestamp, + EntryType: ct.X509LogEntryType, + X509Entry: &ct.ASN1Cert{Data: certData}, + } + rawEntry.Cert = ct.ASN1Cert{Data: certData} + + case 1: // precert_entry + // Build a minimal PreCert. TBSCertificate is the defanged TBS; IssuerKeyHash from tile. + rawEntry.Leaf.TimestampedEntry = &ct.TimestampedEntry{ + Timestamp: leaf.Timestamp, + EntryType: ct.PrecertLogEntryType, + PrecertEntry: &ct.PreCert{ + IssuerKeyHash: leaf.IssuerKeyHash, + TBSCertificate: leaf.PrecertEntry, + }, + } + + default: + // Unknown type; leave as zero-value + } + + return rawEntry +} diff --git a/internal/certificatetransparency/ct-watcher.go b/internal/certificatetransparency/ct-watcher.go index 2628303..6a07f72 100644 --- a/internal/certificatetransparency/ct-watcher.go +++ b/internal/certificatetransparency/ct-watcher.go @@ -13,7 +13,10 @@ import ( "sync/atomic" "time" + "github.com/google/trillian/client/backoff" + "github.com/d-Rickyy-b/certstream-server-go/internal/config" + "github.com/d-Rickyy-b/certstream-server-go/internal/metrics" "github.com/d-Rickyy-b/certstream-server-go/internal/models" "github.com/d-Rickyy-b/certstream-server-go/internal/web" @@ -63,9 +66,9 @@ func (w *Watcher) Start() { return } // Load Saved CT Indexes - metrics.LoadCTIndex(ctIndexFilePath) + metrics.Metrics.LoadCTIndex(ctIndexFilePath) // Save CTIndexes at regular intervals - go metrics.SaveCertIndexesAtInterval(time.Second*30, ctIndexFilePath) // save indexes every X seconds + go metrics.Metrics.SaveCertIndexesAtInterval(time.Second*30, ctIndexFilePath) // save indexes every X seconds } // initialize the watcher with currently available logs @@ -115,7 +118,7 @@ func (w *Watcher) updateLogs() { defer w.workersMu.Unlock() for _, operator := range logList.Operators { - // Iterate over each log of the operator + // Classic logs for _, transparencyLog := range operator.Logs { url := transparencyLog.URL desc := transparencyLog.Description @@ -127,7 +130,24 @@ func (w *Watcher) updateLogs() { } monitoredURLs[normURL] = struct{}{} - if w.addLogIfNew(operator.Name, desc, url) { + if w.addLogIfNew(operator.Name, desc, url, false) { + newCTs++ + } + } + + // Tiled logs + for _, transparencyLog := range operator.TiledLogs { + url := transparencyLog.MonitoringURL + desc := transparencyLog.Description + normURL := normalizeCtlogURL(url) + + if transparencyLog.State.LogStatus() == loglist3.RetiredLogStatus { + log.Printf("Skipping retired CT log: %s\n", normURL) + continue + } + + monitoredURLs[normURL] = struct{}{} + if w.addLogIfNew(operator.Name, desc, url, true) { newCTs++ } } @@ -154,7 +174,7 @@ func (w *Watcher) updateLogs() { // addLogIfNew checks if a log is already being watched and adds it if not. // Returns true if a new log was added, false otherwise. -func (w *Watcher) addLogIfNew(operatorName, description, url string) bool { +func (w *Watcher) addLogIfNew(operatorName, description, url string, isTiled bool) bool { normURL := normalizeCtlogURL(url) // Check if the log is already being watched @@ -168,16 +188,17 @@ func (w *Watcher) addLogIfNew(operatorName, description, url string) bool { // Log is not being watched, so add it w.wg.Add(1) - lastCTIndex := metrics.GetCTIndex(normURL) + lastCTIndex := metrics.Metrics.GetCTIndex(normURL) ctWorker := worker{ name: description, operatorName: operatorName, ctURL: url, entryChan: w.certChan, ctIndex: lastCTIndex, + isTiled: isTiled, } w.workers = append(w.workers, &ctWorker) - metrics.Init(operatorName, normURL) + metrics.Metrics.Init(operatorName, normURL) // Start a goroutine for each worker go func() { @@ -212,7 +233,7 @@ func (w *Watcher) Stop() { if config.AppConfig.General.Recovery.Enabled { // Store current CT Indexes before shutting down filePath := config.AppConfig.General.Recovery.CTIndexFile - metrics.SaveCertIndexes(filePath) + metrics.Metrics.SaveCertIndexes(filePath) } w.cancelFunc() @@ -230,9 +251,15 @@ func (w *Watcher) CreateIndexFile(filePath string) error { for _, operator := range logs.Operators { // Iterate over each log of the operator for _, transparencyLog := range operator.Logs { + if transparencyLog.State.LogStatus() == loglist3.RetiredLogStatus { + log.Printf("Skipping retired CT log: %s\n", transparencyLog.URL) + continue + } + + normalizedURL := normalizeCtlogURL(transparencyLog.URL) // Check if the log is already being watched - metrics.Init(operator.Name, normalizeCtlogURL(transparencyLog.URL)) - log.Println("Fetching STH for", normalizeCtlogURL(transparencyLog.URL)) + metrics.Metrics.Init(operator.Name, normalizedURL) + log.Println("Fetching STH for", normalizedURL) hc := http.Client{Timeout: 5 * time.Second} jsonClient, e := client.New(transparencyLog.URL, &hc, jsonclient.Options{UserAgent: userAgent}) @@ -248,12 +275,31 @@ func (w *Watcher) CreateIndexFile(filePath string) error { continue } - metrics.SetCTIndex(normalizeCtlogURL(transparencyLog.URL), sth.TreeSize) + metrics.Metrics.SetCTIndex(normalizedURL, sth.TreeSize) + } + for _, transparencyLog := range operator.TiledLogs { + if transparencyLog.State.LogStatus() == loglist3.RetiredLogStatus { + log.Printf("Skipping retired CT log: %s\n", transparencyLog.MonitoringURL) + continue + } + // Check if the log is already being watched + normalizedURL := normalizeCtlogURL(transparencyLog.MonitoringURL) + metrics.Metrics.Init(operator.Name, normalizedURL) + log.Println("Fetching checkpoint for", normalizedURL) + + hc := &http.Client{Timeout: 10 * time.Second} + checkpoint, fetchErr := FetchCheckpoint(w.context, hc, transparencyLog.MonitoringURL) + if fetchErr != nil { + log.Printf("Could not get checkpoint for '%s': %s\n", transparencyLog.MonitoringURL, fetchErr) + return errFetchingSTHFailed + } + + metrics.Metrics.SetCTIndex(normalizedURL, checkpoint.Size) } } w.cancelFunc() - metrics.SaveCertIndexes(filePath) + metrics.Metrics.SaveCertIndexes(filePath) log.Println("Index file saved to", filePath) return nil @@ -269,6 +315,7 @@ type worker struct { mu sync.Mutex running bool cancel context.CancelFunc + isTiled bool } // startDownloadingCerts starts downloading certificates from the CT log. This method is blocking. @@ -298,7 +345,14 @@ func (w *worker) startDownloadingCerts(ctx context.Context) { for { log.Printf("Starting worker for CT log: %s\n", w.ctURL) - workerErr := w.runWorker(ctx) + + var workerErr error + if w.isTiled { + workerErr = w.runTiledWorker(ctx) + } else { + workerErr = w.runStandardWorker(ctx) + } + if workerErr != nil { if errors.Is(workerErr, errFetchingSTHFailed) { // TODO this could happen due to a 429 error. We should retry the request @@ -315,7 +369,7 @@ func (w *worker) startDownloadingCerts(ctx context.Context) { log.Printf("Worker for '%s' failed with unexpected error: %s\n", w.ctURL, workerErr) } - // Check if the context was cancelled + // Check if the context was canceled select { case <-ctx.Done(): log.Printf("Context was cancelled; Stopping worker for '%s'\n", w.ctURL) @@ -338,8 +392,8 @@ func (w *worker) stop() { w.cancel() } -// runWorker runs a single worker for a single CT log. This method is blocking. -func (w *worker) runWorker(ctx context.Context) error { +// runStandardWorker runs the worker for a single standard CT log. This method is blocking. +func (w *worker) runStandardWorker(ctx context.Context) error { hc := http.Client{Timeout: 30 * time.Second} jsonClient, e := client.New(w.ctURL, &hc, jsonclient.Options{UserAgent: userAgent}) if e != nil { @@ -384,6 +438,128 @@ func (w *worker) runWorker(ctx context.Context) error { return nil } +// runTiledWorker runs the worker for a single tiled CT log. This method is blocking. +func (w *worker) runTiledWorker(ctx context.Context) error { + hc := &http.Client{Timeout: 30 * time.Second} + + // If recovery is enabled and the CT index is set, we start at the saved index. Otherwise, we start at the latest checkpoint. + validSavedCTIndexExists := config.AppConfig.General.Recovery.Enabled + if !validSavedCTIndexExists { + checkpoint, err := FetchCheckpoint(ctx, hc, w.ctURL) + if err != nil { + log.Printf("Could not get checkpoint for '%s': %s\n", w.ctURL, err) + return errFetchingSTHFailed + } + // Start at the latest checkpoint to skip all the past certificates + w.ctIndex = checkpoint.Size + } + + // Initialize backoff for polling + pollBackoff := &backoff.Backoff{ + Min: 1 * time.Second, + Max: 30 * time.Second, + Factor: 2, + Jitter: true, + } + + // Continuous monitoring loop + for { + hadNewEntries, err := w.processTiledLogUpdates(ctx, hc) + if err != nil { + log.Printf("Error processing tiled log updates for '%s': %s\n", w.ctURL, err) + return err + } + + // Reset backoff if we found new entries + if hadNewEntries { + pollBackoff.Reset() + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(pollBackoff.Duration()): + // Continue to the next iteration + } + } +} + +// processTiledLogUpdates checks for new entries in the tiled log and processes them +func (w *worker) processTiledLogUpdates(ctx context.Context, hc *http.Client) (bool, error) { + // Fetch current checkpoint + checkpoint, err := FetchCheckpoint(ctx, hc, w.ctURL) + if err != nil { + return false, fmt.Errorf("fetching checkpoint: %w", err) + } + + currentTreeSize := checkpoint.Size + if currentTreeSize <= w.ctIndex { + // No new entries + return false, nil + } + + // Process entries from current index to new tree size + startTile := w.ctIndex / TileSize + endTile := currentTreeSize / TileSize + + // Process complete tiles + for tileIndex := startTile; tileIndex < endTile; tileIndex++ { + if err := w.processTile(ctx, hc, tileIndex, 0); err != nil { + return false, fmt.Errorf("processing tile %d: %w", tileIndex, err) + } + } + + // Process partial tile if exists + partialSize := currentTreeSize % TileSize + if partialSize > 0 { + if err := w.processTile(ctx, hc, endTile, partialSize); err != nil { + log.Printf("Warning: error processing partial tile %d: %s\n", endTile, err) + // Don't return error for partial tiles as they might be incomplete + } + } + + return true, nil +} + +// processTile processes a single tile from the tiled log. +// partialWidth of 0 means full tile, otherwise fetch partial tile with that width. +func (w *worker) processTile(ctx context.Context, hc *http.Client, tileIndex uint64, partialWidth uint64) error { + leaves, err := FetchTile(ctx, hc, w.ctURL, tileIndex, partialWidth) + if err != nil { + return fmt.Errorf("fetching tile: %w", err) + } + + // Calculate the starting index for entries in this tile + baseIndex := tileIndex * TileSize + + for i, leaf := range leaves { + entryIndex := baseIndex + uint64(i) + + // Skip entries we've already processed + if entryIndex <= w.ctIndex { + continue + } + + // Convert TileLeaf to RawLogEntry for compatibility with existing parsing + rawEntry := ConvertTileLeafToRawLogEntry(leaf, entryIndex) + + // Process the entry using existing callbacks + switch leaf.EntryType { + case 0: + w.foundCertCallback(rawEntry) + case 1: + w.foundPrecertCallback(rawEntry) + default: + log.Printf("Unknown entry type %d in tile %d, skipping entry at index %d\n", leaf.EntryType, tileIndex, entryIndex) + } + + // Update the index + w.ctIndex = entryIndex + } + + return nil +} + // foundCertCallback is the callback that handles cases where new regular certs are found. func (w *worker) foundCertCallback(rawEntry *ct.RawLogEntry) { entry, parseErr := ParseCertstreamEntry(rawEntry, w.operatorName, w.name, w.ctURL) @@ -395,7 +571,7 @@ func (w *worker) foundCertCallback(rawEntry *ct.RawLogEntry) { entry.Data.UpdateType = "X509LogEntry" w.entryChan <- entry - atomic.AddInt64(&processedCerts, 1) + atomic.AddInt64(&metrics.ProcessedCerts, 1) } // foundPrecertCallback is the callback that handles cases where new precerts are found. @@ -409,13 +585,13 @@ func (w *worker) foundPrecertCallback(rawEntry *ct.RawLogEntry) { entry.Data.UpdateType = "PrecertLogEntry" w.entryChan <- entry - atomic.AddInt64(&processedPrecerts, 1) + atomic.AddInt64(&metrics.ProcessedPrecerts, 1) } // certHandler takes the entries out of the entryChan channel and broadcasts them to all clients. // Only a single instance of the certHandler runs per certstream server. func certHandler(entryChan chan models.Entry) { - var processed int64 + var processed uint64 for { entry := <-entryChan @@ -427,7 +603,7 @@ func certHandler(entryChan chan models.Entry) { web.SetExampleCert(entry) } - // Run json encoding in the background and send the result to the clients. + // Run JSON encoding in the background and send the result to the clients. web.ClientHandler.Broadcast <- entry // Update metrics @@ -435,13 +611,13 @@ func certHandler(entryChan chan models.Entry) { operator := entry.Data.Source.Operator index := entry.Data.CertIndex - metrics.Inc(operator, url, index) + metrics.Metrics.Inc(operator, url, index) } } // getGoogleLogList fetches the list of all CT logs from Google Chromes CT LogList. func getGoogleLogList() (loglist3.LogList, error) { - // Download the list of all logs from ctLogInfo and decode json + // Download the list of all logs from ctLogInfo and decode JSON resp, err := http.Get(loglist3.LogListURL) if err != nil { return loglist3.LogList{}, err @@ -480,10 +656,11 @@ func getAllLogs() (loglist3.LogList, error) { } // Add manually added logs from config to the allLogs list - if config.AppConfig.General.AdditionalLogs == nil { - return allLogs, nil - } + // if config.AppConfig.General.AdditionalLogs == nil { + // return allLogs, nil + // } +logFound: for _, additionalLog := range config.AppConfig.General.AdditionalLogs { customLog := loglist3.Log{ URL: additionalLog.URL, @@ -493,10 +670,16 @@ func getAllLogs() (loglist3.LogList, error) { operatorFound := false for _, operator := range allLogs.Operators { if operator.Name == additionalLog.Operator { - // TODO Check if the log is already in the list - operator.Logs = append(operator.Logs, &customLog) operatorFound = true + for _, ctlog := range operator.Logs { + if ctlog.URL == additionalLog.URL { + // Log already exists, skip it. + break logFound + } + } + // This works, since allLogs.Operators is a slice of pointers. + operator.Logs = append(operator.Logs, &customLog) break } } @@ -510,6 +693,39 @@ func getAllLogs() (loglist3.LogList, error) { } } + for _, additionalLog := range config.AppConfig.General.AdditionalTiledLogs { + customLog := loglist3.TiledLog{ + MonitoringURL: additionalLog.URL, + Description: additionalLog.Description, + } + + operatorFound := false + + tiledLogFound: + for _, operator := range allLogs.Operators { + if operator.Name == additionalLog.Operator { + operatorFound = true + for _, tl := range operator.TiledLogs { + if tl.MonitoringURL == additionalLog.URL { + // Log already exists, skip it. + break tiledLogFound + } + } + // This works, since allLogs.Operators is a slice of pointers. + operator.TiledLogs = append(operator.TiledLogs, &customLog) + break + } + } + + if !operatorFound { + newOperator := loglist3.Operator{ + Name: additionalLog.Operator, + TiledLogs: []*loglist3.TiledLog{&customLog}, + } + allLogs.Operators = append(allLogs.Operators, &newOperator) + } + } + return allLogs, nil } diff --git a/internal/certstream/certstream.go b/internal/certstream/certstream.go index 6ade396..14a2e3b 100644 --- a/internal/certstream/certstream.go +++ b/internal/certstream/certstream.go @@ -62,11 +62,11 @@ func (cs *Certstream) setupMetrics(webserver *web.WebServer) { if (cs.config.Prometheus.ListenAddr == "" || cs.config.Prometheus.ListenAddr == cs.config.Webserver.ListenAddr) && (cs.config.Prometheus.ListenPort == 0 || cs.config.Prometheus.ListenPort == cs.config.Webserver.ListenPort) { log.Println("Starting prometheus server on same interface as webserver") - webserver.RegisterPrometheus(cs.config.Prometheus.MetricsURL, metrics.WritePrometheus) + webserver.RegisterPrometheus(cs.config.Prometheus.MetricsURL, metrics.Prometheus.Write) } else { log.Println("Starting prometheus server on new interface") cs.metricsServer = web.NewMetricsServer(cs.config.Prometheus.ListenAddr, cs.config.Prometheus.ListenPort, cs.config.Prometheus.CertPath, cs.config.Prometheus.CertKeyPath) - cs.metricsServer.RegisterPrometheus(cs.config.Prometheus.MetricsURL, metrics.WritePrometheus) + cs.metricsServer.RegisterPrometheus(cs.config.Prometheus.MetricsURL, metrics.Prometheus.Write) } } } @@ -118,13 +118,13 @@ func (cs *Certstream) Stop() { // CreateIndexFile creates the index file for the certificate transparency logs. // It gets only called when the CLI flag --create-index-file is set. -func (cs *Certstream) CreateIndexFile() error { +func (cs *Certstream) CreateIndexFile(outFile string) error { // If there is no watcher initialized, create a new one if cs.watcher == nil { cs.watcher = &certificatetransparency.Watcher{} } - return cs.watcher.CreateIndexFile(cs.config.General.Recovery.CTIndexFile) + return cs.watcher.CreateIndexFile(outFile) } // signalHandler listens for signals in order to gracefully shut down the server. diff --git a/internal/config/config.go b/internal/config/config.go index 24dd3fb..e25079f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,146 +1,165 @@ package config import ( + "errors" "log" "net" - "os" - "path/filepath" "regexp" "strings" - "gopkg.in/yaml.v3" + "github.com/spf13/viper" ) var ( - AppConfig Config - Version = "1.8.1" + AppConfig Config + viperInstance *viper.Viper + Version = "1.9.0" ) type ServerConfig struct { - ListenAddr string `yaml:"listen_addr"` - ListenPort int `yaml:"listen_port"` - CertPath string `yaml:"cert_path"` - CertKeyPath string `yaml:"cert_key_path"` - RealIP bool `yaml:"real_ip"` - Whitelist []string `yaml:"whitelist"` + ListenAddr string `mapstructure:"listen_addr"` + ListenPort int `mapstructure:"listen_port"` + CertPath string `mapstructure:"cert_path"` + CertKeyPath string `mapstructure:"cert_key_path"` + RealIP bool `mapstructure:"real_ip"` + Whitelist []string `mapstructure:"whitelist"` } type LogConfig struct { - Operator string `yaml:"operator"` - URL string `yaml:"url"` - Description string `yaml:"description"` + Operator string `mapstructure:"operator"` + URL string `mapstructure:"url"` + Description string `mapstructure:"description"` } type BufferSizes struct { - Websocket int `yaml:"websocket"` - CTLog int `yaml:"ctlog"` - BroadcastManager int `yaml:"broadcastmanager"` + Websocket int `mapstructure:"websocket"` + CTLog int `mapstructure:"ctlog"` + BroadcastManager int `mapstructure:"broadcastmanager"` } type Config struct { Webserver struct { - ServerConfig `yaml:",inline"` - FullURL string `yaml:"full_url"` - LiteURL string `yaml:"lite_url"` - DomainsOnlyURL string `yaml:"domains_only_url"` - CompressionEnabled bool `yaml:"compression_enabled"` + ServerConfig `mapstructure:",squash"` + FullURL string `mapstructure:"full_url"` + LiteURL string `mapstructure:"lite_url"` + DomainsOnlyURL string `mapstructure:"domains_only_url"` + CompressionEnabled bool `mapstructure:"compression_enabled"` } Prometheus struct { - ServerConfig `yaml:",inline"` - Enabled bool `yaml:"enabled"` - MetricsURL string `yaml:"metrics_url"` - ExposeSystemMetrics bool `yaml:"expose_system_metrics"` + ServerConfig `mapstructure:",squash"` + Enabled bool `mapstructure:"enabled"` + MetricsURL string `mapstructure:"metrics_url"` + ExposeSystemMetrics bool `mapstructure:"expose_system_metrics"` } General struct { // DisableDefaultLogs indicates whether the default logs used in Google Chrome and provided by Google should be disabled. - DisableDefaultLogs bool `yaml:"disable_default_logs"` + DisableDefaultLogs bool `mapstructure:"disable_default_logs"` // AdditionalLogs contains additional logs provided by the user that can be used in addition to the default logs. - AdditionalLogs []LogConfig `yaml:"additional_logs"` - BufferSizes BufferSizes `yaml:"buffer_sizes"` - DropOldLogs *bool `yaml:"drop_old_logs"` - Recovery struct { - Enabled bool `yaml:"enabled"` - CTIndexFile string `yaml:"ct_index_file"` - } `yaml:"recovery"` + AdditionalLogs []LogConfig `mapstructure:"additional_logs"` + AdditionalTiledLogs []LogConfig `mapstructure:"additional_tiled_logs"` + BufferSizes BufferSizes `mapstructure:"buffer_sizes"` + DropOldLogs *bool `mapstructure:"drop_old_logs"` + Recovery struct { + Enabled bool `mapstructure:"enabled"` + CTIndexFile string `mapstructure:"ct_index_file"` + } `mapstructure:"recovery"` } } -// ReadConfig reads the config file and returns a filled Config struct. +// ReadConfig reads the configuration using Viper and returns a filled Config struct. +// It also validates and stores the result in AppConfig. func ReadConfig(configPath string) (Config, error) { - log.Printf("Reading config file '%s'...\n", configPath) + v := initViper(configPath) + return loadConfigFromViper(v) +} - conf, parseErr := parseConfigFromFile(configPath) - if parseErr != nil { - log.Fatalln("Error while parsing yaml file:", parseErr) - } +// ValidateConfig validates the config file and returns an error if the config is invalid. +func ValidateConfig(configPath string) error { + _, parseErr := ReadConfig(configPath) + return parseErr +} - if !validateConfig(conf) { - log.Fatalln("Invalid config") +// initViper sets up the viper instance with defaults, config file and environment variable support. +// configPath is the path to the YAML config file (e.g. "config.yaml"). +// Environment variables are mapped with the prefix "CERTSTREAM" and "__" as key delimiter. +// Example: CERTSTREAM_WEBSERVER__LISTEN_PORT overrides webserver.listen_port. +func initViper(configPath string) *viper.Viper { + v := viper.NewWithOptions(viper.KeyDelimiter(".")) + + // Defaults + v.SetDefault("webserver.listen_addr", "0.0.0.0") + v.SetDefault("webserver.listen_port", 8080) + v.SetDefault("webserver.full_url", "/full-stream") + v.SetDefault("webserver.lite_url", "/") + v.SetDefault("webserver.domains_only_url", "/domains-only") + v.SetDefault("webserver.real_ip", false) + v.SetDefault("webserver.compression_enabled", false) + + v.SetDefault("prometheus.enabled", false) + v.SetDefault("prometheus.listen_addr", "0.0.0.0") + v.SetDefault("prometheus.listen_port", 9090) + v.SetDefault("prometheus.metrics_url", "/metrics") + v.SetDefault("prometheus.expose_system_metrics", false) + v.SetDefault("prometheus.real_ip", false) + + v.SetDefault("general.disable_default_logs", false) + v.SetDefault("general.buffer_sizes.websocket", 300) + v.SetDefault("general.buffer_sizes.ctlog", 1000) + v.SetDefault("general.buffer_sizes.broadcastmanager", 10000) + v.SetDefault("general.drop_old_logs", true) + v.SetDefault("general.recovery.enabled", false) + v.SetDefault("general.recovery.ct_index_file", "./ct_index.json") + + // TODO check for missing file?! + // Config file + if configPath != "" { + v.SetConfigFile(configPath) + } else { + v.SetConfigName("config") + v.AddConfigPath(".") + v.AddConfigPath("/app/config") } - AppConfig = *conf - return *conf, nil -} - -// parseConfigFromFile reads the config file as bytes and passes it to parseConfigFromBytes. -// It returns a filled Config struct. -func parseConfigFromFile(configFile string) (*Config, error) { - if configFile == "" { - configFile = "config.yml" - } - - // Check if the file exists - absPath, err := filepath.Abs(configFile) - if err != nil { - log.Printf("Couldn't convert to absolute path: '%s'\n", configFile) - return &Config{}, err - } - - if _, statErr := os.Stat(absPath); os.IsNotExist(statErr) { - log.Printf("Config file '%s' does not exist\n", absPath) - ext := filepath.Ext(absPath) - absPath = strings.TrimSuffix(absPath, ext) - - switch ext { - case ".yaml": - absPath += ".yml" - case ".yml": - absPath += ".yaml" - default: - log.Printf("Config file '%s' does not have a valid extension\n", configFile) - return &Config{}, statErr - } + v.SetConfigType("yaml") - if _, secondStatErr := os.Stat(absPath); os.IsNotExist(secondStatErr) { - log.Printf("Config file '%s' does not exist\n", absPath) - return &Config{}, secondStatErr + if err := v.ReadInConfig(); err != nil { + var notFound viper.ConfigFileNotFoundError + if errors.As(err, ¬Found) { + log.Println("No config file found, using defaults and environment variables only") + } else { + log.Fatalf("Error reading config file: %v", err) } + } else { + log.Printf("Using config file: %s\n", v.ConfigFileUsed()) } - log.Printf("File '%s' exists\n", absPath) - yamlFileContent, readErr := os.ReadFile(absPath) - if readErr != nil { - return &Config{}, readErr - } - - conf, parseErr := parseConfigFromBytes(yamlFileContent) - if parseErr != nil { - return &Config{}, parseErr - } + // Environment variables + // Prefix: CERTSTREAM (e.g. CERTSTREAM_WEBSERVER_LISTEN_PORT) + // Viper uses "." as key delimiter internally; environment variables use "_". + // We replace "." with "_" when looking up env vars automatically. + v.SetEnvPrefix("CERTSTREAM") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.AutomaticEnv() - return conf, nil + viperInstance = v + return v } -// parseConfigFromBytes parses the config bytes and returns a filled Config struct. -func parseConfigFromBytes(data []byte) (*Config, error) { - var config Config +// loadConfigFromViper unmarshals a viper instance into a Config struct, validates it +// and stores the result in AppConfig. +func loadConfigFromViper(v *viper.Viper) (Config, error) { + var cfg Config - err := yaml.Unmarshal(data, &config) - if err != nil { - return &config, err + if err := v.Unmarshal(&cfg); err != nil { + return cfg, err } - return &config, nil + if !validateConfig(&cfg) { + return cfg, errors.New("invalid configuration") + } + + AppConfig = cfg + return cfg, nil } // validateConfig validates the config values and sets defaults for missing values. @@ -211,7 +230,7 @@ func validateConfig(config *Config) bool { } } - var validLogs []LogConfig + var validLogs, validTiledLogs []LogConfig if len(config.General.AdditionalLogs) > 0 { for _, ctLog := range config.General.AdditionalLogs { if !URLRegex.MatchString(ctLog.URL) { @@ -221,11 +240,25 @@ func validateConfig(config *Config) bool { validLogs = append(validLogs, ctLog) } - } else if len(config.General.AdditionalLogs) == 0 && config.General.DisableDefaultLogs { - log.Fatalln("Default logs are disabled, but no additional logs are configured. Please add at least one log to the config or enable default logs.") + } + + if len(config.General.AdditionalTiledLogs) > 0 { + for _, ctLog := range config.General.AdditionalTiledLogs { + if !URLRegex.MatchString(ctLog.URL) { + log.Println("Ignoring invalid additional log URL: ", ctLog.URL) + continue + } + + validTiledLogs = append(validTiledLogs, ctLog) + } } config.General.AdditionalLogs = validLogs + config.General.AdditionalTiledLogs = validTiledLogs + + if len(config.General.AdditionalLogs) == 0 && len(config.General.AdditionalTiledLogs) == 0 && config.General.DisableDefaultLogs { + log.Fatalln("Default logs are disabled, but no additional logs are configured. Please add at least one log to the config or enable default logs.") + } if config.General.BufferSizes.Websocket <= 0 { config.General.BufferSizes.Websocket = 300 diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..9d856cd --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,453 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/viper" +) + +// minimalValidYAML is the smallest config that passes validateConfig. +// It only sets the fields that validateConfig strictly requires, leaving +// everything else at its viper default so we can assert default values. +const minimalValidYAML = ` +webserver: + listen_addr: "0.0.0.0" + listen_port: 8080 + full_url: "/full-stream" + lite_url: "/" + domains_only_url: "/domains-only" +` + +// writeConfigFile writes content to a temporary YAML file and returns its path. +func writeConfigFile(t *testing.T, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "config.yaml") + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("failed to write temp config file: %v", err) + } + return path +} + +// TestWebserverDefaults uses a config file that only sets required +// webserver fields so that optional keys still reflect their defaults. +func TestWebserverDefaults(t *testing.T) { + // Only listen_addr and listen_port are set. Remaining keys should be default. + yaml := ` +webserver: + listen_addr: "0.0.0.0" + listen_port: 8080 +` + configPath := writeConfigFile(t, yaml) + v := initViper(configPath) + + cases := []struct { + key string + want interface{} + }{ + {"webserver.full_url", "/full-stream"}, + {"webserver.lite_url", "/"}, + {"webserver.domains_only_url", "/domains-only"}, + {"webserver.real_ip", false}, + {"webserver.compression_enabled", false}, + } + + for _, testcase := range cases { + switch want := testcase.want.(type) { + case string: + if got := v.GetString(testcase.key); got != want { + t.Errorf("key %s: want %q, got %q", testcase.key, want, got) + } + case bool: + if got := v.GetBool(testcase.key); got != want { + t.Errorf("key %s: want %v, got %v", testcase.key, want, got) + } + } + } +} + +// TestPrometheusDefaults uses a config file that only sets required +// webserver fields so that optional keys still reflect their defaults. +func TestPrometheusDefaults(t *testing.T) { + // No prometheus section in the config – all keys should come from defaults. + configPath := writeConfigFile(t, minimalValidYAML) + v := initViper(configPath) + + if got := v.GetBool("prometheus.enabled"); got != false { + t.Errorf("prometheus.enabled: want false, got %v", got) + } + if got := v.GetString("prometheus.listen_addr"); got != "0.0.0.0" { + t.Errorf("prometheus.listen_addr: want '0.0.0.0', got %q", got) + } + if got := v.GetInt("prometheus.listen_port"); got != 9090 { + t.Errorf("prometheus.listen_port: want 9090, got %d", got) + } + if got := v.GetString("prometheus.metrics_url"); got != "/metrics" { + t.Errorf("prometheus.metrics_url: want '/metrics', got %q", got) + } + if got := v.GetBool("prometheus.expose_system_metrics"); got != false { + t.Errorf("prometheus.expose_system_metrics: want false, got %v", got) + } + if got := v.GetBool("prometheus.real_ip"); got != false { + t.Errorf("prometheus.real_ip: want false, got %v", got) + } +} + +// TestGeneralDefaults uses an empty string as configPath. +// It tests whether the defaults are properly configured. +func TestGeneralDefaults(t *testing.T) { + // No general section in the config – all keys should come from defaults. + + v := initViper("") + + if got := v.GetBool("general.disable_default_logs"); got != false { + t.Errorf("general.disable_default_logs: want false, got %v", got) + } + if got := v.GetInt("general.buffer_sizes.websocket"); got != 300 { + t.Errorf("general.buffer_sizes.websocket: want 300, got %d", got) + } + if got := v.GetInt("general.buffer_sizes.ctlog"); got != 1000 { + t.Errorf("general.buffer_sizes.ctlog: want 1000, got %d", got) + } + if got := v.GetInt("general.buffer_sizes.broadcastmanager"); got != 10000 { + t.Errorf("general.buffer_sizes.broadcastmanager: want 10000, got %d", got) + } + if got := v.GetBool("general.drop_old_logs"); got != true { + t.Errorf("general.drop_old_logs: want true, got %v", got) + } + if got := v.GetBool("general.recovery.enabled"); got != false { + t.Errorf("general.recovery.enabled: want false, got %v", got) + } + if got := v.GetString("general.recovery.ct_index_file"); got != "./ct_index.json" { + t.Errorf("general.recovery.ct_index_file: want './ct_index.json', got %q", got) + } +} + +// TestConfigFileOverridesDefaults verifies that values from the config file override +// the defaults, while keys absent from the file still return their default values. +func TestConfigFileOverridesDefaults(t *testing.T) { + yaml := ` +webserver: + listen_addr: "127.0.0.1" + listen_port: 9999 +` + configPath := writeConfigFile(t, yaml) + v := initViper(configPath) + + if got := v.GetString("webserver.listen_addr"); got != "127.0.0.1" { + t.Errorf("listen_addr: want '127.0.0.1', got %q", got) + } + if got := v.GetInt("webserver.listen_port"); got != 9999 { + t.Errorf("listen_port: want 9999, got %d", got) + } + + // Keys absent from the file must still return the registered defaults. + if got := v.GetString("webserver.full_url"); got != "/full-stream" { + t.Errorf("full_url (default): want '/full-stream', got %q", got) + } +} + +// TestEnvOverridesDefaults verifies that environment variables +// take precedence over defaults when no config file is provided. +func TestEnvOverridesDefaults(t *testing.T) { + t.Setenv("CERTSTREAM_WEBSERVER_LISTEN_PORT", "6543") + + configPath := writeConfigFile(t, minimalValidYAML) + v := initViper(configPath) + + if got := v.GetInt("webserver.listen_port"); got != 6543 { + t.Errorf("listen_port via env: want 6543, got %d", got) + } +} + +// TestEnvOverridesConfigFile verifies that environment variables +// take precedence over config file values. +func TestEnvOverridesConfigFile(t *testing.T) { + configPath := writeConfigFile(t, minimalValidYAML) + t.Setenv("CERTSTREAM_WEBSERVER_LISTEN_PORT", "7777") + + v := initViper(configPath) + + if got := v.GetInt("webserver.listen_port"); got != 7777 { + t.Errorf("listen_port via env: want 7777, got %d", got) + } +} + +// TestEnvOverridesPrometheusPort verifies that environment variables +// can override config file values for the prometheus.listen_port key. +func TestEnvOverridesPrometheusPort(t *testing.T) { + configPath := writeConfigFile(t, minimalValidYAML) + t.Setenv("CERTSTREAM_PROMETHEUS_LISTEN_PORT", "19090") + + v := initViper(configPath) + + if got := v.GetInt("prometheus.listen_port"); got != 19090 { + t.Errorf("prometheus.listen_port via env: want 19090, got %d", got) + } +} + +// viperFromYAML creates a fully initialized *viper.Viper from an in-memory +// YAML string by writing it to a temp file and calling initViper. This +// exercises the same code path as production (defaults + file merge) and +// correctly handles the yaml:",inline" embedded struct tags via mapstructure. +func viperFromYAML(t *testing.T, content string) *viper.Viper { + t.Helper() + return initViper(writeConfigFile(t, content)) +} + +// validViperInstance returns a viper instance built from minimalValidYAML. +func validViperInstance(t *testing.T) *viper.Viper { + t.Helper() + return viperFromYAML(t, minimalValidYAML) +} + +// TestLoadConfigFromViper_UnmarshalsWebserver tests if the nested Webserver +// struct is unmarshalled properly. +func TestLoadConfigFromViper_UnmarshalsWebserver(t *testing.T) { + v := validViperInstance(t) + cfg, err := loadConfigFromViper(v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Webserver.ListenAddr != "0.0.0.0" { + t.Errorf("ListenAddr: want '0.0.0.0', got %q", cfg.Webserver.ListenAddr) + } + if cfg.Webserver.ListenPort != 8080 { + t.Errorf("ListenPort: want 8080, got %d", cfg.Webserver.ListenPort) + } + if cfg.Webserver.FullURL != "/full-stream" { + t.Errorf("FullURL: want '/full-stream', got %q", cfg.Webserver.FullURL) + } + if cfg.Webserver.LiteURL != "/" { + t.Errorf("LiteURL: want '/', got %q", cfg.Webserver.LiteURL) + } + if cfg.Webserver.DomainsOnlyURL != "/domains-only" { + t.Errorf("DomainsOnlyURL: want '/domains-only', got %q", cfg.Webserver.DomainsOnlyURL) + } +} + +// TestLoadConfigFromViper_UnmarshalsBufferSizes tests if the nested BufferSizes +// struct is unmarshalled properly. +func TestLoadConfigFromViper_UnmarshalsBufferSizes(t *testing.T) { + v := viperFromYAML(t, ` +webserver: + listen_addr: "0.0.0.0" + listen_port: 8080 + full_url: "/full-stream" + lite_url: "/" + domains_only_url: "/domains-only" +general: + buffer_sizes: + websocket: 512 + ctlog: 2048 + broadcastmanager: 4096 +`) + cfg, err := loadConfigFromViper(v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.General.BufferSizes.Websocket != 512 { + t.Errorf("BufferSizes.Websocket: want 512, got %d", cfg.General.BufferSizes.Websocket) + } + if cfg.General.BufferSizes.CTLog != 2048 { + t.Errorf("BufferSizes.CTLog: want 2048, got %d", cfg.General.BufferSizes.CTLog) + } + if cfg.General.BufferSizes.BroadcastManager != 4096 { + t.Errorf("BufferSizes.BroadcastManager: want 4096, got %d", cfg.General.BufferSizes.BroadcastManager) + } +} + +// TestLoadConfigFromViper_SetsAppConfig verifies that loadConfigFromViper updates +// the global AppConfig variable to match the loaded configuration. +func TestLoadConfigFromViper_SetsAppConfig(t *testing.T) { + v := validViperInstance(t) + cfg, err := loadConfigFromViper(v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if AppConfig.Webserver.ListenPort != cfg.Webserver.ListenPort { + t.Errorf("AppConfig.Webserver.ListenPort not synced: want %d, got %d", + cfg.Webserver.ListenPort, AppConfig.Webserver.ListenPort) + } + if AppConfig.Webserver.ListenAddr != cfg.Webserver.ListenAddr { + t.Errorf("AppConfig.Webserver.ListenAddr not synced: want %q, got %q", + cfg.Webserver.ListenAddr, AppConfig.Webserver.ListenAddr) + } +} + +// TestLoadConfigFromViper_RecoverySettings tests if the nested BufferSizes +// struct is unmarshalled properly. +func TestLoadConfigFromViper_RecoverySettings(t *testing.T) { + v := viperFromYAML(t, ` +webserver: + listen_addr: "0.0.0.0" + listen_port: 8080 + full_url: "/full-stream" + lite_url: "/" + domains_only_url: "/domains-only" +general: + recovery: + enabled: true + ct_index_file: "/tmp/my_index.json" +`) + cfg, err := loadConfigFromViper(v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !cfg.General.Recovery.Enabled { + t.Errorf("Recovery.Enabled: want true, got false") + } + if cfg.General.Recovery.CTIndexFile != "/tmp/my_index.json" { + t.Errorf("Recovery.CTIndexFile: want '/tmp/my_index.json', got %q", cfg.General.Recovery.CTIndexFile) + } +} + +func TestReadConfigViper_MinimalValidFile(t *testing.T) { + configPath := writeConfigFile(t, minimalValidYAML) + + cfg, err := ReadConfig(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Webserver.ListenAddr != "0.0.0.0" { + t.Errorf("ListenAddr: want '0.0.0.0', got %q", cfg.Webserver.ListenAddr) + } + if cfg.Webserver.ListenPort != 8080 { + t.Errorf("ListenPort: want 8080, got %d", cfg.Webserver.ListenPort) + } +} + +func TestReadConfigViper_PrometheusSection(t *testing.T) { + yaml := minimalValidYAML + ` +prometheus: + enabled: true + listen_addr: "0.0.0.0" + listen_port: 9090 + metrics_url: "/metrics" + expose_system_metrics: true +` + configPath := writeConfigFile(t, yaml) + + cfg, err := ReadConfig(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !cfg.Prometheus.Enabled { + t.Errorf("Prometheus.Enabled: want true, got false") + } + if cfg.Prometheus.ListenPort != 9090 { + t.Errorf("Prometheus.ListenPort: want 9090, got %d", cfg.Prometheus.ListenPort) + } + if cfg.Prometheus.MetricsURL != "/metrics" { + t.Errorf("Prometheus.MetricsURL: want '/metrics', got %q", cfg.Prometheus.MetricsURL) + } + if !cfg.Prometheus.ExposeSystemMetrics { + t.Errorf("Prometheus.ExposeSystemMetrics: want true, got false") + } +} + +func TestReadConfigViper_CustomBufferSizes(t *testing.T) { + yaml := minimalValidYAML + ` +general: + buffer_sizes: + websocket: 500 + ctlog: 2000 + broadcastmanager: 20000 +` + configPath := writeConfigFile(t, yaml) + + cfg, err := ReadConfig(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.General.BufferSizes.Websocket != 500 { + t.Errorf("BufferSizes.Websocket: want 500, got %d", cfg.General.BufferSizes.Websocket) + } + if cfg.General.BufferSizes.CTLog != 2000 { + t.Errorf("BufferSizes.CTLog: want 2000, got %d", cfg.General.BufferSizes.CTLog) + } + if cfg.General.BufferSizes.BroadcastManager != 20000 { + t.Errorf("BufferSizes.BroadcastManager: want 20000, got %d", cfg.General.BufferSizes.BroadcastManager) + } +} + +func TestReadConfigViper_AdditionalLogs(t *testing.T) { + yaml := minimalValidYAML + ` +general: + additional_logs: + - url: "https://ct.example.com/log" + operator: "Example" + description: "Example CT log" +` + configPath := writeConfigFile(t, yaml) + + cfg, err := ReadConfig(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(cfg.General.AdditionalLogs) != 1 { + t.Fatalf("AdditionalLogs: want 1 entry, got %d", len(cfg.General.AdditionalLogs)) + } + if cfg.General.AdditionalLogs[0].URL != "https://ct.example.com/log" { + t.Errorf("AdditionalLogs[0].URL: want 'https://ct.example.com/log', got %q", + cfg.General.AdditionalLogs[0].URL) + } + if cfg.General.AdditionalLogs[0].Operator != "Example" { + t.Errorf("AdditionalLogs[0].Operator: want 'Example', got %q", + cfg.General.AdditionalLogs[0].Operator) + } +} + +func TestReadConfigViper_InvalidAdditionalLogURLIgnored(t *testing.T) { + yaml := minimalValidYAML + ` +general: + additional_logs: + - url: "not-a-valid-url" + operator: "Bad" + description: "Invalid log" + - url: "https://ct.example.com/valid-log" + operator: "Good" + description: "Valid log" +` + configPath := writeConfigFile(t, yaml) + + cfg, err := ReadConfig(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(cfg.General.AdditionalLogs) != 1 { + t.Errorf("AdditionalLogs: want 1 valid entry (invalid filtered), got %d", len(cfg.General.AdditionalLogs)) + } +} + +func TestReadConfigViper_RecoveryConfig(t *testing.T) { + yaml := minimalValidYAML + ` +general: + recovery: + enabled: true + ct_index_file: "./my_index.json" +` + configPath := writeConfigFile(t, yaml) + + cfg, err := ReadConfig(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !cfg.General.Recovery.Enabled { + t.Errorf("Recovery.Enabled: want true, got false") + } + if cfg.General.Recovery.CTIndexFile != "./my_index.json" { + t.Errorf("Recovery.CTIndexFile: want './my_index.json', got %q", cfg.General.Recovery.CTIndexFile) + } +} diff --git a/internal/certificatetransparency/logmetrics.go b/internal/metrics/logmetrics.go similarity index 87% rename from internal/certificatetransparency/logmetrics.go rename to internal/metrics/logmetrics.go index f320ab7..8fa0f4a 100644 --- a/internal/certificatetransparency/logmetrics.go +++ b/internal/metrics/logmetrics.go @@ -1,4 +1,4 @@ -package certificatetransparency +package metrics import ( "encoding/json" @@ -22,9 +22,9 @@ type ( ) var ( - processedCerts int64 - processedPrecerts int64 - metrics = LogMetrics{metrics: make(CTMetrics), index: make(CTCertIndex)} + ProcessedCerts int64 + ProcessedPrecerts int64 + Metrics = LogMetrics{metrics: make(CTMetrics), index: make(CTCertIndex)} ) // LogMetrics is a struct that holds a map of metrics for each CT log grouped by operator. @@ -40,6 +40,8 @@ func (m *LogMetrics) GetCTMetrics() CTMetrics { m.mutex.RLock() defer m.mutex.RUnlock() + // Using maps.copy() does not copy the nested maps. + // That leads to an issue where simultaneous reads and writes to the same nested map can happen. copiedMap := make(CTMetrics) for operator, urls := range m.metrics { copiedMap[operator] = make(OperatorMetric) @@ -91,6 +93,9 @@ func (m *LogMetrics) Init(operator, url string) { if _, ok := m.index[url]; !ok { m.index[url] = 0 } + + // Register the metric for this operator and url with Prometheus + Prometheus.RegisterLog(operator, url) } // Get the metric for a given operator and ct url. @@ -162,12 +167,14 @@ func (m *LogMetrics) SetCTIndex(url string, index uint64) { m.mutex.Lock() defer m.mutex.Unlock() - log.Println("Setting CT index for ", url, " to ", index) + log.Printf("Setting CT index for %s to %d\n", url, index) m.index[url] = index } // LoadCTIndex loads the last cert index processed for each CT url if it exists. func (m *LogMetrics) LoadCTIndex(ctIndexFilePath string) { + log.Println("Loading CT indexes from file: ", ctIndexFilePath) + m.mutex.Lock() defer m.mutex.Unlock() @@ -175,13 +182,18 @@ func (m *LogMetrics) LoadCTIndex(ctIndexFilePath string) { if readErr != nil { // Create the file if it doesn't exist if os.IsNotExist(readErr) { - err := createCTIndexFile(ctIndexFilePath, m) + log.Println("CT index file does not exist, creating a new one...") + if m.index == nil { + m.index = make(CTCertIndex) + } + err := m.createCTIndexFile(ctIndexFilePath) if err != nil { log.Printf("Error creating CT index file: '%s'\n", ctIndexFilePath) log.Panicln(err) } + bytes = []byte("{}") } else { - // If the file exists but we can't read it, log the error and panic + // If the file exists, but we can't read it, log the error and panic log.Panicln(readErr) } } @@ -195,10 +207,7 @@ func (m *LogMetrics) LoadCTIndex(ctIndexFilePath string) { log.Println("Successfully loaded saved CT indexes") } -func createCTIndexFile(ctIndexFilePath string, m *LogMetrics) error { - m.mutex.RLock() - defer m.mutex.RUnlock() - +func (m *LogMetrics) createCTIndexFile(ctIndexFilePath string) error { log.Printf("Specified CT index file does not exist: '%s'\n", ctIndexFilePath) log.Println("Creating CT index file now!") @@ -207,7 +216,11 @@ func createCTIndexFile(ctIndexFilePath string, m *LogMetrics) error { log.Printf("Error creating CT index file: '%s'\n", ctIndexFilePath) log.Panicln(createErr) } + defer file.Close() + if m.index == nil { + m.index = make(CTCertIndex) + } bytes, marshalErr := json.Marshal(m.index) if marshalErr != nil { return marshalErr @@ -217,7 +230,6 @@ func createCTIndexFile(ctIndexFilePath string, m *LogMetrics) error { log.Printf("Error writing to CT index file: '%s'\n", ctIndexFilePath) log.Panicln(writeErr) } - file.Close() return nil } @@ -282,18 +294,18 @@ func (m *LogMetrics) SaveCertIndexes(ctIndexFilePath string) { // GetProcessedCerts returns the total number of processed certificates. func GetProcessedCerts() int64 { - return processedCerts + return ProcessedCerts } // GetProcessedPrecerts returns the total number of processed precertificates. func GetProcessedPrecerts() int64 { - return processedPrecerts + return ProcessedPrecerts } func GetCertMetrics() CTMetrics { - return metrics.GetCTMetrics() + return Metrics.GetCTMetrics() } func GetLogOperators() map[string][]string { - return metrics.OperatorLogMapping() + return Metrics.OperatorLogMapping() } diff --git a/internal/metrics/logmetrics_test.go b/internal/metrics/logmetrics_test.go new file mode 100644 index 0000000..00e9e69 --- /dev/null +++ b/internal/metrics/logmetrics_test.go @@ -0,0 +1,25 @@ +package metrics + +import ( + "path/filepath" + "testing" + "time" +) + +func TestLoadCTIndex_DoesNotDeadlockWhenFileMissing(t *testing.T) { + metrics := LogMetrics{metrics: make(CTMetrics), index: make(CTCertIndex)} + ctIndexPath := filepath.Join(t.TempDir(), "ct_index.json") + + done := make(chan struct{}) + go func() { + metrics.LoadCTIndex(ctIndexPath) + close(done) + }() + + select { + case <-done: + // ok + case <-time.After(200 * time.Millisecond): + t.Fatalf("LoadCTIndex appears to deadlock when index file is missing") + } +} diff --git a/internal/metrics/prometheus.go b/internal/metrics/prometheus.go index 7132f4c..6da1796 100644 --- a/internal/metrics/prometheus.go +++ b/internal/metrics/prometheus.go @@ -3,120 +3,108 @@ package metrics import ( "fmt" "io" - "strings" + "log" "sync" "time" - "github.com/d-Rickyy-b/certstream-server-go/internal/certificatetransparency" - "github.com/d-Rickyy-b/certstream-server-go/internal/web" - "github.com/VictoriaMetrics/metrics" ) -var ( - ctLogMetricsInitialized = false - ctLogMetricsInitMutex = &sync.Mutex{} - - tempCertMetricsLastRefreshed = time.Time{} - tempCertMetrics = certificatetransparency.CTMetrics{} - tempCertMetricsMutex = &sync.RWMutex{} +var Prometheus = NewPrometheusExporter() +type PrometheusExporter struct { // Number of currently connected clients. - fullClientCount = metrics.NewGauge("certstreamservergo_clients_total{type=\"full\"}", func() float64 { - return float64(web.ClientHandler.ClientFullCount()) - }) - liteClientCount = metrics.NewGauge("certstreamservergo_clients_total{type=\"lite\"}", func() float64 { - return float64(web.ClientHandler.ClientLiteCount()) - }) - domainClientCount = metrics.NewGauge("certstreamservergo_clients_total{type=\"domain\"}", func() float64 { - return float64(web.ClientHandler.ClientDomainsCount()) - }) + fullClientCount metrics.Gauge + liteClientCount metrics.Gauge + domainClientCount metrics.Gauge + + tempCertMetricsLastRefreshed time.Time + tempCertMetrics CTMetrics + tempCertMetricsMutex sync.RWMutex + + skippedCertsCallback func() map[string]int64 +} - // Number of certificates processed by the CT watcher. - processedCertificates = metrics.NewGauge("certstreamservergo_certificates_total{type=\"regular\"}", func() float64 { - return float64(certificatetransparency.GetProcessedCerts()) +// NewPrometheusExporter creates a new PrometheusExporter and registers the default metrics for the number of processed certificates. +func NewPrometheusExporter() *PrometheusExporter { + e := &PrometheusExporter{} + // Register metrics for the total number of certificates processed by the CT watcher. + metrics.GetOrCreateGauge("certstreamservergo_certificates_total{type=\"regular\"}", func() float64 { + return float64(GetProcessedCerts()) }) - processedPreCertificates = metrics.NewGauge("certstreamservergo_certificates_total{type=\"precert\"}", func() float64 { - return float64(certificatetransparency.GetProcessedPrecerts()) + metrics.GetOrCreateGauge("certstreamservergo_certificates_total{type=\"precert\"}", func() float64 { + return float64(GetProcessedPrecerts()) }) -) - -// WritePrometheus provides an easy way to write metrics to a writer. -func WritePrometheus(w io.Writer, exposeProcessMetrics bool) { - ctLogMetricsInitMutex.Lock() - if !ctLogMetricsInitialized { - initCtLogMetrics() - } - ctLogMetricsInitMutex.Unlock() + return e +} - getSkippedCertMetrics() +// Write is a callback function that is called by a webserver in order to write metrics data to the http response. +func (pm *PrometheusExporter) Write(w io.Writer, exposeProcessMetrics bool) { + // getSkippedCertMetrics() metrics.WritePrometheus(w, exposeProcessMetrics) } -// For having metrics regarding each individual CT log, we need to register them manually. -// initCtLogMetrics fetches all the CT Logs and registers one metric per log. -func initCtLogMetrics() { - logs := certificatetransparency.GetLogOperators() - - for operator, urls := range logs { - operator := operator // Copy variable to new scope - - for i := range urls { - url := urls[i] - name := fmt.Sprintf("certstreamservergo_certs_by_log_total{url=\"%s\",operator=\"%s\"}", url, operator) - metrics.NewGauge(name, func() float64 { - return float64(getCertCountForLog(operator, url)) - }) - } - } +// RegisterGaugeMetric is a helper function that registers a new gauge metric with a float64 callback function. +func (pm *PrometheusExporter) RegisterGaugeMetric(label string, callback func() float64) { + metrics.GetOrCreateGauge(label, callback) +} - if len(logs) > 0 { - ctLogMetricsInitialized = true - } +// RegisterGaugeMetricInt is a helper function that registers a new gauge metric with an int64 callback function. +func (pm *PrometheusExporter) RegisterGaugeMetricInt(label string, callback func() int64) { + metrics.GetOrCreateGauge(label, func() float64 { return float64(callback()) }) +} + +// RegisterClient registers a new gauge metric for the client with the given name. +func (pm *PrometheusExporter) RegisterClient(name string, callback func() float64) { + label := fmt.Sprintf("certstreamservergo_skipped_certs{client=\"%s\"}", name) + metrics.GetOrCreateGauge(label, callback) +} + +// UnregisterClient unregisters the metric for the client with the given name. +func (pm *PrometheusExporter) UnregisterClient(name string) { + label := fmt.Sprintf("certstreamservergo_skipped_certs{client=\"%s\"}", name) + metrics.UnregisterMetric(label) +} + +// RegisterLog registers a new gauge metric for the given CT log. +// The metric will be named "certstreamservergo_certs_by_log_total{url=\"\",operator=\"\"}" and +// will call the given callback function to get the current value of the metric. +func (pm *PrometheusExporter) RegisterLog(operatorName, url string) { + label := fmt.Sprintf("certstreamservergo_certs_by_log_total{url=\"%s\",operator=\"%s\"}", url, operatorName) + metrics.GetOrCreateGauge(label, func() float64 { + return float64(pm.getCertCountForLog(operatorName, url)) + }) +} + +// UnregisterMetric unregisters a metric with a given label. +func (pm *PrometheusExporter) UnregisterMetric(label string) { + metrics.UnregisterMetric(label) } // getCertCountForLog returns the number of certificates processed from a specific CT log. // It caches the result for 5 seconds. Subsequent calls to this method will return the cached result. -func getCertCountForLog(operatorName, logname string) int64 { - tempCertMetricsMutex.Lock() - defer tempCertMetricsMutex.Unlock() +func (pm *PrometheusExporter) getCertCountForLog(operatorName, logname string) int64 { + pm.tempCertMetricsMutex.Lock() + defer pm.tempCertMetricsMutex.Unlock() // Add some caching to avoid having to lock the mutex every time - if time.Since(tempCertMetricsLastRefreshed) > time.Second*5 { - tempCertMetricsLastRefreshed = time.Now() - tempCertMetrics = certificatetransparency.GetCertMetrics() + if time.Since(pm.tempCertMetricsLastRefreshed) > time.Second*5 { + pm.tempCertMetricsLastRefreshed = time.Now() + pm.tempCertMetrics = GetCertMetrics() } - return tempCertMetrics[operatorName][logname] -} - -// getSkippedCertMetrics gets the number of skipped certificates for each client and creates metrics for it. -// It also removes metrics for clients that are not connected anymore. -func getSkippedCertMetrics() { - skippedCerts := web.ClientHandler.GetSkippedCerts() - for clientName := range skippedCerts { - // Get or register a new counter for each client - metricName := fmt.Sprintf("certstreamservergo_skipped_certs{client=\"%s\"}", clientName) - c := metrics.GetOrCreateCounter(metricName) - c.Set(skippedCerts[clientName]) + operatorMetrics, ok := pm.tempCertMetrics[operatorName] + if !ok { + log.Printf("No metrics for operator \"%s\"", operatorName) + return 0 } - // Remove all metrics that are not in the list of current client skipped cert metrics - // Get a list of current client skipped cert metrics - for _, metricName := range metrics.ListMetricNames() { - if !strings.HasPrefix(metricName, "certstreamservergo_skipped_certs") { - continue - } - - clientName := strings.TrimPrefix(metricName, "certstreamservergo_skipped_certs{client=\"") - clientName = strings.TrimSuffix(clientName, "\"}") - - // Check if the registered metric is in the list of current client skipped cert metrics - // If not, unregister the metric - _, exists := skippedCerts[clientName] - if !exists { - metrics.UnregisterMetric(metricName) - } + count, ok := operatorMetrics[logname] + if !ok { + log.Printf("No metrics for log \"%s\" of operator \"%s\"", logname, operatorName) + return 0 } + + return count } diff --git a/internal/web/broadcastmanager.go b/internal/web/broadcastmanager.go index 18382be..6f96efb 100644 --- a/internal/web/broadcastmanager.go +++ b/internal/web/broadcastmanager.go @@ -4,6 +4,7 @@ import ( "log" "sync" + "github.com/d-Rickyy-b/certstream-server-go/internal/metrics" "github.com/d-Rickyy-b/certstream-server-go/internal/models" ) @@ -13,12 +14,21 @@ type BroadcastManager struct { clientLock sync.RWMutex } +func NewBroadcastManager() *BroadcastManager { + bm := &BroadcastManager{} + metrics.Prometheus.RegisterGaugeMetricInt("certstreamservergo_clients_total{type=\"full\"}", bm.ClientFullCount) + metrics.Prometheus.RegisterGaugeMetricInt("certstreamservergo_clients_total{type=\"lite\"}", bm.ClientLiteCount) + metrics.Prometheus.RegisterGaugeMetricInt("certstreamservergo_clients_total{type=\"domain\"}", bm.ClientDomainsCount) + return bm +} + // registerClient adds a client to the list of clients of the BroadcastManager. // The client will receive certificate broadcasts right after registration. func (bm *BroadcastManager) registerClient(c *client) { bm.clientLock.Lock() bm.clients = append(bm.clients, c) log.Printf("Clients: %d, Capacity: %d\n", len(bm.clients), cap(bm.clients)) + metrics.Prometheus.RegisterClient(c.name, func() float64 { return float64(c.skippedCerts) }) bm.clientLock.Unlock() } @@ -38,6 +48,8 @@ func (bm *BroadcastManager) unregisterClient(c *client) { // Close the broadcast channel of the client, otherwise this leads to a memory leak close(c.broadcastChan) + metrics.Prometheus.UnregisterClient(c.name) + break } } diff --git a/internal/web/server.go b/internal/web/server.go index 50a55b6..ab94be6 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -21,7 +21,7 @@ import ( ) var ( - ClientHandler = BroadcastManager{} + ClientHandler = NewBroadcastManager() upgrader websocket.Upgrader )