Fix rare race condition with hwrng.

Signed-off-by: Pol Henarejos <pol.henarejos@cttc.es>
This commit is contained in:
Pol Henarejos
2026-04-27 20:32:28 +02:00
parent dcf747a766
commit 6069c3dc2e
2 changed files with 85 additions and 12 deletions

View File

@@ -21,9 +21,13 @@
#if defined(PICO_PLATFORM)
#include "pico/rand.h"
#include "pico/mutex.h"
#elif defined(ESP_PLATFORM)
#include "bootloader_random.h"
#include "esp_random.h"
#include "compat/esp_compat.h"
#else
#include "compat/queue.h"
#endif
static void hwrng_start(void) {
@@ -75,6 +79,21 @@ struct hwrng_buf {
unsigned int empty : 1;
};
static mutex_t hwrng_mutex;
static bool hwrng_mutex_initialized = false;
static inline void hwrng_lock(void) {
if (hwrng_mutex_initialized) {
mutex_enter_blocking(&hwrng_mutex);
}
}
static inline void hwrng_unlock(void) {
if (hwrng_mutex_initialized) {
mutex_exit(&hwrng_mutex);
}
}
static void hwrng_buf_init(struct hwrng_buf *rb, uint32_t *p, uint8_t size) {
rb->buf = p;
rb->size = size;
@@ -84,22 +103,34 @@ static void hwrng_buf_init(struct hwrng_buf *rb, uint32_t *p, uint8_t size) {
}
static void hwrng_buf_add(struct hwrng_buf *rb, uint32_t v) {
rb->buf[rb->tail++] = v;
if (rb->tail == rb->size) {
rb->tail = 0;
if (rb->full) {
return;
}
if (rb->tail == rb->head) {
rb->full = 1;
uint8_t tail = rb->tail;
rb->buf[tail] = v;
tail++;
if (tail >= rb->size) {
tail = 0;
}
rb->tail = tail;
rb->full = (rb->tail == rb->head);
rb->empty = 0;
}
static uint32_t hwrng_buf_del(struct hwrng_buf *rb) {
uint32_t v = rb->buf[rb->head++];
if (rb->head == rb->size) {
rb->head = 0;
uint32_t v = 0;
if (rb->empty) {
return v;
}
uint8_t head = rb->head;
v = rb->buf[head];
head++;
if (head >= rb->size) {
head = 0;
}
rb->head = head;
if (rb->head == rb->tail) {
rb->empty = 1;
}
@@ -115,6 +146,7 @@ void *hwrng_task(void) {
int n;
hwrng_lock();
if ((n = hwrng_mix_process())) {
const uint32_t *vp = (const uint32_t *) &random_word;
@@ -125,12 +157,15 @@ void *hwrng_task(void) {
}
}
}
hwrng_unlock();
return NULL;
}
void hwrng_init(uint32_t *buf, uint8_t size) {
struct hwrng_buf *rb = &ring_buffer;
mutex_init(&hwrng_mutex);
hwrng_mutex_initialized = true;
hwrng_buf_init(rb, buf, size);
hwrng_start();
@@ -140,19 +175,28 @@ void hwrng_init(uint32_t *buf, uint8_t size) {
void hwrng_flush(void) {
struct hwrng_buf *rb = &ring_buffer;
hwrng_lock();
while (!rb->empty) {
hwrng_buf_del(rb);
}
hwrng_unlock();
}
uint32_t hwrng_get(void) {
struct hwrng_buf *rb = &ring_buffer;
uint32_t v;
while (rb->empty) {
while (true) {
hwrng_lock();
bool empty = rb->empty;
if (!empty) {
v = hwrng_buf_del(rb);
hwrng_unlock();
break;
}
hwrng_unlock();
hwrng_task();
}
v = hwrng_buf_del(rb);
return v;
}
@@ -164,7 +208,13 @@ void hwrng_wait_full(void) {
#elif defined(PICO_PLATFORM)
uint core = get_core_num();
#endif
while (!rb->full) {
while (true) {
hwrng_lock();
bool full = rb->full;
hwrng_unlock();
if (full) {
break;
}
#if defined(PICO_PLATFORM) || defined(ESP_PLATFORM)
if (core == 1) {
sleep_ms(1);

View File

@@ -20,11 +20,22 @@
#include "picokeys.h"
#include "hwrng.h"
#include "random.h"
#if defined(PICO_PLATFORM)
#include "pico/mutex.h"
#elif defined(ESP_PLATFORM)
#include "compat/esp_compat.h"
#else
#include "compat/queue.h"
#endif
#define RANDOM_BYTES_LENGTH 32
static uint32_t random_word[RANDOM_BYTES_LENGTH / sizeof(uint32_t)];
static mutex_t random_mutex;
static bool random_mutex_initialized = false;
void random_init(void) {
mutex_init(&random_mutex);
random_mutex_initialized = true;
hwrng_init(random_word, RANDOM_BYTES_LENGTH / sizeof(uint32_t));
for (int i = 0; i < HWRNG_PRE_LOOP; i++) {
@@ -50,11 +61,17 @@ const uint8_t *random_bytes_get(size_t len) {
return NULL;
}
static uint32_t return_word[MAX_RANDOM_BUFFER / sizeof(uint32_t)];
if (random_mutex_initialized) {
mutex_enter_blocking(&random_mutex);
}
for (size_t ix = 0; ix < len; ix += RANDOM_BYTES_LENGTH) {
hwrng_wait_full();
memcpy(return_word + ix / sizeof(uint32_t), random_word, RANDOM_BYTES_LENGTH);
random_bytes_free((const uint8_t *) random_word);
}
if (random_mutex_initialized) {
mutex_exit(&random_mutex);
}
return (const uint8_t *) return_word;
}
@@ -66,6 +83,9 @@ int random_fill_iterator(void *arg, unsigned char *out, size_t out_len) {
uint8_t index = index_p ? *index_p : 0;
uint8_t n;
if (random_mutex_initialized) {
mutex_enter_blocking(&random_mutex);
}
while (out_len) {
hwrng_wait_full();
@@ -88,6 +108,9 @@ int random_fill_iterator(void *arg, unsigned char *out, size_t out_len) {
if (index_p) {
*index_p = index;
}
if (random_mutex_initialized) {
mutex_exit(&random_mutex);
}
return 0;
}