diff --git a/backend/pkg/db/postgres/pool/pool.go b/backend/pkg/db/postgres/pool/pool.go index f3e46d03b..f6d82e6c3 100644 --- a/backend/pkg/db/postgres/pool/pool.go +++ b/backend/pkg/db/postgres/pool/pool.go @@ -84,7 +84,10 @@ func (p *poolImpl) Begin() (*Tx, error) { tx, err := p.conn.Begin(context.Background()) p.metrics.RecordRequestDuration(float64(time.Now().Sub(start).Milliseconds()), "begin", "") p.metrics.IncreaseTotalRequests("begin", "") - return &Tx{tx, p.metrics}, err + return &Tx{ + origTx: tx, + metrics: p.metrics, + }, err } func (p *poolImpl) Close() { @@ -94,13 +97,13 @@ func (p *poolImpl) Close() { // TX - start type Tx struct { - pgx.Tx + origTx pgx.Tx metrics database.Database } func (tx *Tx) TxExec(sql string, args ...interface{}) error { start := time.Now() - _, err := tx.Exec(context.Background(), sql, args...) + _, err := tx.origTx.Exec(context.Background(), sql, args...) method, table := methodName(sql) tx.metrics.RecordRequestDuration(float64(time.Now().Sub(start).Milliseconds()), method, table) tx.metrics.IncreaseTotalRequests(method, table) @@ -109,7 +112,7 @@ func (tx *Tx) TxExec(sql string, args ...interface{}) error { func (tx *Tx) TxQueryRow(sql string, args ...interface{}) pgx.Row { start := time.Now() - res := tx.QueryRow(context.Background(), sql, args...) + res := tx.origTx.QueryRow(context.Background(), sql, args...) method, table := methodName(sql) tx.metrics.RecordRequestDuration(float64(time.Now().Sub(start).Milliseconds()), method, table) tx.metrics.IncreaseTotalRequests(method, table) @@ -118,7 +121,7 @@ func (tx *Tx) TxQueryRow(sql string, args ...interface{}) pgx.Row { func (tx *Tx) TxRollback() error { start := time.Now() - err := tx.Rollback(context.Background()) + err := tx.origTx.Rollback(context.Background()) tx.metrics.RecordRequestDuration(float64(time.Now().Sub(start).Milliseconds()), "rollback", "") tx.metrics.IncreaseTotalRequests("rollback", "") return err @@ -126,7 +129,7 @@ func (tx *Tx) TxRollback() error { func (tx *Tx) TxCommit() error { start := time.Now() - err := tx.Commit(context.Background()) + err := tx.origTx.Commit(context.Background()) tx.metrics.RecordRequestDuration(float64(time.Now().Sub(start).Milliseconds()), "commit", "") tx.metrics.IncreaseTotalRequests("commit", "") return err