151 lines
2.5 KiB
Go
151 lines
2.5 KiB
Go
package storage
|
|
|
|
import (
|
|
"database/sql"
|
|
"github.com/d3m0k1d/BanForge/internal/logger"
|
|
_ "modernc.org/sqlite"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func createTestDB(t *testing.T) *sql.DB {
|
|
tmpDir, err := os.MkdirTemp("", "banforge-test-*")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
filePath := filepath.Join(tmpDir, "test.db")
|
|
db, err := sql.Open("sqlite", filePath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
db.Close()
|
|
os.RemoveAll(tmpDir)
|
|
})
|
|
|
|
return db
|
|
}
|
|
|
|
func createTestDBStruct(t *testing.T) *DB {
|
|
tmpDir, err := os.MkdirTemp("", "banforge-test-*")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
filePath := filepath.Join(tmpDir, "test.db")
|
|
sqlDB, err := sql.Open("sqlite", filePath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
sqlDB.Close()
|
|
os.RemoveAll(tmpDir)
|
|
})
|
|
|
|
return &DB{
|
|
logger: logger.New(false),
|
|
db: sqlDB,
|
|
}
|
|
}
|
|
|
|
func TestCreateTable(t *testing.T) {
|
|
d := createTestDBStruct(t)
|
|
|
|
err := d.CreateTable()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
rows, err := d.db.Query("SELECT 1 FROM requests LIMIT 1")
|
|
if err != nil {
|
|
t.Fatal("requests table should exist:", err)
|
|
}
|
|
rows.Close()
|
|
|
|
rows, err = d.db.Query("SELECT 1 FROM bans LIMIT 1")
|
|
if err != nil {
|
|
t.Fatal("bans table should exist:", err)
|
|
}
|
|
rows.Close()
|
|
}
|
|
|
|
func TestIsBanned(t *testing.T) {
|
|
d := createTestDBStruct(t)
|
|
|
|
err := d.CreateTable()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
isBanned, err := d.IsBanned("127.0.0.1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !isBanned {
|
|
t.Fatal("should be banned")
|
|
}
|
|
}
|
|
|
|
func TestAddBan(t *testing.T) {
|
|
d := createTestDBStruct(t)
|
|
|
|
err := d.CreateTable()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = d.AddBan("127.0.0.1", "7h")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var ip string
|
|
err = d.db.QueryRow("SELECT ip FROM bans WHERE ip = ?", "127.0.0.1").Scan(&ip)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if ip != "127.0.0.1" {
|
|
t.Fatal("ip should be 127.0.0.1")
|
|
}
|
|
}
|
|
|
|
func TestBanList(t *testing.T) {
|
|
d := createTestDBStruct(t)
|
|
|
|
err := d.CreateTable()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = d.db.Exec("INSERT INTO bans (ip, banned_at) VALUES (?, ?)", "127.0.0.1", time.Now().Format(time.RFC3339))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = d.BanList()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
d := createTestDBStruct(t)
|
|
|
|
err := d.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|