119 lines
2.5 KiB
Go
119 lines
2.5 KiB
Go
package storage
|
|
|
|
import (
|
|
"database/sql"
|
|
"os"
|
|
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/d3m0k1d/BanForge/internal/logger"
|
|
"github.com/jedib0t/go-pretty/v6/table"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
type DB struct {
|
|
logger *logger.Logger
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewDB() (*DB, error) {
|
|
db, err := sql.Open("sqlite3", "/var/lib/banforge/storage.db?mode=rwc&_journal_mode=WAL&_busy_timeout=10000&cache=shared")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
return &DB{
|
|
logger: logger.New(false),
|
|
db: db,
|
|
}, nil
|
|
}
|
|
|
|
func (d *DB) Close() error {
|
|
d.logger.Info("Closing database connection")
|
|
err := d.db.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *DB) CreateTable() error {
|
|
_, err := d.db.Exec(CreateTables)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.logger.Info("Created tables")
|
|
return nil
|
|
}
|
|
|
|
func (d *DB) SearchUnViewed() (*sql.Rows, error) {
|
|
rows, err := d.db.Query("SELECT id, service, ip, path, status, method, viewed, created_at FROM requests WHERE viewed = 0")
|
|
if err != nil {
|
|
d.logger.Error("Failed to query database")
|
|
return nil, err
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func (d *DB) MarkAsViewed(id int) error {
|
|
_, err := d.db.Exec("UPDATE requests SET viewed = 1 WHERE id = ?", id)
|
|
if err != nil {
|
|
d.logger.Error("Failed to mark as viewed", "error", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *DB) IsBanned(ip string) (bool, error) {
|
|
var bannedIP string
|
|
err := d.db.QueryRow("SELECT ip FROM bans WHERE ip = ? ", ip).Scan(&bannedIP)
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check ban status: %w", err)
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (d *DB) AddBan(ip string) error {
|
|
_, err := d.db.Exec("INSERT INTO bans (ip, reason, banned_at) VALUES (?, ?, ?)", ip, "1", time.Now().Format(time.RFC3339))
|
|
if err != nil {
|
|
d.logger.Error("Failed to add ban", "error", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *DB) BanList() error {
|
|
|
|
var count int
|
|
t := table.NewWriter()
|
|
t.SetOutputMirror(os.Stdout)
|
|
t.SetStyle(table.StyleBold)
|
|
t.AppendHeader(table.Row{"№", "IP", "Banned At"})
|
|
rows, err := d.db.Query("SELECT ip, banned_at FROM bans")
|
|
if err != nil {
|
|
d.logger.Error("Failed to get ban list", "error", err)
|
|
return err
|
|
}
|
|
for rows.Next() {
|
|
count++
|
|
var ip string
|
|
var bannedAt string
|
|
err := rows.Scan(&ip, &bannedAt)
|
|
if err != nil {
|
|
d.logger.Error("Failed to get ban list", "error", err)
|
|
return err
|
|
}
|
|
t.AppendRow(table.Row{count, ip, bannedAt})
|
|
|
|
}
|
|
t.Render()
|
|
return nil
|
|
}
|