diff --git a/internal/storage/writer.go b/internal/storage/writer.go index 576b2c2..993767a 100644 --- a/internal/storage/writer.go +++ b/internal/storage/writer.go @@ -1,6 +1,9 @@ package storage import ( + "database/sql" + "errors" + "fmt" "time" ) @@ -14,53 +17,59 @@ func WriteReq(db *RequestWriter, resultCh <-chan *LogEntry) { defer ticker.Stop() flush := func() { - if len(batch) == 0 { - return - } - - tx, err := db.db.Begin() - if err != nil { - db.logger.Error("Failed to begin transaction", "error", err) - return - } - - stmt, err := tx.Prepare( - "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", - ) - if err != nil { - db.logger.Error("Failed to prepare statement", "error", err) - if rollbackErr := tx.Rollback(); rollbackErr != nil { - db.logger.Error("Failed to rollback transaction", "error", rollbackErr) + defer db.logger.Debug("Flushed batch", "count", len(batch)) + err := func() (err error) { + if len(batch) == 0 { + return nil } - return - } - defer func() { - if closeErr := stmt.Close(); closeErr != nil { - db.logger.Error("Failed to close statement", "error", closeErr) - } - }() - for _, entry := range batch { - _, err := stmt.Exec( - entry.Service, - entry.IP, - entry.Path, - entry.Method, - entry.Status, - time.Now().Format(time.RFC3339), + tx, err := db.db.Begin() + if err != nil { + return fmt.Errorf("Failed to begin transaction: %w", err) + } + defer func() { + if rollbackErr := tx.Rollback(); rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) { + err = errors.Join(err, fmt.Errorf("Failed to rollback transaction: %w", rollbackErr)) + } + }() + + stmt, err := tx.Prepare( + "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", ) if err != nil { - db.logger.Error("Failed to insert entry", "error", err) + err = fmt.Errorf("Failed to prepare statement: %w", err) + return err } - } + defer func() { + if closeErr := stmt.Close(); closeErr != nil { + err = errors.Join(err, fmt.Errorf("Failed to close statement: %w", closeErr)) + } + }() - if err := tx.Commit(); err != nil { - db.logger.Error("Failed to commit transaction", "error", err) - return - } + for _, entry := range batch { + _, err := stmt.Exec( + entry.Service, + entry.IP, + entry.Path, + entry.Method, + entry.Status, + time.Now().Format(time.RFC3339), + ) + if err != nil { + db.logger.Error(fmt.Errorf("Failed to insert entry: %w", err).Error()) + } + } - db.logger.Debug("Flushed batch", "count", len(batch)) - batch = batch[:0] + if err := tx.Commit(); err != nil { + return fmt.Errorf("Failed to commit transaction: %w", err) + } + + batch = batch[:0] + return err + }() + if err != nil { + db.logger.Error(err.Error()) + } } for {