Skip to content

Commit 53b157f

Browse files
authored
websocket client fixes (#2071)
* websocket client fixes * Update websocket-client.md * Update websocket_client_test.cpp * Update websocket_client_test.cpp * Update websocket_client_test.cpp * Update websocket_client_test.cpp * more diagnostics * Update websocket_client_test.cpp * Update websocket_connection.hpp * Update websocket_connection.hpp * Update websocket_client_test.cpp * updates * fix payload length cast for 32bit
1 parent 443ac7b commit 53b157f

File tree

10 files changed

+1695
-125
lines changed

10 files changed

+1695
-125
lines changed

docs/networking/websocket-client.md

Lines changed: 692 additions & 0 deletions
Large diffs are not rendered by default.

include/glaze/net/http_client.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ namespace glz
7474
}
7575

7676
std::string protocol(url.substr(0, protocol_end));
77-
if (protocol != "http" && protocol != "https") {
77+
if (protocol != "http" && protocol != "https" && protocol != "ws" && protocol != "wss") {
7878
return std::unexpected(std::make_error_code(std::errc::invalid_argument));
7979
}
8080

@@ -120,7 +120,7 @@ namespace glz
120120

121121
uint16_t port = 0;
122122
if (port_str.empty()) {
123-
port = (protocol == "https") ? 443 : 80;
123+
port = (protocol == "https" || protocol == "wss") ? 443 : 80;
124124
}
125125
else {
126126
try {

include/glaze/net/http_server.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ namespace glz
13771377

13781378
// Create WebSocket connection and start it
13791379
// Need to include websocket_connection.hpp for this to work
1380-
auto ws_conn = std::make_shared<websocket_connection>(std::move(*socket), ws_it->second.get());
1380+
auto ws_conn = std::make_shared<websocket_connection<asio::ip::tcp::socket>>(socket, ws_it->second.get());
13811381
ws_conn->start(req);
13821382
}
13831383

include/glaze/net/websocket_client.hpp

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <mutex>
4+
#include <random>
45
#include <variant>
56

67
#include "glaze/net/http_client.hpp"
@@ -30,13 +31,13 @@ namespace glz
3031

3132
std::shared_ptr<asio::io_context> ctx_;
3233
connection_variant connection_;
33-
std::mutex connection_mutex_;
34+
std::shared_ptr<std::mutex> connection_mutex_;
3435

35-
// Callbacks
36-
message_handler_t on_message_;
37-
open_handler_t on_open_;
38-
close_handler_t on_close_;
39-
error_handler_t on_error_;
36+
// Callbacks - wrapped in shared_ptr to survive client destruction
37+
std::shared_ptr<message_handler_t> on_message_;
38+
std::shared_ptr<open_handler_t> on_open_;
39+
std::shared_ptr<close_handler_t> on_close_;
40+
std::shared_ptr<error_handler_t> on_error_;
4041

4142
// Keep alive the resolver/socket during connection
4243
std::shared_ptr<asio::ip::tcp::resolver> resolver_;
@@ -52,12 +53,16 @@ namespace glz
5253
ctx_ = ctx;
5354
else
5455
ctx_ = std::make_shared<asio::io_context>();
56+
57+
connection_mutex_ = std::make_shared<std::mutex>();
5558
}
5659

57-
void on_message(message_handler_t handler) { on_message_ = std::move(handler); }
58-
void on_open(open_handler_t handler) { on_open_ = std::move(handler); }
59-
void on_close(close_handler_t handler) { on_close_ = std::move(handler); }
60-
void on_error(error_handler_t handler) { on_error_ = std::move(handler); }
60+
~websocket_client() = default;
61+
62+
void on_message(message_handler_t handler) { on_message_ = std::make_shared<message_handler_t>(std::move(handler)); }
63+
void on_open(open_handler_t handler) { on_open_ = std::make_shared<open_handler_t>(std::move(handler)); }
64+
void on_close(close_handler_t handler) { on_close_ = std::make_shared<close_handler_t>(std::move(handler)); }
65+
void on_error(error_handler_t handler) { on_error_ = std::make_shared<error_handler_t>(std::move(handler)); }
6166

6267
void set_max_message_size(size_t size) { max_message_size_ = size; }
6368

@@ -67,7 +72,7 @@ namespace glz
6772
{
6873
auto url_result = parse_url(url_str);
6974
if (!url_result) {
70-
if (on_error_) on_error_(url_result.error());
75+
if (on_error_ && *on_error_) (*on_error_)(url_result.error());
7176
return;
7277
}
7378

@@ -91,7 +96,7 @@ namespace glz
9196
return;
9297
}
9398
#else
94-
if (on_error_) on_error_(std::make_error_code(std::errc::protocol_not_supported));
99+
if (on_error_ && *on_error_) (*on_error_)(std::make_error_code(std::errc::protocol_not_supported));
95100
return;
96101
#endif
97102
}
@@ -101,30 +106,33 @@ namespace glz
101106

102107
resolver_ = std::make_shared<asio::ip::tcp::resolver>(*ctx_);
103108

109+
auto error_handler = on_error_; // Copy shared_ptr
104110
resolver_->async_resolve(
105111
url.host, std::to_string(url.port),
106-
[this, url](std::error_code ec, asio::ip::tcp::resolver::results_type results) {
112+
[this, url, error_handler](std::error_code ec, asio::ip::tcp::resolver::results_type results) {
107113
if (ec) {
108-
if (on_error_) on_error_(ec);
114+
if (error_handler && *error_handler) (*error_handler)(ec);
109115
return;
110116
}
111117

112118
// Determine which socket to connect
113119
auto& socket_ref = get_tcp_socket_ref();
114120

121+
auto error_handler2 = on_error_; // Copy shared_ptr
115122
asio::async_connect(
116-
socket_ref, results, [this, url](std::error_code ec, const asio::ip::tcp::endpoint&) {
123+
socket_ref, results, [this, url, error_handler2](std::error_code ec, const asio::ip::tcp::endpoint&) {
117124
if (ec) {
118-
if (on_error_) on_error_(ec);
125+
if (error_handler2 && *error_handler2) (*error_handler2)(ec);
119126
return;
120127
}
121128

122129
if (url.protocol == "wss") {
123130
#ifdef GLZ_ENABLE_SSL
124131
// Perform SSL Handshake
125-
ssl_socket_->async_handshake(asio::ssl::stream_base::client, [this, url](std::error_code ec) {
132+
auto error_handler3 = on_error_;
133+
ssl_socket_->async_handshake(asio::ssl::stream_base::client, [this, url, error_handler3](std::error_code ec) {
126134
if (ec) {
127-
if (on_error_) on_error_(ec);
135+
if (error_handler3 && *error_handler3) (*error_handler3)(ec);
128136
return;
129137
}
130138
perform_handshake(ssl_socket_, url);
@@ -140,7 +148,10 @@ namespace glz
140148

141149
void send(std::string_view msg)
142150
{
143-
std::lock_guard<std::mutex> lock(connection_mutex_);
151+
auto mutex = connection_mutex_; // Copy shared_ptr to keep mutex alive
152+
if (!mutex) return;
153+
154+
std::lock_guard<std::mutex> lock(*mutex);
144155
std::visit(
145156
[&](auto&& conn) {
146157
if constexpr (!std::is_same_v<std::decay_t<decltype(conn)>, std::monostate>) {
@@ -150,9 +161,27 @@ namespace glz
150161
connection_);
151162
}
152163

164+
void send_binary(std::string_view msg)
165+
{
166+
auto mutex = connection_mutex_; // Copy shared_ptr to keep mutex alive
167+
if (!mutex) return;
168+
169+
std::lock_guard<std::mutex> lock(*mutex);
170+
std::visit(
171+
[&](auto&& conn) {
172+
if constexpr (!std::is_same_v<std::decay_t<decltype(conn)>, std::monostate>) {
173+
if (conn) conn->send_binary(msg);
174+
}
175+
},
176+
connection_);
177+
}
178+
153179
void close()
154180
{
155-
std::lock_guard<std::mutex> lock(connection_mutex_);
181+
auto mutex = connection_mutex_; // Copy shared_ptr to keep mutex alive
182+
if (!mutex) return;
183+
184+
std::lock_guard<std::mutex> lock(*mutex);
156185
std::visit(
157186
[&](auto&& conn) {
158187
if constexpr (!std::is_same_v<std::decay_t<decltype(conn)>, std::monostate>) {
@@ -188,11 +217,12 @@ namespace glz
188217
"\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n";
189218

190219
auto req_buf = std::make_shared<std::string>(std::move(handshake));
220+
auto error_handler = on_error_; // Copy shared_ptr
191221

192222
asio::async_write(*socket, asio::buffer(*req_buf),
193-
[this, socket, req_buf /* keep alive */, key](std::error_code ec, std::size_t) {
223+
[this, socket, req_buf, key, error_handler](std::error_code ec, std::size_t) {
194224
if (ec) {
195-
if (on_error_) on_error_(ec);
225+
if (error_handler && *error_handler) (*error_handler)(ec);
196226
return;
197227
}
198228
read_handshake_response(socket, key);
@@ -205,11 +235,17 @@ namespace glz
205235
// Limit handshake response size to 16KB to prevent DoS
206236
static constexpr size_t max_handshake_size = 1024 * 16;
207237
auto response_buf = std::make_shared<asio::streambuf>(max_handshake_size);
238+
auto error_handler = on_error_; // Copy shared_ptr
239+
auto message_handler = on_message_;
240+
auto close_handler = on_close_;
241+
auto open_handler = on_open_;
242+
auto mutex = connection_mutex_;
243+
auto max_msg_size = max_message_size_;
208244

209245
asio::async_read_until(*socket, *response_buf, "\r\n\r\n",
210-
[this, socket, response_buf, expected_key](std::error_code ec, std::size_t) {
246+
[this, socket, response_buf, expected_key, error_handler, message_handler, close_handler, open_handler, mutex, max_msg_size](std::error_code ec, std::size_t) {
211247
if (ec) {
212-
if (on_error_) on_error_(ec);
248+
if (error_handler && *error_handler) (*error_handler)(ec);
213249
return;
214250
}
215251

@@ -222,7 +258,7 @@ namespace glz
222258
std::getline(response_stream, status_message);
223259

224260
if (!response_stream || status_code != 101) {
225-
if (on_error_) on_error_(std::make_error_code(std::errc::protocol_error));
261+
if (error_handler && *error_handler) (*error_handler)(std::make_error_code(std::errc::protocol_error));
226262
return;
227263
}
228264

@@ -262,33 +298,35 @@ namespace glz
262298
}
263299

264300
if (!upgrade_websocket || !connection_upgrade || !accept_key_valid) {
265-
if (on_error_) on_error_(std::make_error_code(std::errc::protocol_error));
301+
if (error_handler && *error_handler) (*error_handler)(std::make_error_code(std::errc::protocol_error));
266302
return;
267303
}
268304

269305
// Handshake successful. Transfer socket to websocket_connection.
270306
auto ws_conn =
271-
std::make_shared<websocket_connection<SocketType>>(std::move(*socket));
307+
std::make_shared<websocket_connection<SocketType>>(socket);
272308
ws_conn->set_client_mode(true);
273-
ws_conn->set_max_message_size(max_message_size_);
309+
ws_conn->set_max_message_size(max_msg_size);
274310

275311
if (response_buf->size() > 0) {
276312
std::string_view initial_data{
277313
static_cast<const char*>(response_buf->data().data()), response_buf->size()};
278314
ws_conn->set_initial_data(initial_data);
279315
}
280316

281-
if (on_message_) ws_conn->on_message(on_message_);
282-
if (on_close_) ws_conn->on_close(on_close_);
283-
if (on_error_) ws_conn->on_error(on_error_);
317+
if (message_handler && *message_handler) ws_conn->on_message(*message_handler);
318+
if (close_handler && *close_handler) ws_conn->on_close(*close_handler);
319+
if (error_handler && *error_handler) ws_conn->on_error(*error_handler);
284320

285321
ws_conn->start_read();
286322
{
287-
std::lock_guard<std::mutex> lock(connection_mutex_);
288-
connection_ = ws_conn;
323+
if (mutex) {
324+
std::lock_guard<std::mutex> lock(*mutex);
325+
connection_ = ws_conn;
326+
}
289327
}
290328

291-
if (on_open_) on_open_();
329+
if (open_handler && *open_handler) (*open_handler)();
292330
});
293331
}
294332
};

0 commit comments

Comments
 (0)