From 623bd87b4ca568c525404bd42707a21525d9d0e0 Mon Sep 17 00:00:00 2001 From: d3m0k1d Date: Thu, 15 Jan 2026 17:01:49 +0300 Subject: [PATCH] tests: Add tests for storage package --- internal/storage/db_test.go | 177 ++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 internal/storage/db_test.go diff --git a/internal/storage/db_test.go b/internal/storage/db_test.go new file mode 100644 index 0000000..42a8ab0 --- /dev/null +++ b/internal/storage/db_test.go @@ -0,0 +1,177 @@ +package storage + +import ( + "database/sql" + "github.com/d3m0k1d/BanForge/internal/logger" + _ "github.com/mattn/go-sqlite3" + "os" + "path/filepath" + "testing" + "time" +) + +func createTestDB(t *testing.T) *sql.DB { + tmpDir, err := os.MkdirTemp("", "banforge-test-*") + if err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(tmpDir, "test.db") + db, err := sql.Open("sqlite3", filePath) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + db.Close() + os.RemoveAll(tmpDir) + }) + + return db +} + +func createTestDBStruct(t *testing.T) *DB { + tmpDir, err := os.MkdirTemp("", "banforge-test-*") + if err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(tmpDir, "test.db") + sqlDB, err := sql.Open("sqlite3", filePath) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + sqlDB.Close() + os.RemoveAll(tmpDir) + }) + + return &DB{ + logger: logger.New(false), + db: sqlDB, + } +} + +func TestCreateTable(t *testing.T) { + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + rows, err := d.db.Query("SELECT 1 FROM requests LIMIT 1") + if err != nil { + t.Fatal("requests table should exist:", err) + } + rows.Close() + + rows, err = d.db.Query("SELECT 1 FROM bans LIMIT 1") + if err != nil { + t.Fatal("bans table should exist:", err) + } + rows.Close() +} + +func TestMarkAsViewed(t *testing.T) { + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + _, err = d.db.Exec( + "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", + "test", + "127.0.0.1", + "/test", + "GET", + "200", + time.Now().Format(time.RFC3339), + ) + if err != nil { + t.Fatal(err) + } + + err = d.MarkAsViewed(1) + if err != nil { + t.Fatal(err) + } + + var isViewed bool + err = d.db.QueryRow("SELECT viewed FROM requests WHERE id = 1").Scan(&isViewed) + if err != nil { + t.Fatal(err) + } + if !isViewed { + t.Fatal("viewed should be true") + } +} + +func TestSearchUnViewed(t *testing.T) { + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + _, err := d.db.Exec( + "INSERT INTO requests (service, ip, path, method, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", + "test", + "127.0.0.1", + "/test", + "GET", + "200", + time.Now().Format(time.RFC3339), + ) + if err != nil { + t.Fatal(err) + } + } + + rows, err := d.SearchUnViewed() + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + count := 0 + for rows.Next() { + var id int + var service, ip, path, status, method string + var viewed bool + var createdAt string + + err := rows.Scan(&id, &service, &ip, &path, &status, &method, &viewed, &createdAt) + if err != nil { + t.Fatal(err) + } + + if viewed { + t.Fatal("should be unviewed") + } + + count++ + } + + if err := rows.Err(); err != nil { + t.Fatal(err) + } + + if count != 2 { + t.Fatalf("expected 2 unviewed requests, got %d", count) + } +} + +func TestClose(t *testing.T) { + d := createTestDBStruct(t) + + err := d.Close() + if err != nil { + t.Fatal(err) + } +}