Files
BanForge/internal/storage/ban_db_test.go

381 lines
8.2 KiB
Go

package storage
import (
"database/sql"
"github.com/d3m0k1d/BanForge/internal/logger"
"path/filepath"
"testing"
)
func TestBanWriter_AddBan(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ip := "192.168.1.1"
ttl := "1h"
err = writer.AddBan(ip, ttl)
if err != nil {
t.Errorf("AddBan failed: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed: %v", err)
}
if !isBanned {
t.Error("Expected IP to be banned, but it's not")
}
}
func TestBanWriter_RemoveBan(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ip := "192.168.1.2"
err = writer.AddBan(ip, "1h")
if err != nil {
t.Fatalf("Failed to add ban: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Fatalf("IsBanned failed: %v", err)
}
if !isBanned {
t.Fatal("Expected IP to be banned before removal")
}
err = writer.RemoveBan(ip)
if err != nil {
t.Errorf("RemoveBan failed: %v", err)
}
isBanned, err = reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed after removal: %v", err)
}
if isBanned {
t.Error("Expected IP to be unbanned after removal, but it's still banned")
}
}
func TestBanWriter_RemoveExpiredBans(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
expiredIP := "192.168.1.3"
err = writer.AddBan(expiredIP, "-1h")
if err != nil {
t.Fatalf("Failed to add expired ban: %v", err)
}
activeIP := "192.168.1.4"
err = writer.AddBan(activeIP, "1h")
if err != nil {
t.Fatalf("Failed to add active ban: %v", err)
}
removedIPs, err := writer.RemoveExpiredBans()
if err != nil {
t.Errorf("RemoveExpiredBans failed: %v", err)
}
found := false
for _, ip := range removedIPs {
if ip == expiredIP {
found = true
break
}
}
if !found {
t.Error("Expected expired IP to be in removed list")
}
if len(removedIPs) != 1 {
t.Errorf("Expected 1 removed IP, got %d", len(removedIPs))
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isExpiredBanned, err := reader.IsBanned(expiredIP)
if err != nil {
t.Errorf("IsBanned failed for expired IP: %v", err)
}
if isExpiredBanned {
t.Error("Expected expired IP to be unbanned, but it's still banned")
}
isActiveBanned, err := reader.IsBanned(activeIP)
if err != nil {
t.Errorf("IsBanned failed for active IP: %v", err)
}
if !isActiveBanned {
t.Error("Expected active IP to still be banned, but it's not")
}
}
func TestBanReader_IsBanned(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ip := "192.168.1.5"
err = writer.AddBan(ip, "1h")
if err != nil {
t.Fatalf("Failed to add ban: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed for banned IP: %v", err)
}
if !isBanned {
t.Error("Expected IP to be banned")
}
isBanned, err = reader.IsBanned("192.168.1.6")
if err != nil {
t.Errorf("IsBanned failed for non-banned IP: %v", err)
}
if isBanned {
t.Error("Expected IP to not be banned")
}
}
func TestBanWriter_Close(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
err = writer.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
_, err = writer.db.Exec("SELECT 1")
if err == nil {
t.Error("Expected error when using closed connection")
}
}
func TestBanReader_Close(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
err = reader.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
_, err = reader.db.Query("SELECT 1")
if err == nil {
t.Error("Expected error when using closed connection")
}
}
func TestBanWriter_AddBan_InvalidDuration(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
err = writer.AddBan("192.168.1.7", "invalid_duration")
if err == nil {
t.Error("Expected error for invalid duration")
} else if err.Error() == "" || err.Error() == "<nil>" {
t.Error("Expected meaningful error message for invalid duration")
}
}
func TestMultipleBans(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
ips := []string{"192.168.1.8", "192.168.1.9", "192.168.1.10"}
for _, ip := range ips {
err := writer.AddBan(ip, "1h")
if err != nil {
t.Errorf("Failed to add ban for IP %s: %v", ip, err)
}
}
reader, err := NewBanReaderWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanReader: %v", err)
}
defer reader.Close()
for _, ip := range ips {
isBanned, err := reader.IsBanned(ip)
if err != nil {
t.Errorf("IsBanned failed for IP %s: %v", ip, err)
continue
}
if !isBanned {
t.Errorf("Expected IP %s to be banned", ip)
}
}
}
func TestRemoveNonExistentBan(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "bans_test.db")
writer, err := NewBanWriterWithDBPath(dbPath)
if err != nil {
t.Fatalf("Failed to create BanWriter: %v", err)
}
defer writer.Close()
err = writer.CreateTable()
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
err = writer.RemoveBan("192.168.1.11")
if err != nil {
t.Errorf("RemoveBan should not return error for non-existent ban: %v", err)
}
}
func NewBanWriterWithDBPath(dbPath string) (*BanWriter, error) {
db, err := sql.Open("sqlite", dbPath+"?_pragma=journal_mode(WAL)&_pragma=busy_timeout(30000)&_pragma=synchronous(NORMAL)")
if err != nil {
return nil, err
}
return &BanWriter{
logger: logger.New(false),
db: db,
}, nil
}
func NewBanReaderWithDBPath(dbPath string) (*BanReader, error) {
db, err := sql.Open("sqlite",
dbPath+"?"+
"mode=ro&"+
"_pragma=journal_mode(WAL)&"+
"_pragma=mmap_size(268435456)&"+
"_pragma=cache_size(-2000)&"+
"_pragma=query_only(1)")
if err != nil {
return nil, err
}
return &BanReader{
logger: logger.New(false),
db: db,
}, nil
}