Add stop-n-wait mechanism to avoid sending too huge payloads.

Signed-off-by: Pol Henarejos <pol.henarejos@cttc.es>
This commit is contained in:
Pol Henarejos
2026-05-18 16:41:27 +02:00
parent a9261e34ad
commit 3ddb459e5c
3 changed files with 144 additions and 38 deletions

View File

@@ -18,6 +18,7 @@
#ifndef REST_H
#define REST_H
#include "compat/compat.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>

View File

@@ -133,8 +133,11 @@ static const rest_header_descriptor_t rest_http_headers[REST_HEADER_TOTAL_COUNT]
};
static void *rest_core1_thread(void *arg);
static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]);
static bool send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT], bool body_owned);
void rest_close_conn(rest_conn_t *conn);
#ifndef ENABLE_EMULATION
static err_t rest_lwip_continue_send(rest_conn_t *conn);
#endif
static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler, const rest_request_t *request) {
if (handler == NULL) {
@@ -192,7 +195,7 @@ static void *rest_core1_thread(void *arg) {
}
static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) {
send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body), NULL);
(void)send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body), NULL, false);
}
static void send_json_error(rest_conn_t *conn, int status_code, const char *error_message) {
@@ -245,7 +248,10 @@ void rest_task(void) {
if (conn != NULL) {
if (ready && body != NULL && content_type != NULL) {
send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers);
bool body_transferred = send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers, true);
if (body_transferred) {
body = NULL;
}
}
else {
send_json_error(conn, 500, "internal_error");
@@ -319,6 +325,16 @@ void rest_close_conn(rest_conn_t *conn) {
clear_conn(conn);
#else
err_t err;
if (conn != NULL && conn->tx_pending && conn->tx_body != NULL && conn->tx_body_owned) {
free((void *)conn->tx_body);
}
if (conn != NULL) {
conn->tx_body = NULL;
conn->tx_body_len = 0;
conn->tx_body_sent = 0;
conn->tx_pending = false;
conn->tx_body_owned = false;
}
if (conn == NULL || conn->pcb == NULL) {
clear_conn(conn);
return;
@@ -343,10 +359,11 @@ void rest_close_conn(rest_conn_t *conn) {
#endif
}
static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]) {
static bool send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT], bool body_owned) {
char headers_buf[256];
int header_len;
#ifdef ENABLE_EMULATION
(void)body_owned;
size_t sent_total = 0;
#else
err_t err;
@@ -360,7 +377,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|| conn->pcb == NULL
#endif
) {
return;
return false;
}
REST_DEBUG_LOG(
@@ -387,7 +404,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
headers[i] = NULL;
if (n <= 0 || header_len + n >= (int)sizeof(headers_buf)) {
rest_close_conn(conn);
return;
return false;
}
header_len += n;
}
@@ -395,13 +412,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
}
if (header_len + 2 >= (int)sizeof(headers_buf)) {
rest_close_conn(conn);
return;
return false;
}
memcpy(p + header_len, "\r\n", 2);
header_len += 2;
if (header_len <= 0 || (size_t)header_len >= sizeof(headers_buf)) {
rest_close_conn(conn);
return;
return false;
}
if (conn->conn_type == REST_CONN_TLS) {
@@ -414,13 +431,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
if (++want_retries > 2048) {
rest_close_conn(conn);
return;
return false;
}
continue;
}
if (ret <= 0) {
rest_close_conn(conn);
return;
return false;
}
want_retries = 0;
written += (size_t)ret;
@@ -432,13 +449,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
if (++want_retries > 2048) {
rest_close_conn(conn);
return;
return false;
}
continue;
}
if (ret <= 0) {
rest_close_conn(conn);
return;
return false;
}
want_retries = 0;
written += (size_t)ret;
@@ -446,14 +463,14 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
(void)mbedtls_ssl_close_notify(&conn->ssl);
rest_close_conn(conn);
return;
return false;
}
#ifdef ENABLE_EMULATION
while (sent_total < (size_t)header_len) {
ssize_t n = send((socket_t)conn->sock, headers_buf + sent_total, (int)((size_t)header_len - sent_total), 0);
if (n <= 0) {
rest_close_conn(conn);
return;
return false;
}
sent_total += (size_t)n;
}
@@ -462,7 +479,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
ssize_t n = send((socket_t)conn->sock, body + sent_total, (int)(body_len - sent_total), 0);
if (n <= 0) {
rest_close_conn(conn);
return;
return false;
}
sent_total += (size_t)n;
}
@@ -472,7 +489,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
while (tcp_sndbuf(conn->pcb) < (u16_t)header_len) {
if (retries >= 16) {
rest_close_conn(conn);
return;
return false;
}
(void)tcp_output(conn->pcb);
retries++;
@@ -485,44 +502,64 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
}
if (err != ERR_MEM || retries >= 16) {
rest_close_conn(conn);
return;
return false;
}
(void)tcp_output(conn->pcb);
retries++;
}
if (body_len > 0) {
retries = 0;
while (tcp_sndbuf(conn->pcb) < (u16_t)body_len) {
if (retries >= 16) {
rest_close_conn(conn);
return;
if (body_len > 1024) {
if (body_owned) {
conn->tx_body = body;
conn->tx_body_owned = true;
}
(void)tcp_output(conn->pcb);
retries++;
else {
char *tx_copy = (char *)malloc(body_len);
if (tx_copy == NULL) {
rest_close_conn(conn);
return false;
}
memcpy(tx_copy, body, body_len);
conn->tx_body = tx_copy;
conn->tx_body_owned = true;
}
conn->tx_body_len = body_len;
conn->tx_body_sent = 0;
conn->tx_pending = true;
if (rest_lwip_continue_send(conn) != ERR_OK) {
return false;
}
return body_owned;
}
retries = 0;
while (true) {
err = tcp_write(conn->pcb, body, (uint16_t)body_len, TCP_WRITE_FLAG_COPY);
if (err == ERR_OK) {
break;
}
if (err != ERR_MEM || retries >= 16) {
rest_close_conn(conn);
return;
else {
retries = 0;
while (true) {
err = tcp_write(conn->pcb, body, (uint16_t)body_len, TCP_WRITE_FLAG_COPY);
if (err == ERR_OK) {
break;
}
if (err != ERR_MEM || retries >= 16) {
rest_close_conn(conn);
return false;
}
(void)tcp_output(conn->pcb);
retries++;
}
(void)tcp_output(conn->pcb);
retries++;
}
}
(void)tcp_output(conn->pcb);
else {
(void)tcp_output(conn->pcb);
}
#endif
#ifndef ENABLE_EMULATION
if (conn->conn_type == REST_CONN_PLAIN) {
return;
return false;
}
#endif
rest_close_conn(conn);
return false;
}
static void trim_ascii_ws_bounds(const char **start, const char **end) {
@@ -949,10 +986,10 @@ void rest_handle_request(rest_conn_t *conn) {
}
uint16_t code = response.status_code == 0 ? 200 : response.status_code;
if (code == 204) {
send_response(conn, code, rest_status_text_from_code(code), "application/json", "", 0, response.headers);
(void)send_response(conn, code, rest_status_text_from_code(code), "application/json", "", 0, response.headers, false);
}
else if (response.body != NULL && response.content_type != NULL) {
send_response(conn, code, rest_status_text_from_code(code), response.content_type, response.body, response.body_len, response.headers);
(void)send_response(conn, code, rest_status_text_from_code(code), response.content_type, response.body, response.body_len, response.headers, false);
}
else {
send_json_error(conn, 500, "internal_error");
@@ -1072,6 +1109,61 @@ out:
}
#ifndef ENABLE_EMULATION
static err_t rest_lwip_continue_send(rest_conn_t *conn) {
if (conn == NULL || conn->pcb == NULL) {
return ERR_CONN;
}
if (!conn->tx_pending || conn->tx_body == NULL) {
return ERR_OK;
}
while (conn->tx_body_sent < conn->tx_body_len) {
size_t remaining = conn->tx_body_len - conn->tx_body_sent;
size_t to_send = remaining > 1024 ? 1024 : remaining;
u16_t snd_avail = tcp_sndbuf(conn->pcb);
err_t err;
u8_t flags = 0;
if (snd_avail == 0) {
break;
}
if (to_send > snd_avail) {
to_send = snd_avail;
}
if (to_send == 0) {
break;
}
if (conn->tx_body_sent + to_send < conn->tx_body_len) {
flags |= TCP_WRITE_FLAG_MORE;
}
err = tcp_write(conn->pcb, conn->tx_body + conn->tx_body_sent, (u16_t)to_send, flags);
if (err == ERR_OK) {
conn->tx_body_sent += to_send;
continue;
}
if (err == ERR_MEM) {
break;
}
rest_close_conn(conn);
return err;
}
(void)tcp_output(conn->pcb);
if (conn->tx_body_sent >= conn->tx_body_len) {
if (conn->tx_body_owned) {
free((void *)conn->tx_body);
}
conn->tx_body = NULL;
conn->tx_body_len = 0;
conn->tx_body_sent = 0;
conn->tx_pending = false;
conn->tx_body_owned = false;
}
return ERR_OK;
}
static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err) {
rest_conn_t *conn = (rest_conn_t *)arg;
size_t *len = NULL;
@@ -1126,6 +1218,9 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err
static err_t rest_poll(void *arg, struct tcp_pcb *pcb) {
rest_conn_t *conn = (rest_conn_t *)arg;
LWIP_UNUSED_ARG(pcb);
if (conn != NULL && conn->tx_pending) {
return rest_lwip_continue_send(conn);
}
if (rest_core1_job_pending_load() && rest_core1_job.conn == conn) {
return ERR_OK;
}
@@ -1147,6 +1242,9 @@ static err_t rest_sent(void *arg, struct tcp_pcb *pcb, u16_t len) {
rest_conn_t *conn = (rest_conn_t *)arg;
LWIP_UNUSED_ARG(pcb);
LWIP_UNUSED_ARG(len);
if (conn != NULL && conn->tx_pending) {
return rest_lwip_continue_send(conn);
}
if (conn->conn_type == REST_CONN_TLS) {
return tls_progress_conn(conn);
}

View File

@@ -71,6 +71,13 @@ typedef struct {
bool request_complete;
bool request_dispatched;
bool request_headers_parsed;
#ifndef ENABLE_EMULATION
const char *tx_body;
size_t tx_body_len;
size_t tx_body_sent;
bool tx_pending;
bool tx_body_owned;
#endif
char request[REST_MAX_REQUEST_SIZE + 1];
#ifdef ENABLE_EMULATION
char _padding2[2];