This commit is contained in:
Siyuan Miao
2025-03-18 17:25:03 +01:00
parent 4c37f7e079
commit 82c018a2f6
11 changed files with 858 additions and 129 deletions

View File

@@ -1,132 +1,211 @@
package kvm
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"errors"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/jetkvm/kvm/internal/websecure"
)
const (
WebSecureListen = ":443"
WebSecureSelfSignedDefaultDomain = "jetkvm.local"
WebSecureSelfSignedDuration = 365 * 24 * time.Hour
tlsStorePath = "/userdata/jetkvm/tls"
webSecureListen = ":443"
webSecureSelfSignedDefaultDomain = "jetkvm.local"
webSecureSelfSignedCAName = "JetKVM Self-Signed CA"
webSecureSelfSignedOrganization = "JetKVM"
webSecureSelfSignedOU = "JetKVM Self-Signed"
webSecureCustomCertificateName = "user-defined"
)
var (
tlsCerts = make(map[string]*tls.Certificate)
tlsCertLock = &sync.Mutex{}
certStore *websecure.CertStore
certSigner *websecure.SelfSigner
)
type TLSState struct {
Mode string `json:"mode"`
Certificate string `json:"certificate"`
PrivateKey string `json:"privateKey"`
}
func initCertStore() {
if certStore != nil {
websecureLogger.Warn().Msg("TLS store already initialized, it should not be initialized again")
return
}
certStore = websecure.NewCertStore(tlsStorePath, &websecureLogger)
certStore.LoadCertificates()
certSigner = websecure.NewSelfSigner(
certStore,
&websecureLogger,
webSecureSelfSignedDefaultDomain,
webSecureSelfSignedOrganization,
webSecureSelfSignedOU,
webSecureSelfSignedCAName,
)
}
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if config.TLSMode == "self-signed" {
if isTimeSyncNeeded() || !timeSyncSuccess {
return nil, fmt.Errorf("time is not synced")
}
return certSigner.GetCertificate(info)
} else if config.TLSMode == "custom" {
return certStore.GetCertificate(webSecureCustomCertificateName), nil
}
websecureLogger.Info().Msg("TLS mode is disabled but WebSecure is running, returning nil")
return nil, nil
}
func getTLSState() TLSState {
s := TLSState{}
switch config.TLSMode {
case "disabled":
s.Mode = "disabled"
case "custom":
s.Mode = "custom"
cert := certStore.GetCertificate(webSecureCustomCertificateName)
if cert != nil {
var certPEM []byte
// convert to pem format
for _, c := range cert.Certificate {
block := pem.Block{
Type: "CERTIFICATE",
Bytes: c,
}
certPEM = append(certPEM, pem.EncodeToMemory(&block)...)
}
s.Certificate = string(certPEM)
}
case "self-signed":
s.Mode = "self-signed"
}
return s
}
func setTLSState(s TLSState) error {
var isChanged = false
switch s.Mode {
case "disabled":
if config.TLSMode != "" {
isChanged = true
}
config.TLSMode = ""
case "custom":
if config.TLSMode == "" {
isChanged = true
}
// parse pem to cert and key
err, _ := certStore.ValidateAndSaveCertificate(webSecureCustomCertificateName, s.Certificate, s.PrivateKey, true)
// warn doesn't matter as ... we don't know the hostname yet
if err != nil {
return fmt.Errorf("Failed to save certificate: %w", err)
}
config.TLSMode = "custom"
case "self-signed":
if config.TLSMode == "" {
isChanged = true
}
config.TLSMode = "self-signed"
default:
return fmt.Errorf("invalid TLS mode: %s", s.Mode)
}
if !isChanged {
websecureLogger.Info().Msg("TLS enabled state is not changed, not starting/stopping websecure server")
return nil
}
if config.TLSMode == "" {
websecureLogger.Info().Msg("Stopping websecure server, as TLS mode is disabled")
stopWebSecureServer()
} else {
websecureLogger.Info().Msg("Starting websecure server, as TLS mode is enabled")
startWebSecureServer()
}
return nil
}
var (
startTLS = make(chan struct{})
stopTLS = make(chan struct{})
tlsServiceLock = sync.Mutex{}
tlsStarted = false
)
// RunWebSecureServer runs a web server with TLS.
func RunWebSecureServer() {
func runWebSecureServer() {
tlsServiceLock.Lock()
defer tlsServiceLock.Unlock()
tlsStarted = true
defer func() {
tlsStarted = false
}()
r := setupRouter()
server := &http.Server{
Addr: WebSecureListen,
Addr: webSecureListen,
Handler: r,
TLSConfig: &tls.Config{
// TODO: cache certificate in persistent storage
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
var hostname string
if info.ServerName != "" {
hostname = info.ServerName
} else {
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
}
logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake")
cert := createSelfSignedCert(hostname)
return cert, nil
},
MaxVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{},
GetCertificate: getCertificate,
},
}
logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server")
websecureLogger.Info().Str("listen", webSecureListen).Msg("Starting websecure server")
go func() {
for _ = range stopTLS {
websecureLogger.Info().Msg("Shutting down websecure server")
err := server.Shutdown(context.Background())
if err != nil {
websecureLogger.Error().Err(err).Msg("Failed to shutdown websecure server")
}
}
}()
err := server.ListenAndServeTLS("", "")
if err != nil {
if !errors.Is(err, http.ErrServerClosed) {
panic(err)
}
}
func createSelfSignedCert(hostname string) *tls.Certificate {
if tlsCert := tlsCerts[hostname]; tlsCert != nil {
return tlsCert
func stopWebSecureServer() {
if !tlsStarted {
websecureLogger.Info().Msg("Websecure server is not running, not stopping it")
return
}
stopTLS <- struct{}{}
}
func startWebSecureServer() {
if tlsStarted {
websecureLogger.Info().Msg("Websecure server is already running, not starting it again")
return
}
startTLS <- struct{}{}
}
func RunWebSecureServer() {
for _ = range startTLS {
websecureLogger.Info().Msg("Starting websecure server, as we have received a start signal")
if certStore == nil {
initCertStore()
}
go runWebSecureServer()
}
tlsCertLock.Lock()
defer tlsCertLock.Unlock()
logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
logger.Warn().Err(err).Msg("Failed to generate private key")
os.Exit(1)
}
keyUsage := x509.KeyUsageDigitalSignature
notBefore := time.Now()
notAfter := notBefore.AddDate(1, 0, 0)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
logger.Warn().Err(err).Msg("Failed to generate serial number")
}
dnsName := hostname
ip := net.ParseIP(hostname)
if ip != nil {
dnsName = WebSecureSelfSignedDefaultDomain
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: hostname,
Organization: []string{"JetKVM"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{dnsName},
IPAddresses: []net.IP{},
}
if ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
logger.Warn().Err(err).Msg("Failed to create certificate")
}
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if cert == nil {
logger.Warn().Msg("Failed to encode certificate")
}
tlsCert := &tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}
tlsCerts[hostname] = tlsCert
return tlsCert
}