nstool: Split some command line parsing and socket setup to subcommands

This will make it easier to differentiate the options to those commands
further in future.

Signed-off-by: David Gibson <david@gibson.dropbear.id.au>
Signed-off-by: Stefano Brivio <sbrivio@redhat.com>
This commit is contained in:
David Gibson 2023-04-06 13:28:09 +10:00 committed by Stefano Brivio
parent 42fb218347
commit a4b017d91c

View file

@ -11,6 +11,7 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdbool.h>
#include <errno.h> #include <errno.h>
#include <unistd.h> #include <unistd.h>
#include <sys/socket.h> #include <sys/socket.h>
@ -37,19 +38,55 @@ static void usage(void)
" terminate.\n"); " terminate.\n");
} }
static void hold(int fd, const struct sockaddr_un *addr) static int connect_ctl(const char * sockpath, bool wait)
{ {
int fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX);
struct sockaddr_un addr = {
.sun_family = AF_UNIX,
};
int rc; int rc;
rc = bind(fd, (struct sockaddr *)addr, sizeof(*addr)); if (fd < 0)
die("socket(): %s\n", strerror(errno));
strncpy(addr.sun_path, sockpath, UNIX_PATH_MAX);
do {
rc = connect(fd, (struct sockaddr *)&addr, sizeof(addr));
if (rc < 0 &&
(!wait || (errno != ENOENT && errno != ECONNREFUSED)))
die("connect() to %s: %s\n", sockpath, strerror(errno));
} while (rc < 0);
return fd;
}
static void cmd_hold(int argc, char *argv[])
{
int fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX);
struct sockaddr_un addr = {
.sun_family = AF_UNIX,
};
const char *sockpath = argv[1];
int rc;
if (argc != 2)
usage();
if (fd < 0)
die("socket(): %s\n", strerror(errno));
strncpy(addr.sun_path, sockpath, UNIX_PATH_MAX);
rc = bind(fd, (struct sockaddr *)&addr, sizeof(addr));
if (rc < 0) if (rc < 0)
die("bind(): %s\n", strerror(errno)); die("bind() to %s: %s\n", sockpath, strerror(errno));
rc = listen(fd, 0); rc = listen(fd, 0);
if (rc < 0) if (rc < 0)
die("listen(): %s\n", strerror(errno)); die("listen() on %s: %s\n", sockpath, strerror(errno));
printf("nstool: local PID=%d local UID=%u local GID=%u\n", printf("nstool hold: local PID=%d local UID=%u local GID=%u\n",
getpid(), getuid(), getgid()); getpid(), getuid(), getgid());
do { do {
int afd = accept(fd, NULL, NULL); int afd = accept(fd, NULL, NULL);
@ -63,71 +100,68 @@ static void hold(int fd, const struct sockaddr_un *addr)
die("read(): %s\n", strerror(errno)); die("read(): %s\n", strerror(errno));
} while (rc == 0); } while (rc == 0);
unlink(addr->sun_path); unlink(sockpath);
} }
static void pid(int fd, const struct sockaddr_un *addr) static void cmd_pid(int argc, char *argv[])
{ {
int rc; const char *sockpath = argv[1];
struct ucred peercred; struct ucred peercred;
socklen_t optlen = sizeof(peercred); socklen_t optlen = sizeof(peercred);
int fd, rc;
do { if (argc != 2)
rc = connect(fd, (struct sockaddr *)addr, sizeof(*addr)); usage();
if (rc < 0 && errno != ENOENT && errno != ECONNREFUSED)
die("connect(): %s\n", strerror(errno)); fd = connect_ctl(sockpath, true);
} while (rc < 0);
rc = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, rc = getsockopt(fd, SOL_SOCKET, SO_PEERCRED,
&peercred, &optlen); &peercred, &optlen);
if (rc < 0) if (rc < 0)
die("getsockopet(SO_PEERCRED): %s\n", strerror(errno)); die("getsockopet(SO_PEERCRED) %s: %s\n",
sockpath, strerror(errno));
close(fd); close(fd);
printf("%d\n", peercred.pid); printf("%d\n", peercred.pid);
} }
static void stop(int fd, const struct sockaddr_un *addr) static void cmd_stop(int argc, char *argv[])
{ {
int rc; const char *sockpath = argv[1];
int fd, rc;
char buf = 'Q'; char buf = 'Q';
rc = connect(fd, (struct sockaddr *)addr, sizeof(*addr)); if (argc != 2)
if (rc < 0) usage();
die("connect(): %s\n", strerror(errno));
fd = connect_ctl(sockpath, false);
rc = write(fd, &buf, sizeof(buf)); rc = write(fd, &buf, sizeof(buf));
if (rc < 0) if (rc < 0)
die("write(): %s\n", strerror(errno)); die("write() to %s: %s\n", sockpath, strerror(errno));
close(fd); close(fd);
} }
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
const char *subcmd = argv[1];
int fd; int fd;
const char *sockname;
struct sockaddr_un sockaddr = {
.sun_family = AF_UNIX,
};
if (argc != 3) if (argc < 2)
usage(); usage();
sockname = argv[2];
strncpy(sockaddr.sun_path, sockname, UNIX_PATH_MAX);
fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX); fd = socket(AF_UNIX, SOCK_STREAM, PF_UNIX);
if (fd < 0) if (fd < 0)
die("socket(): %s\n", strerror(errno)); die("socket(): %s\n", strerror(errno));
if (strcmp(argv[1], "hold") == 0) if (strcmp(subcmd, "hold") == 0)
hold(fd, &sockaddr); cmd_hold(argc - 1, argv + 1);
else if (strcmp(argv[1], "pid") == 0) else if (strcmp(subcmd, "pid") == 0)
pid(fd, &sockaddr); cmd_pid(argc - 1, argv + 1);
else if (strcmp(argv[1], "stop") == 0) else if (strcmp(subcmd, "stop") == 0)
stop(fd, &sockaddr); cmd_stop(argc - 1, argv + 1);
else else
usage(); usage();