diff --git a/internal/storage/db.go b/internal/storage/db.go index ff82e21..90edc43 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -2,55 +2,59 @@ package storage import ( "database/sql" + "errors" "fmt" + "strings" _ "modernc.org/sqlite" ) -func CreateTables() error { - // Requests DB - db_r, err := sql.Open("sqlite", - "/var/lib/banforge/requests.db?"+ - "mode=rwc&"+ - "_pragma=journal_mode(WAL)&"+ - "_pragma=busy_timeout(30000)&"+ - "_pragma=synchronous(NORMAL)") - if err != nil { - return fmt.Errorf("failed to open requests db: %w", err) - } - defer func() { - err = db_r.Close() - if err != nil { - fmt.Println(err) - } - }() +const ( + DBDir = "/var/lib/banforge/" + ReqDBPath = DBDir + "requests.db" + banDBPath = DBDir + "bans.db" +) - _, err = db_r.Exec(CreateRequestsTable) - if err != nil { - return fmt.Errorf("failed to create requests table: %w", err) - } - - // Bans DB - db_b, err := sql.Open("sqlite", - "/var/lib/banforge/bans.db?"+ - "mode=rwc&"+ - "_pragma=journal_mode(WAL)&"+ - "_pragma=busy_timeout(30000)&"+ - "_pragma=synchronous(FULL)") - if err != nil { - return fmt.Errorf("failed to open bans db: %w", err) - } - defer func() { - err = db_b.Close() - if err != nil { - fmt.Println(err) - } - }() - - _, err = db_b.Exec(CreateBansTable) - if err != nil { - return fmt.Errorf("failed to create bans table: %w", err) - } - fmt.Println("Tables created successfully!") - return nil +var pragmas = map[string]string{ + `journal_mode`: `wal`, + `synchronous`: `normal`, + `busy_timeout`: `30000`, + // also consider these + // `temp_store`: `memory`, + // `cache_size`: `1000000000`, +} + +func buildSqliteDsn(path string, pragmas map[string]string) string { + pragmastrs := make([]string, len(pragmas)) + i := 0 + for k, v := range pragmas { + pragmastrs[i] = (fmt.Sprintf(`pragma=%s(%s)`, k, v)) + i++ + } + return path + "?" + "mode=rwc&" + strings.Join(pragmastrs, "&") +} + +func initDB(dsn, sqlstr string) (err error) { + db, err := sql.Open("sqlite", dsn) + if err != nil { + return fmt.Errorf("failed to open %q: %w", dsn, err) + } + defer func() { + closeErr := db.Close() + if closeErr != nil { + err = errors.Join(err, fmt.Errorf("failed to close %q: %w", dsn, closeErr)) + } + }() + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + return err +} + +func CreateTables() (err error) { + // Requests DB + err1 := initDB(buildSqliteDsn(ReqDBPath, pragmas), CreateRequestsTable) + err2 := initDB(buildSqliteDsn(banDBPath, pragmas), CreateBansTable) + + return errors.Join(err1, err2) }