diff --git a/cmd/banforge/command/fw.go b/cmd/banforge/command/fw.go index ae4591e..50a7f0a 100644 --- a/cmd/banforge/command/fw.go +++ b/cmd/banforge/command/fw.go @@ -16,53 +16,53 @@ var ( port int protocol string ) + var UnbanCmd = &cobra.Command{ Use: "unban", Short: "Unban IP", Run: func(cmd *cobra.Command, args []string) { - if len(args) == 0 { - fmt.Println("IP can't be empty") - os.Exit(1) - } - if ttl_fw == "" { - ttl_fw = "1y" - } - ip := args[0] - db, err := storage.NewBanWriter() + err := func() error { + if len(args) == 0 { + return fmt.Errorf("IP can't be empty") + } + if ttl_fw == "" { + ttl_fw = "1y" + } + ip := args[0] + db, err := storage.NewBanWriter() + if err != nil { + return err + } + cfg, err := config.LoadConfig() + if err != nil { + return err + } + fw := cfg.Firewall.Name + b := blocker.GetBlocker(fw, cfg.Firewall.Config) + if ip == "" { + return fmt.Errorf("IP can't be empty") + } + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid IP") + } + if err != nil { + return err + } + err = b.Unban(ip) + if err != nil { + return err + } + err = db.RemoveBan(ip) + if err != nil { + return err + } + fmt.Println("IP unblocked successfully!") + return nil + }() if err != nil { fmt.Println(err) os.Exit(1) } - cfg, err := config.LoadConfig() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - fw := cfg.Firewall.Name - b := blocker.GetBlocker(fw, cfg.Firewall.Config) - if ip == "" { - fmt.Println("IP can't be empty") - os.Exit(1) - } - if net.ParseIP(ip) == nil { - fmt.Println("Invalid IP") - os.Exit(1) - } - if err != nil { - fmt.Println(err) - os.Exit(1) - } - err = b.Unban(ip) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - err = db.RemoveBan(ip) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - fmt.Println("IP unblocked successfully!") }, } @@ -70,49 +70,48 @@ var BanCmd = &cobra.Command{ Use: "ban", Short: "Ban IP", Run: func(cmd *cobra.Command, args []string) { - if len(args) == 0 { - fmt.Println("IP can't be empty") - os.Exit(1) - } - if ttl_fw == "" { - ttl_fw = "1y" - } - ip := args[0] - db, err := storage.NewBanWriter() + err := func() error { + if len(args) == 0 { + return fmt.Errorf("IP can't be empty") + } + if ttl_fw == "" { + ttl_fw = "1y" + } + ip := args[0] + db, err := storage.NewBanWriter() + if err != nil { + return err + } + cfg, err := config.LoadConfig() + if err != nil { + return err + } + fw := cfg.Firewall.Name + b := blocker.GetBlocker(fw, cfg.Firewall.Config) + if ip == "" { + return fmt.Errorf("IP can't be empty") + } + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid IP") + } + if err != nil { + return err + } + err = b.Ban(ip) + if err != nil { + return err + } + err = db.AddBan(ip, ttl_fw, "manual ban") + if err != nil { + return err + } + fmt.Println("IP blocked successfully!") + return nil + }() if err != nil { fmt.Println(err) os.Exit(1) } - cfg, err := config.LoadConfig() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - fw := cfg.Firewall.Name - b := blocker.GetBlocker(fw, cfg.Firewall.Config) - if ip == "" { - fmt.Println("IP can't be empty") - os.Exit(1) - } - if net.ParseIP(ip) == nil { - fmt.Println("Invalid IP") - os.Exit(1) - } - if err != nil { - fmt.Println(err) - os.Exit(1) - } - err = b.Ban(ip) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - err = db.AddBan(ip, ttl_fw, "manual ban") - if err != nil { - fmt.Println(err) - os.Exit(1) - } - fmt.Println("IP blocked successfully!") }, } diff --git a/internal/blocker/firewalld.go b/internal/blocker/firewalld.go index a28506d..2a0ba1b 100644 --- a/internal/blocker/firewalld.go +++ b/internal/blocker/firewalld.go @@ -94,7 +94,6 @@ func (f *Firewalld) PortClose(port int, protocol string) error { // #nosec G204 - handle is extracted from nftables output and validated if port >= 0 && port <= 65535 { if protocol != "tcp" && protocol != "udp" { - f.logger.Error("invalid protocol") return fmt.Errorf("invalid protocol") } s := strconv.Itoa(port) @@ -106,13 +105,11 @@ func (f *Firewalld) PortClose(port int, protocol string) error { ) output, err := cmd.CombinedOutput() if err != nil { - f.logger.Error(err.Error()) return err } f.logger.Info("Remove port " + s + " " + string(output)) output, err = exec.Command("firewall-cmd", "--reload").CombinedOutput() if err != nil { - f.logger.Error(err.Error()) return err } f.logger.Info("Reload " + string(output)) diff --git a/internal/config/appconf.go b/internal/config/appconf.go index 69933ee..bc76592 100644 --- a/internal/config/appconf.go +++ b/internal/config/appconf.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "os" "strconv" @@ -57,13 +58,9 @@ func NewRule( return err } defer func() { - err = file.Close() - if err != nil { - fmt.Println(err) - } + err = errors.Join(err, file.Close()) }() cfg := Rules{Rules: r} - err = toml.NewEncoder(file).Encode(cfg) if err != nil { return err @@ -126,24 +123,24 @@ func EditRule(Name string, ServiceName string, Path string, Status string, Metho } func ParseDurationWithYears(s string) (time.Duration, error) { - if strings.HasSuffix(s, "y") { - years, err := strconv.Atoi(strings.TrimSuffix(s, "y")) + if ss, ok := strings.CutSuffix(s, "y"); ok { + years, err := strconv.Atoi(ss) if err != nil { return 0, err } return time.Duration(years) * 365 * 24 * time.Hour, nil } - if strings.HasSuffix(s, "M") { - months, err := strconv.Atoi(strings.TrimSuffix(s, "M")) + if ss, ok := strings.CutSuffix(s, "M"); ok { + months, err := strconv.Atoi(ss) if err != nil { return 0, err } return time.Duration(months) * 30 * 24 * time.Hour, nil } - if strings.HasSuffix(s, "d") { - days, err := strconv.Atoi(strings.TrimSuffix(s, "d")) + if ss, ok := strings.CutSuffix(s, "d"); ok { + days, err := strconv.Atoi(ss) if err != nil { return 0, err } diff --git a/internal/storage/ban_db.go b/internal/storage/ban_db.go index b8470b4..4553f19 100644 --- a/internal/storage/ban_db.go +++ b/internal/storage/ban_db.go @@ -21,7 +21,7 @@ type BanWriter struct { func NewBanWriter() (*BanWriter, error) { db, err := sql.Open( "sqlite", - "/var/lib/banforge/bans.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)", + buildSqliteDsn(banDBPath, pragmas), ) if err != nil { return nil, err @@ -175,7 +175,6 @@ func (d *BanReader) IsBanned(ip string) (bool, error) { } func (d *BanReader) BanList() error { - var count int t := table.NewWriter() t.SetOutputMirror(os.Stdout) diff --git a/internal/storage/db.go b/internal/storage/db.go index ff82e21..e1d2a45 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -2,55 +2,60 @@ package storage import ( "database/sql" + "errors" "fmt" + "strings" _ "modernc.org/sqlite" ) -func CreateTables() error { - // Requests DB - db_r, err := sql.Open("sqlite", - "/var/lib/banforge/requests.db?"+ - "mode=rwc&"+ - "_pragma=journal_mode(WAL)&"+ - "_pragma=busy_timeout(30000)&"+ - "_pragma=synchronous(NORMAL)") - if err != nil { - return fmt.Errorf("failed to open requests db: %w", err) - } - defer func() { - err = db_r.Close() - if err != nil { - fmt.Println(err) - } - }() +const ( + DBDir = "/var/lib/banforge/" + ReqDBPath = DBDir + "requests.db" + banDBPath = DBDir + "bans.db" +) - _, err = db_r.Exec(CreateRequestsTable) - if err != nil { - return fmt.Errorf("failed to create requests table: %w", err) - } - - // Bans DB - db_b, err := sql.Open("sqlite", - "/var/lib/banforge/bans.db?"+ - "mode=rwc&"+ - "_pragma=journal_mode(WAL)&"+ - "_pragma=busy_timeout(30000)&"+ - "_pragma=synchronous(FULL)") - if err != nil { - return fmt.Errorf("failed to open bans db: %w", err) - } - defer func() { - err = db_b.Close() - if err != nil { - fmt.Println(err) - } - }() - - _, err = db_b.Exec(CreateBansTable) - if err != nil { - return fmt.Errorf("failed to create bans table: %w", err) - } - fmt.Println("Tables created successfully!") - return nil +var pragmas = map[string]string{ + `journal_mode`: `wal`, + `synchronous`: `normal`, + `busy_timeout`: `30000`, + // also consider these + // `temp_store`: `memory`, + // `cache_size`: `1000000000`, +} + +func buildSqliteDsn(path string, pragmas map[string]string) string { + pragmastrs := make([]string, len(pragmas)) + i := 0 + for k, v := range pragmas { + pragmastrs[i] = (fmt.Sprintf(`pragma=%s(%s)`, k, v)) + i++ + } + return path + "?" + "mode=rwc&" + strings.Join(pragmastrs, "&") +} + +func initDB(dsn, sqlstr string) (err error) { + db, err := sql.Open("sqlite", dsn) + if err != nil { + return fmt.Errorf("failed to open %q: %w", dsn, err) + } + defer func() { + closeErr := db.Close() + if closeErr != nil { + err = errors.Join(err, fmt.Errorf("failed to close %q: %w", dsn, closeErr)) + } + }() + _, err = db.Exec(sqlstr) + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + return err +} + +func CreateTables() (err error) { + // Requests DB + err1 := initDB(buildSqliteDsn(ReqDBPath, pragmas), CreateRequestsTable) + err2 := initDB(buildSqliteDsn(banDBPath, pragmas), CreateBansTable) + + return errors.Join(err1, err2) } diff --git a/internal/storage/requests_db.go b/internal/storage/requests_db.go index de29c2a..379a592 100644 --- a/internal/storage/requests_db.go +++ b/internal/storage/requests_db.go @@ -7,15 +7,15 @@ import ( _ "modernc.org/sqlite" ) -type Request_Writer struct { +type RequestWriter struct { logger *logger.Logger db *sql.DB } -func NewRequestsWr() (*Request_Writer, error) { +func NewRequestsWr() (*RequestWriter, error) { db, err := sql.Open( "sqlite", - "/var/lib/banforge/requests.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)", + buildSqliteDsn(ReqDBPath, pragmas), ) if err != nil { return nil, err @@ -23,7 +23,7 @@ func NewRequestsWr() (*Request_Writer, error) { db.SetMaxOpenConns(1) db.SetMaxIdleConns(1) db.SetConnMaxLifetime(0) - return &Request_Writer{ + return &RequestWriter{ logger: logger.New(false), db: db, }, nil diff --git a/internal/storage/writer.go b/internal/storage/writer.go index 8b63be1..f135feb 100644 --- a/internal/storage/writer.go +++ b/internal/storage/writer.go @@ -1,10 +1,13 @@ package storage import ( + "database/sql" + "errors" + "fmt" "time" ) -func WriteReq(db *Request_Writer, resultCh <-chan *LogEntry) { +func WriteReq(db *RequestWriter, resultCh <-chan *LogEntry) { db.logger.Info("Starting log writer") const batchSize = 100 const flushInterval = 1 * time.Second @@ -14,53 +17,63 @@ func WriteReq(db *Request_Writer, resultCh <-chan *LogEntry) { defer ticker.Stop() flush := func() { - if len(batch) == 0 { - return - } - - tx, err := db.db.Begin() - if err != nil { - db.logger.Error("Failed to begin transaction", "error", err) - return - } - - stmt, err := tx.Prepare( - "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", - ) - if err != nil { - db.logger.Error("Failed to prepare statement", "error", err) - if rollbackErr := tx.Rollback(); rollbackErr != nil { - db.logger.Error("Failed to rollback transaction", "error", rollbackErr) + defer db.logger.Debug("Flushed batch", "count", len(batch)) + err := func() (err error) { + if len(batch) == 0 { + return nil } - return - } - defer func() { - if closeErr := stmt.Close(); closeErr != nil { - db.logger.Error("Failed to close statement", "error", closeErr) - } - }() - for _, entry := range batch { - _, err := stmt.Exec( - entry.Service, - entry.IP, - entry.Path, - entry.Method, - entry.Status, - time.Now().Format(time.RFC3339), + tx, err := db.db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if rollbackErr := tx.Rollback(); rollbackErr != nil && + !errors.Is(rollbackErr, sql.ErrTxDone) { + err = errors.Join( + err, + fmt.Errorf("failed to rollback transaction: %w", rollbackErr), + ) + } + }() + + stmt, err := tx.Prepare( + "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", ) if err != nil { - db.logger.Error("Failed to insert entry", "error", err) + err = fmt.Errorf("failed to prepare statement: %w", err) + return err } - } + defer func() { + if closeErr := stmt.Close(); closeErr != nil { + err = errors.Join(err, fmt.Errorf("failed to close statement: %w", closeErr)) + } + }() - if err := tx.Commit(); err != nil { - db.logger.Error("Failed to commit transaction", "error", err) - return - } + for _, entry := range batch { + _, err := stmt.Exec( + entry.Service, + entry.IP, + entry.Path, + entry.Method, + entry.Status, + time.Now().Format(time.RFC3339), + ) + if err != nil { + db.logger.Error(fmt.Errorf("failed to insert entry: %w", err).Error()) + } + } - db.logger.Debug("Flushed batch", "count", len(batch)) - batch = batch[:0] + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + batch = batch[:0] + return err + }() + if err != nil { + db.logger.Error(err.Error()) + } } for { diff --git a/internal/storage/writer_test.go b/internal/storage/writer_test.go index 5b1be20..68b31e5 100644 --- a/internal/storage/writer_test.go +++ b/internal/storage/writer_test.go @@ -277,7 +277,7 @@ func TestWrite_ChannelClosed(t *testing.T) { } } -func NewRequestWriterWithDBPath(dbPath string) (*Request_Writer, error) { +func NewRequestWriterWithDBPath(dbPath string) (*RequestWriter, error) { db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)") if err != nil { return nil, err @@ -285,13 +285,13 @@ func NewRequestWriterWithDBPath(dbPath string) (*Request_Writer, error) { db.SetMaxOpenConns(1) db.SetMaxIdleConns(1) db.SetConnMaxLifetime(0) - return &Request_Writer{ + return &RequestWriter{ logger: logger.New(false), db: db, }, nil } -func (w *Request_Writer) CreateTable() error { +func (w *RequestWriter) CreateTable() error { _, err := w.db.Exec(CreateRequestsTable) if err != nil { return err @@ -300,7 +300,7 @@ func (w *Request_Writer) CreateTable() error { return nil } -func (w *Request_Writer) Close() error { +func (w *RequestWriter) Close() error { w.logger.Info("Closing request database connection") err := w.db.Close() if err != nil { @@ -309,7 +309,7 @@ func (w *Request_Writer) Close() error { return nil } -func (w *Request_Writer) GetRequestCount() (int, error) { +func (w *RequestWriter) GetRequestCount() (int, error) { var count int err := w.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&count) if err != nil {