From 3ddb459e5c399aac220a9335afa9c399be860cf2 Mon Sep 17 00:00:00 2001 From: Pol Henarejos Date: Mon, 18 May 2026 16:41:27 +0200 Subject: [PATCH] Add stop-n-wait mechanism to avoid sending too huge payloads. Signed-off-by: Pol Henarejos --- src/usb/lwip/rest.h | 1 + src/usb/lwip/rest_server.c | 174 +++++++++++++++++++++++++++++-------- src/usb/lwip/rest_server.h | 7 ++ 3 files changed, 144 insertions(+), 38 deletions(-) diff --git a/src/usb/lwip/rest.h b/src/usb/lwip/rest.h index 832d350..8ee9e17 100644 --- a/src/usb/lwip/rest.h +++ b/src/usb/lwip/rest.h @@ -18,6 +18,7 @@ #ifndef REST_H #define REST_H +#include "compat/compat.h" #include #include #include diff --git a/src/usb/lwip/rest_server.c b/src/usb/lwip/rest_server.c index e031f1e..c852916 100644 --- a/src/usb/lwip/rest_server.c +++ b/src/usb/lwip/rest_server.c @@ -133,8 +133,11 @@ static const rest_header_descriptor_t rest_http_headers[REST_HEADER_TOTAL_COUNT] }; static void *rest_core1_thread(void *arg); -static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]); +static bool send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT], bool body_owned); void rest_close_conn(rest_conn_t *conn); +#ifndef ENABLE_EMULATION +static err_t rest_lwip_continue_send(rest_conn_t *conn); +#endif static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler, const rest_request_t *request) { if (handler == NULL) { @@ -192,7 +195,7 @@ static void *rest_core1_thread(void *arg) { } static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) { - send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body), NULL); + (void)send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body), NULL, false); } static void send_json_error(rest_conn_t *conn, int status_code, const char *error_message) { @@ -245,7 +248,10 @@ void rest_task(void) { if (conn != NULL) { if (ready && body != NULL && content_type != NULL) { - send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers); + bool body_transferred = send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers, true); + if (body_transferred) { + body = NULL; + } } else { send_json_error(conn, 500, "internal_error"); @@ -319,6 +325,16 @@ void rest_close_conn(rest_conn_t *conn) { clear_conn(conn); #else err_t err; + if (conn != NULL && conn->tx_pending && conn->tx_body != NULL && conn->tx_body_owned) { + free((void *)conn->tx_body); + } + if (conn != NULL) { + conn->tx_body = NULL; + conn->tx_body_len = 0; + conn->tx_body_sent = 0; + conn->tx_pending = false; + conn->tx_body_owned = false; + } if (conn == NULL || conn->pcb == NULL) { clear_conn(conn); return; @@ -343,10 +359,11 @@ void rest_close_conn(rest_conn_t *conn) { #endif } -static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]) { +static bool send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT], bool body_owned) { char headers_buf[256]; int header_len; #ifdef ENABLE_EMULATION + (void)body_owned; size_t sent_total = 0; #else err_t err; @@ -360,7 +377,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status || conn->pcb == NULL #endif ) { - return; + return false; } REST_DEBUG_LOG( @@ -387,7 +404,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status headers[i] = NULL; if (n <= 0 || header_len + n >= (int)sizeof(headers_buf)) { rest_close_conn(conn); - return; + return false; } header_len += n; } @@ -395,13 +412,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status } if (header_len + 2 >= (int)sizeof(headers_buf)) { rest_close_conn(conn); - return; + return false; } memcpy(p + header_len, "\r\n", 2); header_len += 2; if (header_len <= 0 || (size_t)header_len >= sizeof(headers_buf)) { rest_close_conn(conn); - return; + return false; } if (conn->conn_type == REST_CONN_TLS) { @@ -414,13 +431,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { if (++want_retries > 2048) { rest_close_conn(conn); - return; + return false; } continue; } if (ret <= 0) { rest_close_conn(conn); - return; + return false; } want_retries = 0; written += (size_t)ret; @@ -432,13 +449,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { if (++want_retries > 2048) { rest_close_conn(conn); - return; + return false; } continue; } if (ret <= 0) { rest_close_conn(conn); - return; + return false; } want_retries = 0; written += (size_t)ret; @@ -446,14 +463,14 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status (void)mbedtls_ssl_close_notify(&conn->ssl); rest_close_conn(conn); - return; + return false; } #ifdef ENABLE_EMULATION while (sent_total < (size_t)header_len) { ssize_t n = send((socket_t)conn->sock, headers_buf + sent_total, (int)((size_t)header_len - sent_total), 0); if (n <= 0) { rest_close_conn(conn); - return; + return false; } sent_total += (size_t)n; } @@ -462,7 +479,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status ssize_t n = send((socket_t)conn->sock, body + sent_total, (int)(body_len - sent_total), 0); if (n <= 0) { rest_close_conn(conn); - return; + return false; } sent_total += (size_t)n; } @@ -472,7 +489,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status while (tcp_sndbuf(conn->pcb) < (u16_t)header_len) { if (retries >= 16) { rest_close_conn(conn); - return; + return false; } (void)tcp_output(conn->pcb); retries++; @@ -485,44 +502,64 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status } if (err != ERR_MEM || retries >= 16) { rest_close_conn(conn); - return; + return false; } (void)tcp_output(conn->pcb); retries++; } if (body_len > 0) { - retries = 0; - while (tcp_sndbuf(conn->pcb) < (u16_t)body_len) { - if (retries >= 16) { - rest_close_conn(conn); - return; + if (body_len > 1024) { + if (body_owned) { + conn->tx_body = body; + conn->tx_body_owned = true; } - (void)tcp_output(conn->pcb); - retries++; + else { + char *tx_copy = (char *)malloc(body_len); + if (tx_copy == NULL) { + rest_close_conn(conn); + return false; + } + memcpy(tx_copy, body, body_len); + conn->tx_body = tx_copy; + conn->tx_body_owned = true; + } + conn->tx_body_len = body_len; + conn->tx_body_sent = 0; + conn->tx_pending = true; + if (rest_lwip_continue_send(conn) != ERR_OK) { + return false; + } + return body_owned; } - retries = 0; - while (true) { - err = tcp_write(conn->pcb, body, (uint16_t)body_len, TCP_WRITE_FLAG_COPY); - if (err == ERR_OK) { - break; - } - if (err != ERR_MEM || retries >= 16) { - rest_close_conn(conn); - return; + else { + retries = 0; + while (true) { + err = tcp_write(conn->pcb, body, (uint16_t)body_len, TCP_WRITE_FLAG_COPY); + if (err == ERR_OK) { + break; + } + if (err != ERR_MEM || retries >= 16) { + rest_close_conn(conn); + return false; + } + (void)tcp_output(conn->pcb); + retries++; } (void)tcp_output(conn->pcb); - retries++; } } - (void)tcp_output(conn->pcb); + else { + (void)tcp_output(conn->pcb); + } #endif #ifndef ENABLE_EMULATION if (conn->conn_type == REST_CONN_PLAIN) { - return; + return false; } #endif rest_close_conn(conn); + return false; } static void trim_ascii_ws_bounds(const char **start, const char **end) { @@ -949,10 +986,10 @@ void rest_handle_request(rest_conn_t *conn) { } uint16_t code = response.status_code == 0 ? 200 : response.status_code; if (code == 204) { - send_response(conn, code, rest_status_text_from_code(code), "application/json", "", 0, response.headers); + (void)send_response(conn, code, rest_status_text_from_code(code), "application/json", "", 0, response.headers, false); } else if (response.body != NULL && response.content_type != NULL) { - send_response(conn, code, rest_status_text_from_code(code), response.content_type, response.body, response.body_len, response.headers); + (void)send_response(conn, code, rest_status_text_from_code(code), response.content_type, response.body, response.body_len, response.headers, false); } else { send_json_error(conn, 500, "internal_error"); @@ -1072,6 +1109,61 @@ out: } #ifndef ENABLE_EMULATION +static err_t rest_lwip_continue_send(rest_conn_t *conn) { + if (conn == NULL || conn->pcb == NULL) { + return ERR_CONN; + } + if (!conn->tx_pending || conn->tx_body == NULL) { + return ERR_OK; + } + + while (conn->tx_body_sent < conn->tx_body_len) { + size_t remaining = conn->tx_body_len - conn->tx_body_sent; + size_t to_send = remaining > 1024 ? 1024 : remaining; + u16_t snd_avail = tcp_sndbuf(conn->pcb); + err_t err; + u8_t flags = 0; + + if (snd_avail == 0) { + break; + } + if (to_send > snd_avail) { + to_send = snd_avail; + } + if (to_send == 0) { + break; + } + if (conn->tx_body_sent + to_send < conn->tx_body_len) { + flags |= TCP_WRITE_FLAG_MORE; + } + + err = tcp_write(conn->pcb, conn->tx_body + conn->tx_body_sent, (u16_t)to_send, flags); + if (err == ERR_OK) { + conn->tx_body_sent += to_send; + continue; + } + if (err == ERR_MEM) { + break; + } + rest_close_conn(conn); + return err; + } + + (void)tcp_output(conn->pcb); + + if (conn->tx_body_sent >= conn->tx_body_len) { + if (conn->tx_body_owned) { + free((void *)conn->tx_body); + } + conn->tx_body = NULL; + conn->tx_body_len = 0; + conn->tx_body_sent = 0; + conn->tx_pending = false; + conn->tx_body_owned = false; + } + return ERR_OK; +} + static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err) { rest_conn_t *conn = (rest_conn_t *)arg; size_t *len = NULL; @@ -1126,6 +1218,9 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err static err_t rest_poll(void *arg, struct tcp_pcb *pcb) { rest_conn_t *conn = (rest_conn_t *)arg; LWIP_UNUSED_ARG(pcb); + if (conn != NULL && conn->tx_pending) { + return rest_lwip_continue_send(conn); + } if (rest_core1_job_pending_load() && rest_core1_job.conn == conn) { return ERR_OK; } @@ -1147,6 +1242,9 @@ static err_t rest_sent(void *arg, struct tcp_pcb *pcb, u16_t len) { rest_conn_t *conn = (rest_conn_t *)arg; LWIP_UNUSED_ARG(pcb); LWIP_UNUSED_ARG(len); + if (conn != NULL && conn->tx_pending) { + return rest_lwip_continue_send(conn); + } if (conn->conn_type == REST_CONN_TLS) { return tls_progress_conn(conn); } diff --git a/src/usb/lwip/rest_server.h b/src/usb/lwip/rest_server.h index 8c46eae..8707dc2 100644 --- a/src/usb/lwip/rest_server.h +++ b/src/usb/lwip/rest_server.h @@ -71,6 +71,13 @@ typedef struct { bool request_complete; bool request_dispatched; bool request_headers_parsed; +#ifndef ENABLE_EMULATION + const char *tx_body; + size_t tx_body_len; + size_t tx_body_sent; + bool tx_pending; + bool tx_body_owned; +#endif char request[REST_MAX_REQUEST_SIZE + 1]; #ifdef ENABLE_EMULATION char _padding2[2];