package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "errors" "fmt" "strings" "time" "unicode" "golang.org/x/crypto/bcrypt" ) var ( ErrEmailExists = errors.New("email already registered") ErrInvalidCreds = errors.New("invalid email or password") ErrUserNotFound = errors.New("user not found") ErrInvalidUserID = errors.New("invalid user ID") ErrInvalidRefresh = errors.New("invalid refresh token") ErrRefreshExpired = errors.New("refresh token expired") ErrLogoutInvalid = errors.New("refresh token not found or already used") ErrWrongPassword = errors.New("current password is incorrect") ErrWeakPassword = errors.New("password must be at least 8 characters with uppercase, lowercase, and digit") ErrSamePassword = errors.New("new password must differ from current password") ) type Service struct { repo *Repository jwtSecret []byte jwtExp time.Duration refreshExp time.Duration } func NewService(repo *Repository, jwtSecret string, jwtExp, refreshExp time.Duration) *Service { return &Service{ repo: repo, jwtSecret: []byte(jwtSecret), jwtExp: jwtExp, refreshExp: refreshExp, } } func sha256Hex(data string) string { h := sha256.Sum256([]byte(data)) return fmt.Sprintf("%x", h) } func generateRandomToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate random bytes: %w", err) } return base64.RawURLEncoding.EncodeToString(b), nil } func validatePasswordStrength(password string) error { if len(password) < 8 { return ErrWeakPassword } var hasUpper, hasLower, hasDigit bool for _, ch := range password { switch { case unicode.IsUpper(ch): hasUpper = true case unicode.IsLower(ch): hasLower = true case unicode.IsDigit(ch): hasDigit = true } } if !hasUpper || !hasLower || !hasDigit { return ErrWeakPassword } return nil } func (s *Service) issueTokenPair(ctx context.Context, user *User) (*AuthResponse, error) { accessToken, err := GenerateToken(user.ID, user.Email, s.jwtSecret, s.jwtExp) if err != nil { return nil, fmt.Errorf("failed to generate access token: %w", err) } rawRefresh, err := generateRandomToken() if err != nil { return nil, fmt.Errorf("failed to generate refresh token: %w", err) } refreshDoc := &RefreshTokenDoc{ UserID: user.ID, TokenHash: sha256Hex(rawRefresh), ExpiresAt: time.Now().UTC().Add(s.refreshExp), } if err := s.repo.CreateRefreshToken(ctx, refreshDoc); err != nil { return nil, fmt.Errorf("failed to store refresh token: %w", err) } return &AuthResponse{ Token: accessToken, RefreshToken: rawRefresh, User: NewUserPublic(user), }, nil } func (s *Service) Register(ctx context.Context, req RegisterRequest) (*AuthResponse, error) { if err := validatePasswordStrength(req.Password); err != nil { return nil, err } req.Email = strings.ToLower(req.Email) existing, err := s.repo.FindByEmail(ctx, req.Email) if err != nil && !errors.Is(err, ErrNoRows) { return nil, fmt.Errorf("failed to check existing user: %w", err) } if existing != nil { return nil, ErrEmailExists } hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } user := &User{ Username: req.Username, Email: req.Email, PasswordHash: string(hash), } if err := s.repo.CreateUser(ctx, user); err != nil { if isPGUniqueViolation(err) { return nil, ErrEmailExists } return nil, fmt.Errorf("failed to create user: %w", err) } return s.issueTokenPair(ctx, user) } func (s *Service) Login(ctx context.Context, req LoginRequest) (*AuthResponse, error) { req.Email = strings.ToLower(req.Email) user, err := s.repo.FindByEmail(ctx, req.Email) if err != nil { if errors.Is(err, ErrNoRows) { return nil, ErrInvalidCreds } return nil, fmt.Errorf("failed to find user: %w", err) } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { return nil, ErrInvalidCreds } return s.issueTokenPair(ctx, user) } func (s *Service) Refresh(ctx context.Context, rawRefresh string) (*AuthResponse, error) { hash := sha256Hex(rawRefresh) doc, err := s.repo.FindRefreshTokenByHash(ctx, hash) if err != nil { if errors.Is(err, ErrNoRows) { return nil, ErrInvalidRefresh } return nil, fmt.Errorf("failed to find refresh token: %w", err) } if err := s.repo.DeleteRefreshToken(ctx, doc.ID); err != nil { return nil, fmt.Errorf("failed to delete old refresh token: %w", err) } user, err := s.repo.FindByID(ctx, doc.UserID) if err != nil { return nil, fmt.Errorf("failed to find user: %w", err) } return s.issueTokenPair(ctx, user) } func (s *Service) Logout(ctx context.Context, rawRefresh string) error { hash := sha256Hex(rawRefresh) found, err := s.repo.DeleteRefreshTokenByHash(ctx, hash) if err != nil { return fmt.Errorf("failed to delete refresh token: %w", err) } if !found { return ErrLogoutInvalid } return nil } func (s *Service) GetUserByID(ctx context.Context, userID string) (*UserPublic, error) { if userID == "" { return nil, ErrInvalidUserID } user, err := s.repo.FindByID(ctx, userID) if err != nil { if errors.Is(err, ErrNoRows) { return nil, ErrUserNotFound } return nil, fmt.Errorf("failed to find user: %w", err) } public := NewUserPublic(user) return &public, nil } func (s *Service) ChangePassword(ctx context.Context, userID string, req PasswordChangeRequest) error { if userID == "" { return ErrInvalidUserID } user, err := s.repo.FindByID(ctx, userID) if err != nil { if errors.Is(err, ErrNoRows) { return ErrUserNotFound } return fmt.Errorf("failed to find user: %w", err) } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.OldPassword)); err != nil { return ErrWrongPassword } if req.OldPassword == req.NewPassword { return ErrSamePassword } if err := validatePasswordStrength(req.NewPassword); err != nil { return err } hash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %w", err) } if err := s.repo.UpdateUserPassword(ctx, userID, string(hash)); err != nil { return fmt.Errorf("failed to update password: %w", err) } return nil } func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (*UserPublic, error) { if userID == "" { return nil, ErrInvalidUserID } user, err := s.repo.FindByID(ctx, userID) if err != nil { if errors.Is(err, ErrNoRows) { return nil, ErrUserNotFound } return nil, fmt.Errorf("failed to find user: %w", err) } if err := s.repo.UpdateUserUsername(ctx, userID, req.Username); err != nil { return nil, fmt.Errorf("failed to update username: %w", err) } user.Username = req.Username public := NewUserPublic(user) return &public, nil } func isPGUniqueViolation(err error) bool { return err != nil && (strings.Contains(err.Error(), "unique") || strings.Contains(err.Error(), "23505")) }