diff --git a/postgres_exporter.go b/postgres_exporter.go index bd393b40..c4efed13 100644 --- a/postgres_exporter.go +++ b/postgres_exporter.go @@ -28,8 +28,6 @@ import ( // (semantic version)-(commitish) form. var Version = "0.0.1" -var sharedDBConn *sql.DB - var ( listenAddress = flag.String( "web.listen-address", ":9187", @@ -673,6 +671,11 @@ type Exporter struct { duration, error prometheus.Gauge totalScrapes prometheus.Counter + // dbDsn is the connection string used to establish the dbConnection + dbDsn string + // dbConnection is used to allow re-using the DB connection between scrapes + dbConnection *sql.DB + // Last version used to calculate metric map. If mismatch on scrape, // then maps are recalculated. lastMapVersion semver.Version @@ -923,8 +926,16 @@ func (e *Exporter) checkMapVersions(ch chan<- prometheus.Metric, db *sql.DB) err return nil } -func getDB(conn string) (*sql.DB, error) { - if sharedDBConn == nil { +func (e *Exporter) getDB(conn string) (*sql.DB, error) { + // Has dsn changed? + if (e.dbConnection != nil) && (e.dsn != e.dbDsn) { + err := e.dbConnection.Close() + log.Warnln("Error while closing obsolete DB connection:", err) + e.dbConnection = nil + e.dbDsn = "" + } + + if e.dbConnection == nil { d, err := sql.Open("postgres", conn) if err != nil { return nil, err @@ -935,10 +946,12 @@ func getDB(conn string) (*sql.DB, error) { } d.SetMaxOpenConns(1) d.SetMaxIdleConns(1) - sharedDBConn = d + e.dbConnection = d + e.dbDsn = e.dsn + log.Infoln("Established new database connection.") } - return sharedDBConn, nil + return e.dbConnection, nil } func (e *Exporter) scrape(ch chan<- prometheus.Metric) { @@ -949,10 +962,11 @@ func (e *Exporter) scrape(ch chan<- prometheus.Metric) { e.error.Set(0) e.totalScrapes.Inc() - db, err := getDB(e.dsn) + db, err := e.getDB(e.dsn) if err != nil { loggableDsn := "could not parse DATA_SOURCE_NAME" if pDsn, pErr := url.Parse(e.dsn); pErr != nil { + log.Debugln("Blanking password for loggable DSN:", e.dsn) pDsn.User = url.UserPassword(pDsn.User.Username(), "xxx") loggableDsn = pDsn.String() }