network enhanecment / refactor (#361)

* chore(network): improve connectivity check

* refactor(network): rewrite network and timesync component

* feat(display): show cloud connection status

* chore: change logging verbosity

* chore(websecure): update log message

* fix(ota): validate root certificate when downloading update

* feat(ui): add network settings tab

* fix(display): cloud connecting animation

* fix: golintci issues

* feat: add network settings tab

* feat(timesync): query servers in parallel

* refactor(network): move to internal/network package

* feat(timesync): add metrics

* refactor(log): move log to internal/logging package

* refactor(mdms): move mdns to internal/mdns package

* feat(developer): add pprof endpoint

* feat(logging): add a simple logging streaming endpoint

* fix(mdns): do not start mdns until network is up

* feat(network): allow users to update network settings from ui

* fix(network): handle errors when net.IPAddr is nil

* fix(mdns): scopedLogger SIGSEGV

* fix(dhcp): watch directory instead of file to catch fsnotify.Create event

* refactor(nbd): move platform-specific code to different files

* refactor(native): move platform-specific code to different files

* chore: fix linter issues

* chore(dev_deploy): allow to override PION_LOG_TRACE
This commit is contained in:
Aveline
2025-04-16 01:39:23 +02:00
committed by GitHub
parent 2b2a14204d
commit 189b84380b
71 changed files with 4938 additions and 825 deletions

View File

@@ -0,0 +1,381 @@
package confparser
import (
"fmt"
"net"
"reflect"
"slices"
"strconv"
"strings"
"github.com/guregu/null/v6"
"golang.org/x/net/idna"
)
type FieldConfig struct {
Name string
Required bool
RequiredIf map[string]interface{}
OneOf []string
ValidateTypes []string
Defaults interface{}
IsEmpty bool
CurrentValue interface{}
TypeString string
Delegated bool
shouldUpdateValue bool
}
func SetDefaultsAndValidate(config interface{}) error {
return setDefaultsAndValidate(config, true)
}
func setDefaultsAndValidate(config interface{}, isRoot bool) error {
// first we need to check if the config is a pointer
if reflect.TypeOf(config).Kind() != reflect.Ptr {
return fmt.Errorf("config is not a pointer")
}
// now iterate over the lease struct and set the values
configType := reflect.TypeOf(config).Elem()
configValue := reflect.ValueOf(config).Elem()
fields := make(map[string]FieldConfig)
for i := 0; i < configType.NumField(); i++ {
field := configType.Field(i)
fieldValue := configValue.Field(i)
defaultValue := field.Tag.Get("default")
fieldType := field.Type.String()
fieldConfig := FieldConfig{
Name: field.Name,
OneOf: splitString(field.Tag.Get("one_of")),
ValidateTypes: splitString(field.Tag.Get("validate_type")),
RequiredIf: make(map[string]interface{}),
CurrentValue: fieldValue.Interface(),
IsEmpty: false,
TypeString: fieldType,
}
// check if the field is required
required := field.Tag.Get("required")
if required != "" {
requiredBool, _ := strconv.ParseBool(required)
fieldConfig.Required = requiredBool
}
var canUseOneOff = false
// use switch to get the type
switch fieldValue.Interface().(type) {
case string, null.String:
if defaultValue != "" {
fieldConfig.Defaults = defaultValue
}
canUseOneOff = true
case []string:
if defaultValue != "" {
fieldConfig.Defaults = strings.Split(defaultValue, ",")
}
canUseOneOff = true
case int, null.Int:
if defaultValue != "" {
defaultValueInt, err := strconv.Atoi(defaultValue)
if err != nil {
return fmt.Errorf("invalid default value for field `%s`: %s", field.Name, defaultValue)
}
fieldConfig.Defaults = defaultValueInt
}
case bool, null.Bool:
if defaultValue != "" {
defaultValueBool, err := strconv.ParseBool(defaultValue)
if err != nil {
return fmt.Errorf("invalid default value for field `%s`: %s", field.Name, defaultValue)
}
fieldConfig.Defaults = defaultValueBool
}
default:
if defaultValue != "" {
return fmt.Errorf("field `%s` cannot use default value: unsupported type: %s", field.Name, fieldType)
}
// check if it's a pointer
if fieldValue.Kind() == reflect.Ptr {
// check if the pointer is nil
if fieldValue.IsNil() {
fieldConfig.IsEmpty = true
} else {
fieldConfig.CurrentValue = fieldValue.Elem().Addr()
fieldConfig.Delegated = true
}
} else {
fieldConfig.Delegated = true
}
}
// now check if the field is nullable interface
switch fieldValue.Interface().(type) {
case null.String:
if fieldValue.Interface().(null.String).IsZero() {
fieldConfig.IsEmpty = true
}
case null.Int:
if fieldValue.Interface().(null.Int).IsZero() {
fieldConfig.IsEmpty = true
}
case null.Bool:
if fieldValue.Interface().(null.Bool).IsZero() {
fieldConfig.IsEmpty = true
}
case []string:
if len(fieldValue.Interface().([]string)) == 0 {
fieldConfig.IsEmpty = true
}
}
// now check if the field has required_if
requiredIf := field.Tag.Get("required_if")
if requiredIf != "" {
requiredIfParts := strings.Split(requiredIf, ",")
for _, part := range requiredIfParts {
partVal := strings.SplitN(part, "=", 2)
if len(partVal) != 2 {
return fmt.Errorf("invalid required_if for field `%s`: %s", field.Name, requiredIf)
}
fieldConfig.RequiredIf[partVal[0]] = partVal[1]
}
}
// check if the field can use one_of
if !canUseOneOff && len(fieldConfig.OneOf) > 0 {
return fmt.Errorf("field `%s` cannot use one_of: unsupported type: %s", field.Name, fieldType)
}
fields[field.Name] = fieldConfig
}
if err := validateFields(config, fields); err != nil {
return err
}
return nil
}
func validateFields(config interface{}, fields map[string]FieldConfig) error {
// now we can start to validate the fields
for _, fieldConfig := range fields {
if err := fieldConfig.validate(fields); err != nil {
return err
}
fieldConfig.populate(config)
}
return nil
}
func (f *FieldConfig) validate(fields map[string]FieldConfig) error {
var required bool
var err error
if required, err = f.validateRequired(fields); err != nil {
return err
}
// check if the field needs to be updated and set defaults if needed
if err := f.checkIfFieldNeedsUpdate(); err != nil {
return err
}
// then we can check if the field is one_of
if err := f.validateOneOf(); err != nil {
return err
}
// and validate the type
if err := f.validateField(); err != nil {
return err
}
// if the field is delegated, we need to validate the nested field
// but before that, let's check if the field is required
if required && f.Delegated {
if err := setDefaultsAndValidate(f.CurrentValue.(reflect.Value).Interface(), false); err != nil {
return err
}
}
return nil
}
func (f *FieldConfig) populate(config interface{}) {
// update the field if it's not empty
if !f.shouldUpdateValue {
return
}
reflect.ValueOf(config).Elem().FieldByName(f.Name).Set(reflect.ValueOf(f.CurrentValue))
}
func (f *FieldConfig) checkIfFieldNeedsUpdate() error {
// populate the field if it's empty and has a default value
if f.IsEmpty && f.Defaults != nil {
switch f.CurrentValue.(type) {
case null.String:
f.CurrentValue = null.StringFrom(f.Defaults.(string))
case null.Int:
f.CurrentValue = null.IntFrom(int64(f.Defaults.(int)))
case null.Bool:
f.CurrentValue = null.BoolFrom(f.Defaults.(bool))
case string:
f.CurrentValue = f.Defaults.(string)
case int:
f.CurrentValue = f.Defaults.(int)
case bool:
f.CurrentValue = f.Defaults.(bool)
case []string:
f.CurrentValue = f.Defaults.([]string)
default:
return fmt.Errorf("field `%s` cannot use default value: unsupported type: %s", f.Name, f.TypeString)
}
f.shouldUpdateValue = true
}
return nil
}
func (f *FieldConfig) validateRequired(fields map[string]FieldConfig) (bool, error) {
var required = f.Required
// if the field is not required, we need to check if it's required_if
if !required && len(f.RequiredIf) > 0 {
for key, value := range f.RequiredIf {
// check if the field's result matches the required_if
// right now we only support string and int
requiredField, ok := fields[key]
if !ok {
return required, fmt.Errorf("required_if field `%s` not found", key)
}
switch requiredField.CurrentValue.(type) {
case string:
if requiredField.CurrentValue.(string) == value.(string) {
required = true
}
case int:
if requiredField.CurrentValue.(int) == value.(int) {
required = true
}
case null.String:
if !requiredField.CurrentValue.(null.String).IsZero() &&
requiredField.CurrentValue.(null.String).String == value.(string) {
required = true
}
case null.Int:
if !requiredField.CurrentValue.(null.Int).IsZero() &&
requiredField.CurrentValue.(null.Int).Int64 == value.(int64) {
required = true
}
}
// if the field is required, we can break the loop
// because we only need one of the required_if fields to be true
if required {
break
}
}
}
if required && f.IsEmpty {
return false, fmt.Errorf("field `%s` is required", f.Name)
}
return required, nil
}
func checkIfSliceContains(slice []string, one_of []string) bool {
for _, oneOf := range one_of {
if slices.Contains(slice, oneOf) {
return true
}
}
return false
}
func (f *FieldConfig) validateOneOf() error {
if len(f.OneOf) == 0 {
return nil
}
var val []string
switch f.CurrentValue.(type) {
case string:
val = []string{f.CurrentValue.(string)}
case null.String:
val = []string{f.CurrentValue.(null.String).String}
case []string:
// let's validate the value here
val = f.CurrentValue.([]string)
default:
return fmt.Errorf("field `%s` cannot use one_of: unsupported type: %s", f.Name, f.TypeString)
}
if !checkIfSliceContains(val, f.OneOf) {
return fmt.Errorf(
"field `%s` is not one of the allowed values: %s, current value: %s",
f.Name,
strings.Join(f.OneOf, ", "),
strings.Join(val, ", "),
)
}
return nil
}
func (f *FieldConfig) validateField() error {
if len(f.ValidateTypes) == 0 || f.IsEmpty {
return nil
}
val, err := toString(f.CurrentValue)
if err != nil {
return fmt.Errorf("field `%s` cannot use validate_type: %s", f.Name, err)
}
if val == "" {
return nil
}
for _, validateType := range f.ValidateTypes {
switch validateType {
case "ipv4":
if net.ParseIP(val).To4() == nil {
return fmt.Errorf("field `%s` is not a valid IPv4 address: %s", f.Name, val)
}
case "ipv6":
if net.ParseIP(val).To16() == nil {
return fmt.Errorf("field `%s` is not a valid IPv6 address: %s", f.Name, val)
}
case "hwaddr":
if _, err := net.ParseMAC(val); err != nil {
return fmt.Errorf("field `%s` is not a valid MAC address: %s", f.Name, val)
}
case "hostname":
if _, err := idna.Lookup.ToASCII(val); err != nil {
return fmt.Errorf("field `%s` is not a valid hostname: %s", f.Name, val)
}
default:
return fmt.Errorf("field `%s` cannot use validate_type: unsupported validator: %s", f.Name, validateType)
}
}
return nil
}

View File

@@ -0,0 +1,100 @@
package confparser
import (
"net"
"testing"
"time"
"github.com/guregu/null/v6"
)
type testIPv6Address struct { //nolint:unused
Address net.IP `json:"address"`
Prefix net.IPNet `json:"prefix"`
ValidLifetime *time.Time `json:"valid_lifetime"`
PreferredLifetime *time.Time `json:"preferred_lifetime"`
Scope int `json:"scope"`
}
type testIPv4StaticConfig struct {
Address null.String `json:"address" validate_type:"ipv4" required:"true"`
Netmask null.String `json:"netmask" validate_type:"ipv4" required:"true"`
Gateway null.String `json:"gateway" validate_type:"ipv4" required:"true"`
DNS []string `json:"dns" validate_type:"ipv4" required:"true"`
}
type testIPv6StaticConfig struct {
Address null.String `json:"address" validate_type:"ipv6" required:"true"`
Prefix null.String `json:"prefix" validate_type:"ipv6" required:"true"`
Gateway null.String `json:"gateway" validate_type:"ipv6" required:"true"`
DNS []string `json:"dns" validate_type:"ipv6" required:"true"`
}
type testNetworkConfig struct {
Hostname null.String `json:"hostname,omitempty"`
Domain null.String `json:"domain,omitempty"`
IPv4Mode null.String `json:"ipv4_mode" one_of:"dhcp,static,disabled" default:"dhcp"`
IPv4Static *testIPv4StaticConfig `json:"ipv4_static,omitempty" required_if:"IPv4Mode=static"`
IPv6Mode null.String `json:"ipv6_mode" one_of:"slaac,dhcpv6,slaac_and_dhcpv6,static,link_local,disabled" default:"slaac"`
IPv6Static *testIPv6StaticConfig `json:"ipv6_static,omitempty" required_if:"IPv6Mode=static"`
LLDPMode null.String `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"`
LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"`
MDNSMode null.String `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"`
TimeSyncMode null.String `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"`
TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"`
TimeSyncDisableFallback null.Bool `json:"time_sync_disable_fallback,omitempty" default:"false"`
TimeSyncParallel null.Int `json:"time_sync_parallel,omitempty" default:"4"`
}
func TestValidateConfig(t *testing.T) {
config := &testNetworkConfig{}
err := SetDefaultsAndValidate(config)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}
func TestValidateIPv4StaticConfigRequired(t *testing.T) {
config := &testNetworkConfig{
IPv4Static: &testIPv4StaticConfig{
Address: null.StringFrom("192.168.1.1"),
Gateway: null.StringFrom("192.168.1.1"),
},
}
err := SetDefaultsAndValidate(config)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestValidateIPv4StaticConfigRequiredIf(t *testing.T) {
config := &testNetworkConfig{
IPv4Mode: null.StringFrom("static"),
}
err := SetDefaultsAndValidate(config)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestValidateIPv4StaticConfigValidateType(t *testing.T) {
config := &testNetworkConfig{
IPv4Static: &testIPv4StaticConfig{
Address: null.StringFrom("X"),
Netmask: null.StringFrom("255.255.255.0"),
Gateway: null.StringFrom("192.168.1.1"),
DNS: []string{"8.8.8.8", "8.8.4.4"},
},
IPv4Mode: null.StringFrom("static"),
}
err := SetDefaultsAndValidate(config)
if err == nil {
t.Fatalf("expected error, got nil")
}
}

View File

@@ -0,0 +1,28 @@
package confparser
import (
"fmt"
"reflect"
"strings"
"github.com/guregu/null/v6"
)
func splitString(s string) []string {
if s == "" {
return []string{}
}
return strings.Split(s, ",")
}
func toString(v interface{}) (string, error) {
switch v := v.(type) {
case string:
return v, nil
case null.String:
return v.String, nil
}
return "", fmt.Errorf("unsupported type: %s", reflect.TypeOf(v))
}

197
internal/logging/logger.go Normal file
View File

@@ -0,0 +1,197 @@
package logging
import (
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
)
type Logger struct {
l *zerolog.Logger
scopeLoggers map[string]*zerolog.Logger
scopeLevels map[string]zerolog.Level
scopeLevelMutex sync.Mutex
defaultLogLevelFromEnv zerolog.Level
defaultLogLevelFromConfig zerolog.Level
defaultLogLevel zerolog.Level
}
const (
defaultLogLevel = zerolog.ErrorLevel
)
type logOutput struct {
mu *sync.Mutex
}
func (w *logOutput) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
// TODO: write to file or syslog
if sseServer != nil {
// use a goroutine to avoid blocking the Write method
go func() {
sseServer.Message <- string(p)
}()
}
return len(p), nil
}
var (
consoleLogOutput io.Writer = zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: time.RFC3339,
PartsOrder: []string{"time", "level", "scope", "component", "message"},
FieldsExclude: []string{"scope", "component"},
FormatPartValueByName: func(value interface{}, name string) string {
val := fmt.Sprintf("%s", value)
if name == "component" {
if value == nil {
return "-"
}
}
return val
},
}
fileLogOutput io.Writer = &logOutput{mu: &sync.Mutex{}}
defaultLogOutput = zerolog.MultiLevelWriter(consoleLogOutput, fileLogOutput)
zerologLevels = map[string]zerolog.Level{
"DISABLE": zerolog.Disabled,
"NOLEVEL": zerolog.NoLevel,
"PANIC": zerolog.PanicLevel,
"FATAL": zerolog.FatalLevel,
"ERROR": zerolog.ErrorLevel,
"WARN": zerolog.WarnLevel,
"INFO": zerolog.InfoLevel,
"DEBUG": zerolog.DebugLevel,
"TRACE": zerolog.TraceLevel,
}
)
func NewLogger(zerologLogger zerolog.Logger) *Logger {
return &Logger{
l: &zerologLogger,
scopeLoggers: make(map[string]*zerolog.Logger),
scopeLevels: make(map[string]zerolog.Level),
scopeLevelMutex: sync.Mutex{},
defaultLogLevelFromEnv: -2,
defaultLogLevelFromConfig: -2,
defaultLogLevel: defaultLogLevel,
}
}
func (l *Logger) updateLogLevel() {
l.scopeLevelMutex.Lock()
defer l.scopeLevelMutex.Unlock()
l.scopeLevels = make(map[string]zerolog.Level)
finalDefaultLogLevel := l.defaultLogLevel
for name, level := range zerologLevels {
env := os.Getenv(fmt.Sprintf("JETKVM_LOG_%s", name))
if env == "" {
env = os.Getenv(fmt.Sprintf("PION_LOG_%s", name))
}
if env == "" {
env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name))
}
if env == "" {
continue
}
if strings.ToLower(env) == "all" {
l.defaultLogLevelFromEnv = level
if finalDefaultLogLevel > level {
finalDefaultLogLevel = level
}
continue
}
scopes := strings.Split(strings.ToLower(env), ",")
for _, scope := range scopes {
l.scopeLevels[scope] = level
}
}
l.defaultLogLevel = finalDefaultLogLevel
}
func (l *Logger) getScopeLoggerLevel(scope string) zerolog.Level {
if l.scopeLevels == nil {
l.updateLogLevel()
}
var scopeLevel zerolog.Level
if l.defaultLogLevelFromConfig != -2 {
scopeLevel = l.defaultLogLevelFromConfig
}
if l.defaultLogLevelFromEnv != -2 {
scopeLevel = l.defaultLogLevelFromEnv
}
// if the scope is not in the map, use the default level from the root logger
if level, ok := l.scopeLevels[scope]; ok {
scopeLevel = level
}
return scopeLevel
}
func (l *Logger) newScopeLogger(scope string) zerolog.Logger {
scopeLevel := l.getScopeLoggerLevel(scope)
logger := l.l.Level(scopeLevel).With().Str("component", scope).Logger()
return logger
}
func (l *Logger) getLogger(scope string) *zerolog.Logger {
logger, ok := l.scopeLoggers[scope]
if !ok || logger == nil {
scopeLogger := l.newScopeLogger(scope)
l.scopeLoggers[scope] = &scopeLogger
}
return l.scopeLoggers[scope]
}
func (l *Logger) UpdateLogLevel(configDefaultLogLevel string) {
needUpdate := false
if configDefaultLogLevel != "" {
if logLevel, ok := zerologLevels[configDefaultLogLevel]; ok {
l.defaultLogLevelFromConfig = logLevel
} else {
l.l.Warn().Str("logLevel", configDefaultLogLevel).Msg("invalid defaultLogLevel from config, using ERROR")
}
if l.defaultLogLevelFromConfig != l.defaultLogLevel {
needUpdate = true
}
}
l.updateLogLevel()
if needUpdate {
for scope, logger := range l.scopeLoggers {
currentLevel := logger.GetLevel()
targetLevel := l.getScopeLoggerLevel(scope)
if currentLevel != targetLevel {
*logger = l.newScopeLogger(scope)
}
}
}
}

63
internal/logging/pion.go Normal file
View File

@@ -0,0 +1,63 @@
package logging
import (
"github.com/pion/logging"
"github.com/rs/zerolog"
)
type pionLogger struct {
logger *zerolog.Logger
}
// Print all messages except trace.
func (c pionLogger) Trace(msg string) {
c.logger.Trace().Msg(msg)
}
func (c pionLogger) Tracef(format string, args ...interface{}) {
c.logger.Trace().Msgf(format, args...)
}
func (c pionLogger) Debug(msg string) {
c.logger.Debug().Msg(msg)
}
func (c pionLogger) Debugf(format string, args ...interface{}) {
c.logger.Debug().Msgf(format, args...)
}
func (c pionLogger) Info(msg string) {
c.logger.Info().Msg(msg)
}
func (c pionLogger) Infof(format string, args ...interface{}) {
c.logger.Info().Msgf(format, args...)
}
func (c pionLogger) Warn(msg string) {
c.logger.Warn().Msg(msg)
}
func (c pionLogger) Warnf(format string, args ...interface{}) {
c.logger.Warn().Msgf(format, args...)
}
func (c pionLogger) Error(msg string) {
c.logger.Error().Msg(msg)
}
func (c pionLogger) Errorf(format string, args ...interface{}) {
c.logger.Error().Msgf(format, args...)
}
// customLoggerFactory satisfies the interface logging.LoggerFactory
// This allows us to create different loggers per subsystem. So we can
// add custom behavior.
type pionLoggerFactory struct{}
func (c pionLoggerFactory) NewLogger(subsystem string) logging.LeveledLogger {
logger := rootLogger.getLogger(subsystem).With().
Str("scope", "pion").
Str("component", subsystem).
Logger()
return pionLogger{logger: &logger}
}
var defaultLoggerFactory = &pionLoggerFactory{}
func GetPionDefaultLoggerFactory() logging.LoggerFactory {
return defaultLoggerFactory
}

20
internal/logging/root.go Normal file
View File

@@ -0,0 +1,20 @@
package logging
import "github.com/rs/zerolog"
var (
rootZerologLogger = zerolog.New(defaultLogOutput).With().
Str("scope", "jetkvm").
Timestamp().
Stack().
Logger()
rootLogger = NewLogger(rootZerologLogger)
)
func GetRootLogger() *Logger {
return rootLogger
}
func GetSubsystemLogger(subsystem string) *zerolog.Logger {
return rootLogger.getLogger(subsystem)
}

137
internal/logging/sse.go Normal file
View File

@@ -0,0 +1,137 @@
package logging
import (
"embed"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog"
)
//go:embed sse.html
var sseHTML embed.FS
type sseEvent struct {
Message chan string
NewClients chan chan string
ClosedClients chan chan string
TotalClients map[chan string]bool
}
// New event messages are broadcast to all registered client connection channels
type sseClientChan chan string
var (
sseServer *sseEvent
sseLogger *zerolog.Logger
)
func init() {
sseServer = newSseServer()
sseLogger = GetSubsystemLogger("sse")
}
// Initialize event and Start procnteessing requests
func newSseServer() (event *sseEvent) {
event = &sseEvent{
Message: make(chan string),
NewClients: make(chan chan string),
ClosedClients: make(chan chan string),
TotalClients: make(map[chan string]bool),
}
go event.listen()
return
}
// It Listens all incoming requests from clients.
// Handles addition and removal of clients and broadcast messages to clients.
func (stream *sseEvent) listen() {
for {
select {
// Add new available client
case client := <-stream.NewClients:
stream.TotalClients[client] = true
sseLogger.Info().
Int("total_clients", len(stream.TotalClients)).
Msg("new client connected")
// Remove closed client
case client := <-stream.ClosedClients:
delete(stream.TotalClients, client)
close(client)
sseLogger.Info().Int("total_clients", len(stream.TotalClients)).Msg("client disconnected")
// Broadcast message to client
case eventMsg := <-stream.Message:
for clientMessageChan := range stream.TotalClients {
select {
case clientMessageChan <- eventMsg:
// Message sent successfully
default:
// Failed to send, dropping message
}
}
}
}
}
func (stream *sseEvent) serveHTTP() gin.HandlerFunc {
return func(c *gin.Context) {
clientChan := make(sseClientChan)
stream.NewClients <- clientChan
go func() {
<-c.Writer.CloseNotify()
for range clientChan {
}
stream.ClosedClients <- clientChan
}()
c.Set("clientChan", clientChan)
c.Next()
}
}
func sseHeadersMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method == "GET" && c.NegotiateFormat(gin.MIMEHTML) == gin.MIMEHTML {
c.FileFromFS("/sse.html", http.FS(sseHTML))
c.Status(http.StatusOK)
c.Abort()
return
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Next()
}
}
func AttachSSEHandler(router *gin.RouterGroup) {
router.StaticFS("/log-stream", http.FS(sseHTML))
router.GET("/log-stream", sseHeadersMiddleware(), sseServer.serveHTTP(), func(c *gin.Context) {
v, ok := c.Get("clientChan")
if !ok {
return
}
clientChan, ok := v.(sseClientChan)
if !ok {
return
}
c.Stream(func(w io.Writer) bool {
if msg, ok := <-clientChan; ok {
c.SSEvent("message", msg)
return true
}
return false
})
})
}

319
internal/logging/sse.html Normal file
View File

@@ -0,0 +1,319 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Server Sent Event</title>
<style>
.main-container {
display: flex;
flex-direction: column;
gap: 10px;
font-family: 'Hack', monospace;
font-size: 12px;
}
#loading {
font-style: italic;
}
.log-entry {
font-size: 12px;
line-height: 1.2;
}
.log-entry > span {
min-width: 0;
overflow-wrap: break-word;
word-break: break-word;
margin-right: 10px;
}
.log-entry > span:last-child {
margin-right: 0;
}
.log-entry.log-entry-trace .log-level {
color: blue;
}
.log-entry.log-entry-debug .log-level {
color: gray;
}
.log-entry.log-entry-info .log-level {
color: green;
}
.log-entry.log-entry-warn .log-level {
color: yellow;
}
.log-entry.log-entry-error .log-level,
.log-entry.log-entry-fatal .log-level,
.log-entry.log-entry-panic .log-level {
color: red;
}
.log-entry.log-entry-info .log-message,
.log-entry.log-entry-warn .log-message,
.log-entry.log-entry-error .log-message,
.log-entry.log-entry-fatal .log-message,
.log-entry.log-entry-panic .log-message {
font-weight: bold;
}
.log-timestamp {
color: #666;
min-width: 150px;
}
.log-level {
font-size: 12px;
min-width: 50px;
}
.log-scope {
font-size: 12px;
min-width: 40px;
}
.log-component {
font-size: 12px;
min-width: 80px;
}
.log-message {
font-size: 12px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
max-width: 500px;
}
.log-extras {
color: #000;
}
.log-extras .log-extras-header {
font-weight: bold;
color:cornflowerblue;
}
</style>
</head>
<body>
<div class="main-container">
<div id="header">
<span id="loading">
Connecting to log stream...
</span>
<span id="stats">
</span>
</div>
<div id="event-data">
</div>
</div>
</body>
<script>
class LogStream {
constructor(url, eventDataElement, loadingElement, statsElement) {
this.url = url;
this.eventDataElement = eventDataElement;
this.loadingElement = loadingElement;
this.statsElement = statsElement;
this.stream = null;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 10;
this.reconnectDelay = 1000; // Start with 1 second
this.maxReconnectDelay = 30000; // Max 30 seconds
this.isConnecting = false;
this.totalMessages = 0;
this.connect();
}
connect() {
if (this.isConnecting) return;
this.isConnecting = true;
this.loadingElement.innerText = "Connecting to log stream...";
this.stream = new EventSource(this.url);
this.stream.onopen = () => {
this.isConnecting = false;
this.reconnectAttempts = 0;
this.reconnectDelay = 1000;
this.loadingElement.innerText = "Log stream connected.";
this.totalMessages = 0;
this.totalBytes = 0;
};
this.stream.onmessage = (event) => {
this.totalBytes += event.data.length;
this.totalMessages++;
const data = JSON.parse(event.data);
this.addLogEntry(data);
this.updateStats();
};
this.stream.onerror = () => {
this.isConnecting = false;
this.loadingElement.innerText = "Log stream disconnected.";
this.stream.close();
this.handleReconnect();
};
}
updateStats() {
this.statsElement.innerHTML = `Messages: <strong>${this.totalMessages}</strong>, Bytes: <strong>${this.totalBytes}</strong> `;
}
handleReconnect() {
if (this.reconnectAttempts >= this.maxReconnectAttempts) {
this.loadingElement.innerText = "Failed to reconnect after multiple attempts";
return;
}
this.reconnectAttempts++;
this.reconnectDelay = Math.min(this.reconnectDelay * 1, this.maxReconnectDelay);
this.loadingElement.innerText = `Reconnecting in ${this.reconnectDelay/1000} seconds... (Attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts})`;
setTimeout(() => {
this.connect();
}, this.reconnectDelay);
}
addLogEntry(data) {
const el = document.createElement("div");
el.className = "log-entry log-entry-" + data.level;
const timestamp = document.createElement("span");
timestamp.className = "log-timestamp";
timestamp.innerText = data.time;
el.appendChild(timestamp);
const level = document.createElement("span");
level.className = "log-level";
level.innerText = this.shortLogLevel(data.level);
el.appendChild(level);
const scope = document.createElement("span");
scope.className = "log-scope";
scope.innerText = data.scope;
el.appendChild(scope);
const component = document.createElement("span");
component.className = "log-component";
component.innerText = data.component;
el.appendChild(component);
const message = document.createElement("span");
message.className = "log-message";
message.innerText = data.message;
el.appendChild(message);
this.addLogExtras(el, data);
this.eventDataElement.appendChild(el);
window.scrollTo(0, document.body.scrollHeight);
}
shortLogLevel(level) {
switch (level) {
case "trace":
return "TRC";
case "debug":
return "DBG";
case "info":
return "INF";
case "warn":
return "WRN";
case "error":
return "ERR";
case "fatal":
return "FTL";
case "panic":
return "PNC";
default:
return level;
}
}
addLogExtras(el, data) {
const excludeKeys = [
"timestamp",
"time",
"level",
"scope",
"component",
"message",
];
const extras = {};
for (const key in data) {
if (excludeKeys.includes(key)) {
continue;
}
extras[key] = data[key];
}
for (const key in extras) {
const extra = document.createElement("span");
extra.className = "log-extras log-extras-" + key;
const extraKey = document.createElement("span");
extraKey.className = "log-extras-header";
extraKey.innerText = key + '=';
extra.appendChild(extraKey);
const extraValue = document.createElement("span");
extraValue.className = "log-extras-value";
let value = extras[key];
if (typeof value === 'object') {
value = JSON.stringify(value);
}
extraValue.innerText = value;
extra.appendChild(extraValue);
el.appendChild(extra);
}
}
disconnect() {
if (this.stream) {
this.stream.close();
this.stream = null;
}
}
}
// Initialize the log stream when the page loads
document.addEventListener('DOMContentLoaded', () => {
const logStream = new LogStream(
"/developer/log-stream",
document.getElementById("event-data"),
document.getElementById("loading"),
document.getElementById("stats"),
);
// Clean up when the page is unloaded
window.addEventListener('beforeunload', () => {
logStream.disconnect();
});
});
</script>
</html>

32
internal/logging/utils.go Normal file
View File

@@ -0,0 +1,32 @@
package logging
import (
"fmt"
"os"
"github.com/rs/zerolog"
)
var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel)
func GetDefaultLogger() *zerolog.Logger {
return &defaultLogger
}
func ErrorfL(l *zerolog.Logger, format string, err error, args ...interface{}) error {
// TODO: move rootLogger to logging package
if l == nil {
l = &defaultLogger
}
l.Error().Err(err).Msgf(format, args...)
if err == nil {
return fmt.Errorf(format, args...)
}
err_msg := err.Error() + ": %v"
err_args := append(args, err)
return fmt.Errorf(err_msg, err_args...)
}

190
internal/mdns/mdns.go Normal file
View File

@@ -0,0 +1,190 @@
package mdns
import (
"fmt"
"net"
"reflect"
"strings"
"sync"
"github.com/jetkvm/kvm/internal/logging"
pion_mdns "github.com/pion/mdns/v2"
"github.com/rs/zerolog"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type MDNS struct {
conn *pion_mdns.Conn
lock sync.Mutex
l *zerolog.Logger
localNames []string
listenOptions *MDNSListenOptions
}
type MDNSListenOptions struct {
IPv4 bool
IPv6 bool
}
type MDNSOptions struct {
Logger *zerolog.Logger
LocalNames []string
ListenOptions *MDNSListenOptions
}
const (
DefaultAddressIPv4 = pion_mdns.DefaultAddressIPv4
DefaultAddressIPv6 = pion_mdns.DefaultAddressIPv6
)
func NewMDNS(opts *MDNSOptions) (*MDNS, error) {
if opts.Logger == nil {
opts.Logger = logging.GetDefaultLogger()
}
if opts.ListenOptions == nil {
opts.ListenOptions = &MDNSListenOptions{
IPv4: true,
IPv6: true,
}
}
return &MDNS{
l: opts.Logger,
lock: sync.Mutex{},
localNames: opts.LocalNames,
listenOptions: opts.ListenOptions,
}, nil
}
func (m *MDNS) start(allowRestart bool) error {
m.lock.Lock()
defer m.lock.Unlock()
if m.conn != nil {
if !allowRestart {
return fmt.Errorf("mDNS server already running")
}
m.conn.Close()
}
if m.listenOptions == nil {
return fmt.Errorf("listen options not set")
}
if !m.listenOptions.IPv4 && !m.listenOptions.IPv6 {
m.l.Info().Msg("mDNS server disabled")
return nil
}
var (
addr4, addr6 *net.UDPAddr
l4, l6 *net.UDPConn
p4 *ipv4.PacketConn
p6 *ipv6.PacketConn
err error
)
if m.listenOptions.IPv4 {
addr4, err = net.ResolveUDPAddr("udp4", DefaultAddressIPv4)
if err != nil {
return err
}
l4, err = net.ListenUDP("udp4", addr4)
if err != nil {
return err
}
p4 = ipv4.NewPacketConn(l4)
}
if m.listenOptions.IPv6 {
addr6, err = net.ResolveUDPAddr("udp6", DefaultAddressIPv6)
if err != nil {
return err
}
l6, err = net.ListenUDP("udp6", addr6)
if err != nil {
return err
}
p6 = ipv6.NewPacketConn(l6)
}
scopeLogger := m.l.With().
Interface("local_names", m.localNames).
Bool("ipv4", m.listenOptions.IPv4).
Bool("ipv6", m.listenOptions.IPv6).
Logger()
newLocalNames := make([]string, len(m.localNames))
for i, name := range m.localNames {
newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".")
if !strings.HasSuffix(newLocalNames[i], ".local") {
newLocalNames[i] = newLocalNames[i] + ".local"
}
}
mDNSConn, err := pion_mdns.Server(p4, p6, &pion_mdns.Config{
LocalNames: newLocalNames,
LoggerFactory: logging.GetPionDefaultLoggerFactory(),
})
if err != nil {
scopeLogger.Warn().Err(err).Msg("failed to start mDNS server")
return err
}
m.conn = mDNSConn
scopeLogger.Info().Msg("mDNS server started")
return nil
}
func (m *MDNS) Start() error {
return m.start(false)
}
func (m *MDNS) Restart() error {
return m.start(true)
}
func (m *MDNS) Stop() error {
m.lock.Lock()
defer m.lock.Unlock()
if m.conn == nil {
return nil
}
return m.conn.Close()
}
func (m *MDNS) SetLocalNames(localNames []string, always bool) error {
if reflect.DeepEqual(m.localNames, localNames) && !always {
return nil
}
m.localNames = localNames
_ = m.Restart()
return nil
}
func (m *MDNS) SetListenOptions(listenOptions *MDNSListenOptions) error {
if m.listenOptions != nil &&
m.listenOptions.IPv4 == listenOptions.IPv4 &&
m.listenOptions.IPv6 == listenOptions.IPv6 {
return nil
}
m.listenOptions = listenOptions
_ = m.Restart()
return nil
}

1
internal/mdns/utils.go Normal file
View File

@@ -0,0 +1 @@
package mdns

110
internal/network/config.go Normal file
View File

@@ -0,0 +1,110 @@
package network
import (
"fmt"
"net"
"time"
"github.com/guregu/null/v6"
"github.com/jetkvm/kvm/internal/mdns"
"golang.org/x/net/idna"
)
type IPv6Address struct {
Address net.IP `json:"address"`
Prefix net.IPNet `json:"prefix"`
ValidLifetime *time.Time `json:"valid_lifetime"`
PreferredLifetime *time.Time `json:"preferred_lifetime"`
Scope int `json:"scope"`
}
type IPv4StaticConfig struct {
Address null.String `json:"address,omitempty" validate_type:"ipv4" required:"true"`
Netmask null.String `json:"netmask,omitempty" validate_type:"ipv4" required:"true"`
Gateway null.String `json:"gateway,omitempty" validate_type:"ipv4" required:"true"`
DNS []string `json:"dns,omitempty" validate_type:"ipv4" required:"true"`
}
type IPv6StaticConfig struct {
Address null.String `json:"address,omitempty" validate_type:"ipv6" required:"true"`
Prefix null.String `json:"prefix,omitempty" validate_type:"ipv6" required:"true"`
Gateway null.String `json:"gateway,omitempty" validate_type:"ipv6" required:"true"`
DNS []string `json:"dns,omitempty" validate_type:"ipv6" required:"true"`
}
type NetworkConfig struct {
Hostname null.String `json:"hostname,omitempty" validate_type:"hostname"`
Domain null.String `json:"domain,omitempty" validate_type:"hostname"`
IPv4Mode null.String `json:"ipv4_mode,omitempty" one_of:"dhcp,static,disabled" default:"dhcp"`
IPv4Static *IPv4StaticConfig `json:"ipv4_static,omitempty" required_if:"IPv4Mode=static"`
IPv6Mode null.String `json:"ipv6_mode,omitempty" one_of:"slaac,dhcpv6,slaac_and_dhcpv6,static,link_local,disabled" default:"slaac"`
IPv6Static *IPv6StaticConfig `json:"ipv6_static,omitempty" required_if:"IPv6Mode=static"`
LLDPMode null.String `json:"lldp_mode,omitempty" one_of:"disabled,basic,all" default:"basic"`
LLDPTxTLVs []string `json:"lldp_tx_tlvs,omitempty" one_of:"chassis,port,system,vlan" default:"chassis,port,system,vlan"`
MDNSMode null.String `json:"mdns_mode,omitempty" one_of:"disabled,auto,ipv4_only,ipv6_only" default:"auto"`
TimeSyncMode null.String `json:"time_sync_mode,omitempty" one_of:"ntp_only,ntp_and_http,http_only,custom" default:"ntp_and_http"`
TimeSyncOrdering []string `json:"time_sync_ordering,omitempty" one_of:"http,ntp,ntp_dhcp,ntp_user_provided,ntp_fallback" default:"ntp,http"`
TimeSyncDisableFallback null.Bool `json:"time_sync_disable_fallback,omitempty" default:"false"`
TimeSyncParallel null.Int `json:"time_sync_parallel,omitempty" default:"4"`
}
func (c *NetworkConfig) GetMDNSMode() *mdns.MDNSListenOptions {
mode := c.MDNSMode.String
listenOptions := &mdns.MDNSListenOptions{
IPv4: true,
IPv6: true,
}
switch mode {
case "ipv4_only":
listenOptions.IPv6 = false
case "ipv6_only":
listenOptions.IPv4 = false
case "disabled":
listenOptions.IPv4 = false
listenOptions.IPv6 = false
}
return listenOptions
}
func (s *NetworkInterfaceState) GetHostname() string {
hostname := ToValidHostname(s.config.Hostname.String)
if hostname == "" {
return s.defaultHostname
}
return hostname
}
func ToValidDomain(domain string) string {
ascii, err := idna.Lookup.ToASCII(domain)
if err != nil {
return ""
}
return ascii
}
func (s *NetworkInterfaceState) GetDomain() string {
domain := ToValidDomain(s.config.Domain.String)
if domain == "" {
lease := s.dhcpClient.GetLease()
if lease != nil && lease.Domain != "" {
domain = ToValidDomain(lease.Domain)
}
}
if domain == "" {
return "local"
}
return domain
}
func (s *NetworkInterfaceState) GetFQDN() string {
return fmt.Sprintf("%s.%s", s.GetHostname(), s.GetDomain())
}

11
internal/network/dhcp.go Normal file
View File

@@ -0,0 +1,11 @@
package network
type DhcpTargetState int
const (
DhcpTargetStateDoNothing DhcpTargetState = iota
DhcpTargetStateStart
DhcpTargetStateStop
DhcpTargetStateRenew
DhcpTargetStateRelease
)

View File

@@ -0,0 +1,137 @@
package network
import (
"fmt"
"io"
"os"
"os/exec"
"strings"
"sync"
"golang.org/x/net/idna"
)
const (
hostnamePath = "/etc/hostname"
hostsPath = "/etc/hosts"
)
var (
hostnameLock sync.Mutex = sync.Mutex{}
)
func updateEtcHosts(hostname string, fqdn string) error {
// update /etc/hosts
hostsFile, err := os.OpenFile(hostsPath, os.O_RDWR|os.O_SYNC, os.ModeExclusive)
if err != nil {
return fmt.Errorf("failed to open %s: %w", hostsPath, err)
}
defer hostsFile.Close()
// read all lines
if _, err := hostsFile.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek %s: %w", hostsPath, err)
}
lines, err := io.ReadAll(hostsFile)
if err != nil {
return fmt.Errorf("failed to read %s: %w", hostsPath, err)
}
newLines := []string{}
hostLine := fmt.Sprintf("127.0.1.1\t%s %s", hostname, fqdn)
hostLineExists := false
for _, line := range strings.Split(string(lines), "\n") {
if strings.HasPrefix(line, "127.0.1.1") {
hostLineExists = true
line = hostLine
}
newLines = append(newLines, line)
}
if !hostLineExists {
newLines = append(newLines, hostLine)
}
if err := hostsFile.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate %s: %w", hostsPath, err)
}
if _, err := hostsFile.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek %s: %w", hostsPath, err)
}
if _, err := hostsFile.Write([]byte(strings.Join(newLines, "\n"))); err != nil {
return fmt.Errorf("failed to write %s: %w", hostsPath, err)
}
return nil
}
func ToValidHostname(hostname string) string {
ascii, err := idna.Lookup.ToASCII(hostname)
if err != nil {
return ""
}
return ascii
}
func SetHostname(hostname string, fqdn string) error {
hostnameLock.Lock()
defer hostnameLock.Unlock()
hostname = ToValidHostname(strings.TrimSpace(hostname))
fqdn = ToValidHostname(strings.TrimSpace(fqdn))
if hostname == "" {
return fmt.Errorf("invalid hostname: %s", hostname)
}
if fqdn == "" {
fqdn = hostname
}
// update /etc/hostname
if err := os.WriteFile(hostnamePath, []byte(hostname), 0644); err != nil {
return fmt.Errorf("failed to write %s: %w", hostnamePath, err)
}
// update /etc/hosts
if err := updateEtcHosts(hostname, fqdn); err != nil {
return fmt.Errorf("failed to update /etc/hosts: %w", err)
}
// run hostname
if err := exec.Command("hostname", "-F", hostnamePath).Run(); err != nil {
return fmt.Errorf("failed to run hostname: %w", err)
}
return nil
}
func (s *NetworkInterfaceState) setHostnameIfNotSame() error {
hostname := s.GetHostname()
currentHostname, _ := os.Hostname()
fqdn := fmt.Sprintf("%s.%s", hostname, s.GetDomain())
if currentHostname == hostname && s.currentFqdn == fqdn && s.currentHostname == hostname {
return nil
}
scopedLogger := s.l.With().Str("hostname", hostname).Str("fqdn", fqdn).Logger()
err := SetHostname(hostname, fqdn)
if err != nil {
scopedLogger.Error().Err(err).Msg("failed to set hostname")
return err
}
s.currentHostname = hostname
s.currentFqdn = fqdn
scopedLogger.Info().Msg("hostname set")
return nil
}

346
internal/network/netif.go Normal file
View File

@@ -0,0 +1,346 @@
package network
import (
"fmt"
"net"
"sync"
"github.com/jetkvm/kvm/internal/confparser"
"github.com/jetkvm/kvm/internal/logging"
"github.com/jetkvm/kvm/internal/udhcpc"
"github.com/rs/zerolog"
"github.com/vishvananda/netlink"
)
type NetworkInterfaceState struct {
interfaceName string
interfaceUp bool
ipv4Addr *net.IP
ipv4Addresses []string
ipv6Addr *net.IP
ipv6Addresses []IPv6Address
ipv6LinkLocal *net.IP
macAddr *net.HardwareAddr
l *zerolog.Logger
stateLock sync.Mutex
config *NetworkConfig
dhcpClient *udhcpc.DHCPClient
defaultHostname string
currentHostname string
currentFqdn string
onStateChange func(state *NetworkInterfaceState)
onInitialCheck func(state *NetworkInterfaceState)
cbConfigChange func(config *NetworkConfig)
checked bool
}
type NetworkInterfaceOptions struct {
InterfaceName string
DhcpPidFile string
Logger *zerolog.Logger
DefaultHostname string
OnStateChange func(state *NetworkInterfaceState)
OnInitialCheck func(state *NetworkInterfaceState)
OnDhcpLeaseChange func(lease *udhcpc.Lease)
OnConfigChange func(config *NetworkConfig)
NetworkConfig *NetworkConfig
}
func NewNetworkInterfaceState(opts *NetworkInterfaceOptions) (*NetworkInterfaceState, error) {
if opts.NetworkConfig == nil {
return nil, fmt.Errorf("NetworkConfig can not be nil")
}
if opts.DefaultHostname == "" {
opts.DefaultHostname = "jetkvm"
}
err := confparser.SetDefaultsAndValidate(opts.NetworkConfig)
if err != nil {
return nil, err
}
l := opts.Logger
s := &NetworkInterfaceState{
interfaceName: opts.InterfaceName,
defaultHostname: opts.DefaultHostname,
stateLock: sync.Mutex{},
l: l,
onStateChange: opts.OnStateChange,
onInitialCheck: opts.OnInitialCheck,
cbConfigChange: opts.OnConfigChange,
config: opts.NetworkConfig,
}
// create the dhcp client
dhcpClient := udhcpc.NewDHCPClient(&udhcpc.DHCPClientOptions{
InterfaceName: opts.InterfaceName,
PidFile: opts.DhcpPidFile,
Logger: l,
OnLeaseChange: func(lease *udhcpc.Lease) {
_, err := s.update()
if err != nil {
opts.Logger.Error().Err(err).Msg("failed to update network state")
return
}
_ = s.setHostnameIfNotSame()
opts.OnDhcpLeaseChange(lease)
},
})
s.dhcpClient = dhcpClient
return s, nil
}
func (s *NetworkInterfaceState) IsUp() bool {
return s.interfaceUp
}
func (s *NetworkInterfaceState) HasIPAssigned() bool {
return s.ipv4Addr != nil || s.ipv6Addr != nil
}
func (s *NetworkInterfaceState) IsOnline() bool {
return s.IsUp() && s.HasIPAssigned()
}
func (s *NetworkInterfaceState) IPv4() *net.IP {
return s.ipv4Addr
}
func (s *NetworkInterfaceState) IPv4String() string {
if s.ipv4Addr == nil {
return "..."
}
return s.ipv4Addr.String()
}
func (s *NetworkInterfaceState) IPv6() *net.IP {
return s.ipv6Addr
}
func (s *NetworkInterfaceState) IPv6String() string {
if s.ipv6Addr == nil {
return "..."
}
return s.ipv6Addr.String()
}
func (s *NetworkInterfaceState) MAC() *net.HardwareAddr {
return s.macAddr
}
func (s *NetworkInterfaceState) MACString() string {
if s.macAddr == nil {
return ""
}
return s.macAddr.String()
}
func (s *NetworkInterfaceState) update() (DhcpTargetState, error) {
s.stateLock.Lock()
defer s.stateLock.Unlock()
dhcpTargetState := DhcpTargetStateDoNothing
iface, err := netlink.LinkByName(s.interfaceName)
if err != nil {
s.l.Error().Err(err).Msg("failed to get interface")
return dhcpTargetState, err
}
// detect if the interface status changed
var changed bool
attrs := iface.Attrs()
state := attrs.OperState
newInterfaceUp := state == netlink.OperUp
// check if the interface is coming up
interfaceGoingUp := !s.interfaceUp && newInterfaceUp
interfaceGoingDown := s.interfaceUp && !newInterfaceUp
if s.interfaceUp != newInterfaceUp {
s.interfaceUp = newInterfaceUp
changed = true
}
if changed {
if interfaceGoingUp {
s.l.Info().Msg("interface state transitioned to up")
dhcpTargetState = DhcpTargetStateRenew
} else if interfaceGoingDown {
s.l.Info().Msg("interface state transitioned to down")
}
}
// set the mac address
s.macAddr = &attrs.HardwareAddr
// get the ip addresses
addrs, err := netlinkAddrs(iface)
if err != nil {
return dhcpTargetState, logging.ErrorfL(s.l, "failed to get ip addresses", err)
}
var (
ipv4Addresses = make([]net.IP, 0)
ipv4AddressesString = make([]string, 0)
ipv6Addresses = make([]IPv6Address, 0)
// ipv6AddressesString = make([]string, 0)
ipv6LinkLocal *net.IP
)
for _, addr := range addrs {
if addr.IP.To4() != nil {
scopedLogger := s.l.With().Str("ipv4", addr.IP.String()).Logger()
if interfaceGoingDown {
// remove all IPv4 addresses from the interface.
scopedLogger.Info().Msg("state transitioned to down, removing IPv4 address")
err := netlink.AddrDel(iface, &addr)
if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to delete address")
}
// notify the DHCP client to release the lease
dhcpTargetState = DhcpTargetStateRelease
continue
}
ipv4Addresses = append(ipv4Addresses, addr.IP)
ipv4AddressesString = append(ipv4AddressesString, addr.IPNet.String())
} else if addr.IP.To16() != nil {
scopedLogger := s.l.With().Str("ipv6", addr.IP.String()).Logger()
// check if it's a link local address
if addr.IP.IsLinkLocalUnicast() {
ipv6LinkLocal = &addr.IP
continue
}
if !addr.IP.IsGlobalUnicast() {
scopedLogger.Trace().Msg("not a global unicast address, skipping")
continue
}
if interfaceGoingDown {
scopedLogger.Info().Msg("state transitioned to down, removing IPv6 address")
err := netlink.AddrDel(iface, &addr)
if err != nil {
scopedLogger.Warn().Err(err).Msg("failed to delete address")
}
continue
}
ipv6Addresses = append(ipv6Addresses, IPv6Address{
Address: addr.IP,
Prefix: *addr.IPNet,
ValidLifetime: lifetimeToTime(addr.ValidLft),
PreferredLifetime: lifetimeToTime(addr.PreferedLft),
Scope: addr.Scope,
})
// ipv6AddressesString = append(ipv6AddressesString, addr.IPNet.String())
}
}
if len(ipv4Addresses) > 0 {
// compare the addresses to see if there's a change
if s.ipv4Addr == nil || s.ipv4Addr.String() != ipv4Addresses[0].String() {
scopedLogger := s.l.With().Str("ipv4", ipv4Addresses[0].String()).Logger()
if s.ipv4Addr != nil {
scopedLogger.Info().
Str("old_ipv4", s.ipv4Addr.String()).
Msg("IPv4 address changed")
} else {
scopedLogger.Info().Msg("IPv4 address found")
}
s.ipv4Addr = &ipv4Addresses[0]
changed = true
}
}
s.ipv4Addresses = ipv4AddressesString
if ipv6LinkLocal != nil {
if s.ipv6LinkLocal == nil || s.ipv6LinkLocal.String() != ipv6LinkLocal.String() {
scopedLogger := s.l.With().Str("ipv6", ipv6LinkLocal.String()).Logger()
if s.ipv6LinkLocal != nil {
scopedLogger.Info().
Str("old_ipv6", s.ipv6LinkLocal.String()).
Msg("IPv6 link local address changed")
} else {
scopedLogger.Info().Msg("IPv6 link local address found")
}
s.ipv6LinkLocal = ipv6LinkLocal
changed = true
}
}
s.ipv6Addresses = ipv6Addresses
if len(ipv6Addresses) > 0 {
// compare the addresses to see if there's a change
if s.ipv6Addr == nil || s.ipv6Addr.String() != ipv6Addresses[0].Address.String() {
scopedLogger := s.l.With().Str("ipv6", ipv6Addresses[0].Address.String()).Logger()
if s.ipv6Addr != nil {
scopedLogger.Info().
Str("old_ipv6", s.ipv6Addr.String()).
Msg("IPv6 address changed")
} else {
scopedLogger.Info().Msg("IPv6 address found")
}
s.ipv6Addr = &ipv6Addresses[0].Address
changed = true
}
}
// if it's the initial check, we'll set changed to false
initialCheck := !s.checked
if initialCheck {
s.checked = true
changed = false
if dhcpTargetState == DhcpTargetStateRenew {
// it's the initial check, we'll start the DHCP client
// dhcpTargetState = DhcpTargetStateStart
// TODO: manage DHCP client start/stop
dhcpTargetState = DhcpTargetStateDoNothing
}
}
if initialCheck {
s.onInitialCheck(s)
} else if changed {
s.onStateChange(s)
}
return dhcpTargetState, nil
}
func (s *NetworkInterfaceState) CheckAndUpdateDhcp() error {
dhcpTargetState, err := s.update()
if err != nil {
return logging.ErrorfL(s.l, "failed to update network state", err)
}
switch dhcpTargetState {
case DhcpTargetStateRenew:
s.l.Info().Msg("renewing DHCP lease")
_ = s.dhcpClient.Renew()
case DhcpTargetStateRelease:
s.l.Info().Msg("releasing DHCP lease")
_ = s.dhcpClient.Release()
case DhcpTargetStateStart:
s.l.Warn().Msg("dhcpTargetStateStart not implemented")
case DhcpTargetStateStop:
s.l.Warn().Msg("dhcpTargetStateStop not implemented")
}
return nil
}
func (s *NetworkInterfaceState) onConfigChange(config *NetworkConfig) {
_ = s.setHostnameIfNotSame()
s.cbConfigChange(config)
}

View File

@@ -0,0 +1,58 @@
//go:build linux
package network
import (
"time"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
)
func (s *NetworkInterfaceState) HandleLinkUpdate(update netlink.LinkUpdate) {
if update.Link.Attrs().Name == s.interfaceName {
s.l.Info().Interface("update", update).Msg("interface link update received")
_ = s.CheckAndUpdateDhcp()
}
}
func (s *NetworkInterfaceState) Run() error {
updates := make(chan netlink.LinkUpdate)
done := make(chan struct{})
if err := netlink.LinkSubscribe(updates, done); err != nil {
s.l.Warn().Err(err).Msg("failed to subscribe to link updates")
return err
}
_ = s.setHostnameIfNotSame()
// run the dhcp client
go s.dhcpClient.Run() // nolint:errcheck
if err := s.CheckAndUpdateDhcp(); err != nil {
return err
}
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case update := <-updates:
s.HandleLinkUpdate(update)
case <-ticker.C:
_ = s.CheckAndUpdateDhcp()
case <-done:
return
}
}
}()
return nil
}
func netlinkAddrs(iface netlink.Link) ([]netlink.Addr, error) {
return netlink.AddrList(iface, nl.FAMILY_ALL)
}

View File

@@ -0,0 +1,21 @@
//go:build !linux
package network
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (s *NetworkInterfaceState) HandleLinkUpdate() error {
return fmt.Errorf("not implemented")
}
func (s *NetworkInterfaceState) Run() error {
return fmt.Errorf("not implemented")
}
func netlinkAddrs(iface netlink.Link) ([]netlink.Addr, error) {
return nil, fmt.Errorf("not implemented")
}

126
internal/network/rpc.go Normal file
View File

@@ -0,0 +1,126 @@
package network
import (
"fmt"
"time"
"github.com/jetkvm/kvm/internal/confparser"
"github.com/jetkvm/kvm/internal/udhcpc"
)
type RpcIPv6Address struct {
Address string `json:"address"`
ValidLifetime *time.Time `json:"valid_lifetime,omitempty"`
PreferredLifetime *time.Time `json:"preferred_lifetime,omitempty"`
Scope int `json:"scope"`
}
type RpcNetworkState struct {
InterfaceName string `json:"interface_name"`
MacAddress string `json:"mac_address"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
IPv6LinkLocal string `json:"ipv6_link_local,omitempty"`
IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
IPv6Addresses []RpcIPv6Address `json:"ipv6_addresses,omitempty"`
DHCPLease *udhcpc.Lease `json:"dhcp_lease,omitempty"`
}
type RpcNetworkSettings struct {
NetworkConfig
}
func (s *NetworkInterfaceState) MacAddress() string {
if s.macAddr == nil {
return ""
}
return s.macAddr.String()
}
func (s *NetworkInterfaceState) IPv4Address() string {
if s.ipv4Addr == nil {
return ""
}
return s.ipv4Addr.String()
}
func (s *NetworkInterfaceState) IPv6Address() string {
if s.ipv6Addr == nil {
return ""
}
return s.ipv6Addr.String()
}
func (s *NetworkInterfaceState) IPv6LinkLocalAddress() string {
if s.ipv6LinkLocal == nil {
return ""
}
return s.ipv6LinkLocal.String()
}
func (s *NetworkInterfaceState) RpcGetNetworkState() RpcNetworkState {
ipv6Addresses := make([]RpcIPv6Address, 0)
if s.ipv6Addresses != nil {
for _, addr := range s.ipv6Addresses {
ipv6Addresses = append(ipv6Addresses, RpcIPv6Address{
Address: addr.Prefix.String(),
ValidLifetime: addr.ValidLifetime,
PreferredLifetime: addr.PreferredLifetime,
Scope: addr.Scope,
})
}
}
return RpcNetworkState{
InterfaceName: s.interfaceName,
MacAddress: s.MacAddress(),
IPv4: s.IPv4Address(),
IPv6: s.IPv6Address(),
IPv6LinkLocal: s.IPv6LinkLocalAddress(),
IPv4Addresses: s.ipv4Addresses,
IPv6Addresses: ipv6Addresses,
DHCPLease: s.dhcpClient.GetLease(),
}
}
func (s *NetworkInterfaceState) RpcGetNetworkSettings() RpcNetworkSettings {
if s.config == nil {
return RpcNetworkSettings{}
}
return RpcNetworkSettings{
NetworkConfig: *s.config,
}
}
func (s *NetworkInterfaceState) RpcSetNetworkSettings(settings RpcNetworkSettings) error {
currentSettings := s.config
err := confparser.SetDefaultsAndValidate(&settings.NetworkConfig)
if err != nil {
return err
}
if IsSame(currentSettings, settings.NetworkConfig) {
// no changes, do nothing
return nil
}
s.config = &settings.NetworkConfig
s.onConfigChange(s.config)
return nil
}
func (s *NetworkInterfaceState) RpcRenewDHCPLease() error {
if s.dhcpClient == nil {
return fmt.Errorf("dhcp client not initialized")
}
return s.dhcpClient.Renew()
}

26
internal/network/utils.go Normal file
View File

@@ -0,0 +1,26 @@
package network
import (
"encoding/json"
"time"
)
func lifetimeToTime(lifetime int) *time.Time {
if lifetime == 0 {
return nil
}
t := time.Now().Add(time.Duration(lifetime) * time.Second)
return &t
}
func IsSame(a, b interface{}) bool {
aJSON, err := json.Marshal(a)
if err != nil {
return false
}
bJSON, err := json.Marshal(b)
if err != nil {
return false
}
return string(aJSON) == string(bJSON)
}

132
internal/timesync/http.go Normal file
View File

@@ -0,0 +1,132 @@
package timesync
import (
"context"
"errors"
"math/rand"
"net/http"
"strconv"
"time"
)
var defaultHTTPUrls = []string{
"http://www.gstatic.com/generate_204",
"http://cp.cloudflare.com/",
"http://edge-http.microsoft.com/captiveportal/generate_204",
// Firefox, Apple, and Microsoft have inconsistent results, so we don't use it
// "http://detectportal.firefox.com/",
// "http://www.apple.com/library/test/success.html",
// "http://www.msftconnecttest.com/connecttest.txt",
}
func (t *TimeSync) queryAllHttpTime() (now *time.Time) {
chunkSize := 4
httpUrls := t.httpUrls
// shuffle the http urls to avoid always querying the same servers
rand.Shuffle(len(httpUrls), func(i, j int) { httpUrls[i], httpUrls[j] = httpUrls[j], httpUrls[i] })
for i := 0; i < len(httpUrls); i += chunkSize {
chunk := httpUrls[i:min(i+chunkSize, len(httpUrls))]
results := t.queryMultipleHttp(chunk, timeSyncTimeout)
if results != nil {
return results
}
}
return nil
}
func (t *TimeSync) queryMultipleHttp(urls []string, timeout time.Duration) (now *time.Time) {
results := make(chan *time.Time, len(urls))
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for _, url := range urls {
go func(url string) {
scopedLogger := t.l.With().
Str("http_url", url).
Logger()
metricHttpRequestCount.WithLabelValues(url).Inc()
metricHttpTotalRequestCount.Inc()
startTime := time.Now()
now, response, err := queryHttpTime(
ctx,
url,
timeout,
)
duration := time.Since(startTime)
metricHttpServerLastRTT.WithLabelValues(url).Set(float64(duration.Milliseconds()))
metricHttpServerRttHistogram.WithLabelValues(url).Observe(float64(duration.Milliseconds()))
status := 0
if response != nil {
status = response.StatusCode
}
metricHttpServerInfo.WithLabelValues(
url,
strconv.Itoa(status),
).Set(1)
if err == nil {
metricHttpTotalSuccessCount.Inc()
metricHttpSuccessCount.WithLabelValues(url).Inc()
requestId := response.Header.Get("X-Request-Id")
if requestId != "" {
requestId = response.Header.Get("X-Msedge-Ref")
}
if requestId == "" {
requestId = response.Header.Get("Cf-Ray")
}
scopedLogger.Info().
Str("time", now.Format(time.RFC3339)).
Int("status", status).
Str("request_id", requestId).
Str("time_taken", duration.String()).
Msg("HTTP server returned time")
cancel()
results <- now
} else if errors.Is(err, context.Canceled) {
metricHttpCancelCount.WithLabelValues(url).Inc()
metricHttpTotalCancelCount.Inc()
} else {
scopedLogger.Warn().
Str("error", err.Error()).
Int("status", status).
Msg("failed to query HTTP server")
}
}(url)
}
return <-results
}
func queryHttpTime(
ctx context.Context,
url string,
timeout time.Duration,
) (now *time.Time, response *http.Response, err error) {
client := http.Client{
Timeout: timeout,
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, nil, err
}
dateStr := resp.Header.Get("Date")
parsedTime, err := time.Parse(time.RFC1123, dateStr)
if err != nil {
return nil, nil, err
}
return &parsedTime, resp, nil
}

View File

@@ -0,0 +1,147 @@
package timesync
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
metricTimeSyncStatus = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "jetkvm_timesync_status",
Help: "The status of the timesync, 1 if successful, 0 if not",
},
)
metricTimeSyncCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_count",
Help: "The number of times the timesync has been run",
},
)
metricTimeSyncSuccessCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_success_count",
Help: "The number of times the timesync has been successful",
},
)
metricRTCUpdateCount = promauto.NewCounter( //nolint:unused
prometheus.CounterOpts{
Name: "jetkvm_timesync_rtc_update_count",
Help: "The number of times the RTC has been updated",
},
)
metricNtpTotalSuccessCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_ntp_total_success_count",
Help: "The total number of successful NTP requests",
},
)
metricNtpTotalRequestCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_ntp_total_request_count",
Help: "The total number of NTP requests sent",
},
)
metricNtpSuccessCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_timesync_ntp_success_count",
Help: "The number of successful NTP requests",
},
[]string{"url"},
)
metricNtpRequestCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_timesync_ntp_request_count",
Help: "The number of NTP requests sent to the server",
},
[]string{"url"},
)
metricNtpServerLastRTT = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_timesync_ntp_server_last_rtt",
Help: "The last RTT of the NTP server in milliseconds",
},
[]string{"url"},
)
metricNtpServerRttHistogram = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "jetkvm_timesync_ntp_server_rtt",
Help: "The histogram of the RTT of the NTP server in milliseconds",
Buckets: []float64{
10, 25, 50, 100, 200, 300, 500, 1000,
},
},
[]string{"url"},
)
metricNtpServerInfo = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_timesync_ntp_server_info",
Help: "The info of the NTP server",
},
[]string{"url", "reference", "stratum", "precision"},
)
metricHttpTotalSuccessCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_http_total_success_count",
Help: "The total number of successful HTTP requests",
},
)
metricHttpTotalRequestCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_http_total_request_count",
Help: "The total number of HTTP requests sent",
},
)
metricHttpTotalCancelCount = promauto.NewCounter(
prometheus.CounterOpts{
Name: "jetkvm_timesync_http_total_cancel_count",
Help: "The total number of HTTP requests cancelled",
},
)
metricHttpSuccessCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_timesync_http_success_count",
Help: "The number of successful HTTP requests",
},
[]string{"url"},
)
metricHttpRequestCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_timesync_http_request_count",
Help: "The number of HTTP requests sent to the server",
},
[]string{"url"},
)
metricHttpCancelCount = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "jetkvm_timesync_http_cancel_count",
Help: "The number of HTTP requests cancelled",
},
[]string{"url"},
)
metricHttpServerLastRTT = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_timesync_http_server_last_rtt",
Help: "The last RTT of the HTTP server in milliseconds",
},
[]string{"url"},
)
metricHttpServerRttHistogram = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "jetkvm_timesync_http_server_rtt",
Help: "The histogram of the RTT of the HTTP server in milliseconds",
Buckets: []float64{
10, 25, 50, 100, 200, 300, 500, 1000,
},
},
[]string{"url"},
)
metricHttpServerInfo = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "jetkvm_timesync_http_server_info",
Help: "The info of the HTTP server",
},
[]string{"url", "http_code"},
)
)

113
internal/timesync/ntp.go Normal file
View File

@@ -0,0 +1,113 @@
package timesync
import (
"math/rand/v2"
"strconv"
"time"
"github.com/beevik/ntp"
)
var defaultNTPServers = []string{
"time.apple.com",
"time.aws.com",
"time.windows.com",
"time.google.com",
"162.159.200.123", // time.cloudflare.com
"0.pool.ntp.org",
"1.pool.ntp.org",
"2.pool.ntp.org",
"3.pool.ntp.org",
}
func (t *TimeSync) queryNetworkTime() (now *time.Time, offset *time.Duration) {
chunkSize := 4
ntpServers := t.ntpServers
// shuffle the ntp servers to avoid always querying the same servers
rand.Shuffle(len(ntpServers), func(i, j int) { ntpServers[i], ntpServers[j] = ntpServers[j], ntpServers[i] })
for i := 0; i < len(ntpServers); i += chunkSize {
chunk := ntpServers[i:min(i+chunkSize, len(ntpServers))]
now, offset := t.queryMultipleNTP(chunk, timeSyncTimeout)
if now != nil {
return now, offset
}
}
return nil, nil
}
type ntpResult struct {
now *time.Time
offset *time.Duration
}
func (t *TimeSync) queryMultipleNTP(servers []string, timeout time.Duration) (now *time.Time, offset *time.Duration) {
results := make(chan *ntpResult, len(servers))
for _, server := range servers {
go func(server string) {
scopedLogger := t.l.With().
Str("server", server).
Logger()
// increase request count
metricNtpTotalRequestCount.Inc()
metricNtpRequestCount.WithLabelValues(server).Inc()
// query the server
now, response, err := queryNtpServer(server, timeout)
// set the last RTT
metricNtpServerLastRTT.WithLabelValues(
server,
).Set(float64(response.RTT.Milliseconds()))
// set the RTT histogram
metricNtpServerRttHistogram.WithLabelValues(
server,
).Observe(float64(response.RTT.Milliseconds()))
// set the server info
metricNtpServerInfo.WithLabelValues(
server,
response.ReferenceString(),
strconv.Itoa(int(response.Stratum)),
strconv.Itoa(int(response.Precision)),
).Set(1)
if err == nil {
// increase success count
metricNtpTotalSuccessCount.Inc()
metricNtpSuccessCount.WithLabelValues(server).Inc()
scopedLogger.Info().
Str("time", now.Format(time.RFC3339)).
Str("reference", response.ReferenceString()).
Str("rtt", response.RTT.String()).
Str("clockOffset", response.ClockOffset.String()).
Uint8("stratum", response.Stratum).
Msg("NTP server returned time")
results <- &ntpResult{
now: now,
offset: &response.ClockOffset,
}
} else {
scopedLogger.Warn().
Str("error", err.Error()).
Msg("failed to query NTP server")
}
}(server)
}
result := <-results
return result.now, result.offset
}
func queryNtpServer(server string, timeout time.Duration) (now *time.Time, response *ntp.Response, err error) {
resp, err := ntp.QueryWithOptions(server, ntp.QueryOptions{Timeout: timeout})
if err != nil {
return nil, nil, err
}
return &resp.Time, resp, nil
}

26
internal/timesync/rtc.go Normal file
View File

@@ -0,0 +1,26 @@
package timesync
import (
"fmt"
"os"
)
var (
rtcDeviceSearchPaths = []string{
"/dev/rtc",
"/dev/rtc0",
"/dev/rtc1",
"/dev/misc/rtc",
"/dev/misc/rtc0",
"/dev/misc/rtc1",
}
)
func getRtcDevicePath() (string, error) {
for _, path := range rtcDeviceSearchPaths {
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
return "", fmt.Errorf("rtc device not found")
}

View File

@@ -0,0 +1,105 @@
//go:build linux
package timesync
import (
"fmt"
"os"
"time"
"golang.org/x/sys/unix"
)
func TimetoRtcTime(t time.Time) unix.RTCTime {
return unix.RTCTime{
Sec: int32(t.Second()),
Min: int32(t.Minute()),
Hour: int32(t.Hour()),
Mday: int32(t.Day()),
Mon: int32(t.Month() - 1),
Year: int32(t.Year() - 1900),
Wday: int32(0),
Yday: int32(0),
Isdst: int32(0),
}
}
func RtcTimetoTime(t unix.RTCTime) time.Time {
return time.Date(
int(t.Year)+1900,
time.Month(t.Mon+1),
int(t.Mday),
int(t.Hour),
int(t.Min),
int(t.Sec),
0,
time.UTC,
)
}
func (t *TimeSync) getRtcDevice() (*os.File, error) {
if t.rtcDevice == nil {
file, err := os.OpenFile(t.rtcDevicePath, os.O_RDWR, 0666)
if err != nil {
return nil, err
}
t.rtcDevice = file
}
return t.rtcDevice, nil
}
func (t *TimeSync) getRtcDeviceFd() (int, error) {
device, err := t.getRtcDevice()
if err != nil {
return 0, err
}
return int(device.Fd()), nil
}
// Read implements Read for the Linux RTC
func (t *TimeSync) readRtcTime() (time.Time, error) {
fd, err := t.getRtcDeviceFd()
if err != nil {
return time.Time{}, fmt.Errorf("failed to get RTC device fd: %w", err)
}
rtcTime, err := unix.IoctlGetRTCTime(fd)
if err != nil {
return time.Time{}, fmt.Errorf("failed to get RTC time: %w", err)
}
date := RtcTimetoTime(*rtcTime)
return date, nil
}
// Set implements Set for the Linux RTC
// ...
// It might be not accurate as the time consumed by the system call is not taken into account
// but it's good enough for our purposes
func (t *TimeSync) setRtcTime(tu time.Time) error {
rt := TimetoRtcTime(tu)
fd, err := t.getRtcDeviceFd()
if err != nil {
return fmt.Errorf("failed to get RTC device fd: %w", err)
}
currentRtcTime, err := t.readRtcTime()
if err != nil {
return fmt.Errorf("failed to read RTC time: %w", err)
}
t.l.Info().
Interface("rtc_time", tu).
Str("offset", tu.Sub(currentRtcTime).String()).
Msg("set rtc time")
if err := unix.IoctlSetRTCTime(fd, &rt); err != nil {
return fmt.Errorf("failed to set RTC time: %w", err)
}
metricRTCUpdateCount.Inc()
return nil
}

View File

@@ -0,0 +1,16 @@
//go:build !linux
package timesync
import (
"errors"
"time"
)
func (t *TimeSync) readRtcTime() (time.Time, error) {
return time.Now(), nil
}
func (t *TimeSync) setRtcTime(tu time.Time) error {
return errors.New("not supported")
}

View File

@@ -0,0 +1,208 @@
package timesync
import (
"fmt"
"os"
"os/exec"
"sync"
"time"
"github.com/jetkvm/kvm/internal/network"
"github.com/rs/zerolog"
)
const (
timeSyncRetryStep = 5 * time.Second
timeSyncRetryMaxInt = 1 * time.Minute
timeSyncWaitNetChkInt = 100 * time.Millisecond
timeSyncWaitNetUpInt = 3 * time.Second
timeSyncInterval = 1 * time.Hour
timeSyncTimeout = 2 * time.Second
)
var (
timeSyncRetryInterval = 0 * time.Second
)
type TimeSync struct {
syncLock *sync.Mutex
l *zerolog.Logger
ntpServers []string
httpUrls []string
networkConfig *network.NetworkConfig
rtcDevicePath string
rtcDevice *os.File //nolint:unused
rtcLock *sync.Mutex
syncSuccess bool
preCheckFunc func() (bool, error)
}
type TimeSyncOptions struct {
PreCheckFunc func() (bool, error)
Logger *zerolog.Logger
NetworkConfig *network.NetworkConfig
}
type SyncMode struct {
Ntp bool
Http bool
Ordering []string
NtpUseFallback bool
HttpUseFallback bool
}
func NewTimeSync(opts *TimeSyncOptions) *TimeSync {
rtcDevice, err := getRtcDevicePath()
if err != nil {
opts.Logger.Error().Err(err).Msg("failed to get RTC device path")
} else {
opts.Logger.Info().Str("path", rtcDevice).Msg("RTC device found")
}
t := &TimeSync{
syncLock: &sync.Mutex{},
l: opts.Logger,
rtcDevicePath: rtcDevice,
rtcLock: &sync.Mutex{},
preCheckFunc: opts.PreCheckFunc,
ntpServers: defaultNTPServers,
httpUrls: defaultHTTPUrls,
networkConfig: opts.NetworkConfig,
}
if t.rtcDevicePath != "" {
rtcTime, _ := t.readRtcTime()
t.l.Info().Interface("rtc_time", rtcTime).Msg("read RTC time")
}
return t
}
func (t *TimeSync) getSyncMode() SyncMode {
syncMode := SyncMode{
NtpUseFallback: true,
HttpUseFallback: true,
}
var syncModeString string
if t.networkConfig != nil {
syncModeString = t.networkConfig.TimeSyncMode.String
if t.networkConfig.TimeSyncDisableFallback.Bool {
syncMode.NtpUseFallback = false
syncMode.HttpUseFallback = false
}
}
switch syncModeString {
case "ntp_only":
syncMode.Ntp = true
case "http_only":
syncMode.Http = true
default:
syncMode.Ntp = true
syncMode.Http = true
}
return syncMode
}
func (t *TimeSync) doTimeSync() {
metricTimeSyncStatus.Set(0)
for {
if ok, err := t.preCheckFunc(); !ok {
if err != nil {
t.l.Error().Err(err).Msg("pre-check failed")
}
time.Sleep(timeSyncWaitNetChkInt)
continue
}
t.l.Info().Msg("syncing system time")
start := time.Now()
err := t.Sync()
if err != nil {
t.l.Error().Str("error", err.Error()).Msg("failed to sync system time")
// retry after a delay
timeSyncRetryInterval += timeSyncRetryStep
time.Sleep(timeSyncRetryInterval)
// reset the retry interval if it exceeds the max interval
if timeSyncRetryInterval > timeSyncRetryMaxInt {
timeSyncRetryInterval = 0
}
continue
}
t.syncSuccess = true
t.l.Info().Str("now", time.Now().Format(time.RFC3339)).
Str("time_taken", time.Since(start).String()).
Msg("time sync successful")
metricTimeSyncStatus.Set(1)
time.Sleep(timeSyncInterval) // after the first sync is done
}
}
func (t *TimeSync) Sync() error {
var (
now *time.Time
offset *time.Duration
)
syncMode := t.getSyncMode()
metricTimeSyncCount.Inc()
if syncMode.Ntp {
now, offset = t.queryNetworkTime()
}
if syncMode.Http && now == nil {
now = t.queryAllHttpTime()
}
if now == nil {
return fmt.Errorf("failed to get time from any source")
}
if offset != nil {
newNow := time.Now().Add(*offset)
now = &newNow
}
err := t.setSystemTime(*now)
if err != nil {
return fmt.Errorf("failed to set system time: %w", err)
}
metricTimeSyncSuccessCount.Inc()
return nil
}
func (t *TimeSync) IsSyncSuccess() bool {
return t.syncSuccess
}
func (t *TimeSync) Start() {
go t.doTimeSync()
}
func (t *TimeSync) setSystemTime(now time.Time) error {
nowStr := now.Format("2006-01-02 15:04:05")
output, err := exec.Command("date", "-s", nowStr).CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run date -s: %w, %s", err, string(output))
}
if t.rtcDevicePath != "" {
return t.setRtcTime(now)
}
return nil
}

View File

@@ -0,0 +1,12 @@
package udhcpc
func (u *DHCPClient) GetNtpServers() []string {
if u.lease == nil {
return nil
}
servers := make([]string, len(u.lease.NTPServers))
for i, server := range u.lease.NTPServers {
servers[i] = server.String()
}
return servers
}

186
internal/udhcpc/parser.go Normal file
View File

@@ -0,0 +1,186 @@
package udhcpc
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"reflect"
"strconv"
"strings"
"time"
)
type Lease struct {
// from https://udhcp.busybox.net/README.udhcpc
IPAddress net.IP `env:"ip" json:"ip"` // The obtained IP
Netmask net.IP `env:"subnet" json:"netmask"` // The assigned subnet mask
Broadcast net.IP `env:"broadcast" json:"broadcast"` // The broadcast address for this network
TTL int `env:"ipttl" json:"ttl,omitempty"` // The TTL to use for this network
MTU int `env:"mtu" json:"mtu,omitempty"` // The MTU to use for this network
HostName string `env:"hostname" json:"hostname,omitempty"` // The assigned hostname
Domain string `env:"domain" json:"domain,omitempty"` // The domain name of the network
BootPNextServer net.IP `env:"siaddr" json:"bootp_next_server,omitempty"` // The bootp next server option
BootPServerName string `env:"sname" json:"bootp_server_name,omitempty"` // The bootp server name option
BootPFile string `env:"boot_file" json:"bootp_file,omitempty"` // The bootp boot file option
Timezone string `env:"timezone" json:"timezone,omitempty"` // Offset in seconds from UTC
Routers []net.IP `env:"router" json:"routers,omitempty"` // A list of routers
DNS []net.IP `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers
NTPServers []net.IP `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers
LPRServers []net.IP `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers
TimeServers []net.IP `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete)
IEN116NameServers []net.IP `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete)
LogServers []net.IP `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete)
CookieServers []net.IP `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete)
WINSServers []net.IP `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers
SwapServer net.IP `env:"swapsvr" json:"_swap_server,omitempty"` // The IP address of the client's swap server
BootSize int `env:"bootsize" json:"bootsize,omitempty"` // The length in 512 octect blocks of the bootfile
RootPath string `env:"rootpath" json:"root_path,omitempty"` // The path name of the client's root disk
LeaseTime time.Duration `env:"lease" json:"lease,omitempty"` // The lease time, in seconds
DHCPType string `env:"dhcptype" json:"dhcp_type,omitempty"` // DHCP message type (safely ignored)
ServerID string `env:"serverid" json:"server_id,omitempty"` // The IP of the server
Message string `env:"message" json:"reason,omitempty"` // Reason for a DHCPNAK
TFTPServerName string `env:"tftp" json:"tftp,omitempty"` // The TFTP server name
BootFileName string `env:"bootfile" json:"bootfile,omitempty"` // The boot file name
Uptime time.Duration `env:"uptime" json:"uptime,omitempty"` // The uptime of the device when the lease was obtained, in seconds
LeaseExpiry *time.Time `json:"lease_expiry,omitempty"` // The expiry time of the lease
isEmpty map[string]bool
}
func (l *Lease) setIsEmpty(m map[string]bool) {
l.isEmpty = m
}
func (l *Lease) IsEmpty(key string) bool {
return l.isEmpty[key]
}
func (l *Lease) ToJSON() string {
json, err := json.Marshal(l)
if err != nil {
return ""
}
return string(json)
}
func (l *Lease) SetLeaseExpiry() (time.Time, error) {
if l.Uptime == 0 || l.LeaseTime == 0 {
return time.Time{}, fmt.Errorf("uptime or lease time isn't set")
}
// get the uptime of the device
file, err := os.Open("/proc/uptime")
if err != nil {
return time.Time{}, fmt.Errorf("failed to open uptime file: %w", err)
}
defer file.Close()
var uptime time.Duration
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
parts := strings.Split(text, " ")
uptime, err = time.ParseDuration(parts[0] + "s")
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse uptime: %w", err)
}
}
relativeLeaseRemaining := (l.Uptime + l.LeaseTime) - uptime
leaseExpiry := time.Now().Add(relativeLeaseRemaining)
l.LeaseExpiry = &leaseExpiry
return leaseExpiry, nil
}
func UnmarshalDHCPCLease(lease *Lease, str string) error {
// parse the lease file as a map
data := make(map[string]string)
for _, line := range strings.Split(str, "\n") {
line = strings.TrimSpace(line)
// skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
data[key] = value
}
// now iterate over the lease struct and set the values
leaseType := reflect.TypeOf(lease).Elem()
leaseValue := reflect.ValueOf(lease).Elem()
valuesParsed := make(map[string]bool)
for i := 0; i < leaseType.NumField(); i++ {
field := leaseValue.Field(i)
// get the env tag
key := leaseType.Field(i).Tag.Get("env")
if key == "" {
continue
}
valuesParsed[key] = false
// get the value from the data map
value, ok := data[key]
if !ok || value == "" {
continue
}
switch field.Interface().(type) {
case string:
field.SetString(value)
case int:
val, err := strconv.Atoi(value)
if err != nil {
continue
}
field.SetInt(int64(val))
case time.Duration:
val, err := time.ParseDuration(value + "s")
if err != nil {
continue
}
field.Set(reflect.ValueOf(val))
case net.IP:
ip := net.ParseIP(value)
if ip == nil {
continue
}
field.Set(reflect.ValueOf(ip))
case []net.IP:
val := make([]net.IP, 0)
for _, ipStr := range strings.Fields(value) {
ip := net.ParseIP(ipStr)
if ip == nil {
continue
}
val = append(val, ip)
}
field.Set(reflect.ValueOf(val))
default:
return fmt.Errorf("unsupported field `%s` type: %s", key, field.Type().String())
}
valuesParsed[key] = true
}
lease.setIsEmpty(valuesParsed)
return nil
}

View File

@@ -0,0 +1,74 @@
package udhcpc
import (
"testing"
"time"
)
func TestUnmarshalDHCPCLease(t *testing.T) {
lease := &Lease{}
err := UnmarshalDHCPCLease(lease, `
# generated @ Mon Jan 4 19:31:53 UTC 2021
# 19:31:53 up 0 min, 0 users, load average: 0.72, 0.14, 0.04
# the date might be inaccurate if the clock is not set
ip=192.168.0.240
siaddr=192.168.0.1
sname=
boot_file=
subnet=255.255.255.0
timezone=
router=192.168.0.1
timesvr=
namesvr=
dns=172.19.53.2
logsvr=
cookiesvr=
lprsvr=
hostname=
bootsize=
domain=
swapsvr=
rootpath=
ipttl=
mtu=
broadcast=
ntpsrv=162.159.200.123
wins=
lease=172800
dhcptype=
serverid=192.168.0.1
message=
tftp=
bootfile=
`)
if lease.IPAddress.String() != "192.168.0.240" {
t.Fatalf("expected ip to be 192.168.0.240, got %s", lease.IPAddress.String())
}
if lease.Netmask.String() != "255.255.255.0" {
t.Fatalf("expected netmask to be 255.255.255.0, got %s", lease.Netmask.String())
}
if len(lease.Routers) != 1 {
t.Fatalf("expected 1 router, got %d", len(lease.Routers))
}
if lease.Routers[0].String() != "192.168.0.1" {
t.Fatalf("expected router to be 192.168.0.1, got %s", lease.Routers[0].String())
}
if len(lease.NTPServers) != 1 {
t.Fatalf("expected 1 timeserver, got %d", len(lease.NTPServers))
}
if lease.NTPServers[0].String() != "162.159.200.123" {
t.Fatalf("expected timeserver to be 162.159.200.123, got %s", lease.NTPServers[0].String())
}
if len(lease.DNS) != 1 {
t.Fatalf("expected 1 dns, got %d", len(lease.DNS))
}
if lease.DNS[0].String() != "172.19.53.2" {
t.Fatalf("expected dns to be 172.19.53.2, got %s", lease.DNS[0].String())
}
if lease.LeaseTime != 172800*time.Second {
t.Fatalf("expected lease time to be 172800 seconds, got %d", lease.LeaseTime)
}
if err != nil {
t.Fatal(err)
}
}

212
internal/udhcpc/proc.go Normal file
View File

@@ -0,0 +1,212 @@
package udhcpc
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
)
func readFileNoStat(filename string) ([]byte, error) {
const maxBufferSize = 1024 * 1024
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
reader := io.LimitReader(f, maxBufferSize)
return io.ReadAll(reader)
}
func toCmdline(path string) ([]string, error) {
data, err := readFileNoStat(path)
if err != nil {
return nil, err
}
if len(data) < 1 {
return []string{}, nil
}
return strings.Split(string(bytes.TrimRight(data, "\x00")), "\x00"), nil
}
func (p *DHCPClient) findUdhcpcProcess() (int, error) {
// read procfs for udhcpc processes
// we do not use procfs.AllProcs() because we want to avoid the overhead of reading the entire procfs
processes, err := os.ReadDir("/proc")
if err != nil {
return 0, err
}
// iterate over the processes
for _, d := range processes {
// check if file is numeric
pid, err := strconv.Atoi(d.Name())
if err != nil {
continue
}
// check if it's a directory
if !d.IsDir() {
continue
}
cmdline, err := toCmdline(filepath.Join("/proc", d.Name(), "cmdline"))
if err != nil {
continue
}
if len(cmdline) < 1 {
continue
}
if cmdline[0] != "udhcpc" {
continue
}
cmdlineText := strings.Join(cmdline, " ")
// check if it's a udhcpc process
if strings.Contains(cmdlineText, fmt.Sprintf("-i %s", p.InterfaceName)) {
p.logger.Debug().
Str("pid", d.Name()).
Interface("cmdline", cmdline).
Msg("found udhcpc process")
return pid, nil
}
}
return 0, errors.New("udhcpc process not found")
}
func (c *DHCPClient) getProcessPid() (int, error) {
var pid int
if c.pidFile != "" {
// try to read the pid file
pidHandle, err := os.ReadFile(c.pidFile)
if err != nil {
c.logger.Warn().Err(err).
Str("pidFile", c.pidFile).Msg("failed to read udhcpc pid file")
}
// if it exists, try to read the pid
if pidHandle != nil {
pidFromFile, err := strconv.Atoi(string(pidHandle))
if err != nil {
c.logger.Warn().Err(err).
Str("pidFile", c.pidFile).Msg("failed to convert pid file to int")
}
pid = pidFromFile
}
}
// if the pid is 0, try to find the pid using procfs
if pid == 0 {
newPid, err := c.findUdhcpcProcess()
if err != nil {
return 0, err
}
pid = newPid
}
return pid, nil
}
func (c *DHCPClient) getProcess() *os.Process {
pid, err := c.getProcessPid()
if err != nil {
return nil
}
process, err := os.FindProcess(pid)
if err != nil {
c.logger.Warn().Err(err).
Int("pid", pid).Msg("failed to find process")
return nil
}
return process
}
func (c *DHCPClient) GetProcess() *os.Process {
if c.process == nil {
process := c.getProcess()
if process == nil {
return nil
}
c.process = process
}
err := c.process.Signal(syscall.Signal(0))
if err != nil && errors.Is(err, os.ErrProcessDone) {
oldPid := c.process.Pid
c.process = nil
c.process = c.getProcess()
if c.process == nil {
c.logger.Error().Msg("failed to find new udhcpc process")
return nil
}
c.logger.Warn().
Int("oldPid", oldPid).
Int("newPid", c.process.Pid).
Msg("udhcpc process pid changed")
} else if err != nil {
c.logger.Warn().Err(err).
Int("pid", c.process.Pid).Msg("udhcpc process is not running")
}
return c.process
}
func (c *DHCPClient) KillProcess() error {
process := c.GetProcess()
if process == nil {
return nil
}
return process.Kill()
}
func (c *DHCPClient) ReleaseProcess() error {
process := c.GetProcess()
if process == nil {
return nil
}
return process.Release()
}
func (c *DHCPClient) signalProcess(sig syscall.Signal) error {
process := c.GetProcess()
if process == nil {
return nil
}
s := process.Signal(sig)
if s != nil {
c.logger.Warn().Err(s).
Int("pid", process.Pid).
Str("signal", sig.String()).
Msg("failed to signal udhcpc process")
return s
}
return nil
}
func (c *DHCPClient) Renew() error {
return c.signalProcess(syscall.SIGUSR1)
}
func (c *DHCPClient) Release() error {
return c.signalProcess(syscall.SIGUSR2)
}

191
internal/udhcpc/udhcpc.go Normal file
View File

@@ -0,0 +1,191 @@
package udhcpc
import (
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog"
)
const (
DHCPLeaseFile = "/run/udhcpc.%s.info"
DHCPPidFile = "/run/udhcpc.%s.pid"
)
type DHCPClient struct {
InterfaceName string
leaseFile string
pidFile string
lease *Lease
logger *zerolog.Logger
process *os.Process
onLeaseChange func(lease *Lease)
}
type DHCPClientOptions struct {
InterfaceName string
PidFile string
Logger *zerolog.Logger
OnLeaseChange func(lease *Lease)
}
var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel)
func NewDHCPClient(options *DHCPClientOptions) *DHCPClient {
if options.Logger == nil {
options.Logger = &defaultLogger
}
l := options.Logger.With().Str("interface", options.InterfaceName).Logger()
return &DHCPClient{
InterfaceName: options.InterfaceName,
logger: &l,
leaseFile: fmt.Sprintf(DHCPLeaseFile, options.InterfaceName),
pidFile: options.PidFile,
onLeaseChange: options.OnLeaseChange,
}
}
func (c *DHCPClient) getWatchPaths() []string {
watchPaths := make(map[string]interface{})
watchPaths[filepath.Dir(c.leaseFile)] = nil
if c.pidFile != "" {
watchPaths[filepath.Dir(c.pidFile)] = nil
}
paths := make([]string, 0)
for path := range watchPaths {
paths = append(paths, path)
}
return paths
}
// Run starts the DHCP client and watches the lease file for changes.
// this isn't a blocking call, and the lease file is reloaded when a change is detected.
func (c *DHCPClient) Run() error {
err := c.loadLeaseFile()
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
defer watcher.Close()
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
continue
}
if !event.Has(fsnotify.Write) && !event.Has(fsnotify.Create) {
continue
}
if event.Name == c.leaseFile {
c.logger.Debug().
Str("event", event.Op.String()).
Str("path", event.Name).
Msg("udhcpc lease file updated, reloading lease")
_ = c.loadLeaseFile()
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
c.logger.Error().Err(err).Msg("error watching lease file")
}
}
}()
for _, path := range c.getWatchPaths() {
err = watcher.Add(path)
if err != nil {
c.logger.Error().
Err(err).
Str("path", path).
Msg("failed to watch directory")
return err
}
}
// TODO: update udhcpc pid file
// we'll comment this out for now because the pid might change
// process := c.GetProcess()
// if process == nil {
// c.logger.Error().Msg("udhcpc process not found")
// }
// block the goroutine until the lease file is updated
<-make(chan struct{})
return nil
}
func (c *DHCPClient) loadLeaseFile() error {
file, err := os.ReadFile(c.leaseFile)
if err != nil {
return err
}
data := string(file)
if data == "" {
c.logger.Debug().Msg("udhcpc lease file is empty")
return nil
}
lease := &Lease{}
err = UnmarshalDHCPCLease(lease, string(file))
if err != nil {
return err
}
isFirstLoad := c.lease == nil
c.lease = lease
if lease.IPAddress == nil {
c.logger.Info().
Interface("lease", lease).
Str("data", string(file)).
Msg("udhcpc lease cleared")
return nil
}
msg := "udhcpc lease updated"
if isFirstLoad {
msg = "udhcpc lease loaded"
}
leaseExpiry, err := lease.SetLeaseExpiry()
if err != nil {
c.logger.Error().Err(err).Msg("failed to get dhcp lease expiry")
} else {
expiresIn := time.Until(leaseExpiry)
c.logger.Info().
Interface("expiry", leaseExpiry).
Str("expiresIn", expiresIn.String()).
Msg("current dhcp lease expiry time calculated")
}
c.onLeaseChange(lease)
c.logger.Info().
Str("ip", lease.IPAddress.String()).
Str("leaseTime", lease.LeaseTime.String()).
Interface("data", lease).
Msg(msg)
return nil
}
func (c *DHCPClient) GetLease() *Lease {
return c.lease
}

View File

@@ -96,7 +96,11 @@ func (s *CertStore) loadCertificate(hostname string) {
s.certificates[hostname] = &cert
s.log.Info().Str("hostname", hostname).Msg("Loaded certificate")
if hostname == selfSignerCAMagicName {
s.log.Info().Msg("loaded CA certificate")
} else {
s.log.Info().Str("hostname", hostname).Msg("loaded certificate")
}
}
// GetCertificate returns the certificate for the given hostname
@@ -131,7 +135,7 @@ func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key
if !ignoreWarning {
return nil, fmt.Errorf("certificate does not match hostname: %w", err)
}
s.log.Warn().Err(err).Msg("Certificate does not match hostname")
s.log.Warn().Err(err).Msg("certificate does not match hostname")
}
}