diff --git a/cmd/banforge/command/fw.go b/cmd/banforge/command/fw.go index 03aa7b3..12b66ee 100644 --- a/cmd/banforge/command/fw.go +++ b/cmd/banforge/command/fw.go @@ -7,6 +7,7 @@ import ( "github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/config" + "github.com/d3m0k1d/BanForge/internal/storage" "github.com/spf13/cobra" ) @@ -17,6 +18,11 @@ var UnbanCmd = &cobra.Command{ Use: "unban", Short: "Unban IP", Run: func(cmd *cobra.Command, args []string) { + db, err := storage.NewDB() + if err != nil { + fmt.Println(err) + os.Exit(1) + } cfg, err := config.LoadConfig() if err != nil { fmt.Println(err) @@ -41,6 +47,11 @@ var UnbanCmd = &cobra.Command{ fmt.Println(err) os.Exit(1) } + err = db.RemoveBan(ip) + if err != nil { + fmt.Println(err) + os.Exit(1) + } fmt.Println("IP unblocked successfully!") }, } @@ -49,7 +60,11 @@ var BanCmd = &cobra.Command{ Use: "ban", Short: "Ban IP", Run: func(cmd *cobra.Command, args []string) { - + db, err := storage.NewDB() + if err != nil { + fmt.Println(err) + os.Exit(1) + } cfg, err := config.LoadConfig() if err != nil { fmt.Println(err) @@ -74,7 +89,12 @@ var BanCmd = &cobra.Command{ fmt.Println(err) os.Exit(1) } - fmt.Println("IP unblocked successfully!") + err = db.AddBan(ip, "1y") + if err != nil { + fmt.Println(err) + os.Exit(1) + } + fmt.Println("IP blocked successfully!") }, } diff --git a/internal/blocker/nftables.go b/internal/blocker/nftables.go index d4591ce..f161783 100644 --- a/internal/blocker/nftables.go +++ b/internal/blocker/nftables.go @@ -104,15 +104,14 @@ func (n *Nftables) Setup(config string) error { nftConfig := `table inet banforge { chain input { - type filter hook input priority 0 - policy accept + type filter hook input priority filter; policy accept; + jump banned } chain banned { } } ` - cmd := exec.Command("sudo", "tee", config) stdin, err := cmd.StdinPipe() if err != nil { diff --git a/internal/judge/judge.go b/internal/judge/judge.go index b55ba21..84a9881 100644 --- a/internal/judge/judge.go +++ b/internal/judge/judge.go @@ -129,6 +129,10 @@ func (j *Judge) UnbanChecker() { } for _, ip := range ips { + err = j.db.RemoveBan(ip) + if err != nil { + j.logger.Error(fmt.Sprintf("Failed to remove ban: %v", err)) + } if err := j.Blocker.Unban(ip); err != nil { j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err)) continue diff --git a/internal/storage/db.go b/internal/storage/db.go index 2dff4eb..3cf5f56 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -111,6 +111,15 @@ func (d *DB) AddBan(ip string, ttl string) error { return nil } +func (d *DB) RemoveBan(ip string) error { + _, err := d.db.Exec("DELETE FROM bans WHERE ip = ?", ip) + if err != nil { + d.logger.Error("Failed to remove ban", "error", err) + return err + } + return nil +} + func (d *DB) BanList() error { var count int