mirror of
https://github.com/luckfox-eng29/kvm.git
synced 2026-01-18 03:28:19 +01:00
feat(tls): #330
This commit is contained in:
9
internal/websecure/log.go
Normal file
9
internal/websecure/log.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package websecure
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var defaultLogger = zerolog.New(os.Stdout).With().Str("component", "websecure").Logger()
|
||||
191
internal/websecure/selfsign.go
Normal file
191
internal/websecure/selfsign.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package websecure
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
const selfSignerCAMagicName = "__ca__"
|
||||
|
||||
type SelfSigner struct {
|
||||
store *CertStore
|
||||
log *zerolog.Logger
|
||||
|
||||
caInfo pkix.Name
|
||||
|
||||
DefaultDomain string
|
||||
DefaultOrg string
|
||||
DefaultOU string
|
||||
}
|
||||
|
||||
func NewSelfSigner(
|
||||
store *CertStore,
|
||||
log *zerolog.Logger,
|
||||
defaultDomain,
|
||||
defaultOrg,
|
||||
defaultOU,
|
||||
caName string,
|
||||
) *SelfSigner {
|
||||
return &SelfSigner{
|
||||
store: store,
|
||||
log: log,
|
||||
DefaultDomain: defaultDomain,
|
||||
DefaultOrg: defaultOrg,
|
||||
DefaultOU: defaultOU,
|
||||
caInfo: pkix.Name{
|
||||
CommonName: caName,
|
||||
Organization: []string{defaultOrg},
|
||||
OrganizationalUnit: []string{defaultOU},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SelfSigner) getCA() *tls.Certificate {
|
||||
return s.createSelfSignedCert(selfSignerCAMagicName)
|
||||
}
|
||||
|
||||
func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate {
|
||||
if tlsCert := s.store.certificates[hostname]; tlsCert != nil {
|
||||
return tlsCert
|
||||
}
|
||||
|
||||
// check if hostname is the CA magic name
|
||||
var ca *tls.Certificate
|
||||
if hostname != selfSignerCAMagicName {
|
||||
ca = s.getCA()
|
||||
if ca == nil {
|
||||
s.log.Error().Msg("Failed to get CA certificate")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
|
||||
|
||||
// lock the store while creating the certificate (do not move upwards)
|
||||
s.store.certLock.Lock()
|
||||
defer s.store.certLock.Unlock()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to generate private key")
|
||||
return nil
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.AddDate(1, 0, 0)
|
||||
|
||||
serialNumber, err := generateSerialNumber()
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to generate serial number")
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsName := hostname
|
||||
ip := net.ParseIP(hostname)
|
||||
if ip != nil {
|
||||
dnsName = s.DefaultDomain
|
||||
}
|
||||
|
||||
// set up CSR
|
||||
isCA := hostname == selfSignerCAMagicName
|
||||
subject := pkix.Name{
|
||||
CommonName: hostname,
|
||||
Organization: []string{s.DefaultOrg},
|
||||
OrganizationalUnit: []string{s.DefaultOU},
|
||||
}
|
||||
keyUsage := x509.KeyUsageDigitalSignature
|
||||
extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
|
||||
|
||||
// check if hostname is the CA magic name, and if so, set the subject to the CA info
|
||||
if isCA {
|
||||
subject = s.caInfo
|
||||
keyUsage |= x509.KeyUsageCertSign
|
||||
extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth)
|
||||
notAfter = notBefore.AddDate(10, 0, 0)
|
||||
}
|
||||
|
||||
cert := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: subject,
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
IsCA: isCA,
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: extKeyUsage,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
// set up DNS names and IP addresses
|
||||
if !isCA {
|
||||
cert.DNSNames = []string{dnsName}
|
||||
if ip != nil {
|
||||
cert.IPAddresses = []net.IP{ip}
|
||||
}
|
||||
}
|
||||
|
||||
// set up parent certificate
|
||||
parent := &cert
|
||||
parentPriv := priv
|
||||
if ca != nil {
|
||||
parent, err = x509.ParseCertificate(ca.Certificate[0])
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to parse parent certificate")
|
||||
return nil
|
||||
}
|
||||
parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey)
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to create certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{certBytes},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
if ca != nil {
|
||||
tlsCert.Certificate = append(tlsCert.Certificate, ca.Certificate...)
|
||||
}
|
||||
|
||||
s.store.certificates[hostname] = tlsCert
|
||||
s.store.saveCertificate(hostname)
|
||||
|
||||
return tlsCert
|
||||
}
|
||||
|
||||
// GetCertificate returns the certificate for the given hostname
|
||||
// returns nil if the certificate is not found
|
||||
func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
var hostname string
|
||||
if info.ServerName != "" && info.ServerName != selfSignerCAMagicName {
|
||||
hostname = info.ServerName
|
||||
} else {
|
||||
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
|
||||
}
|
||||
|
||||
s.log.Info().Str("hostname", hostname).Strs("supported_protos", info.SupportedProtos).Msg("TLS handshake")
|
||||
|
||||
// convert hostname to punycode
|
||||
h, err := idna.Lookup.ToASCII(hostname)
|
||||
if err != nil {
|
||||
s.log.Warn().Str("hostname", hostname).Err(err).Str("remote_addr", info.Conn.RemoteAddr().String()).Msg("Hostname is not valid")
|
||||
hostname = s.DefaultDomain
|
||||
} else {
|
||||
hostname = h
|
||||
}
|
||||
|
||||
cert := s.createSelfSignedCert(hostname)
|
||||
return cert, nil
|
||||
}
|
||||
175
internal/websecure/store.go
Normal file
175
internal/websecure/store.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package websecure
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type CertStore struct {
|
||||
certificates map[string]*tls.Certificate
|
||||
certLock *sync.Mutex
|
||||
|
||||
storePath string
|
||||
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewCertStore(storePath string, log *zerolog.Logger) *CertStore {
|
||||
if log == nil {
|
||||
log = &defaultLogger
|
||||
}
|
||||
|
||||
return &CertStore{
|
||||
certificates: make(map[string]*tls.Certificate),
|
||||
certLock: &sync.Mutex{},
|
||||
|
||||
storePath: storePath,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CertStore) ensureStorePath() error {
|
||||
// check if directory exists
|
||||
stat, err := os.Stat(s.storePath)
|
||||
if err == nil {
|
||||
if stat.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("TLS store path exists but is not a directory: %s", s.storePath)
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
s.log.Trace().Str("path", s.storePath).Msg("TLS store directory does not exist, creating directory")
|
||||
err = os.MkdirAll(s.storePath, 0755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create TLS store path: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Failed to check TLS store path: %w", err)
|
||||
}
|
||||
|
||||
func (s *CertStore) LoadCertificates() {
|
||||
err := s.ensureStorePath()
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to ensure store path")
|
||||
return
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(s.storePath)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to read TLS directory")
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(file.Name(), ".crt") {
|
||||
s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CertStore) loadCertificate(hostname string) {
|
||||
s.certLock.Lock()
|
||||
defer s.certLock.Unlock()
|
||||
|
||||
keyFile := path.Join(s.storePath, hostname+".key")
|
||||
crtFile := path.Join(s.storePath, hostname+".crt")
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(crtFile, keyFile)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Str("hostname", hostname).Msg("Failed to load certificate")
|
||||
return
|
||||
}
|
||||
|
||||
s.certificates[hostname] = &cert
|
||||
|
||||
s.log.Info().Str("hostname", hostname).Msg("Loaded certificate")
|
||||
}
|
||||
|
||||
// GetCertificate returns the certificate for the given hostname
|
||||
// returns nil if the certificate is not found
|
||||
func (s *CertStore) GetCertificate(hostname string) *tls.Certificate {
|
||||
s.certLock.Lock()
|
||||
defer s.certLock.Unlock()
|
||||
|
||||
return s.certificates[hostname]
|
||||
}
|
||||
|
||||
// ValidateAndSaveCertificate validates the certificate and saves it to the store
|
||||
// returns are:
|
||||
// - error: if the certificate is invalid or if there's any error during saving the certificate
|
||||
// - error: if there's any warning or error during saving the certificate
|
||||
func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool) (error, error) {
|
||||
tlsCert, err := tls.X509KeyPair([]byte(cert), []byte(key))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to parse certificate: %w", err), nil
|
||||
}
|
||||
|
||||
// this can be skipped as current implementation supports one custom certificate only
|
||||
if tlsCert.Leaf != nil {
|
||||
// add recover to avoid panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.log.Error().Interface("recovered", r).Msg("Failed to verify hostname")
|
||||
}
|
||||
}()
|
||||
|
||||
if err = tlsCert.Leaf.VerifyHostname(hostname); err != nil {
|
||||
if !ignoreWarning {
|
||||
return nil, fmt.Errorf("Certificate does not match hostname: %w", err)
|
||||
}
|
||||
s.log.Warn().Err(err).Msg("Certificate does not match hostname")
|
||||
}
|
||||
}
|
||||
|
||||
s.certLock.Lock()
|
||||
s.certificates[hostname] = &tlsCert
|
||||
s.certLock.Unlock()
|
||||
|
||||
s.saveCertificate(hostname)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *CertStore) saveCertificate(hostname string) {
|
||||
// check if certificate already exists
|
||||
tlsCert := s.certificates[hostname]
|
||||
if tlsCert == nil {
|
||||
s.log.Error().Str("hostname", hostname).Msg("Certificate for hostname does not exist, skipping saving certificate")
|
||||
return
|
||||
}
|
||||
|
||||
err := s.ensureStorePath()
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to ensure store path")
|
||||
return
|
||||
}
|
||||
|
||||
keyFile := path.Join(s.storePath, hostname+".key")
|
||||
crtFile := path.Join(s.storePath, hostname+".crt")
|
||||
|
||||
if err := keyToFile(tlsCert, keyFile); err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to save key file")
|
||||
return
|
||||
}
|
||||
|
||||
if err := certToFile(tlsCert, crtFile); err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to save certificate")
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Info().Str("hostname", hostname).Msg("Saved certificate")
|
||||
}
|
||||
80
internal/websecure/utils.go
Normal file
80
internal/websecure/utils.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package websecure
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
)
|
||||
|
||||
var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 4096)
|
||||
|
||||
func withSecretFile(filename string, f func(*os.File) error) error {
|
||||
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return f(file)
|
||||
}
|
||||
|
||||
func keyToFile(cert *tls.Certificate, filename string) error {
|
||||
var keyBlock pem.Block
|
||||
switch k := cert.PrivateKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
keyBlock = pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(k),
|
||||
}
|
||||
case *ecdsa.PrivateKey:
|
||||
b, e := x509.MarshalECPrivateKey(k)
|
||||
if e != nil {
|
||||
return fmt.Errorf("Failed to marshal EC private key: %v", e)
|
||||
}
|
||||
|
||||
keyBlock = pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: b,
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unknown private key type: %T", k)
|
||||
}
|
||||
|
||||
err := withSecretFile(filename, func(file *os.File) error {
|
||||
return pem.Encode(file, &keyBlock)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to save private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func certToFile(cert *tls.Certificate, filename string) error {
|
||||
return withSecretFile(filename, func(file *os.File) error {
|
||||
for _, c := range cert.Certificate {
|
||||
block := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: c,
|
||||
}
|
||||
|
||||
err := pem.Encode(file, &block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to save certificate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func generateSerialNumber() (*big.Int, error) {
|
||||
return rand.Int(rand.Reader, serialNumberLimit)
|
||||
}
|
||||
Reference in New Issue
Block a user