diff --git a/cmd/banforge/command/daemon.go b/cmd/banforge/command/daemon.go index fa1390d..125075a 100644 --- a/cmd/banforge/command/daemon.go +++ b/cmd/banforge/command/daemon.go @@ -25,13 +25,27 @@ var DaemonCmd = &cobra.Command{ defer stop() log := logger.New(false) log.Info("Starting BanForge daemon") - db, err := storage.NewDB() + reqDb_w, err := storage.NewRequestsWr() if err != nil { - log.Error("Failed to create database", "error", err) + log.Error("Failed to create request writer", "error", err) + os.Exit(1) + } + banDb_r, err := storage.NewBanReader() + if err != nil { + log.Error("Failed to create ban reader", "error", err) + os.Exit(1) + } + banDb_w, err := storage.NewBanWriter() + if err != nil { + log.Error("Failed to create ban writter", "error", err) os.Exit(1) } defer func() { - err = db.Close() + err = banDb_r.Close() + if err != nil { + log.Error("Failed to close database connection", "error", err) + } + err = banDb_w.Close() if err != nil { log.Error("Failed to close database connection", "error", err) } @@ -49,11 +63,11 @@ var DaemonCmd = &cobra.Command{ log.Error("Failed to load rules", "error", err) os.Exit(1) } - j := judge.New(db, b, resultCh, entryCh) + j := judge.New(banDb_r, banDb_w, b, resultCh, entryCh) j.LoadRules(r) go j.UnbanChecker() go j.Tribunal() - go storage.Write(db, resultCh) + go storage.Write(reqDb_w, resultCh) var scanners []*parser.Scanner for _, svc := range cfg.Service { diff --git a/cmd/banforge/command/fw.go b/cmd/banforge/command/fw.go index 12b66ee..eb2a589 100644 --- a/cmd/banforge/command/fw.go +++ b/cmd/banforge/command/fw.go @@ -18,7 +18,7 @@ var UnbanCmd = &cobra.Command{ Use: "unban", Short: "Unban IP", Run: func(cmd *cobra.Command, args []string) { - db, err := storage.NewDB() + db, err := storage.NewBanWriter() if err != nil { fmt.Println(err) os.Exit(1) @@ -60,7 +60,7 @@ var BanCmd = &cobra.Command{ Use: "ban", Short: "Ban IP", Run: func(cmd *cobra.Command, args []string) { - db, err := storage.NewDB() + db, err := storage.NewBanWriter() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/banforge/command/init.go b/cmd/banforge/command/init.go index c37e6a5..584b4b5 100644 --- a/cmd/banforge/command/init.go +++ b/cmd/banforge/command/init.go @@ -82,23 +82,11 @@ var InitCmd = &cobra.Command{ } fmt.Println("Firewall configured") - db, err := storage.NewDB() + err = storage.CreateTables() if err != nil { fmt.Println(err) os.Exit(1) } - err = db.CreateTable() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - defer func() { - err = db.Close() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - }() fmt.Println("Firewall detected and configured") fmt.Println("BanForge initialized successfully!") diff --git a/cmd/banforge/command/list.go b/cmd/banforge/command/list.go index f3a55d8..690cd2d 100644 --- a/cmd/banforge/command/list.go +++ b/cmd/banforge/command/list.go @@ -13,7 +13,7 @@ var BanListCmd = &cobra.Command{ Short: "List banned IP adresses", Run: func(cmd *cobra.Command, args []string) { var log = logger.New(false) - d, err := storage.NewDB() + d, err := storage.NewBanReader() if err != nil { log.Error("Failed to create database", "error", err) os.Exit(1) diff --git a/internal/judge/judge.go b/internal/judge/judge.go index 7092eba..4171cd9 100644 --- a/internal/judge/judge.go +++ b/internal/judge/judge.go @@ -12,7 +12,8 @@ import ( ) type Judge struct { - db *storage.DB + db_r *storage.BanReader + db_w *storage.BanWriter logger *logger.Logger Blocker blocker.BlockerEngine rulesByService map[string][]config.Rule @@ -21,13 +22,15 @@ type Judge struct { } func New( - db *storage.DB, + db_r *storage.BanReader, + db_w *storage.BanWriter, b blocker.BlockerEngine, resultCh chan *storage.LogEntry, entryCh chan *storage.LogEntry, ) *Judge { return &Judge{ - db: db, + db_w: db_w, + db_r: db_r, logger: logger.New(false), rulesByService: make(map[string][]config.Rule), Blocker: b, @@ -85,7 +88,7 @@ func (j *Judge) Tribunal() { ruleMatched = true j.logger.Info("Rule matched", "rule", rule.Name, "ip", entry.IP) - banned, err := j.db.IsBanned(entry.IP) + banned, err := j.db_r.IsBanned(entry.IP) if err != nil { j.logger.Error("Failed to check ban status", "ip", entry.IP, "error", err) break @@ -97,7 +100,7 @@ func (j *Judge) Tribunal() { break } - err = j.db.AddBan(entry.IP, rule.BanTime) + err = j.db_w.AddBan(entry.IP, rule.BanTime) if err != nil { j.logger.Error( "Failed to add ban to database", @@ -142,22 +145,16 @@ func (j *Judge) UnbanChecker() { defer tick.Stop() for range tick.C { - ips, err := j.db.CheckExpiredBans() + ips, err := j.db_w.RemoveExpiredBans() if err != nil { j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err)) continue } for _, ip := range ips { - err = j.db.RemoveBan(ip) - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to remove ban: %v", err)) - } if err := j.Blocker.Unban(ip); err != nil { - j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err)) - continue + j.logger.Error(fmt.Sprintf("Failed to unban IP at firewall: %v", err)) } - j.logger.Info(fmt.Sprintf("IP unbanned: %s", ip)) } } } diff --git a/internal/storage/ban_db.go b/internal/storage/ban_db.go new file mode 100644 index 0000000..46f5312 --- /dev/null +++ b/internal/storage/ban_db.go @@ -0,0 +1,204 @@ +package storage + +import ( + "database/sql" + "fmt" + "github.com/d3m0k1d/BanForge/internal/config" + "github.com/d3m0k1d/BanForge/internal/logger" + "github.com/jedib0t/go-pretty/v6/table" + _ "modernc.org/sqlite" + "os" + "time" +) + +// Writer block +type BanWriter struct { + logger *logger.Logger + db *sql.DB +} + +func NewBanWriter() (*BanWriter, error) { + db, err := sql.Open("sqlite", "/var/lib/banforge/bans.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)") + if err != nil { + return nil, err + } + return &BanWriter{ + logger: logger.New(false), + db: db, + }, nil +} + +func (d *BanWriter) CreateTable() error { + _, err := d.db.Exec(CreateBansTable) + if err != nil { + return err + } + d.logger.Info("Created tables") + return nil +} + +func (d *BanWriter) AddBan(ip string, ttl string) error { + duration, err := config.ParseDurationWithYears(ttl) + if err != nil { + d.logger.Error("Invalid duration format", "ttl", ttl, "error", err) + return fmt.Errorf("invalid duration: %w", err) + } + + now := time.Now() + expiredAt := now.Add(duration) + + _, err = d.db.Exec( + "INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)", + ip, + "1", + now.Format(time.RFC3339), + expiredAt.Format(time.RFC3339), + ) + if err != nil { + d.logger.Error("Failed to add ban", "error", err) + return err + } + + return nil +} + +func (d *BanWriter) RemoveBan(ip string) error { + _, err := d.db.Exec("DELETE FROM bans WHERE ip = ?", ip) + if err != nil { + d.logger.Error("Failed to remove ban", "error", err) + return err + } + return nil +} + +func (w *BanWriter) RemoveExpiredBans() ([]string, error) { + var ips []string + now := time.Now().Format(time.RFC3339) + + rows, err := w.db.Query( + "SELECT ip FROM bans WHERE expired_at < ?", + now, + ) + if err != nil { + w.logger.Error("Failed to get expired bans", "error", err) + return nil, err + } + defer rows.Close() + + for rows.Next() { + var ip string + err := rows.Scan(&ip) + if err != nil { + w.logger.Error("Failed to scan ban", "error", err) + continue + } + ips = append(ips, ip) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + result, err := w.db.Exec( + "DELETE FROM bans WHERE expired_at < ?", + now, + ) + if err != nil { + w.logger.Error("Failed to remove expired bans", "error", err) + return nil, err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, err + } + + if rowsAffected > 0 { + w.logger.Info("Removed expired bans", "count", rowsAffected, "ips", len(ips)) + } + + return ips, nil +} + +func (d *BanWriter) Close() error { + d.logger.Info("Closing database connection") + err := d.db.Close() + if err != nil { + return err + } + return nil +} + +// Reader block + +type BanReader struct { + logger *logger.Logger + db *sql.DB +} + +func NewBanReader() (*BanReader, error) { + db, err := sql.Open("sqlite", + "/var/lib/banforge/bans.db?"+ + "mode=ro&"+ + "_pragma=journal_mode(WAL)&"+ + "_pragma=mmap_size(268435456)&"+ + "_pragma=cache_size(-2000)&"+ + "_pragma=query_only(1)") + if err != nil { + return nil, err + } + + return &BanReader{ + logger: logger.New(false), + db: db, + }, nil +} + +func (d *BanReader) IsBanned(ip string) (bool, error) { + var bannedIP string + err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP) + if err == sql.ErrNoRows { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to check ban status: %w", err) + } + return true, nil +} + +func (d *BanReader) BanList() error { + + var count int + t := table.NewWriter() + t.SetOutputMirror(os.Stdout) + t.SetStyle(table.StyleBold) + t.AppendHeader(table.Row{"№", "IP", "Banned At"}) + rows, err := d.db.Query("SELECT ip, banned_at FROM bans") + if err != nil { + d.logger.Error("Failed to get ban list", "error", err) + return err + } + for rows.Next() { + count++ + var ip string + var bannedAt string + err := rows.Scan(&ip, &bannedAt) + if err != nil { + d.logger.Error("Failed to get ban list", "error", err) + return err + } + t.AppendRow(table.Row{count, ip, bannedAt}) + + } + t.Render() + return nil +} + +func (d *BanReader) Close() error { + d.logger.Info("Closing database connection") + err := d.db.Close() + if err != nil { + return err + } + return nil +} diff --git a/internal/storage/ban_db_test.go b/internal/storage/ban_db_test.go new file mode 100644 index 0000000..b569730 --- /dev/null +++ b/internal/storage/ban_db_test.go @@ -0,0 +1,380 @@ +package storage + +import ( + "database/sql" + + "github.com/d3m0k1d/BanForge/internal/logger" + "path/filepath" + "testing" +) + +func TestBanWriter_AddBan(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + ip := "192.168.1.1" + ttl := "1h" + + err = writer.AddBan(ip, ttl) + if err != nil { + t.Errorf("AddBan failed: %v", err) + } + + reader, err := NewBanReaderWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanReader: %v", err) + } + defer reader.Close() + + isBanned, err := reader.IsBanned(ip) + if err != nil { + t.Errorf("IsBanned failed: %v", err) + } + if !isBanned { + t.Error("Expected IP to be banned, but it's not") + } +} + +func TestBanWriter_RemoveBan(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + ip := "192.168.1.2" + err = writer.AddBan(ip, "1h") + if err != nil { + t.Fatalf("Failed to add ban: %v", err) + } + + reader, err := NewBanReaderWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanReader: %v", err) + } + defer reader.Close() + + isBanned, err := reader.IsBanned(ip) + if err != nil { + t.Fatalf("IsBanned failed: %v", err) + } + if !isBanned { + t.Fatal("Expected IP to be banned before removal") + } + + err = writer.RemoveBan(ip) + if err != nil { + t.Errorf("RemoveBan failed: %v", err) + } + + isBanned, err = reader.IsBanned(ip) + if err != nil { + t.Errorf("IsBanned failed after removal: %v", err) + } + if isBanned { + t.Error("Expected IP to be unbanned after removal, but it's still banned") + } +} + +func TestBanWriter_RemoveExpiredBans(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + expiredIP := "192.168.1.3" + err = writer.AddBan(expiredIP, "-1h") + if err != nil { + t.Fatalf("Failed to add expired ban: %v", err) + } + + activeIP := "192.168.1.4" + err = writer.AddBan(activeIP, "1h") + if err != nil { + t.Fatalf("Failed to add active ban: %v", err) + } + + removedIPs, err := writer.RemoveExpiredBans() + if err != nil { + t.Errorf("RemoveExpiredBans failed: %v", err) + } + + found := false + for _, ip := range removedIPs { + if ip == expiredIP { + found = true + break + } + } + if !found { + t.Error("Expected expired IP to be in removed list") + } + + if len(removedIPs) != 1 { + t.Errorf("Expected 1 removed IP, got %d", len(removedIPs)) + } + + reader, err := NewBanReaderWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanReader: %v", err) + } + defer reader.Close() + + isExpiredBanned, err := reader.IsBanned(expiredIP) + if err != nil { + t.Errorf("IsBanned failed for expired IP: %v", err) + } + if isExpiredBanned { + t.Error("Expected expired IP to be unbanned, but it's still banned") + } + + isActiveBanned, err := reader.IsBanned(activeIP) + if err != nil { + t.Errorf("IsBanned failed for active IP: %v", err) + } + if !isActiveBanned { + t.Error("Expected active IP to still be banned, but it's not") + } +} + +func TestBanReader_IsBanned(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + ip := "192.168.1.5" + err = writer.AddBan(ip, "1h") + if err != nil { + t.Fatalf("Failed to add ban: %v", err) + } + + reader, err := NewBanReaderWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanReader: %v", err) + } + defer reader.Close() + + isBanned, err := reader.IsBanned(ip) + if err != nil { + t.Errorf("IsBanned failed for banned IP: %v", err) + } + if !isBanned { + t.Error("Expected IP to be banned") + } + + isBanned, err = reader.IsBanned("192.168.1.6") + if err != nil { + t.Errorf("IsBanned failed for non-banned IP: %v", err) + } + if isBanned { + t.Error("Expected IP to not be banned") + } +} + +func TestBanWriter_Close(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + err = writer.Close() + if err != nil { + t.Errorf("Close failed: %v", err) + } + + _, err = writer.db.Exec("SELECT 1") + if err == nil { + t.Error("Expected error when using closed connection") + } +} + +func TestBanReader_Close(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + reader, err := NewBanReaderWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanReader: %v", err) + } + + err = reader.Close() + if err != nil { + t.Errorf("Close failed: %v", err) + } + + _, err = reader.db.Query("SELECT 1") + if err == nil { + t.Error("Expected error when using closed connection") + } +} + +func TestBanWriter_AddBan_InvalidDuration(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + err = writer.AddBan("192.168.1.7", "invalid_duration") + if err == nil { + t.Error("Expected error for invalid duration") + } else if err.Error() == "" || err.Error() == "" { + t.Error("Expected meaningful error message for invalid duration") + } +} + +func TestMultipleBans(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + ips := []string{"192.168.1.8", "192.168.1.9", "192.168.1.10"} + + for _, ip := range ips { + err := writer.AddBan(ip, "1h") + if err != nil { + t.Errorf("Failed to add ban for IP %s: %v", ip, err) + } + } + + reader, err := NewBanReaderWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanReader: %v", err) + } + defer reader.Close() + + for _, ip := range ips { + isBanned, err := reader.IsBanned(ip) + if err != nil { + t.Errorf("IsBanned failed for IP %s: %v", ip, err) + continue + } + if !isBanned { + t.Errorf("Expected IP %s to be banned", ip) + } + } +} + +func TestRemoveNonExistentBan(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "bans_test.db") + + writer, err := NewBanWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create BanWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + err = writer.RemoveBan("192.168.1.11") + if err != nil { + t.Errorf("RemoveBan should not return error for non-existent ban: %v", err) + } +} +func NewBanWriterWithDBPath(dbPath string) (*BanWriter, error) { + db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)") + if err != nil { + return nil, err + } + return &BanWriter{ + logger: logger.New(false), + db: db, + }, nil +} + +func NewBanReaderWithDBPath(dbPath string) (*BanReader, error) { + db, err := sql.Open("sqlite", + dbPath+"?"+ + "mode=ro&"+ + "_pragma=journal_mode(WAL)&"+ + "_pragma=mmap_size(268435456)&"+ + "_pragma=cache_size(-2000)&"+ + "_pragma=query_only(1)") + if err != nil { + return nil, err + } + + return &BanReader{ + logger: logger.New(false), + db: db, + }, nil +} diff --git a/internal/storage/db.go b/internal/storage/db.go index bada3a5..8ee4eac 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -3,157 +3,43 @@ package storage import ( "database/sql" "fmt" - "os" - "time" - - "github.com/d3m0k1d/BanForge/internal/config" - "github.com/d3m0k1d/BanForge/internal/logger" - "github.com/jedib0t/go-pretty/v6/table" _ "modernc.org/sqlite" ) -type DB struct { - logger *logger.Logger - db *sql.DB -} - -func NewDB() (*DB, error) { - db, err := sql.Open( - "sqlite", - "/var/lib/banforge/storage.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)", - ) - db.SetMaxOpenConns(1) - db.SetMaxIdleConns(1) - db.SetConnMaxLifetime(0) +func CreateTables() error { + // Requests DB + db_r, err := sql.Open("sqlite", + "/var/lib/banforge/requests.db?"+ + "mode=rwc&"+ + "_pragma=journal_mode(WAL)&"+ + "_pragma=busy_timeout(30000)&"+ + "_pragma=synchronous(NORMAL)") if err != nil { - return nil, err + return fmt.Errorf("failed to open requests db: %w", err) } + defer db_r.Close() - if err := db.Ping(); err != nil { - return nil, err - } - return &DB{ - logger: logger.New(false), - db: db, - }, nil -} - -func (d *DB) Close() error { - d.logger.Info("Closing database connection") - err := d.db.Close() + _, err = db_r.Exec(CreateRequestsTable) if err != nil { - return err + return fmt.Errorf("failed to create requests table: %w", err) } + + // Bans DB + db_b, err := sql.Open("sqlite", + "/var/lib/banforge/bans.db?"+ + "mode=rwc&"+ + "_pragma=journal_mode(WAL)&"+ + "_pragma=busy_timeout(30000)&"+ + "_pragma=synchronous(FULL)") + if err != nil { + return fmt.Errorf("failed to open bans db: %w", err) + } + defer db_b.Close() + + _, err = db_b.Exec(CreateBansTable) + if err != nil { + return fmt.Errorf("failed to create bans table: %w", err) + } + fmt.Println("Tables created successfully!") return nil } - -func (d *DB) CreateTable() error { - _, err := d.db.Exec(CreateTables) - if err != nil { - return err - } - d.logger.Info("Created tables") - return nil -} - -func (d *DB) IsBanned(ip string) (bool, error) { - var bannedIP string - err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP) - if err == sql.ErrNoRows { - return false, nil - } - if err != nil { - return false, fmt.Errorf("failed to check ban status: %w", err) - } - return true, nil -} - -func (d *DB) AddBan(ip string, ttl string) error { - duration, err := config.ParseDurationWithYears(ttl) - if err != nil { - d.logger.Error("Invalid duration format", "ttl", ttl, "error", err) - return fmt.Errorf("invalid duration: %w", err) - } - - now := time.Now() - expiredAt := now.Add(duration) - - _, err = d.db.Exec( - "INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)", - ip, - "1", - now.Format(time.RFC3339), - expiredAt.Format(time.RFC3339), - ) - if err != nil { - d.logger.Error("Failed to add ban", "error", err) - return err - } - - return nil -} - -func (d *DB) RemoveBan(ip string) error { - _, err := d.db.Exec("DELETE FROM bans WHERE ip = ?", ip) - if err != nil { - d.logger.Error("Failed to remove ban", "error", err) - return err - } - return nil -} - -func (d *DB) BanList() error { - - var count int - t := table.NewWriter() - t.SetOutputMirror(os.Stdout) - t.SetStyle(table.StyleBold) - t.AppendHeader(table.Row{"№", "IP", "Banned At"}) - rows, err := d.db.Query("SELECT ip, banned_at FROM bans") - if err != nil { - d.logger.Error("Failed to get ban list", "error", err) - return err - } - for rows.Next() { - count++ - var ip string - var bannedAt string - err := rows.Scan(&ip, &bannedAt) - if err != nil { - d.logger.Error("Failed to get ban list", "error", err) - return err - } - t.AppendRow(table.Row{count, ip, bannedAt}) - - } - t.Render() - return nil -} - -func (d *DB) CheckExpiredBans() ([]string, error) { - var ips []string - rows, err := d.db.Query( - "SELECT ip FROM bans WHERE expired_at < ?", - time.Now().Format(time.RFC3339), - ) - if err != nil { - d.logger.Error("Failed to get ban list", "error", err) - return nil, err - } - for rows.Next() { - var ip string - r, err := d.db.Exec("DELETE FROM bans WHERE ip = ?", ip) - if err != nil { - d.logger.Error("Failed to get ban list", "error", err) - return nil, err - } - d.logger.Info("Ban removed", "ip", ip, "rows", r) - err = rows.Scan(&ip) - if err != nil { - d.logger.Error("Failed to get ban list", "error", err) - return nil, err - } - ips = append(ips, ip) - } - return ips, nil -} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go deleted file mode 100644 index c426e0b..0000000 --- a/internal/storage/db_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package storage - -import ( - "database/sql" - "github.com/d3m0k1d/BanForge/internal/logger" - _ "modernc.org/sqlite" - "os" - "path/filepath" - "testing" - "time" -) - -func createTestDB(t *testing.T) *sql.DB { - tmpDir, err := os.MkdirTemp("", "banforge-test-*") - if err != nil { - t.Fatal(err) - } - - filePath := filepath.Join(tmpDir, "test.db") - db, err := sql.Open("sqlite", filePath) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - db.Close() - os.RemoveAll(tmpDir) - }) - - return db -} - -func createTestDBStruct(t *testing.T) *DB { - tmpDir, err := os.MkdirTemp("", "banforge-test-*") - if err != nil { - t.Fatal(err) - } - - filePath := filepath.Join(tmpDir, "test.db") - sqlDB, err := sql.Open("sqlite", filePath) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - sqlDB.Close() - os.RemoveAll(tmpDir) - }) - - return &DB{ - logger: logger.New(false), - db: sqlDB, - } -} - -func TestCreateTable(t *testing.T) { - d := createTestDBStruct(t) - - err := d.CreateTable() - if err != nil { - t.Fatal(err) - } - - rows, err := d.db.Query("SELECT 1 FROM requests LIMIT 1") - if err != nil { - t.Fatal("requests table should exist:", err) - } - rows.Close() - - rows, err = d.db.Query("SELECT 1 FROM bans LIMIT 1") - if err != nil { - t.Fatal("bans table should exist:", err) - } - rows.Close() -} - -func TestIsBanned(t *testing.T) { - d := createTestDBStruct(t) - - err := d.CreateTable() - if err != nil { - t.Fatal(err) - } - - _, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339)) - if err != nil { - t.Fatal(err) - } - - isBanned, err := d.IsBanned("127.0.0.1") - if err != nil { - t.Fatal(err) - } - - if !isBanned { - t.Fatal("should be banned") - } -} - -func TestAddBan(t *testing.T) { - d := createTestDBStruct(t) - - err := d.CreateTable() - if err != nil { - t.Fatal(err) - } - - err = d.AddBan("127.0.0.1", "7h") - if err != nil { - t.Fatal(err) - } - - var ip string - err = d.db.QueryRow("SELECT ip FROM bans WHERE ip = ?", "127.0.0.1").Scan(&ip) - if err != nil { - t.Fatal(err) - } - - if ip != "127.0.0.1" { - t.Fatal("ip should be 127.0.0.1") - } -} - -func TestBanList(t *testing.T) { - d := createTestDBStruct(t) - - err := d.CreateTable() - if err != nil { - t.Fatal(err) - } - - _, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339)) - if err != nil { - t.Fatal(err) - } - - err = d.BanList() - if err != nil { - t.Fatal(err) - } -} - -func TestClose(t *testing.T) { - d := createTestDBStruct(t) - - err := d.Close() - if err != nil { - t.Fatal(err) - } -} diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go index 91ca66a..ff42852 100644 --- a/internal/storage/migrations.go +++ b/internal/storage/migrations.go @@ -1,7 +1,6 @@ package storage -const CreateTables = ` - +const CreateRequestsTable = ` CREATE TABLE IF NOT EXISTS requests ( id INTEGER PRIMARY KEY, service TEXT NOT NULL, @@ -12,6 +11,14 @@ CREATE TABLE IF NOT EXISTS requests ( created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); +CREATE INDEX IF NOT EXISTS idx_requests_service ON requests(service); +CREATE INDEX IF NOT EXISTS idx_requests_ip ON requests(ip); +CREATE INDEX IF NOT EXISTS idx_requests_status ON requests(status); +CREATE INDEX IF NOT EXISTS idx_requests_created_at ON requests(created_at); +` + +// Миграция для bans.db +const CreateBansTable = ` CREATE TABLE IF NOT EXISTS bans ( id INTEGER PRIMARY KEY, ip TEXT UNIQUE NOT NULL, @@ -20,9 +27,5 @@ CREATE TABLE IF NOT EXISTS bans ( expired_at DATETIME ); -CREATE INDEX IF NOT EXISTS idx_service ON requests(service); -CREATE INDEX IF NOT EXISTS idx_ip ON requests(ip); -CREATE INDEX IF NOT EXISTS idx_status ON requests(status); -CREATE INDEX IF NOT EXISTS idx_created_at ON requests(created_at); -CREATE INDEX IF NOT EXISTS idx_ban_ip ON bans(ip); - ` +CREATE INDEX IF NOT EXISTS idx_bans_ip ON bans(ip); +` diff --git a/internal/storage/requests_db.go b/internal/storage/requests_db.go new file mode 100644 index 0000000..e64e65b --- /dev/null +++ b/internal/storage/requests_db.go @@ -0,0 +1,26 @@ +package storage + +import ( + "database/sql" + "github.com/d3m0k1d/BanForge/internal/logger" + _ "modernc.org/sqlite" +) + +type Request_Writer struct { + logger *logger.Logger + db *sql.DB +} + +func NewRequestsWr() (*Request_Writer, error) { + db, err := sql.Open("sqlite", "/var/lib/banforge/requests.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)") + if err != nil { + return nil, err + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) + return &Request_Writer{ + logger: logger.New(false), + db: db, + }, nil +} diff --git a/internal/storage/writer.go b/internal/storage/writer.go index fb78bf9..41ccb46 100644 --- a/internal/storage/writer.go +++ b/internal/storage/writer.go @@ -4,7 +4,7 @@ import ( "time" ) -func Write(db *DB, resultCh <-chan *LogEntry) { +func Write(db *Request_Writer, resultCh <-chan *LogEntry) { db.logger.Info("Starting log writer") const batchSize = 100 const flushInterval = 1 * time.Second @@ -28,17 +28,15 @@ func Write(db *DB, resultCh <-chan *LogEntry) { "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", ) if err != nil { - err := tx.Rollback() - if err != nil { - db.logger.Error("Failed to rollback transaction", "error", err) - } db.logger.Error("Failed to prepare statement", "error", err) + if rollbackErr := tx.Rollback(); rollbackErr != nil { + db.logger.Error("Failed to rollback transaction", "error", rollbackErr) + } return } defer func() { - err := stmt.Close() - if err != nil { - db.logger.Error("Failed to close statement", "error", err) + if closeErr := stmt.Close(); closeErr != nil { + db.logger.Error("Failed to close statement", "error", closeErr) } }() @@ -58,10 +56,10 @@ func Write(db *DB, resultCh <-chan *LogEntry) { if err := tx.Commit(); err != nil { db.logger.Error("Failed to commit transaction", "error", err) - } else { - db.logger.Debug("Flushed batch", "count", len(batch)) + return } + db.logger.Debug("Flushed batch", "count", len(batch)) batch = batch[:0] } diff --git a/internal/storage/writer_test.go b/internal/storage/writer_test.go index 659216b..256d657 100644 --- a/internal/storage/writer_test.go +++ b/internal/storage/writer_test.go @@ -1,40 +1,319 @@ package storage import ( + "database/sql" + "github.com/d3m0k1d/BanForge/internal/logger" + _ "modernc.org/sqlite" + "path/filepath" "testing" "time" ) -func TestWrite(t *testing.T) { - var ip string - d := createTestDBStruct(t) +func TestWrite_BatchInsert(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "requests_test.db") - err := d.CreateTable() + writer, err := NewRequestWriterWithDBPath(dbPath) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to create RequestWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) } - resultCh := make(chan *LogEntry, 100) // ← Добавь буфер + resultCh := make(chan *LogEntry, 100) - go Write(d, resultCh) + done := make(chan bool) + go func() { + Write(writer, resultCh) + close(done) + }() - resultCh <- &LogEntry{ - Service: "test", - IP: "127.0.0.1", - Path: "/test", - Method: "GET", - Status: "200", + entries := []*LogEntry{ + {Service: "service1", IP: "192.168.1.1", Path: "/path1", Method: "GET", Status: "200"}, + {Service: "service2", IP: "192.168.1.2", Path: "/path2", Method: "POST", Status: "404"}, + {Service: "service3", IP: "192.168.1.3", Path: "/path3", Method: "PUT", Status: "500"}, + {Service: "service4", IP: "192.168.1.4", Path: "/path4", Method: "DELETE", Status: "200"}, + {Service: "service5", IP: "192.168.1.5", Path: "/path5", Method: "GET", Status: "301"}, + } + + for _, entry := range entries { + resultCh <- entry + } + + close(resultCh) + <-done + + count, err := writer.GetRequestCount() + if err != nil { + t.Fatalf("Failed to get request count: %v", err) + } + + if count != len(entries) { + t.Errorf("Expected %d entries, got %d", len(entries), count) + } + rows, err := writer.db.Query("SELECT service, ip, path, method, status FROM requests ORDER BY id") + if err != nil { + t.Fatalf("Failed to query requests: %v", err) + } + defer rows.Close() + + i := 0 + for rows.Next() { + var service, ip, path, method, status string + err := rows.Scan(&service, &ip, &path, &method, &status) + if err != nil { + t.Fatalf("Failed to scan row: %v", err) + } + + if i >= len(entries) { + t.Fatal("More rows returned than expected") + } + + expected := entries[i] + if service != expected.Service { + t.Errorf("Expected service %s, got %s", expected.Service, service) + } + if ip != expected.IP { + t.Errorf("Expected IP %s, got %s", expected.IP, ip) + } + if path != expected.Path { + t.Errorf("Expected path %s, got %s", expected.Path, path) + } + if method != expected.Method { + t.Errorf("Expected method %s, got %s", expected.Method, method) + } + if status != expected.Status { + t.Errorf("Expected status %s, got %s", expected.Status, status) + } + + i++ + } + + if i != len(entries) { + t.Errorf("Expected to read %d entries, got %d", len(entries), i) + } +} + +func TestWrite_BatchSizeTrigger(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "requests_test.db") + + writer, err := NewRequestWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create RequestWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + resultCh := make(chan *LogEntry, 100) + done := make(chan bool) + go func() { + Write(writer, resultCh) + close(done) + }() + + batchSize := 100 + entries := make([]*LogEntry, batchSize) + for i := 0; i < batchSize; i++ { + entries[i] = &LogEntry{ + Service: "service" + string(rune(i+'0')), + IP: "192.168.1." + string(rune(i+'0')), + Path: "/path" + string(rune(i+'0')), + Method: "GET", + Status: "200", + } + } + + for _, entry := range entries { + resultCh <- entry + } + + close(resultCh) + <-done + + count, err := writer.GetRequestCount() + if err != nil { + t.Fatalf("Failed to get request count: %v", err) + } + + if count != batchSize { + t.Errorf("Expected %d entries, got %d", batchSize, count) + } +} + +func TestWrite_FlushInterval(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "requests_test.db") + + writer, err := NewRequestWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create RequestWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + resultCh := make(chan *LogEntry, 100) + + done := make(chan bool) + go func() { + Write(writer, resultCh) + close(done) + }() + + entries := []*LogEntry{ + {Service: "service1", IP: "192.168.1.1", Path: "/path1", Method: "GET", Status: "200"}, + {Service: "service2", IP: "192.168.1.2", Path: "/path2", Method: "POST", Status: "404"}, + {Service: "service3", IP: "192.168.1.3", Path: "/path3", Method: "PUT", Status: "500"}, + {Service: "service4", IP: "192.168.1.4", Path: "/path4", Method: "DELETE", Status: "200"}, + {Service: "service5", IP: "192.168.1.5", Path: "/path5", Method: "GET", Status: "301"}, + } + + for _, entry := range entries { + resultCh <- entry + } + time.Sleep(1500 * time.Millisecond) + + close(resultCh) + <-done + + count, err := writer.GetRequestCount() + if err != nil { + t.Fatalf("Failed to get request count: %v", err) + } + + if count != len(entries) { + t.Errorf("Expected %d entries, got %d", len(entries), count) + } +} + +func TestWrite_EmptyBatch(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "requests_test.db") + + writer, err := NewRequestWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create RequestWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + resultCh := make(chan *LogEntry, 100) + + done := make(chan bool) + go func() { + Write(writer, resultCh) + close(done) + }() + + close(resultCh) + <-done + count, err := writer.GetRequestCount() + if err != nil { + t.Fatalf("Failed to get request count: %v", err) + } + + if count != 0 { + t.Errorf("Expected 0 entries for empty batch, got %d", count) + } +} + +func TestWrite_ChannelClosed(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "requests_test.db") + + writer, err := NewRequestWriterWithDBPath(dbPath) + if err != nil { + t.Fatalf("Failed to create RequestWriter: %v", err) + } + defer writer.Close() + + err = writer.CreateTable() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + resultCh := make(chan *LogEntry, 100) + + done := make(chan bool) + go func() { + Write(writer, resultCh) + close(done) + }() + + entries := []*LogEntry{ + {Service: "service1", IP: "192.168.1.1", Path: "/path1", Method: "GET", Status: "200"}, + {Service: "service2", IP: "192.168.1.2", Path: "/path2", Method: "POST", Status: "404"}, + } + + for _, entry := range entries { + resultCh <- entry } close(resultCh) - time.Sleep(2 * time.Second) + <-done - err = d.db.QueryRow("SELECT ip FROM requests LIMIT 1").Scan(&ip) + count, err := writer.GetRequestCount() if err != nil { - t.Fatal(err) + t.Fatalf("Failed to get request count: %v", err) } - if ip != "127.0.0.1" { - t.Fatal("ip should be 127.0.0.1") + + if count != len(entries) { + t.Errorf("Expected %d entries, got %d", len(entries), count) } } + +func NewRequestWriterWithDBPath(dbPath string) (*Request_Writer, error) { + db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)") + if err != nil { + return nil, err + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) + return &Request_Writer{ + logger: logger.New(false), + db: db, + }, nil +} + +func (w *Request_Writer) CreateTable() error { + _, err := w.db.Exec(CreateRequestsTable) + if err != nil { + return err + } + w.logger.Info("Created requests table") + return nil +} + +func (w *Request_Writer) Close() error { + w.logger.Info("Closing request database connection") + err := w.db.Close() + if err != nil { + return err + } + return nil +} + +func (w *Request_Writer) GetRequestCount() (int, error) { + var count int + err := w.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&count) + if err != nil { + return 0, err + } + return count, nil +}