From 791d64ae4d086ac30e854abc3b114b1e0d1d7d15 Mon Sep 17 00:00:00 2001 From: d3m0k1d Date: Thu, 22 Jan 2026 00:09:56 +0300 Subject: [PATCH] feat: Recode logic for add logs to db --- cmd/banforge/command/daemon.go | 20 ++---- internal/judge/judge.go | 126 +++++++++++++++------------------ internal/judge/judge_test.go | 6 +- internal/parser/NginxParser.go | 11 ++- internal/parser/sshd.go | 11 ++- internal/storage/db.go | 30 ++------ internal/storage/db_test.go | 93 ------------------------ internal/storage/migrations.go | 1 - internal/storage/models.go | 1 - 9 files changed, 81 insertions(+), 218 deletions(-) diff --git a/cmd/banforge/command/daemon.go b/cmd/banforge/command/daemon.go index 5b745f9..fa1390d 100644 --- a/cmd/banforge/command/daemon.go +++ b/cmd/banforge/command/daemon.go @@ -5,7 +5,6 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/config" @@ -20,6 +19,8 @@ var DaemonCmd = &cobra.Command{ Use: "daemon", Short: "Run BanForge daemon process", Run: func(cmd *cobra.Command, args []string) { + entryCh := make(chan *storage.LogEntry, 1000) + resultCh := make(chan *storage.LogEntry, 100) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) defer stop() log := logger.New(false) @@ -48,19 +49,10 @@ var DaemonCmd = &cobra.Command{ log.Error("Failed to load rules", "error", err) os.Exit(1) } - j := judge.New(db, b) + j := judge.New(db, b, resultCh, entryCh) j.LoadRules(r) go j.UnbanChecker() - go func() { - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - for range ticker.C { - if err := j.ProcessUnviewed(); err != nil { - log.Error("Failed to process unviewed", "error", err) - } - } - }() - resultCh := make(chan *storage.LogEntry, 1000) + go j.Tribunal() go storage.Write(db, resultCh) var scanners []*parser.Scanner @@ -99,12 +91,12 @@ var DaemonCmd = &cobra.Command{ if svc.Name == "nginx" { log.Info("Starting nginx parser", "service", serviceName) ng := parser.NewNginxParser() - ng.Parse(p.Events(), resultCh) + ng.Parse(p.Events(), entryCh) } if svc.Name == "ssh" { log.Info("Starting ssh parser", "service", serviceName) ssh := parser.NewSshdParser() - ssh.Parse(p.Events(), resultCh) + ssh.Parse(p.Events(), entryCh) } }(pars, svc.Name) continue diff --git a/internal/judge/judge.go b/internal/judge/judge.go index 20cb897..d64a8da 100644 --- a/internal/judge/judge.go +++ b/internal/judge/judge.go @@ -16,14 +16,18 @@ type Judge struct { logger *logger.Logger Blocker blocker.BlockerEngine rulesByService map[string][]config.Rule + entryCh chan *storage.LogEntry + resultCh chan *storage.LogEntry } -func New(db *storage.DB, b blocker.BlockerEngine) *Judge { +func New(db *storage.DB, b blocker.BlockerEngine, resultCh chan *storage.LogEntry, entryCh chan *storage.LogEntry) *Judge { return &Judge{ db: db, logger: logger.New(false), rulesByService: make(map[string][]config.Rule), Blocker: b, + entryCh: entryCh, + resultCh: resultCh, } } @@ -38,84 +42,70 @@ func (j *Judge) LoadRules(rules []config.Rule) { j.logger.Info("Rules loaded and indexed by service") } -func (j *Judge) ProcessUnviewed() error { - rows, err := j.db.SearchUnViewed() - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to query database: %v", err)) - return err - } - j.logger.Info("Unviewed logs found") - defer func() { - err = rows.Close() - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to close database connection: %v", err)) - } - }() - for rows.Next() { - var entry storage.LogEntry - err = rows.Scan( - &entry.ID, - &entry.Service, - &entry.IP, - &entry.Path, - &entry.Status, - &entry.Method, - &entry.IsViewed, - &entry.CreatedAt, - ) - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to scan database row: %v", err)) +func (j *Judge) Tribunal() { + j.logger.Info("Tribunal started") + + for entry := range j.entryCh { + j.logger.Debug("Processing entry", "ip", entry.IP, "service", entry.Service, "status", entry.Status) + + rules, serviceExists := j.rulesByService[entry.Service] + if !serviceExists { + j.logger.Debug("No rules for service", "service", entry.Service) continue } - rules, serviceExists := j.rulesByService[entry.Service] - if serviceExists { - for _, rule := range rules { - if (rule.Method == "" || entry.Method == rule.Method) && - (rule.Status == "" || entry.Status == rule.Status) && - matchPath(entry.Path, rule.Path) { + ruleMatched := false + for _, rule := range rules { + methodMatch := rule.Method == "" || entry.Method == rule.Method + statusMatch := rule.Status == "" || entry.Status == rule.Status + pathMatch := matchPath(entry.Path, rule.Path) - j.logger.Info( - fmt.Sprintf( - "Rule matched for IP: %s, Service: %s", - entry.IP, - entry.Service, - ), - ) - ban_status, err := j.db.IsBanned(entry.IP) - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to check ban status: %v", err)) - return err - } - if !ban_status { - err = j.Blocker.Ban(entry.IP) - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err)) - } - j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP)) - err = j.db.AddBan(entry.IP, rule.BanTime) - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err)) - } - } + j.logger.Debug( + "Testing rule", + "rule", rule.Name, + "method_match", methodMatch, + "status_match", statusMatch, + "path_match", pathMatch, + ) + + if methodMatch && statusMatch && pathMatch { + ruleMatched = true + j.logger.Info("Rule matched", "rule", rule.Name, "ip", entry.IP) + + banned, err := j.db.IsBanned(entry.IP) + if err != nil { + j.logger.Error("Failed to check ban status", "ip", entry.IP, "error", err) break } + + if banned { + j.logger.Info("IP already banned", "ip", entry.IP) + j.resultCh <- entry + break + } + + err = j.db.AddBan(entry.IP, rule.BanTime) + if err != nil { + j.logger.Error("Failed to add ban to database", "ip", entry.IP, "ban_time", rule.BanTime, "error", err) + break + } + + if err := j.Blocker.Ban(entry.IP); err != nil { + j.logger.Error("Failed to ban IP at firewall", "ip", entry.IP, "error", err) + break + } + j.logger.Info("IP banned successfully", "ip", entry.IP, "rule", rule.Name, "ban_time", rule.BanTime) + j.resultCh <- entry + break } } - err = j.db.MarkAsViewed(entry.ID) - if err != nil { - j.logger.Error(fmt.Sprintf("Failed to mark entry as viewed: %v", err)) - } else { - j.logger.Info(fmt.Sprintf("Entry marked as viewed: ID=%d", entry.ID)) + + if !ruleMatched { + j.logger.Debug("No rules matched", "ip", entry.IP, "service", entry.Service) } } - if err = rows.Err(); err != nil { - j.logger.Error(fmt.Sprintf("Error iterating rows: %v", err)) - return err - } - - return nil + j.logger.Info("Tribunal stopped - entryCh closed") } func (j *Judge) UnbanChecker() { diff --git a/internal/judge/judge_test.go b/internal/judge/judge_test.go index 72340bc..3d66605 100644 --- a/internal/judge/judge_test.go +++ b/internal/judge/judge_test.go @@ -18,21 +18,21 @@ func TestJudgeLogic(t *testing.T) { { name: "Empty rule", inputRule: config.Rule{Name: "", ServiceName: "", Path: "", Status: "", Method: ""}, - inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, + inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""}, wantErr: true, wantMatch: false, }, { name: "Matching rule", inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "200", Method: "GET"}, - inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, + inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""}, wantErr: false, wantMatch: true, }, { name: "Non-matching status", inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "404", Method: "GET"}, - inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, + inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""}, wantErr: false, wantMatch: false, }, diff --git a/internal/parser/NginxParser.go b/internal/parser/NginxParser.go index 3e15b19..7e94144 100644 --- a/internal/parser/NginxParser.go +++ b/internal/parser/NginxParser.go @@ -34,12 +34,11 @@ func (p *NginxParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEn method := matches[3] resultCh <- &storage.LogEntry{ - Service: "nginx", - IP: matches[1], - Path: path, - Status: status, - Method: method, - IsViewed: false, + Service: "nginx", + IP: matches[1], + Path: path, + Status: status, + Method: method, } p.logger.Info( "Parsed nginx log entry", diff --git a/internal/parser/sshd.go b/internal/parser/sshd.go index 87e4a4c..6a9b0f4 100644 --- a/internal/parser/sshd.go +++ b/internal/parser/sshd.go @@ -31,12 +31,11 @@ func (p *SshdParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEnt continue } resultCh <- &storage.LogEntry{ - Service: "ssh", - IP: matches[6], - Path: matches[5], // user - Status: "Failed", - Method: matches[4], // method auth - IsViewed: false, + Service: "ssh", + IP: matches[6], + Path: matches[5], // user + Status: "Failed", + Method: matches[4], // method auth } p.logger.Info( "Parsed ssh log entry", diff --git a/internal/storage/db.go b/internal/storage/db.go index 5de2793..054d9a2 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -2,15 +2,13 @@ package storage import ( "database/sql" - "os" - "fmt" - "time" - "github.com/d3m0k1d/BanForge/internal/config" "github.com/d3m0k1d/BanForge/internal/logger" "github.com/jedib0t/go-pretty/v6/table" _ "modernc.org/sqlite" + "os" + "time" ) type DB struct { @@ -23,8 +21,8 @@ func NewDB() (*DB, error) { "sqlite", "/var/lib/banforge/storage.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)", ) - db.SetMaxOpenConns(4) - db.SetMaxIdleConns(2) + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) db.SetConnMaxLifetime(0) if err != nil { return nil, err @@ -57,26 +55,6 @@ func (d *DB) CreateTable() error { return nil } -func (d *DB) SearchUnViewed() (*sql.Rows, error) { - rows, err := d.db.Query( - "SELECT id, service, ip, path, status, method, viewed, created_at FROM requests WHERE viewed = 0", - ) - if err != nil { - d.logger.Error("Failed to query database") - return nil, err - } - return rows, nil -} - -func (d *DB) MarkAsViewed(id int) error { - _, err := d.db.Exec("UPDATE requests SET viewed = 1 WHERE id = ?", id) - if err != nil { - d.logger.Error("Failed to mark as viewed", "error", err) - return err - } - return nil -} - func (d *DB) IsBanned(ip string) (bool, error) { var bannedIP string err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP) diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index bc3d613..c426e0b 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -74,99 +74,6 @@ func TestCreateTable(t *testing.T) { rows.Close() } -func TestMarkAsViewed(t *testing.T) { - d := createTestDBStruct(t) - - err := d.CreateTable() - if err != nil { - t.Fatal(err) - } - - _, err = d.db.Exec( - "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", - "test", - "127.0.0.1", - "/test", - "GET", - "200", - time.Now().Format(time.RFC3339), - ) - if err != nil { - t.Fatal(err) - } - - err = d.MarkAsViewed(1) - if err != nil { - t.Fatal(err) - } - - var isViewed bool - err = d.db.QueryRow("SELECT viewed FROM requests WHERE id = 1").Scan(&isViewed) - if err != nil { - t.Fatal(err) - } - if !isViewed { - t.Fatal("viewed should be true") - } -} - -func TestSearchUnViewed(t *testing.T) { - d := createTestDBStruct(t) - - err := d.CreateTable() - if err != nil { - t.Fatal(err) - } - - for i := 0; i < 2; i++ { - _, err := d.db.Exec( - "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", - "test", - "127.0.0.1", - "/test", - "GET", - "200", - time.Now().Format(time.RFC3339), - ) - if err != nil { - t.Fatal(err) - } - } - - rows, err := d.SearchUnViewed() - if err != nil { - t.Fatal(err) - } - defer rows.Close() - - count := 0 - for rows.Next() { - var id int - var service, ip, path, status, method string - var viewed bool - var createdAt string - - err := rows.Scan(&id, &service, &ip, &path, &status, &method, &viewed, &createdAt) - if err != nil { - t.Fatal(err) - } - - if viewed { - t.Fatal("should be unviewed") - } - - count++ - } - - if err := rows.Err(); err != nil { - t.Fatal(err) - } - - if count != 2 { - t.Fatalf("expected 2 unviewed requests, got %d", count) - } -} - func TestIsBanned(t *testing.T) { d := createTestDBStruct(t) diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go index 35ad85a..91ca66a 100644 --- a/internal/storage/migrations.go +++ b/internal/storage/migrations.go @@ -9,7 +9,6 @@ CREATE TABLE IF NOT EXISTS requests ( path TEXT, method TEXT, status TEXT, - viewed BOOLEAN DEFAULT FALSE, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); diff --git a/internal/storage/models.go b/internal/storage/models.go index 013a106..dcf3da7 100644 --- a/internal/storage/models.go +++ b/internal/storage/models.go @@ -7,7 +7,6 @@ type LogEntry struct { Path string `db:"path"` Status string `db:"status"` Method string `db:"method"` - IsViewed bool `db:"viewed"` CreatedAt string `db:"created_at"` }