refactoring: full refactoring the database structure from 1 file to 2 file db struct to avoid conflict 2 writters and sqllite busy, improve tests
Some checks failed
build / build (push) Has been cancelled
Some checks failed
build / build (push) Has been cancelled
This commit is contained in:
380
internal/storage/ban_db_test.go
Normal file
380
internal/storage/ban_db_test.go
Normal file
@@ -0,0 +1,380 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/d3m0k1d/BanForge/internal/logger"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBanWriter_AddBan(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
ip := "192.168.1.1"
|
||||
ttl := "1h"
|
||||
|
||||
err = writer.AddBan(ip, ttl)
|
||||
if err != nil {
|
||||
t.Errorf("AddBan failed: %v", err)
|
||||
}
|
||||
|
||||
reader, err := NewBanReaderWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanReader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
isBanned, err := reader.IsBanned(ip)
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed: %v", err)
|
||||
}
|
||||
if !isBanned {
|
||||
t.Error("Expected IP to be banned, but it's not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanWriter_RemoveBan(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
ip := "192.168.1.2"
|
||||
err = writer.AddBan(ip, "1h")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add ban: %v", err)
|
||||
}
|
||||
|
||||
reader, err := NewBanReaderWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanReader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
isBanned, err := reader.IsBanned(ip)
|
||||
if err != nil {
|
||||
t.Fatalf("IsBanned failed: %v", err)
|
||||
}
|
||||
if !isBanned {
|
||||
t.Fatal("Expected IP to be banned before removal")
|
||||
}
|
||||
|
||||
err = writer.RemoveBan(ip)
|
||||
if err != nil {
|
||||
t.Errorf("RemoveBan failed: %v", err)
|
||||
}
|
||||
|
||||
isBanned, err = reader.IsBanned(ip)
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed after removal: %v", err)
|
||||
}
|
||||
if isBanned {
|
||||
t.Error("Expected IP to be unbanned after removal, but it's still banned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanWriter_RemoveExpiredBans(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
expiredIP := "192.168.1.3"
|
||||
err = writer.AddBan(expiredIP, "-1h")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add expired ban: %v", err)
|
||||
}
|
||||
|
||||
activeIP := "192.168.1.4"
|
||||
err = writer.AddBan(activeIP, "1h")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add active ban: %v", err)
|
||||
}
|
||||
|
||||
removedIPs, err := writer.RemoveExpiredBans()
|
||||
if err != nil {
|
||||
t.Errorf("RemoveExpiredBans failed: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, ip := range removedIPs {
|
||||
if ip == expiredIP {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected expired IP to be in removed list")
|
||||
}
|
||||
|
||||
if len(removedIPs) != 1 {
|
||||
t.Errorf("Expected 1 removed IP, got %d", len(removedIPs))
|
||||
}
|
||||
|
||||
reader, err := NewBanReaderWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanReader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
isExpiredBanned, err := reader.IsBanned(expiredIP)
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed for expired IP: %v", err)
|
||||
}
|
||||
if isExpiredBanned {
|
||||
t.Error("Expected expired IP to be unbanned, but it's still banned")
|
||||
}
|
||||
|
||||
isActiveBanned, err := reader.IsBanned(activeIP)
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed for active IP: %v", err)
|
||||
}
|
||||
if !isActiveBanned {
|
||||
t.Error("Expected active IP to still be banned, but it's not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanReader_IsBanned(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
ip := "192.168.1.5"
|
||||
err = writer.AddBan(ip, "1h")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add ban: %v", err)
|
||||
}
|
||||
|
||||
reader, err := NewBanReaderWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanReader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
isBanned, err := reader.IsBanned(ip)
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed for banned IP: %v", err)
|
||||
}
|
||||
if !isBanned {
|
||||
t.Error("Expected IP to be banned")
|
||||
}
|
||||
|
||||
isBanned, err = reader.IsBanned("192.168.1.6")
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed for non-banned IP: %v", err)
|
||||
}
|
||||
if isBanned {
|
||||
t.Error("Expected IP to not be banned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanWriter_Close(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = writer.db.Exec("SELECT 1")
|
||||
if err == nil {
|
||||
t.Error("Expected error when using closed connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanReader_Close(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
reader, err := NewBanReaderWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanReader: %v", err)
|
||||
}
|
||||
|
||||
err = reader.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = reader.db.Query("SELECT 1")
|
||||
if err == nil {
|
||||
t.Error("Expected error when using closed connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanWriter_AddBan_InvalidDuration(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
err = writer.AddBan("192.168.1.7", "invalid_duration")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid duration")
|
||||
} else if err.Error() == "" || err.Error() == "<nil>" {
|
||||
t.Error("Expected meaningful error message for invalid duration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleBans(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
ips := []string{"192.168.1.8", "192.168.1.9", "192.168.1.10"}
|
||||
|
||||
for _, ip := range ips {
|
||||
err := writer.AddBan(ip, "1h")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to add ban for IP %s: %v", ip, err)
|
||||
}
|
||||
}
|
||||
|
||||
reader, err := NewBanReaderWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanReader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
for _, ip := range ips {
|
||||
isBanned, err := reader.IsBanned(ip)
|
||||
if err != nil {
|
||||
t.Errorf("IsBanned failed for IP %s: %v", ip, err)
|
||||
continue
|
||||
}
|
||||
if !isBanned {
|
||||
t.Errorf("Expected IP %s to be banned", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveNonExistentBan(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "bans_test.db")
|
||||
|
||||
writer, err := NewBanWriterWithDBPath(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create BanWriter: %v", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
err = writer.CreateTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
err = writer.RemoveBan("192.168.1.11")
|
||||
if err != nil {
|
||||
t.Errorf("RemoveBan should not return error for non-existent ban: %v", err)
|
||||
}
|
||||
}
|
||||
func NewBanWriterWithDBPath(dbPath string) (*BanWriter, error) {
|
||||
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BanWriter{
|
||||
logger: logger.New(false),
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewBanReaderWithDBPath(dbPath string) (*BanReader, error) {
|
||||
db, err := sql.Open("sqlite",
|
||||
dbPath+"?"+
|
||||
"mode=ro&"+
|
||||
"_pragma=journal_mode(WAL)&"+
|
||||
"_pragma=mmap_size(268435456)&"+
|
||||
"_pragma=cache_size(-2000)&"+
|
||||
"_pragma=query_only(1)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &BanReader{
|
||||
logger: logger.New(false),
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user