Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/vast/Conversion/Parser/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
Option< "socket", "socket", "std::string", "",
"Unix socket path to use for server."
>,
Option< "tcp_port", "tcp-port", "int", "-1",
"TCP port to use for server."
>,
Option< "tcp_host", "tcp-host", "int", "0",
"TCP host to use for server."
>,
Option< "yaml_out", "yaml-out", "std::string", "",
"Path to YAML output file for models got from user."
>
Expand Down
2 changes: 2 additions & 0 deletions include/vast/server/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ namespace vast::server {
void close() override;

static std::unique_ptr< sock_adapter > create_unix_socket(const std::string &path);
static std::unique_ptr< sock_adapter >
create_tcp_server_socket(uint32_t host, uint16_t port);

private:
std::unique_ptr< struct impl > pimpl;
Expand Down
6 changes: 6 additions & 0 deletions lib/vast/Conversion/Parser/ToParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,12 @@ namespace vast::conv {
vast::server::sock_adapter::create_unix_socket(socket), 1,
server_handler{ models }
);
} else if (tcp_port >= 0) {
server = std::make_shared<
vast::server::server< server_handler, get_function_model_request > >(
vast::server::sock_adapter::create_tcp_server_socket(tcp_host, tcp_port), 1,
server_handler{ models }
);
}
}

Expand Down
65 changes: 49 additions & 16 deletions lib/vast/server/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <stdexcept>
#include <system_error>

#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
Expand All @@ -11,13 +12,20 @@ namespace vast::server {
union addr {
sockaddr base;
sockaddr_un unix;
sockaddr_in net;
};

struct descriptor
{
int fd;

explicit descriptor(int fd) : fd(fd) {}
explicit descriptor() : fd(-1) {}

explicit descriptor(int fd) : fd(fd) {
if (fd < 0) {
throw std::system_error(errno, std::generic_category());
}
}

descriptor(const descriptor &) = delete;
descriptor &operator=(const descriptor &) = delete;
Expand Down Expand Up @@ -53,8 +61,8 @@ namespace vast::server {
sock_adapter::~sock_adapter() = default;

void sock_adapter::close() {
pimpl->clientd = descriptor{ -1 };
pimpl->serverd = descriptor{ -1 };
pimpl->clientd = descriptor{};
pimpl->serverd = descriptor{};
}

size_t sock_adapter::read_some(std::span< char > dst) {
Expand All @@ -73,15 +81,29 @@ namespace vast::server {
return static_cast< size_t >(res);
}

static descriptor bind_and_accept(
descriptor &serverd, addr &sockaddr_server, size_t socklen_server,
sockaddr *sockaddr_client, socklen_t *socklen_client
) {
int rc = bind(serverd, &sockaddr_server.base, static_cast< socklen_t >(socklen_server));
if (rc < 0) {
throw std::system_error(errno, std::generic_category());
}

rc = listen(serverd, 1);
if (rc < 0) {
throw std::system_error(errno, std::generic_category());
}

return descriptor{ accept(serverd, sockaddr_client, socklen_client) };
}

std::unique_ptr< sock_adapter > sock_adapter::create_unix_socket(const std::string &path) {
if (path.size() > (sizeof(sockaddr_un::sun_path) - 1)) {
throw std::runtime_error("Unix socket pathname is too long");
}

descriptor serverd{ socket(AF_UNIX, SOCK_STREAM, 0) };
if (serverd < 0) {
throw std::system_error(errno, std::generic_category());
}

addr sock_addr{};
sock_addr.unix.sun_family = AF_UNIX;
Expand All @@ -95,18 +117,29 @@ namespace vast::server {
if (unlink(path.c_str()) < 0 && errno != ENOENT) {
throw std::system_error(errno, std::generic_category());
}
int rc =
bind(serverd, &sock_addr.base, static_cast< socklen_t >(SUN_LEN(&sock_addr.unix)));
if (rc < 0) {
throw std::system_error(errno, std::generic_category());
}

rc = listen(serverd, 1);
if (rc < 0) {
throw std::system_error(errno, std::generic_category());
}
auto clientd =
bind_and_accept(serverd, sock_addr, sizeof(sock_addr.unix), nullptr, nullptr);

return std::unique_ptr< sock_adapter >(new sock_adapter{
std::make_unique< impl >(std::move(serverd), std::move(clientd)) });
}

std::unique_ptr< sock_adapter >
sock_adapter::create_tcp_server_socket(uint32_t host, uint16_t port) {
descriptor serverd{ socket(AF_INET, SOCK_STREAM, 0) };

addr sock_addr{};
sock_addr.net.sin_family = AF_INET;
sock_addr.net.sin_addr.s_addr = htonl(host);
sock_addr.net.sin_port = htons(port);

addr client_addr{};
socklen_t client_addr_size = sizeof(client_addr.net);

descriptor clientd{ accept(serverd, nullptr, nullptr) };
auto clientd = bind_and_accept(
serverd, sock_addr, sizeof(sock_addr.net), &client_addr.base, &client_addr_size
);

return std::unique_ptr< sock_adapter >(new sock_adapter{
std::make_unique< impl >(std::move(serverd), std::move(clientd)) });
Expand Down