Feat/Trickle ice (#336)

* feat(cloud): Use Websocket signaling in cloud mode

* refactor: Enhance WebRTC signaling and connection handling

* refactor: Improve WebRTC connection management and logging in KvmIdRoute

* refactor: Update PeerConnectionDisconnectedOverlay to use Card component for better UI structure

* refactor: Standardize metric naming and improve websocket logging

* refactor: Rename WebRTC signaling functions and update deployment script for debug version

* fix: Handle error when writing new ICE candidate to WebRTC signaling channel

* refactor: Rename signaling handler function for clarity

* refactor: Remove old http local http endpoint

* refactor: Improve metric help text and standardize comparison operator in KvmIdRoute

* chore(websocket): use MetricVec instead of Metric to store metrics

* fix conflicts

* fix: use wss when the page is served over https

* feat: Add app version header and update WebRTC signaling endpoint

* fix: Handle error when writing device metadata to WebRTC signaling channel

---------

Co-authored-by: Siyuan Miao <i@xswan.net>
This commit is contained in:
Adam Shiervani
2025-04-09 00:10:38 +02:00
committed by GitHub
parent fa1b11b228
commit 1a30977085
10 changed files with 654 additions and 275 deletions

188
web.go
View File

@@ -1,6 +1,7 @@
package kvm
import (
"context"
"embed"
"encoding/json"
"fmt"
@@ -10,8 +11,12 @@ import (
"strings"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/pion/webrtc/v4"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/crypto/bcrypt"
)
@@ -94,7 +99,7 @@ func setupRouter() *gin.Engine {
protected := r.Group("/")
protected.Use(protectedMiddleware())
{
protected.POST("/webrtc/session", handleWebRTCSession)
protected.GET("/webrtc/signaling/client", handleLocalWebRTCSignal)
protected.POST("/cloud/register", handleCloudRegister)
protected.GET("/cloud/state", handleCloudState)
protected.GET("/device", handleDevice)
@@ -121,35 +126,182 @@ func setupRouter() *gin.Engine {
// TODO: support multiple sessions?
var currentSession *Session
func handleWebRTCSession(c *gin.Context) {
var req WebRTCSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
func handleLocalWebRTCSignal(c *gin.Context) {
cloudLogger.Infof("new websocket connection established")
// Create WebSocket options with InsecureSkipVerify to bypass origin check
wsOptions := &websocket.AcceptOptions{
InsecureSkipVerify: true, // Allow connections from any origin
}
session, err := newSession(SessionConfig{})
wsCon, err := websocket.Accept(c.Writer, c.Request, wsOptions)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sd, err := session.ExchangeOffer(req.Sd)
// get the source from the request
source := c.ClientIP()
// Now use conn for websocket operations
defer wsCon.Close(websocket.StatusNormalClosure, "")
err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if currentSession != nil {
writeJSONRPCEvent("otherSessionConnected", nil, currentSession)
peerConn := currentSession.peerConnection
err = handleWebRTCSignalWsMessages(wsCon, false, source)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func handleWebRTCSignalWsMessages(wsCon *websocket.Conn, isCloudConnection bool, source string) error {
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
// Add connection tracking to detect reconnections
connectionID := uuid.New().String()
cloudLogger.Infof("new websocket connection established with ID: %s", connectionID)
// connection type
var sourceType string
if isCloudConnection {
sourceType = "cloud"
} else {
sourceType = "local"
}
// probably we can use a better logging framework here
logInfof := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Infof(format+", source: %s, sourceType: %s", args...)
}
logWarnf := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Warnf(format+", source: %s, sourceType: %s", args...)
}
logTracef := func(format string, args ...interface{}) {
args = append(args, source, sourceType)
websocketLogger.Tracef(format+", source: %s, sourceType: %s", args...)
}
go func() {
for {
time.Sleep(WebsocketPingInterval)
// set the timer for the ping duration
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
metricConnectionLastPingDuration.WithLabelValues(sourceType, source).Set(v)
metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v)
}))
logInfof("pinging websocket")
err := wsCon.Ping(runCtx)
if err != nil {
logWarnf("websocket ping error: %v", err)
cancelRun()
return
}
// dont use `defer` here because we want to observe the duration of the ping
timer.ObserveDuration()
metricConnectionTotalPingCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
}
}()
if isCloudConnection {
// create a channel to receive the disconnect event, once received, we cancelRun
cloudDisconnectChan = make(chan error)
defer func() {
close(cloudDisconnectChan)
cloudDisconnectChan = nil
}()
go func() {
time.Sleep(1 * time.Second)
_ = peerConn.Close()
for err := range cloudDisconnectChan {
if err == nil {
continue
}
cloudLogger.Infof("disconnecting from cloud due to: %v", err)
cancelRun()
}
}()
}
currentSession = session
c.JSON(http.StatusOK, gin.H{"sd": sd})
for {
typ, msg, err := wsCon.Read(runCtx)
if err != nil {
logWarnf("websocket read error: %v", err)
return err
}
if typ != websocket.MessageText {
// ignore non-text messages
continue
}
var message struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
err = json.Unmarshal(msg, &message)
if err != nil {
logWarnf("unable to parse ws message: %v", err)
continue
}
if message.Type == "offer" {
logInfof("new session request received")
var req WebRTCSessionRequest
err = json.Unmarshal(message.Data, &req)
if err != nil {
logWarnf("unable to parse session request data: %v", err)
continue
}
logInfof("new session request: %v", req.OidcGoogle)
logTracef("session request info: %v", req)
metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc()
metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime()
err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source)
if err != nil {
logWarnf("error starting new session: %v", err)
continue
}
} else if message.Type == "new-ice-candidate" {
logInfof("The client sent us a new ICE candidate: %v", string(message.Data))
var candidate webrtc.ICECandidateInit
// Attempt to unmarshal as a ICECandidateInit
if err := json.Unmarshal(message.Data, &candidate); err != nil {
logWarnf("unable to parse incoming ICE candidate data: %v", string(message.Data))
continue
}
if candidate.Candidate == "" {
logWarnf("empty incoming ICE candidate, skipping")
continue
}
logInfof("unmarshalled incoming ICE candidate: %v", candidate)
if currentSession == nil {
logInfof("no current session, skipping incoming ICE candidate")
continue
}
logInfof("adding incoming ICE candidate to current session: %v", candidate)
if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil {
logWarnf("failed to add incoming ICE candidate to our peer connection: %v", err)
}
}
}
}
func handleLogin(c *gin.Context) {