Add support for verified sessions.

Signed-off-by: Pol Henarejos <pol.henarejos@cttc.es>
This commit is contained in:
Pol Henarejos
2026-04-21 15:25:53 +02:00
parent dfeb5b973b
commit 0eeac93416
3 changed files with 159 additions and 19 deletions

View File

@@ -32,7 +32,7 @@ rest_session_t *rest_session_create(const rest_session_role_t role, rest_session
rest_sessions[i].status = status;
rest_sessions[i].role = role;
random_fill_buffer(rest_sessions[i].id, sizeof(rest_sessions[i].id));
rest_sessions[i].created_at = get_rtc_time();
rest_sessions[i].created_at = board_millis();
rest_sessions[i].last_activity_timestamp = rest_sessions[i].created_at;
return &rest_sessions[i];
}
@@ -68,7 +68,7 @@ int rest_session_update_activity(const uint8_t *id, size_t id_len) {
if (session == NULL) {
return -1;
}
session->last_activity_timestamp = get_rtc_time();
session->last_activity_timestamp = board_millis();
return 0;
}
@@ -92,7 +92,7 @@ int rest_session_set_role(const uint8_t *id, size_t id_len, rest_session_role_t
int rest_session_cleanup_expired(time_t expiration_time) {
int count = 0;
time_t now = get_rtc_time();
time_t now = board_millis();
for (int i = 0; i < REST_MAX_SESSIONS; i++) {
if (rest_sessions[i].status != REST_SESSION_UNKNOWN && rest_sessions[i].status != REST_SESSION_EXPIRED && rest_sessions[i].status != REST_SESSION_TERMINATED) {
if (now - rest_sessions[i].last_activity_timestamp > expiration_time) {

View File

@@ -38,12 +38,26 @@ typedef enum {
REST_HTTP_DELETE
} rest_http_method_t;
typedef enum {
REST_HEADER_USER_AGENT = 0,
REST_HEADER_AUTHORIZATION,
REST_HEADER_CONTENT_TYPE,
REST_HEADER_CONTENT_LENGTH,
REST_HEADER_HOST,
REST_HEADER_ACCEPT,
REST_HEADER_X_SESSION_ID,
REST_HEADER_X_SEQ,
REST_HEADER_X_SIGNATURE,
REST_HEADER_TOTAL_COUNT
} rest_header_id_t;
typedef struct {
rest_http_method_t method;
char path[REST_MAX_PATH_SIZE];
const char *body;
size_t body_len;
const char *content_type;
char *headers[REST_HEADER_TOTAL_COUNT];
} rest_request_t;
typedef struct {
@@ -52,13 +66,15 @@ typedef struct {
char *body; // heap !
size_t body_len;
cJSON *json;
char *headers[REST_HEADER_TOTAL_COUNT];
} rest_response_t;
typedef int (*rest_route_handler_t)(const rest_request_t *request, rest_response_t *response);
typedef enum {
REST_ROUTE_NONE = 0x0,
REST_ROUTE_AUTH = 0x1,
REST_ROUTE_NONE = 0x0,
REST_ROUTE_REQUIRE_AUTH = 0x1,
REST_ROUTE_REQUIRE_TLS = 0x2,
} rest_route_flags_t;
typedef struct {

View File

@@ -19,9 +19,15 @@
#include "rest_server.h"
#include "rest_server_tls.h"
#include "usb.h"
#include "pico_time.h"
#include "serial.h"
#include <ctype.h>
#include <strings.h>
#include "mbedtls/base64.h"
#include "mbedtls/md.h"
#include "mbedtls/hkdf.h"
#include "crypto_utils.h"
#ifdef ENABLE_EMULATION
#ifndef _MSC_VER
@@ -37,6 +43,9 @@
#include "lwip/def.h"
#endif
#define REST_SESSION_TIMEOUT_INACTIVITY_MS (10 * 60 * 1000) // 10 minutes
#define REST_SESSION_TIMEOUT_TOTAL_MS (2 * 60 * 60 * 1000) // 2 hours
#ifndef ENABLE_EMULATION
static struct tcp_pcb *listener_pcb = NULL;
#else
@@ -64,7 +73,7 @@ static rest_core1_result_t rest_core1_result = {0};
static void *rest_core1_thread(void *arg);
static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len);
static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body);
void rest_close_conn(rest_conn_t *conn);
static int rest_start_core1_job(rest_conn_t *conn, const rest_request_t *request, rest_route_handler_t handler) {
if (request == NULL || handler == NULL || rest_core1_job.pending) {
@@ -76,7 +85,6 @@ static int rest_start_core1_job(rest_conn_t *conn, const rest_request_t *request
rest_core1_job.pending = true;
rest_core1_job.conn = conn;
rest_core1_job.handler = handler;
rest_core1_job.request = *request;
card_start(ITF_LWIP_NET, rest_core1_thread);
usb_send_event(EV_CMD_AVAILABLE);
@@ -114,6 +122,20 @@ static void *rest_core1_thread(void *arg) {
return NULL;
}
static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) {
send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body));
}
static void send_json_error(rest_conn_t *conn, int status_code, const char *error_message) {
char json[256];
int json_len = snprintf(json, sizeof(json), "{\"error\":\"%s\"}", error_message);
if (json_len <= 0 || (size_t)json_len >= sizeof(json)) {
rest_close_conn(conn);
return;
}
send_json(conn, status_code, rest_status_text_from_code(status_code), json);
}
void rest_task(void) {
int status;
rest_conn_t *conn;
@@ -139,7 +161,7 @@ void rest_task(void) {
send_response(conn, code, rest_status_text_from_code(code), response->content_type, response->body, response->body_len);
}
else {
send_json(conn, 500, "Internal Server Error", "{\"error\":\"internal_error\"}");
send_json_error(conn, 500, "internal_error");
}
}
@@ -362,9 +384,22 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
rest_close_conn(conn);
}
static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) {
send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body));
}
typedef struct {
rest_header_id_t id;
const char *name;
} rest_header_descriptor_t;
static const rest_header_descriptor_t rest_http_headers[REST_HEADER_TOTAL_COUNT] = {
{ REST_HEADER_USER_AGENT, "User-Agent" },
{ REST_HEADER_AUTHORIZATION, "Authorization" },
{ REST_HEADER_CONTENT_TYPE, "Content-Type" },
{ REST_HEADER_CONTENT_LENGTH, "Content-Length" },
{ REST_HEADER_HOST, "Host" },
{ REST_HEADER_ACCEPT, "Accept" },
{ REST_HEADER_X_SESSION_ID, "X-Session-ID" },
{ REST_HEADER_X_SEQ, "X-Seq" },
{ REST_HEADER_X_SIGNATURE, "X-Signature" }
};
static int parse_request(rest_conn_t *conn, rest_request_t *request) {
char *header_end, *line_end, *cursor;
@@ -440,6 +475,14 @@ static int parse_request(rest_conn_t *conn, rest_request_t *request) {
else if (strcasecmp(name, "Content-Type") == 0) {
request->content_type = value;
}
else {
for (int i = 0; i < REST_HEADER_TOTAL_COUNT; i++) {
if (strcasecmp(name, rest_http_headers[i].name) == 0) {
request->headers[rest_http_headers[i].id] = value;
break;
}
}
}
}
cursor = next + 2;
}
@@ -451,6 +494,63 @@ static int parse_request(rest_conn_t *conn, rest_request_t *request) {
return 1;
}
static int rest_session_derive_key(const uint8_t *session_id, size_t session_id_len, uint8_t derived_key[32]) {
uint8_t kver[32];
const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256);
derive_kver(session_id, session_id_len, kver);
mbedtls_hkdf(md_info, pico_serial_hash, sizeof(pico_serial_hash), kver, 32, (const uint8_t *)"REST/SESSION", 12, derived_key, 32);
mbedtls_platform_zeroize(kver, sizeof(kver));
return PICOKEYS_OK;
}
static int rest_verify_request_signature(const rest_request_t *request, const rest_session_t *session) {
mbedtls_md_context_t ctx;
const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256);
unsigned char hmac[32], hmac_x[32];
size_t olen = 0;
if (md_info == NULL) {
return PICOKEYS_ERR_MEMORY_FATAL;
}
if (mbedtls_base64_decode(hmac_x, sizeof(hmac_x), &olen, (const unsigned char *)request->headers[REST_HEADER_X_SIGNATURE], strlen(request->headers[REST_HEADER_X_SIGNATURE])) != 0) {
return PICOKEYS_EXEC_ERROR;
}
mbedtls_md_init(&ctx);
if (mbedtls_md_setup(&ctx, md_info, 1) != 0) {
mbedtls_md_free(&ctx);
return PICOKEYS_ERR_MEMORY_FATAL;
}
const char *body_empty = "{}";
const char *body = request->body_len > 0 ? request->body : body_empty;
const char *method_str = rest_method_to_string(request->method);
char seq[16];
size_t body_len = request->body_len > 0 ? request->body_len : strlen((const char *)body_empty);
snprintf(seq, sizeof(seq), "%s", request->headers[REST_HEADER_X_SEQ] ? request->headers[REST_HEADER_X_SEQ] : "0");
uint8_t derived_key[32];
if (rest_session_derive_key(session->id, sizeof(session->id), derived_key) != 0) {
mbedtls_md_free(&ctx);
return PICOKEYS_EXEC_ERROR;
}
if (mbedtls_md_hmac_starts(&ctx, (const unsigned char *)derived_key, sizeof(derived_key)) != 0 ||
mbedtls_md_hmac_starts(&ctx, (const unsigned char *)session->id, sizeof(session->id)) != 0 ||
mbedtls_md_hmac_update(&ctx, (const unsigned char *)method_str, strlen(method_str)) != 0 ||
mbedtls_md_hmac_update(&ctx, (const unsigned char *)request->path, strlen(request->path)) != 0 ||
mbedtls_md_hmac_update(&ctx, (const unsigned char *)seq, strlen(seq)) != 0 ||
mbedtls_md_hmac_update(&ctx, (const unsigned char *)body, body_len) != 0) {
mbedtls_md_free(&ctx);
return PICOKEYS_EXEC_ERROR;
}
if (mbedtls_md_hmac_finish(&ctx, hmac) != 0) {
mbedtls_md_free(&ctx);
return PICOKEYS_EXEC_ERROR;
}
mbedtls_md_free(&ctx);
if (ct_memcmp(hmac, hmac_x, sizeof(hmac)) != 0) {
return PICOKEYS_EXEC_ERROR;
}
return PICOKEYS_OK;
}
void rest_handle_request(rest_conn_t *conn) {
rest_request_t *request = &rest_core1_job.request;
const rest_route_t *routes;
@@ -459,7 +559,7 @@ void rest_handle_request(rest_conn_t *conn) {
int parsed;
if (rest_core1_job.pending) {
send_json(conn, 503, "Service Unavailable", "{\"error\":\"busy\"}");
send_json_error(conn, 503, "busy");
return;
}
@@ -467,7 +567,7 @@ void rest_handle_request(rest_conn_t *conn) {
parsed = parse_request(conn, request);
if (parsed <= 0) {
if (parsed < 0) {
send_json(conn, 400, "Bad Request", "{\"error\":\"bad_request\"}");
send_json_error(conn, 400, "bad_request");
}
return;
}
@@ -482,7 +582,7 @@ void rest_handle_request(rest_conn_t *conn) {
if (request->method == REST_HTTP_POST || request->method == REST_HTTP_PUT) {
if (!rest_content_type_is_json(request->content_type)) {
send_json(conn, 415, "Unsupported Media Type", "{\"error\":\"content_type_must_be_application_json\"}");
send_json_error(conn, 415, "content_type_must_be_application_json");
return;
}
}
@@ -498,17 +598,41 @@ void rest_handle_request(rest_conn_t *conn) {
path_exists_for_other_method = true;
continue;
}
if (routes[i].flags & REST_ROUTE_REQUIRE_AUTH) {
if (!request->headers[REST_HEADER_X_SESSION_ID] || strlen(request->headers[REST_HEADER_X_SESSION_ID]) == 0 ||!request->headers[REST_HEADER_X_SIGNATURE] || strlen(request->headers[REST_HEADER_X_SIGNATURE]) == 0 || !request->headers[REST_HEADER_X_SEQ] || strlen(request->headers[REST_HEADER_X_SEQ]) == 0) {
send_json_error(conn, 401, "authentication_required");
return;
}
rest_session_t *session = rest_session_get((const uint8_t *)request->headers[REST_HEADER_X_SESSION_ID], strlen(request->headers[REST_HEADER_X_SESSION_ID]));
if (!session) {
send_json_error(conn, 401, "authentication_required");
return;
}
if (session->status != REST_SESSION_AUTHENTICATED) {
send_json_error(conn, 401, "authentication_required");
return;
}
if (session->last_activity_timestamp + REST_SESSION_TIMEOUT_INACTIVITY_MS < board_millis() || session->created_at + REST_SESSION_TIMEOUT_TOTAL_MS < board_millis()) {
session->status = REST_SESSION_EXPIRED;
send_json_error(conn, 401, "session_expired");
return;
}
if (rest_verify_request_signature(request, session) != PICOKEYS_OK) {
send_json_error(conn, 401, "invalid_signature");
return;
}
}
if (rest_start_core1_job(conn, request, routes[i].handler) != 0) {
send_json(conn, 500, "Internal Server Error", "{\"error\":\"internal_error\"}");
send_json_error(conn, 500, "internal_error");
}
return;
}
if (path_exists_for_other_method) {
send_json(conn, 405, "Method Not Allowed", "{\"error\":\"method_not_allowed\"}");
send_json_error(conn, 405, "method_not_allowed");
}
else {
send_json(conn, 404, "Not Found", "{\"error\":\"not_found\"}");
send_json_error(conn, 404, "not_found");
}
}
@@ -626,7 +750,7 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err
rest_close_conn(conn);
return ERR_ABRT;
}
send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}");
send_json_error(conn, 413, "payload_too_large");
return ERR_OK;
}
pbuf_copy_partial(p, buffer + *len, p->tot_len, 0);
@@ -812,7 +936,7 @@ static void *rest_emulation_thread(void *arg) {
}
conn->request_len += (size_t)n;
if (conn->request_len > REST_MAX_REQUEST_SIZE) {
send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}");
send_json_error(conn, 413, "payload_too_large");
break;
}
rest_handle_request(conn);