diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 6379b92..f24cf2b 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -2,8 +2,12 @@ package parser import ( "bufio" + "fmt" "os" "os/exec" + "path/filepath" + "regexp" + "strings" "time" "github.com/d3m0k1d/BanForge/internal/logger" @@ -23,8 +27,56 @@ type Scanner struct { pollDelay time.Duration } +func validateLogPath(path string) error { + if path == "" { + return fmt.Errorf("log path cannot be empty") + } + + if !filepath.IsAbs(path) { + return fmt.Errorf("log path must be absolute: %s", path) + } + + if strings.Contains(path, "..") { + return fmt.Errorf("log path contains '..': %s", path) + } + + if _, err := os.Stat(path); os.IsNotExist(err) { + return fmt.Errorf("log file does not exist: %s", path) + } + + info, err := os.Lstat(path) + if err != nil { + return fmt.Errorf("failed to stat log file: %w", err) + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("log path is a symlink: %s", path) + } + + return nil +} + +func validateJournaldUnit(unit string) error { + if unit == "" { + return fmt.Errorf("journald unit cannot be empty") + } + + if !regexp.MustCompile(`^[a-zA-Z0-9._-]+$`).MatchString(unit) { + return fmt.Errorf("invalid journald unit name: %s", unit) + } + + if strings.HasPrefix(unit, "-") { + return fmt.Errorf("journald unit cannot start with '-': %s", unit) + } + + return nil +} + func NewScannerTail(path string) (*Scanner, error) { - // #nosec G204 - managed by system adminstartor + if err := validateLogPath(path); err != nil { + return nil, fmt.Errorf("invalid log path: %w", err) + } + + // #nosec G204 - path is validated above via validateLogPath() cmd := exec.Command("tail", "-F", "-n", "10", path) stdout, err := cmd.StdoutPipe() if err != nil { @@ -47,7 +99,11 @@ func NewScannerTail(path string) (*Scanner, error) { } func NewScannerJournald(unit string) (*Scanner, error) { - // #nosec G204 - managed by system adminstartor + if err := validateJournaldUnit(unit); err != nil { + return nil, fmt.Errorf("invalid journald unit: %w", err) + } + + // #nosec G204 - unit is validated above via validateJournaldUnit() cmd := exec.Command("journalctl", "-u", unit, "-f", "-n", "0", "-o", "short", "--no-pager") stdout, err := cmd.StdoutPipe() if err != nil { diff --git a/internal/parser/parser_test.go b/internal/parser/parser_test.go index 1bb555e..ce56f0d 100644 --- a/internal/parser/parser_test.go +++ b/internal/parser/parser_test.go @@ -2,6 +2,7 @@ package parser import ( "os" + "strings" "testing" "time" ) @@ -281,3 +282,201 @@ func BenchmarkScanner(b *testing.B) { <-scanner.Events() } } + +func TestValidateLogPath(t *testing.T) { + tests := []struct { + name string + path string + setup func() (string, func()) + wantErr bool + errMsg string + }{ + { + name: "empty path", + path: "", + wantErr: true, + errMsg: "log path cannot be empty", + }, + { + name: "relative path", + path: "logs/test.log", + wantErr: true, + errMsg: "log path must be absolute", + }, + { + name: "path with traversal", + path: "/var/log/../etc/passwd", + wantErr: true, + errMsg: "log path contains '..'", + }, + { + name: "non-existent file", + path: "/var/log/nonexistent.log", + wantErr: true, + errMsg: "log file does not exist", + }, + { + name: "valid file", + path: "/tmp/test-valid.log", + setup: func() (string, func()) { + _, _ = os.Create("/tmp/test-valid.log") + return "/tmp/test-valid.log", func() { os.Remove("/tmp/test-valid.log") } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cleanup func() + if tt.setup != nil { + tt.path, cleanup = tt.setup() + defer cleanup() + } + + err := validateLogPath(tt.path) + if (err != nil) != tt.wantErr { + t.Errorf("validateLogPath() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMsg != "" && err != nil { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateLogPath() error = %v, want message containing %q", err, tt.errMsg) + } + } + }) + } +} + +func TestValidateJournaldUnit(t *testing.T) { + tests := []struct { + name string + unit string + wantErr bool + errMsg string + }{ + { + name: "empty unit", + unit: "", + wantErr: true, + errMsg: "journald unit cannot be empty", + }, + { + name: "unit starting with dash", + unit: "-dangerous", + wantErr: true, + errMsg: "journald unit cannot start with '-'", + }, + { + name: "unit with special chars", + unit: "test;rm -rf /", + wantErr: true, + errMsg: "invalid journald unit name", + }, + { + name: "unit with spaces", + unit: "test unit", + wantErr: true, + errMsg: "invalid journald unit name", + }, + { + name: "valid unit simple", + unit: "nginx", + wantErr: false, + }, + { + name: "valid unit with dash", + unit: "ssh-agent", + wantErr: false, + }, + { + name: "valid unit with dot", + unit: "systemd-journald.service", + wantErr: false, + }, + { + name: "valid unit with underscore", + unit: "my_service", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateJournaldUnit(tt.unit) + if (err != nil) != tt.wantErr { + t.Errorf("validateJournaldUnit() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMsg != "" && err != nil { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateJournaldUnit() error = %v, want message containing %q", err, tt.errMsg) + } + } + }) + } +} + +func TestNewScannerTailValidation(t *testing.T) { + tests := []struct { + name string + path string + wantErr bool + }{ + { + name: "empty path", + path: "", + wantErr: true, + }, + { + name: "relative path", + path: "test.log", + wantErr: true, + }, + { + name: "non-existent path", + path: "/nonexistent/path/file.log", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewScannerTail(tt.path) + if (err != nil) != tt.wantErr { + t.Errorf("NewScannerTail() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNewScannerJournaldValidation(t *testing.T) { + tests := []struct { + name string + unit string + wantErr bool + }{ + { + name: "empty unit", + unit: "", + wantErr: true, + }, + { + name: "unit with semicolon", + unit: "test;rm -rf /", + wantErr: true, + }, + { + name: "unit starting with dash", + unit: "-dangerous", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewScannerJournald(tt.unit) + if (err != nil) != tt.wantErr { + t.Errorf("NewScannerJournald() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/parser/sshd.go b/internal/parser/sshd.go index 6a9b0f4..69154ca 100644 --- a/internal/parser/sshd.go +++ b/internal/parser/sshd.go @@ -24,30 +24,28 @@ func NewSshdParser() *SshdParser { func (p *SshdParser) Parse(eventCh <-chan Event, resultCh chan<- *storage.LogEntry) { // Group 1: Timestamp, Group 2: hostame, Group 3: pid, Group 4: Method auth, Group 5: User, Group 6: IP, Group 7: port - go func() { - for event := range eventCh { - matches := p.pattern.FindStringSubmatch(event.Data) - if matches == nil { - continue - } - resultCh <- &storage.LogEntry{ - Service: "ssh", - IP: matches[6], - Path: matches[5], // user - Status: "Failed", - Method: matches[4], // method auth - } - p.logger.Info( - "Parsed ssh log entry", - "ip", - matches[6], - "user", - matches[5], - "method", - matches[4], - "status", - "Failed", - ) + for event := range eventCh { + matches := p.pattern.FindStringSubmatch(event.Data) + if matches == nil { + continue } - }() + resultCh <- &storage.LogEntry{ + Service: "ssh", + IP: matches[6], + Path: matches[5], // user + Status: "Failed", + Method: matches[4], // method auth + } + p.logger.Info( + "Parsed ssh log entry", + "ip", + matches[6], + "user", + matches[5], + "method", + matches[4], + "status", + "Failed", + ) + } }