@@ -0,0 +1,157 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertBundle holds CA and server certificates loaded from disk.
|
||||
type CertBundle struct {
|
||||
CACert *x509.Certificate
|
||||
CAKey *rsa.PrivateKey
|
||||
ServerCert *x509.Certificate
|
||||
ServerKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// LoadCertBundle loads CA and server certificates from the given directory.
|
||||
func LoadCertBundle(certDir string) (*CertBundle, error) {
|
||||
caCertPEM, err := os.ReadFile(filepath.Join(certDir, "ca.crt"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ca.crt: %w", err)
|
||||
}
|
||||
|
||||
caKeyPEM, err := os.ReadFile(filepath.Join(certDir, "ca.key"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ca.key: %w", err)
|
||||
}
|
||||
|
||||
serverCertPEM, err := os.ReadFile(filepath.Join(certDir, "server.crt"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read server.crt: %w", err)
|
||||
}
|
||||
|
||||
serverKeyPEM, err := os.ReadFile(filepath.Join(certDir, "server.key"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read server.key: %w", err)
|
||||
}
|
||||
|
||||
caCert := decodeCert(caCertPEM)
|
||||
caKey, err := decodeRSAPrivateKey(caKeyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ca.key: %w", err)
|
||||
}
|
||||
serverCert := decodeCert(serverCertPEM)
|
||||
serverKey, err := decodeRSAPrivateKey(serverKeyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse server.key: %w", err)
|
||||
}
|
||||
|
||||
return &CertBundle{
|
||||
CACert: caCert,
|
||||
CAKey: caKey,
|
||||
ServerCert: serverCert,
|
||||
ServerKey: serverKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignCSR signs a client CSR with the CA and returns the client certificate PEM.
|
||||
func (b *CertBundle) SignCSR(csrPEM []byte, label string) ([]byte, error) {
|
||||
csr := decodeCSR(csrPEM)
|
||||
|
||||
// Verify CSR signature
|
||||
if err := csr.CheckSignature(); err != nil {
|
||||
return nil, fmt.Errorf("invalid CSR signature: %w", err)
|
||||
}
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: label,
|
||||
Organization: csr.Subject.Organization,
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, b.CACert, csr.PublicKey, b.CAKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
return certPEM, nil
|
||||
}
|
||||
|
||||
// GetCACertPEM returns the CA certificate as PEM bytes.
|
||||
func (b *CertBundle) GetCACertPEM() []byte {
|
||||
return pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: b.CACert.Raw,
|
||||
})
|
||||
}
|
||||
|
||||
func decodeCert(pemData []byte) *x509.Certificate {
|
||||
block, _ := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
func decodeRSAPrivateKey(pemData []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block found")
|
||||
}
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
// Try PKCS1 fallback
|
||||
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse PKCS1: %w", err)
|
||||
}
|
||||
return key.(*rsa.PrivateKey), nil
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key is not RSA, got %T", key)
|
||||
}
|
||||
return rsaKey, nil
|
||||
}
|
||||
|
||||
func decodeCSR(pemData []byte) *x509.CertificateRequest {
|
||||
block, _ := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
csr, err := x509.ParseCertificateRequest(block.Bytes)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return csr
|
||||
}
|
||||
Reference in New Issue
Block a user