diff --git a/cmd/banforge/command/daemon.go b/cmd/banforge/command/daemon.go index e1b70ee..f92d11d 100644 --- a/cmd/banforge/command/daemon.go +++ b/cmd/banforge/command/daemon.go @@ -60,6 +60,11 @@ var DaemonCmd = &cobra.Command{ log.Error("Failed to load config", "error", err) os.Exit(1) } + _, err = config.LoadMetricsConfig() + if err != nil { + log.Error("Failed to load metrics config", "error", err) + os.Exit(1) + } var b blocker.BlockerEngine fw := cfg.Firewall.Name b = blocker.GetBlocker(fw, cfg.Firewall.Config) diff --git a/internal/config/appconf.go b/internal/config/appconf.go index 85bd83c..837bc34 100644 --- a/internal/config/appconf.go +++ b/internal/config/appconf.go @@ -3,15 +3,31 @@ package config import ( "errors" "fmt" + "github.com/BurntSushi/toml" + "github.com/d3m0k1d/BanForge/internal/logger" + "github.com/d3m0k1d/BanForge/internal/metrics" "os" "strconv" "strings" "time" - - "github.com/BurntSushi/toml" - "github.com/d3m0k1d/BanForge/internal/logger" ) +func LoadMetricsConfig() (*Metrics, error) { + cfg := &Metrics{} + _, err := toml.DecodeFile("/etc/banforge/config.toml", cfg) + if err != nil { + return nil, fmt.Errorf("failed to decode config: %w", err) + } + + if cfg.Enabled && cfg.Port > 0 && cfg.Port < 65535 { + go metrics.StartMetricsServer(cfg.Port) + } else if cfg.Enabled { + fmt.Println("Metrics enabled but port invalid, not starting server") + } + + return cfg, nil +} + func LoadRuleConfig() ([]Rule, error) { log := logger.New(false) var cfg Rules diff --git a/internal/config/template.go b/internal/config/template.go index 21a7fcc..f870833 100644 --- a/internal/config/template.go +++ b/internal/config/template.go @@ -8,6 +8,10 @@ const Base_config = ` name = "" config = "/etc/nftables.conf" +[metrics] +enabled = false +port = 2122 + [[service]] name = "nginx" logging = "file" diff --git a/internal/config/types.go b/internal/config/types.go index 73828d1..d3bc943 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -14,6 +14,7 @@ type Service struct { type Config struct { Firewall Firewall `toml:"firewall"` + Metrics Metrics `toml:"metrics"` Service []Service `toml:"service"` } @@ -31,3 +32,8 @@ type Rule struct { MaxRetry int `toml:"max_retry"` BanTime string `toml:"ban_time"` } + +type Metrics struct { + Enabled bool `toml:"enabled"` + Port int `toml:"port"` +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..f1ee416 --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,65 @@ +package metrics + +import ( + "fmt" + "net/http" + "sync" +) + +var ( + metricsMu sync.RWMutex + metrics = make(map[string]int64) +) + +func IncBan(service string) { + metricsMu.Lock() + metrics["ban_count"]++ + metrics[service+"_bans"]++ + metricsMu.Unlock() +} + +func IncUnban(service string) { + metricsMu.Lock() + metrics["unban_count"]++ + metrics[service+"_unbans"]++ + metricsMu.Unlock() +} + +func IncRuleMatched(rule_name string) { + metricsMu.Lock() + metrics[rule_name+"_rule_matched"]++ + metricsMu.Unlock() +} + +func IncLogParsed() { + metricsMu.Lock() + metrics["log_parsed"]++ + metricsMu.Unlock() +} + +func MetricsHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + metricsMu.RLock() + snapshot := make(map[string]int64, len(metrics)) + for k, v := range metrics { + snapshot[k] = v + } + metricsMu.RUnlock() + + w.Header().Set("Content-Type", "text/plain; version=0.0.4") + + for name, value := range snapshot { + metricName := name + "_total" + fmt.Fprintf(w, "# TYPE %s counter\n", metricName) + fmt.Fprintf(w, "%s %d\n", metricName, value) + } + }) +} +func StartMetricsServer(port int) { + http.Handle("/metrics", MetricsHandler()) + addr := fmt.Sprintf(":%d", port) + if err := http.ListenAndServe(addr, nil); err != nil { + } +} + diff --git a/internal/metrics/server.go b/internal/metrics/server.go new file mode 100644 index 0000000..1f47870 --- /dev/null +++ b/internal/metrics/server.go @@ -0,0 +1,14 @@ +package metrics + +import ( + "net/http" +) + +func Handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + for k, v := range metrics { + w.Write([]byte(k + " " + string(v) + "\n")) + } + }) +}