diff --git a/cmd/banforge/command/daemon.go b/cmd/banforge/command/daemon.go index 6a27362..1a17193 100644 --- a/cmd/banforge/command/daemon.go +++ b/cmd/banforge/command/daemon.go @@ -50,6 +50,7 @@ var DaemonCmd = &cobra.Command{ } j := judge.New(db, b) j.LoadRules(r) + go j.UnbanChecker() go func() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() diff --git a/cmd/banforge/command/rule.go b/cmd/banforge/command/rule.go index 94d6dc5..ac200aa 100644 --- a/cmd/banforge/command/rule.go +++ b/cmd/banforge/command/rule.go @@ -14,6 +14,7 @@ var ( path string status string method string + ttl string ) var RuleCmd = &cobra.Command{ @@ -37,8 +38,10 @@ var AddCmd = &cobra.Command{ fmt.Printf("At least 1 rule field must be filled in.") os.Exit(1) } - - err := config.NewRule(name, service, path, status, method) + if ttl == "" { + ttl = "1y" + } + err := config.NewRule(name, service, path, status, method, ttl) if err != nil { fmt.Println(err) os.Exit(1) @@ -68,6 +71,7 @@ func RuleRegister() { AddCmd.Flags().StringVarP(&name, "name", "n", "", "rule name (required)") AddCmd.Flags().StringVarP(&service, "service", "s", "", "service name") AddCmd.Flags().StringVarP(&path, "path", "p", "", "request path") - AddCmd.Flags().StringVarP(&status, "status", "c", "", "HTTP status code") - AddCmd.Flags().StringVarP(&method, "method", "m", "", "HTTP method") + AddCmd.Flags().StringVarP(&status, "status", "c", "", "status code") + AddCmd.Flags().StringVarP(&method, "method", "m", "", "method") + AddCmd.Flags().StringVarP(&ttl, "ttl", "t", "", "ban time") } diff --git a/internal/config/appconf.go b/internal/config/appconf.go index 68fd56a..f7a6a1b 100644 --- a/internal/config/appconf.go +++ b/internal/config/appconf.go @@ -3,6 +3,9 @@ package config import ( "fmt" "os" + "strconv" + "strings" + "time" "github.com/BurntSushi/toml" "github.com/d3m0k1d/BanForge/internal/logger" @@ -22,7 +25,7 @@ func LoadRuleConfig() ([]Rule, error) { return cfg.Rules, nil } -func NewRule(Name string, ServiceName string, Path string, Status string, Method string) error { +func NewRule(Name string, ServiceName string, Path string, Status string, Method string, ttl string) error { r, err := LoadRuleConfig() if err != nil { r = []Rule{} @@ -31,7 +34,7 @@ func NewRule(Name string, ServiceName string, Path string, Status string, Method fmt.Printf("Rule name can't be empty\n") return nil } - r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method}) + r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method, BanTime: ttl}) file, err := os.Create("/etc/banforge/rules.toml") if err != nil { return err @@ -104,3 +107,31 @@ func EditRule(Name string, ServiceName string, Path string, Status string, Metho return nil } + +func ParseDurationWithYears(s string) (time.Duration, error) { + if strings.HasSuffix(s, "y") { + years, err := strconv.Atoi(strings.TrimSuffix(s, "y")) + if err != nil { + return 0, err + } + return time.Duration(years) * 365 * 24 * time.Hour, nil + } + + if strings.HasSuffix(s, "M") { + months, err := strconv.Atoi(strings.TrimSuffix(s, "M")) + if err != nil { + return 0, err + } + return time.Duration(months) * 30 * 24 * time.Hour, nil + } + + if strings.HasSuffix(s, "d") { + days, err := strconv.Atoi(strings.TrimSuffix(s, "d")) + if err != nil { + return 0, err + } + return time.Duration(days) * 24 * time.Hour, nil + } + + return time.ParseDuration(s) +} diff --git a/internal/config/template.go b/internal/config/template.go index dc0fd0d..6f468ea 100644 --- a/internal/config/template.go +++ b/internal/config/template.go @@ -7,7 +7,6 @@ const Base_config = ` [firewall] name = "" config = "/etc/nftables.conf" -ban_time = 1200 [[service]] name = "nginx" @@ -18,7 +17,4 @@ enabled = true name = "nginx" log_path = "/var/log/nginx/access.log" enabled = false - ` - -// TODO: fix types for use 1 or any services" diff --git a/internal/config/types.go b/internal/config/types.go index c14f805..43d2d11 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -1,9 +1,8 @@ package config type Firewall struct { - Name string `toml:"name"` - Config string `toml:"config"` - BanTime int `toml:"ban_time"` + Name string `toml:"name"` + Config string `toml:"config"` } type Service struct { @@ -28,4 +27,5 @@ type Rule struct { Path string `toml:"path"` Status string `toml:"status"` Method string `toml:"method"` + BanTime string `toml:"ban_time"` } diff --git a/internal/judge/judge.go b/internal/judge/judge.go index ac17200..14d896e 100644 --- a/internal/judge/judge.go +++ b/internal/judge/judge.go @@ -2,6 +2,7 @@ package judge import ( "fmt" + "time" "github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/config" @@ -75,7 +76,7 @@ func (j *Judge) ProcessUnviewed() error { j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err)) } j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP)) - err = j.db.AddBan(entry.IP) + err = j.db.AddBan(entry.IP, rule.BanTime) if err != nil { j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err)) } @@ -100,3 +101,24 @@ func (j *Judge) ProcessUnviewed() error { return nil } + +func (j *Judge) UnbanChecker() { + tick := time.NewTicker(5 * time.Minute) + defer tick.Stop() + + for range tick.C { + ips, err := j.db.CheckExpiredBans() + if err != nil { + j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err)) + continue + } + + for _, ip := range ips { + if err := j.Blocker.Unban(ip); err != nil { + j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err)) + continue + } + j.logger.Info(fmt.Sprintf("IP unbanned: %s", ip)) + } + } +} diff --git a/internal/storage/db.go b/internal/storage/db.go index b84f9d3..290e381 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/d3m0k1d/BanForge/internal/config" "github.com/d3m0k1d/BanForge/internal/logger" "github.com/jedib0t/go-pretty/v6/table" _ "github.com/mattn/go-sqlite3" @@ -80,12 +81,28 @@ func (d *DB) IsBanned(ip string) (bool, error) { return true, nil } -func (d *DB) AddBan(ip string) error { - _, err := d.db.Exec("INSERT INTO bans (ip, reason, banned_at) VALUES (?, ?, ?)", ip, "1", time.Now().Format(time.RFC3339)) +func (d *DB) AddBan(ip string, ttl string) error { + duration, err := config.ParseDurationWithYears(ttl) + if err != nil { + d.logger.Error("Invalid duration format", "ttl", ttl, "error", err) + return fmt.Errorf("invalid duration: %w", err) + } + + now := time.Now() + expiredAt := now.Add(duration) + + _, err = d.db.Exec( + "INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)", + ip, + "1", + now.Format(time.RFC3339), + expiredAt.Format(time.RFC3339), + ) if err != nil { d.logger.Error("Failed to add ban", "error", err) return err } + return nil } @@ -116,3 +133,22 @@ func (d *DB) BanList() error { t.Render() return nil } + +func (d *DB) CheckExpiredBans() ([]string, error) { + var ips []string + rows, err := d.db.Query("SELECT ip FROM bans WHERE expired_at < ?", time.Now().Format(time.RFC3339)) + if err != nil { + d.logger.Error("Failed to get ban list", "error", err) + return nil, err + } + for rows.Next() { + var ip string + err := rows.Scan(&ip) + if err != nil { + d.logger.Error("Failed to get ban list", "error", err) + return nil, err + } + ips = append(ips, ip) + } + return ips, nil +} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 8c5028f..a759807 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -198,7 +198,7 @@ func TestAddBan(t *testing.T) { t.Fatal(err) } - err = d.AddBan("127.0.0.1") + err = d.AddBan("127.0.0.1", "7h") if err != nil { t.Fatal(err) } diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go index a96d3ba..35ad85a 100644 --- a/internal/storage/migrations.go +++ b/internal/storage/migrations.go @@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS bans ( id INTEGER PRIMARY KEY, ip TEXT UNIQUE NOT NULL, reason TEXT, - banned_at DATETIME DEFAULT CURRENT_TIMESTAMP + banned_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expired_at DATETIME ); CREATE INDEX IF NOT EXISTS idx_service ON requests(service);