Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ The following environment variables configure the exporter:
* `DATA_SOURCE_PASS_FILE`
The same as above but reads the password from a file.

* `PG_EXPORTER_COLLECTION_TIMEOUT`
Timeout duration to use when collecting the statistics, default to `1m`.
When the timeout is reached, the database connection will be dropped.
It avoids connections stacking when the database answers too slowly
(for instance if the database creates/drop a huge table and locks the tables)
and will avoid exhausting the pool of connections of the database.
Value of `0` or less than `1ms` is considered invalid and will report an error.

* `PG_EXPORTER_WEB_TELEMETRY_PATH`
Path under which to expose metrics. Default is `/metrics`.

Expand Down
3 changes: 2 additions & 1 deletion cmd/postgres_exporter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var (
excludeDatabases = kingpin.Flag("exclude-databases", "A list of databases to remove when autoDiscoverDatabases is enabled (DEPRECATED)").Default("").Envar("PG_EXPORTER_EXCLUDE_DATABASES").String()
includeDatabases = kingpin.Flag("include-databases", "A list of databases to include when autoDiscoverDatabases is enabled (DEPRECATED)").Default("").Envar("PG_EXPORTER_INCLUDE_DATABASES").String()
metricPrefix = kingpin.Flag("metric-prefix", "A metric prefix can be used to have non-default (not \"pg\") prefixes for each of the metrics").Default("pg").Envar("PG_EXPORTER_METRIC_PREFIX").String()
collectionTimeout = kingpin.Flag("collection-timeout", "Timeout for collecting the statistics when the database is slow").Default("1m").Envar("PG_EXPORTER_COLLECTION_TIMEOUT").String()
logger = promslog.NewNopLogger()
)

Expand Down Expand Up @@ -137,7 +138,7 @@ func main() {
excludedDatabases,
dsn,
[]string{},
)
collector.WithCollectionTimeout(*collectionTimeout))
if err != nil {
logger.Warn("Failed to create PostgresCollector", "err", err.Error())
} else {
Expand Down
26 changes: 23 additions & 3 deletions collector/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ type PostgresCollector struct {
Collectors map[string]Collector
logger *slog.Logger

instance *instance
instance *instance
CollectionTimeout time.Duration
}

type Option func(*PostgresCollector) error
Expand Down Expand Up @@ -158,6 +159,20 @@ func NewPostgresCollector(logger *slog.Logger, excludeDatabases []string, dsn st
return p, nil
}

func WithCollectionTimeout(s string) Option {
return func(e *PostgresCollector) error {
duration, err := time.ParseDuration(s)
if err != nil {
return err
}
if duration < 1*time.Millisecond {
return errors.New("timeout must be greater than 1ms")
}
e.CollectionTimeout = duration
return nil
}
}

// Describe implements the prometheus.Collector interface.
func (p PostgresCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- scrapeDurationDesc
Expand All @@ -166,8 +181,6 @@ func (p PostgresCollector) Describe(ch chan<- *prometheus.Desc) {

// Collect implements the prometheus.Collector interface.
func (p PostgresCollector) Collect(ch chan<- prometheus.Metric) {
ctx := context.TODO()

// copy the instance so that concurrent scrapes have independent instances
inst := p.instance.copy()

Expand All @@ -178,6 +191,13 @@ func (p PostgresCollector) Collect(ch chan<- prometheus.Metric) {
p.logger.Error("Error opening connection to database", "err", err)
return
}
p.collectFromConnection(inst, ch)
}

func (p PostgresCollector) collectFromConnection(inst *instance, ch chan<- prometheus.Metric) {
// Eventually, connect this to the http scraping context
ctx, cancel := context.WithTimeout(context.Background(), p.CollectionTimeout)
defer cancel()

wg := sync.WaitGroup{}
wg.Add(len(p.Collectors))
Expand Down
73 changes: 73 additions & 0 deletions collector/collector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ package collector

import (
"strings"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/promslog"
)

type labelMap map[string]string
Expand Down Expand Up @@ -60,3 +64,72 @@ func sanitizeQuery(q string) string {
q = strings.ReplaceAll(q, "$", "\\$")
return q
}

// We ensure that when the database respond after a long time
// The collection process still occurs in a predictable manner
// Will avoid accumulation of queries on a completely frozen DB
func TestWithConnectionTimeout(t *testing.T) {

timeoutForQuery := time.Duration(100 * time.Millisecond)

db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("Error opening a stub db connection: %s", err)
}
defer db.Close()

inst := &instance{db: db}

columns := []string{"pg_roles.rolname", "pg_roles.rolconnlimit"}
rows := sqlmock.NewRows(columns).AddRow("role1", 2)
mock.ExpectQuery(pgRolesConnectionLimitsQuery).
WillDelayFor(30 * time.Second).
WillReturnRows(rows)

log_config := promslog.Config{}

logger := promslog.New(&log_config)

c, err := NewPostgresCollector(logger, []string{}, "postgresql://local", []string{}, WithCollectionTimeout(timeoutForQuery.String()))
if err != nil {
t.Fatalf("error creating NewPostgresCollector: %s", err)
}
collector_config := collectorConfig{
logger: logger,
excludeDatabases: []string{},
}

collector, err := NewPGRolesCollector(collector_config)
if err != nil {
t.Fatalf("error creating collector: %s", err)
}
c.Collectors["test"] = collector
c.instance = inst

ch := make(chan prometheus.Metric)
defer close(ch)

go func() {
for {
<-ch
time.Sleep(1 * time.Millisecond)
}
}()

startTime := time.Now()
c.collectFromConnection(inst, ch)
elapsed := time.Since(startTime)

if elapsed <= timeoutForQuery {
t.Errorf("elapsed time was %v, should be bigger than timeout=%v", elapsed, timeoutForQuery)
}

// Ensure we took more than timeout, but not too much
if elapsed >= timeoutForQuery+500*time.Millisecond {
t.Errorf("elapsed time was %v, should not be much bigger than timeout=%v", elapsed, timeoutForQuery)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled exceptions: %s", err)
}
}