diff --git a/tcp.c b/tcp.c index 108006e..4e3a134 100644 --- a/tcp.c +++ b/tcp.c @@ -573,22 +573,12 @@ static unsigned int tcp6_l2_flags_buf_used; #define CONN(idx) (&(FLOW(idx)->tcp)) -/** conn_at_idx() - Find a connection by index, if present - * @idx: Index of connection to lookup - * - * Return: pointer to connection, or NULL if @idx is out of bounds - */ -static inline struct tcp_tap_conn *conn_at_idx(unsigned idx) -{ - if (idx >= FLOW_MAX) - return NULL; - ASSERT(CONN(idx)->f.type == FLOW_TCP); - return CONN(idx); -} - /* Table for lookup from remote address, local port, remote port */ static struct tcp_tap_conn *tc_hash[TCP_HASH_TABLE_SIZE]; +static_assert(ARRAY_SIZE(tc_hash) >= FLOW_MAX, + "Safe linear probing requires hash table larger than connection table"); + /* Pools for pre-opened sockets (in init) */ int init_sock_pool4 [TCP_SOCK_POOL_SIZE]; int init_sock_pool6 [TCP_SOCK_POOL_SIZE]; @@ -1197,6 +1187,26 @@ static unsigned int tcp_conn_hash(const struct ctx *c, return tcp_hash(c, &conn->faddr, conn->eport, conn->fport); } +/** + * tcp_hash_probe() - Find hash bucket for a connection + * @c: Execution context + * @conn: Connection to find bucket for + * + * Return: If @conn is in the table, its current bucket, otherwise a suitable + * free bucket for it. + */ +static inline unsigned tcp_hash_probe(const struct ctx *c, + const struct tcp_tap_conn *conn) +{ + unsigned b = tcp_conn_hash(c, conn); + + /* Linear probing */ + while (tc_hash[b] && tc_hash[b] != conn) + b = mod_sub(b, 1, TCP_HASH_TABLE_SIZE); + + return b; +} + /** * tcp_hash_insert() - Insert connection into hash table, chain link * @c: Execution context @@ -1204,14 +1214,10 @@ static unsigned int tcp_conn_hash(const struct ctx *c, */ static void tcp_hash_insert(const struct ctx *c, struct tcp_tap_conn *conn) { - int b; + unsigned b = tcp_hash_probe(c, conn); - b = tcp_hash(c, &conn->faddr, conn->eport, conn->fport); - conn->next_index = tc_hash[b] ? FLOW_IDX(tc_hash[b]) : -1U; tc_hash[b] = conn; - - flow_dbg(conn, "hash table insert: sock %i, bucket: %i, next: %p", - conn->sock, b, (void *)conn_at_idx(conn->next_index)); + flow_dbg(conn, "hash table insert: sock %i, bucket: %u", conn->sock, b); } /** @@ -1222,23 +1228,27 @@ static void tcp_hash_insert(const struct ctx *c, struct tcp_tap_conn *conn) static void tcp_hash_remove(const struct ctx *c, const struct tcp_tap_conn *conn) { - struct tcp_tap_conn *entry, *prev = NULL; - int b = tcp_conn_hash(c, conn); + unsigned b = tcp_hash_probe(c, conn), s; - for (entry = tc_hash[b]; entry; - prev = entry, entry = conn_at_idx(entry->next_index)) { - if (entry == conn) { - if (prev) - prev->next_index = conn->next_index; - else - tc_hash[b] = conn_at_idx(conn->next_index); - break; + if (!tc_hash[b]) + return; /* Redundant remove */ + + flow_dbg(conn, "hash table remove: sock %i, bucket: %u", conn->sock, b); + + /* Scan the remainder of the cluster */ + for (s = mod_sub(b, 1, TCP_HASH_TABLE_SIZE); tc_hash[s]; + s = mod_sub(s, 1, TCP_HASH_TABLE_SIZE)) { + unsigned h = tcp_conn_hash(c, tc_hash[s]); + + if (!mod_between(h, s, b, TCP_HASH_TABLE_SIZE)) { + /* tc_hash[s] can live in tc_hash[b]'s slot */ + debug("hash table remove: shuffle %u -> %u", s, b); + tc_hash[b] = tc_hash[s]; + b = s; } } - flow_dbg(conn, "hash table remove: sock %i, bucket: %i, new: %p", - conn->sock, b, - (void *)(prev ? conn_at_idx(prev->next_index) : tc_hash[b])); + tc_hash[b] = NULL; } /** @@ -1251,24 +1261,15 @@ void tcp_tap_conn_update(const struct ctx *c, struct tcp_tap_conn *old, struct tcp_tap_conn *new) { - struct tcp_tap_conn *entry, *prev = NULL; - int b = tcp_conn_hash(c, old); + unsigned b = tcp_hash_probe(c, old); - for (entry = tc_hash[b]; entry; - prev = entry, entry = conn_at_idx(entry->next_index)) { - if (entry == old) { - if (prev) - prev->next_index = FLOW_IDX(new); - else - tc_hash[b] = new; - break; - } - } + if (!tc_hash[b]) + return; /* Not in hash table, nothing to update */ + + tc_hash[b] = new; debug("TCP: hash table update: old index %u, new index %u, sock %i, " - "bucket: %i, old: %p, new: %p", - FLOW_IDX(old), FLOW_IDX(new), new->sock, b, - (void *)old, (void *)new); + "bucket: %u", FLOW_IDX(old), FLOW_IDX(new), new->sock, b); tcp_epoll_ctl(c, new); } @@ -1288,17 +1289,15 @@ static struct tcp_tap_conn *tcp_hash_lookup(const struct ctx *c, in_port_t eport, in_port_t fport) { union inany_addr aany; - struct tcp_tap_conn *conn; - int b; + unsigned b; inany_from_af(&aany, af, faddr); - b = tcp_hash(c, &aany, eport, fport); - for (conn = tc_hash[b]; conn; conn = conn_at_idx(conn->next_index)) { - if (tcp_hash_match(conn, &aany, eport, fport)) - return conn; - } - return NULL; + b = tcp_hash(c, &aany, eport, fport); + while (tc_hash[b] && !tcp_hash_match(tc_hash[b], &aany, eport, fport)) + b = mod_sub(b, 1, TCP_HASH_TABLE_SIZE); + + return tc_hash[b]; } /** diff --git a/tcp_conn.h b/tcp_conn.h index 3900305..e3400bb 100644 --- a/tcp_conn.h +++ b/tcp_conn.h @@ -13,7 +13,6 @@ * struct tcp_tap_conn - Descriptor for a TCP connection (not spliced) * @f: Generic flow information * @in_epoll: Is the connection in the epoll set? - * @next_index: Connection index of next item in hash chain, -1 for none * @tap_mss: MSS advertised by tap/guest, rounded to 2 ^ TCP_MSS_BITS * @sock: Socket descriptor number * @events: Connection events, implying connection states @@ -40,7 +39,6 @@ struct tcp_tap_conn { struct flow_common f; bool in_epoll :1; - unsigned next_index :FLOW_INDEX_BITS + 2; #define TCP_RETRANS_BITS 3 unsigned int retrans :TCP_RETRANS_BITS; diff --git a/util.h b/util.h index 53bb54b..9446ea7 100644 --- a/util.h +++ b/util.h @@ -227,6 +227,34 @@ int __daemon(int pidfile_fd, int devnull_fd); int fls(unsigned long x); int write_file(const char *path, const char *buf); +/** + * mod_sub() - Modular arithmetic subtraction + * @a: Minued, unsigned value < @m + * @b: Subtrahend, unsigned value < @m + * @m: Modulus, must be less than (UINT_MAX / 2) + * + * Returns (@a - @b) mod @m, correctly handling unsigned underflows. + */ +static inline unsigned mod_sub(unsigned a, unsigned b, unsigned m) +{ + if (a < b) + a += m; + return a - b; +} + +/** + * mod_between() - Determine if a value is in a cyclic range + * @x, @i, @j: Unsigned values < @m + * @m: Modulus + * + * Returns true iff @x is in the cyclic range of values from @i..@j (mod @m), + * inclusive of @i, exclusive of @j. + */ +static inline bool mod_between(unsigned x, unsigned i, unsigned j, unsigned m) +{ + return mod_sub(x, i, m) < mod_sub(j, i, m); +} + /* * Workarounds for https://github.com/llvm/llvm-project/issues/58992 *