feat: full working max_retry logic
All checks were successful
build / build (push) Successful in 2m45s

This commit is contained in:
d3m0k1d
2026-02-22 16:06:51 +03:00
parent 8c0cfcdbe7
commit a602207369
5 changed files with 24 additions and 11 deletions

View File

@@ -30,6 +30,11 @@ var DaemonCmd = &cobra.Command{
log.Error("Failed to create request writer", "error", err) log.Error("Failed to create request writer", "error", err)
os.Exit(1) os.Exit(1)
} }
reqDb_r, err := storage.NewRequestsRd()
if err != nil {
log.Error("Failed to create request reader", "error", err)
os.Exit(1)
}
banDb_r, err := storage.NewBanReader() banDb_r, err := storage.NewBanReader()
if err != nil { if err != nil {
log.Error("Failed to create ban reader", "error", err) log.Error("Failed to create ban reader", "error", err)
@@ -63,7 +68,7 @@ var DaemonCmd = &cobra.Command{
log.Error("Failed to load rules", "error", err) log.Error("Failed to load rules", "error", err)
os.Exit(1) os.Exit(1)
} }
j := judge.New(banDb_r, banDb_w, b, resultCh, entryCh) j := judge.New(banDb_r, banDb_w, reqDb_r, b, resultCh, entryCh)
j.LoadRules(r) j.LoadRules(r)
go j.UnbanChecker() go j.UnbanChecker()
go j.Tribunal() go j.Tribunal()

View File

@@ -61,11 +61,12 @@ var ListCmd = &cobra.Command{
} }
for _, rule := range r { for _, rule := range r {
fmt.Printf( fmt.Printf(
"Name: %s\nService: %s\nPath: %s\nStatus: %s\nMethod: %s\n\n", "Name: %s\nService: %s\nPath: %s\nStatus: %s\n MaxRetry: %d\nMethod: %s\n\n",
rule.Name, rule.Name,
rule.ServiceName, rule.ServiceName,
rule.Path, rule.Path,
rule.Status, rule.Status,
rule.MaxRetry,
rule.Method, rule.Method,
) )
} }

View File

@@ -12,13 +12,11 @@ config = "/etc/nftables.conf"
name = "nginx" name = "nginx"
logging = "file" logging = "file"
log_path = "/var/log/nginx/access.log" log_path = "/var/log/nginx/access.log"
max_retry = 3
enabled = true enabled = true
[[service]] [[service]]
name = "nginx" name = "nginx"
logging = "journald" logging = "journald"
log_path = "/var/log/nginx/access.log" log_path = "/var/log/nginx/access.log"
max_retry = 3
enabled = false enabled = false
` `

View File

@@ -14,6 +14,7 @@ import (
type Judge struct { type Judge struct {
db_r *storage.BanReader db_r *storage.BanReader
db_w *storage.BanWriter db_w *storage.BanWriter
db_rq *storage.RequestReader
logger *logger.Logger logger *logger.Logger
Blocker blocker.BlockerEngine Blocker blocker.BlockerEngine
rulesByService map[string][]config.Rule rulesByService map[string][]config.Rule
@@ -24,6 +25,7 @@ type Judge struct {
func New( func New(
db_r *storage.BanReader, db_r *storage.BanReader,
db_w *storage.BanWriter, db_w *storage.BanWriter,
db_rq *storage.RequestReader,
b blocker.BlockerEngine, b blocker.BlockerEngine,
resultCh chan *storage.LogEntry, resultCh chan *storage.LogEntry,
entryCh chan *storage.LogEntry, entryCh chan *storage.LogEntry,
@@ -31,6 +33,7 @@ func New(
return &Judge{ return &Judge{
db_w: db_w, db_w: db_w,
db_r: db_r, db_r: db_r,
db_rq: db_rq,
logger: logger.New(false), logger: logger.New(false),
rulesByService: make(map[string][]config.Rule), rulesByService: make(map[string][]config.Rule),
Blocker: b, Blocker: b,
@@ -75,11 +78,10 @@ func (j *Judge) Tribunal() {
methodMatch := rule.Method == "" || entry.Method == rule.Method methodMatch := rule.Method == "" || entry.Method == rule.Method
statusMatch := rule.Status == "" || entry.Status == rule.Status statusMatch := rule.Status == "" || entry.Status == rule.Status
pathMatch := matchPath(entry.Path, rule.Path) pathMatch := matchPath(entry.Path, rule.Path)
if methodMatch && statusMatch && pathMatch { if methodMatch && statusMatch && pathMatch {
ruleMatched = true ruleMatched = true
j.logger.Info("Rule matched", "rule", rule.Name, "ip", entry.IP) j.logger.Info("Rule matched", "rule", rule.Name, "ip", entry.IP)
j.resultCh <- entry
banned, err := j.db_r.IsBanned(entry.IP) banned, err := j.db_r.IsBanned(entry.IP)
if err != nil { if err != nil {
j.logger.Error("Failed to check ban status", "ip", entry.IP, "error", err) j.logger.Error("Failed to check ban status", "ip", entry.IP, "error", err)
@@ -87,10 +89,17 @@ func (j *Judge) Tribunal() {
} }
if banned { if banned {
j.logger.Info("IP already banned", "ip", entry.IP) j.logger.Info("IP already banned", "ip", entry.IP)
j.resultCh <- entry
break break
} }
exceeded, err := j.db_rq.IsMaxRetryExceeded(entry.IP, rule.MaxRetry)
if err != nil {
j.logger.Error("Failed to check retry count", "ip", entry.IP, "error", err)
break
}
if !exceeded {
j.logger.Info("Max retry not exceeded", "ip", entry.IP)
break
}
err = j.db_w.AddBan(entry.IP, rule.BanTime, rule.Name) err = j.db_w.AddBan(entry.IP, rule.BanTime, rule.Name)
if err != nil { if err != nil {
j.logger.Error( j.logger.Error(
@@ -118,7 +127,6 @@ func (j *Judge) Tribunal() {
"ban_time", "ban_time",
rule.BanTime, rule.BanTime,
) )
j.resultCh <- entry
break break
} }
} }

View File

@@ -51,12 +51,13 @@ func NewRequestsRd() (*RequestReader, error) {
}, nil }, nil
} }
func (r *RequestReader) IsMaxRetryExceeded(ip string, max_retry int) (bool, error) { func (r *RequestReader) IsMaxRetryExceeded(ip string, maxRetry int) (bool, error) {
var count int var count int
err := r.db.QueryRow("SELECT COUNT(*) FROM requests WHERE ip = ?", ip).Scan(&count) err := r.db.QueryRow("SELECT COUNT(*) FROM requests WHERE ip = ?", ip).Scan(&count)
if err != nil { if err != nil {
r.logger.Error("error query count: " + err.Error()) r.logger.Error("error query count: " + err.Error())
return false, err return false, err
} }
return count >= max_retry, nil r.logger.Info("Current request count for IP", "ip", ip, "count", count, "maxRetry", maxRetry)
return count >= maxRetry, nil
} }