feat: Add bantime and goroutines for unban expires ban
All checks were successful
CI.yml / build (push) Successful in 2m24s
All checks were successful
CI.yml / build (push) Successful in 2m24s
This commit is contained in:
@@ -50,6 +50,7 @@ var DaemonCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
j := judge.New(db, b)
|
j := judge.New(db, b)
|
||||||
j.LoadRules(r)
|
j.LoadRules(r)
|
||||||
|
go j.UnbanChecker()
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(5 * time.Second)
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ var (
|
|||||||
path string
|
path string
|
||||||
status string
|
status string
|
||||||
method string
|
method string
|
||||||
|
ttl string
|
||||||
)
|
)
|
||||||
|
|
||||||
var RuleCmd = &cobra.Command{
|
var RuleCmd = &cobra.Command{
|
||||||
@@ -37,8 +38,10 @@ var AddCmd = &cobra.Command{
|
|||||||
fmt.Printf("At least 1 rule field must be filled in.")
|
fmt.Printf("At least 1 rule field must be filled in.")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
if ttl == "" {
|
||||||
err := config.NewRule(name, service, path, status, method)
|
ttl = "1y"
|
||||||
|
}
|
||||||
|
err := config.NewRule(name, service, path, status, method, ttl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@@ -68,6 +71,7 @@ func RuleRegister() {
|
|||||||
AddCmd.Flags().StringVarP(&name, "name", "n", "", "rule name (required)")
|
AddCmd.Flags().StringVarP(&name, "name", "n", "", "rule name (required)")
|
||||||
AddCmd.Flags().StringVarP(&service, "service", "s", "", "service name")
|
AddCmd.Flags().StringVarP(&service, "service", "s", "", "service name")
|
||||||
AddCmd.Flags().StringVarP(&path, "path", "p", "", "request path")
|
AddCmd.Flags().StringVarP(&path, "path", "p", "", "request path")
|
||||||
AddCmd.Flags().StringVarP(&status, "status", "c", "", "HTTP status code")
|
AddCmd.Flags().StringVarP(&status, "status", "c", "", "status code")
|
||||||
AddCmd.Flags().StringVarP(&method, "method", "m", "", "HTTP method")
|
AddCmd.Flags().StringVarP(&method, "method", "m", "", "method")
|
||||||
|
AddCmd.Flags().StringVarP(&ttl, "ttl", "t", "", "ban time")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package config
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
"github.com/d3m0k1d/BanForge/internal/logger"
|
"github.com/d3m0k1d/BanForge/internal/logger"
|
||||||
@@ -22,7 +25,7 @@ func LoadRuleConfig() ([]Rule, error) {
|
|||||||
return cfg.Rules, nil
|
return cfg.Rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRule(Name string, ServiceName string, Path string, Status string, Method string) error {
|
func NewRule(Name string, ServiceName string, Path string, Status string, Method string, ttl string) error {
|
||||||
r, err := LoadRuleConfig()
|
r, err := LoadRuleConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r = []Rule{}
|
r = []Rule{}
|
||||||
@@ -31,7 +34,7 @@ func NewRule(Name string, ServiceName string, Path string, Status string, Method
|
|||||||
fmt.Printf("Rule name can't be empty\n")
|
fmt.Printf("Rule name can't be empty\n")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method})
|
r = append(r, Rule{Name: Name, ServiceName: ServiceName, Path: Path, Status: Status, Method: Method, BanTime: ttl})
|
||||||
file, err := os.Create("/etc/banforge/rules.toml")
|
file, err := os.Create("/etc/banforge/rules.toml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -104,3 +107,31 @@ func EditRule(Name string, ServiceName string, Path string, Status string, Metho
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ParseDurationWithYears(s string) (time.Duration, error) {
|
||||||
|
if strings.HasSuffix(s, "y") {
|
||||||
|
years, err := strconv.Atoi(strings.TrimSuffix(s, "y"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return time.Duration(years) * 365 * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(s, "M") {
|
||||||
|
months, err := strconv.Atoi(strings.TrimSuffix(s, "M"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return time.Duration(months) * 30 * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(s, "d") {
|
||||||
|
days, err := strconv.Atoi(strings.TrimSuffix(s, "d"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return time.Duration(days) * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.ParseDuration(s)
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ const Base_config = `
|
|||||||
[firewall]
|
[firewall]
|
||||||
name = ""
|
name = ""
|
||||||
config = "/etc/nftables.conf"
|
config = "/etc/nftables.conf"
|
||||||
ban_time = 1200
|
|
||||||
|
|
||||||
[[service]]
|
[[service]]
|
||||||
name = "nginx"
|
name = "nginx"
|
||||||
@@ -18,7 +17,4 @@ enabled = true
|
|||||||
name = "nginx"
|
name = "nginx"
|
||||||
log_path = "/var/log/nginx/access.log"
|
log_path = "/var/log/nginx/access.log"
|
||||||
enabled = false
|
enabled = false
|
||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
// TODO: fix types for use 1 or any services"
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package config
|
|||||||
type Firewall struct {
|
type Firewall struct {
|
||||||
Name string `toml:"name"`
|
Name string `toml:"name"`
|
||||||
Config string `toml:"config"`
|
Config string `toml:"config"`
|
||||||
BanTime int `toml:"ban_time"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
@@ -28,4 +27,5 @@ type Rule struct {
|
|||||||
Path string `toml:"path"`
|
Path string `toml:"path"`
|
||||||
Status string `toml:"status"`
|
Status string `toml:"status"`
|
||||||
Method string `toml:"method"`
|
Method string `toml:"method"`
|
||||||
|
BanTime string `toml:"ban_time"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package judge
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/d3m0k1d/BanForge/internal/blocker"
|
"github.com/d3m0k1d/BanForge/internal/blocker"
|
||||||
"github.com/d3m0k1d/BanForge/internal/config"
|
"github.com/d3m0k1d/BanForge/internal/config"
|
||||||
@@ -75,7 +76,7 @@ func (j *Judge) ProcessUnviewed() error {
|
|||||||
j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err))
|
j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err))
|
||||||
}
|
}
|
||||||
j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP))
|
j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP))
|
||||||
err = j.db.AddBan(entry.IP)
|
err = j.db.AddBan(entry.IP, rule.BanTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err))
|
j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err))
|
||||||
}
|
}
|
||||||
@@ -100,3 +101,24 @@ func (j *Judge) ProcessUnviewed() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (j *Judge) UnbanChecker() {
|
||||||
|
tick := time.NewTicker(5 * time.Minute)
|
||||||
|
defer tick.Stop()
|
||||||
|
|
||||||
|
for range tick.C {
|
||||||
|
ips, err := j.db.CheckExpiredBans()
|
||||||
|
if err != nil {
|
||||||
|
j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
if err := j.Blocker.Unban(ip); err != nil {
|
||||||
|
j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
j.logger.Info(fmt.Sprintf("IP unbanned: %s", ip))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/d3m0k1d/BanForge/internal/config"
|
||||||
"github.com/d3m0k1d/BanForge/internal/logger"
|
"github.com/d3m0k1d/BanForge/internal/logger"
|
||||||
"github.com/jedib0t/go-pretty/v6/table"
|
"github.com/jedib0t/go-pretty/v6/table"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
@@ -80,12 +81,28 @@ func (d *DB) IsBanned(ip string) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DB) AddBan(ip string) error {
|
func (d *DB) AddBan(ip string, ttl string) error {
|
||||||
_, err := d.db.Exec("INSERT INTO bans (ip, reason, banned_at) VALUES (?, ?, ?)", ip, "1", time.Now().Format(time.RFC3339))
|
duration, err := config.ParseDurationWithYears(ttl)
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Error("Invalid duration format", "ttl", ttl, "error", err)
|
||||||
|
return fmt.Errorf("invalid duration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expiredAt := now.Add(duration)
|
||||||
|
|
||||||
|
_, err = d.db.Exec(
|
||||||
|
"INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)",
|
||||||
|
ip,
|
||||||
|
"1",
|
||||||
|
now.Format(time.RFC3339),
|
||||||
|
expiredAt.Format(time.RFC3339),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.logger.Error("Failed to add ban", "error", err)
|
d.logger.Error("Failed to add ban", "error", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,3 +133,22 @@ func (d *DB) BanList() error {
|
|||||||
t.Render()
|
t.Render()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DB) CheckExpiredBans() ([]string, error) {
|
||||||
|
var ips []string
|
||||||
|
rows, err := d.db.Query("SELECT ip FROM bans WHERE expired_at < ?", time.Now().Format(time.RFC3339))
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Error("Failed to get ban list", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for rows.Next() {
|
||||||
|
var ip string
|
||||||
|
err := rows.Scan(&ip)
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Error("Failed to get ban list", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
return ips, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ func TestAddBan(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = d.AddBan("127.0.0.1")
|
err = d.AddBan("127.0.0.1", "7h")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS bans (
|
|||||||
id INTEGER PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
ip TEXT UNIQUE NOT NULL,
|
ip TEXT UNIQUE NOT NULL,
|
||||||
reason TEXT,
|
reason TEXT,
|
||||||
banned_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
banned_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
expired_at DATETIME
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_service ON requests(service);
|
CREATE INDEX IF NOT EXISTS idx_service ON requests(service);
|
||||||
|
|||||||
Reference in New Issue
Block a user