diff --git a/.github/workflows/cmake-multi-platform.yml b/.github/workflows/cmake-multi-platform.yml index 8e4919e..cb001fa 100644 --- a/.github/workflows/cmake-multi-platform.yml +++ b/.github/workflows/cmake-multi-platform.yml @@ -29,6 +29,22 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Install OpenSSL + run: | + if [[ "${{ matrix.os }}" == "ubuntu-latest" ]]; then + sudo apt update + sudo apt install -y libssl-dev + echo "CMAKE_PREFIX_PATH=/usr" >> $GITHUB_ENV + elif [[ "${{ matrix.os }}" == "macos-latest" ]]; then + brew update + brew install openssl + OPENSSL_DIR=$(brew --prefix openssl) + echo "CMAKE_PREFIX_PATH=${OPENSSL_DIR}" >> $GITHUB_ENV + elif [[ "${{ matrix.os }}" == "windows-latest" ]]; then + choco install openssl.light --version=3.1.2 -y + echo "CMAKE_PREFIX_PATH=C:/Program Files/OpenSSL-Win64" >> $GITHUB_ENV + fi + shell: bash - name: Set reusable strings id: strings @@ -42,6 +58,7 @@ jobs: -DCMAKE_C_COMPILER=${{ matrix.c_compiler }} -DCMAKE_CXX_COMPILER=${{ matrix.cpp_compiler }} -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + -DCMAKE_PREFIX_PATH=${{ env.CMAKE_PREFIX_PATH }} -S ${{ github.workspace }} - name: Build diff --git a/include/ssock.hpp b/include/ssock.hpp index 9f4ec10..9bf3d35 100644 --- a/include/ssock.hpp +++ b/include/ssock.hpp @@ -19,10 +19,10 @@ #include #include #include -#include #include #include #include +#include #ifndef SSOCK #define SSOCK 1 @@ -50,6 +50,12 @@ #define SSOCK_LOCALHOST_IPV6 "::1" #endif +#ifdef SSOCK_OPENSSL +#include +#include +#include +#endif + #if defined(__APPLE__) #define SSOCK_MACOS #include @@ -63,8 +69,10 @@ #include #include #include +#if SSOCK_ENABLE_DEPRECATED #include #include +#endif #include #include #include @@ -75,6 +83,7 @@ #include #include #include +#include static constexpr int ns_t_a{1}; static constexpr int ns_t_ns{2}; @@ -137,61 +146,45 @@ namespace ssock::internal_net { namespace ssock { using exception_type = std::exception; - /** - * @brief A class to represent an exception in a socket operation. - */ - class socket_error : public exception_type { - const char* message{"Socket error"}; + class generic_error : public exception_type { + protected: + std::string message; public: - [[nodiscard]] const char* what() const noexcept override { - return message; - } - socket_error() = default; - explicit socket_error(const char* message) : message(message) {}; - explicit socket_error(const std::string& string) : message(string.c_str()) {}; + explicit generic_error(std::string msg) : message(std::move(msg)) {} + [[nodiscard]] const char* what() const noexcept override { return message.c_str(); } }; - /** - * @brief A class to represent an exception trying to parse. - */ - class parsing_error : public exception_type { - const char* message{"Parsing error"}; + class socket_error : public generic_error { public: - [[nodiscard]] const char* what() const noexcept override { - return message; - } - parsing_error() = default; - explicit parsing_error(const char* message) : message(message) {}; - explicit parsing_error(const std::string& string) : message(string.c_str()) {}; + socket_error() : generic_error("Socket error") {} + explicit socket_error(std::string msg) : generic_error(std::move(msg)) {} }; - /** - * @brief A class to represent an exception with an IP address. - */ - class ip_error : public exception_type { - const char* message{"IP error"}; + class parsing_error : public generic_error { public: - [[nodiscard]] const char* what() const noexcept override { - return message; - } - ip_error() = default; - explicit ip_error(const char* message) : message(message) {}; - explicit ip_error(const std::string& string) : message(string.c_str()) {}; + parsing_error() : generic_error("Parsing error") {} + explicit parsing_error(std::string msg) : generic_error(std::move(msg)) {} }; - /** - * @brief A class to represent an exception in DNS resolution. - */ - class dns_error : public exception_type { - const char* message{"DNS error"}; + class ip_error : public generic_error { public: - [[nodiscard]] const char* what() const noexcept override { - return message; - } - dns_error() = default; - explicit dns_error(const char* message) : message(message) {} - explicit dns_error(const std::string& string) : message(string.c_str()) {}; + ip_error() : generic_error("IP error") {} + explicit ip_error(std::string msg) : generic_error(std::move(msg)) {} + }; + + class dns_error : public generic_error { + public: + dns_error() : generic_error("DNS error") {} + explicit dns_error(std::string msg) : generic_error(std::move(msg)) {} + }; + +#ifdef SSOCK_OPENSSL + class ssl_error : public generic_error { + public: + ssl_error() : generic_error("SSL error") {} + explicit ssl_error(std::string msg) : generic_error(std::move(msg)) {} }; +#endif } /** @@ -1897,6 +1890,7 @@ namespace ssock::sock { [[nodiscard]] virtual sock_recv_result recv(int timeout_seconds, const std::string& match) const = 0; [[nodiscard]] virtual sock_recv_result recv(int timeout_seconds, const std::string& match, size_t eof) const = 0; [[nodiscard]] virtual sock_recv_result recv(int timeout_seconds, size_t eof) const = 0; + [[nodiscard]] virtual sock_recv_result primitive_recv() const = 0; [[nodiscard]] virtual std::string overflow_bytes() const = 0; virtual void clear_overflow_bytes() const = 0; virtual void close() = 0; @@ -2105,7 +2099,7 @@ namespace ssock::sock { * @param opts The socket options (reuse_addr, no_reuse_addr). */ sync_sock(int existing_fd, const sock_addr& peer, sock_type t, sock_opt opts = sock_opt::no_reuse_addr|sock_opt::no_delay|sock_opt::blocking) - : sockfd(existing_fd), addr(peer), type(t) { + : addr(peer), type(t), sockfd(existing_fd) { if (sockfd < 0) throw socket_error("invalid fd"); if (this->sockfd >= 0) { this->set_sock_opts(opts); @@ -2469,6 +2463,21 @@ namespace ssock::sock { } } } + + [[nodiscard]] sock_recv_result primitive_recv() const override { + char buf[8192]; + for (;;) { + ssize_t n = ::recv(this->sockfd, buf, sizeof(buf), 0); + if (n > 0) + return {{buf, buf + n}, sock_recv_status::success}; + if (n == 0) + return {{}, sock_recv_status::closed}; + if (errno == EINTR) + continue; + throw socket_error("recv failed"); + } + } + #endif #ifdef SSOCK_WINDOWS [[nodiscard]] sock_recv_result recv(const int timeout_seconds, const std::string& match, size_t eof) const override { @@ -2560,6 +2569,26 @@ namespace ssock::sock { } } } + + sock_recv_result primitive_recv() const override { + constexpr size_t buffer_size = 8192; + char buf[buffer_size]; + + for (;;) { + int n = ::recv(this->sockfd, buf, static_cast(buffer_size), 0); + if (n > 0) { + return {std::string(buf, buf + n), sock_recv_status::success}; + } else if (n == 0) { + return {{}, sock_recv_status::closed}; + } else { + int err = WSAGetLastError(); + if (err == WSAEINTR || err == WSAEWOULDBLOCK || err == WSAEINPROGRESS) { + continue; + } + throw std::runtime_error("recv failed: WSA error " + std::to_string(err)); + } + } + } #endif /* @brief Receive data from the server. @@ -2596,9 +2625,9 @@ namespace ssock::sock { if (this->sockfd == -1) { return; } - if (internal_net::sys_net_close(this->sockfd) < 0) { - throw socket_error("failed to close socket"); - } + + (void)internal_net::sys_net_close(this->sockfd); + this->sockfd = -1; } #endif #ifdef SSOCK_WINDOWS @@ -2607,12 +2636,13 @@ namespace ssock::sock { return; } - if (::closesocket(this->sockfd) == SOCKET_ERROR) { - int err = WSAGetLastError(); - throw socket_error("failed to close socket, error code: " + std::to_string(err)); + ::shutdown(this->sockfd, SD_BOTH); + + if (::closesocket(this->sockfd) != 0) { + ; } - this->sockfd = INVALID_SOCKET; + sockfd = INVALID_SOCKET; } #endif [[nodiscard]] sock_addr get_peer() const { @@ -2731,6 +2761,335 @@ namespace ssock::sock { } } +#ifdef SSOCK_OPENSSL +namespace ssock::crypto { +class ssl_sync_sock { +public: + enum class mode { client, server }; + + explicit ssl_sync_sock(std::unique_ptr underlying, + mode ssl_mode, + const std::string& cert_path = "", + const std::string& key_path = "") + : underlying_sock_(std::move(underlying)), + ssl_mode_(ssl_mode), + cert_path_(cert_path), key_path_(key_path) + { + init_openssl_once(); + create_ssl_context(); + create_ssl_object(); + create_bio(); + + if (ssl_mode_ == mode::client) { + auto underlying_hostname = underlying_sock_->get_addr().get_hostname(); + if (underlying_hostname.empty()) { + throw std::runtime_error{"empty hostname"}; + } + SSL_set_tlsext_host_name(ssl_, underlying_hostname.c_str()); + } + } + + ~ssl_sync_sock() { close(); } + + void connect() { + if (ssl_mode_ != mode::client) + throw std::runtime_error("connect() only valid for client mode"); + underlying_sock_->connect(); + perform_handshake(); + } + + void bind() { underlying_sock_->bind(); } + void unbind() { underlying_sock_->unbind(); } + void listen(int backlog) { underlying_sock_->listen(backlog); } + void listen() { underlying_sock_->listen(); } + + std::unique_ptr accept() { + if (ssl_mode_ != mode::server) + throw std::runtime_error("accept() only valid for server mode"); + auto new_sock = underlying_sock_->accept(); + if (!new_sock) return nullptr; + + auto child = std::make_unique(std::move(new_sock), + mode::server, + cert_path_, key_path_); + child->perform_handshake(); + return child; + } + + int send(const void* buf, size_t len) { + ensure_ready(); + + size_t offset = 0; + while (offset < len) { + int ret = SSL_write( + ssl_, + static_cast(buf) + offset, + static_cast(len - offset)); + + drain_write_bio(); + + if (ret > 0) { + offset += ret; + continue; + } + + int err = SSL_get_error(ssl_, ret); + if (err == SSL_ERROR_WANT_READ) { + feed_read_bio_blocking(); + } else if (err == SSL_ERROR_WANT_WRITE) { + // retry + } else { + throw_ssl_error("SSL_write failed"); + } + } + + return static_cast(len); + } + + void send(const std::string& buf) { send(buf.data(), buf.size()); } + + sock::sock_recv_result recv(int timeout_seconds) const { + return recv_internal(timeout_seconds, nullptr, 0); + } + sock::sock_recv_result recv(int timeout_seconds, const std::string& match) const { + return recv_internal(timeout_seconds, &match, 0); + } + sock::sock_recv_result recv(int timeout_seconds, const std::string& match, size_t eof) const { + return recv_internal(timeout_seconds, &match, eof); + } + sock::sock_recv_result recv(int timeout_seconds, size_t eof) const { + return recv_internal(timeout_seconds, nullptr, eof); + } + + std::string overflow_bytes() const { return overflow_; } + void clear_overflow_bytes() const { overflow_.clear(); } + + void close() { + std::scoped_lock lk(state_mtx_); + + if (ssl_) { + SSL_shutdown(ssl_); // ignore result for sync wrapper + SSL_free(ssl_); + ssl_ = nullptr; + } + + if (ctx_) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + + if (underlying_sock_) { + underlying_sock_->close(); + } + } + +private: + mutable std::string overflow_; + mutable std::mutex state_mtx_; + + std::unique_ptr underlying_sock_; + mode ssl_mode_; + std::string cert_path_; + std::string key_path_; + + SSL_CTX* ctx_ = nullptr; + SSL* ssl_ = nullptr; + + BIO* read_bio_ = nullptr; + BIO* write_bio_ = nullptr; + + static void init_openssl_once() { + static bool initialized = false; + static std::mutex m; + std::scoped_lock lk(m); + if (!initialized) { + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + initialized = true; + } + } + + void create_ssl_context() { + const SSL_METHOD* method = (ssl_mode_ == mode::client) + ? TLS_client_method() + : TLS_server_method(); + ctx_ = SSL_CTX_new(method); + if (!ctx_) throw_ssl_error("SSL_CTX_new failed"); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (ssl_mode_ == mode::server) { + if (SSL_CTX_use_certificate_file(ctx_, cert_path_.c_str(), SSL_FILETYPE_PEM) <= 0) + throw_ssl_error("Failed to load certificate"); + if (SSL_CTX_use_PrivateKey_file(ctx_, key_path_.c_str(), SSL_FILETYPE_PEM) <= 0) + throw_ssl_error("Failed to load private key"); + } else { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + SSL_CTX_set_default_verify_paths(ctx_); + + const char* ca_path = std::getenv("SSL_CERT_FILE"); + if (ca_path) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_path, NULL)) { + throw std::runtime_error{"failed to load ca bundle from environment variable"}; + } + } + + if (!cert_path_.empty()) { + SSL_CTX* ctx = SSL_CTX_new(TLS_client_method()); + if(!SSL_CTX_load_verify_locations(ctx_, cert_path_.c_str(), NULL)) { + throw std::runtime_error{"Failed to load CA bundle"}; + } + } + + X509_VERIFY_PARAM_set1_host(SSL_CTX_get0_param(ctx_), + underlying_sock_->get_addr().get_hostname().c_str(), + 0); + } + } + + void create_ssl_object() { + ssl_ = SSL_new(ctx_); + if (!ssl_) throw_ssl_error("SSL_new failed"); + } + + void create_bio() { + read_bio_ = BIO_new(BIO_s_mem()); + BIO_set_mem_eof_return(read_bio_, -1); + write_bio_ = BIO_new(BIO_s_mem()); + if (!read_bio_ || !write_bio_) + throw_ssl_error("Failed to create memory BIOs"); + + SSL_set_bio(ssl_, read_bio_, write_bio_); + + if (ssl_mode_ == mode::client) { + SSL_set_connect_state(ssl_); + } else { + SSL_set_accept_state(ssl_); + } + } + + bool handshake_complete_ = false; + void perform_handshake() { + while (!SSL_is_init_finished(ssl_)) { + int ret = SSL_do_handshake(ssl_); + drain_write_bio(); + + if (ret == 1) + break; + + int err = SSL_get_error(ssl_, ret); + if (err == SSL_ERROR_WANT_READ) { + feed_read_bio_blocking(); + } else if (err == SSL_ERROR_WANT_WRITE) { + continue; + } else { + throw_ssl_error("TLS handshake failed"); + } + } + + handshake_complete_ = true; + } + + void drain_write_bio() const { + char buf[4096]; + int n; + + while ((n = BIO_read(write_bio_, buf, sizeof(buf))) > 0) { + underlying_sock_->send(buf, n); + } + } + + bool read_eof_ = false; + bool transport_eof_ = false; + void feed_read_bio_blocking() { + auto res = underlying_sock_->primitive_recv(); + + if (res.status == sock::sock_recv_status::closed) { + if (!handshake_complete_) { + throw std::runtime_error("Socket closed during TLS handshake"); + } + + BIO_set_mem_eof_return(read_bio_, 0); + return; + } + + if (res.status != sock::sock_recv_status::success) + throw std::runtime_error("Socket read failed"); + + if (!res.data.empty()) { + int written = BIO_write( + read_bio_, + res.data.data(), + static_cast(res.data.size())); + if (written <= 0) + throw_ssl_error("BIO_write failed"); + } + } + + void ensure_ready() const { + if (!ssl_) throw std::runtime_error("SSL socket closed"); + } + + sock::sock_recv_result recv_internal(int, const std::string* match, size_t eof) const { + ensure_ready(); + sock::sock_recv_result result; + + if (!overflow_.empty()) { + result.data = std::exchange(overflow_, ""); + } + + while (true) { + char buf[4096]; + int ret = SSL_read(ssl_, buf, sizeof(buf)); + drain_write_bio(); + + if (ret > 0) { + result.data.append(buf, ret); + } else { + int err = SSL_get_error(ssl_, ret); + if (err == SSL_ERROR_WANT_READ) { + const_cast(this)->feed_read_bio_blocking(); + continue; + } else if (err == SSL_ERROR_WANT_WRITE) { + continue; + } else if (err == SSL_ERROR_ZERO_RETURN) { + result.status = sock::sock_recv_status::closed; + break; + } else { + result.status = sock::sock_recv_status::error; + break; + } + } + + if (match) { + auto pos = result.data.find(*match); + if (pos != std::string::npos) { + overflow_ = result.data.substr(pos + match->size()); + result.data.resize(pos + match->size()); + break; + } + } + + if (eof && result.data.size() >= eof) { + overflow_ = result.data.substr(eof); + result.data.resize(eof); + break; + } + } + + return result; + } + + static void throw_ssl_error(const std::string& msg) { + char buf[256]; + ERR_error_string_n(ERR_get_error(), buf, sizeof(buf)); + throw std::runtime_error(msg + ": " + buf); + } +}; +} +#endif + namespace ssock::network::dns { class dns_nameserver_list { std::vector ipv4{}; @@ -3358,7 +3717,8 @@ namespace ssock::network::dns { std::vector all_records; - auto send_udp = [&](const std::string& server, ssock::sock::sock_addr_type family) -> std::optional> { + auto send_udp = [&](const std::string &server, + ssock::sock::sock_addr_type family) -> std::optional > { ssock::sock::sock_addr addr(server, 53, family); ssock::sock::sync_sock sock( addr,