tcp: Better helpers for converting between connection pointer and index

The macro CONN_OR_NULL() is used to look up connections by index with
bounds checking.  Replace it with an inline function, which means:
    - Better type checking
    - No danger of multiple evaluation of an @index with side effects

Also add a helper to perform the reverse translation: from connection
pointer to index.  Introduce a macro for this which will make later
cleanups easier and safer.

Signed-off-by: David Gibson <david@gibson.dropbear.id.au>
Signed-off-by: Stefano Brivio <sbrivio@redhat.com>
This commit is contained in:
David Gibson 2022-11-17 16:58:41 +11:00 committed by Stefano Brivio
parent 46b78ce96a
commit 60be7438fa

83
tcp.c
View file

@ -518,14 +518,6 @@ struct tcp_conn {
(conn->events & (SOCK_FIN_RCVD | TAP_FIN_RCVD)))
#define CONN_HAS(conn, set) ((conn->events & (set)) == (set))
#define CONN(index) (tc + (index))
/* We probably don't want to use gcc statement expressions (for portability), so
* use this only after well-defined sequence points (no pre-/post-increments).
*/
#define CONN_OR_NULL(index) \
(((int)(index) >= 0 && (index) < TCP_MAX_CONNS) ? (tc + (index)) : NULL)
static const char *tcp_event_str[] __attribute((__unused__)) = {
"SOCK_ACCEPTED", "TAP_SYN_RCVD", "ESTABLISHED", "TAP_SYN_ACK_SENT",
@ -705,6 +697,21 @@ static size_t tcp6_l2_flags_buf_bytes;
/* TCP connections */
static struct tcp_conn tc[TCP_MAX_CONNS];
#define CONN(index) (tc + (index))
#define CONN_IDX(conn) ((conn) - tc)
/** conn_at_idx() - Find a connection by index, if present
* @index: Index of connection to lookup
*
* Return: pointer to connection, or NULL if @index is out of bounds
*/
static inline struct tcp_conn *conn_at_idx(int index)
{
if ((index < 0) || (index >= TCP_MAX_CONNS))
return NULL;
return CONN(index);
}
/* Table for lookup from remote address, local port, remote port */
static struct tcp_conn *tc_hash[TCP_HASH_TABLE_SIZE];
@ -761,7 +768,7 @@ static int tcp_epoll_ctl(const struct ctx *c, struct tcp_conn *conn)
{
int m = (conn->flags & IN_EPOLL) ? EPOLL_CTL_MOD : EPOLL_CTL_ADD;
union epoll_ref ref = { .r.proto = IPPROTO_TCP, .r.s = conn->sock,
.r.p.tcp.tcp.index = conn - tc,
.r.p.tcp.tcp.index = CONN_IDX(conn),
.r.p.tcp.tcp.v6 = CONN_V6(conn) };
struct epoll_event ev = { .data.u64 = ref.u64 };
@ -784,7 +791,7 @@ static int tcp_epoll_ctl(const struct ctx *c, struct tcp_conn *conn)
union epoll_ref ref_t = { .r.proto = IPPROTO_TCP,
.r.s = conn->sock,
.r.p.tcp.tcp.timer = 1,
.r.p.tcp.tcp.index = conn - tc };
.r.p.tcp.tcp.index = CONN_IDX(conn) };
struct epoll_event ev_t = { .data.u64 = ref_t.u64,
.events = EPOLLIN | EPOLLET };
@ -813,7 +820,7 @@ static void tcp_timer_ctl(const struct ctx *c, struct tcp_conn *conn)
union epoll_ref ref = { .r.proto = IPPROTO_TCP,
.r.s = conn->sock,
.r.p.tcp.tcp.timer = 1,
.r.p.tcp.tcp.index = conn - tc };
.r.p.tcp.tcp.index = CONN_IDX(conn) };
struct epoll_event ev = { .data.u64 = ref.u64,
.events = EPOLLIN | EPOLLET };
int fd;
@ -846,7 +853,7 @@ static void tcp_timer_ctl(const struct ctx *c, struct tcp_conn *conn)
it.it_value.tv_sec = ACT_TIMEOUT;
}
debug("TCP: index %li, timer expires in %lu.%03lus", conn - tc,
debug("TCP: index %li, timer expires in %lu.%03lus", CONN_IDX(conn),
it.it_value.tv_sec, it.it_value.tv_nsec / 1000 / 1000);
timerfd_settime(conn->timer, 0, &it, NULL);
@ -867,7 +874,7 @@ static void conn_flag_do(const struct ctx *c, struct tcp_conn *conn,
conn->flags &= flag;
if (fls(~flag) >= 0) {
debug("TCP: index %li: %s dropped", conn - tc,
debug("TCP: index %li: %s dropped", CONN_IDX(conn),
tcp_flag_str[fls(~flag)]);
}
} else {
@ -876,7 +883,7 @@ static void conn_flag_do(const struct ctx *c, struct tcp_conn *conn,
conn->flags |= flag;
if (fls(flag) >= 0) {
debug("TCP: index %li: %s", conn - tc,
debug("TCP: index %li: %s", CONN_IDX(conn),
tcp_flag_str[fls(flag)]);
}
}
@ -926,12 +933,12 @@ static void conn_event_do(const struct ctx *c, struct tcp_conn *conn,
new += 5;
if (prev != new) {
debug("TCP: index %li, %s: %s -> %s", conn - tc,
debug("TCP: index %li, %s: %s -> %s", CONN_IDX(conn),
num == -1 ? "CLOSED" : tcp_event_str[num],
prev == -1 ? "CLOSED" : tcp_state_str[prev],
(new == -1 || num == -1) ? "CLOSED" : tcp_state_str[new]);
} else {
debug("TCP: index %li, %s", conn - tc,
debug("TCP: index %li, %s", CONN_IDX(conn),
num == -1 ? "CLOSED" : tcp_event_str[num]);
}
@ -1355,12 +1362,12 @@ static void tcp_hash_insert(const struct ctx *c, struct tcp_conn *conn,
int b;
b = tcp_hash(c, af, addr, conn->tap_port, conn->sock_port);
conn->next_index = tc_hash[b] ? tc_hash[b] - tc : -1;
conn->next_index = tc_hash[b] ? CONN_IDX(tc_hash[b]) : -1;
tc_hash[b] = conn;
conn->hash_bucket = b;
debug("TCP: hash table insert: index %li, sock %i, bucket: %i, next: "
"%p", conn - tc, conn->sock, b, CONN_OR_NULL(conn->next_index));
"%p", CONN_IDX(conn), conn->sock, b, conn_at_idx(conn->next_index));
}
/**
@ -1373,19 +1380,19 @@ static void tcp_hash_remove(const struct tcp_conn *conn)
int b = conn->hash_bucket;
for (entry = tc_hash[b]; entry;
prev = entry, entry = CONN_OR_NULL(entry->next_index)) {
prev = entry, entry = conn_at_idx(entry->next_index)) {
if (entry == conn) {
if (prev)
prev->next_index = conn->next_index;
else
tc_hash[b] = CONN_OR_NULL(conn->next_index);
tc_hash[b] = conn_at_idx(conn->next_index);
break;
}
}
debug("TCP: hash table remove: index %li, sock %i, bucket: %i, new: %p",
conn - tc, conn->sock, b,
prev ? CONN_OR_NULL(prev->next_index) : tc_hash[b]);
CONN_IDX(conn), conn->sock, b,
prev ? conn_at_idx(prev->next_index) : tc_hash[b]);
}
/**
@ -1399,10 +1406,10 @@ static void tcp_hash_update(struct tcp_conn *old, struct tcp_conn *new)
int b = old->hash_bucket;
for (entry = tc_hash[b]; entry;
prev = entry, entry = CONN_OR_NULL(entry->next_index)) {
prev = entry, entry = conn_at_idx(entry->next_index)) {
if (entry == old) {
if (prev)
prev->next_index = new - tc;
prev->next_index = CONN_IDX(new);
else
tc_hash[b] = new;
break;
@ -1411,7 +1418,7 @@ static void tcp_hash_update(struct tcp_conn *old, struct tcp_conn *new)
debug("TCP: hash table update: old index %li, new index %li, sock %i, "
"bucket: %i, old: %p, new: %p",
old - tc, new - tc, new->sock, b, old, new);
CONN_IDX(old), CONN_IDX(new), new->sock, b, old, new);
}
/**
@ -1431,7 +1438,7 @@ static struct tcp_conn *tcp_hash_lookup(const struct ctx *c, int af,
int b = tcp_hash(c, af, addr, tap_port, sock_port);
struct tcp_conn *conn;
for (conn = tc_hash[b]; conn; conn = CONN_OR_NULL(conn->next_index)) {
for (conn = tc_hash[b]; conn; conn = conn_at_idx(conn->next_index)) {
if (tcp_hash_match(conn, af, addr, tap_port, sock_port))
return conn;
}
@ -1448,9 +1455,9 @@ static void tcp_table_compact(struct ctx *c, struct tcp_conn *hole)
{
struct tcp_conn *from, *to;
if ((hole - tc) == --c->tcp.conn_count) {
if (CONN_IDX(hole) == --c->tcp.conn_count) {
debug("TCP: hash table compaction: maximum index was %li (%p)",
hole - tc, hole);
CONN_IDX(hole), hole);
memset(hole, 0, sizeof(*hole));
return;
}
@ -1465,7 +1472,7 @@ static void tcp_table_compact(struct ctx *c, struct tcp_conn *hole)
debug("TCP: hash table compaction: old index %li, new index %li, "
"sock %i, from: %p, to: %p",
from - tc, to - tc, from->sock, from, to);
CONN_IDX(from), CONN_IDX(to), from->sock, from, to);
memset(from, 0, sizeof(*from));
}
@ -1488,7 +1495,7 @@ static void tcp_conn_destroy(struct ctx *c, struct tcp_conn *conn)
static void tcp_rst_do(struct ctx *c, struct tcp_conn *conn);
#define tcp_rst(c, conn) \
do { \
debug("TCP: index %li, reset at %s:%i", conn - tc, \
debug("TCP: index %li, reset at %s:%i", CONN_IDX(conn), \
__func__, __LINE__); \
tcp_rst_do(c, conn); \
} while (0)
@ -2734,7 +2741,7 @@ int tcp_tap_handler(struct ctx *c, int af, const void *addr,
return 1;
}
trace("TCP: packet length %lu from tap for index %lu", len, conn - tc);
trace("TCP: packet length %lu from tap for index %lu", len, CONN_IDX(conn));
if (th->rst) {
conn_event(c, conn, CLOSED);
@ -2942,7 +2949,7 @@ static void tcp_conn_from_sock(struct ctx *c, union epoll_ref ref,
*/
static void tcp_timer_handler(struct ctx *c, union epoll_ref ref)
{
struct tcp_conn *conn = CONN_OR_NULL(ref.r.p.tcp.tcp.index);
struct tcp_conn *conn = conn_at_idx(ref.r.p.tcp.tcp.index);
struct itimerspec check_armed = { { 0 }, { 0 } };
if (!conn)
@ -2961,17 +2968,17 @@ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref)
conn_flag(c, conn, ~ACK_TO_TAP_DUE);
} else if (conn->flags & ACK_FROM_TAP_DUE) {
if (!(conn->events & ESTABLISHED)) {
debug("TCP: index %li, handshake timeout", conn - tc);
debug("TCP: index %li, handshake timeout", CONN_IDX(conn));
tcp_rst(c, conn);
} else if (CONN_HAS(conn, SOCK_FIN_SENT | TAP_FIN_ACKED)) {
debug("TCP: index %li, FIN timeout", conn - tc);
debug("TCP: index %li, FIN timeout", CONN_IDX(conn));
tcp_rst(c, conn);
} else if (conn->retrans == TCP_MAX_RETRANS) {
debug("TCP: index %li, retransmissions count exceeded",
conn - tc);
CONN_IDX(conn));
tcp_rst(c, conn);
} else {
debug("TCP: index %li, ACK timeout, retry", conn - tc);
debug("TCP: index %li, ACK timeout, retry", CONN_IDX(conn));
conn->retrans++;
conn->seq_to_tap = conn->seq_ack_from_tap;
tcp_data_from_sock(c, conn);
@ -2989,7 +2996,7 @@ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref)
*/
timerfd_settime(conn->timer, 0, &new, &old);
if (old.it_value.tv_sec == ACT_TIMEOUT) {
debug("TCP: index %li, activity timeout", conn - tc);
debug("TCP: index %li, activity timeout", CONN_IDX(conn));
tcp_rst(c, conn);
}
}
@ -3022,7 +3029,7 @@ void tcp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events,
return;
}
if (!(conn = CONN_OR_NULL(ref.r.p.tcp.tcp.index)))
if (!(conn = conn_at_idx(ref.r.p.tcp.tcp.index)))
return;
if (conn->events == CLOSED)