diff --git a/src/usb/lwip/rest_server.c b/src/usb/lwip/rest_server.c index 9b014f3..0060a8d 100644 --- a/src/usb/lwip/rest_server.c +++ b/src/usb/lwip/rest_server.c @@ -26,6 +26,9 @@ #include #ifdef _WIN32 #include "compat/pthread_win32.h" +#ifdef _MSC_VER +#include +#endif typedef SOCKET socket_t; typedef int socklen_t; #define close closesocket @@ -68,7 +71,7 @@ static pthread_t rest_thread; static rest_conn_t conns[REST_MAX_CONNS]; typedef struct { - bool pending; + volatile long pending; rest_conn_t *conn; rest_route_handler_t handler; rest_request_t request; @@ -82,6 +85,31 @@ typedef struct { static rest_core1_job_t rest_core1_job = {0}; static rest_core1_result_t rest_core1_result = {0}; +static bool rest_core1_job_pending_try_acquire(void) { +#ifdef _MSC_VER + return InterlockedCompareExchange(&rest_core1_job.pending, 1, 0) == 0; +#else + long expected = 0; + return __atomic_compare_exchange_n(&rest_core1_job.pending, &expected, 1, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE); +#endif +} + +static bool rest_core1_job_pending_load(void) { +#ifdef _MSC_VER + return InterlockedCompareExchange(&rest_core1_job.pending, 0, 0) != 0; +#else + return __atomic_load_n(&rest_core1_job.pending, __ATOMIC_ACQUIRE) != 0; +#endif +} + +static void rest_core1_job_pending_store(bool pending) { +#ifdef _MSC_VER + InterlockedExchange(&rest_core1_job.pending, pending ? 1 : 0); +#else + __atomic_store_n(&rest_core1_job.pending, pending ? 1 : 0, __ATOMIC_RELEASE); +#endif +} + typedef struct { rest_header_id_t id; const char *name; @@ -104,16 +132,24 @@ static void *rest_core1_thread(void *arg); static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]); void rest_close_conn(rest_conn_t *conn); -static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler) { - if (handler == NULL || rest_core1_job.pending) { +static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler, const rest_request_t *request) { + if (handler == NULL) { + return -1; + } + if (!rest_core1_job_pending_try_acquire()) { return -1; } memset(&rest_core1_result, 0, sizeof(rest_core1_result)); - rest_core1_job.pending = true; rest_core1_job.conn = conn; rest_core1_job.handler = handler; + if (request != NULL) { + rest_core1_job.request = *request; + } + else { + memset(&rest_core1_job.request, 0, sizeof(rest_core1_job.request)); + } card_start(ITF_LWIP, rest_core1_thread); usb_send_event(EV_CMD_AVAILABLE); @@ -137,7 +173,7 @@ static void *rest_core1_thread(void *arg) { if (m == EV_CMD_AVAILABLE) { rest_core1_result.ready = false; memset(&rest_core1_result.response, 0, sizeof(rest_core1_result.response)); - if (!rest_core1_job.pending || rest_core1_job.handler == NULL || rest_execute_route_handler(&rest_core1_job.request, rest_core1_job.handler, &rest_core1_result.response) != 0) { + if (!rest_core1_job_pending_load() || rest_core1_job.handler == NULL || rest_execute_route_handler(&rest_core1_job.request, rest_core1_job.handler, &rest_core1_result.response) != 0) { rest_server_error(&rest_core1_result.response, 500, "internal_error"); } rest_core1_result.ready = true; @@ -166,10 +202,10 @@ static void send_json_error(rest_conn_t *conn, int status_code, const char *erro } void rest_task(void) { - if (!rest_core1_job.pending) { + if (!rest_core1_job_pending_load()) { rest_route_handler_t handler = rest_background_job_pop(); if (handler != NULL) { - if (rest_start_core1_job(NULL, handler) != 0) { + if (rest_start_core1_job(NULL, handler, NULL) != 0) { // Failed to start background job, push it back to the queue rest_background_job_push(handler); } @@ -182,34 +218,42 @@ void rest_task(void) { } rest_conn_t *conn = rest_core1_job.conn; - //if (conn == NULL) { - // return; - //} rest_response_t *response = &rest_core1_result.response; rest_request_t *request = &rest_core1_job.request; - if (response == NULL) { - return; + bool ready = rest_core1_result.ready; + + uint16_t code = response->status_code == 0 ? 200 : response->status_code; + const char *content_type = response->content_type; + char *body = response->body; + size_t body_len = response->body_len; + char *headers[REST_HEADER_TOTAL_COUNT] = {0}; + for (size_t i = 0; i < REST_HEADER_TOTAL_COUNT; i++) { + headers[i] = response->headers[i]; } + rest_query_t *query = request->query; + + // Release the shared core1 slot before doing potentially slow network I/O. + memset(&rest_core1_result, 0, sizeof(rest_core1_result)); + rest_core1_job.conn = NULL; + rest_core1_job.handler = NULL; + memset(&rest_core1_job.request, 0, sizeof(rest_core1_job.request)); + rest_core1_job_pending_store(false); + if (conn != NULL) { - if (rest_core1_result.ready && response->body != NULL && response->content_type != NULL) { - uint16_t code = response->status_code == 0 ? 200 : response->status_code; - send_response(conn, code, rest_status_text_from_code(code), response->content_type, response->body, response->body_len, response->headers); + if (ready && body != NULL && content_type != NULL) { + send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers); } else { send_json_error(conn, 500, "internal_error"); } } - if (response->body != NULL) { - free(response->body); - response->body = NULL; + if (body != NULL) { + free(body); } - if (request->query != NULL) { - free(request->query); - request->query = NULL; + if (query != NULL) { + free(query); } - memset(&rest_core1_result, 0, sizeof(rest_core1_result)); - memset(&rest_core1_job, 0, sizeof(rest_core1_job)); } static rest_conn_t *alloc_conn( @@ -787,16 +831,13 @@ static int parse_query(rest_request_t *request, char *query_str) { } void rest_handle_request(rest_conn_t *conn) { - rest_request_t *request = &rest_core1_job.request; + rest_request_t request_local = {0}; + rest_request_t *request = &request_local; const rest_route_t *routes; size_t route_count = 0, i; bool path_exists_for_other_method = false; int parsed; - if (request && request->query) { - free(request->query); - } - memset(request, 0, sizeof(*request)); parsed = parse_request(conn, request); if (parsed <= 0) { if (parsed < 0) { @@ -805,7 +846,11 @@ void rest_handle_request(rest_conn_t *conn) { return; } - if (rest_core1_job.pending && !(request->method == REST_HTTP_POST && strcmp(request->path, "/device/jobs/cancel") == 0)) { + if (rest_core1_job_pending_load() && !(request->method == REST_HTTP_POST && strcmp(request->path, "/device/jobs/cancel") == 0)) { + if (request->query != NULL) { + free(request->query); + request->query = NULL; + } send_json_error(conn, 503, "busy"); return; } @@ -827,8 +872,7 @@ void rest_handle_request(rest_conn_t *conn) { routes = rest_get_routes(&route_count); char *query_str = strchr(request->path, '?'); if (query_str != NULL) { - *query_str++ = '\0'; - if (*query_str == '\0') { + if (*(query_str + 1) == '\0') { send_json_error(conn, 400, "bad_request"); return; } @@ -843,8 +887,10 @@ void rest_handle_request(rest_conn_t *conn) { continue; } } - else if (strcmp(routes[i].path, request->path) != 0) { - continue; + else { + if ((query_str ? strncmp(routes[i].path, request->path, (size_t)(query_str - request->path)) : strcmp(routes[i].path, request->path)) != 0) { + continue; + } } if (!(routes[i].method & request->method)) { path_exists_for_other_method = true; @@ -912,12 +958,12 @@ void rest_handle_request(rest_conn_t *conn) { return; } - if (query_str != NULL && parse_query(request, query_str) != PICOKEYS_OK) { + if (query_str != NULL && parse_query(request, query_str + 1) != PICOKEYS_OK) { send_json_error(conn, 400, "bad_request"); return; } - if (rest_start_core1_job(conn, routes[i].handler) != 0) { - send_json_error(conn, 500, "internal_error"); + if (rest_start_core1_job(conn, routes[i].handler, request) != 0) { + send_json_error(conn, 503, "busy"); if (request->query != NULL) { free(request->query); request->query = NULL; @@ -1066,7 +1112,7 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err static err_t rest_poll(void *arg, struct tcp_pcb *pcb) { rest_conn_t *conn = (rest_conn_t *)arg; LWIP_UNUSED_ARG(pcb); - if (rest_core1_job.pending && rest_core1_job.conn == conn) { + if (rest_core1_job_pending_load() && rest_core1_job.conn == conn) { return ERR_OK; } if (conn->conn_type == REST_CONN_TLS) {