mirror of
https://github.com/luckfox-eng29/kvm.git
synced 2026-01-18 03:28:19 +01:00
feat(tls): #330
This commit is contained in:
283
web_tls.go
283
web_tls.go
@@ -1,132 +1,211 @@
|
||||
package kvm
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jetkvm/kvm/internal/websecure"
|
||||
)
|
||||
|
||||
const (
|
||||
WebSecureListen = ":443"
|
||||
WebSecureSelfSignedDefaultDomain = "jetkvm.local"
|
||||
WebSecureSelfSignedDuration = 365 * 24 * time.Hour
|
||||
tlsStorePath = "/userdata/jetkvm/tls"
|
||||
webSecureListen = ":443"
|
||||
webSecureSelfSignedDefaultDomain = "jetkvm.local"
|
||||
webSecureSelfSignedCAName = "JetKVM Self-Signed CA"
|
||||
webSecureSelfSignedOrganization = "JetKVM"
|
||||
webSecureSelfSignedOU = "JetKVM Self-Signed"
|
||||
webSecureCustomCertificateName = "user-defined"
|
||||
)
|
||||
|
||||
var (
|
||||
tlsCerts = make(map[string]*tls.Certificate)
|
||||
tlsCertLock = &sync.Mutex{}
|
||||
certStore *websecure.CertStore
|
||||
certSigner *websecure.SelfSigner
|
||||
)
|
||||
|
||||
type TLSState struct {
|
||||
Mode string `json:"mode"`
|
||||
Certificate string `json:"certificate"`
|
||||
PrivateKey string `json:"privateKey"`
|
||||
}
|
||||
|
||||
func initCertStore() {
|
||||
if certStore != nil {
|
||||
websecureLogger.Warn().Msg("TLS store already initialized, it should not be initialized again")
|
||||
return
|
||||
}
|
||||
certStore = websecure.NewCertStore(tlsStorePath, &websecureLogger)
|
||||
certStore.LoadCertificates()
|
||||
|
||||
certSigner = websecure.NewSelfSigner(
|
||||
certStore,
|
||||
&websecureLogger,
|
||||
webSecureSelfSignedDefaultDomain,
|
||||
webSecureSelfSignedOrganization,
|
||||
webSecureSelfSignedOU,
|
||||
webSecureSelfSignedCAName,
|
||||
)
|
||||
}
|
||||
|
||||
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if config.TLSMode == "self-signed" {
|
||||
if isTimeSyncNeeded() || !timeSyncSuccess {
|
||||
return nil, fmt.Errorf("time is not synced")
|
||||
}
|
||||
return certSigner.GetCertificate(info)
|
||||
} else if config.TLSMode == "custom" {
|
||||
return certStore.GetCertificate(webSecureCustomCertificateName), nil
|
||||
}
|
||||
|
||||
websecureLogger.Info().Msg("TLS mode is disabled but WebSecure is running, returning nil")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getTLSState() TLSState {
|
||||
s := TLSState{}
|
||||
switch config.TLSMode {
|
||||
case "disabled":
|
||||
s.Mode = "disabled"
|
||||
case "custom":
|
||||
s.Mode = "custom"
|
||||
cert := certStore.GetCertificate(webSecureCustomCertificateName)
|
||||
if cert != nil {
|
||||
var certPEM []byte
|
||||
// convert to pem format
|
||||
for _, c := range cert.Certificate {
|
||||
block := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: c,
|
||||
}
|
||||
|
||||
certPEM = append(certPEM, pem.EncodeToMemory(&block)...)
|
||||
}
|
||||
s.Certificate = string(certPEM)
|
||||
}
|
||||
case "self-signed":
|
||||
s.Mode = "self-signed"
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func setTLSState(s TLSState) error {
|
||||
var isChanged = false
|
||||
|
||||
switch s.Mode {
|
||||
case "disabled":
|
||||
if config.TLSMode != "" {
|
||||
isChanged = true
|
||||
}
|
||||
config.TLSMode = ""
|
||||
case "custom":
|
||||
if config.TLSMode == "" {
|
||||
isChanged = true
|
||||
}
|
||||
// parse pem to cert and key
|
||||
err, _ := certStore.ValidateAndSaveCertificate(webSecureCustomCertificateName, s.Certificate, s.PrivateKey, true)
|
||||
// warn doesn't matter as ... we don't know the hostname yet
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to save certificate: %w", err)
|
||||
}
|
||||
config.TLSMode = "custom"
|
||||
case "self-signed":
|
||||
if config.TLSMode == "" {
|
||||
isChanged = true
|
||||
}
|
||||
config.TLSMode = "self-signed"
|
||||
default:
|
||||
return fmt.Errorf("invalid TLS mode: %s", s.Mode)
|
||||
}
|
||||
|
||||
if !isChanged {
|
||||
websecureLogger.Info().Msg("TLS enabled state is not changed, not starting/stopping websecure server")
|
||||
return nil
|
||||
}
|
||||
|
||||
if config.TLSMode == "" {
|
||||
websecureLogger.Info().Msg("Stopping websecure server, as TLS mode is disabled")
|
||||
stopWebSecureServer()
|
||||
} else {
|
||||
websecureLogger.Info().Msg("Starting websecure server, as TLS mode is enabled")
|
||||
startWebSecureServer()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
startTLS = make(chan struct{})
|
||||
stopTLS = make(chan struct{})
|
||||
tlsServiceLock = sync.Mutex{}
|
||||
tlsStarted = false
|
||||
)
|
||||
|
||||
// RunWebSecureServer runs a web server with TLS.
|
||||
func RunWebSecureServer() {
|
||||
func runWebSecureServer() {
|
||||
tlsServiceLock.Lock()
|
||||
defer tlsServiceLock.Unlock()
|
||||
|
||||
tlsStarted = true
|
||||
defer func() {
|
||||
tlsStarted = false
|
||||
}()
|
||||
|
||||
r := setupRouter()
|
||||
|
||||
server := &http.Server{
|
||||
Addr: WebSecureListen,
|
||||
Addr: webSecureListen,
|
||||
Handler: r,
|
||||
TLSConfig: &tls.Config{
|
||||
// TODO: cache certificate in persistent storage
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
var hostname string
|
||||
if info.ServerName != "" {
|
||||
hostname = info.ServerName
|
||||
} else {
|
||||
hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0]
|
||||
}
|
||||
|
||||
logger.Info().Str("hostname", hostname).Interface("SupportedProtos", info.SupportedProtos).Msg("TLS handshake")
|
||||
|
||||
cert := createSelfSignedCert(hostname)
|
||||
|
||||
return cert, nil
|
||||
},
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
CurvePreferences: []tls.CurveID{},
|
||||
GetCertificate: getCertificate,
|
||||
},
|
||||
}
|
||||
logger.Info().Str("listen", WebSecureListen).Msg("Starting websecure server")
|
||||
websecureLogger.Info().Str("listen", webSecureListen).Msg("Starting websecure server")
|
||||
|
||||
go func() {
|
||||
for _ = range stopTLS {
|
||||
websecureLogger.Info().Msg("Shutting down websecure server")
|
||||
err := server.Shutdown(context.Background())
|
||||
if err != nil {
|
||||
websecureLogger.Error().Err(err).Msg("Failed to shutdown websecure server")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err := server.ListenAndServeTLS("", "")
|
||||
if err != nil {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createSelfSignedCert(hostname string) *tls.Certificate {
|
||||
if tlsCert := tlsCerts[hostname]; tlsCert != nil {
|
||||
return tlsCert
|
||||
func stopWebSecureServer() {
|
||||
if !tlsStarted {
|
||||
websecureLogger.Info().Msg("Websecure server is not running, not stopping it")
|
||||
return
|
||||
}
|
||||
stopTLS <- struct{}{}
|
||||
}
|
||||
|
||||
func startWebSecureServer() {
|
||||
if tlsStarted {
|
||||
websecureLogger.Info().Msg("Websecure server is already running, not starting it again")
|
||||
return
|
||||
}
|
||||
startTLS <- struct{}{}
|
||||
}
|
||||
|
||||
func RunWebSecureServer() {
|
||||
for _ = range startTLS {
|
||||
websecureLogger.Info().Msg("Starting websecure server, as we have received a start signal")
|
||||
if certStore == nil {
|
||||
initCertStore()
|
||||
}
|
||||
go runWebSecureServer()
|
||||
}
|
||||
tlsCertLock.Lock()
|
||||
defer tlsCertLock.Unlock()
|
||||
|
||||
logger.Info().Str("hostname", hostname).Msg("Creating self-signed certificate")
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to generate private key")
|
||||
os.Exit(1)
|
||||
}
|
||||
keyUsage := x509.KeyUsageDigitalSignature
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.AddDate(1, 0, 0)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to generate serial number")
|
||||
}
|
||||
|
||||
dnsName := hostname
|
||||
ip := net.ParseIP(hostname)
|
||||
if ip != nil {
|
||||
dnsName = WebSecureSelfSignedDefaultDomain
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: hostname,
|
||||
Organization: []string{"JetKVM"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
|
||||
DNSNames: []string{dnsName},
|
||||
IPAddresses: []net.IP{},
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to create certificate")
|
||||
}
|
||||
|
||||
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
if cert == nil {
|
||||
logger.Warn().Msg("Failed to encode certificate")
|
||||
}
|
||||
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{derBytes},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
tlsCerts[hostname] = tlsCert
|
||||
|
||||
return tlsCert
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user