diff --git a/src/usb/lwip/rest.h b/src/usb/lwip/rest.h index 96221d6..30040e0 100644 --- a/src/usb/lwip/rest.h +++ b/src/usb/lwip/rest.h @@ -84,6 +84,7 @@ typedef enum { REST_ROUTE_NONE = 0x0, REST_ROUTE_REQUIRE_AUTH = 0x1, REST_ROUTE_REQUIRE_TLS = 0x2, + REST_ROUTE_ALLOW_QUERY = 0x4 } rest_route_flags_t; typedef int (*rest_route_param_parser_t)(const char *str, const char *param_str, rest_param_t params_out[REST_MAX_REQUEST_PARAMS]); @@ -122,6 +123,11 @@ typedef struct { uint8_t user_id; } rest_session_t; +typedef struct { + const char *key; + const char *value; +} rest_query_t; + typedef struct { rest_http_method_t method; char path[REST_MAX_PATH_SIZE]; @@ -130,6 +136,8 @@ typedef struct { const char *content_type; char *headers[REST_HEADER_TOTAL_COUNT]; rest_param_t params[REST_MAX_REQUEST_PARAMS]; + rest_query_t *query; + uint8_t query_count; rest_session_t *session; rest_request_conn_type_t conn_type; } rest_request_t; diff --git a/src/usb/lwip/rest_server.c b/src/usb/lwip/rest_server.c index ff970cb..9b014f3 100644 --- a/src/usb/lwip/rest_server.c +++ b/src/usb/lwip/rest_server.c @@ -186,6 +186,7 @@ void rest_task(void) { // return; //} rest_response_t *response = &rest_core1_result.response; + rest_request_t *request = &rest_core1_job.request; if (response == NULL) { return; } @@ -201,6 +202,11 @@ void rest_task(void) { if (response->body != NULL) { free(response->body); + response->body = NULL; + } + if (request->query != NULL) { + free(request->query); + request->query = NULL; } memset(&rest_core1_result, 0, sizeof(rest_core1_result)); memset(&rest_core1_job, 0, sizeof(rest_core1_job)); @@ -725,6 +731,61 @@ static int rest_verify_request_signature(const rest_request_t *request, const re return PICOKEYS_OK; } +static int parse_query(rest_request_t *request, char *query_str) { + uint8_t count = 1; + for (const char *p = query_str; (p = strchr(p, '&')) != NULL && count < UINT8_MAX; p++, count++); + if (count >= UINT8_MAX) { + return PICOKEYS_EXEC_ERROR; + } + request->query = (rest_query_t *)calloc(count, sizeof(rest_query_t)); + if (request->query == NULL) { + return PICOKEYS_ERR_MEMORY_FATAL; + } + request->query_count = count; + count = 0; + for (char *p = query_str; *p != '\0'; p++) { + if (*p == '&') { + *p++ = '\0'; + if (*p == '\0' || *p == '&' || *p == '=') { + free(request->query); + request->query = NULL; + return PICOKEYS_ERR_NULL_PARAM; + } + if (request->query[count].key == NULL) { + request->query[count++].key = query_str; + } + else { + count++; + } + query_str = p; + } + else if (*p == '=') { + if (p == query_str || request->query[count].key != NULL) { + free(request->query); + request->query = NULL; + return PICOKEYS_ERR_NULL_PARAM; + } + *p++ = '\0'; + if (*p == '\0' || *p == '&' || *p == '=') { + free(request->query); + request->query = NULL; + return PICOKEYS_ERR_NULL_PARAM; + } + request->query[count].key = query_str; + request->query[count].value = p; + } + } + if (count < request->query_count) { + request->query[count++].key = query_str; + } + if (count != request->query_count) { + free(request->query); + request->query = NULL; + return PICOKEYS_ERR_MEMORY_FATAL; + } + return PICOKEYS_OK; +} + void rest_handle_request(rest_conn_t *conn) { rest_request_t *request = &rest_core1_job.request; const rest_route_t *routes; @@ -732,6 +793,9 @@ void rest_handle_request(rest_conn_t *conn) { bool path_exists_for_other_method = false; int parsed; + if (request && request->query) { + free(request->query); + } memset(request, 0, sizeof(*request)); parsed = parse_request(conn, request); if (parsed <= 0) { @@ -761,6 +825,14 @@ void rest_handle_request(rest_conn_t *conn) { } } routes = rest_get_routes(&route_count); + char *query_str = strchr(request->path, '?'); + if (query_str != NULL) { + *query_str++ = '\0'; + if (*query_str == '\0') { + send_json_error(conn, 400, "bad_request"); + return; + } + } for (i = 0; i < route_count; i++) { if (routes[i].path == NULL || routes[i].handler == NULL) { continue; @@ -778,6 +850,10 @@ void rest_handle_request(rest_conn_t *conn) { path_exists_for_other_method = true; continue; } + if (query_str && !(routes[i].flags & REST_ROUTE_ALLOW_QUERY)) { + send_json_error(conn, 400, "query_not_allowed"); + return; + } if (routes[i].flags & REST_ROUTE_REQUIRE_AUTH) { if (!request->headers[REST_HEADER_X_SESSION_ID] || strlen(request->headers[REST_HEADER_X_SESSION_ID]) == 0 ||!request->headers[REST_HEADER_X_SIGNATURE] || strlen(request->headers[REST_HEADER_X_SIGNATURE]) == 0 || !request->headers[REST_HEADER_X_SEQ] || strlen(request->headers[REST_HEADER_X_SEQ]) == 0) { send_json_error(conn, 401, "authentication_required"); @@ -835,8 +911,17 @@ void rest_handle_request(rest_conn_t *conn) { } return; } + + if (query_str != NULL && parse_query(request, query_str) != PICOKEYS_OK) { + send_json_error(conn, 400, "bad_request"); + return; + } if (rest_start_core1_job(conn, routes[i].handler) != 0) { send_json_error(conn, 500, "internal_error"); + if (request->query != NULL) { + free(request->query); + request->query = NULL; + } } return; }