Skip to content

Commit b0e1b94

Browse files
committed
Support libpq-style key-value connection string.
1 parent 13cd02f commit b0e1b94

File tree

1 file changed

+61
-45
lines changed

1 file changed

+61
-45
lines changed

collector/collect.go

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import (
3333

3434
"github.com/jackc/pgx/v5"
3535
"github.com/jackc/pgx/v5/pgtype"
36-
"github.com/jackc/pgx/v5/stdlib"
36+
_ "github.com/jackc/pgx/v5/stdlib"
3737
"github.com/rapidloop/pgmetrics"
3838
"golang.org/x/mod/semver"
3939
)
@@ -180,27 +180,57 @@ func getRegexp(r string) (rx *regexp.Regexp) {
180180
func Collect(o CollectConfig, dbnames []string) *pgmetrics.Model {
181181
// form connection string
182182
var connstr string
183-
if len(o.Host) > 0 {
184-
connstr += makeKV("host", o.Host)
185-
}
186-
connstr += makeKV("port", strconv.Itoa(int(o.Port)))
187-
if len(o.User) > 0 {
188-
connstr += makeKV("user", o.User)
183+
mode := "postgres"
184+
// Support supplying the connection string itself as an argument. If this
185+
// is specified, it takes precedence over other command-line options.
186+
if len(dbnames) == 1 {
187+
// see if this is actually a connection string
188+
cfg, err := pgx.ParseConfig(dbnames[0])
189+
if err == nil {
190+
// yes it is, use it
191+
connstr = cfg.ConnString() + " "
192+
if cfg.Database == "pgbouncer" {
193+
mode = "pgbouncer"
194+
}
195+
dbnames = dbnames[1:]
196+
}
189197
}
190-
if len(o.Password) > 0 {
191-
connstr += makeKV("password", o.Password)
198+
if len(connstr) == 0 {
199+
// connection string was not specified, use command-line options
200+
if len(o.Host) > 0 {
201+
connstr += makeKV("host", o.Host)
202+
}
203+
connstr += makeKV("port", strconv.Itoa(int(o.Port)))
204+
if len(o.User) > 0 {
205+
connstr += makeKV("user", o.User)
206+
}
207+
if len(o.Password) > 0 {
208+
connstr += makeKV("password", o.Password)
209+
}
210+
// pgmetrics defaults to sslmode=disable if unset. Explicitly set
211+
// the environment variable PGSSLMODE before invoking pgmetrics if you want
212+
// a different behavior.
213+
if os.Getenv("PGSSLMODE") == "" {
214+
connstr += makeKV("sslmode", "disable")
215+
}
216+
connstr += makeKV("application_name", "pgmetrics")
217+
if len(dbnames) == 1 && dbnames[0] == "pgbouncer" {
218+
mode = "pgbouncer"
219+
}
192220
}
193-
if os.Getenv("PGSSLMODE") == "" {
194-
connstr += makeKV("sslmode", "disable")
221+
if o.Pgpool {
222+
mode = "pgpool"
195223
}
196-
connstr += makeKV("application_name", "pgmetrics")
197224

198225
// set timeouts (but not for pgbouncer, it does not like them)
199-
if !(len(dbnames) == 1 && dbnames[0] == "pgbouncer") {
226+
if mode != "pgbouncer" {
200227
connstr += makeKV("lock_timeout", strconv.Itoa(int(o.LockTimeoutMillisec)))
201228
connstr += makeKV("statement_timeout", strconv.Itoa(int(o.TimeoutSec)*1000))
202229
}
203230

231+
// use simple protocol for maximum compatibility (pgx-specific keyword)
232+
connstr += makeKV("default_query_exec_mode", "simple_protocol")
233+
204234
// if "all DBs" was specified, collect the names of databases first
205235
if o.AllDBs {
206236
dbnames = getDBNames(connstr, o)
@@ -209,6 +239,7 @@ func Collect(o CollectConfig, dbnames []string) *pgmetrics.Model {
209239
// collect from 1 or more DBs
210240
c := &collector{
211241
dbnames: dbnames,
242+
mode: mode,
212243
}
213244
if len(dbnames) == 0 {
214245
collectFromDB(connstr, c, o)
@@ -236,42 +267,28 @@ func Collect(o CollectConfig, dbnames []string) *pgmetrics.Model {
236267
}
237268

238269
func getConn(connstr string, o CollectConfig) *sql.DB {
239-
// open database/sql connection
240-
cfg, err := pgx.ParseConfig(connstr)
270+
db, err := sql.Open("pgx", connstr)
241271
if err != nil {
242-
log.Fatalf("failed to parse connection string: %v", err)
272+
log.Fatalf("failed to open connection: %v", err)
243273
}
244-
// use simple protocol for maximum compatibility
245-
cfg.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
246-
db := stdlib.OpenDB(*cfg)
247274

248-
// ping (does not work with pgx+pgbouncer)
249-
if cfg.Database != "pgbouncer" {
250-
t := time.Duration(o.TimeoutSec) * time.Second
251-
ctx, cancel := context.WithTimeout(context.Background(), t)
252-
defer cancel()
253-
if err := db.PingContext(ctx); err != nil {
254-
log.Fatal(err)
255-
}
256-
}
275+
// ensure only 1 conn
276+
db.SetMaxIdleConns(1)
277+
db.SetMaxOpenConns(1)
257278

258279
// set role, if specified
259280
if len(o.Role) > 0 {
260281
if !isValidIdent(o.Role) {
261282
log.Fatalf("bad format for role %q", o.Role)
262283
}
263-
t2 := time.Duration(o.TimeoutSec) * time.Second
264-
ctx2, cancel2 := context.WithTimeout(context.Background(), t2)
265-
defer cancel2()
266-
if _, err := db.ExecContext(ctx2, "SET ROLE "+o.Role); err != nil {
284+
t := time.Duration(o.TimeoutSec) * time.Second
285+
ctx, cancel := context.WithTimeout(context.Background(), t)
286+
defer cancel()
287+
if _, err := db.ExecContext(ctx, "SET ROLE "+o.Role); err != nil {
267288
log.Fatalf("failed to set role %q: %v", o.Role, err)
268289
}
269290
}
270291

271-
// ensure only 1 conn
272-
db.SetMaxIdleConns(1)
273-
db.SetMaxOpenConns(1)
274-
275292
return db
276293
}
277294

@@ -331,6 +348,7 @@ type collector struct {
331348
logSpan uint
332349
currLog pgmetrics.LogEntry
333350
rxPrefix *regexp.Regexp
351+
mode string // "postgres", "pgbouncer" or "pgpool"
334352
}
335353

336354
func (c *collector) collect(db *sql.DB, o CollectConfig) {
@@ -363,20 +381,18 @@ func (c *collector) collectFirst(db *sql.DB, o CollectConfig) {
363381
c.result.Metadata.Version = pgmetrics.ModelSchemaVersion
364382

365383
// collect either postgres, pgbouncer or pgpool metrics
366-
if o.Pgpool {
367-
// pgpool mode:
368-
c.result.Metadata.Mode = "pgpool"
384+
c.result.Metadata.Mode = c.mode
385+
switch c.mode {
386+
case "pgpool":
369387
c.getCurrentUser()
370388
c.collectPgpool()
371-
} else if len(c.dbnames) == 1 && c.dbnames[0] == "pgbouncer" {
372-
// pgbouncer mode:
373-
c.result.Metadata.Mode = "pgbouncer"
389+
case "pgbouncer":
374390
c.collectPgBouncer()
375-
} else {
376-
// postgres mode:
377-
c.result.Metadata.Mode = "postgres"
391+
case "postgres":
378392
c.getCurrentUser()
379393
c.collectPostgres(o)
394+
default:
395+
log.Fatalf("unknown mode %q", c.mode)
380396
}
381397
}
382398

0 commit comments

Comments
 (0)