feat: Add bantime and goroutines for unban expires ban
All checks were successful
CI.yml / build (push) Successful in 2m24s

This commit is contained in:
d3m0k1d
2026-01-19 16:03:12 +03:00
parent 6f24088069
commit 847002129d
9 changed files with 109 additions and 18 deletions

View File

@@ -50,6 +50,7 @@ var DaemonCmd = &cobra.Command{
} }
j := judge.New(db, b) j := judge.New(db, b)
j.LoadRules(r) j.LoadRules(r)
go j.UnbanChecker()
go func() { go func() {
ticker := time.NewTicker(5 * time.Second) ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop() defer ticker.Stop()

View File

@@ -14,6 +14,7 @@ var (
path string path string
status string status string
method string method string
ttl string
) )
var RuleCmd = &cobra.Command{ var RuleCmd = &cobra.Command{
@@ -37,8 +38,10 @@ var AddCmd = &cobra.Command{
fmt.Printf("At least 1 rule field must be filled in.") fmt.Printf("At least 1 rule field must be filled in.")
os.Exit(1) os.Exit(1)
} }
if ttl == "" {
err := config.NewRule(name, service, path, status, method) ttl = "1y"
}
err := config.NewRule(name, service, path, status, method, ttl)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@@ -68,6 +71,7 @@ func RuleRegister() {
AddCmd.Flags().StringVarP(&name, "name", "n", "", "rule name (required)") AddCmd.Flags().StringVarP(&name, "name", "n", "", "rule name (required)")
AddCmd.Flags().StringVarP(&service, "service", "s", "", "service name") AddCmd.Flags().StringVarP(&service, "service", "s", "", "service name")
AddCmd.Flags().StringVarP(&path, "path", "p", "", "request path") AddCmd.Flags().StringVarP(&path, "path", "p", "", "request path")
AddCmd.Flags().StringVarP(&status, "status", "c", "", "HTTP status code") AddCmd.Flags().StringVarP(&status, "status", "c", "", "status code")
AddCmd.Flags().StringVarP(&method, "method", "m", "", "HTTP method") AddCmd.Flags().StringVarP(&method, "method", "m", "", "method")
AddCmd.Flags().StringVarP(&ttl, "ttl", "t", "", "ban time")
} }

View File

@@ -3,6 +3,9 @@ package config
import ( import (
"fmt" "fmt"
"os" "os"
"strconv"
"strings"
"time"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
@@ -22,7 +25,7 @@ func LoadRuleConfig() ([]Rule, error) {
return cfg.Rules, nil return cfg.Rules, nil
} }
func NewRule(Name string, ServiceName string, Path string, Status string, Method string) error { func NewRule(Name string, ServiceName string, Path string, Status string, Method string, ttl string) error {
r, err := LoadRuleConfig() r, err := LoadRuleConfig()
if err != nil { if err != nil {
r = []Rule{} r = []Rule{}
@@ -31,7 +34,7 @@ func NewRule(Name string, ServiceName string, Path string, Status string, Method
fmt.Printf("Rule name can't be empty\n") fmt.Printf("Rule name can't be empty\n")
return nil return nil
} }
r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method}) r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method, BanTime: ttl})
file, err := os.Create("/etc/banforge/rules.toml") file, err := os.Create("/etc/banforge/rules.toml")
if err != nil { if err != nil {
return err return err
@@ -104,3 +107,31 @@ func EditRule(Name string, ServiceName string, Path string, Status string, Metho
return nil return nil
} }
func ParseDurationWithYears(s string) (time.Duration, error) {
if strings.HasSuffix(s, "y") {
years, err := strconv.Atoi(strings.TrimSuffix(s, "y"))
if err != nil {
return 0, err
}
return time.Duration(years) * 365 * 24 * time.Hour, nil
}
if strings.HasSuffix(s, "M") {
months, err := strconv.Atoi(strings.TrimSuffix(s, "M"))
if err != nil {
return 0, err
}
return time.Duration(months) * 30 * 24 * time.Hour, nil
}
if strings.HasSuffix(s, "d") {
days, err := strconv.Atoi(strings.TrimSuffix(s, "d"))
if err != nil {
return 0, err
}
return time.Duration(days) * 24 * time.Hour, nil
}
return time.ParseDuration(s)
}

View File

@@ -7,7 +7,6 @@ const Base_config = `
[firewall] [firewall]
name = "" name = ""
config = "/etc/nftables.conf" config = "/etc/nftables.conf"
ban_time = 1200
[[service]] [[service]]
name = "nginx" name = "nginx"
@@ -18,7 +17,4 @@ enabled = true
name = "nginx" name = "nginx"
log_path = "/var/log/nginx/access.log" log_path = "/var/log/nginx/access.log"
enabled = false enabled = false
` `
// TODO: fix types for use 1 or any services"

View File

@@ -1,9 +1,8 @@
package config package config
type Firewall struct { type Firewall struct {
Name string `toml:"name"` Name string `toml:"name"`
Config string `toml:"config"` Config string `toml:"config"`
BanTime int `toml:"ban_time"`
} }
type Service struct { type Service struct {
@@ -28,4 +27,5 @@ type Rule struct {
Path string `toml:"path"` Path string `toml:"path"`
Status string `toml:"status"` Status string `toml:"status"`
Method string `toml:"method"` Method string `toml:"method"`
BanTime string `toml:"ban_time"`
} }

View File

@@ -2,6 +2,7 @@ package judge
import ( import (
"fmt" "fmt"
"time"
"github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/blocker"
"github.com/d3m0k1d/BanForge/internal/config" "github.com/d3m0k1d/BanForge/internal/config"
@@ -75,7 +76,7 @@ func (j *Judge) ProcessUnviewed() error {
j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err)) j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err))
} }
j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP)) j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP))
err = j.db.AddBan(entry.IP) err = j.db.AddBan(entry.IP, rule.BanTime)
if err != nil { if err != nil {
j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err)) j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err))
} }
@@ -100,3 +101,24 @@ func (j *Judge) ProcessUnviewed() error {
return nil return nil
} }
func (j *Judge) UnbanChecker() {
tick := time.NewTicker(5 * time.Minute)
defer tick.Stop()
for range tick.C {
ips, err := j.db.CheckExpiredBans()
if err != nil {
j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err))
continue
}
for _, ip := range ips {
if err := j.Blocker.Unban(ip); err != nil {
j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err))
continue
}
j.logger.Info(fmt.Sprintf("IP unbanned: %s", ip))
}
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/d3m0k1d/BanForge/internal/config"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
"github.com/jedib0t/go-pretty/v6/table" "github.com/jedib0t/go-pretty/v6/table"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@@ -80,12 +81,28 @@ func (d *DB) IsBanned(ip string) (bool, error) {
return true, nil return true, nil
} }
func (d *DB) AddBan(ip string) error { func (d *DB) AddBan(ip string, ttl string) error {
_, err := d.db.Exec("INSERT INTO bans (ip, reason, banned_at) VALUES (?, ?, ?)", ip, "1", time.Now().Format(time.RFC3339)) duration, err := config.ParseDurationWithYears(ttl)
if err != nil {
d.logger.Error("Invalid duration format", "ttl", ttl, "error", err)
return fmt.Errorf("invalid duration: %w", err)
}
now := time.Now()
expiredAt := now.Add(duration)
_, err = d.db.Exec(
"INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)",
ip,
"1",
now.Format(time.RFC3339),
expiredAt.Format(time.RFC3339),
)
if err != nil { if err != nil {
d.logger.Error("Failed to add ban", "error", err) d.logger.Error("Failed to add ban", "error", err)
return err return err
} }
return nil return nil
} }
@@ -116,3 +133,22 @@ func (d *DB) BanList() error {
t.Render() t.Render()
return nil return nil
} }
func (d *DB) CheckExpiredBans() ([]string, error) {
var ips []string
rows, err := d.db.Query("SELECT ip FROM bans WHERE expired_at < ?", time.Now().Format(time.RFC3339))
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return nil, err
}
for rows.Next() {
var ip string
err := rows.Scan(&ip)
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return nil, err
}
ips = append(ips, ip)
}
return ips, nil
}

View File

@@ -198,7 +198,7 @@ func TestAddBan(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = d.AddBan("127.0.0.1") err = d.AddBan("127.0.0.1", "7h")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS bans (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
ip TEXT UNIQUE NOT NULL, ip TEXT UNIQUE NOT NULL,
reason TEXT, reason TEXT,
banned_at DATETIME DEFAULT CURRENT_TIMESTAMP banned_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expired_at DATETIME
); );
CREATE INDEX IF NOT EXISTS idx_service ON requests(service); CREATE INDEX IF NOT EXISTS idx_service ON requests(service);