diff --git a/Makefile b/Makefile index 40b1f06..8756c56 100644 --- a/Makefile +++ b/Makefile @@ -25,3 +25,6 @@ clean: test: go test ./... + +test-cover: + go test -cover ./... diff --git a/internal/blocker/validators_test.go b/internal/blocker/validators_test.go new file mode 100644 index 0000000..ee8b806 --- /dev/null +++ b/internal/blocker/validators_test.go @@ -0,0 +1,47 @@ +package blocker + +import ( + "testing" +) + +func TestValidateConfigPath(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {name: "empty", input: "", wantErr: true}, + {name: "valid path", input: "/path/to/config", wantErr: false}, + {name: "invalid path", input: "path/to/config", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateConfigPath(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateConfigPath(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateIP(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {name: "empty", input: "", wantErr: true}, + {name: "invalid IP", input: "1.1.1", wantErr: true}, + {name: "valid IP", input: "1.1.1.1", wantErr: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateIP(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateIP(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} diff --git a/internal/storage/writer_test.go b/internal/storage/writer_test.go new file mode 100644 index 0000000..d216d54 --- /dev/null +++ b/internal/storage/writer_test.go @@ -0,0 +1,40 @@ +package storage + +import ( + "testing" + "time" +) + +func TestWrite(t *testing.T) { + var ip string + d := createTestDBStruct(t) + + err := d.CreateTable() + if err != nil { + t.Fatal(err) + } + + resultCh := make(chan *LogEntry) + + go Write(d, resultCh) + + resultCh <- &LogEntry{ + Service: "test", + IP: "127.0.0.1", + Path: "/test", + Method: "GET", + Status: "200", + } + + close(resultCh) + + time.Sleep(100 * time.Millisecond) + + err = d.db.QueryRow("SELECT ip FROM requests LIMIT 1").Scan(&ip) + if err != nil { + t.Fatal(err) + } + if ip != "127.0.0.1" { + t.Fatal("ip should be 127.0.0.1") + } +}