158 lines
3.8 KiB
Go
158 lines
3.8 KiB
Go
package utils
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
)
|
|
|
|
// CertBundle holds CA and server certificates loaded from disk.
|
|
type CertBundle struct {
|
|
CACert *x509.Certificate
|
|
CAKey *rsa.PrivateKey
|
|
ServerCert *x509.Certificate
|
|
ServerKey *rsa.PrivateKey
|
|
}
|
|
|
|
// LoadCertBundle loads CA and server certificates from the given directory.
|
|
func LoadCertBundle(certDir string) (*CertBundle, error) {
|
|
caCertPEM, err := os.ReadFile(filepath.Join(certDir, "ca.crt"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ca.crt: %w", err)
|
|
}
|
|
|
|
caKeyPEM, err := os.ReadFile(filepath.Join(certDir, "ca.key"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read ca.key: %w", err)
|
|
}
|
|
|
|
serverCertPEM, err := os.ReadFile(filepath.Join(certDir, "server.crt"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read server.crt: %w", err)
|
|
}
|
|
|
|
serverKeyPEM, err := os.ReadFile(filepath.Join(certDir, "server.key"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read server.key: %w", err)
|
|
}
|
|
|
|
caCert := decodeCert(caCertPEM)
|
|
caKey, err := decodeRSAPrivateKey(caKeyPEM)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse ca.key: %w", err)
|
|
}
|
|
serverCert := decodeCert(serverCertPEM)
|
|
serverKey, err := decodeRSAPrivateKey(serverKeyPEM)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse server.key: %w", err)
|
|
}
|
|
|
|
return &CertBundle{
|
|
CACert: caCert,
|
|
CAKey: caKey,
|
|
ServerCert: serverCert,
|
|
ServerKey: serverKey,
|
|
}, nil
|
|
}
|
|
|
|
// SignCSR signs a client CSR with the CA and returns the client certificate PEM.
|
|
func (b *CertBundle) SignCSR(csrPEM []byte, label string) ([]byte, error) {
|
|
csr := decodeCSR(csrPEM)
|
|
|
|
// Verify CSR signature
|
|
if err := csr.CheckSignature(); err != nil {
|
|
return nil, fmt.Errorf("invalid CSR signature: %w", err)
|
|
}
|
|
|
|
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate serial: %w", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
template := x509.Certificate{
|
|
SerialNumber: serialNumber,
|
|
Subject: pkix.Name{
|
|
CommonName: label,
|
|
Organization: csr.Subject.Organization,
|
|
},
|
|
NotBefore: now,
|
|
NotAfter: now.Add(365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, b.CACert, csr.PublicKey, b.CAKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create certificate: %w", err)
|
|
}
|
|
|
|
certPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certDER,
|
|
})
|
|
|
|
return certPEM, nil
|
|
}
|
|
|
|
// GetCACertPEM returns the CA certificate as PEM bytes.
|
|
func (b *CertBundle) GetCACertPEM() []byte {
|
|
return pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: b.CACert.Raw,
|
|
})
|
|
}
|
|
|
|
func decodeCert(pemData []byte) *x509.Certificate {
|
|
block, _ := pem.Decode(pemData)
|
|
if block == nil {
|
|
return nil
|
|
}
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return cert
|
|
}
|
|
|
|
func decodeRSAPrivateKey(pemData []byte) (*rsa.PrivateKey, error) {
|
|
block, _ := pem.Decode(pemData)
|
|
if block == nil {
|
|
return nil, fmt.Errorf("no PEM block found")
|
|
}
|
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
// Try PKCS1 fallback
|
|
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse PKCS1: %w", err)
|
|
}
|
|
return key.(*rsa.PrivateKey), nil
|
|
}
|
|
rsaKey, ok := key.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("key is not RSA, got %T", key)
|
|
}
|
|
return rsaKey, nil
|
|
}
|
|
|
|
func decodeCSR(pemData []byte) *x509.CertificateRequest {
|
|
block, _ := pem.Decode(pemData)
|
|
if block == nil {
|
|
return nil
|
|
}
|
|
csr, err := x509.ParseCertificateRequest(block.Bytes)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return csr
|
|
}
|