From a906628318640a0636f50c60b1f8d1785bcbaf60 Mon Sep 17 00:00:00 2001 From: Pol Henarejos Date: Tue, 21 Apr 2026 20:40:19 +0200 Subject: [PATCH] Added session key negotiation. Signed-off-by: Pol Henarejos --- src/usb/lwip/rest.c | 76 +++++++++++++++++++++++++++++++++++++- src/usb/lwip/rest.h | 4 +- src/usb/lwip/rest_server.c | 29 +++++++++------ 3 files changed, 95 insertions(+), 14 deletions(-) diff --git a/src/usb/lwip/rest.c b/src/usb/lwip/rest.c index 9c9a03e..a6fc6a3 100644 --- a/src/usb/lwip/rest.c +++ b/src/usb/lwip/rest.c @@ -21,17 +21,29 @@ #include #include "random.h" #include "crypto_utils.h" +#include "serial.h" + +#include "mbedtls/ecdh.h" +#include "mbedtls/ecp.h" +#include "mbedtls/hkdf.h" +#include "mbedtls/platform_util.h" #define REST_MAX_SESSIONS 4 static rest_session_t rest_sessions[REST_MAX_SESSIONS] = {0}; +static int x25519_hkdf_derive_key32(const uint8_t sk[32], const uint8_t pk[32], const uint8_t *salt, size_t salt_len, const uint8_t *info, size_t info_len, uint8_t out_key[32]); -rest_session_t *rest_session_create(const rest_session_role_t role, rest_session_status_t status) { +rest_session_t *rest_session_create(const rest_session_role_t role, rest_session_status_t status, const uint8_t public_key[32]) { for (int i = 0; i < REST_MAX_SESSIONS; i++) { if (rest_sessions[i].status == REST_SESSION_UNKNOWN || rest_sessions[i].status == REST_SESSION_EXPIRED || rest_sessions[i].status == REST_SESSION_TERMINATED) { memset(&rest_sessions[i], 0, sizeof(rest_session_t)); rest_sessions[i].status = status; rest_sessions[i].role = role; + if (public_key != NULL) { + memcpy(rest_sessions[i].public_key, public_key, sizeof(rest_sessions[i].public_key)); + } else { + memset(rest_sessions[i].public_key, 0, sizeof(rest_sessions[i].public_key)); + } random_fill_buffer(rest_sessions[i].id, sizeof(rest_sessions[i].id)); rest_sessions[i].created_at = board_millis(); rest_sessions[i].last_activity_timestamp = rest_sessions[i].created_at; @@ -268,3 +280,65 @@ __attribute__((weak)) const rest_route_t *rest_get_routes(size_t *count) { } return NULL; } + + +static int x25519_hkdf_derive_key32(const uint8_t sk[32], const uint8_t pk[32], const uint8_t *salt, size_t salt_len, const uint8_t *info, size_t info_len, uint8_t out_key[32]) { + int ret = -1; + size_t shared_len = 0; + uint8_t shared[32] = {0}; + + mbedtls_ecdh_context ecdh; + mbedtls_ecp_keypair ours, theirs; + const mbedtls_md_info_t *md = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); + + mbedtls_ecdh_init(&ecdh); + mbedtls_ecp_keypair_init(&ours); + mbedtls_ecp_keypair_init(&theirs); + + if (md == NULL) { + ret = MBEDTLS_ERR_MD_BAD_INPUT_DATA; + goto cleanup; + } + + MBEDTLS_MPI_CHK(mbedtls_ecp_group_load(&ours.grp, MBEDTLS_ECP_DP_CURVE25519)); + MBEDTLS_MPI_CHK(mbedtls_ecp_group_load(&theirs.grp, MBEDTLS_ECP_DP_CURVE25519)); + + MBEDTLS_MPI_CHK(mbedtls_ecp_read_key(MBEDTLS_ECP_DP_CURVE25519, &ours, sk, 32)); + + // Carrega pública remota (32 bytes) + MBEDTLS_MPI_CHK(mbedtls_ecp_point_read_binary(&theirs.grp, &theirs.Q, pk, 32)); + + MBEDTLS_MPI_CHK(mbedtls_ecdh_setup(&ecdh, MBEDTLS_ECP_DP_CURVE25519)); + MBEDTLS_MPI_CHK(mbedtls_ecdh_get_params(&ecdh, &ours, MBEDTLS_ECDH_OURS)); + MBEDTLS_MPI_CHK(mbedtls_ecdh_get_params(&ecdh, &theirs, MBEDTLS_ECDH_THEIRS)); + + MBEDTLS_MPI_CHK(mbedtls_ecdh_calc_secret(&ecdh, &shared_len, shared, sizeof(shared), random_fill_iterator, NULL)); + + if (shared_len != 32) { + ret = MBEDTLS_ERR_ECP_BAD_INPUT_DATA; + goto cleanup; + } + + ret = mbedtls_hkdf(md, salt, salt_len, shared, shared_len, info, info_len, out_key, 32); + +cleanup: + mbedtls_platform_zeroize(shared, sizeof(shared)); + mbedtls_ecdh_free(&ecdh); + mbedtls_ecp_keypair_free(&ours); + mbedtls_ecp_keypair_free(&theirs); + return ret; +} + +int rest_session_derive_key(const rest_session_t *session, uint8_t derived_key[32]) { + uint8_t kver[32], sk[32]; + const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); + derive_kver(session->id, sizeof(session->id), kver); + mbedtls_hkdf(md_info, pico_serial_hash, sizeof(pico_serial_hash), kver, 32, (const uint8_t *)"REST/SESSION", 12, derived_key, 32); + mbedtls_platform_zeroize(kver, sizeof(kver)); + int ret = x25519_hkdf_derive_key32(sk, session->public_key, session->id, sizeof(session->id), (const uint8_t *)"REST/SESSION/DERIVE", 20, derived_key); + mbedtls_platform_zeroize(sk, sizeof(sk)); + if (ret != 0) { + return -1; + } + return PICOKEYS_OK; +} diff --git a/src/usb/lwip/rest.h b/src/usb/lwip/rest.h index 52e3322..fa5c105 100644 --- a/src/usb/lwip/rest.h +++ b/src/usb/lwip/rest.h @@ -102,6 +102,7 @@ typedef enum { } rest_session_status_t; typedef struct { + uint8_t public_key[32]; uint8_t id[16]; uint8_t id_str[25]; time_t last_activity_timestamp; @@ -120,7 +121,7 @@ bool rest_content_type_is_json(const char *content_type); const rest_route_t *rest_get_routes(size_t *count); -extern rest_session_t *rest_session_create(const rest_session_role_t role, rest_session_status_t status); +extern rest_session_t *rest_session_create(const rest_session_role_t role, rest_session_status_t status, const uint8_t public_key[32]); extern rest_session_t *rest_session_get(const uint8_t *id, size_t id_len); extern rest_session_t *rest_session_get_by_id_str(const char *id_str); extern int rest_session_terminate(const uint8_t *id, size_t id_len); @@ -128,6 +129,7 @@ extern int rest_session_update_activity(const uint8_t *id, size_t id_len); extern int rest_session_set_status(const uint8_t *id, size_t id_len, rest_session_status_t status); extern int rest_session_set_role(const uint8_t *id, size_t id_len, rest_session_role_t role); extern int rest_session_cleanup_expired(time_t expiration_time); +extern int rest_session_derive_key(const rest_session_t *session, uint8_t derived_key[32]); #ifdef DEBUG_APDU extern void rest_debug_dump_payload(const char *tag, const char *buffer, size_t len); diff --git a/src/usb/lwip/rest_server.c b/src/usb/lwip/rest_server.c index a4c314d..cdd6a57 100644 --- a/src/usb/lwip/rest_server.c +++ b/src/usb/lwip/rest_server.c @@ -597,13 +597,11 @@ static int parse_request(rest_conn_t *conn, rest_request_t *request) { return 1; } -static int rest_session_derive_key(const uint8_t *session_id, size_t session_id_len, uint8_t derived_key[32]) { - uint8_t kver[32]; - const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256); - derive_kver(session_id, session_id_len, kver); - mbedtls_hkdf(md_info, pico_serial_hash, sizeof(pico_serial_hash), kver, 32, (const uint8_t *)"REST/SESSION", 12, derived_key, 32); - mbedtls_platform_zeroize(kver, sizeof(kver)); - return PICOKEYS_OK; +static uint32_t rest_request_get_seq(const rest_request_t *request) { + if (request == NULL || request->headers[REST_HEADER_X_SEQ] == NULL) { + return 0; + } + return (uint32_t)strtoul(request->headers[REST_HEADER_X_SEQ], NULL, 10); } static int rest_verify_request_signature(const rest_request_t *request, const rest_session_t *session) { @@ -625,25 +623,28 @@ static int rest_verify_request_signature(const rest_request_t *request, const re const char *body_empty = "{}"; const char *body = request->body_len > 0 ? request->body : body_empty; const char *method_str = rest_method_to_string(request->method); - char seq[16]; size_t body_len = request->body_len > 0 ? request->body_len : strlen((const char *)body_empty); - snprintf(seq, sizeof(seq), "%s", request->headers[REST_HEADER_X_SEQ] ? request->headers[REST_HEADER_X_SEQ] : "0"); uint8_t derived_key[32]; - if (rest_session_derive_key(session->id, sizeof(session->id), derived_key) != 0) { + if (rest_session_derive_key(session, derived_key) != 0) { mbedtls_md_free(&ctx); return PICOKEYS_EXEC_ERROR; } + uint32_t seq = htonl(rest_request_get_seq(request)); if (mbedtls_md_hmac_starts(&ctx, (const unsigned char *)derived_key, sizeof(derived_key)) != 0 || mbedtls_md_hmac_starts(&ctx, (const unsigned char *)session->id, sizeof(session->id)) != 0 || mbedtls_md_hmac_update(&ctx, (const unsigned char *)method_str, strlen(method_str)) != 0 || mbedtls_md_hmac_update(&ctx, (const unsigned char *)request->path, strlen(request->path)) != 0 || - mbedtls_md_hmac_update(&ctx, (const unsigned char *)seq, strlen(seq)) != 0 || - mbedtls_md_hmac_update(&ctx, (const unsigned char *)body, body_len) != 0) { + mbedtls_md_hmac_update(&ctx, (const unsigned char *)&seq, sizeof(uint32_t)) != 0 || + mbedtls_md_hmac_update(&ctx, (const unsigned char *)body, body_len) != 0) + { + mbedtls_platform_zeroize(derived_key, sizeof(derived_key)); mbedtls_md_free(&ctx); return PICOKEYS_EXEC_ERROR; } + mbedtls_platform_zeroize(derived_key, sizeof(derived_key)); if (mbedtls_md_hmac_finish(&ctx, hmac) != 0) { + mbedtls_platform_zeroize(derived_key, sizeof(derived_key)); mbedtls_md_free(&ctx); return PICOKEYS_EXEC_ERROR; } @@ -720,6 +721,10 @@ void rest_handle_request(rest_conn_t *conn) { send_json_error(conn, 401, "session_expired"); return; } + if (rest_request_get_seq(request) < session->last_seq) { + send_json_error(conn, 401, "invalid_seq"); + return; + } if (rest_verify_request_signature(request, session) != PICOKEYS_OK) { send_json_error(conn, 401, "invalid_signature"); return;