From 0857515c943d439eade80710c16f15f146dfa9e8 Mon Sep 17 00:00:00 2001
From: David Gibson <david@gibson.dropbear.id.au>
Date: Mon, 17 Mar 2025 20:24:23 +1100
Subject: [PATCH] packet: ASSERT on signs of pool corruption

If packet_check_range() fails in packet_get_try_do() we just return NULL.
But this check only takes places after we've already validated the given
range against the packet it's in.  That means that if packet_check_range()
fails, the packet pool is already in a corrupted state (we should have
made strictly stronger checks when the packet was added).  Simply returning
NULL and logging a trace() level message isn't really adequate for that
situation; ASSERT instead.

Similarly we check the given idx against both p->count and p->size.  The
latter should be redundant, because count should always be <= size.  If
that's not the case then, again, the pool is already in a corrupted state
and we may have overwritten unknown memory.  Assert for this case too.

Signed-off-by: David Gibson <david@gibson.dropbear.id.au>
Signed-off-by: Stefano Brivio <sbrivio@redhat.com>
---
 packet.c | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/packet.c b/packet.c
index b3e8c79..be28f27 100644
--- a/packet.c
+++ b/packet.c
@@ -129,9 +129,13 @@ void *packet_get_try_do(const struct pool *p, size_t idx, size_t offset,
 {
 	char *ptr;
 
-	if (idx >= p->size || idx >= p->count) {
-		trace("packet %zu from pool size: %zu, count: %zu, %s:%i",
-		      idx, p->size, p->count, func, line);
+	ASSERT_WITH_MSG(p->count <= p->size,
+			"Corrupt pool count: %zu, size: %zu, %s:%i",
+			p->count, p->size, func, line);
+
+	if (idx >= p->count) {
+		trace("packet %zu from pool count: %zu, %s:%i",
+		      idx, p->count, func, line);
 		return NULL;
 	}
 
@@ -141,8 +145,8 @@ void *packet_get_try_do(const struct pool *p, size_t idx, size_t offset,
 
 	ptr = (char *)p->pkt[idx].iov_base + offset;
 
-	if (packet_check_range(p, ptr, len, func, line))
-		return NULL;
+	ASSERT_WITH_MSG(!packet_check_range(p, ptr, len, func, line),
+			"Corrupt packet pool, %s:%i", func, line);
 
 	if (left)
 		*left = p->pkt[idx].iov_len - offset - len;