This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/d3m0k1d/BanForge/internal/blocker"
|
"github.com/d3m0k1d/BanForge/internal/blocker"
|
||||||
"github.com/d3m0k1d/BanForge/internal/config"
|
"github.com/d3m0k1d/BanForge/internal/config"
|
||||||
@@ -20,6 +19,8 @@ var DaemonCmd = &cobra.Command{
|
|||||||
Use: "daemon",
|
Use: "daemon",
|
||||||
Short: "Run BanForge daemon process",
|
Short: "Run BanForge daemon process",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
entryCh := make(chan *storage.LogEntry, 1000)
|
||||||
|
resultCh := make(chan *storage.LogEntry, 100)
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||||
defer stop()
|
defer stop()
|
||||||
log := logger.New(false)
|
log := logger.New(false)
|
||||||
@@ -48,19 +49,10 @@ var DaemonCmd = &cobra.Command{
|
|||||||
log.Error("Failed to load rules", "error", err)
|
log.Error("Failed to load rules", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
j := judge.New(db, b)
|
j := judge.New(db, b, resultCh, entryCh)
|
||||||
j.LoadRules(r)
|
j.LoadRules(r)
|
||||||
go j.UnbanChecker()
|
go j.UnbanChecker()
|
||||||
go func() {
|
go j.Tribunal()
|
||||||
ticker := time.NewTicker(5 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for range ticker.C {
|
|
||||||
if err := j.ProcessUnviewed(); err != nil {
|
|
||||||
log.Error("Failed to process unviewed", "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
resultCh := make(chan *storage.LogEntry, 1000)
|
|
||||||
go storage.Write(db, resultCh)
|
go storage.Write(db, resultCh)
|
||||||
var scanners []*parser.Scanner
|
var scanners []*parser.Scanner
|
||||||
|
|
||||||
@@ -99,12 +91,12 @@ var DaemonCmd = &cobra.Command{
|
|||||||
if svc.Name == "nginx" {
|
if svc.Name == "nginx" {
|
||||||
log.Info("Starting nginx parser", "service", serviceName)
|
log.Info("Starting nginx parser", "service", serviceName)
|
||||||
ng := parser.NewNginxParser()
|
ng := parser.NewNginxParser()
|
||||||
ng.Parse(p.Events(), resultCh)
|
ng.Parse(p.Events(), entryCh)
|
||||||
}
|
}
|
||||||
if svc.Name == "ssh" {
|
if svc.Name == "ssh" {
|
||||||
log.Info("Starting ssh parser", "service", serviceName)
|
log.Info("Starting ssh parser", "service", serviceName)
|
||||||
ssh := parser.NewSshdParser()
|
ssh := parser.NewSshdParser()
|
||||||
ssh.Parse(p.Events(), resultCh)
|
ssh.Parse(p.Events(), entryCh)
|
||||||
}
|
}
|
||||||
}(pars, svc.Name)
|
}(pars, svc.Name)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -16,14 +16,18 @@ type Judge struct {
|
|||||||
logger *logger.Logger
|
logger *logger.Logger
|
||||||
Blocker blocker.BlockerEngine
|
Blocker blocker.BlockerEngine
|
||||||
rulesByService map[string][]config.Rule
|
rulesByService map[string][]config.Rule
|
||||||
|
entryCh chan *storage.LogEntry
|
||||||
|
resultCh chan *storage.LogEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(db *storage.DB, b blocker.BlockerEngine) *Judge {
|
func New(db *storage.DB, b blocker.BlockerEngine, resultCh chan *storage.LogEntry, entryCh chan *storage.LogEntry) *Judge {
|
||||||
return &Judge{
|
return &Judge{
|
||||||
db: db,
|
db: db,
|
||||||
logger: logger.New(false),
|
logger: logger.New(false),
|
||||||
rulesByService: make(map[string][]config.Rule),
|
rulesByService: make(map[string][]config.Rule),
|
||||||
Blocker: b,
|
Blocker: b,
|
||||||
|
entryCh: entryCh,
|
||||||
|
resultCh: resultCh,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,84 +42,70 @@ func (j *Judge) LoadRules(rules []config.Rule) {
|
|||||||
j.logger.Info("Rules loaded and indexed by service")
|
j.logger.Info("Rules loaded and indexed by service")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *Judge) ProcessUnviewed() error {
|
func (j *Judge) Tribunal() {
|
||||||
rows, err := j.db.SearchUnViewed()
|
j.logger.Info("Tribunal started")
|
||||||
if err != nil {
|
|
||||||
j.logger.Error(fmt.Sprintf("Failed to query database: %v", err))
|
for entry := range j.entryCh {
|
||||||
return err
|
j.logger.Debug("Processing entry", "ip", entry.IP, "service", entry.Service, "status", entry.Status)
|
||||||
}
|
|
||||||
j.logger.Info("Unviewed logs found")
|
rules, serviceExists := j.rulesByService[entry.Service]
|
||||||
defer func() {
|
if !serviceExists {
|
||||||
err = rows.Close()
|
j.logger.Debug("No rules for service", "service", entry.Service)
|
||||||
if err != nil {
|
|
||||||
j.logger.Error(fmt.Sprintf("Failed to close database connection: %v", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for rows.Next() {
|
|
||||||
var entry storage.LogEntry
|
|
||||||
err = rows.Scan(
|
|
||||||
&entry.ID,
|
|
||||||
&entry.Service,
|
|
||||||
&entry.IP,
|
|
||||||
&entry.Path,
|
|
||||||
&entry.Status,
|
|
||||||
&entry.Method,
|
|
||||||
&entry.IsViewed,
|
|
||||||
&entry.CreatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
j.logger.Error(fmt.Sprintf("Failed to scan database row: %v", err))
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, serviceExists := j.rulesByService[entry.Service]
|
ruleMatched := false
|
||||||
if serviceExists {
|
for _, rule := range rules {
|
||||||
for _, rule := range rules {
|
methodMatch := rule.Method == "" || entry.Method == rule.Method
|
||||||
if (rule.Method == "" || entry.Method == rule.Method) &&
|
statusMatch := rule.Status == "" || entry.Status == rule.Status
|
||||||
(rule.Status == "" || entry.Status == rule.Status) &&
|
pathMatch := matchPath(entry.Path, rule.Path)
|
||||||
matchPath(entry.Path, rule.Path) {
|
|
||||||
|
|
||||||
j.logger.Info(
|
j.logger.Debug(
|
||||||
fmt.Sprintf(
|
"Testing rule",
|
||||||
"Rule matched for IP: %s, Service: %s",
|
"rule", rule.Name,
|
||||||
entry.IP,
|
"method_match", methodMatch,
|
||||||
entry.Service,
|
"status_match", statusMatch,
|
||||||
),
|
"path_match", pathMatch,
|
||||||
)
|
)
|
||||||
ban_status, err := j.db.IsBanned(entry.IP)
|
|
||||||
if err != nil {
|
if methodMatch && statusMatch && pathMatch {
|
||||||
j.logger.Error(fmt.Sprintf("Failed to check ban status: %v", err))
|
ruleMatched = true
|
||||||
return err
|
j.logger.Info("Rule matched", "rule", rule.Name, "ip", entry.IP)
|
||||||
}
|
|
||||||
if !ban_status {
|
banned, err := j.db.IsBanned(entry.IP)
|
||||||
err = j.Blocker.Ban(entry.IP)
|
if err != nil {
|
||||||
if err != nil {
|
j.logger.Error("Failed to check ban status", "ip", entry.IP, "error", err)
|
||||||
j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err))
|
|
||||||
}
|
|
||||||
j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP))
|
|
||||||
err = j.db.AddBan(entry.IP, rule.BanTime)
|
|
||||||
if err != nil {
|
|
||||||
j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if banned {
|
||||||
|
j.logger.Info("IP already banned", "ip", entry.IP)
|
||||||
|
j.resultCh <- entry
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err = j.db.AddBan(entry.IP, rule.BanTime)
|
||||||
|
if err != nil {
|
||||||
|
j.logger.Error("Failed to add ban to database", "ip", entry.IP, "ban_time", rule.BanTime, "error", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := j.Blocker.Ban(entry.IP); err != nil {
|
||||||
|
j.logger.Error("Failed to ban IP at firewall", "ip", entry.IP, "error", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
j.logger.Info("IP banned successfully", "ip", entry.IP, "rule", rule.Name, "ban_time", rule.BanTime)
|
||||||
|
j.resultCh <- entry
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = j.db.MarkAsViewed(entry.ID)
|
|
||||||
if err != nil {
|
if !ruleMatched {
|
||||||
j.logger.Error(fmt.Sprintf("Failed to mark entry as viewed: %v", err))
|
j.logger.Debug("No rules matched", "ip", entry.IP, "service", entry.Service)
|
||||||
} else {
|
|
||||||
j.logger.Info(fmt.Sprintf("Entry marked as viewed: ID=%d", entry.ID))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = rows.Err(); err != nil {
|
j.logger.Info("Tribunal stopped - entryCh closed")
|
||||||
j.logger.Error(fmt.Sprintf("Error iterating rows: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *Judge) UnbanChecker() {
|
func (j *Judge) UnbanChecker() {
|
||||||
|
|||||||
@@ -18,21 +18,21 @@ func TestJudgeLogic(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Empty rule",
|
name: "Empty rule",
|
||||||
inputRule: config.Rule{Name: "", ServiceName: "", Path: "", Status: "", Method: ""},
|
inputRule: config.Rule{Name: "", ServiceName: "", Path: "", Status: "", Method: ""},
|
||||||
inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""},
|
inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
wantMatch: false,
|
wantMatch: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Matching rule",
|
name: "Matching rule",
|
||||||
inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "200", Method: "GET"},
|
inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "200", Method: "GET"},
|
||||||
inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""},
|
inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
wantMatch: true,
|
wantMatch: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Non-matching status",
|
name: "Non-matching status",
|
||||||
inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "404", Method: "GET"},
|
inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "404", Method: "GET"},
|
||||||
inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""},
|
inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
wantMatch: false,
|
wantMatch: false,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -34,12 +34,11 @@ func (p *NginxParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEn
|
|||||||
method := matches[3]
|
method := matches[3]
|
||||||
|
|
||||||
resultCh <- &storage.LogEntry{
|
resultCh <- &storage.LogEntry{
|
||||||
Service: "nginx",
|
Service: "nginx",
|
||||||
IP: matches[1],
|
IP: matches[1],
|
||||||
Path: path,
|
Path: path,
|
||||||
Status: status,
|
Status: status,
|
||||||
Method: method,
|
Method: method,
|
||||||
IsViewed: false,
|
|
||||||
}
|
}
|
||||||
p.logger.Info(
|
p.logger.Info(
|
||||||
"Parsed nginx log entry",
|
"Parsed nginx log entry",
|
||||||
|
|||||||
@@ -31,12 +31,11 @@ func (p *SshdParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEnt
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resultCh <- &storage.LogEntry{
|
resultCh <- &storage.LogEntry{
|
||||||
Service: "ssh",
|
Service: "ssh",
|
||||||
IP: matches[6],
|
IP: matches[6],
|
||||||
Path: matches[5], // user
|
Path: matches[5], // user
|
||||||
Status: "Failed",
|
Status: "Failed",
|
||||||
Method: matches[4], // method auth
|
Method: matches[4], // method auth
|
||||||
IsViewed: false,
|
|
||||||
}
|
}
|
||||||
p.logger.Info(
|
p.logger.Info(
|
||||||
"Parsed ssh log entry",
|
"Parsed ssh log entry",
|
||||||
|
|||||||
@@ -2,15 +2,13 @@ package storage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"os"
|
|
||||||
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/d3m0k1d/BanForge/internal/config"
|
"github.com/d3m0k1d/BanForge/internal/config"
|
||||||
"github.com/d3m0k1d/BanForge/internal/logger"
|
"github.com/d3m0k1d/BanForge/internal/logger"
|
||||||
"github.com/jedib0t/go-pretty/v6/table"
|
"github.com/jedib0t/go-pretty/v6/table"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
@@ -23,8 +21,8 @@ func NewDB() (*DB, error) {
|
|||||||
"sqlite",
|
"sqlite",
|
||||||
"/var/lib/banforge/storage.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)",
|
"/var/lib/banforge/storage.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)",
|
||||||
)
|
)
|
||||||
db.SetMaxOpenConns(4)
|
db.SetMaxOpenConns(1)
|
||||||
db.SetMaxIdleConns(2)
|
db.SetMaxIdleConns(1)
|
||||||
db.SetConnMaxLifetime(0)
|
db.SetConnMaxLifetime(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -57,26 +55,6 @@ func (d *DB) CreateTable() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DB) SearchUnViewed() (*sql.Rows, error) {
|
|
||||||
rows, err := d.db.Query(
|
|
||||||
"SELECT id, service, ip, path, status, method, viewed, created_at FROM requests WHERE viewed = 0",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
d.logger.Error("Failed to query database")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DB) MarkAsViewed(id int) error {
|
|
||||||
_, err := d.db.Exec("UPDATE requests SET viewed = 1 WHERE id = ?", id)
|
|
||||||
if err != nil {
|
|
||||||
d.logger.Error("Failed to mark as viewed", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DB) IsBanned(ip string) (bool, error) {
|
func (d *DB) IsBanned(ip string) (bool, error) {
|
||||||
var bannedIP string
|
var bannedIP string
|
||||||
err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP)
|
err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP)
|
||||||
|
|||||||
@@ -74,99 +74,6 @@ func TestCreateTable(t *testing.T) {
|
|||||||
rows.Close()
|
rows.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarkAsViewed(t *testing.T) {
|
|
||||||
d := createTestDBStruct(t)
|
|
||||||
|
|
||||||
err := d.CreateTable()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = d.db.Exec(
|
|
||||||
"INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
"test",
|
|
||||||
"127.0.0.1",
|
|
||||||
"/test",
|
|
||||||
"GET",
|
|
||||||
"200",
|
|
||||||
time.Now().Format(time.RFC3339),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = d.MarkAsViewed(1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var isViewed bool
|
|
||||||
err = d.db.QueryRow("SELECT viewed FROM requests WHERE id = 1").Scan(&isViewed)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if !isViewed {
|
|
||||||
t.Fatal("viewed should be true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSearchUnViewed(t *testing.T) {
|
|
||||||
d := createTestDBStruct(t)
|
|
||||||
|
|
||||||
err := d.CreateTable()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
_, err := d.db.Exec(
|
|
||||||
"INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
"test",
|
|
||||||
"127.0.0.1",
|
|
||||||
"/test",
|
|
||||||
"GET",
|
|
||||||
"200",
|
|
||||||
time.Now().Format(time.RFC3339),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := d.SearchUnViewed()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
for rows.Next() {
|
|
||||||
var id int
|
|
||||||
var service, ip, path, status, method string
|
|
||||||
var viewed bool
|
|
||||||
var createdAt string
|
|
||||||
|
|
||||||
err := rows.Scan(&id, &service, &ip, &path, &status, &method, &viewed, &createdAt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if viewed {
|
|
||||||
t.Fatal("should be unviewed")
|
|
||||||
}
|
|
||||||
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if count != 2 {
|
|
||||||
t.Fatalf("expected 2 unviewed requests, got %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsBanned(t *testing.T) {
|
func TestIsBanned(t *testing.T) {
|
||||||
d := createTestDBStruct(t)
|
d := createTestDBStruct(t)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ CREATE TABLE IF NOT EXISTS requests (
|
|||||||
path TEXT,
|
path TEXT,
|
||||||
method TEXT,
|
method TEXT,
|
||||||
status TEXT,
|
status TEXT,
|
||||||
viewed BOOLEAN DEFAULT FALSE,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ type LogEntry struct {
|
|||||||
Path string `db:"path"`
|
Path string `db:"path"`
|
||||||
Status string `db:"status"`
|
Status string `db:"status"`
|
||||||
Method string `db:"method"`
|
Method string `db:"method"`
|
||||||
IsViewed bool `db:"viewed"`
|
|
||||||
CreatedAt string `db:"created_at"`
|
CreatedAt string `db:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user