feat: Add bantime and goroutines for unban expires ban
All checks were successful
CI.yml / build (push) Successful in 2m24s
All checks were successful
CI.yml / build (push) Successful in 2m24s
This commit is contained in:
@@ -50,6 +50,7 @@ var DaemonCmd = &cobra.Command{
|
||||
}
|
||||
j := judge.New(db, b)
|
||||
j.LoadRules(r)
|
||||
go j.UnbanChecker()
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -14,6 +14,7 @@ var (
|
||||
path string
|
||||
status string
|
||||
method string
|
||||
ttl string
|
||||
)
|
||||
|
||||
var RuleCmd = &cobra.Command{
|
||||
@@ -37,8 +38,10 @@ var AddCmd = &cobra.Command{
|
||||
fmt.Printf("At least 1 rule field must be filled in.")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err := config.NewRule(name, service, path, status, method)
|
||||
if ttl == "" {
|
||||
ttl = "1y"
|
||||
}
|
||||
err := config.NewRule(name, service, path, status, method, ttl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
@@ -68,6 +71,7 @@ func RuleRegister() {
|
||||
AddCmd.Flags().StringVarP(&name, "name", "n", "", "rule name (required)")
|
||||
AddCmd.Flags().StringVarP(&service, "service", "s", "", "service name")
|
||||
AddCmd.Flags().StringVarP(&path, "path", "p", "", "request path")
|
||||
AddCmd.Flags().StringVarP(&status, "status", "c", "", "HTTP status code")
|
||||
AddCmd.Flags().StringVarP(&method, "method", "m", "", "HTTP method")
|
||||
AddCmd.Flags().StringVarP(&status, "status", "c", "", "status code")
|
||||
AddCmd.Flags().StringVarP(&method, "method", "m", "", "method")
|
||||
AddCmd.Flags().StringVarP(&ttl, "ttl", "t", "", "ban time")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/d3m0k1d/BanForge/internal/logger"
|
||||
@@ -22,7 +25,7 @@ func LoadRuleConfig() ([]Rule, error) {
|
||||
return cfg.Rules, nil
|
||||
}
|
||||
|
||||
func NewRule(Name string, ServiceName string, Path string, Status string, Method string) error {
|
||||
func NewRule(Name string, ServiceName string, Path string, Status string, Method string, ttl string) error {
|
||||
r, err := LoadRuleConfig()
|
||||
if err != nil {
|
||||
r = []Rule{}
|
||||
@@ -31,7 +34,7 @@ func NewRule(Name string, ServiceName string, Path string, Status string, Method
|
||||
fmt.Printf("Rule name can't be empty\n")
|
||||
return nil
|
||||
}
|
||||
r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method})
|
||||
r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method, BanTime: ttl})
|
||||
file, err := os.Create("/etc/banforge/rules.toml")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -104,3 +107,31 @@ func EditRule(Name string, ServiceName string, Path string, Status string, Metho
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseDurationWithYears(s string) (time.Duration, error) {
|
||||
if strings.HasSuffix(s, "y") {
|
||||
years, err := strconv.Atoi(strings.TrimSuffix(s, "y"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return time.Duration(years) * 365 * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
if strings.HasSuffix(s, "M") {
|
||||
months, err := strconv.Atoi(strings.TrimSuffix(s, "M"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return time.Duration(months) * 30 * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
if strings.HasSuffix(s, "d") {
|
||||
days, err := strconv.Atoi(strings.TrimSuffix(s, "d"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return time.Duration(days) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
return time.ParseDuration(s)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ const Base_config = `
|
||||
[firewall]
|
||||
name = ""
|
||||
config = "/etc/nftables.conf"
|
||||
ban_time = 1200
|
||||
|
||||
[[service]]
|
||||
name = "nginx"
|
||||
@@ -18,7 +17,4 @@ enabled = true
|
||||
name = "nginx"
|
||||
log_path = "/var/log/nginx/access.log"
|
||||
enabled = false
|
||||
|
||||
`
|
||||
|
||||
// TODO: fix types for use 1 or any services"
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package config
|
||||
|
||||
type Firewall struct {
|
||||
Name string `toml:"name"`
|
||||
Config string `toml:"config"`
|
||||
BanTime int `toml:"ban_time"`
|
||||
Name string `toml:"name"`
|
||||
Config string `toml:"config"`
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
@@ -28,4 +27,5 @@ type Rule struct {
|
||||
Path string `toml:"path"`
|
||||
Status string `toml:"status"`
|
||||
Method string `toml:"method"`
|
||||
BanTime string `toml:"ban_time"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package judge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/d3m0k1d/BanForge/internal/blocker"
|
||||
"github.com/d3m0k1d/BanForge/internal/config"
|
||||
@@ -75,7 +76,7 @@ func (j *Judge) ProcessUnviewed() error {
|
||||
j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err))
|
||||
}
|
||||
j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP))
|
||||
err = j.db.AddBan(entry.IP)
|
||||
err = j.db.AddBan(entry.IP, rule.BanTime)
|
||||
if err != nil {
|
||||
j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err))
|
||||
}
|
||||
@@ -100,3 +101,24 @@ func (j *Judge) ProcessUnviewed() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *Judge) UnbanChecker() {
|
||||
tick := time.NewTicker(5 * time.Minute)
|
||||
defer tick.Stop()
|
||||
|
||||
for range tick.C {
|
||||
ips, err := j.db.CheckExpiredBans()
|
||||
if err != nil {
|
||||
j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if err := j.Blocker.Unban(ip); err != nil {
|
||||
j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err))
|
||||
continue
|
||||
}
|
||||
j.logger.Info(fmt.Sprintf("IP unbanned: %s", ip))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/d3m0k1d/BanForge/internal/config"
|
||||
"github.com/d3m0k1d/BanForge/internal/logger"
|
||||
"github.com/jedib0t/go-pretty/v6/table"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@@ -80,12 +81,28 @@ func (d *DB) IsBanned(ip string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (d *DB) AddBan(ip string) error {
|
||||
_, err := d.db.Exec("INSERT INTO bans (ip, reason, banned_at) VALUES (?, ?, ?)", ip, "1", time.Now().Format(time.RFC3339))
|
||||
func (d *DB) AddBan(ip string, ttl string) error {
|
||||
duration, err := config.ParseDurationWithYears(ttl)
|
||||
if err != nil {
|
||||
d.logger.Error("Invalid duration format", "ttl", ttl, "error", err)
|
||||
return fmt.Errorf("invalid duration: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expiredAt := now.Add(duration)
|
||||
|
||||
_, err = d.db.Exec(
|
||||
"INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)",
|
||||
ip,
|
||||
"1",
|
||||
now.Format(time.RFC3339),
|
||||
expiredAt.Format(time.RFC3339),
|
||||
)
|
||||
if err != nil {
|
||||
d.logger.Error("Failed to add ban", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,3 +133,22 @@ func (d *DB) BanList() error {
|
||||
t.Render()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) CheckExpiredBans() ([]string, error) {
|
||||
var ips []string
|
||||
rows, err := d.db.Query("SELECT ip FROM bans WHERE expired_at < ?", time.Now().Format(time.RFC3339))
|
||||
if err != nil {
|
||||
d.logger.Error("Failed to get ban list", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var ip string
|
||||
err := rows.Scan(&ip)
|
||||
if err != nil {
|
||||
d.logger.Error("Failed to get ban list", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestAddBan(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = d.AddBan("127.0.0.1")
|
||||
err = d.AddBan("127.0.0.1", "7h")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS bans (
|
||||
id INTEGER PRIMARY KEY,
|
||||
ip TEXT UNIQUE NOT NULL,
|
||||
reason TEXT,
|
||||
banned_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
banned_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
expired_at DATETIME
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_service ON requests(service);
|
||||
|
||||
Reference in New Issue
Block a user