Fix race condition when generating key.

Signed-off-by: Pol Henarejos <pol.henarejos@cttc.es>
This commit is contained in:
Pol Henarejos
2026-05-12 19:04:01 +02:00
parent b3c91f068d
commit a4a1651ed4

View File

@@ -26,6 +26,9 @@
#include <ctype.h>
#ifdef _WIN32
#include "compat/pthread_win32.h"
#ifdef _MSC_VER
#include <windows.h>
#endif
typedef SOCKET socket_t;
typedef int socklen_t;
#define close closesocket
@@ -68,7 +71,7 @@ static pthread_t rest_thread;
static rest_conn_t conns[REST_MAX_CONNS];
typedef struct {
bool pending;
volatile long pending;
rest_conn_t *conn;
rest_route_handler_t handler;
rest_request_t request;
@@ -82,6 +85,31 @@ typedef struct {
static rest_core1_job_t rest_core1_job = {0};
static rest_core1_result_t rest_core1_result = {0};
static bool rest_core1_job_pending_try_acquire(void) {
#ifdef _MSC_VER
return InterlockedCompareExchange(&rest_core1_job.pending, 1, 0) == 0;
#else
long expected = 0;
return __atomic_compare_exchange_n(&rest_core1_job.pending, &expected, 1, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE);
#endif
}
static bool rest_core1_job_pending_load(void) {
#ifdef _MSC_VER
return InterlockedCompareExchange(&rest_core1_job.pending, 0, 0) != 0;
#else
return __atomic_load_n(&rest_core1_job.pending, __ATOMIC_ACQUIRE) != 0;
#endif
}
static void rest_core1_job_pending_store(bool pending) {
#ifdef _MSC_VER
InterlockedExchange(&rest_core1_job.pending, pending ? 1 : 0);
#else
__atomic_store_n(&rest_core1_job.pending, pending ? 1 : 0, __ATOMIC_RELEASE);
#endif
}
typedef struct {
rest_header_id_t id;
const char *name;
@@ -104,16 +132,24 @@ static void *rest_core1_thread(void *arg);
static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]);
void rest_close_conn(rest_conn_t *conn);
static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler) {
if (handler == NULL || rest_core1_job.pending) {
static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler, const rest_request_t *request) {
if (handler == NULL) {
return -1;
}
if (!rest_core1_job_pending_try_acquire()) {
return -1;
}
memset(&rest_core1_result, 0, sizeof(rest_core1_result));
rest_core1_job.pending = true;
rest_core1_job.conn = conn;
rest_core1_job.handler = handler;
if (request != NULL) {
rest_core1_job.request = *request;
}
else {
memset(&rest_core1_job.request, 0, sizeof(rest_core1_job.request));
}
card_start(ITF_LWIP, rest_core1_thread);
usb_send_event(EV_CMD_AVAILABLE);
@@ -137,7 +173,7 @@ static void *rest_core1_thread(void *arg) {
if (m == EV_CMD_AVAILABLE) {
rest_core1_result.ready = false;
memset(&rest_core1_result.response, 0, sizeof(rest_core1_result.response));
if (!rest_core1_job.pending || rest_core1_job.handler == NULL || rest_execute_route_handler(&rest_core1_job.request, rest_core1_job.handler, &rest_core1_result.response) != 0) {
if (!rest_core1_job_pending_load() || rest_core1_job.handler == NULL || rest_execute_route_handler(&rest_core1_job.request, rest_core1_job.handler, &rest_core1_result.response) != 0) {
rest_server_error(&rest_core1_result.response, 500, "internal_error");
}
rest_core1_result.ready = true;
@@ -166,10 +202,10 @@ static void send_json_error(rest_conn_t *conn, int status_code, const char *erro
}
void rest_task(void) {
if (!rest_core1_job.pending) {
if (!rest_core1_job_pending_load()) {
rest_route_handler_t handler = rest_background_job_pop();
if (handler != NULL) {
if (rest_start_core1_job(NULL, handler) != 0) {
if (rest_start_core1_job(NULL, handler, NULL) != 0) {
// Failed to start background job, push it back to the queue
rest_background_job_push(handler);
}
@@ -182,34 +218,42 @@ void rest_task(void) {
}
rest_conn_t *conn = rest_core1_job.conn;
//if (conn == NULL) {
// return;
//}
rest_response_t *response = &rest_core1_result.response;
rest_request_t *request = &rest_core1_job.request;
if (response == NULL) {
return;
bool ready = rest_core1_result.ready;
uint16_t code = response->status_code == 0 ? 200 : response->status_code;
const char *content_type = response->content_type;
char *body = response->body;
size_t body_len = response->body_len;
char *headers[REST_HEADER_TOTAL_COUNT] = {0};
for (size_t i = 0; i < REST_HEADER_TOTAL_COUNT; i++) {
headers[i] = response->headers[i];
}
rest_query_t *query = request->query;
// Release the shared core1 slot before doing potentially slow network I/O.
memset(&rest_core1_result, 0, sizeof(rest_core1_result));
rest_core1_job.conn = NULL;
rest_core1_job.handler = NULL;
memset(&rest_core1_job.request, 0, sizeof(rest_core1_job.request));
rest_core1_job_pending_store(false);
if (conn != NULL) {
if (rest_core1_result.ready && response->body != NULL && response->content_type != NULL) {
uint16_t code = response->status_code == 0 ? 200 : response->status_code;
send_response(conn, code, rest_status_text_from_code(code), response->content_type, response->body, response->body_len, response->headers);
if (ready && body != NULL && content_type != NULL) {
send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers);
}
else {
send_json_error(conn, 500, "internal_error");
}
}
if (response->body != NULL) {
free(response->body);
response->body = NULL;
if (body != NULL) {
free(body);
}
if (request->query != NULL) {
free(request->query);
request->query = NULL;
if (query != NULL) {
free(query);
}
memset(&rest_core1_result, 0, sizeof(rest_core1_result));
memset(&rest_core1_job, 0, sizeof(rest_core1_job));
}
static rest_conn_t *alloc_conn(
@@ -787,16 +831,13 @@ static int parse_query(rest_request_t *request, char *query_str) {
}
void rest_handle_request(rest_conn_t *conn) {
rest_request_t *request = &rest_core1_job.request;
rest_request_t request_local = {0};
rest_request_t *request = &request_local;
const rest_route_t *routes;
size_t route_count = 0, i;
bool path_exists_for_other_method = false;
int parsed;
if (request && request->query) {
free(request->query);
}
memset(request, 0, sizeof(*request));
parsed = parse_request(conn, request);
if (parsed <= 0) {
if (parsed < 0) {
@@ -805,7 +846,11 @@ void rest_handle_request(rest_conn_t *conn) {
return;
}
if (rest_core1_job.pending && !(request->method == REST_HTTP_POST && strcmp(request->path, "/device/jobs/cancel") == 0)) {
if (rest_core1_job_pending_load() && !(request->method == REST_HTTP_POST && strcmp(request->path, "/device/jobs/cancel") == 0)) {
if (request->query != NULL) {
free(request->query);
request->query = NULL;
}
send_json_error(conn, 503, "busy");
return;
}
@@ -827,8 +872,7 @@ void rest_handle_request(rest_conn_t *conn) {
routes = rest_get_routes(&route_count);
char *query_str = strchr(request->path, '?');
if (query_str != NULL) {
*query_str++ = '\0';
if (*query_str == '\0') {
if (*(query_str + 1) == '\0') {
send_json_error(conn, 400, "bad_request");
return;
}
@@ -843,8 +887,10 @@ void rest_handle_request(rest_conn_t *conn) {
continue;
}
}
else if (strcmp(routes[i].path, request->path) != 0) {
continue;
else {
if ((query_str ? strncmp(routes[i].path, request->path, (size_t)(query_str - request->path)) : strcmp(routes[i].path, request->path)) != 0) {
continue;
}
}
if (!(routes[i].method & request->method)) {
path_exists_for_other_method = true;
@@ -912,12 +958,12 @@ void rest_handle_request(rest_conn_t *conn) {
return;
}
if (query_str != NULL && parse_query(request, query_str) != PICOKEYS_OK) {
if (query_str != NULL && parse_query(request, query_str + 1) != PICOKEYS_OK) {
send_json_error(conn, 400, "bad_request");
return;
}
if (rest_start_core1_job(conn, routes[i].handler) != 0) {
send_json_error(conn, 500, "internal_error");
if (rest_start_core1_job(conn, routes[i].handler, request) != 0) {
send_json_error(conn, 503, "busy");
if (request->query != NULL) {
free(request->query);
request->query = NULL;
@@ -1066,7 +1112,7 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err
static err_t rest_poll(void *arg, struct tcp_pcb *pcb) {
rest_conn_t *conn = (rest_conn_t *)arg;
LWIP_UNUSED_ARG(pcb);
if (rest_core1_job.pending && rest_core1_job.conn == conn) {
if (rest_core1_job_pending_load() && rest_core1_job.conn == conn) {
return ERR_OK;
}
if (conn->conn_type == REST_CONN_TLS) {