tcp: Better helpers for converting between connection pointer and index

The macro CONN_OR_NULL() is used to look up connections by index with
bounds checking.  Replace it with an inline function, which means:
    - Better type checking
    - No danger of multiple evaluation of an @index with side effects

Also add a helper to perform the reverse translation: from connection
pointer to index.  Introduce a macro for this which will make later
cleanups easier and safer.

Signed-off-by: David Gibson <david@gibson.dropbear.id.au>
Signed-off-by: Stefano Brivio <sbrivio@redhat.com>
This commit is contained in:
David Gibson 2022-11-17 16:58:41 +11:00 committed by Stefano Brivio
parent 46b78ce96a
commit 60be7438fa

83
tcp.c
View file

@ -518,14 +518,6 @@ struct tcp_conn {
(conn->events & (SOCK_FIN_RCVD | TAP_FIN_RCVD))) (conn->events & (SOCK_FIN_RCVD | TAP_FIN_RCVD)))
#define CONN_HAS(conn, set) ((conn->events & (set)) == (set)) #define CONN_HAS(conn, set) ((conn->events & (set)) == (set))
#define CONN(index) (tc + (index))
/* We probably don't want to use gcc statement expressions (for portability), so
* use this only after well-defined sequence points (no pre-/post-increments).
*/
#define CONN_OR_NULL(index) \
(((int)(index) >= 0 && (index) < TCP_MAX_CONNS) ? (tc + (index)) : NULL)
static const char *tcp_event_str[] __attribute((__unused__)) = { static const char *tcp_event_str[] __attribute((__unused__)) = {
"SOCK_ACCEPTED", "TAP_SYN_RCVD", "ESTABLISHED", "TAP_SYN_ACK_SENT", "SOCK_ACCEPTED", "TAP_SYN_RCVD", "ESTABLISHED", "TAP_SYN_ACK_SENT",
@ -705,6 +697,21 @@ static size_t tcp6_l2_flags_buf_bytes;
/* TCP connections */ /* TCP connections */
static struct tcp_conn tc[TCP_MAX_CONNS]; static struct tcp_conn tc[TCP_MAX_CONNS];
#define CONN(index) (tc + (index))
#define CONN_IDX(conn) ((conn) - tc)
/** conn_at_idx() - Find a connection by index, if present
* @index: Index of connection to lookup
*
* Return: pointer to connection, or NULL if @index is out of bounds
*/
static inline struct tcp_conn *conn_at_idx(int index)
{
if ((index < 0) || (index >= TCP_MAX_CONNS))
return NULL;
return CONN(index);
}
/* Table for lookup from remote address, local port, remote port */ /* Table for lookup from remote address, local port, remote port */
static struct tcp_conn *tc_hash[TCP_HASH_TABLE_SIZE]; static struct tcp_conn *tc_hash[TCP_HASH_TABLE_SIZE];
@ -761,7 +768,7 @@ static int tcp_epoll_ctl(const struct ctx *c, struct tcp_conn *conn)
{ {
int m = (conn->flags & IN_EPOLL) ? EPOLL_CTL_MOD : EPOLL_CTL_ADD; int m = (conn->flags & IN_EPOLL) ? EPOLL_CTL_MOD : EPOLL_CTL_ADD;
union epoll_ref ref = { .r.proto = IPPROTO_TCP, .r.s = conn->sock, union epoll_ref ref = { .r.proto = IPPROTO_TCP, .r.s = conn->sock,
.r.p.tcp.tcp.index = conn - tc, .r.p.tcp.tcp.index = CONN_IDX(conn),
.r.p.tcp.tcp.v6 = CONN_V6(conn) }; .r.p.tcp.tcp.v6 = CONN_V6(conn) };
struct epoll_event ev = { .data.u64 = ref.u64 }; struct epoll_event ev = { .data.u64 = ref.u64 };
@ -784,7 +791,7 @@ static int tcp_epoll_ctl(const struct ctx *c, struct tcp_conn *conn)
union epoll_ref ref_t = { .r.proto = IPPROTO_TCP, union epoll_ref ref_t = { .r.proto = IPPROTO_TCP,
.r.s = conn->sock, .r.s = conn->sock,
.r.p.tcp.tcp.timer = 1, .r.p.tcp.tcp.timer = 1,
.r.p.tcp.tcp.index = conn - tc }; .r.p.tcp.tcp.index = CONN_IDX(conn) };
struct epoll_event ev_t = { .data.u64 = ref_t.u64, struct epoll_event ev_t = { .data.u64 = ref_t.u64,
.events = EPOLLIN | EPOLLET }; .events = EPOLLIN | EPOLLET };
@ -813,7 +820,7 @@ static void tcp_timer_ctl(const struct ctx *c, struct tcp_conn *conn)
union epoll_ref ref = { .r.proto = IPPROTO_TCP, union epoll_ref ref = { .r.proto = IPPROTO_TCP,
.r.s = conn->sock, .r.s = conn->sock,
.r.p.tcp.tcp.timer = 1, .r.p.tcp.tcp.timer = 1,
.r.p.tcp.tcp.index = conn - tc }; .r.p.tcp.tcp.index = CONN_IDX(conn) };
struct epoll_event ev = { .data.u64 = ref.u64, struct epoll_event ev = { .data.u64 = ref.u64,
.events = EPOLLIN | EPOLLET }; .events = EPOLLIN | EPOLLET };
int fd; int fd;
@ -846,7 +853,7 @@ static void tcp_timer_ctl(const struct ctx *c, struct tcp_conn *conn)
it.it_value.tv_sec = ACT_TIMEOUT; it.it_value.tv_sec = ACT_TIMEOUT;
} }
debug("TCP: index %li, timer expires in %lu.%03lus", conn - tc, debug("TCP: index %li, timer expires in %lu.%03lus", CONN_IDX(conn),
it.it_value.tv_sec, it.it_value.tv_nsec / 1000 / 1000); it.it_value.tv_sec, it.it_value.tv_nsec / 1000 / 1000);
timerfd_settime(conn->timer, 0, &it, NULL); timerfd_settime(conn->timer, 0, &it, NULL);
@ -867,7 +874,7 @@ static void conn_flag_do(const struct ctx *c, struct tcp_conn *conn,
conn->flags &= flag; conn->flags &= flag;
if (fls(~flag) >= 0) { if (fls(~flag) >= 0) {
debug("TCP: index %li: %s dropped", conn - tc, debug("TCP: index %li: %s dropped", CONN_IDX(conn),
tcp_flag_str[fls(~flag)]); tcp_flag_str[fls(~flag)]);
} }
} else { } else {
@ -876,7 +883,7 @@ static void conn_flag_do(const struct ctx *c, struct tcp_conn *conn,
conn->flags |= flag; conn->flags |= flag;
if (fls(flag) >= 0) { if (fls(flag) >= 0) {
debug("TCP: index %li: %s", conn - tc, debug("TCP: index %li: %s", CONN_IDX(conn),
tcp_flag_str[fls(flag)]); tcp_flag_str[fls(flag)]);
} }
} }
@ -926,12 +933,12 @@ static void conn_event_do(const struct ctx *c, struct tcp_conn *conn,
new += 5; new += 5;
if (prev != new) { if (prev != new) {
debug("TCP: index %li, %s: %s -> %s", conn - tc, debug("TCP: index %li, %s: %s -> %s", CONN_IDX(conn),
num == -1 ? "CLOSED" : tcp_event_str[num], num == -1 ? "CLOSED" : tcp_event_str[num],
prev == -1 ? "CLOSED" : tcp_state_str[prev], prev == -1 ? "CLOSED" : tcp_state_str[prev],
(new == -1 || num == -1) ? "CLOSED" : tcp_state_str[new]); (new == -1 || num == -1) ? "CLOSED" : tcp_state_str[new]);
} else { } else {
debug("TCP: index %li, %s", conn - tc, debug("TCP: index %li, %s", CONN_IDX(conn),
num == -1 ? "CLOSED" : tcp_event_str[num]); num == -1 ? "CLOSED" : tcp_event_str[num]);
} }
@ -1355,12 +1362,12 @@ static void tcp_hash_insert(const struct ctx *c, struct tcp_conn *conn,
int b; int b;
b = tcp_hash(c, af, addr, conn->tap_port, conn->sock_port); b = tcp_hash(c, af, addr, conn->tap_port, conn->sock_port);
conn->next_index = tc_hash[b] ? tc_hash[b] - tc : -1; conn->next_index = tc_hash[b] ? CONN_IDX(tc_hash[b]) : -1;
tc_hash[b] = conn; tc_hash[b] = conn;
conn->hash_bucket = b; conn->hash_bucket = b;
debug("TCP: hash table insert: index %li, sock %i, bucket: %i, next: " debug("TCP: hash table insert: index %li, sock %i, bucket: %i, next: "
"%p", conn - tc, conn->sock, b, CONN_OR_NULL(conn->next_index)); "%p", CONN_IDX(conn), conn->sock, b, conn_at_idx(conn->next_index));
} }
/** /**
@ -1373,19 +1380,19 @@ static void tcp_hash_remove(const struct tcp_conn *conn)
int b = conn->hash_bucket; int b = conn->hash_bucket;
for (entry = tc_hash[b]; entry; for (entry = tc_hash[b]; entry;
prev = entry, entry = CONN_OR_NULL(entry->next_index)) { prev = entry, entry = conn_at_idx(entry->next_index)) {
if (entry == conn) { if (entry == conn) {
if (prev) if (prev)
prev->next_index = conn->next_index; prev->next_index = conn->next_index;
else else
tc_hash[b] = CONN_OR_NULL(conn->next_index); tc_hash[b] = conn_at_idx(conn->next_index);
break; break;
} }
} }
debug("TCP: hash table remove: index %li, sock %i, bucket: %i, new: %p", debug("TCP: hash table remove: index %li, sock %i, bucket: %i, new: %p",
conn - tc, conn->sock, b, CONN_IDX(conn), conn->sock, b,
prev ? CONN_OR_NULL(prev->next_index) : tc_hash[b]); prev ? conn_at_idx(prev->next_index) : tc_hash[b]);
} }
/** /**
@ -1399,10 +1406,10 @@ static void tcp_hash_update(struct tcp_conn *old, struct tcp_conn *new)
int b = old->hash_bucket; int b = old->hash_bucket;
for (entry = tc_hash[b]; entry; for (entry = tc_hash[b]; entry;
prev = entry, entry = CONN_OR_NULL(entry->next_index)) { prev = entry, entry = conn_at_idx(entry->next_index)) {
if (entry == old) { if (entry == old) {
if (prev) if (prev)
prev->next_index = new - tc; prev->next_index = CONN_IDX(new);
else else
tc_hash[b] = new; tc_hash[b] = new;
break; break;
@ -1411,7 +1418,7 @@ static void tcp_hash_update(struct tcp_conn *old, struct tcp_conn *new)
debug("TCP: hash table update: old index %li, new index %li, sock %i, " debug("TCP: hash table update: old index %li, new index %li, sock %i, "
"bucket: %i, old: %p, new: %p", "bucket: %i, old: %p, new: %p",
old - tc, new - tc, new->sock, b, old, new); CONN_IDX(old), CONN_IDX(new), new->sock, b, old, new);
} }
/** /**
@ -1431,7 +1438,7 @@ static struct tcp_conn *tcp_hash_lookup(const struct ctx *c, int af,
int b = tcp_hash(c, af, addr, tap_port, sock_port); int b = tcp_hash(c, af, addr, tap_port, sock_port);
struct tcp_conn *conn; struct tcp_conn *conn;
for (conn = tc_hash[b]; conn; conn = CONN_OR_NULL(conn->next_index)) { for (conn = tc_hash[b]; conn; conn = conn_at_idx(conn->next_index)) {
if (tcp_hash_match(conn, af, addr, tap_port, sock_port)) if (tcp_hash_match(conn, af, addr, tap_port, sock_port))
return conn; return conn;
} }
@ -1448,9 +1455,9 @@ static void tcp_table_compact(struct ctx *c, struct tcp_conn *hole)
{ {
struct tcp_conn *from, *to; struct tcp_conn *from, *to;
if ((hole - tc) == --c->tcp.conn_count) { if (CONN_IDX(hole) == --c->tcp.conn_count) {
debug("TCP: hash table compaction: maximum index was %li (%p)", debug("TCP: hash table compaction: maximum index was %li (%p)",
hole - tc, hole); CONN_IDX(hole), hole);
memset(hole, 0, sizeof(*hole)); memset(hole, 0, sizeof(*hole));
return; return;
} }
@ -1465,7 +1472,7 @@ static void tcp_table_compact(struct ctx *c, struct tcp_conn *hole)
debug("TCP: hash table compaction: old index %li, new index %li, " debug("TCP: hash table compaction: old index %li, new index %li, "
"sock %i, from: %p, to: %p", "sock %i, from: %p, to: %p",
from - tc, to - tc, from->sock, from, to); CONN_IDX(from), CONN_IDX(to), from->sock, from, to);
memset(from, 0, sizeof(*from)); memset(from, 0, sizeof(*from));
} }
@ -1488,7 +1495,7 @@ static void tcp_conn_destroy(struct ctx *c, struct tcp_conn *conn)
static void tcp_rst_do(struct ctx *c, struct tcp_conn *conn); static void tcp_rst_do(struct ctx *c, struct tcp_conn *conn);
#define tcp_rst(c, conn) \ #define tcp_rst(c, conn) \
do { \ do { \
debug("TCP: index %li, reset at %s:%i", conn - tc, \ debug("TCP: index %li, reset at %s:%i", CONN_IDX(conn), \
__func__, __LINE__); \ __func__, __LINE__); \
tcp_rst_do(c, conn); \ tcp_rst_do(c, conn); \
} while (0) } while (0)
@ -2734,7 +2741,7 @@ int tcp_tap_handler(struct ctx *c, int af, const void *addr,
return 1; return 1;
} }
trace("TCP: packet length %lu from tap for index %lu", len, conn - tc); trace("TCP: packet length %lu from tap for index %lu", len, CONN_IDX(conn));
if (th->rst) { if (th->rst) {
conn_event(c, conn, CLOSED); conn_event(c, conn, CLOSED);
@ -2942,7 +2949,7 @@ static void tcp_conn_from_sock(struct ctx *c, union epoll_ref ref,
*/ */
static void tcp_timer_handler(struct ctx *c, union epoll_ref ref) static void tcp_timer_handler(struct ctx *c, union epoll_ref ref)
{ {
struct tcp_conn *conn = CONN_OR_NULL(ref.r.p.tcp.tcp.index); struct tcp_conn *conn = conn_at_idx(ref.r.p.tcp.tcp.index);
struct itimerspec check_armed = { { 0 }, { 0 } }; struct itimerspec check_armed = { { 0 }, { 0 } };
if (!conn) if (!conn)
@ -2961,17 +2968,17 @@ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref)
conn_flag(c, conn, ~ACK_TO_TAP_DUE); conn_flag(c, conn, ~ACK_TO_TAP_DUE);
} else if (conn->flags & ACK_FROM_TAP_DUE) { } else if (conn->flags & ACK_FROM_TAP_DUE) {
if (!(conn->events & ESTABLISHED)) { if (!(conn->events & ESTABLISHED)) {
debug("TCP: index %li, handshake timeout", conn - tc); debug("TCP: index %li, handshake timeout", CONN_IDX(conn));
tcp_rst(c, conn); tcp_rst(c, conn);
} else if (CONN_HAS(conn, SOCK_FIN_SENT | TAP_FIN_ACKED)) { } else if (CONN_HAS(conn, SOCK_FIN_SENT | TAP_FIN_ACKED)) {
debug("TCP: index %li, FIN timeout", conn - tc); debug("TCP: index %li, FIN timeout", CONN_IDX(conn));
tcp_rst(c, conn); tcp_rst(c, conn);
} else if (conn->retrans == TCP_MAX_RETRANS) { } else if (conn->retrans == TCP_MAX_RETRANS) {
debug("TCP: index %li, retransmissions count exceeded", debug("TCP: index %li, retransmissions count exceeded",
conn - tc); CONN_IDX(conn));
tcp_rst(c, conn); tcp_rst(c, conn);
} else { } else {
debug("TCP: index %li, ACK timeout, retry", conn - tc); debug("TCP: index %li, ACK timeout, retry", CONN_IDX(conn));
conn->retrans++; conn->retrans++;
conn->seq_to_tap = conn->seq_ack_from_tap; conn->seq_to_tap = conn->seq_ack_from_tap;
tcp_data_from_sock(c, conn); tcp_data_from_sock(c, conn);
@ -2989,7 +2996,7 @@ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref)
*/ */
timerfd_settime(conn->timer, 0, &new, &old); timerfd_settime(conn->timer, 0, &new, &old);
if (old.it_value.tv_sec == ACT_TIMEOUT) { if (old.it_value.tv_sec == ACT_TIMEOUT) {
debug("TCP: index %li, activity timeout", conn - tc); debug("TCP: index %li, activity timeout", CONN_IDX(conn));
tcp_rst(c, conn); tcp_rst(c, conn);
} }
} }
@ -3022,7 +3029,7 @@ void tcp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events,
return; return;
} }
if (!(conn = CONN_OR_NULL(ref.r.p.tcp.tcp.index))) if (!(conn = conn_at_idx(ref.r.p.tcp.tcp.index)))
return; return;
if (conn->events == CLOSED) if (conn->events == CLOSED)