mirror of
https://github.com/polhenarejos/pico-keys-sdk
synced 2026-05-25 15:45:11 +02:00
Add stop-n-wait mechanism to avoid sending too huge payloads.
Signed-off-by: Pol Henarejos <pol.henarejos@cttc.es>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#ifndef REST_H
|
||||
#define REST_H
|
||||
|
||||
#include "compat/compat.h"
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
@@ -133,8 +133,11 @@ static const rest_header_descriptor_t rest_http_headers[REST_HEADER_TOTAL_COUNT]
|
||||
};
|
||||
|
||||
static void *rest_core1_thread(void *arg);
|
||||
static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]);
|
||||
static bool send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT], bool body_owned);
|
||||
void rest_close_conn(rest_conn_t *conn);
|
||||
#ifndef ENABLE_EMULATION
|
||||
static err_t rest_lwip_continue_send(rest_conn_t *conn);
|
||||
#endif
|
||||
|
||||
static int rest_start_core1_job(rest_conn_t *conn, rest_route_handler_t handler, const rest_request_t *request) {
|
||||
if (handler == NULL) {
|
||||
@@ -192,7 +195,7 @@ static void *rest_core1_thread(void *arg) {
|
||||
}
|
||||
|
||||
static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body) {
|
||||
send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body), NULL);
|
||||
(void)send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body), NULL, false);
|
||||
}
|
||||
|
||||
static void send_json_error(rest_conn_t *conn, int status_code, const char *error_message) {
|
||||
@@ -245,7 +248,10 @@ void rest_task(void) {
|
||||
|
||||
if (conn != NULL) {
|
||||
if (ready && body != NULL && content_type != NULL) {
|
||||
send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers);
|
||||
bool body_transferred = send_response(conn, code, rest_status_text_from_code(code), content_type, body, body_len, headers, true);
|
||||
if (body_transferred) {
|
||||
body = NULL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
send_json_error(conn, 500, "internal_error");
|
||||
@@ -319,6 +325,16 @@ void rest_close_conn(rest_conn_t *conn) {
|
||||
clear_conn(conn);
|
||||
#else
|
||||
err_t err;
|
||||
if (conn != NULL && conn->tx_pending && conn->tx_body != NULL && conn->tx_body_owned) {
|
||||
free((void *)conn->tx_body);
|
||||
}
|
||||
if (conn != NULL) {
|
||||
conn->tx_body = NULL;
|
||||
conn->tx_body_len = 0;
|
||||
conn->tx_body_sent = 0;
|
||||
conn->tx_pending = false;
|
||||
conn->tx_body_owned = false;
|
||||
}
|
||||
if (conn == NULL || conn->pcb == NULL) {
|
||||
clear_conn(conn);
|
||||
return;
|
||||
@@ -343,10 +359,11 @@ void rest_close_conn(rest_conn_t *conn) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT]) {
|
||||
static bool send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len, char *headers[REST_HEADER_TOTAL_COUNT], bool body_owned) {
|
||||
char headers_buf[256];
|
||||
int header_len;
|
||||
#ifdef ENABLE_EMULATION
|
||||
(void)body_owned;
|
||||
size_t sent_total = 0;
|
||||
#else
|
||||
err_t err;
|
||||
@@ -360,7 +377,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
|| conn->pcb == NULL
|
||||
#endif
|
||||
) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
REST_DEBUG_LOG(
|
||||
@@ -387,7 +404,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
headers[i] = NULL;
|
||||
if (n <= 0 || header_len + n >= (int)sizeof(headers_buf)) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
header_len += n;
|
||||
}
|
||||
@@ -395,13 +412,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
}
|
||||
if (header_len + 2 >= (int)sizeof(headers_buf)) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
memcpy(p + header_len, "\r\n", 2);
|
||||
header_len += 2;
|
||||
if (header_len <= 0 || (size_t)header_len >= sizeof(headers_buf)) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (conn->conn_type == REST_CONN_TLS) {
|
||||
@@ -414,13 +431,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||||
if (++want_retries > 2048) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ret <= 0) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
want_retries = 0;
|
||||
written += (size_t)ret;
|
||||
@@ -432,13 +449,13 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||||
if (++want_retries > 2048) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ret <= 0) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
want_retries = 0;
|
||||
written += (size_t)ret;
|
||||
@@ -446,14 +463,14 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
|
||||
(void)mbedtls_ssl_close_notify(&conn->ssl);
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
#ifdef ENABLE_EMULATION
|
||||
while (sent_total < (size_t)header_len) {
|
||||
ssize_t n = send((socket_t)conn->sock, headers_buf + sent_total, (int)((size_t)header_len - sent_total), 0);
|
||||
if (n <= 0) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
sent_total += (size_t)n;
|
||||
}
|
||||
@@ -462,7 +479,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
ssize_t n = send((socket_t)conn->sock, body + sent_total, (int)(body_len - sent_total), 0);
|
||||
if (n <= 0) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
sent_total += (size_t)n;
|
||||
}
|
||||
@@ -472,7 +489,7 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
while (tcp_sndbuf(conn->pcb) < (u16_t)header_len) {
|
||||
if (retries >= 16) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
(void)tcp_output(conn->pcb);
|
||||
retries++;
|
||||
@@ -485,44 +502,64 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status
|
||||
}
|
||||
if (err != ERR_MEM || retries >= 16) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
(void)tcp_output(conn->pcb);
|
||||
retries++;
|
||||
}
|
||||
|
||||
if (body_len > 0) {
|
||||
retries = 0;
|
||||
while (tcp_sndbuf(conn->pcb) < (u16_t)body_len) {
|
||||
if (retries >= 16) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
if (body_len > 1024) {
|
||||
if (body_owned) {
|
||||
conn->tx_body = body;
|
||||
conn->tx_body_owned = true;
|
||||
}
|
||||
(void)tcp_output(conn->pcb);
|
||||
retries++;
|
||||
else {
|
||||
char *tx_copy = (char *)malloc(body_len);
|
||||
if (tx_copy == NULL) {
|
||||
rest_close_conn(conn);
|
||||
return false;
|
||||
}
|
||||
memcpy(tx_copy, body, body_len);
|
||||
conn->tx_body = tx_copy;
|
||||
conn->tx_body_owned = true;
|
||||
}
|
||||
conn->tx_body_len = body_len;
|
||||
conn->tx_body_sent = 0;
|
||||
conn->tx_pending = true;
|
||||
if (rest_lwip_continue_send(conn) != ERR_OK) {
|
||||
return false;
|
||||
}
|
||||
return body_owned;
|
||||
}
|
||||
retries = 0;
|
||||
while (true) {
|
||||
err = tcp_write(conn->pcb, body, (uint16_t)body_len, TCP_WRITE_FLAG_COPY);
|
||||
if (err == ERR_OK) {
|
||||
break;
|
||||
}
|
||||
if (err != ERR_MEM || retries >= 16) {
|
||||
rest_close_conn(conn);
|
||||
return;
|
||||
else {
|
||||
retries = 0;
|
||||
while (true) {
|
||||
err = tcp_write(conn->pcb, body, (uint16_t)body_len, TCP_WRITE_FLAG_COPY);
|
||||
if (err == ERR_OK) {
|
||||
break;
|
||||
}
|
||||
if (err != ERR_MEM || retries >= 16) {
|
||||
rest_close_conn(conn);
|
||||
return false;
|
||||
}
|
||||
(void)tcp_output(conn->pcb);
|
||||
retries++;
|
||||
}
|
||||
(void)tcp_output(conn->pcb);
|
||||
retries++;
|
||||
}
|
||||
}
|
||||
(void)tcp_output(conn->pcb);
|
||||
else {
|
||||
(void)tcp_output(conn->pcb);
|
||||
}
|
||||
#endif
|
||||
#ifndef ENABLE_EMULATION
|
||||
if (conn->conn_type == REST_CONN_PLAIN) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
rest_close_conn(conn);
|
||||
return false;
|
||||
}
|
||||
|
||||
static void trim_ascii_ws_bounds(const char **start, const char **end) {
|
||||
@@ -949,10 +986,10 @@ void rest_handle_request(rest_conn_t *conn) {
|
||||
}
|
||||
uint16_t code = response.status_code == 0 ? 200 : response.status_code;
|
||||
if (code == 204) {
|
||||
send_response(conn, code, rest_status_text_from_code(code), "application/json", "", 0, response.headers);
|
||||
(void)send_response(conn, code, rest_status_text_from_code(code), "application/json", "", 0, response.headers, false);
|
||||
}
|
||||
else if (response.body != NULL && response.content_type != NULL) {
|
||||
send_response(conn, code, rest_status_text_from_code(code), response.content_type, response.body, response.body_len, response.headers);
|
||||
(void)send_response(conn, code, rest_status_text_from_code(code), response.content_type, response.body, response.body_len, response.headers, false);
|
||||
}
|
||||
else {
|
||||
send_json_error(conn, 500, "internal_error");
|
||||
@@ -1072,6 +1109,61 @@ out:
|
||||
}
|
||||
|
||||
#ifndef ENABLE_EMULATION
|
||||
static err_t rest_lwip_continue_send(rest_conn_t *conn) {
|
||||
if (conn == NULL || conn->pcb == NULL) {
|
||||
return ERR_CONN;
|
||||
}
|
||||
if (!conn->tx_pending || conn->tx_body == NULL) {
|
||||
return ERR_OK;
|
||||
}
|
||||
|
||||
while (conn->tx_body_sent < conn->tx_body_len) {
|
||||
size_t remaining = conn->tx_body_len - conn->tx_body_sent;
|
||||
size_t to_send = remaining > 1024 ? 1024 : remaining;
|
||||
u16_t snd_avail = tcp_sndbuf(conn->pcb);
|
||||
err_t err;
|
||||
u8_t flags = 0;
|
||||
|
||||
if (snd_avail == 0) {
|
||||
break;
|
||||
}
|
||||
if (to_send > snd_avail) {
|
||||
to_send = snd_avail;
|
||||
}
|
||||
if (to_send == 0) {
|
||||
break;
|
||||
}
|
||||
if (conn->tx_body_sent + to_send < conn->tx_body_len) {
|
||||
flags |= TCP_WRITE_FLAG_MORE;
|
||||
}
|
||||
|
||||
err = tcp_write(conn->pcb, conn->tx_body + conn->tx_body_sent, (u16_t)to_send, flags);
|
||||
if (err == ERR_OK) {
|
||||
conn->tx_body_sent += to_send;
|
||||
continue;
|
||||
}
|
||||
if (err == ERR_MEM) {
|
||||
break;
|
||||
}
|
||||
rest_close_conn(conn);
|
||||
return err;
|
||||
}
|
||||
|
||||
(void)tcp_output(conn->pcb);
|
||||
|
||||
if (conn->tx_body_sent >= conn->tx_body_len) {
|
||||
if (conn->tx_body_owned) {
|
||||
free((void *)conn->tx_body);
|
||||
}
|
||||
conn->tx_body = NULL;
|
||||
conn->tx_body_len = 0;
|
||||
conn->tx_body_sent = 0;
|
||||
conn->tx_pending = false;
|
||||
conn->tx_body_owned = false;
|
||||
}
|
||||
return ERR_OK;
|
||||
}
|
||||
|
||||
static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err) {
|
||||
rest_conn_t *conn = (rest_conn_t *)arg;
|
||||
size_t *len = NULL;
|
||||
@@ -1126,6 +1218,9 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err
|
||||
static err_t rest_poll(void *arg, struct tcp_pcb *pcb) {
|
||||
rest_conn_t *conn = (rest_conn_t *)arg;
|
||||
LWIP_UNUSED_ARG(pcb);
|
||||
if (conn != NULL && conn->tx_pending) {
|
||||
return rest_lwip_continue_send(conn);
|
||||
}
|
||||
if (rest_core1_job_pending_load() && rest_core1_job.conn == conn) {
|
||||
return ERR_OK;
|
||||
}
|
||||
@@ -1147,6 +1242,9 @@ static err_t rest_sent(void *arg, struct tcp_pcb *pcb, u16_t len) {
|
||||
rest_conn_t *conn = (rest_conn_t *)arg;
|
||||
LWIP_UNUSED_ARG(pcb);
|
||||
LWIP_UNUSED_ARG(len);
|
||||
if (conn != NULL && conn->tx_pending) {
|
||||
return rest_lwip_continue_send(conn);
|
||||
}
|
||||
if (conn->conn_type == REST_CONN_TLS) {
|
||||
return tls_progress_conn(conn);
|
||||
}
|
||||
|
||||
@@ -71,6 +71,13 @@ typedef struct {
|
||||
bool request_complete;
|
||||
bool request_dispatched;
|
||||
bool request_headers_parsed;
|
||||
#ifndef ENABLE_EMULATION
|
||||
const char *tx_body;
|
||||
size_t tx_body_len;
|
||||
size_t tx_body_sent;
|
||||
bool tx_pending;
|
||||
bool tx_body_owned;
|
||||
#endif
|
||||
char request[REST_MAX_REQUEST_SIZE + 1];
|
||||
#ifdef ENABLE_EMULATION
|
||||
char _padding2[2];
|
||||
|
||||
Reference in New Issue
Block a user