diff --git a/internal/judge/judge_test.go b/internal/judge/judge_test.go new file mode 100644 index 0000000..72340bc --- /dev/null +++ b/internal/judge/judge_test.go @@ -0,0 +1,60 @@ +package judge + +import ( + "testing" + + "github.com/d3m0k1d/BanForge/internal/config" + "github.com/d3m0k1d/BanForge/internal/storage" +) + +func TestJudgeLogic(t *testing.T) { + tests := []struct { + name string + inputRule config.Rule + inputLog storage.LogEntry + wantErr bool + wantMatch bool + }{ + { + name: "Empty rule", + inputRule: config.Rule{Name: "", ServiceName: "", Path: "", Status: "", Method: ""}, + inputLog: storage.LogEntry{ID: 0, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, + wantErr: true, + wantMatch: false, + }, + { + name: "Matching rule", + inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "200", Method: "GET"}, + inputLog: storage.LogEntry{ID: 1, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, + wantErr: false, + wantMatch: true, + }, + { + name: "Non-matching status", + inputRule: config.Rule{Name: "test", ServiceName: "nginx", Path: "/api", Status: "404", Method: "GET"}, + inputLog: storage.LogEntry{ID: 2, Service: "nginx", IP: "127.0.0.1", Path: "/api", Status: "200", Method: "GET", IsViewed: false, CreatedAt: ""}, + wantErr: false, + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.inputRule.Name == "" { + if !tt.wantErr { + t.Errorf("Expected error for empty rule name, but got none") + } + return + } + + result := (tt.inputRule.Method == "" || tt.inputLog.Method == tt.inputRule.Method) && + (tt.inputRule.Status == "" || tt.inputLog.Status == tt.inputRule.Status) && + (tt.inputRule.Path == "" || tt.inputLog.Path == tt.inputRule.Path) && + (tt.inputRule.ServiceName == "" || tt.inputLog.Service == tt.inputRule.ServiceName) + + if result != tt.wantMatch { + t.Errorf("Expected error: %v, but got: %v", tt.wantErr, result) + } + }) + } +} diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go index 42a8ab0..8c5028f 100644 --- a/internal/storage/db_test.go +++ b/internal/storage/db_test.go @@ -167,6 +167,72 @@ func TestSearchUnViewed(t *testing.T) { } } +func TestIsBanned(t *testing.T) { + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + _, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339)) + if err != nil { + t.Fatal(err) + } + + isBanned, err := d.IsBanned("127.0.0.1") + if err != nil { + t.Fatal(err) + } + + if !isBanned { + t.Fatal("should be banned") + } +} + +func TestAddBan(t *testing.T) { + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + err = d.AddBan("127.0.0.1") + if err != nil { + t.Fatal(err) + } + + var ip string + err = d.db.QueryRow("SELECT ip FROM bans WHERE ip = ?", "127.0.0.1").Scan(&ip) + if err != nil { + t.Fatal(err) + } + + if ip != "127.0.0.1" { + t.Fatal("ip should be 127.0.0.1") + } +} + +func TestBanList(t *testing.T) { + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + _, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339)) + if err != nil { + t.Fatal(err) + } + + err = d.BanList() + if err != nil { + t.Fatal(err) + } +} + func TestClose(t *testing.T) { d := createTestDBStruct(t)