69 Commits

Author SHA1 Message Date
d3m0k1d
aef2647a82 chore: fix deps
All checks were successful
build / build (push) Successful in 2m28s
CD - BanForge Release / release (push) Successful in 4m10s
2026-02-22 19:52:22 +03:00
d3m0k1d
c3b6708a98 fix: gofmt
Some checks failed
build / build (push) Has been cancelled
2026-02-22 19:48:37 +03:00
d3m0k1d
3acd0b899c fix: linter run and gosec fix
Some checks failed
build / build (push) Failing after 1m52s
2026-02-22 19:45:59 +03:00
d3m0k1d
3ac1250bfc feat: add metrics support 2026-02-22 19:45:47 +03:00
d3m0k1d
7bba444522 feat: upgrade max_retry logic and change version
All checks were successful
build / build (push) Successful in 2m9s
2026-02-22 18:27:21 +03:00
d3m0k1d
97eb626237 fix: update libs
Some checks failed
CD - BanForge Release / release (push) Failing after 2m24s
2026-02-22 17:21:20 +03:00
d3m0k1d
b7a1ac06d4 feat: new ver
All checks were successful
CD - BanForge Release / release (push) Successful in 3m42s
2026-02-22 16:13:51 +03:00
d3m0k1d
49f0acb777 docs: update add to example max retry
All checks were successful
build / build (push) Successful in 2m8s
2026-02-22 16:12:52 +03:00
d3m0k1d
a602207369 feat: full working max_retry logic
All checks were successful
build / build (push) Successful in 2m45s
2026-02-22 16:06:51 +03:00
d3m0k1d
8c0cfcdbe7 refactoring: method on reader req db
All checks were successful
build / build (push) Successful in 2m8s
2026-02-19 12:36:56 +03:00
d3m0k1d
35a1a89baf fix: run tests in storage
All checks were successful
build / build (push) Successful in 2m6s
2026-02-19 11:22:52 +03:00
d3m0k1d
f3387b169a fix: gosec
Some checks failed
build / build (push) Failing after 1m59s
2026-02-19 11:17:51 +03:00
d3m0k1d
5782072f91 fix: ci one more time
Some checks failed
build / build (push) Failing after 1m42s
2026-02-19 11:14:45 +03:00
d3m0k1d
7918b3efe6 feat: add new nosec flags for fix ci
Some checks failed
build / build (push) Failing after 1m38s
2026-02-19 11:09:59 +03:00
d3m0k1d
f628e24f58 fix: golangci fix
Some checks failed
build / build (push) Failing after 1m40s
2026-02-19 11:03:52 +03:00
d3m0k1d
7f54db0cd4 feat: add new method and for db req and add to template max retry
Some checks failed
build / build (push) Failing after 1m48s
2026-02-19 10:53:55 +03:00
Ilya Chernishev
2e9b307194 Merge pull request #1 from shinyzero00/master
All checks were successful
build / build (push) Successful in 2m25s
refactoring pr by shinyzero00
2026-02-15 13:17:01 +03:00
Ilya Chernishev
726594a712 Change return value to nil on successful IP block 2026-02-15 13:13:26 +03:00
Ilya Chernishev
b27038a59c Execute SQL statement to create table in database 2026-02-15 13:08:40 +03:00
Ilya Chernishev
72025dab7d Remove comment about potential failure in encoding
Removed commented-out question regarding error handling.
2026-02-15 12:59:20 +03:00
Ilya Chernishev
dd131477e2 fix ST1005 2026-02-15 12:51:18 +03:00
Ilya Chernishev
670aec449a fix ST1005 staticcheck 2026-02-15 12:49:57 +03:00
zero@thinky
fc37e641be refactor(internal/config): use CutSuffix 2026-02-15 04:56:22 +03:00
zero@thinky
361de03208 refactor(cmd/fw): wtf is that error handling 2026-02-15 04:56:22 +03:00
zero@thinky
a2268fda5d fix(cmd/fw): why to fucking log when it is printed by the only caller 2026-02-15 04:56:22 +03:00
zero@thinky
9dc0b6002e refactor(internal/config): error handling 2026-02-15 04:56:22 +03:00
zero@thinky
4953be3ef6 refactor(internal/storage/RequestWriter/WriteReq): wtf is that error handling 2026-02-15 04:56:22 +03:00
zero@thinky
c386a2d6bc refactor(internal/storage/RequestWriter): deduplicate dsn 2026-02-15 04:54:38 +03:00
zero@thinky
dea03a6f70 refactor(*): what the fuck is that naming 2026-02-15 04:54:38 +03:00
zero@thinky
11f755c03c style(internal/storage/BanWriter): rm extra newline 2026-02-15 04:54:38 +03:00
zero@thinky
1c7a1c1778 refactor(internal/storage/BanWriter): deduplicate dsn 2026-02-15 04:54:38 +03:00
zero@thinky
411574cabe refactor(internal/storage): generalization and deduplication 2026-02-15 04:28:34 +03:00
d3m0k1d
820c9410a1 feat: update docs for new commands
All checks were successful
build / build (push) Successful in 2m8s
CD - BanForge Release / release (push) Successful in 3m46s
2026-02-09 22:27:28 +03:00
d3m0k1d
6f261803a7 feat: add to cli commands for open/close ports on firewall
All checks were successful
build / build (push) Successful in 2m2s
2026-02-09 21:51:31 +03:00
d3m0k1d
aacc98668f feat: add logic for PortClose and PortOpen on interfaces
All checks were successful
build / build (push) Successful in 2m4s
2026-02-09 21:31:19 +03:00
d3m0k1d
9519eedf4f feat: add new interface method to firewals
All checks were successful
build / build (push) Successful in 3m9s
2026-02-09 19:50:06 +03:00
d3m0k1d
b8b9b227a9 Fix: daemon chanels
All checks were successful
build / build (push) Successful in 3m9s
CD - BanForge Release / release (push) Successful in 5m9s
2026-01-27 17:10:01 +03:00
d3m0k1d
08d3214f22 Fix: goimport linter fix
All checks were successful
build / build (push) Successful in 3m27s
2026-01-27 17:04:36 +03:00
d3m0k1d
6ebda76738 feat: Add apache support
Some checks failed
build / build (push) Failing after 2m48s
2026-01-27 16:59:32 +03:00
d3m0k1d
b9754f605b fix: Delete sudo calls on exec
All checks were successful
build / build (push) Successful in 3m8s
CD - BanForge Release / release (push) Successful in 5m24s
2026-01-27 16:20:03 +03:00
d3m0k1d
be6b19426b docs: Add installation guide
All checks were successful
build / build (push) Successful in 3m16s
2026-01-26 16:51:40 +03:00
d3m0k1d
3ebffda2c7 feat: improve table on cli interface
All checks were successful
build / build (push) Successful in 3m14s
CD - BanForge Release / release (push) Successful in 5m8s
2026-01-26 14:21:35 +03:00
d3m0k1d
cadbbc9080 feat: improve reason string on db
All checks were successful
build / build (push) Successful in 3m8s
2026-01-26 14:04:30 +03:00
d3m0k1d
e907fb0b1a feat: update ban/unban command
All checks were successful
build / build (push) Successful in 3m13s
2026-01-25 21:13:56 +03:00
d3m0k1d
b0fc0646d2 fix: typo
All checks were successful
build / build (push) Successful in 3m23s
2026-01-24 20:30:27 +03:00
d3m0k1d
c2eb02afc7 docs: fix roadmap
All checks were successful
build / build (push) Successful in 3m22s
2026-01-23 18:22:18 +03:00
d3m0k1d
262f3daee4 docs: update reaadme.md Roadmap and overview
All checks were successful
build / build (push) Successful in 3m3s
2026-01-23 17:58:03 +03:00
d3m0k1d
fb32886d4a refactoring: rename func writer
All checks were successful
build / build (push) Successful in 3m31s
2026-01-22 21:08:55 +03:00
d3m0k1d
fb624a9147 fix: errcheck
All checks were successful
build / build (push) Successful in 3m10s
2026-01-22 20:34:49 +03:00
d3m0k1d
7741e08ebc fix: linter 2026-01-22 20:34:36 +03:00
d3m0k1d
5f607d0be0 refactoring: full refactoring the database structure from 1 file to 2 file db struct to avoid conflict 2 writters and sqllite busy, improve tests
Some checks failed
build / build (push) Has been cancelled
2026-01-22 20:29:19 +03:00
d3m0k1d
9a7e5a4796 fix: fix matchPath logic
All checks were successful
build / build (push) Successful in 3m26s
CD - BanForge Release / release (push) Successful in 5m21s
2026-01-22 00:37:57 +03:00
d3m0k1d
95bc7683ea fix: fix test for test server
All checks were successful
build / build (push) Successful in 3m18s
2026-01-22 00:25:50 +03:00
d3m0k1d
dca0241f17 fix: golangci-lint --fix run
Some checks failed
build / build (push) Failing after 3m51s
2026-01-22 00:11:29 +03:00
d3m0k1d
791d64ae4d feat: Recode logic for add logs to db
Some checks failed
build / build (push) Has been cancelled
2026-01-22 00:09:56 +03:00
d3m0k1d
7df9925f94 fix: db connection bug and delete debug logs
Some checks failed
build / build (push) Failing after 5m6s
CD - BanForge Release / release (push) Failing after 6m11s
2026-01-21 22:43:08 +03:00
d3m0k1d
211e019c68 fix: fix gosec and err checl
Some checks failed
build / build (push) Failing after 3m10s
2026-01-21 22:30:27 +03:00
d3m0k1d
de000ab5b6 fix: fix Nginx parser start without gourutine
Some checks failed
build / build (push) Failing after 2m54s
2026-01-21 21:40:57 +03:00
d3m0k1d
0fe34d1537 Fix: fix daemon logic
Some checks failed
build / build (push) Failing after 2m33s
2026-01-21 21:36:38 +03:00
d3m0k1d
341f49c4b4 feat: improve db logic and logger(untested)
Some checks failed
build / build (push) Failing after 2m48s
2026-01-21 20:44:28 +03:00
d3m0k1d
7522071a03 fix: fix tests
Some checks failed
build / build (push) Failing after 2m47s
2026-01-21 19:16:25 +03:00
d3m0k1d
4e8dc51ac8 chore: switch driver one more time ncruses -> modernc
Some checks failed
build / build (push) Has been cancelled
2026-01-21 19:15:35 +03:00
d3m0k1d
11453bd0d9 chore: delete old driver from deps
All checks were successful
build / build (push) Successful in 1m43s
CD - BanForge Release / release (push) Successful in 3m18s
2026-01-21 16:56:26 +03:00
d3m0k1d
f03ec114b1 fix: fix test, new driver 2026-01-21 16:56:07 +03:00
d3m0k1d
26f4f17760 fix: golangci --fix run
Some checks failed
build / build (push) Failing after 2m11s
2026-01-21 16:52:07 +03:00
d3m0k1d
3001282d88 chore: Add script for postinstall and postremove package, add scripts to gorealeaser
Some checks failed
build / build (push) Failing after 1m8s
2026-01-21 16:44:42 +03:00
d3m0k1d
9198f19805 chore: new driver for sqlite3 with wasm and edit goreleaser env to compiler CGO_ENABLED = 0, add empty file for postinstall script
Some checks failed
build / build (push) Failing after 1m15s
2026-01-21 16:26:55 +03:00
d3m0k1d
b6e92a2a57 feat: Add new logic for judge path with *
Some checks failed
build / build (push) Failing after 1m38s
2026-01-21 16:07:31 +03:00
d3m0k1d
16a174cf56 Fix: fix init nftables, fix logic ban/unban command
All checks were successful
build / build (push) Successful in 2m39s
2026-01-20 23:41:22 +03:00
40 changed files with 2175 additions and 697 deletions

View File

@@ -25,6 +25,8 @@ builds:
- arm64 - arm64
ldflags: ldflags:
- "-s -w" - "-s -w"
env:
- CGO_ENABLED=0
archives: archives:
- format: tar.gz - format: tar.gz
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
@@ -43,7 +45,9 @@ nfpms:
- rpm - rpm
- archlinux - archlinux
bindir: /usr/bin bindir: /usr/bin
scripts:
postinstall: build/postinstall.sh
postremove: build/postremove.sh
release: release:
gitea: gitea:
owner: d3m0k1d owner: d3m0k1d

View File

@@ -15,14 +15,15 @@ Log-based IPS system written in Go for Linux-based system.
# Overview # Overview
BanForge is a simple IPS for replacement fail2ban in Linux system. BanForge is a simple IPS for replacement fail2ban in Linux system.
The project is currently in its early stages of development. All release are available on my self-hosted [Gitea](https://gitea.d3m0k1d.ru/d3m0k1d/BanForge) after release v1.0.0 are available on Github release page.
All release are available on my self-hosted [Gitea](https://gitea.d3m0k1d.ru/d3m0k1d/BanForge) because Github has limits for Actions.
If you have any questions or suggestions, create issue on [Github](https://github.com/d3m0k1d/BanForge/issues). If you have any questions or suggestions, create issue on [Github](https://github.com/d3m0k1d/BanForge/issues).
## Roadmap ## Roadmap
- [x] Real-time Nginx log monitoring - [x] Rule system
- [ ] Add support for other service - [x] Nginx and Sshd support
- [ ] Add support for user service with regular expressions - [x] Working with ufw/iptables/nftables/firewalld
- [ ] Add support for most popular web-service
- [ ] User regexp for custom services
- [ ] TUI interface - [ ] TUI interface
# Requirements # Requirements
@@ -31,15 +32,79 @@ If you have any questions or suggestions, create issue on [Github](https://githu
- ufw/iptables/nftables/firewalld - ufw/iptables/nftables/firewalld
# Installation # Installation
Search for a release on the [Gitea](https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases) releases page and download it. Then create or copy a systemd unit file. Search for a release on the [Gitea](https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases) releases page and download it.
Or clone the repo and use the Makefile. In release page you can find rpm, deb, apk packages, for amd or arm architecture.
```
git clone https://gitea.d3m0k1d.ru/d3m0k1d/BanForge.git ## Installation guide for packages
cd BanForge
sudo make build-daemon ### Debian/Ubuntu(.deb)
cd bin ```bash
# Download the latest DEB package
wget https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases/download/v0.4.0/banforge_0.4.0_linux_amd64.deb
# Install
sudo dpkg -i banforge_0.4.0_linux_amd64.deb
# Verify installation
sudo systemctl status banforge
``` ```
### RHEL-based(.rpm)
```bash
# Download
wget https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases/download/v0.4.0/banforge_0.4.0_linux_amd64.rpm
# Install
sudo rpm -i banforge_0.4.0_linux_amd64.rpm
# Or with dnf (CentOS 8+, AlmaLinux)
sudo dnf install banforge_0.4.0_linux_amd64.rpm
# Verify
sudo systemctl status banforge
```
### Alpine(.apk)
```bash
# Download
wget https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases/download/v0.4.0/banforge_0.4.0_linux_amd64.apk
# Install
sudo apk add --allow-untrusted banforge_0.4.0_linux_amd64.apk
# Verify
sudo rc-service banforge status
```
### Arch Linux(.pkg.tar.zst)
```bash
# Download
wget https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases/download/v0.4.0/banforge_0.4.0_linux_amd64.pkg.tar.zst
# Install
sudo pacman -U banforge_0.4.0_linux_amd64.pkg.tar.zst
# Verify
sudo systemctl status banforge
```
This is examples for other versions with different architecture or new versions check release page on [Gitea](https://gitea.d3m0k1d.ru/d3m0k1d/BanForge/releases).
## Installation guide for source code
```bash
# Download
git clone https://github.com/d3m0k1d/BanForge.git
cd BanForge
make build-daemon
cd bin
mv banforge /usr/bin/banforge
cd ..
# Add init script and uses banforge init
cd build
./postinstall.sh
```
# Usage # Usage
For first steps use this commands For first steps use this commands
```bash ```bash

61
build/postinstall.sh Normal file
View File

@@ -0,0 +1,61 @@
#!/bin/sh
if command -v systemctl >/dev/null 2>&1; then
# for systemd based systems
banforge init
cat > /etc/systemd/system/banforge.service << 'EOF'
[Unit]
Description=BanForge - IPS log based system
After=network-online.target
Wants=network-online.target
Documentation=https://github.com/d3m0k1d/BanForge
[Service]
Type=simple
ExecStart=/usr/local/bin/banforge daemon
User=root
Group=root
Restart=always
StandardOutput=journal
StandardError=journal
SyslogIdentifier=banforge
TimeoutStopSec=90
KillSignal=SIGTERM
[Install]
WantedBy=multi-user.target
EOF
chmod 644 /etc/systemd/system/banforge.service
systemctl daemon-reload
systemctl enable banforge
fi
if command -v rc-service >/dev/null 2>&1; then
# for openrc based systems
banforge init
cat > /etc/init.d/banforge << 'EOF'
#!/sbin/openrc-run
description="BanForge - IPS log based system"
command="/usr/bin/banforge"
command_args="daemon"
pidfile="/run/${RC_SVCNAME}.pid"
command_background="yes"
depend() {
need net
after network
}
start_post() {
einfo "BanForge is now running"
}
stop_post() {
einfo "BanForge is now stopped"
}
EOF
chmod 755 /etc/init.d/banforge
rc-update add banforge
fi

20
build/postremove.sh Normal file
View File

@@ -0,0 +1,20 @@
#!/bin/sh
if command -v systemctl >/dev/null 2>&1; then
# for systemd based systems
systemctl stop banforge 2>/dev/null || true
systemctl disable banforge 2>/dev/null || true
rm -f /etc/systemd/system/banforge.service
systemctl daemon-reload
fi
if command -v rc-service >/dev/null 2>&1; then
# for openrc based systems
rc-service banforge stop 2>/dev/null || true
rc-update del banforge 2>/dev/null || true
rm -f /etc/init.d/banforge
fi
rm -rf /etc/banforge/
rm -rf /var/lib/banforge/
rm -rf /var/log/banforge/

View File

@@ -5,7 +5,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
"github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/blocker"
"github.com/d3m0k1d/BanForge/internal/config" "github.com/d3m0k1d/BanForge/internal/config"
@@ -20,17 +19,38 @@ var DaemonCmd = &cobra.Command{
Use: "daemon", Use: "daemon",
Short: "Run BanForge daemon process", Short: "Run BanForge daemon process",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
entryCh := make(chan *storage.LogEntry, 1000)
resultCh := make(chan *storage.LogEntry, 100)
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop() defer stop()
log := logger.New(false) log := logger.New(false)
log.Info("Starting BanForge daemon") log.Info("Starting BanForge daemon")
db, err := storage.NewDB() reqDb_w, err := storage.NewRequestsWr()
if err != nil { if err != nil {
log.Error("Failed to create database", "error", err) log.Error("Failed to create request writer", "error", err)
os.Exit(1)
}
reqDb_r, err := storage.NewRequestsRd()
if err != nil {
log.Error("Failed to create request reader", "error", err)
os.Exit(1)
}
banDb_r, err := storage.NewBanReader()
if err != nil {
log.Error("Failed to create ban reader", "error", err)
os.Exit(1)
}
banDb_w, err := storage.NewBanWriter()
if err != nil {
log.Error("Failed to create ban writter", "error", err)
os.Exit(1) os.Exit(1)
} }
defer func() { defer func() {
err = db.Close() err = banDb_r.Close()
if err != nil {
log.Error("Failed to close database connection", "error", err)
}
err = banDb_w.Close()
if err != nil { if err != nil {
log.Error("Failed to close database connection", "error", err) log.Error("Failed to close database connection", "error", err)
} }
@@ -40,6 +60,11 @@ var DaemonCmd = &cobra.Command{
log.Error("Failed to load config", "error", err) log.Error("Failed to load config", "error", err)
os.Exit(1) os.Exit(1)
} }
_, err = config.LoadMetricsConfig()
if err != nil {
log.Error("Failed to load metrics config", "error", err)
os.Exit(1)
}
var b blocker.BlockerEngine var b blocker.BlockerEngine
fw := cfg.Firewall.Name fw := cfg.Firewall.Name
b = blocker.GetBlocker(fw, cfg.Firewall.Config) b = blocker.GetBlocker(fw, cfg.Firewall.Config)
@@ -48,19 +73,11 @@ var DaemonCmd = &cobra.Command{
log.Error("Failed to load rules", "error", err) log.Error("Failed to load rules", "error", err)
os.Exit(1) os.Exit(1)
} }
j := judge.New(db, b) j := judge.New(banDb_r, banDb_w, reqDb_r, b, resultCh, entryCh)
j.LoadRules(r) j.LoadRules(r)
go j.UnbanChecker() go j.UnbanChecker()
go func() { go j.Tribunal()
ticker := time.NewTicker(5 * time.Second) go storage.WriteReq(reqDb_w, resultCh)
defer ticker.Stop()
for range ticker.C {
if err := j.ProcessUnviewed(); err != nil {
log.Error("Failed to process unviewed", "error", err)
}
}
}()
var scanners []*parser.Scanner var scanners []*parser.Scanner
for _, svc := range cfg.Service { for _, svc := range cfg.Service {
@@ -98,16 +115,17 @@ var DaemonCmd = &cobra.Command{
if svc.Name == "nginx" { if svc.Name == "nginx" {
log.Info("Starting nginx parser", "service", serviceName) log.Info("Starting nginx parser", "service", serviceName)
ng := parser.NewNginxParser() ng := parser.NewNginxParser()
resultCh := make(chan *storage.LogEntry, 100) ng.Parse(p.Events(), entryCh)
ng.Parse(p.Events(), resultCh)
go storage.Write(db, resultCh)
} }
if svc.Name == "ssh" { if svc.Name == "ssh" {
log.Info("Starting ssh parser", "service", serviceName) log.Info("Starting ssh parser", "service", serviceName)
ssh := parser.NewSshdParser() ssh := parser.NewSshdParser()
resultCh := make(chan *storage.LogEntry, 100) ssh.Parse(p.Events(), entryCh)
ssh.Parse(p.Events(), resultCh) }
go storage.Write(db, resultCh) if svc.Name == "apache" {
log.Info("Starting apache parser", "service", serviceName)
ap := parser.NewApacheParser()
ap.Parse(p.Events(), entryCh)
} }
}(pars, svc.Name) }(pars, svc.Name)
continue continue
@@ -128,16 +146,18 @@ var DaemonCmd = &cobra.Command{
if svc.Name == "nginx" { if svc.Name == "nginx" {
log.Info("Starting nginx parser", "service", serviceName) log.Info("Starting nginx parser", "service", serviceName)
ng := parser.NewNginxParser() ng := parser.NewNginxParser()
resultCh := make(chan *storage.LogEntry, 100) ng.Parse(p.Events(), entryCh)
ng.Parse(p.Events(), resultCh)
go storage.Write(db, resultCh)
} }
if svc.Name == "ssh" { if svc.Name == "ssh" {
log.Info("Starting ssh parser", "service", serviceName) log.Info("Starting ssh parser", "service", serviceName)
ssh := parser.NewSshdParser() ssh := parser.NewSshdParser()
resultCh := make(chan *storage.LogEntry, 100) ssh.Parse(p.Events(), entryCh)
ssh.Parse(p.Events(), resultCh) }
go storage.Write(db, resultCh) if svc.Name == "apache" {
log.Info("Starting apache parser", "service", serviceName)
ap := parser.NewApacheParser()
ap.Parse(p.Events(), entryCh)
} }
}(pars, svc.Name) }(pars, svc.Name)

View File

@@ -7,41 +7,62 @@ import (
"github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/blocker"
"github.com/d3m0k1d/BanForge/internal/config" "github.com/d3m0k1d/BanForge/internal/config"
"github.com/d3m0k1d/BanForge/internal/storage"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
ip string ttl_fw string
port int
protocol string
) )
var UnbanCmd = &cobra.Command{ var UnbanCmd = &cobra.Command{
Use: "unban", Use: "unban",
Short: "Unban IP", Short: "Unban IP",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cfg, err := config.LoadConfig() err := func() error {
if len(args) == 0 {
return fmt.Errorf("IP can't be empty")
}
if ttl_fw == "" {
ttl_fw = "1y"
}
ip := args[0]
db, err := storage.NewBanWriter()
if err != nil {
return err
}
cfg, err := config.LoadConfig()
if err != nil {
return err
}
fw := cfg.Firewall.Name
b := blocker.GetBlocker(fw, cfg.Firewall.Config)
if ip == "" {
return fmt.Errorf("IP can't be empty")
}
if net.ParseIP(ip) == nil {
return fmt.Errorf("invalid IP")
}
if err != nil {
return err
}
err = b.Unban(ip)
if err != nil {
return err
}
err = db.RemoveBan(ip)
if err != nil {
return err
}
fmt.Println("IP unblocked successfully!")
return nil
}()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
fw := cfg.Firewall.Name
b := blocker.GetBlocker(fw, cfg.Firewall.Config)
if ip == "" {
fmt.Println("IP can't be empty")
os.Exit(1)
}
if net.ParseIP(ip) == nil {
fmt.Println("Invalid IP")
os.Exit(1)
}
if err != nil {
fmt.Println(err)
os.Exit(1)
}
err = b.Unban(ip)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
fmt.Println("IP unblocked successfully!")
}, },
} }
@@ -49,7 +70,64 @@ var BanCmd = &cobra.Command{
Use: "ban", Use: "ban",
Short: "Ban IP", Short: "Ban IP",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
err := func() error {
if len(args) == 0 {
return fmt.Errorf("IP can't be empty")
}
if ttl_fw == "" {
ttl_fw = "1y"
}
ip := args[0]
db, err := storage.NewBanWriter()
if err != nil {
return err
}
cfg, err := config.LoadConfig()
if err != nil {
return err
}
fw := cfg.Firewall.Name
b := blocker.GetBlocker(fw, cfg.Firewall.Config)
if ip == "" {
return fmt.Errorf("IP can't be empty")
}
if net.ParseIP(ip) == nil {
return fmt.Errorf("invalid IP")
}
if err != nil {
return err
}
err = b.Ban(ip)
if err != nil {
return err
}
err = db.AddBan(ip, ttl_fw, "manual ban")
if err != nil {
return err
}
fmt.Println("IP blocked successfully!")
return nil
}()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
},
}
var PortCmd = &cobra.Command{
Use: "port",
Short: "Ports commands",
}
var PortOpenCmd = &cobra.Command{
Use: "open",
Short: "Open ports on firewall",
Run: func(cmd *cobra.Command, args []string) {
if protocol == "" {
fmt.Println("Protocol can't be empty")
os.Exit(1)
}
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -57,28 +135,45 @@ var BanCmd = &cobra.Command{
} }
fw := cfg.Firewall.Name fw := cfg.Firewall.Name
b := blocker.GetBlocker(fw, cfg.Firewall.Config) b := blocker.GetBlocker(fw, cfg.Firewall.Config)
if ip == "" { err = b.PortOpen(port, protocol)
fmt.Println("IP can't be empty")
os.Exit(1)
}
if net.ParseIP(ip) == nil {
fmt.Println("Invalid IP")
os.Exit(1)
}
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
err = b.Ban(ip) fmt.Println("Port opened successfully!")
},
}
var PortCloseCmd = &cobra.Command{
Use: "close",
Short: "Close ports on firewall",
Run: func(cmd *cobra.Command, args []string) {
if protocol == "" {
fmt.Println("Protocol can't be empty")
os.Exit(1)
}
cfg, err := config.LoadConfig()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
fmt.Println("IP unblocked successfully!") fw := cfg.Firewall.Name
b := blocker.GetBlocker(fw, cfg.Firewall.Config)
err = b.PortClose(port, protocol)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
fmt.Println("Port closed successfully!")
}, },
} }
func FwRegister() { func FwRegister() {
BanCmd.Flags().StringVarP(&ip, "ip", "i", "", "ip to ban") BanCmd.Flags().StringVarP(&ttl_fw, "ttl", "t", "", "ban time")
UnbanCmd.Flags().StringVarP(&ip, "ip", "i", "", "ip to unban") PortCmd.AddCommand(PortOpenCmd)
PortCmd.AddCommand(PortCloseCmd)
PortOpenCmd.Flags().IntVarP(&port, "port", "p", 0, "port number")
PortOpenCmd.Flags().StringVarP(&protocol, "protocol", "c", "", "protocol")
PortCloseCmd.Flags().IntVarP(&port, "port", "p", 0, "port number")
PortCloseCmd.Flags().StringVarP(&protocol, "protocol", "c", "", "protocol")
} }

View File

@@ -82,23 +82,11 @@ var InitCmd = &cobra.Command{
} }
fmt.Println("Firewall configured") fmt.Println("Firewall configured")
db, err := storage.NewDB() err = storage.CreateTables()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
err = db.CreateTable()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
defer func() {
err = db.Close()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}()
fmt.Println("Firewall detected and configured") fmt.Println("Firewall detected and configured")
fmt.Println("BanForge initialized successfully!") fmt.Println("BanForge initialized successfully!")

View File

@@ -13,7 +13,7 @@ var BanListCmd = &cobra.Command{
Short: "List banned IP adresses", Short: "List banned IP adresses",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
var log = logger.New(false) var log = logger.New(false)
d, err := storage.NewDB() d, err := storage.NewBanReader()
if err != nil { if err != nil {
log.Error("Failed to create database", "error", err) log.Error("Failed to create database", "error", err)
os.Exit(1) os.Exit(1)

View File

@@ -9,12 +9,13 @@ import (
) )
var ( var (
name string name string
service string service string
path string path string
status string status string
method string method string
ttl string ttl string
max_retry int
) )
var RuleCmd = &cobra.Command{ var RuleCmd = &cobra.Command{
@@ -41,7 +42,7 @@ var AddCmd = &cobra.Command{
if ttl == "" { if ttl == "" {
ttl = "1y" ttl = "1y"
} }
err := config.NewRule(name, service, path, status, method, ttl) err := config.NewRule(name, service, path, status, method, ttl, max_retry)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@@ -61,11 +62,12 @@ var ListCmd = &cobra.Command{
} }
for _, rule := range r { for _, rule := range r {
fmt.Printf( fmt.Printf(
"Name: %s\nService: %s\nPath: %s\nStatus: %s\nMethod: %s\n\n", "Name: %s\nService: %s\nPath: %s\nStatus: %s\n MaxRetry: %d\nMethod: %s\n\n",
rule.Name, rule.Name,
rule.ServiceName, rule.ServiceName,
rule.Path, rule.Path,
rule.Status, rule.Status,
rule.MaxRetry,
rule.Method, rule.Method,
) )
} }
@@ -81,4 +83,5 @@ func RuleRegister() {
AddCmd.Flags().StringVarP(&status, "status", "c", "", "status code") AddCmd.Flags().StringVarP(&status, "status", "c", "", "status code")
AddCmd.Flags().StringVarP(&method, "method", "m", "", "method") AddCmd.Flags().StringVarP(&method, "method", "m", "", "method")
AddCmd.Flags().StringVarP(&ttl, "ttl", "t", "", "ban time") AddCmd.Flags().StringVarP(&ttl, "ttl", "t", "", "ban time")
AddCmd.Flags().IntVarP(&max_retry, "max_retry", "r", 0, "max retry")
} }

View File

@@ -0,0 +1,17 @@
package command
import (
"fmt"
"github.com/spf13/cobra"
)
var version = "0.5.2"
var VersionCmd = &cobra.Command{
Use: "version",
Short: "BanForge version",
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("BanForge version:", version)
},
}

View File

@@ -13,7 +13,6 @@ var rootCmd = &cobra.Command{
Use: "banforge", Use: "banforge",
Short: "IPS log-based written on Golang", Short: "IPS log-based written on Golang",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
}, },
} }
@@ -28,6 +27,8 @@ func Execute() {
rootCmd.AddCommand(command.BanCmd) rootCmd.AddCommand(command.BanCmd)
rootCmd.AddCommand(command.UnbanCmd) rootCmd.AddCommand(command.UnbanCmd)
rootCmd.AddCommand(command.BanListCmd) rootCmd.AddCommand(command.BanListCmd)
rootCmd.AddCommand(command.VersionCmd)
rootCmd.AddCommand(command.PortCmd)
command.RuleRegister() command.RuleRegister()
command.FwRegister() command.FwRegister()
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {

View File

@@ -11,6 +11,16 @@ banforge init
**Description** **Description**
This command creates the necessary directories and base configuration files This command creates the necessary directories and base configuration files
required for the daemon to operate. required for the daemon to operate.
### version - Display BanForge version
```shell
banforge version
```
**Description**
This command displays the current version of the BanForge software.
### daemon - Starts the BanForge daemon process ### daemon - Starts the BanForge daemon process
```shell ```shell
@@ -31,6 +41,18 @@ banforge unban <ip>
**Description** **Description**
These commands provide an abstraction over your firewall. If you want to simplify the interface to your firewall, you can use these commands. These commands provide an abstraction over your firewall. If you want to simplify the interface to your firewall, you can use these commands.
Flag -t or -ttl add bantime if not used default ban 1 year
### ports - Open and Close ports on firewall
```shell
banforge open -port <port> -protocol <protocol>
banforge close -port <port> -protocol <protocol>
```
**Description**
These commands provide an abstraction over your firewall. If you want to simplify the interface to your firewall, you can use these commands.
### list - Lists the IP addresses that are currently blocked ### list - Lists the IP addresses that are currently blocked
```shell ```shell
banforge list banforge list
@@ -57,5 +79,6 @@ These command help you to create and manage detection rules in CLI interface.
| -m -method | - | | -m -method | - |
| -c -status | - | | -c -status | - |
| -t -ttl | -(if not used default ban 1 year) | | -t -ttl | -(if not used default ban 1 year) |
| -r -max_retry | - |
You must specify at least 1 of the optional flags to create a rule. You must specify at least 1 of the optional flags to create a rule.

View File

@@ -40,9 +40,12 @@ Example:
service = "nginx" service = "nginx"
path = "" path = ""
status = "304" status = "304"
max_retry = 3
method = "" method = ""
ban_time = "1m" ban_time = "1m"
``` ```
**Description** **Description**
The [[rule]] section require name and one of the following parameters: service, path, status, method. To add a rule, create a [[rule]] block and specify the parameters. The [[rule]] section require name and one of the following parameters: service, path, status, method. To add a rule, create a [[rule]] block and specify the parameters.
ban_time require in format "1m", "1h", "1d", "1M", "1y" ban_time require in format "1m", "1h", "1d", "1M", "1y".
If you want to ban all requests to PHP files (e.g., path = "*.php") or requests to the admin panel (e.g., path = "/admin/*").
If max_retry = 0 ban on first request.

20
go.mod
View File

@@ -5,15 +5,25 @@ go 1.25.5
require ( require (
github.com/BurntSushi/toml v1.6.0 github.com/BurntSushi/toml v1.6.0
github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/jedib0t/go-pretty/v6 v6.7.8
github.com/mattn/go-sqlite3 v1.14.33
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
gopkg.in/natefinch/lumberjack.v2 v2.2.1
modernc.org/sqlite v1.46.1
) )
require ( require (
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/mattn/go-runewidth v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/pflag v1.0.10 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect
golang.org/x/text v0.22.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
modernc.org/libc v1.68.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
) )

72
go.sum
View File

@@ -1,21 +1,32 @@
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jedib0t/go-pretty/v6 v6.7.8 h1:BVYrDy5DPBA3Qn9ICT+PokP9cvCv1KaHv2i+Hc8sr5o= github.com/jedib0t/go-pretty/v6 v6.7.8 h1:BVYrDy5DPBA3Qn9ICT+PokP9cvCv1KaHv2i+Hc8sr5o=
github.com/jedib0t/go-pretty/v6 v6.7.8/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyqczZk+U6BDALU= github.com/jedib0t/go-pretty/v6 v6.7.8/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyqczZk+U6BDALU=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjcQQaQ=
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
@@ -25,10 +36,49 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.2 h1:4yPaaq9dXYXZ2V8s1UgrC3KIj580l2N4ClrLwnbv2so=
modernc.org/ccgo/v4 v4.30.2/go.mod h1:yZMnhWEdW0qw3EtCndG1+ldRrVGS+bIwyWmAWzS0XEw=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.68.0 h1:PJ5ikFOV5pwpW+VqCK1hKJuEWsonkIJhhIXyuF/91pQ=
modernc.org/libc v1.68.0/go.mod h1:NnKCYeoYgsEqnY3PgvNgAeaJnso968ygU8Z0DxjoEc0=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@@ -1,7 +1,9 @@
package blocker package blocker
import ( import (
"fmt"
"os/exec" "os/exec"
"strconv"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
) )
@@ -21,14 +23,15 @@ func (f *Firewalld) Ban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
cmd := exec.Command("sudo", "firewall-cmd", "--zone=drop", "--add-source", ip, "--permanent") // #nosec G204 - ip is validated
cmd := exec.Command("firewall-cmd", "--zone=drop", "--add-source", ip, "--permanent")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
f.logger.Error(err.Error()) f.logger.Error(err.Error())
return err return err
} }
f.logger.Info("Add source " + ip + " " + string(output)) f.logger.Info("Add source " + ip + " " + string(output))
output, err = exec.Command("sudo", "firewall-cmd", "--reload").CombinedOutput() output, err = exec.Command("firewall-cmd", "--reload").CombinedOutput()
if err != nil { if err != nil {
f.logger.Error(err.Error()) f.logger.Error(err.Error())
return err return err
@@ -42,14 +45,15 @@ func (f *Firewalld) Unban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
cmd := exec.Command("sudo", "firewall-cmd", "--zone=drop", "--remove-source", ip, "--permanent") // #nosec G204 - ip is validated
cmd := exec.Command("firewall-cmd", "--zone=drop", "--remove-source", ip, "--permanent")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
f.logger.Error(err.Error()) f.logger.Error(err.Error())
return err return err
} }
f.logger.Info("Remove source " + ip + " " + string(output)) f.logger.Info("Remove source " + ip + " " + string(output))
output, err = exec.Command("sudo", "firewall-cmd", "--reload").CombinedOutput() output, err = exec.Command("firewall-cmd", "--reload").CombinedOutput()
if err != nil { if err != nil {
f.logger.Error(err.Error()) f.logger.Error(err.Error())
return err return err
@@ -58,6 +62,63 @@ func (f *Firewalld) Unban(ip string) error {
return nil return nil
} }
func (f *Firewalld) PortOpen(port int, protocol string) error {
// #nosec G204 - handle is extracted from Firewalld output and validated
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
f.logger.Error("invalid protocol")
return fmt.Errorf("invalid protocol")
}
s := strconv.Itoa(port)
cmd := exec.Command(
"firewall-cmd",
"--zone=public",
"--add-port="+s+"/"+protocol,
"--permanent",
)
output, err := cmd.CombinedOutput()
if err != nil {
f.logger.Error(err.Error())
return err
}
f.logger.Info("Add port " + s + " " + string(output))
output, err = exec.Command("firewall-cmd", "--reload").CombinedOutput()
if err != nil {
f.logger.Error(err.Error())
return err
}
f.logger.Info("Reload " + string(output))
}
return nil
}
func (f *Firewalld) PortClose(port int, protocol string) error {
// #nosec G204 - handle is extracted from nftables output and validated
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
return fmt.Errorf("invalid protocol")
}
s := strconv.Itoa(port)
cmd := exec.Command(
"firewall-cmd",
"--zone=public",
"--remove-port="+s+"/"+protocol,
"--permanent",
)
output, err := cmd.CombinedOutput()
if err != nil {
return err
}
f.logger.Info("Remove port " + s + " " + string(output))
output, err = exec.Command("firewall-cmd", "--reload").CombinedOutput()
if err != nil {
return err
}
f.logger.Info("Reload " + string(output))
}
return nil
}
func (f *Firewalld) Setup(config string) error { func (f *Firewalld) Setup(config string) error {
return nil return nil
} }

View File

@@ -10,6 +10,8 @@ type BlockerEngine interface {
Ban(ip string) error Ban(ip string) error
Unban(ip string) error Unban(ip string) error
Setup(config string) error Setup(config string) error
PortOpen(port int, protocol string) error
PortClose(port int, protocol string) error
} }
func GetBlocker(fw string, config string) BlockerEngine { func GetBlocker(fw string, config string) BlockerEngine {

View File

@@ -2,6 +2,7 @@ package blocker
import ( import (
"os/exec" "os/exec"
"strconv"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
) )
@@ -27,7 +28,8 @@ func (f *Iptables) Ban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
cmd := exec.Command("sudo", "iptables", "-A", "INPUT", "-s", ip, "-j", "DROP") // #nosec G204 - f.config is validated above via validateConfigPath()
cmd := exec.Command("iptables", "-A", "INPUT", "-s", ip, "-j", "DROP")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
f.logger.Error("failed to ban IP", f.logger.Error("failed to ban IP",
@@ -45,7 +47,7 @@ func (f *Iptables) Ban(ip string) error {
return err return err
} }
// #nosec G204 - f.config is validated above via validateConfigPath() // #nosec G204 - f.config is validated above via validateConfigPath()
cmd = exec.Command("sudo", "iptables-save", "-f", f.config) cmd = exec.Command("iptables-save", "-f", f.config)
output, err = cmd.CombinedOutput() output, err = cmd.CombinedOutput()
if err != nil { if err != nil {
f.logger.Error("failed to save config", f.logger.Error("failed to save config",
@@ -69,7 +71,8 @@ func (f *Iptables) Unban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
cmd := exec.Command("sudo", "iptables", "-D", "INPUT", "-s", ip, "-j", "DROP") // #nosec G204 - f.config is validated above via validateConfigPath()
cmd := exec.Command("iptables", "-D", "INPUT", "-s", ip, "-j", "DROP")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
f.logger.Error("failed to unban IP", f.logger.Error("failed to unban IP",
@@ -87,7 +90,7 @@ func (f *Iptables) Unban(ip string) error {
return err return err
} }
// #nosec G204 - f.config is validated above via validateConfigPath() // #nosec G204 - f.config is validated above via validateConfigPath()
cmd = exec.Command("sudo", "iptables-save", "-f", f.config) cmd = exec.Command("iptables-save", "-f", f.config)
output, err = cmd.CombinedOutput() output, err = cmd.CombinedOutput()
if err != nil { if err != nil {
f.logger.Error("failed to save config", f.logger.Error("failed to save config",
@@ -102,6 +105,64 @@ func (f *Iptables) Unban(ip string) error {
return nil return nil
} }
func (f *Iptables) PortOpen(port int, protocol string) error {
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
f.logger.Error("invalid protocol")
return nil
}
s := strconv.Itoa(port)
// #nosec G204 - managed by system adminstartor
cmd := exec.Command("iptables", "-A", "INPUT", "-p", protocol, "--dport", s, "-j", "ACCEPT")
output, err := cmd.CombinedOutput()
if err != nil {
f.logger.Error(err.Error())
return err
}
f.logger.Info("Add port " + s + " " + string(output))
// #nosec G204 - f.config is validated above via validateConfigPath()
cmd = exec.Command("iptables-save", "-f", f.config)
output, err = cmd.CombinedOutput()
if err != nil {
f.logger.Error("failed to save config",
"config_path", f.config,
"error", err.Error(),
"output", string(output))
return err
}
}
return nil
}
func (f *Iptables) PortClose(port int, protocol string) error {
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
f.logger.Error("invalid protocol")
return nil
}
s := strconv.Itoa(port)
// #nosec G204 - managed by system adminstartor
cmd := exec.Command("iptables", "-D", "INPUT", "-p", protocol, "--dport", s, "-j", "ACCEPT")
output, err := cmd.CombinedOutput()
if err != nil {
f.logger.Error(err.Error())
return err
}
f.logger.Info("Add port " + s + " " + string(output))
// #nosec G204 - f.config is validated above via validateConfigPath()
cmd = exec.Command("iptables-save", "-f", f.config)
output, err = cmd.CombinedOutput()
if err != nil {
f.logger.Error("failed to save config",
"config_path", f.config,
"error", err.Error(),
"output", string(output))
return err
}
}
return nil
}
func (f *Iptables) Setup(config string) error { func (f *Iptables) Setup(config string) error {
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package blocker
import ( import (
"fmt" "fmt"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
@@ -25,8 +26,8 @@ func (n *Nftables) Ban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
// #nosec G204 - ip is validated
cmd := exec.Command("sudo", "nft", "add", "rule", "inet", "banforge", "banned", cmd := exec.Command("nft", "add", "rule", "inet", "banforge", "banned",
"ip", "saddr", ip, "drop") "ip", "saddr", ip, "drop")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
@@ -70,7 +71,7 @@ func (n *Nftables) Unban(ip string) error {
return fmt.Errorf("no rule found for IP %s", ip) return fmt.Errorf("no rule found for IP %s", ip)
} }
// #nosec G204 - handle is extracted from nftables output and validated // #nosec G204 - handle is extracted from nftables output and validated
cmd := exec.Command("sudo", "nft", "delete", "rule", "inet", "banforge", "banned", cmd := exec.Command("nft", "delete", "rule", "inet", "banforge", "banned",
"handle", handle) "handle", handle)
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
@@ -104,16 +105,16 @@ func (n *Nftables) Setup(config string) error {
nftConfig := `table inet banforge { nftConfig := `table inet banforge {
chain input { chain input {
type filter hook input priority 0 type filter hook input priority filter; policy accept;
policy accept jump banned
} }
chain banned { chain banned {
} }
} }
` `
// #nosec G204 - config is managed by adminstartor
cmd := exec.Command("sudo", "tee", config) cmd := exec.Command("tee", config)
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
return fmt.Errorf("failed to create stdin pipe: %w", err) return fmt.Errorf("failed to create stdin pipe: %w", err)
@@ -135,8 +136,8 @@ func (n *Nftables) Setup(config string) error {
if err = cmd.Wait(); err != nil { if err = cmd.Wait(); err != nil {
return fmt.Errorf("failed to save config: %w", err) return fmt.Errorf("failed to save config: %w", err)
} }
// #nosec G204 - config is managed by adminstartor
cmd = exec.Command("sudo", "nft", "-f", config) cmd = exec.Command("nft", "-f", config)
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return fmt.Errorf("failed to load nftables config: %s", string(output)) return fmt.Errorf("failed to load nftables config: %s", string(output))
@@ -146,7 +147,7 @@ func (n *Nftables) Setup(config string) error {
} }
func (n *Nftables) findRuleHandle(ip string) (string, error) { func (n *Nftables) findRuleHandle(ip string) (string, error) {
cmd := exec.Command("sudo", "nft", "-a", "list", "chain", "inet", "banforge", "banned") cmd := exec.Command("nft", "-a", "list", "chain", "inet", "banforge", "banned")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to list chain rules: %w", err) return "", fmt.Errorf("failed to list chain rules: %w", err)
@@ -167,19 +168,94 @@ func (n *Nftables) findRuleHandle(ip string) (string, error) {
return "", nil return "", nil
} }
func (n *Nftables) PortOpen(port int, protocol string) error {
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
n.logger.Error("invalid protocol")
return fmt.Errorf("invalid protocol")
}
s := strconv.Itoa(port)
// #nosec G204 - managed by system adminstartor
cmd := exec.Command(
"nft",
"add",
"rule",
"inet",
"banforge",
"input",
protocol,
"dport",
s,
"accept",
)
output, err := cmd.CombinedOutput()
if err != nil {
n.logger.Error(err.Error())
return err
}
n.logger.Info("Add port " + s + " " + string(output))
err = saveNftablesConfig(n.config)
if err != nil {
n.logger.Error("failed to save config",
"config_path", n.config,
"error", err.Error())
return err
}
}
return nil
}
func (n *Nftables) PortClose(port int, protocol string) error {
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
n.logger.Error("invalid protocol")
return fmt.Errorf("invalid protocol")
}
s := strconv.Itoa(port)
// #nosec G204 - managed by system adminstartor
cmd := exec.Command(
"nft",
"add",
"rule",
"inet",
"banforge",
"input",
protocol,
"dport",
s,
"drop",
)
output, err := cmd.CombinedOutput()
if err != nil {
n.logger.Error(err.Error())
return err
}
n.logger.Info("Add port " + s + " " + string(output))
err = saveNftablesConfig(n.config)
if err != nil {
n.logger.Error("failed to save config",
"config_path", n.config,
"error", err.Error())
return err
}
}
return nil
}
func saveNftablesConfig(configPath string) error { func saveNftablesConfig(configPath string) error {
err := validateConfigPath(configPath) err := validateConfigPath(configPath)
if err != nil { if err != nil {
return err return err
} }
cmd := exec.Command("sudo", "nft", "list", "ruleset") cmd := exec.Command("nft", "list", "ruleset")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return fmt.Errorf("failed to get nftables ruleset: %w", err) return fmt.Errorf("failed to get nftables ruleset: %w", err)
} }
// #nosec G204 - managed by system adminstartor
cmd = exec.Command("sudo", "tee", configPath) cmd = exec.Command("tee", configPath)
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
return fmt.Errorf("failed to create stdin pipe: %w", err) return fmt.Errorf("failed to create stdin pipe: %w", err)

View File

@@ -3,6 +3,7 @@ package blocker
import ( import (
"fmt" "fmt"
"os/exec" "os/exec"
"strconv"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
) )
@@ -22,8 +23,8 @@ func (u *Ufw) Ban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
// #nosec G204 - ip is validated
cmd := exec.Command("sudo", "ufw", "--force", "deny", "from", ip) cmd := exec.Command("ufw", "--force", "deny", "from", ip)
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
u.logger.Error("failed to ban IP", u.logger.Error("failed to ban IP",
@@ -41,8 +42,8 @@ func (u *Ufw) Unban(ip string) error {
if err != nil { if err != nil {
return err return err
} }
// #nosec G204 - ip is validated
cmd := exec.Command("sudo", "ufw", "--force", "delete", "deny", "from", ip) cmd := exec.Command("ufw", "--force", "delete", "deny", "from", ip)
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
u.logger.Error("failed to unban IP", u.logger.Error("failed to unban IP",
@@ -56,10 +57,48 @@ func (u *Ufw) Unban(ip string) error {
return nil return nil
} }
func (u *Ufw) PortOpen(port int, protocol string) error {
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
u.logger.Error("invalid protocol")
return fmt.Errorf("invalid protocol")
}
s := strconv.Itoa(port)
// #nosec G204 - managed by system adminstartor
cmd := exec.Command("ufw", "allow", s+"/"+protocol)
output, err := cmd.CombinedOutput()
if err != nil {
u.logger.Error(err.Error())
return err
}
u.logger.Info("Add port " + s + " " + string(output))
}
return nil
}
func (u *Ufw) PortClose(port int, protocol string) error {
if port >= 0 && port <= 65535 {
if protocol != "tcp" && protocol != "udp" {
u.logger.Error("invalid protocol")
return nil
}
s := strconv.Itoa(port)
// #nosec G204 - managed by system adminstartor
cmd := exec.Command("ufw", "deny", s+"/"+protocol)
output, err := cmd.CombinedOutput()
if err != nil {
u.logger.Error(err.Error())
return err
}
u.logger.Info("Add port " + s + " " + string(output))
}
return nil
}
func (u *Ufw) Setup(config string) error { func (u *Ufw) Setup(config string) error {
if config != "" { if config != "" {
fmt.Printf("Ufw dont support config file\n") fmt.Printf("Ufw dont support config file\n")
cmd := exec.Command("sudo", "ufw", "enable") cmd := exec.Command("ufw", "enable")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
u.logger.Error("failed to enable ufw", u.logger.Error("failed to enable ufw",
@@ -69,7 +108,7 @@ func (u *Ufw) Setup(config string) error {
} }
} }
if config == "" { if config == "" {
cmd := exec.Command("sudo", "ufw", "enable") cmd := exec.Command("ufw", "enable")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
u.logger.Error("failed to enable ufw", u.logger.Error("failed to enable ufw",

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@@ -9,8 +10,25 @@ import (
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/d3m0k1d/BanForge/internal/logger" "github.com/d3m0k1d/BanForge/internal/logger"
"github.com/d3m0k1d/BanForge/internal/metrics"
) )
func LoadMetricsConfig() (*Metrics, error) {
cfg := &Metrics{}
_, err := toml.DecodeFile("/etc/banforge/config.toml", cfg)
if err != nil {
return nil, fmt.Errorf("failed to decode config: %w", err)
}
if cfg.Enabled && cfg.Port > 0 && cfg.Port < 65535 {
go metrics.StartMetricsServer(cfg.Port)
} else if cfg.Enabled {
fmt.Println("Metrics enabled but port invalid, not starting server")
}
return cfg, nil
}
func LoadRuleConfig() ([]Rule, error) { func LoadRuleConfig() ([]Rule, error) {
log := logger.New(false) log := logger.New(false)
var cfg Rules var cfg Rules
@@ -32,6 +50,7 @@ func NewRule(
Status string, Status string,
Method string, Method string,
ttl string, ttl string,
max_retry int,
) error { ) error {
r, err := LoadRuleConfig() r, err := LoadRuleConfig()
if err != nil { if err != nil {
@@ -50,6 +69,7 @@ func NewRule(
Status: Status, Status: Status,
Method: Method, Method: Method,
BanTime: ttl, BanTime: ttl,
MaxRetry: max_retry,
}, },
) )
file, err := os.Create("/etc/banforge/rules.toml") file, err := os.Create("/etc/banforge/rules.toml")
@@ -57,13 +77,9 @@ func NewRule(
return err return err
} }
defer func() { defer func() {
err = file.Close() err = errors.Join(err, file.Close())
if err != nil {
fmt.Println(err)
}
}() }()
cfg := Rules{Rules: r} cfg := Rules{Rules: r}
err = toml.NewEncoder(file).Encode(cfg) err = toml.NewEncoder(file).Encode(cfg)
if err != nil { if err != nil {
return err return err
@@ -126,24 +142,24 @@ func EditRule(Name string, ServiceName string, Path string, Status string, Metho
} }
func ParseDurationWithYears(s string) (time.Duration, error) { func ParseDurationWithYears(s string) (time.Duration, error) {
if strings.HasSuffix(s, "y") { if ss, ok := strings.CutSuffix(s, "y"); ok {
years, err := strconv.Atoi(strings.TrimSuffix(s, "y")) years, err := strconv.Atoi(ss)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return time.Duration(years) * 365 * 24 * time.Hour, nil return time.Duration(years) * 365 * 24 * time.Hour, nil
} }
if strings.HasSuffix(s, "M") { if ss, ok := strings.CutSuffix(s, "M"); ok {
months, err := strconv.Atoi(strings.TrimSuffix(s, "M")) months, err := strconv.Atoi(ss)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return time.Duration(months) * 30 * 24 * time.Hour, nil return time.Duration(months) * 30 * 24 * time.Hour, nil
} }
if strings.HasSuffix(s, "d") { if ss, ok := strings.CutSuffix(s, "d"); ok {
days, err := strconv.Atoi(strings.TrimSuffix(s, "d")) days, err := strconv.Atoi(ss)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -8,6 +8,10 @@ const Base_config = `
name = "" name = ""
config = "/etc/nftables.conf" config = "/etc/nftables.conf"
[metrics]
enabled = false
port = 2122
[[service]] [[service]]
name = "nginx" name = "nginx"
logging = "file" logging = "file"

View File

@@ -14,6 +14,7 @@ type Service struct {
type Config struct { type Config struct {
Firewall Firewall `toml:"firewall"` Firewall Firewall `toml:"firewall"`
Metrics Metrics `toml:"metrics"`
Service []Service `toml:"service"` Service []Service `toml:"service"`
} }
@@ -28,5 +29,11 @@ type Rule struct {
Path string `toml:"path"` Path string `toml:"path"`
Status string `toml:"status"` Status string `toml:"status"`
Method string `toml:"method"` Method string `toml:"method"`
MaxRetry int `toml:"max_retry"`
BanTime string `toml:"ban_time"` BanTime string `toml:"ban_time"`
} }
type Metrics struct {
Enabled bool `toml:"enabled"`
Port int `toml:"port"`
}

View File

@@ -2,6 +2,7 @@ package judge
import ( import (
"fmt" "fmt"
"strings"
"time" "time"
"github.com/d3m0k1d/BanForge/internal/blocker" "github.com/d3m0k1d/BanForge/internal/blocker"
@@ -11,18 +12,33 @@ import (
) )
type Judge struct { type Judge struct {
db *storage.DB db_r *storage.BanReader
db_w *storage.BanWriter
db_rq *storage.RequestReader
logger *logger.Logger logger *logger.Logger
Blocker blocker.BlockerEngine Blocker blocker.BlockerEngine
rulesByService map[string][]config.Rule rulesByService map[string][]config.Rule
entryCh chan *storage.LogEntry
resultCh chan *storage.LogEntry
} }
func New(db *storage.DB, b blocker.BlockerEngine) *Judge { func New(
db_r *storage.BanReader,
db_w *storage.BanWriter,
db_rq *storage.RequestReader,
b blocker.BlockerEngine,
resultCh chan *storage.LogEntry,
entryCh chan *storage.LogEntry,
) *Judge {
return &Judge{ return &Judge{
db: db, db_w: db_w,
db_r: db_r,
db_rq: db_rq,
logger: logger.New(false), logger: logger.New(false),
rulesByService: make(map[string][]config.Rule), rulesByService: make(map[string][]config.Rule),
Blocker: b, Blocker: b,
entryCh: entryCh,
resultCh: resultCh,
} }
} }
@@ -37,84 +53,90 @@ func (j *Judge) LoadRules(rules []config.Rule) {
j.logger.Info("Rules loaded and indexed by service") j.logger.Info("Rules loaded and indexed by service")
} }
func (j *Judge) ProcessUnviewed() error { func (j *Judge) Tribunal() {
rows, err := j.db.SearchUnViewed() j.logger.Info("Tribunal started")
if err != nil {
j.logger.Error(fmt.Sprintf("Failed to query database: %v", err)) for entry := range j.entryCh {
return err j.logger.Debug(
} "Processing entry",
defer func() { "ip",
err = rows.Close() entry.IP,
if err != nil { "service",
j.logger.Error(fmt.Sprintf("Failed to close database connection: %v", err)) entry.Service,
} "status",
}() entry.Status,
for rows.Next() {
var entry storage.LogEntry
err = rows.Scan(
&entry.ID,
&entry.Service,
&entry.IP,
&entry.Path,
&entry.Status,
&entry.Method,
&entry.IsViewed,
&entry.CreatedAt,
) )
if err != nil {
j.logger.Error(fmt.Sprintf("Failed to scan database row: %v", err)) rules, serviceExists := j.rulesByService[entry.Service]
if !serviceExists {
j.logger.Debug("No rules for service", "service", entry.Service)
continue continue
} }
rules, serviceExists := j.rulesByService[entry.Service] ruleMatched := false
if serviceExists { for _, rule := range rules {
for _, rule := range rules { methodMatch := rule.Method == "" || entry.Method == rule.Method
if (rule.Method == "" || entry.Method == rule.Method) && statusMatch := rule.Status == "" || entry.Status == rule.Status
(rule.Status == "" || entry.Status == rule.Status) && pathMatch := matchPath(entry.Path, rule.Path)
(rule.Path == "" || entry.Path == rule.Path) { if methodMatch && statusMatch && pathMatch {
ruleMatched = true
j.logger.Info( j.logger.Info("Rule matched", "rule", rule.Name, "ip", entry.IP)
fmt.Sprintf( j.resultCh <- entry
"Rule matched for IP: %s, Service: %s", banned, err := j.db_r.IsBanned(entry.IP)
entry.IP, if err != nil {
entry.Service, j.logger.Error("Failed to check ban status", "ip", entry.IP, "error", err)
),
)
ban_status, err := j.db.IsBanned(entry.IP)
if err != nil {
j.logger.Error(fmt.Sprintf("Failed to check ban status: %v", err))
return err
}
if !ban_status {
err = j.Blocker.Ban(entry.IP)
if err != nil {
j.logger.Error(fmt.Sprintf("Failed to ban IP: %v", err))
}
j.logger.Info(fmt.Sprintf("IP banned: %s", entry.IP))
err = j.db.AddBan(entry.IP, rule.BanTime)
if err != nil {
j.logger.Error(fmt.Sprintf("Failed to add ban: %v", err))
}
}
break break
} }
if banned {
j.logger.Info("IP already banned", "ip", entry.IP)
break
}
exceeded, err := j.db_rq.IsMaxRetryExceeded(entry.IP, rule.MaxRetry)
if err != nil {
j.logger.Error("Failed to check retry count", "ip", entry.IP, "error", err)
break
}
if !exceeded {
j.logger.Info("Max retry not exceeded", "ip", entry.IP)
break
}
err = j.db_w.AddBan(entry.IP, rule.BanTime, rule.Name)
if err != nil {
j.logger.Error(
"Failed to add ban to database",
"ip",
entry.IP,
"ban_time",
rule.BanTime,
"error",
err,
)
break
}
if err := j.Blocker.Ban(entry.IP); err != nil {
j.logger.Error("Failed to ban IP at firewall", "ip", entry.IP, "error", err)
break
}
j.logger.Info(
"IP banned successfully",
"ip",
entry.IP,
"rule",
rule.Name,
"ban_time",
rule.BanTime,
)
break
} }
} }
err = j.db.MarkAsViewed(entry.ID) if !ruleMatched {
if err != nil { j.logger.Debug("No rules matched", "ip", entry.IP, "service", entry.Service)
j.logger.Error(fmt.Sprintf("Failed to mark entry as viewed: %v", err))
} else {
j.logger.Info(fmt.Sprintf("Entry marked as viewed: ID=%d", entry.ID))
} }
} }
if err = rows.Err(); err != nil { j.logger.Info("Tribunal stopped - entryCh closed")
j.logger.Error(fmt.Sprintf("Error iterating rows: %v", err))
return err
}
return nil
} }
func (j *Judge) UnbanChecker() { func (j *Judge) UnbanChecker() {
@@ -122,7 +144,7 @@ func (j *Judge) UnbanChecker() {
defer tick.Stop() defer tick.Stop()
for range tick.C { for range tick.C {
ips, err := j.db.CheckExpiredBans() ips, err := j.db_w.RemoveExpiredBans()
if err != nil { if err != nil {
j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err)) j.logger.Error(fmt.Sprintf("Failed to check expired bans: %v", err))
continue continue
@@ -130,10 +152,30 @@ func (j *Judge) UnbanChecker() {
for _, ip := range ips { for _, ip := range ips {
if err := j.Blocker.Unban(ip); err != nil { if err := j.Blocker.Unban(ip); err != nil {
j.logger.Error(fmt.Sprintf("Failed to unban IP %s: %v", ip, err)) j.logger.Error(fmt.Sprintf("Failed to unban IP at firewall: %v", err))
continue
} }
j.logger.Info(fmt.Sprintf("IP unbanned: %s", ip))
} }
} }
} }
func matchPath(path string, rulePath string) bool {
if rulePath == "" {
return true
}
if strings.HasPrefix(rulePath, "*") {
suffix := strings.TrimPrefix(rulePath, "*")
return strings.HasSuffix(path, suffix)
}
if strings.HasPrefix(rulePath, "/*") {
suffix := strings.TrimPrefix(rulePath, "/*")
return strings.HasSuffix(path, suffix)
}
if strings.HasSuffix(rulePath, "*") {
prefix := strings.TrimSuffix(rulePath, "*")
return strings.HasPrefix(path, prefix)
}
return path == rulePath
}

View File

@@ -18,21 +18,21 @@ func TestJudgeLogic(t *testing.T) {
{ {
name: "Empty rule", name: "Empty rule",
inputRule: config.Rule{Name: "", ServiceName: "", Path: "", Status: "", Method: ""}, inputRule: config.Rule{Name: "", ServiceName: "", Path: "", Status: "", Method: ""},
inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""},
wantErr: true, wantErr: true,
wantMatch: false, wantMatch: false,
}, },
{ {
name: "Matching rule", name: "Matching rule",
inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "200", Method: "GET"}, inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "200", Method: "GET"},
inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""},
wantErr: false, wantErr: false,
wantMatch: true, wantMatch: true,
}, },
{ {
name: "Non-matching status", name: "Non-matching status",
inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "404", Method: "GET"}, inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "404", Method: "GET"},
inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", CreatedAt: ""},
wantErr: false, wantErr: false,
wantMatch: false, wantMatch: false,
}, },

View File

@@ -1,8 +1,12 @@
package logger package logger
import ( import (
"io"
"log/slog" "log/slog"
"os" "os"
"path/filepath"
"gopkg.in/natefinch/lumberjack.v2"
) )
type Logger struct { type Logger struct {
@@ -10,13 +14,28 @@ type Logger struct {
} }
func New(debug bool) *Logger { func New(debug bool) *Logger {
logDir := "/var/log/banforge"
if err := os.MkdirAll(logDir, 0750); err != nil {
return nil
}
fileWriter := &lumberjack.Logger{
Filename: filepath.Join(logDir, "banforge.log"),
MaxSize: 500,
MaxBackups: 3,
MaxAge: 28,
Compress: true,
}
var level slog.Level var level slog.Level
if debug { if debug {
level = slog.LevelDebug level = slog.LevelDebug
} else { } else {
level = slog.LevelInfo level = slog.LevelInfo
} }
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ multiWriter := io.MultiWriter(fileWriter, os.Stdout)
handler := slog.NewTextHandler(multiWriter, &slog.HandlerOptions{
Level: level, Level: level,
}) })

View File

@@ -0,0 +1,77 @@
package metrics
import (
"fmt"
"log"
"net/http"
"sync"
"time"
)
var (
metricsMu sync.RWMutex
metrics = make(map[string]int64)
)
func IncBan(service string) {
metricsMu.Lock()
metrics["ban_count"]++
metrics[service+"_bans"]++
metricsMu.Unlock()
}
func IncUnban(service string) {
metricsMu.Lock()
metrics["unban_count"]++
metrics[service+"_unbans"]++
metricsMu.Unlock()
}
func IncRuleMatched(rule_name string) {
metricsMu.Lock()
metrics[rule_name+"_rule_matched"]++
metricsMu.Unlock()
}
func IncLogParsed() {
metricsMu.Lock()
metrics["log_parsed"]++
metricsMu.Unlock()
}
func MetricsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
metricsMu.RLock()
snapshot := make(map[string]int64, len(metrics))
for k, v := range metrics {
snapshot[k] = v
}
metricsMu.RUnlock()
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
for name, value := range snapshot {
metricName := name + "_total"
_, _ = fmt.Fprintf(w, "# TYPE %s counter\n", metricName)
_, _ = fmt.Fprintf(w, "%s %d\n", metricName, value)
}
})
}
func StartMetricsServer(port int) {
mux := http.NewServeMux()
mux.Handle("/metrics", MetricsHandler())
server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 15 * time.Second,
}
log.Printf("Starting metrics server on %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Printf("Metrics server error: %v", err)
}
}

View File

@@ -0,0 +1,61 @@
package parser
import (
"regexp"
"github.com/d3m0k1d/BanForge/internal/logger"
"github.com/d3m0k1d/BanForge/internal/storage"
)
type ApacheParser struct {
pattern *regexp.Regexp
logger *logger.Logger
}
func NewApacheParser() *ApacheParser {
pattern := regexp.MustCompile(
`^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\s+-\s+-\s+\[(.*?)\]\s+"(\w+)\s+(.*?)\s+HTTP/[\d.]+"\s+(\d+)\s+(\d+|-)\s+"(.*?)"\s+"(.*?)"`,
)
// Groups:
// 1: IP
// 2: Timestamp
// 3: Method (GET, POST, etc.)
// 4: Path
// 5: Status Code (200, 404, 403...)
// 6: Response Size
// 7: Referer
// 8: User-Agent
return &ApacheParser{
pattern: pattern,
logger: logger.New(false),
}
}
func (p *ApacheParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEntry) {
// Group 1: IP, Group 2: Timestamp, Group 3: Method, Group 4: Path, Group 5: Status
for event := range eventCh {
matches := p.pattern.FindStringSubmatch(event.Data)
if matches == nil {
continue
}
path := matches[4]
status := matches[5]
method := matches[3]
resultCh <- &storage.LogEntry{
Service: "apache",
IP: matches[1],
Path: path,
Status: status,
Method: method,
}
p.logger.Info(
"Parsed apache log entry",
"ip", matches[1],
"path", path,
"status", status,
"method", method,
)
}
}

View File

@@ -24,35 +24,32 @@ func NewNginxParser() *NginxParser {
func (p *NginxParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEntry) { func (p *NginxParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEntry) {
// Group 1: IP, Group 2: Timestamp, Group 3: Method, Group 4: Path, Group 5: Status // Group 1: IP, Group 2: Timestamp, Group 3: Method, Group 4: Path, Group 5: Status
go func() { for event := range eventCh {
for event := range eventCh { matches := p.pattern.FindStringSubmatch(event.Data)
matches := p.pattern.FindStringSubmatch(event.Data) if matches == nil {
if matches == nil { continue
continue
}
path := matches[4]
status := matches[5]
method := matches[3]
resultCh <- &storage.LogEntry{
Service: "nginx",
IP: matches[1],
Path: path,
Status: status,
Method: method,
IsViewed: false,
}
p.logger.Info(
"Parsed nginx log entry",
"ip",
matches[1],
"path",
path,
"status",
status,
"method",
method,
)
} }
}() path := matches[4]
status := matches[5]
method := matches[3]
resultCh <- &storage.LogEntry{
Service: "nginx",
IP: matches[1],
Path: path,
Status: status,
Method: method,
}
p.logger.Info(
"Parsed nginx log entry",
"ip",
matches[1],
"path",
path,
"status",
status,
"method",
method,
)
}
} }

View File

@@ -24,6 +24,7 @@ type Scanner struct {
} }
func NewScannerTail(path string) (*Scanner, error) { func NewScannerTail(path string) (*Scanner, error) {
// #nosec G204 - managed by system adminstartor
cmd := exec.Command("tail", "-F", "-n", "10", path) cmd := exec.Command("tail", "-F", "-n", "10", path)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
@@ -46,6 +47,7 @@ func NewScannerTail(path string) (*Scanner, error) {
} }
func NewScannerJournald(unit string) (*Scanner, error) { func NewScannerJournald(unit string) (*Scanner, error) {
// #nosec G204 - managed by system adminstartor
cmd := exec.Command("journalctl", "-u", unit, "-f", "-n", "0", "-o", "short", "--no-pager") cmd := exec.Command("journalctl", "-u", unit, "-f", "-n", "0", "-o", "short", "--no-pager")
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {

View File

@@ -31,12 +31,11 @@ func (p *SshdParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEnt
continue continue
} }
resultCh <- &storage.LogEntry{ resultCh <- &storage.LogEntry{
Service: "ssh", Service: "ssh",
IP: matches[6], IP: matches[6],
Path: matches[5], // user Path: matches[5], // user
Status: "Failed", Status: "Failed",
Method: matches[4], // method auth Method: matches[4], // method auth
IsViewed: false,
} }
p.logger.Info( p.logger.Info(
"Parsed ssh log entry", "Parsed ssh log entry",

213
internal/storage/ban_db.go Normal file
View File

@@ -0,0 +1,213 @@
package storage
import (
"database/sql"
"fmt"
"os"
"time"
"github.com/d3m0k1d/BanForge/internal/config"
"github.com/d3m0k1d/BanForge/internal/logger"
"github.com/jedib0t/go-pretty/v6/table"
_ "modernc.org/sqlite"
)
// Writer block
type BanWriter struct {
logger *logger.Logger
db *sql.DB
}
func NewBanWriter() (*BanWriter, error) {
db, err := sql.Open(
"sqlite",
buildSqliteDsn(banDBPath, pragmas),
)
if err != nil {
return nil, err
}
return &BanWriter{
logger: logger.New(false),
db: db,
}, nil
}
func (d *BanWriter) CreateTable() error {
_, err := d.db.Exec(CreateBansTable)
if err != nil {
return err
}
d.logger.Info("Created tables")
return nil
}
func (d *BanWriter) AddBan(ip string, ttl string, reason string) error {
duration, err := config.ParseDurationWithYears(ttl)
if err != nil {
d.logger.Error("Invalid duration format", "ttl", ttl, "error", err)
return fmt.Errorf("invalid duration: %w", err)
}
now := time.Now()
expiredAt := now.Add(duration)
_, err = d.db.Exec(
"INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)",
ip,
reason,
now.Format(time.RFC3339),
expiredAt.Format(time.RFC3339),
)
if err != nil {
d.logger.Error("Failed to add ban", "error", err)
return err
}
return nil
}
func (d *BanWriter) RemoveBan(ip string) error {
_, err := d.db.Exec("DELETE FROM bans WHERE ip = ?", ip)
if err != nil {
d.logger.Error("Failed to remove ban", "error", err)
return err
}
return nil
}
func (w *BanWriter) RemoveExpiredBans() ([]string, error) {
var ips []string
now := time.Now().Format(time.RFC3339)
rows, err := w.db.Query(
"SELECT ip FROM bans WHERE expired_at < ?",
now,
)
if err != nil {
w.logger.Error("Failed to get expired bans", "error", err)
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
w.logger.Error("Failed to close rows", "error", err)
}
}()
for rows.Next() {
var ip string
err := rows.Scan(&ip)
if err != nil {
w.logger.Error("Failed to scan ban", "error", err)
continue
}
ips = append(ips, ip)
}
if err = rows.Err(); err != nil {
return nil, err
}
result, err := w.db.Exec(
"DELETE FROM bans WHERE expired_at < ?",
now,
)
if err != nil {
w.logger.Error("Failed to remove expired bans", "error", err)
return nil, err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, err
}
if rowsAffected > 0 {
w.logger.Info("Removed expired bans", "count", rowsAffected, "ips", len(ips))
}
return ips, nil
}
func (d *BanWriter) Close() error {
d.logger.Info("Closing database connection")
err := d.db.Close()
if err != nil {
return err
}
return nil
}
// Reader block
type BanReader struct {
logger *logger.Logger
db *sql.DB
}
func NewBanReader() (*BanReader, error) {
db, err := sql.Open("sqlite",
"/var/lib/banforge/bans.db?"+
"mode=ro&"+
"_pragma=journal_mode(WAL)&"+
"_pragma=mmap_size(268435456)&"+
"_pragma=cache_size(-2000)&"+
"_pragma=query_only(1)")
if err != nil {
return nil, err
}
return &BanReader{
logger: logger.New(false),
db: db,
}, nil
}
func (d *BanReader) IsBanned(ip string) (bool, error) {
var bannedIP string
err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, fmt.Errorf("failed to check ban status: %w", err)
}
return true, nil
}
func (d *BanReader) BanList() error {
var count int
t := table.NewWriter()
t.SetOutputMirror(os.Stdout)
t.SetStyle(table.StyleBold)
t.AppendHeader(table.Row{"№", "IP", "Banned At", "Reason", "Expires At"})
rows, err := d.db.Query("SELECT ip, banned_at, reason, expired_at FROM bans")
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return err
}
for rows.Next() {
count++
var ip string
var bannedAt string
var reason string
var expiredAt string
err := rows.Scan(&ip, &bannedAt, &reason, &expiredAt)
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return err
}
t.AppendRow(table.Row{count, ip, bannedAt, reason, expiredAt})
}
t.Render()
return nil
}
func (d *BanReader) Close() error {
d.logger.Info("Closing database connection")
err := d.db.Close()
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,380 @@
package storage
import (
"database/sql"
"github.com/d3m0k1d/BanForge/internal/logger"
"path/filepath"
"testing"
)
func TestBanWriter_AddBan(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ip := "192.168.1.1"
ttl := "1h"
err = writer.AddBan(ip, ttl, "test")
if err != nil {
t.Errorf("AddBan failed: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed: %v", err)
}
if !isBanned {
t.Error("Expected IP to be banned, but it's not")
}
}
func TestBanWriter_RemoveBan(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ip := "192.168.1.2"
err = writer.AddBan(ip, "1h", "test")
if err != nil {
t.Fatalf("Failed to add ban: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Fatalf("IsBanned failed: %v", err)
}
if !isBanned {
t.Fatal("Expected IP to be banned before removal")
}
err = writer.RemoveBan(ip)
if err != nil {
t.Errorf("RemoveBan failed: %v", err)
}
isBanned, err = reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed after removal: %v", err)
}
if isBanned {
t.Error("Expected IP to be unbanned after removal, but it's still banned")
}
}
func TestBanWriter_RemoveExpiredBans(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
expiredIP := "192.168.1.3"
err = writer.AddBan(expiredIP, "-1h", "tes")
if err != nil {
t.Fatalf("Failed to add expired ban: %v", err)
}
activeIP := "192.168.1.4"
err = writer.AddBan(activeIP, "1h", "test")
if err != nil {
t.Fatalf("Failed to add active ban: %v", err)
}
removedIPs, err := writer.RemoveExpiredBans()
if err != nil {
t.Errorf("RemoveExpiredBans failed: %v", err)
}
found := false
for _, ip := range removedIPs {
if ip == expiredIP {
found = true
break
}
}
if !found {
t.Error("Expected expired IP to be in removed list")
}
if len(removedIPs) != 1 {
t.Errorf("Expected 1 removed IP, got %d", len(removedIPs))
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isExpiredBanned, err := reader.IsBanned(expiredIP)
if err != nil {
t.Errorf("IsBanned failed for expired IP: %v", err)
}
if isExpiredBanned {
t.Error("Expected expired IP to be unbanned, but it's still banned")
}
isActiveBanned, err := reader.IsBanned(activeIP)
if err != nil {
t.Errorf("IsBanned failed for active IP: %v", err)
}
if !isActiveBanned {
t.Error("Expected active IP to still be banned, but it's not")
}
}
func TestBanReader_IsBanned(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ip := "192.168.1.5"
err = writer.AddBan(ip, "1h", "test")
if err != nil {
t.Fatalf("Failed to add ban: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed for banned IP: %v", err)
}
if !isBanned {
t.Error("Expected IP to be banned")
}
isBanned, err = reader.IsBanned("192.168.1.6")
if err != nil {
t.Errorf("IsBanned failed for non-banned IP: %v", err)
}
if isBanned {
t.Error("Expected IP to not be banned")
}
}
func TestBanWriter_Close(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
err = writer.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
_, err = writer.db.Exec("SELECT 1")
if err == nil {
t.Error("Expected error when using closed connection")
}
}
func TestBanReader_Close(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
err = reader.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
_, err = reader.db.Query("SELECT 1")
if err == nil {
t.Error("Expected error when using closed connection")
}
}
func TestBanWriter_AddBan_InvalidDuration(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
err = writer.AddBan("192.168.1.7", "invalid_duration", "test")
if err == nil {
t.Error("Expected error for invalid duration")
} else if err.Error() == "" || err.Error() == "<nil>" {
t.Error("Expected meaningful error message for invalid duration")
}
}
func TestMultipleBans(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ips := []string{"192.168.1.8", "192.168.1.9", "192.168.1.10"}
for _, ip := range ips {
err := writer.AddBan(ip, "1h", "test")
if err != nil {
t.Errorf("Failed to add ban for IP %s: %v", ip, err)
}
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
for _, ip := range ips {
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed for IP %s: %v", ip, err)
continue
}
if !isBanned {
t.Errorf("Expected IP %s to be banned", ip)
}
}
}
func TestRemoveNonExistentBan(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
err = writer.RemoveBan("192.168.1.11")
if err != nil {
t.Errorf("RemoveBan should not return error for non-existent ban: %v", err)
}
}
func NewBanWriterWithDBPath(dbPath string) (*BanWriter, error) {
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)")
if err != nil {
return nil, err
}
return &BanWriter{
logger: logger.New(false),
db: db,
}, nil
}
func NewBanReaderWithDBPath(dbPath string) (*BanReader, error) {
db, err := sql.Open("sqlite",
dbPath+"?"+
"mode=ro&"+
"_pragma=journal_mode(WAL)&"+
"_pragma=mmap_size(268435456)&"+
"_pragma=cache_size(-2000)&"+
"_pragma=query_only(1)")
if err != nil {
return nil, err
}
return &BanReader{
logger: logger.New(false),
db: db,
}, nil
}

View File

@@ -2,167 +2,60 @@ package storage
import ( import (
"database/sql" "database/sql"
"os" "errors"
"fmt" "fmt"
"time" "strings"
"github.com/d3m0k1d/BanForge/internal/config" _ "modernc.org/sqlite"
"github.com/d3m0k1d/BanForge/internal/logger"
"github.com/jedib0t/go-pretty/v6/table"
_ "github.com/mattn/go-sqlite3"
) )
type DB struct { const (
logger *logger.Logger DBDir = "/var/lib/banforge/"
db *sql.DB ReqDBPath = DBDir + "requests.db"
banDBPath = DBDir + "bans.db"
)
var pragmas = map[string]string{
`journal_mode`: `wal`,
`synchronous`: `normal`,
`busy_timeout`: `30000`,
// also consider these
// `temp_store`: `memory`,
// `cache_size`: `1000000000`,
} }
func NewDB() (*DB, error) { func buildSqliteDsn(path string, pragmas map[string]string) string {
db, err := sql.Open( pragmastrs := make([]string, len(pragmas))
"sqlite3", i := 0
"/var/lib/banforge/storage.db?mode=rwc&_journal_mode=WAL&_busy_timeout=10000&cache=shared", for k, v := range pragmas {
) pragmastrs[i] = (fmt.Sprintf(`pragma=%s(%s)`, k, v))
if err != nil { i++
return nil, err
} }
return path + "?" + "mode=rwc&" + strings.Join(pragmastrs, "&")
if err := db.Ping(); err != nil {
return nil, err
}
return &DB{
logger: logger.New(false),
db: db,
}, nil
} }
func (d *DB) Close() error { func initDB(dsn, sqlstr string) (err error) {
d.logger.Info("Closing database connection") db, err := sql.Open("sqlite", dsn)
err := d.db.Close()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to open %q: %w", dsn, err)
} }
return nil defer func() {
} closeErr := db.Close()
if closeErr != nil {
func (d *DB) CreateTable() error { err = errors.Join(err, fmt.Errorf("failed to close %q: %w", dsn, closeErr))
_, err := d.db.Exec(CreateTables)
if err != nil {
return err
}
d.logger.Info("Created tables")
return nil
}
func (d *DB) SearchUnViewed() (*sql.Rows, error) {
rows, err := d.db.Query(
"SELECT id, service, ip, path, status, method, viewed, created_at FROM requests WHERE viewed = 0",
)
if err != nil {
d.logger.Error("Failed to query database")
return nil, err
}
return rows, nil
}
func (d *DB) MarkAsViewed(id int) error {
_, err := d.db.Exec("UPDATE requests SET viewed = 1 WHERE id = ?", id)
if err != nil {
d.logger.Error("Failed to mark as viewed", "error", err)
return err
}
return nil
}
func (d *DB) IsBanned(ip string) (bool, error) {
var bannedIP string
err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, fmt.Errorf("failed to check ban status: %w", err)
}
return true, nil
}
func (d *DB) AddBan(ip string, ttl string) error {
duration, err := config.ParseDurationWithYears(ttl)
if err != nil {
d.logger.Error("Invalid duration format", "ttl", ttl, "error", err)
return fmt.Errorf("invalid duration: %w", err)
}
now := time.Now()
expiredAt := now.Add(duration)
_, err = d.db.Exec(
"INSERT INTO bans (ip, reason, banned_at, expired_at) VALUES (?, ?, ?, ?)",
ip,
"1",
now.Format(time.RFC3339),
expiredAt.Format(time.RFC3339),
)
if err != nil {
d.logger.Error("Failed to add ban", "error", err)
return err
}
return nil
}
func (d *DB) BanList() error {
var count int
t := table.NewWriter()
t.SetOutputMirror(os.Stdout)
t.SetStyle(table.StyleBold)
t.AppendHeader(table.Row{"№", "IP", "Banned At"})
rows, err := d.db.Query("SELECT ip, banned_at FROM bans")
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return err
}
for rows.Next() {
count++
var ip string
var bannedAt string
err := rows.Scan(&ip, &bannedAt)
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return err
} }
t.AppendRow(table.Row{count, ip, bannedAt}) }()
_, err = db.Exec(sqlstr)
}
t.Render()
return nil
}
func (d *DB) CheckExpiredBans() ([]string, error) {
var ips []string
rows, err := d.db.Query(
"SELECT ip FROM bans WHERE expired_at < ?",
time.Now().Format(time.RFC3339),
)
if err != nil { if err != nil {
d.logger.Error("Failed to get ban list", "error", err) return fmt.Errorf("failed to create table: %w", err)
return nil, err
} }
for rows.Next() { return err
var ip string }
r, err := d.db.Exec("DELETE FROM bans WHERE ip = ?", ip)
if err != nil { func CreateTables() (err error) {
d.logger.Error("Failed to get ban list", "error", err) // Requests DB
return nil, err err1 := initDB(buildSqliteDsn(ReqDBPath, pragmas), CreateRequestsTable)
} err2 := initDB(buildSqliteDsn(banDBPath, pragmas), CreateBansTable)
d.logger.Info("Ban removed", "ip", ip, "rows", r)
err = rows.Scan(&ip) return errors.Join(err1, err2)
if err != nil {
d.logger.Error("Failed to get ban list", "error", err)
return nil, err
}
ips = append(ips, ip)
}
return ips, nil
} }

View File

@@ -1,243 +0,0 @@
package storage
import (
"database/sql"
"github.com/d3m0k1d/BanForge/internal/logger"
_ "github.com/mattn/go-sqlite3"
"os"
"path/filepath"
"testing"
"time"
)
func createTestDB(t *testing.T) *sql.DB {
tmpDir, err := os.MkdirTemp("", "banforge-test-*")
if err != nil {
t.Fatal(err)
}
filePath := filepath.Join(tmpDir, "test.db")
db, err := sql.Open("sqlite3", filePath)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
db.Close()
os.RemoveAll(tmpDir)
})
return db
}
func createTestDBStruct(t *testing.T) *DB {
tmpDir, err := os.MkdirTemp("", "banforge-test-*")
if err != nil {
t.Fatal(err)
}
filePath := filepath.Join(tmpDir, "test.db")
sqlDB, err := sql.Open("sqlite3", filePath)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
sqlDB.Close()
os.RemoveAll(tmpDir)
})
return &DB{
logger: logger.New(false),
db: sqlDB,
}
}
func TestCreateTable(t *testing.T) {
d := createTestDBStruct(t)
err := d.CreateTable()
if err != nil {
t.Fatal(err)
}
rows, err := d.db.Query("SELECT 1 FROM requests LIMIT 1")
if err != nil {
t.Fatal("requests table should exist:", err)
}
rows.Close()
rows, err = d.db.Query("SELECT 1 FROM bans LIMIT 1")
if err != nil {
t.Fatal("bans table should exist:", err)
}
rows.Close()
}
func TestMarkAsViewed(t *testing.T) {
d := createTestDBStruct(t)
err := d.CreateTable()
if err != nil {
t.Fatal(err)
}
_, err = d.db.Exec(
"INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)",
"test",
"127.0.0.1",
"/test",
"GET",
"200",
time.Now().Format(time.RFC3339),
)
if err != nil {
t.Fatal(err)
}
err = d.MarkAsViewed(1)
if err != nil {
t.Fatal(err)
}
var isViewed bool
err = d.db.QueryRow("SELECT viewed FROM requests WHERE id = 1").Scan(&isViewed)
if err != nil {
t.Fatal(err)
}
if !isViewed {
t.Fatal("viewed should be true")
}
}
func TestSearchUnViewed(t *testing.T) {
d := createTestDBStruct(t)
err := d.CreateTable()
if err != nil {
t.Fatal(err)
}
for i := 0; i < 2; i++ {
_, err := d.db.Exec(
"INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)",
"test",
"127.0.0.1",
"/test",
"GET",
"200",
time.Now().Format(time.RFC3339),
)
if err != nil {
t.Fatal(err)
}
}
rows, err := d.SearchUnViewed()
if err != nil {
t.Fatal(err)
}
defer rows.Close()
count := 0
for rows.Next() {
var id int
var service, ip, path, status, method string
var viewed bool
var createdAt string
err := rows.Scan(&id, &service, &ip, &path, &status, &method, &viewed, &createdAt)
if err != nil {
t.Fatal(err)
}
if viewed {
t.Fatal("should be unviewed")
}
count++
}
if err := rows.Err(); err != nil {
t.Fatal(err)
}
if count != 2 {
t.Fatalf("expected 2 unviewed requests, got %d", count)
}
}
func TestIsBanned(t *testing.T) {
d := createTestDBStruct(t)
err := d.CreateTable()
if err != nil {
t.Fatal(err)
}
_, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339))
if err != nil {
t.Fatal(err)
}
isBanned, err := d.IsBanned("127.0.0.1")
if err != nil {
t.Fatal(err)
}
if !isBanned {
t.Fatal("should be banned")
}
}
func TestAddBan(t *testing.T) {
d := createTestDBStruct(t)
err := d.CreateTable()
if err != nil {
t.Fatal(err)
}
err = d.AddBan("127.0.0.1", "7h")
if err != nil {
t.Fatal(err)
}
var ip string
err = d.db.QueryRow("SELECT ip FROM bans WHERE ip = ?", "127.0.0.1").Scan(&ip)
if err != nil {
t.Fatal(err)
}
if ip != "127.0.0.1" {
t.Fatal("ip should be 127.0.0.1")
}
}
func TestBanList(t *testing.T) {
d := createTestDBStruct(t)
err := d.CreateTable()
if err != nil {
t.Fatal(err)
}
_, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339))
if err != nil {
t.Fatal(err)
}
err = d.BanList()
if err != nil {
t.Fatal(err)
}
}
func TestClose(t *testing.T) {
d := createTestDBStruct(t)
err := d.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -1,7 +1,6 @@
package storage package storage
const CreateTables = ` const CreateRequestsTable = `
CREATE TABLE IF NOT EXISTS requests ( CREATE TABLE IF NOT EXISTS requests (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
service TEXT NOT NULL, service TEXT NOT NULL,
@@ -9,10 +8,17 @@ CREATE TABLE IF NOT EXISTS requests (
path TEXT, path TEXT,
method TEXT, method TEXT,
status TEXT, status TEXT,
viewed BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP created_at DATETIME DEFAULT CURRENT_TIMESTAMP
); );
CREATE INDEX IF NOT EXISTS idx_requests_service ON requests(service);
CREATE INDEX IF NOT EXISTS idx_requests_ip ON requests(ip);
CREATE INDEX IF NOT EXISTS idx_requests_status ON requests(status);
CREATE INDEX IF NOT EXISTS idx_requests_created_at ON requests(created_at);
`
// Миграция для bans.db
const CreateBansTable = `
CREATE TABLE IF NOT EXISTS bans ( CREATE TABLE IF NOT EXISTS bans (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
ip TEXT UNIQUE NOT NULL, ip TEXT UNIQUE NOT NULL,
@@ -21,9 +27,5 @@ CREATE TABLE IF NOT EXISTS bans (
expired_at DATETIME expired_at DATETIME
); );
CREATE INDEX IF NOT EXISTS idx_service ON requests(service); CREATE INDEX IF NOT EXISTS idx_bans_ip ON bans(ip);
CREATE INDEX IF NOT EXISTS idx_ip ON requests(ip); `
CREATE INDEX IF NOT EXISTS idx_status ON requests(status);
CREATE INDEX IF NOT EXISTS idx_created_at ON requests(created_at);
CREATE INDEX IF NOT EXISTS idx_ban_ip ON bans(ip);
`

View File

@@ -7,7 +7,6 @@ type LogEntry struct {
Path string `db:"path"` Path string `db:"path"`
Status string `db:"status"` Status string `db:"status"`
Method string `db:"method"` Method string `db:"method"`
IsViewed bool `db:"viewed"`
CreatedAt string `db:"created_at"` CreatedAt string `db:"created_at"`
} }

View File

@@ -0,0 +1,66 @@
package storage
import (
"database/sql"
"github.com/d3m0k1d/BanForge/internal/logger"
_ "modernc.org/sqlite"
)
type RequestWriter struct {
logger *logger.Logger
db *sql.DB
}
func NewRequestsWr() (*RequestWriter, error) {
db, err := sql.Open(
"sqlite",
buildSqliteDsn(ReqDBPath, pragmas),
)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
return &RequestWriter{
logger: logger.New(false),
db: db,
}, nil
}
type RequestReader struct {
logger *logger.Logger
db *sql.DB
}
func NewRequestsRd() (*RequestReader, error) {
db, err := sql.Open(
"sqlite",
buildSqliteDsn(ReqDBPath, pragmas),
)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
return &RequestReader{
logger: logger.New(false),
db: db,
}, nil
}
func (r *RequestReader) IsMaxRetryExceeded(ip string, maxRetry int) (bool, error) {
var count int
if maxRetry == 0 {
return true, nil
}
err := r.db.QueryRow("SELECT COUNT(*) FROM requests WHERE ip = ?", ip).Scan(&count)
if err != nil {
r.logger.Error("error query count: " + err.Error())
return false, err
}
r.logger.Info("Current request count for IP", "ip", ip, "count", count, "maxRetry", maxRetry)
return count >= maxRetry, nil
}

View File

@@ -1,22 +1,106 @@
package storage package storage
import ( import (
"database/sql"
"errors"
"fmt"
"time" "time"
) )
func Write(db *DB, resultCh <-chan *LogEntry) { func WriteReq(db *RequestWriter, resultCh <-chan *LogEntry) {
for result := range resultCh { db.logger.Info("Starting log writer")
_, err := db.db.Exec( const batchSize = 100
"INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", const flushInterval = 1 * time.Second
result.Service,
result.IP, batch := make([]*LogEntry, 0, batchSize)
result.Path, ticker := time.NewTicker(flushInterval)
result.Method, defer ticker.Stop()
result.Status,
time.Now().Format(time.RFC3339), flush := func() {
) defer db.logger.Debug("Flushed batch", "count", len(batch))
err := func() (err error) {
if len(batch) == 0 {
return nil
}
tx, err := db.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil &&
!errors.Is(rollbackErr, sql.ErrTxDone) {
err = errors.Join(
err,
fmt.Errorf("failed to rollback transaction: %w", rollbackErr),
)
}
}()
stmt, err := tx.Prepare(
"INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)",
)
if err != nil {
err = fmt.Errorf("failed to prepare statement: %w", err)
return err
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
err = errors.Join(err, fmt.Errorf("failed to close statement: %w", closeErr))
}
}()
for _, entry := range batch {
_, err := stmt.Exec(
entry.Service,
entry.IP,
entry.Path,
entry.Method,
entry.Status,
time.Now().Format(time.RFC3339),
)
if err != nil {
db.logger.Error(fmt.Errorf("failed to insert entry: %w", err).Error())
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
batch = batch[:0]
return err
}()
if err != nil { if err != nil {
db.logger.Error("Failed to write to database", "error", err) db.logger.Error(err.Error())
}
}
for {
select {
case result, ok := <-resultCh:
if !ok {
flush()
return
}
batch = append(batch, result)
if len(batch) >= batchSize {
flush()
}
case <-ticker.C:
flush()
} }
} }
} }
func (w *RequestWriter) GetRequestCount() (int, error) {
var count int
err := w.db.QueryRow("SELECT COUNT(*) FROM requests").Scan(&count)
return count, err
}
func (w *RequestWriter) Close() error {
return w.db.Close()
}

View File

@@ -1,40 +1,301 @@
package storage package storage
import ( import (
"database/sql"
"github.com/d3m0k1d/BanForge/internal/logger"
_ "modernc.org/sqlite"
"path/filepath"
"testing" "testing"
"time" "time"
) )
func TestWrite(t *testing.T) { func TestWrite_BatchInsert(t *testing.T) {
var ip string tempDir := t.TempDir()
d := createTestDBStruct(t) dbPath := filepath.Join(tempDir, "requests_test.db")
err := d.CreateTable() writer, err := NewRequestWriterWithDBPath(dbPath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("Failed to create RequestWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
} }
resultCh := make(chan *LogEntry) resultCh := make(chan *LogEntry, 100)
go Write(d, resultCh) done := make(chan bool)
go func() {
WriteReq(writer, resultCh)
close(done)
}()
resultCh <- &LogEntry{ entries := []*LogEntry{
Service: "test", {Service: "service1", IP: "192.168.1.1", Path: "/path1", Method: "GET", Status: "200"},
IP: "127.0.0.1", {Service: "service2", IP: "192.168.1.2", Path: "/path2", Method: "POST", Status: "404"},
Path: "/test", {Service: "service3", IP: "192.168.1.3", Path: "/path3", Method: "PUT", Status: "500"},
Method: "GET", {Service: "service4", IP: "192.168.1.4", Path: "/path4", Method: "DELETE", Status: "200"},
Status: "200", {Service: "service5", IP: "192.168.1.5", Path: "/path5", Method: "GET", Status: "301"},
}
for _, entry := range entries {
resultCh <- entry
}
close(resultCh)
<-done
count, err := writer.GetRequestCount()
if err != nil {
t.Fatalf("Failed to get request count: %v", err)
}
if count != len(entries) {
t.Errorf("Expected %d entries, got %d", len(entries), count)
}
rows, err := writer.db.Query("SELECT service, ip, path, method, status FROM requests ORDER BY id")
if err != nil {
t.Fatalf("Failed to query requests: %v", err)
}
defer rows.Close()
i := 0
for rows.Next() {
var service, ip, path, method, status string
err := rows.Scan(&service, &ip, &path, &method, &status)
if err != nil {
t.Fatalf("Failed to scan row: %v", err)
}
if i >= len(entries) {
t.Fatal("More rows returned than expected")
}
expected := entries[i]
if service != expected.Service {
t.Errorf("Expected service %s, got %s", expected.Service, service)
}
if ip != expected.IP {
t.Errorf("Expected IP %s, got %s", expected.IP, ip)
}
if path != expected.Path {
t.Errorf("Expected path %s, got %s", expected.Path, path)
}
if method != expected.Method {
t.Errorf("Expected method %s, got %s", expected.Method, method)
}
if status != expected.Status {
t.Errorf("Expected status %s, got %s", expected.Status, status)
}
i++
}
if i != len(entries) {
t.Errorf("Expected to read %d entries, got %d", len(entries), i)
}
}
func TestWrite_BatchSizeTrigger(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "requests_test.db")
writer, err := NewRequestWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create RequestWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
resultCh := make(chan *LogEntry, 100)
done := make(chan bool)
go func() {
WriteReq(writer, resultCh)
close(done)
}()
batchSize := 100
entries := make([]*LogEntry, batchSize)
for i := 0; i < batchSize; i++ {
entries[i] = &LogEntry{
Service: "service" + string(rune(i+'0')),
IP: "192.168.1." + string(rune(i+'0')),
Path: "/path" + string(rune(i+'0')),
Method: "GET",
Status: "200",
}
}
for _, entry := range entries {
resultCh <- entry
}
close(resultCh)
<-done
count, err := writer.GetRequestCount()
if err != nil {
t.Fatalf("Failed to get request count: %v", err)
}
if count != batchSize {
t.Errorf("Expected %d entries, got %d", batchSize, count)
}
}
func TestWrite_FlushInterval(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "requests_test.db")
writer, err := NewRequestWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create RequestWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
resultCh := make(chan *LogEntry, 100)
done := make(chan bool)
go func() {
WriteReq(writer, resultCh)
close(done)
}()
entries := []*LogEntry{
{Service: "service1", IP: "192.168.1.1", Path: "/path1", Method: "GET", Status: "200"},
{Service: "service2", IP: "192.168.1.2", Path: "/path2", Method: "POST", Status: "404"},
{Service: "service3", IP: "192.168.1.3", Path: "/path3", Method: "PUT", Status: "500"},
{Service: "service4", IP: "192.168.1.4", Path: "/path4", Method: "DELETE", Status: "200"},
{Service: "service5", IP: "192.168.1.5", Path: "/path5", Method: "GET", Status: "301"},
}
for _, entry := range entries {
resultCh <- entry
}
time.Sleep(1500 * time.Millisecond)
close(resultCh)
<-done
count, err := writer.GetRequestCount()
if err != nil {
t.Fatalf("Failed to get request count: %v", err)
}
if count != len(entries) {
t.Errorf("Expected %d entries, got %d", len(entries), count)
}
}
func TestWrite_EmptyBatch(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "requests_test.db")
writer, err := NewRequestWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create RequestWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
resultCh := make(chan *LogEntry, 100)
done := make(chan bool)
go func() {
WriteReq(writer, resultCh)
close(done)
}()
close(resultCh)
<-done
count, err := writer.GetRequestCount()
if err != nil {
t.Fatalf("Failed to get request count: %v", err)
}
if count != 0 {
t.Errorf("Expected 0 entries for empty batch, got %d", count)
}
}
func TestWrite_ChannelClosed(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "requests_test.db")
writer, err := NewRequestWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create RequestWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
resultCh := make(chan *LogEntry, 100)
done := make(chan bool)
go func() {
WriteReq(writer, resultCh)
close(done)
}()
entries := []*LogEntry{
{Service: "service1", IP: "192.168.1.1", Path: "/path1", Method: "GET", Status: "200"},
{Service: "service2", IP: "192.168.1.2", Path: "/path2", Method: "POST", Status: "404"},
}
for _, entry := range entries {
resultCh <- entry
} }
close(resultCh) close(resultCh)
time.Sleep(100 * time.Millisecond) <-done
err = d.db.QueryRow("SELECT ip FROM requests LIMIT 1").Scan(&ip) count, err := writer.GetRequestCount()
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("Failed to get request count: %v", err)
} }
if ip != "127.0.0.1" {
t.Fatal("ip should be 127.0.0.1") if count != len(entries) {
t.Errorf("Expected %d entries, got %d", len(entries), count)
} }
} }
func NewRequestWriterWithDBPath(dbPath string) (*RequestWriter, error) {
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)")
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
return &RequestWriter{
logger: logger.New(false),
db: db,
}, nil
}
func (w *RequestWriter) CreateTable() error {
_, err := w.db.Exec(CreateRequestsTable)
if err != nil {
return err
}
w.logger.Info("Created requests table")
return nil
}