From 0eeac93416ae9dbdb0414e7491b0bcfbaa5f3db8 Mon Sep 17 00:00:00 2001 From: Pol Henarejos Date: Tue, 21 Apr 2026 15:25:53 +0200 Subject: [PATCH] Add support for verified sessions. Signed-off-by: Pol Henarejos --- src/usb/lwip/rest.c | 6 +- src/usb/lwip/rest.h | 20 ++++- src/usb/lwip/rest_server.c | 152 +++++++++++++++++++++++++++++++++---- 3 files changed, 159 insertions(+), 19 deletions(-) diff --git a/src/usb/lwip/rest.c b/src/usb/lwip/rest.c index c9d55d1..e96b025 100644 --- a/src/usb/lwip/rest.c +++ b/src/usb/lwip/rest.c @@ -32,7 +32,7 @@ rest_session_t *rest_session_create(const rest_session_role_t role, rest_session rest_sessions[i].status = status; rest_sessions[i].role = role; random_fill_buffer(rest_sessions[i].id, sizeof(rest_sessions[i].id)); - rest_sessions[i].created_at = get_rtc_time(); + rest_sessions[i].created_at = board_millis(); rest_sessions[i].last_activity_timestamp = rest_sessions[i].created_at; return &rest_sessions[i]; } @@ -68,7 +68,7 @@ int rest_session_update_activity(const uint8_t *id, size_t id_len) { if (session == NULL) { return -1; } - session->last_activity_timestamp = get_rtc_time(); + session->last_activity_timestamp = board_millis(); return 0; } @@ -92,7 +92,7 @@ int rest_session_set_role(const uint8_t *id, size_t id_len, rest_session_role_t int rest_session_cleanup_expired(time_t expiration_time) { int count = 0; - time_t now = get_rtc_time(); + time_t now = board_millis(); for (int i = 0; i < REST_MAX_SESSIONS; i++) { if (rest_sessions[i].status != REST_SESSION_UNKNOWN && rest_sessions[i].status != REST_SESSION_EXPIRED && rest_sessions[i].status != REST_SESSION_TERMINATED) { if (now - rest_sessions[i].last_activity_timestamp > expiration_time) { diff --git a/src/usb/lwip/rest.h b/src/usb/lwip/rest.h index bd870bb..87cc2d1 100644 --- a/src/usb/lwip/rest.h +++ b/src/usb/lwip/rest.h @@ -38,12 +38,26 @@ typedef enum { REST_HTTP_DELETE } rest_http_method_t; +typedef enum { + REST_HEADER_USER_AGENT = 0, + REST_HEADER_AUTHORIZATION, + REST_HEADER_CONTENT_TYPE, + REST_HEADER_CONTENT_LENGTH, + REST_HEADER_HOST, + REST_HEADER_ACCEPT, + REST_HEADER_X_SESSION_ID, + REST_HEADER_X_SEQ, + REST_HEADER_X_SIGNATURE, + REST_HEADER_TOTAL_COUNT +} rest_header_id_t; + typedef struct { rest_http_method_t method; char path[REST_MAX_PATH_SIZE]; const char *body; size_t body_len; const char *content_type; + char *headers[REST_HEADER_TOTAL_COUNT]; } rest_request_t; typedef struct { @@ -52,13 +66,15 @@ typedef struct { char *body; // heap ! size_t body_len; cJSON *json; + char *headers[REST_HEADER_TOTAL_COUNT]; } rest_response_t; typedef int (*rest_route_handler_t)(const rest_request_t *request, rest_response_t *response); typedef enum { - REST_ROUTE_NONE = 0x0, - REST_ROUTE_AUTH = 0x1, + REST_ROUTE_NONE = 0x0, + REST_ROUTE_REQUIRE_AUTH = 0x1, + REST_ROUTE_REQUIRE_TLS = 0x2, } rest_route_flags_t; typedef struct { diff --git a/src/usb/lwip/rest_server.c b/src/usb/lwip/rest_server.c index 997ab51..4e7d3a7 100644 --- a/src/usb/lwip/rest_server.c +++ b/src/usb/lwip/rest_server.c @@ -19,9 +19,15 @@ #include "rest_server.h" #include "rest_server_tls.h" #include "usb.h" +#include "pico_time.h" +#include "serial.h" #include #include +#include "mbedtls/base64.h" +#include "mbedtls/md.h" +#include "mbedtls/hkdf.h" +#include "crypto_utils.h" #ifdef ENABLE_EMULATION #ifndef _MSC_VER @@ -37,6 +43,9 @@ #include "lwip/def.h" #endif +#define REST_SESSION_TIMEOUT_INACTIVITY_MS (10 * 60 * 1000) // 10 minutes +#define REST_SESSION_TIMEOUT_TOTAL_MS (2 * 60 * 60 * 1000) // 2 hours + #ifndef ENABLE_EMULATION static struct tcp_pcb *listener_pcb = NULL; #else @@ -64,7 +73,7 @@ static rest_core1_result_t rest_core1_result = {0}; static void *rest_core1_thread(void *arg); static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len); -static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body); +void rest_close_conn(rest_conn_t *conn); static int rest_start_core1_job(rest_conn_t *conn, const rest_request_t *request, rest_route_handler_t handler) { if (request == NULL || handler == NULL || rest_core1_job.pending) { @@ -76,7 +85,6 @@ static int rest_start_core1_job(rest_conn_t *conn, const rest_request_t *request rest_core1_job.pending = true; rest_core1_job.conn = conn; rest_core1_job.handler = handler; - rest_core1_job.request = *request; card_start(ITF_LWIP_NET, rest_core1_thread); usb_send_event(EV_CMD_AVAILABLE); @@ -114,6 +122,20 @@ static void *rest_core1_thread(void *arg) { return NULL; } +static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) { + send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body)); +} + +static void send_json_error(rest_conn_t *conn, int status_code, const char *error_message) { + char json[256]; + int json_len = snprintf(json, sizeof(json), "{\"error\":\"%s\"}", error_message); + if (json_len <= 0 || (size_t)json_len >= sizeof(json)) { + rest_close_conn(conn); + return; + } + send_json(conn, status_code, rest_status_text_from_code(status_code), json); +} + void rest_task(void) { int status; rest_conn_t *conn; @@ -139,7 +161,7 @@ void rest_task(void) { send_response(conn, code, rest_status_text_from_code(code), response->content_type, response->body, response->body_len); } else { - send_json(conn, 500, "Internal Server Error", "{\"error\":\"internal_error\"}"); + send_json_error(conn, 500, "internal_error"); } } @@ -362,9 +384,22 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status rest_close_conn(conn); } -static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) { - send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body)); -} +typedef struct { + rest_header_id_t id; + const char *name; +} rest_header_descriptor_t; + +static const rest_header_descriptor_t rest_http_headers[REST_HEADER_TOTAL_COUNT] = { + { REST_HEADER_USER_AGENT, "User-Agent" }, + { REST_HEADER_AUTHORIZATION, "Authorization" }, + { REST_HEADER_CONTENT_TYPE, "Content-Type" }, + { REST_HEADER_CONTENT_LENGTH, "Content-Length" }, + { REST_HEADER_HOST, "Host" }, + { REST_HEADER_ACCEPT, "Accept" }, + { REST_HEADER_X_SESSION_ID, "X-Session-ID" }, + { REST_HEADER_X_SEQ, "X-Seq" }, + { REST_HEADER_X_SIGNATURE, "X-Signature" } +}; static int parse_request(rest_conn_t *conn, rest_request_t *request) { char *header_end, *line_end, *cursor; @@ -440,6 +475,14 @@ static int parse_request(rest_conn_t *conn, rest_request_t *request) { else if (strcasecmp(name, "Content-Type") == 0) { request->content_type = value; } + else { + for (int i = 0; i < REST_HEADER_TOTAL_COUNT; i++) { + if (strcasecmp(name, rest_http_headers[i].name) == 0) { + request->headers[rest_http_headers[i].id] = value; + break; + } + } + } } cursor = next + 2; } @@ -451,6 +494,63 @@ static int parse_request(rest_conn_t *conn, rest_request_t *request) { return 1; } +static int rest_session_derive_key(const uint8_t *session_id, size_t session_id_len, uint8_t derived_key[32]) { + uint8_t kver[32]; + const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); + derive_kver(session_id, session_id_len, kver); + mbedtls_hkdf(md_info, pico_serial_hash, sizeof(pico_serial_hash), kver, 32, (const uint8_t *)"REST/SESSION", 12, derived_key, 32); + mbedtls_platform_zeroize(kver, sizeof(kver)); + return PICOKEYS_OK; +} + +static int rest_verify_request_signature(const rest_request_t *request, const rest_session_t *session) { + mbedtls_md_context_t ctx; + const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); + unsigned char hmac[32], hmac_x[32]; + size_t olen = 0; + if (md_info == NULL) { + return PICOKEYS_ERR_MEMORY_FATAL; + } + if (mbedtls_base64_decode(hmac_x, sizeof(hmac_x), &olen, (const unsigned char *)request->headers[REST_HEADER_X_SIGNATURE], strlen(request->headers[REST_HEADER_X_SIGNATURE])) != 0) { + return PICOKEYS_EXEC_ERROR; + } + mbedtls_md_init(&ctx); + if (mbedtls_md_setup(&ctx, md_info, 1) != 0) { + mbedtls_md_free(&ctx); + return PICOKEYS_ERR_MEMORY_FATAL; + } + const char *body_empty = "{}"; + const char *body = request->body_len > 0 ? request->body : body_empty; + const char *method_str = rest_method_to_string(request->method); + char seq[16]; + size_t body_len = request->body_len > 0 ? request->body_len : strlen((const char *)body_empty); + snprintf(seq, sizeof(seq), "%s", request->headers[REST_HEADER_X_SEQ] ? request->headers[REST_HEADER_X_SEQ] : "0"); + uint8_t derived_key[32]; + if (rest_session_derive_key(session->id, sizeof(session->id), derived_key) != 0) { + mbedtls_md_free(&ctx); + return PICOKEYS_EXEC_ERROR; + } + + if (mbedtls_md_hmac_starts(&ctx, (const unsigned char *)derived_key, sizeof(derived_key)) != 0 || + mbedtls_md_hmac_starts(&ctx, (const unsigned char *)session->id, sizeof(session->id)) != 0 || + mbedtls_md_hmac_update(&ctx, (const unsigned char *)method_str, strlen(method_str)) != 0 || + mbedtls_md_hmac_update(&ctx, (const unsigned char *)request->path, strlen(request->path)) != 0 || + mbedtls_md_hmac_update(&ctx, (const unsigned char *)seq, strlen(seq)) != 0 || + mbedtls_md_hmac_update(&ctx, (const unsigned char *)body, body_len) != 0) { + mbedtls_md_free(&ctx); + return PICOKEYS_EXEC_ERROR; + } + if (mbedtls_md_hmac_finish(&ctx, hmac) != 0) { + mbedtls_md_free(&ctx); + return PICOKEYS_EXEC_ERROR; + } + mbedtls_md_free(&ctx); + if (ct_memcmp(hmac, hmac_x, sizeof(hmac)) != 0) { + return PICOKEYS_EXEC_ERROR; + } + return PICOKEYS_OK; +} + void rest_handle_request(rest_conn_t *conn) { rest_request_t *request = &rest_core1_job.request; const rest_route_t *routes; @@ -459,7 +559,7 @@ void rest_handle_request(rest_conn_t *conn) { int parsed; if (rest_core1_job.pending) { - send_json(conn, 503, "Service Unavailable", "{\"error\":\"busy\"}"); + send_json_error(conn, 503, "busy"); return; } @@ -467,7 +567,7 @@ void rest_handle_request(rest_conn_t *conn) { parsed = parse_request(conn, request); if (parsed <= 0) { if (parsed < 0) { - send_json(conn, 400, "Bad Request", "{\"error\":\"bad_request\"}"); + send_json_error(conn, 400, "bad_request"); } return; } @@ -482,7 +582,7 @@ void rest_handle_request(rest_conn_t *conn) { if (request->method == REST_HTTP_POST || request->method == REST_HTTP_PUT) { if (!rest_content_type_is_json(request->content_type)) { - send_json(conn, 415, "Unsupported Media Type", "{\"error\":\"content_type_must_be_application_json\"}"); + send_json_error(conn, 415, "content_type_must_be_application_json"); return; } } @@ -498,17 +598,41 @@ void rest_handle_request(rest_conn_t *conn) { path_exists_for_other_method = true; continue; } + if (routes[i].flags & REST_ROUTE_REQUIRE_AUTH) { + if (!request->headers[REST_HEADER_X_SESSION_ID] || strlen(request->headers[REST_HEADER_X_SESSION_ID]) == 0 ||!request->headers[REST_HEADER_X_SIGNATURE] || strlen(request->headers[REST_HEADER_X_SIGNATURE]) == 0 || !request->headers[REST_HEADER_X_SEQ] || strlen(request->headers[REST_HEADER_X_SEQ]) == 0) { + send_json_error(conn, 401, "authentication_required"); + return; + } + rest_session_t *session = rest_session_get((const uint8_t *)request->headers[REST_HEADER_X_SESSION_ID], strlen(request->headers[REST_HEADER_X_SESSION_ID])); + if (!session) { + send_json_error(conn, 401, "authentication_required"); + return; + } + if (session->status != REST_SESSION_AUTHENTICATED) { + send_json_error(conn, 401, "authentication_required"); + return; + } + if (session->last_activity_timestamp + REST_SESSION_TIMEOUT_INACTIVITY_MS < board_millis() || session->created_at + REST_SESSION_TIMEOUT_TOTAL_MS < board_millis()) { + session->status = REST_SESSION_EXPIRED; + send_json_error(conn, 401, "session_expired"); + return; + } + if (rest_verify_request_signature(request, session) != PICOKEYS_OK) { + send_json_error(conn, 401, "invalid_signature"); + return; + } + } if (rest_start_core1_job(conn, request, routes[i].handler) != 0) { - send_json(conn, 500, "Internal Server Error", "{\"error\":\"internal_error\"}"); + send_json_error(conn, 500, "internal_error"); } return; } if (path_exists_for_other_method) { - send_json(conn, 405, "Method Not Allowed", "{\"error\":\"method_not_allowed\"}"); + send_json_error(conn, 405, "method_not_allowed"); } else { - send_json(conn, 404, "Not Found", "{\"error\":\"not_found\"}"); + send_json_error(conn, 404, "not_found"); } } @@ -626,7 +750,7 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err rest_close_conn(conn); return ERR_ABRT; } - send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}"); + send_json_error(conn, 413, "payload_too_large"); return ERR_OK; } pbuf_copy_partial(p, buffer + *len, p->tot_len, 0); @@ -812,7 +936,7 @@ static void *rest_emulation_thread(void *arg) { } conn->request_len += (size_t)n; if (conn->request_len > REST_MAX_REQUEST_SIZE) { - send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}"); + send_json_error(conn, 413, "payload_too_large"); break; } rest_handle_request(conn);