From febae0e664f2e8500ae08b79a783883ce524336b Mon Sep 17 00:00:00 2001 From: Pol Henarejos Date: Sun, 19 Apr 2026 03:16:54 +0200 Subject: [PATCH] Add support for TLS. Cert is self-signed and auto-generated on first boot. Signed-off-by: Pol Henarejos --- pico_keys_sdk_import.cmake | 21 ++ src/usb/lwip/lwip.c | 3 +- src/usb/lwip/rest.c | 164 ++++++++++ src/usb/lwip/rest.h | 80 +++++ src/usb/lwip/rest_server.c | 580 ++++++++++++++++++++------------- src/usb/lwip/rest_server.h | 64 ++-- src/usb/lwip/rest_server_tls.c | 320 ++++++++++++++++++ src/usb/lwip/rest_server_tls.h | 61 ++++ 8 files changed, 1034 insertions(+), 259 deletions(-) create mode 100644 src/usb/lwip/rest.c create mode 100644 src/usb/lwip/rest.h create mode 100644 src/usb/lwip/rest_server_tls.c create mode 100644 src/usb/lwip/rest_server_tls.h diff --git a/pico_keys_sdk_import.cmake b/pico_keys_sdk_import.cmake index 386bea1..c8ee27d 100644 --- a/pico_keys_sdk_import.cmake +++ b/pico_keys_sdk_import.cmake @@ -400,6 +400,25 @@ set(SYSTEM_INCLUDES ${CMAKE_CURRENT_LIST_DIR}/third-party/cjson ) +if(USB_ITF_LWIP) + add_compile_definitions( + MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED + MBEDTLS_SSL_PROTO_TLS1_2 + MBEDTLS_SSL_SRV_C + MBEDTLS_SSL_TLS_C + ) + list(APPEND MBEDTLS_SOURCES + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/pkparse.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/pk_ecc.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/pkcs12.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/ssl_ciphersuites.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/ssl_msg.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/ssl_tls.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/ssl_tls12_server.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/x509.c + ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/x509_crt.c + ) +endif() if(USB_ITF_HID) list(APPEND MBEDTLS_SOURCES ${CMAKE_CURRENT_LIST_DIR}/third-party/mbedtls/library/x509write_crt.c @@ -623,7 +642,9 @@ endif() if(USB_ITF_LWIP) list(APPEND PICO_KEYS_SOURCES + ${CMAKE_CURRENT_LIST_DIR}/src/usb/lwip/rest.c ${CMAKE_CURRENT_LIST_DIR}/src/usb/lwip/rest_server.c + ${CMAKE_CURRENT_LIST_DIR}/src/usb/lwip/rest_server_tls.c ) list(APPEND INCLUDES ${CMAKE_CURRENT_LIST_DIR}/src/usb/lwip diff --git a/src/usb/lwip/lwip.c b/src/usb/lwip/lwip.c index 924e2ea..9147d1a 100644 --- a/src/usb/lwip/lwip.c +++ b/src/usb/lwip/lwip.c @@ -52,6 +52,7 @@ try changing the first byte of tud_network_mac_address[] below from 0x02 to 0x00 #include "lwip/init.h" #include "lwip/timeouts.h" #include "rest_server.h" +#include "rest_server_tls.h" #define INIT_IP4(a, b, c, d) \ { PP_HTONL(LWIP_MAKEU32(a, b, c, d)) } @@ -217,7 +218,7 @@ int lwip_itf_init(void) { while (!netif_is_up(&netif_data)); while (dhserv_init(&dhcp_config) != ERR_OK); while (dnserv_init(IP_ADDR_ANY, 53, dns_query_proc) != ERR_OK); - while (rest_server_init() != ERR_OK); + while (rest_server_init(REST_CONN_ALL) != ERR_OK); return 0; } diff --git a/src/usb/lwip/rest.c b/src/usb/lwip/rest.c new file mode 100644 index 0000000..7309512 --- /dev/null +++ b/src/usb/lwip/rest.c @@ -0,0 +1,164 @@ +/* + * This file is part of the Pico Keys SDK distribution (https://github.com/polhenarejos/pico-keys-sdk). + * Copyright (c) 2022 Pol Henarejos. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, version 3. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "rest.h" +#include + +#ifdef DEBUG_APDU +void debug_dump_payload(const char *tag, const char *buffer, size_t len) { + size_t i; + if (buffer == NULL) { + printf("[rest] %s: \n", tag); + return; + } + + printf("[rest] %s (%lu bytes): \"", tag, (unsigned long)len); + for (i = 0; i < len; i++) { + unsigned char c = (unsigned char)buffer[i]; + if (c == '\r') { + printf("\\r"); + } + else if (c == '\n') { + printf("\\n"); + } + else if (c == '\t') { + printf("\\t"); + } + else if (c >= 32 && c <= 126) { + putchar((int)c); + } + else { + printf("\\x%02X", c); + } + } + printf("\"\n"); + if (tag[2] == 's') { + printf("\n"); + } +} +#endif + +int execute_route_handler(const rest_request_t *request, rest_route_handler_t handler, rest_response_t *response) { + if (request == NULL || handler == NULL || response == NULL) { + return -1; + } + + memset(response, 0, sizeof(*response)); + response->status_code = 200; + response->content_type = "application/json"; + response->body = "{\"ok\":true}"; + response->json = cJSON_CreateObject(); + if (response->json == NULL) { + return -1; + } + + if (handler(request, response) != 0) { + cJSON_Delete(response->json); + response->json = NULL; + return -1; + } + if (response->content_type == NULL || response->body == NULL) { + cJSON_Delete(response->json); + response->json = NULL; + return -1; + } + if (response->status_code == 0 || response->status_code == 200) { + char *body = cJSON_PrintUnformatted(response->json); + cJSON_Delete(response->json); + response->json = NULL; + if (body == NULL) { + return -1; + } + response->body = body; + } + + response->status_code = (response->status_code == 0) ? 200 : response->status_code; + response->body_len = (response->body_len == 0) ? strlen(response->body) : response->body_len; + return 0; +} + +int rest_response_set_error(rest_response_t *response, int status_code, const char *message) { + char json_template[256]; + int json_len; + if (response == NULL) { + return -1; + } + json_len = snprintf(json_template, sizeof(json_template), "{\"error\":\"%s\"}", message); + if (json_len <= 0 || (size_t)json_len >= sizeof(json_template)) { + return -1; + } + response->status_code = (uint16_t)status_code; + response->content_type = "application/json"; + response->body = strdup(json_template); + if (response->body == NULL) { + return -1; + } + response->body_len = (size_t)json_len; + return 0; +} + +const char *rest_status_text_from_code(uint16_t code) { + switch (code) { + case 200: + return "OK"; + case 400: + return "Bad Request"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 413: + return "Payload Too Large"; + case 415: + return "Unsupported Media Type"; + case 500: + return "Internal Server Error"; + case 503: + return "Service Unavailable"; + default: + return "OK"; + } +} + +const char *rest_method_to_string(rest_http_method_t method) { + switch (method) { + case REST_HTTP_GET: + return "GET"; + case REST_HTTP_POST: + return "POST"; + case REST_HTTP_PUT: + return "PUT"; + case REST_HTTP_DELETE: + return "DELETE"; + default: + return "UNKNOWN"; + } +} + +bool rest_content_type_is_json(const char *content_type) { + if (content_type == NULL) { + return false; + } + return strncasecmp(content_type, "application/json", 16) == 0; +} + +__attribute__((weak)) const rest_route_t *rest_get_routes(size_t *count) { + if (count != NULL) { + *count = 0; + } + return NULL; +} diff --git a/src/usb/lwip/rest.h b/src/usb/lwip/rest.h new file mode 100644 index 0000000..b444226 --- /dev/null +++ b/src/usb/lwip/rest.h @@ -0,0 +1,80 @@ +/* + * This file is part of the Pico Keys SDK distribution (https://github.com/polhenarejos/pico-keys-sdk). + * Copyright (c) 2022 Pol Henarejos. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, version 3. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef REST_SERVER_H +#define REST_SERVER_H + +#include +#include +#include +#include +#include "cJSON.h" + +#define REST_MAX_REQUEST_SIZE 1024 +#define REST_MAX_METHOD_SIZE 8 +#define REST_MAX_CONTENT_TYPE_SIZE 64 +#define REST_MAX_PATH_SIZE 192 + +typedef enum { + REST_HTTP_GET = 0, + REST_HTTP_POST, + REST_HTTP_PUT, + REST_HTTP_DELETE +} rest_http_method_t; + +typedef struct { + rest_http_method_t method; + char path[REST_MAX_PATH_SIZE]; + const char *body; + size_t body_len; + const char *content_type; +} rest_request_t; + +typedef struct { + uint16_t status_code; + const char *content_type; + char *body; // heap ! + size_t body_len; + cJSON *json; +} rest_response_t; + +typedef int (*rest_route_handler_t)(const rest_request_t *request, rest_response_t *response); + +typedef struct { + rest_http_method_t method; + const char *path; + rest_route_handler_t handler; +} rest_route_t; + + +extern int execute_route_handler(const rest_request_t *request, rest_route_handler_t handler, rest_response_t *response); +extern int rest_response_set_error(rest_response_t *response, int status_code, const char *message); +const char *rest_status_text_from_code(uint16_t code); +const char *rest_method_to_string(rest_http_method_t method); +bool rest_content_type_is_json(const char *content_type); + +const rest_route_t *rest_get_routes(size_t *count); + +#ifdef DEBUG_APDU +extern void debug_dump_payload(const char *tag, const char *buffer, size_t len); +#define REST_DEBUG_LOG(...) printf(__VA_ARGS__) +#else +#define debug_dump_payload(tag, buffer, len) do { (void)(tag); (void)(buffer); (void)(len); } while (0) +#define REST_DEBUG_LOG(...) do {} while (0) +#endif + +#endif diff --git a/src/usb/lwip/rest_server.c b/src/usb/lwip/rest_server.c index 529dda0..89da40d 100644 --- a/src/usb/lwip/rest_server.c +++ b/src/usb/lwip/rest_server.c @@ -1,4 +1,23 @@ +/* + * This file is part of the Pico Keys SDK distribution (https://github.com/polhenarejos/pico-keys-sdk). + * Copyright (c) 2022 Pol Henarejos. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, version 3. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#define MBEDTLS_ALLOW_PRIVATE_ACCESS #include "rest_server.h" +#include "rest_server_tls.h" #include "pico_keys.h" #include "usb.h" @@ -24,57 +43,6 @@ #include "lwip/def.h" #endif -#define REST_PORT 80 -#define REST_MAX_CONNS 4 -#define REST_MAX_REQUEST_SIZE 1024 -#define REST_MAX_METHOD_SIZE 8 -#define REST_MAX_CONTENT_TYPE_SIZE 64 - -#ifdef DEBUG_APDU -static void debug_dump_payload(const char *tag, const char *buffer, size_t len) { - size_t i; - if (buffer == NULL) { - printf("[rest-debug] %s: \n", tag); - return; - } - - printf("[rest-debug] %s (%lu bytes): \"", tag, (unsigned long)len); - for (i = 0; i < len; i++) { - unsigned char c = (unsigned char)buffer[i]; - if (c == '\r') { - printf("\\r"); - } else if (c == '\n') { - printf("\\n"); - } else if (c == '\t') { - printf("\\t"); - } else if (c >= 32 && c <= 126) { - putchar((int)c); - } else { - printf("\\x%02X", c); - } - } - printf("\"\n"); - if (tag[2] == 's') { - printf("\n"); - } -} -#define REST_DEBUG_LOG(...) printf(__VA_ARGS__) -#else -#define debug_dump_payload(tag, buffer, len) do { (void)(tag); (void)(buffer); (void)(len); } while (0) -#define REST_DEBUG_LOG(...) do {} while (0) -#endif - -typedef struct { - bool in_use; -#ifdef ENABLE_EMULATION - int sock; -#else - struct tcp_pcb *pcb; -#endif - char request[REST_MAX_REQUEST_SIZE + 1]; - size_t request_len; -} rest_conn_t; - #ifndef ENABLE_EMULATION static struct tcp_pcb *listener_pcb = NULL; #else @@ -89,7 +57,7 @@ typedef struct { bool pending; rest_conn_t *conn; rest_route_handler_t handler; - const rest_request_t *request; + rest_request_t request; } rest_core1_job_t; typedef struct { @@ -103,59 +71,18 @@ static rest_core1_result_t rest_core1_result = {0}; static void *rest_core1_thread(void *arg); static void send_response(rest_conn_t *conn, int status_code, const char *status_text, const char *content_type, const char *body, size_t body_len); static void send_json(rest_conn_t *conn, int status_code, const char *status_text, const char *json_body); -static const char *status_text_from_code(uint16_t code); - -static int execute_route_handler(const rest_request_t *request, rest_route_handler_t handler, rest_response_t *response) { - if (request == NULL || handler == NULL || response == NULL) { - return -1; - } - - memset(response, 0, sizeof(*response)); - response->status_code = 200; - response->content_type = "application/json"; - response->body = "{\"ok\":true}"; - response->json = cJSON_CreateObject(); - if (response->json == NULL) { - return -1; - } - - if (handler(request, response) != 0) { - cJSON_Delete(response->json); - response->json = NULL; - return -1; - } - if (response->content_type == NULL || response->body == NULL) { - cJSON_Delete(response->json); - response->json = NULL; - return -1; - } - if (response->status_code == 0 || response->status_code == 200) { - char *body = cJSON_PrintUnformatted(response->json); - cJSON_Delete(response->json); - response->json = NULL; - if (body == NULL) { - return -1; - } - response->body = body; - } - - response->status_code = (response->status_code == 0) ? 200 : response->status_code; - response->body_len = (response->body_len == 0) ? strlen(response->body) : response->body_len; - return 0; -} static int rest_start_core1_job(rest_conn_t *conn, const rest_request_t *request, rest_route_handler_t handler) { - if (conn == NULL || request == NULL || handler == NULL || rest_core1_job.pending) { + if (request == NULL || handler == NULL || rest_core1_job.pending) { return -1; } - memset(&rest_core1_job, 0, sizeof(rest_core1_job)); memset(&rest_core1_result, 0, sizeof(rest_core1_result)); rest_core1_job.pending = true; rest_core1_job.conn = conn; rest_core1_job.handler = handler; - rest_core1_job.request = request; + rest_core1_job.request = *request; card_start(ITF_LWIP, rest_core1_thread); usb_send_event(EV_CMD_AVAILABLE); @@ -179,7 +106,7 @@ static void *rest_core1_thread(void *arg) { if (m == EV_CMD_AVAILABLE) { rest_core1_result.ready = false; memset(&rest_core1_result.response, 0, sizeof(rest_core1_result.response)); - if (!rest_core1_job.pending || rest_core1_job.handler == NULL || execute_route_handler(rest_core1_job.request, rest_core1_job.handler, &rest_core1_result.response) != 0) { + if (!rest_core1_job.pending || rest_core1_job.handler == NULL || execute_route_handler(&rest_core1_job.request, rest_core1_job.handler, &rest_core1_result.response) != 0) { rest_server_error(&rest_core1_result.response, 500, "internal_error"); } rest_core1_result.ready = true; @@ -205,30 +132,30 @@ void rest_task(void) { } conn = rest_core1_job.conn; + if (conn == NULL) { + return; + } + rest_response_t *response = &rest_core1_result.response; + if (response == NULL) { + return; + } if (conn != NULL) { - if (rest_core1_result.ready && rest_core1_result.response.body != NULL && rest_core1_result.response.content_type != NULL) { - uint16_t code = rest_core1_result.response.status_code == 0 ? 200 : rest_core1_result.response.status_code; - send_response(conn, code, status_text_from_code(code), rest_core1_result.response.content_type, - rest_core1_result.response.body, rest_core1_result.response.body_len); - } else { + if (rest_core1_result.ready && response->body != NULL && response->content_type != NULL) { + uint16_t code = response->status_code == 0 ? 200 : response->status_code; + send_response(conn, code, rest_status_text_from_code(code), response->content_type, response->body, response->body_len); + } + else { send_json(conn, 500, "Internal Server Error", "{\"error\":\"internal_error\"}"); } } - if (rest_core1_result.response.body != NULL) { - free(rest_core1_result.response.body); + if (response->body != NULL) { + free(response->body); } memset(&rest_core1_result, 0, sizeof(rest_core1_result)); memset(&rest_core1_job, 0, sizeof(rest_core1_job)); } -__attribute__((weak)) const rest_route_t *rest_get_routes(size_t *count) { - if (count != NULL) { - *count = 0; - } - return NULL; -} - static rest_conn_t *alloc_conn( #ifdef ENABLE_EMULATION int sock @@ -237,15 +164,23 @@ static rest_conn_t *alloc_conn( #endif ) { size_t i; + uint16_t local_port = REST_PORT; for (i = 0; i < REST_MAX_CONNS; i++) { if (!conns[i].in_use) { memset(&conns[i], 0, sizeof(conns[i])); conns[i].in_use = true; #ifdef ENABLE_EMULATION conns[i].sock = sock; + struct sockaddr_in addr; + socklen_t len = sizeof(addr); + if (getsockname(sock, (struct sockaddr *)&addr, &len) == 0) { + local_port = ntohs(addr.sin_port); + } #else conns[i].pcb = pcb; + local_port = pcb->local_port; #endif + conns[i].conn_type = (local_port == REST_TLS_PORT) ? REST_CONN_TLS : REST_CONN_PLAIN; return &conns[i]; } } @@ -262,11 +197,14 @@ static void clear_conn(rest_conn_t *conn) { #endif } -static void close_conn(rest_conn_t *conn) { +void close_conn(rest_conn_t *conn) { #ifdef ENABLE_EMULATION if (conn == NULL) { return; } + if (conn->conn_type == REST_CONN_TLS) { + mbedtls_ssl_free(&conn->ssl); + } if (conn->sock >= 0) { #ifndef _MSC_VER (void)close(conn->sock); @@ -288,6 +226,9 @@ static void close_conn(rest_conn_t *conn) { if (err != ERR_OK) { tcp_abort(conn->pcb); } + if (conn->conn_type == REST_CONN_TLS) { + mbedtls_ssl_free(&conn->ssl); + } clear_conn(conn); #endif } @@ -313,7 +254,8 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status } REST_DEBUG_LOG( - "[rest-debug] response: status=%d content_type=%s body_len=%lu\n", + "[rest %s] response: status=%d content_type=%s body_len=%lu\n", + (conn->conn_type == REST_CONN_TLS) ? "TLS" : "PLAIN", status_code, (content_type != NULL) ? content_type : "(null)", (unsigned long)body_len @@ -331,6 +273,51 @@ static void send_response(rest_conn_t *conn, int status_code, const char *status close_conn(conn); return; } + + if (conn->conn_type == REST_CONN_TLS) { + size_t written = 0; + int ret; + int want_retries = 0; + + while (written < (size_t)header_len) { + ret = mbedtls_ssl_write(&conn->ssl, (const unsigned char *)headers + written, (size_t)header_len - written); + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + if (++want_retries > 2048) { + close_conn(conn); + return; + } + continue; + } + if (ret <= 0) { + close_conn(conn); + return; + } + want_retries = 0; + written += (size_t)ret; + } + + written = 0; + while (written < body_len) { + ret = mbedtls_ssl_write(&conn->ssl, (const unsigned char *)body + written, body_len - written); + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + if (++want_retries > 2048) { + close_conn(conn); + return; + } + continue; + } + if (ret <= 0) { + close_conn(conn); + return; + } + want_retries = 0; + written += (size_t)ret; + } + + (void)mbedtls_ssl_close_notify(&conn->ssl); + close_conn(conn); + return; + } #ifdef ENABLE_EMULATION while (sent_total < (size_t)header_len) { #ifndef _MSC_VER @@ -385,44 +372,6 @@ static void send_json(rest_conn_t *conn, int status_code, const char *status_tex send_response(conn, status_code, status_text, "application/json", json_body, strlen(json_body)); } -static const char *status_text_from_code(uint16_t code) { - switch (code) { - case 200: - return "OK"; - case 400: - return "Bad Request"; - case 404: - return "Not Found"; - case 405: - return "Method Not Allowed"; - case 413: - return "Payload Too Large"; - case 415: - return "Unsupported Media Type"; - case 500: - return "Internal Server Error"; - case 503: - return "Service Unavailable"; - default: - return "OK"; - } -} - -static const char *method_str_from_request_method(rest_http_method_t method) { - switch (method) { - case REST_HTTP_GET: - return "GET"; - case REST_HTTP_POST: - return "POST"; - case REST_HTTP_PUT: - return "PUT"; - case REST_HTTP_DELETE: - return "DELETE"; - default: - return "UNKNOWN"; - } -} - static int parse_request(rest_conn_t *conn, rest_request_t *request) { char *header_end, *line_end, *cursor; size_t headers_size; @@ -508,15 +457,8 @@ static int parse_request(rest_conn_t *conn, rest_request_t *request) { return 1; } -static int is_json(const char *content_type) { - if (content_type == NULL) { - return 0; - } - return strncasecmp(content_type, "application/json", 16) == 0; -} - -static void handle_request(rest_conn_t *conn) { - rest_request_t request = {0}; +void handle_request(rest_conn_t *conn) { + rest_request_t *request = &rest_core1_job.request; const rest_route_t *routes; size_t route_count = 0, i; bool path_exists_for_other_method = false; @@ -527,7 +469,8 @@ static void handle_request(rest_conn_t *conn) { return; } - parsed = parse_request(conn, &request); + memset(&rest_core1_job, 0, sizeof(rest_core1_job)); + parsed = parse_request(conn, request); if (parsed <= 0) { if (parsed < 0) { send_json(conn, 400, "Bad Request", "{\"error\":\"bad_request\"}"); @@ -536,14 +479,15 @@ static void handle_request(rest_conn_t *conn) { } REST_DEBUG_LOG( - "[rest-debug] request: %s %s\n", - method_str_from_request_method(request.method), - request.path + "[rest %s] request: %s %s\n", + (conn->conn_type == REST_CONN_TLS) ? "TLS" : "PLAIN", + rest_method_to_string(request->method), + request->path ); - debug_dump_payload("request-body", request.body, request.body_len); + debug_dump_payload("request-body", request->body, request->body_len); - if (request.method == REST_HTTP_POST || request.method == REST_HTTP_PUT) { - if (!is_json(request.content_type)) { + if (request->method == REST_HTTP_POST || request->method == REST_HTTP_PUT) { + if (!rest_content_type_is_json(request->content_type)) { send_json(conn, 415, "Unsupported Media Type", "{\"error\":\"content_type_must_be_application_json\"}"); return; } @@ -553,14 +497,14 @@ static void handle_request(rest_conn_t *conn) { if (routes[i].path == NULL || routes[i].handler == NULL) { continue; } - if (strcmp(routes[i].path, request.path) != 0) { + if (strcmp(routes[i].path, request->path) != 0) { continue; } - if (routes[i].method != request.method) { + if (routes[i].method != request->method) { path_exists_for_other_method = true; continue; } - if (rest_start_core1_job(conn, &request, routes[i].handler) != 0) { + if (rest_start_core1_job(conn, request, routes[i].handler) != 0) { send_json(conn, 500, "Internal Server Error", "{\"error\":\"internal_error\"}"); } return; @@ -568,14 +512,95 @@ static void handle_request(rest_conn_t *conn) { if (path_exists_for_other_method) { send_json(conn, 405, "Method Not Allowed", "{\"error\":\"method_not_allowed\"}"); - } else { + } + else { send_json(conn, 404, "Not Found", "{\"error\":\"not_found\"}"); } } +void rest_check_and_load_credentials(void) { + file_t *ef = file_new(EF_TLS_KEY); + if (!file_has_data(ef)) { + mbedtls_ecdsa_context ecdsa; + size_t olen = 0; + uint8_t pkey[MBEDTLS_ECP_MAX_BYTES]; + while (olen != 32) { + mbedtls_ecdsa_init(&ecdsa); + mbedtls_ecp_group_id ec_id = MBEDTLS_ECP_DP_SECP256R1; + mbedtls_ecdsa_genkey(&ecdsa, ec_id, random_gen, NULL); + mbedtls_ecp_write_key_ext(&ecdsa, &olen, pkey, sizeof(pkey)); + mbedtls_ecdsa_free(&ecdsa); + } + file_put_data(ef, pkey, olen); + mbedtls_platform_zeroize(pkey, sizeof(pkey)); + printf("TLS key generated and stored, length: %u bytes\n", (unsigned)olen); + } + tls_credentials.tls_key_pem = (char *)file_get_data(ef); + tls_credentials.tls_key_pem_len = file_get_size(ef); + printf("TLS key loaded, length: %u bytes\n", (unsigned)tls_credentials.tls_key_pem_len); + ef = file_new(EF_TLS_CERT); + if (!file_has_data(ef)) { + mbedtls_x509write_cert crt; + mbedtls_pk_context key; + unsigned char cert_pem[2048] = {0}; + int ret = 0; + file_t *ef_key = search_file(EF_TLS_KEY); + const uint8_t *file = file_get_data(ef_key); + size_t file_len = file_get_size(ef_key); + + mbedtls_x509write_crt_init(&crt); + mbedtls_x509write_crt_set_version(&crt, MBEDTLS_X509_CRT_VERSION_3); + mbedtls_pk_init(&key); + ret = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); + if (ret != 0) goto out; + mbedtls_ecp_read_key(MBEDTLS_ECP_DP_SECP256R1, mbedtls_pk_ec(key), file, file_len); + mbedtls_ecp_check_privkey(&mbedtls_pk_ec(key)->grp, &mbedtls_pk_ec(key)->d); + mbedtls_ecp_mul(&mbedtls_pk_ec(key)->grp, &mbedtls_pk_ec(key)->Q, + &mbedtls_pk_ec(key)->d, &mbedtls_pk_ec(key)->grp.G, + random_gen, NULL); + mbedtls_ecp_check_pubkey(&mbedtls_pk_ec(key)->grp, &mbedtls_pk_ec(key)->Q); + + mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256); + mbedtls_x509write_crt_set_subject_key(&crt, &key); + mbedtls_x509write_crt_set_issuer_key(&crt, &key); // self-signed + + mbedtls_x509write_crt_set_subject_key_identifier(&crt); + mbedtls_x509write_crt_set_authority_key_identifier(&crt); + ret = mbedtls_x509write_crt_set_subject_name(&crt, "CN=pico-novus"); + if (ret != 0) goto out; + ret = mbedtls_x509write_crt_set_issuer_name(&crt, "CN=pico-novus"); + if (ret != 0) goto out; + uint8_t serial[16]; + random_gen(NULL, serial, sizeof(serial)); + mbedtls_x509write_crt_set_serial_raw(&crt, serial, sizeof(serial)); + if (ret != 0) goto out; + ret = mbedtls_x509write_crt_set_validity(&crt, "20260101000000", "20360101000000"); + if (ret != 0) goto out; + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, 0); + if (ret != 0) goto out; + ret = mbedtls_x509write_crt_set_key_usage(&crt, MBEDTLS_X509_KU_DIGITAL_SIGNATURE | MBEDTLS_X509_KU_KEY_CERT_SIGN | MBEDTLS_X509_KU_KEY_ENCIPHERMENT); + if (ret != 0) goto out; + + ret = mbedtls_x509write_crt_pem(&crt, cert_pem, sizeof(cert_pem), random_gen, NULL); + if (ret == 0) { + file_put_data(ef, cert_pem, strlen((char *)cert_pem) + 1); + printf("TLS certificate generated and stored, length: %u bytes\n", (unsigned)strlen((char *)cert_pem)); + } +out: + mbedtls_x509write_crt_free(&crt); + mbedtls_platform_zeroize(cert_pem, sizeof(cert_pem)); + } + tls_credentials.tls_cert_pem = (char *)file_get_data(ef); + tls_credentials.tls_cert_pem_len = file_get_size(ef); + printf("TLS certificate loaded, length: %u bytes\n", (unsigned)tls_credentials.tls_cert_pem_len); + low_flash_available(); +} + #ifndef ENABLE_EMULATION static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err) { rest_conn_t *conn = (rest_conn_t *)arg; + size_t *len = NULL; + unsigned char *buffer = NULL; LWIP_UNUSED_ARG(pcb); if (err != ERR_OK) { @@ -594,16 +619,31 @@ static err_t rest_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err pbuf_free(p); return ERR_OK; } - if (conn->request_len + p->tot_len > REST_MAX_REQUEST_SIZE) { + if (conn->conn_type == REST_CONN_TLS) { + len = &conn->rx_cipher_len; + buffer = conn->rx_cipher; + } + else { + len = &conn->request_len; + buffer = (unsigned char *)conn->request; + } + if (*len + p->tot_len > REST_MAX_REQUEST_SIZE) { tcp_recved(pcb, p->tot_len); pbuf_free(p); + if (conn->conn_type == REST_CONN_TLS) { + close_conn(conn); + return ERR_ABRT; + } send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}"); return ERR_OK; } - pbuf_copy_partial(p, conn->request + conn->request_len, p->tot_len, 0); - conn->request_len += p->tot_len; + pbuf_copy_partial(p, buffer + *len, p->tot_len, 0); + *len += p->tot_len; tcp_recved(pcb, p->tot_len); pbuf_free(p); + if (conn->conn_type == REST_CONN_TLS) { + return tls_progress_conn(conn); + } handle_request(conn); return ERR_OK; } @@ -614,13 +654,29 @@ static err_t rest_poll(void *arg, struct tcp_pcb *pcb) { if (rest_core1_job.pending && rest_core1_job.conn == conn) { return ERR_OK; } + if (conn->conn_type == REST_CONN_TLS) { + return tls_progress_conn(conn); + } close_conn(conn); return ERR_OK; } +static err_t rest_sent(void *arg, struct tcp_pcb *pcb, u16_t len) { + rest_conn_t *conn = (rest_conn_t *)arg; + LWIP_UNUSED_ARG(pcb); + LWIP_UNUSED_ARG(len); + if (conn->conn_type == REST_CONN_TLS) { + return tls_progress_conn(conn); + } + return ERR_OK; +} + static void rest_err(void *arg, err_t err) { rest_conn_t *conn = (rest_conn_t *)arg; LWIP_UNUSED_ARG(err); + if (conn == NULL) { + return; + } clear_conn(conn); } @@ -638,33 +694,70 @@ static err_t rest_accept(void *arg, struct tcp_pcb *newpcb, err_t err) { tcp_abort(newpcb); return ERR_ABRT; } + + if (conn->conn_type == REST_CONN_TLS) { + mbedtls_ssl_init(&conn->ssl); + if (mbedtls_ssl_setup(&conn->ssl, &tls_conf) != 0) { + close_conn(conn); + return ERR_ABRT; + } + mbedtls_ssl_set_bio(&conn->ssl, conn, tls_send_cb, tls_recv_cb, NULL); + tcp_nagle_disable(newpcb); + } tcp_arg(newpcb, conn); tcp_recv(newpcb, rest_recv); + tcp_sent(newpcb, rest_sent); tcp_poll(newpcb, rest_poll, 8); tcp_err(newpcb, rest_err); return ERR_OK; } -err_t rest_server_init(void) { - err_t err; - if (listener_pcb != NULL) { +static err_t rest_server_init_conn(struct tcp_pcb **listener_pcb, uint16_t port, rest_conn_type_t conn_type, const tls_credentials_t *tls_credentials) { + if (*listener_pcb != NULL) { return ERR_OK; } - listener_pcb = tcp_new_ip_type(IPADDR_TYPE_ANY); - if (listener_pcb == NULL) { + if (conn_type & REST_CONN_TLS) { + if (tls_credentials == NULL || tls_credentials->tls_key_pem == NULL || tls_credentials->tls_cert_pem == NULL) { + return ERR_VAL; + } + if (tls_init_tls_context(tls_credentials) != 0) { + return ERR_VAL; + } + } + + *listener_pcb = tcp_new_ip_type(IPADDR_TYPE_ANY); + if (*listener_pcb == NULL) { return ERR_MEM; } - err = tcp_bind(listener_pcb, IP_ANY_TYPE, REST_PORT); + err_t err = tcp_bind(*listener_pcb, IP_ANY_TYPE, port); if (err != ERR_OK) { - tcp_abort(listener_pcb); - listener_pcb = NULL; + tcp_abort(*listener_pcb); + *listener_pcb = NULL; return err; } - listener_pcb = tcp_listen_with_backlog(listener_pcb, REST_MAX_CONNS); - if (listener_pcb == NULL) { + *listener_pcb = tcp_listen_with_backlog(*listener_pcb, REST_MAX_CONNS); + if (*listener_pcb == NULL) { return ERR_MEM; } - tcp_accept(listener_pcb, rest_accept); + tcp_accept(*listener_pcb, rest_accept); + return ERR_OK; +} + +err_t rest_server_init(rest_conn_type_t conn_type) { + err_t err; + if (conn_type & REST_CONN_PLAIN) { + err = rest_server_init_conn(&listener_pcb, REST_PORT, REST_CONN_PLAIN, NULL); + if (err != ERR_OK) { + return err; + } + } + if (conn_type & REST_CONN_TLS) { + rest_check_and_load_credentials(); + err = rest_server_init_conn(&tls_listener_pcb, REST_TLS_PORT, REST_CONN_TLS, &tls_credentials); + if (err != ERR_OK) { + return err; + } + } return ERR_OK; } #else @@ -689,11 +782,11 @@ static int emulation_rest_port(void) { #ifndef _MSC_VER static void *rest_emulation_thread(void *arg) { struct sockaddr_in peer; - (void)arg; + int listen_fd = (int)(intptr_t)arg; while (true) { socklen_t peer_len = sizeof(peer); - int accepted = accept(listener_sock, (struct sockaddr *)&peer, &peer_len); + int accepted = accept(listen_fd, (struct sockaddr *)&peer, &peer_len); rest_conn_t *conn; if (accepted < 0) { continue; @@ -703,59 +796,83 @@ static void *rest_emulation_thread(void *arg) { (void)close(accepted); continue; } - while (conn->in_use) { - ssize_t n = recv(conn->sock, conn->request + conn->request_len, REST_MAX_REQUEST_SIZE - conn->request_len, 0); - if (n <= 0) { + if (conn->conn_type == REST_CONN_TLS) { + mbedtls_ssl_init(&conn->ssl); + if (mbedtls_ssl_setup(&conn->ssl, &tls_conf) != 0) { close_conn(conn); - break; + continue; } - conn->request_len += (size_t)n; - if (conn->request_len > REST_MAX_REQUEST_SIZE) { - send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}"); - break; + mbedtls_ssl_set_bio(&conn->ssl, &conn->sock, tls_send_cb, tls_recv_cb, NULL); + } + while (conn->in_use) { + if (conn->conn_type == REST_CONN_TLS) { + /* TLS on emulation reads directly from socket through mbedtls BIO callbacks. */ + if (tls_progress_conn(conn) != ERR_OK) { + close_conn(conn); + break; + } + } + else { + ssize_t n = recv(conn->sock, conn->request + conn->request_len, REST_MAX_REQUEST_SIZE - conn->request_len, 0); + if (n <= 0) { + close_conn(conn); + break; + } + conn->request_len += (size_t)n; + if (conn->request_len > REST_MAX_REQUEST_SIZE) { + send_json(conn, 413, "Payload Too Large", "{\"error\":\"payload_too_large\"}"); + break; + } + handle_request(conn); } - handle_request(conn); } } return NULL; } #endif -err_t rest_server_init(void) { +static err_t rest_server_init_conn(int *listener_sock, int port, rest_conn_type_t conn_type, const tls_credentials_t *tls_credentials) { #ifndef _MSC_VER struct sockaddr_in addr; int one = 1; - int port = emulation_rest_port(); - if (listener_sock >= 0) { + if (*listener_sock >= 0) { return ERR_OK; } - listener_sock = socket(AF_INET, SOCK_STREAM, 0); - if (listener_sock < 0) { + if (conn_type & REST_CONN_TLS) { + if (tls_credentials == NULL || tls_credentials->tls_key_pem == NULL || tls_credentials->tls_cert_pem == NULL) { + return ERR_VAL; + } + if (tls_init_tls_context(tls_credentials) != 0) { + return ERR_VAL; + } + } + *listener_sock = socket(AF_INET, SOCK_STREAM, 0); + if (*listener_sock < 0) { return -1; } - if (setsockopt(listener_sock, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) != 0) { - (void)close(listener_sock); - listener_sock = -1; + if (setsockopt(*listener_sock, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) != 0) { + (void)close(*listener_sock); + *listener_sock = -1; return -1; } memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons((uint16_t)port); addr.sin_addr.s_addr = htonl(INADDR_ANY); - if (bind(listener_sock, (struct sockaddr *)&addr, sizeof(addr)) != 0) { - (void)close(listener_sock); - listener_sock = -1; + if (bind(*listener_sock, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + (void)close(*listener_sock); + *listener_sock = -1; return -1; } - if (listen(listener_sock, REST_MAX_CONNS) != 0) { - (void)close(listener_sock); - listener_sock = -1; + if (listen(*listener_sock, REST_MAX_CONNS) != 0) { + (void)close(*listener_sock); + *listener_sock = -1; return -1; } - if (pthread_create(&rest_thread, NULL, rest_emulation_thread, NULL) != 0) { - (void)close(listener_sock); - listener_sock = -1; + if (pthread_create(&rest_thread, NULL, rest_emulation_thread, (void *)(intptr_t)(*listener_sock)) != 0) { + (void)close(*listener_sock); + *listener_sock = -1; return -1; } (void)pthread_detach(rest_thread); @@ -765,25 +882,32 @@ err_t rest_server_init(void) { #endif } +err_t rest_server_init(rest_conn_type_t conn_type) { + + if (conn_type & REST_CONN_PLAIN) { + if (rest_server_init_conn(&listener_sock, emulation_rest_port(), REST_CONN_PLAIN, NULL) != ERR_OK) { + return -1; + } + } + if (conn_type & REST_CONN_TLS) { + rest_check_and_load_credentials(); + if (rest_server_init_conn(&tls_listener_sock, emulation_rest_tls_port(), REST_CONN_TLS, &tls_credentials) != ERR_OK) { + return -1; + } + } + return ERR_OK; + +} + int lwip_itf_init(void) { - return rest_server_init(); + err_t err = rest_server_init(REST_CONN_ALL); + if (err != ERR_OK) { + return err; + } + return ERR_OK; } #endif int rest_server_error(rest_response_t *response, int status_code, const char *message) { - if (response == NULL) { - return -1; - } - char json_template[256]; - int json_len = snprintf(json_template, sizeof(json_template), "{\"error\":\"%s\"}", message); - if (json_len <= 0 || (size_t)json_len >= sizeof(json_template)) { - return -1; - } - response->status_code = status_code; - response->body = strdup(json_template); - if (response->body == NULL) { - return -1; - } - response->body_len = (size_t)json_len; - return 0; + return rest_response_set_error(response, status_code, message); } diff --git a/src/usb/lwip/rest_server.h b/src/usb/lwip/rest_server.h index 074c37b..e8b9780 100644 --- a/src/usb/lwip/rest_server.h +++ b/src/usb/lwip/rest_server.h @@ -22,49 +22,53 @@ #ifdef ENABLE_EMULATION typedef int err_t; #define ERR_OK 0 +#define ERR_VAL -6 +#define ERR_ABRT -13 #else #include "lwip/err.h" #endif #include #include #include "cJSON.h" +#include +#include "mbedtls/ssl.h" +#include "rest.h" +#define REST_PORT 80 +#define REST_MAX_CONNS 4 +#define REST_MAX_REQUEST_SIZE 1024 +#define REST_MAX_METHOD_SIZE 8 +#define REST_MAX_CONTENT_TYPE_SIZE 64 #define REST_MAX_PATH_SIZE 192 +#define EF_TLS_KEY 0xD500 +#define EF_TLS_CERT 0xD501 + typedef enum { - REST_HTTP_GET = 0, - REST_HTTP_POST, - REST_HTTP_PUT, - REST_HTTP_DELETE -} rest_http_method_t; + REST_CONN_PLAIN = 0x1, + REST_CONN_TLS = 0x2, + REST_CONN_ALL = REST_CONN_PLAIN | REST_CONN_TLS +} rest_conn_type_t; typedef struct { - rest_http_method_t method; - char path[REST_MAX_PATH_SIZE]; - const char *body; - size_t body_len; - const char *content_type; -} rest_request_t; + bool in_use; +#ifdef ENABLE_EMULATION + int sock; +#else + struct tcp_pcb *pcb; +#endif + char request[REST_MAX_REQUEST_SIZE + 1]; + size_t request_len; + rest_conn_type_t conn_type; + mbedtls_ssl_context ssl; + unsigned char rx_cipher[REST_MAX_REQUEST_SIZE]; + size_t rx_cipher_len; + bool handshake_done; + bool request_complete; + bool request_dispatched; +} rest_conn_t; -typedef struct { - uint16_t status_code; - const char *content_type; - char *body; // heap ! - size_t body_len; - cJSON *json; -} rest_response_t; - -typedef int (*rest_route_handler_t)(const rest_request_t *request, rest_response_t *response); - -typedef struct { - rest_http_method_t method; - const char *path; - rest_route_handler_t handler; -} rest_route_t; - -const rest_route_t *rest_get_routes(size_t *count); - -err_t rest_server_init(void); +err_t rest_server_init(rest_conn_type_t conn_type); int lwip_itf_init(void); extern int rest_server_error(rest_response_t *response, int status_code, const char *message); diff --git a/src/usb/lwip/rest_server_tls.c b/src/usb/lwip/rest_server_tls.c new file mode 100644 index 0000000..df774f3 --- /dev/null +++ b/src/usb/lwip/rest_server_tls.c @@ -0,0 +1,320 @@ +/* + * This file is part of the Pico Keys SDK distribution (https://github.com/polhenarejos/pico-keys-sdk). + * Copyright (c) 2022 Pol Henarejos. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, version 3. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#define MBEDTLS_ALLOW_PRIVATE_ACCESS +#include "rest_server_tls.h" + +#include +#include +#include +#include +#include + +extern void close_conn(rest_conn_t *conn); +extern void handle_request(rest_conn_t *conn); + +static const int tls_ciphersuites[] = { + MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + 0 +}; + +static bool tls_ctx_ready = false; +mbedtls_ssl_config tls_conf; +mbedtls_x509_crt tls_cert; +mbedtls_pk_context tls_key; +tls_credentials_t tls_credentials = {0}; + +int tls_init_tls_context(const tls_credentials_t *tls_credentials) { + int ret; + if (tls_ctx_ready) { + return 0; + } + if (tls_credentials == NULL || tls_credentials->tls_key_pem == NULL || tls_credentials->tls_cert_pem == NULL) { + return -1; + } + + mbedtls_ssl_config_init(&tls_conf); + mbedtls_x509_crt_init(&tls_cert); + mbedtls_pk_init(&tls_key); + mbedtls_pk_setup(&tls_key, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); + + ret = mbedtls_x509_crt_parse(&tls_cert, (const unsigned char *)tls_credentials->tls_cert_pem, tls_credentials->tls_cert_pem_len); + if (ret != 0) { + return ret; + } + ret = mbedtls_ecp_read_key(MBEDTLS_ECP_DP_SECP256R1, mbedtls_pk_ec(tls_key), (const unsigned char *)tls_credentials->tls_key_pem, tls_credentials->tls_key_pem_len); + if (ret != 0) { + return ret; + } + mbedtls_ecp_check_privkey(&mbedtls_pk_ec(tls_key)->grp, &mbedtls_pk_ec(tls_key)->d); + mbedtls_ecp_mul(&mbedtls_pk_ec(tls_key)->grp, &mbedtls_pk_ec(tls_key)->Q, + &mbedtls_pk_ec(tls_key)->d, &mbedtls_pk_ec(tls_key)->grp.G, + random_gen, NULL); + mbedtls_ecp_check_pubkey(&mbedtls_pk_ec(tls_key)->grp, &mbedtls_pk_ec(tls_key)->Q); + ret = mbedtls_ssl_config_defaults(&tls_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + return ret; + } + + mbedtls_ssl_conf_rng(&tls_conf, random_gen, NULL); + mbedtls_ssl_conf_min_tls_version(&tls_conf, MBEDTLS_SSL_VERSION_TLS1_2); + mbedtls_ssl_conf_max_tls_version(&tls_conf, MBEDTLS_SSL_VERSION_TLS1_2); + mbedtls_ssl_conf_ciphersuites(&tls_conf, tls_ciphersuites); + ret = mbedtls_ssl_conf_own_cert(&tls_conf, &tls_cert, &tls_key); + if (ret != 0) { + return ret; + } + + tls_ctx_ready = true; + return 0; +} + +static const char *find_header_end(const char *request, size_t request_len) { + if (request == NULL || request_len < 4) { + return NULL; + } + return strstr(request, "\r\n\r\n"); +} + +static int parse_content_length(const char *request, size_t request_len, size_t *content_len) { + const char *line = request; + const char *header_end = find_header_end(request, request_len); + if (header_end == NULL) { + return 0; + } + + *content_len = 0; + while (line < header_end) { + const char *next = strstr(line, "\r\n"); + const char *colon; + if (next == NULL || next > header_end) { + break; + } + if (next == line) { + break; + } + colon = memchr(line, ':', (size_t)(next - line)); + if (colon != NULL) { + size_t name_len = (size_t)(colon - line); + const char *value = colon + 1; + while (value < next && (*value == ' ' || *value == '\t')) { + value++; + } + if (name_len == 14 && strncasecmp(line, "Content-Length", 14) == 0) { + char *endptr = NULL; + unsigned long v = strtoul(value, &endptr, 10); + if (endptr == value || endptr > next) { + return -1; + } + *content_len = (size_t)v; + return 1; + } + } + line = next + 2; + } + return 1; +} + +static int request_is_complete(const char *request, size_t request_len, size_t *payload_offset, size_t *payload_len) { + const char *header_end = find_header_end(request, request_len); + size_t content_len = 0; + int rc; + if (header_end == NULL) { + return 0; + } + *payload_offset = (size_t)((header_end + 4) - request); + rc = parse_content_length(request, request_len, &content_len); + if (rc < 0) { + return -1; + } + *payload_len = content_len; + if (request_len < *payload_offset + content_len) { + return 0; + } + return 1; +} + +#ifdef ENABLE_EMULATION + +#ifndef _MSC_VER +#include +#include +#include +#include +#include +#include + +int tls_listener_sock = -1; + +int emulation_rest_tls_port(void) { + const char *port_env = getenv("PICO_REST_TLS_PORT"); + long v; + if (port_env == NULL || *port_env == '\0') { + return REST_TLS_PORT; + } + errno = 0; + v = strtol(port_env, NULL, 10); + if (errno != 0 || v < 1 || v > 65535) { + return REST_TLS_PORT; + } + return (int)v; +} + +int tls_send_cb(void *ctx, const unsigned char *buf, size_t len) { + const int fd = *(const int *)ctx; + ssize_t r = send(fd, buf, len, 0); + if (r >= 0) { + return (int)r; + } + if (errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; +} + +int tls_recv_cb(void *ctx, unsigned char *buf, size_t len) { + const int fd = *(const int *)ctx; + ssize_t r = recv(fd, buf, len, 0); + if (r > 0) { + return (int)r; + } + if (r == 0) { + return MBEDTLS_ERR_SSL_CONN_EOF; + } + if (errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; +} +#endif + +#else + +#include "lwip/def.h" +#include "lwip/tcp.h" + +struct tcp_pcb *tls_listener_pcb = NULL; + +int tls_send_cb(void *ctx, const unsigned char *buf, size_t len) { + rest_conn_t *conn = (rest_conn_t *)ctx; + size_t chunk; + u16_t snd_avail; + err_t err; + if (conn == NULL || conn->pcb == NULL) { + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } + if (len == 0) { + return 0; + } + snd_avail = tcp_sndbuf(conn->pcb); + if (snd_avail == 0) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + chunk = (len > snd_avail) ? snd_avail : len; + if (chunk > TCP_MSS) { + chunk = TCP_MSS; + } + err = tcp_write(conn->pcb, buf, (uint16_t)chunk, TCP_WRITE_FLAG_COPY); + if (err == ERR_MEM) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + if (err != ERR_OK) { + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } + (void)tcp_output(conn->pcb); + return (int)chunk; +} + +int tls_recv_cb(void *ctx, unsigned char *buf, size_t len) { + rest_conn_t *conn = (rest_conn_t *)ctx; + size_t n; + if (conn == NULL || buf == NULL || len == 0) { + return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + } + if (conn->rx_cipher_len == 0) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + n = (len < conn->rx_cipher_len) ? len : conn->rx_cipher_len; + memcpy(buf, conn->rx_cipher, n); + conn->rx_cipher_len -= n; + if (conn->rx_cipher_len > 0) { + memmove(conn->rx_cipher, conn->rx_cipher + n, conn->rx_cipher_len); + } + return (int)n; +} + +#endif + +err_t tls_progress_conn(rest_conn_t *conn) { + int ret; + if (conn == NULL +#ifdef ENABLE_EMULATION + || conn->sock < 0 +#else + || conn->pcb == NULL +#endif + ) { + return ERR_OK; + } + + if (!conn->handshake_done) { + while ((ret = mbedtls_ssl_handshake(&conn->ssl)) != 0) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return ERR_OK; + } + close_conn(conn); + return ERR_ABRT; + } + conn->handshake_done = true; + } + + while (!conn->request_complete) { + size_t payload_offset, payload_len; + ret = mbedtls_ssl_read(&conn->ssl, (unsigned char *)conn->request + conn->request_len, REST_MAX_REQUEST_SIZE - conn->request_len); + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return ERR_OK; + } + if (ret <= 0) { + close_conn(conn); + return ERR_ABRT; + } + conn->request_len += (size_t)ret; + conn->request[conn->request_len] = '\0'; + ret = request_is_complete(conn->request, conn->request_len, &payload_offset, &payload_len); + if (ret < 0) { + close_conn(conn); + return ERR_ABRT; + } + if (ret == 0) { + if (conn->request_len >= REST_MAX_REQUEST_SIZE) { + close_conn(conn); + return ERR_ABRT; + } + continue; + } + conn->request_complete = true; + } + + if (!conn->request_dispatched) { + conn->request_dispatched = true; + handle_request(conn); + } + return ERR_OK; +} diff --git a/src/usb/lwip/rest_server_tls.h b/src/usb/lwip/rest_server_tls.h new file mode 100644 index 0000000..19b125b --- /dev/null +++ b/src/usb/lwip/rest_server_tls.h @@ -0,0 +1,61 @@ +/* + * This file is part of the Pico Keys SDK distribution (https://github.com/polhenarejos/pico-keys-sdk). + * Copyright (c) 2022 Pol Henarejos. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, version 3. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef PICO_KEYS_REST_SERVER_TLS_H +#define PICO_KEYS_REST_SERVER_TLS_H + +#ifdef ENABLE_EMULATION +typedef int err_t; +#define ERR_OK 0 +#else +#include "lwip/err.h" +#endif + +#include "mbedtls/pk.h" +#include "mbedtls/ssl.h" +#include "mbedtls/x509_crt.h" +#include "random.h" +#include "rest_server.h" + +#define REST_TLS_PORT 443 + +typedef struct { + char *tls_key_pem; + size_t tls_key_pem_len; + char *tls_cert_pem; + size_t tls_cert_pem_len; +} tls_credentials_t; + +extern tls_credentials_t tls_credentials; + +extern mbedtls_ssl_config tls_conf; +extern mbedtls_x509_crt tls_cert; +extern mbedtls_pk_context tls_key; +extern int tls_send_cb(void *ctx, const unsigned char *buf, size_t len); +extern int tls_recv_cb(void *ctx, unsigned char *buf, size_t len); +extern err_t tls_progress_conn(rest_conn_t *conn); +extern int tls_init_tls_context(const tls_credentials_t *tls_credentials); + +#ifdef ENABLE_EMULATION +extern int emulation_rest_tls_port(void); +extern int tls_listener_sock; +extern void tls_handle_client(int client_fd); +#else +extern struct tcp_pcb *tls_listener_pcb; +#endif + +#endif