11#pragma once
22
33#include < mutex>
4+ #include < random>
45#include < variant>
56
67#include " glaze/net/http_client.hpp"
@@ -30,13 +31,13 @@ namespace glz
3031
3132 std::shared_ptr<asio::io_context> ctx_;
3233 connection_variant connection_;
33- std::mutex connection_mutex_;
34+ std::shared_ptr<std:: mutex> connection_mutex_;
3435
35- // Callbacks
36- message_handler_t on_message_;
37- open_handler_t on_open_;
38- close_handler_t on_close_;
39- error_handler_t on_error_;
36+ // Callbacks - wrapped in shared_ptr to survive client destruction
37+ std::shared_ptr< message_handler_t > on_message_;
38+ std::shared_ptr< open_handler_t > on_open_;
39+ std::shared_ptr< close_handler_t > on_close_;
40+ std::shared_ptr< error_handler_t > on_error_;
4041
4142 // Keep alive the resolver/socket during connection
4243 std::shared_ptr<asio::ip::tcp::resolver> resolver_;
@@ -52,12 +53,16 @@ namespace glz
5253 ctx_ = ctx;
5354 else
5455 ctx_ = std::make_shared<asio::io_context>();
56+
57+ connection_mutex_ = std::make_shared<std::mutex>();
5558 }
5659
57- void on_message (message_handler_t handler) { on_message_ = std::move (handler); }
58- void on_open (open_handler_t handler) { on_open_ = std::move (handler); }
59- void on_close (close_handler_t handler) { on_close_ = std::move (handler); }
60- void on_error (error_handler_t handler) { on_error_ = std::move (handler); }
60+ ~websocket_client () = default ;
61+
62+ void on_message (message_handler_t handler) { on_message_ = std::make_shared<message_handler_t >(std::move (handler)); }
63+ void on_open (open_handler_t handler) { on_open_ = std::make_shared<open_handler_t >(std::move (handler)); }
64+ void on_close (close_handler_t handler) { on_close_ = std::make_shared<close_handler_t >(std::move (handler)); }
65+ void on_error (error_handler_t handler) { on_error_ = std::make_shared<error_handler_t >(std::move (handler)); }
6166
6267 void set_max_message_size (size_t size) { max_message_size_ = size; }
6368
@@ -67,7 +72,7 @@ namespace glz
6772 {
6873 auto url_result = parse_url (url_str);
6974 if (!url_result) {
70- if (on_error_) on_error_ (url_result.error ());
75+ if (on_error_ && *on_error_) (* on_error_) (url_result.error ());
7176 return ;
7277 }
7378
@@ -91,7 +96,7 @@ namespace glz
9196 return ;
9297 }
9398#else
94- if (on_error_) on_error_ (std::make_error_code (std::errc::protocol_not_supported));
99+ if (on_error_ && *on_error_) (* on_error_) (std::make_error_code (std::errc::protocol_not_supported));
95100 return ;
96101#endif
97102 }
@@ -101,30 +106,33 @@ namespace glz
101106
102107 resolver_ = std::make_shared<asio::ip::tcp::resolver>(*ctx_);
103108
109+ auto error_handler = on_error_; // Copy shared_ptr
104110 resolver_->async_resolve (
105111 url.host , std::to_string (url.port ),
106- [this , url](std::error_code ec, asio::ip::tcp::resolver::results_type results) {
112+ [this , url, error_handler ](std::error_code ec, asio::ip::tcp::resolver::results_type results) {
107113 if (ec) {
108- if (on_error_) on_error_ (ec);
114+ if (error_handler && *error_handler) (*error_handler) (ec);
109115 return ;
110116 }
111117
112118 // Determine which socket to connect
113119 auto & socket_ref = get_tcp_socket_ref ();
114120
121+ auto error_handler2 = on_error_; // Copy shared_ptr
115122 asio::async_connect (
116- socket_ref, results, [this , url](std::error_code ec, const asio::ip::tcp::endpoint&) {
123+ socket_ref, results, [this , url, error_handler2 ](std::error_code ec, const asio::ip::tcp::endpoint&) {
117124 if (ec) {
118- if (on_error_) on_error_ (ec);
125+ if (error_handler2 && *error_handler2) (*error_handler2) (ec);
119126 return ;
120127 }
121128
122129 if (url.protocol == " wss" ) {
123130#ifdef GLZ_ENABLE_SSL
124131 // Perform SSL Handshake
125- ssl_socket_->async_handshake (asio::ssl::stream_base::client, [this , url](std::error_code ec) {
132+ auto error_handler3 = on_error_;
133+ ssl_socket_->async_handshake (asio::ssl::stream_base::client, [this , url, error_handler3](std::error_code ec) {
126134 if (ec) {
127- if (on_error_) on_error_ (ec);
135+ if (error_handler3 && *error_handler3) (*error_handler3) (ec);
128136 return ;
129137 }
130138 perform_handshake (ssl_socket_, url);
@@ -140,7 +148,10 @@ namespace glz
140148
141149 void send (std::string_view msg)
142150 {
143- std::lock_guard<std::mutex> lock (connection_mutex_);
151+ auto mutex = connection_mutex_; // Copy shared_ptr to keep mutex alive
152+ if (!mutex) return ;
153+
154+ std::lock_guard<std::mutex> lock (*mutex);
144155 std::visit (
145156 [&](auto && conn) {
146157 if constexpr (!std::is_same_v<std::decay_t <decltype (conn)>, std::monostate>) {
@@ -150,9 +161,27 @@ namespace glz
150161 connection_);
151162 }
152163
164+ void send_binary (std::string_view msg)
165+ {
166+ auto mutex = connection_mutex_; // Copy shared_ptr to keep mutex alive
167+ if (!mutex) return ;
168+
169+ std::lock_guard<std::mutex> lock (*mutex);
170+ std::visit (
171+ [&](auto && conn) {
172+ if constexpr (!std::is_same_v<std::decay_t <decltype (conn)>, std::monostate>) {
173+ if (conn) conn->send_binary (msg);
174+ }
175+ },
176+ connection_);
177+ }
178+
153179 void close ()
154180 {
155- std::lock_guard<std::mutex> lock (connection_mutex_);
181+ auto mutex = connection_mutex_; // Copy shared_ptr to keep mutex alive
182+ if (!mutex) return ;
183+
184+ std::lock_guard<std::mutex> lock (*mutex);
156185 std::visit (
157186 [&](auto && conn) {
158187 if constexpr (!std::is_same_v<std::decay_t <decltype (conn)>, std::monostate>) {
@@ -188,11 +217,12 @@ namespace glz
188217 " \r\n " + " Sec-WebSocket-Version: 13\r\n\r\n " ;
189218
190219 auto req_buf = std::make_shared<std::string>(std::move (handshake));
220+ auto error_handler = on_error_; // Copy shared_ptr
191221
192222 asio::async_write (*socket, asio::buffer (*req_buf),
193- [this , socket, req_buf /* keep alive */ , key](std::error_code ec, std::size_t ) {
223+ [this , socket, req_buf, key, error_handler ](std::error_code ec, std::size_t ) {
194224 if (ec) {
195- if (on_error_) on_error_ (ec);
225+ if (error_handler && *error_handler) (*error_handler) (ec);
196226 return ;
197227 }
198228 read_handshake_response (socket, key);
@@ -205,11 +235,17 @@ namespace glz
205235 // Limit handshake response size to 16KB to prevent DoS
206236 static constexpr size_t max_handshake_size = 1024 * 16 ;
207237 auto response_buf = std::make_shared<asio::streambuf>(max_handshake_size);
238+ auto error_handler = on_error_; // Copy shared_ptr
239+ auto message_handler = on_message_;
240+ auto close_handler = on_close_;
241+ auto open_handler = on_open_;
242+ auto mutex = connection_mutex_;
243+ auto max_msg_size = max_message_size_;
208244
209245 asio::async_read_until (*socket, *response_buf, " \r\n\r\n " ,
210- [this , socket, response_buf, expected_key](std::error_code ec, std::size_t ) {
246+ [this , socket, response_buf, expected_key, error_handler, message_handler, close_handler, open_handler, mutex, max_msg_size ](std::error_code ec, std::size_t ) {
211247 if (ec) {
212- if (on_error_) on_error_ (ec);
248+ if (error_handler && *error_handler) (*error_handler) (ec);
213249 return ;
214250 }
215251
@@ -222,7 +258,7 @@ namespace glz
222258 std::getline (response_stream, status_message);
223259
224260 if (!response_stream || status_code != 101 ) {
225- if (on_error_) on_error_ (std::make_error_code (std::errc::protocol_error));
261+ if (error_handler && *error_handler) (*error_handler) (std::make_error_code (std::errc::protocol_error));
226262 return ;
227263 }
228264
@@ -262,33 +298,35 @@ namespace glz
262298 }
263299
264300 if (!upgrade_websocket || !connection_upgrade || !accept_key_valid) {
265- if (on_error_) on_error_ (std::make_error_code (std::errc::protocol_error));
301+ if (error_handler && *error_handler) (*error_handler) (std::make_error_code (std::errc::protocol_error));
266302 return ;
267303 }
268304
269305 // Handshake successful. Transfer socket to websocket_connection.
270306 auto ws_conn =
271- std::make_shared<websocket_connection<SocketType>>(std::move (* socket) );
307+ std::make_shared<websocket_connection<SocketType>>(socket);
272308 ws_conn->set_client_mode (true );
273- ws_conn->set_max_message_size (max_message_size_ );
309+ ws_conn->set_max_message_size (max_msg_size );
274310
275311 if (response_buf->size () > 0 ) {
276312 std::string_view initial_data{
277313 static_cast <const char *>(response_buf->data ().data ()), response_buf->size ()};
278314 ws_conn->set_initial_data (initial_data);
279315 }
280316
281- if (on_message_ ) ws_conn->on_message (on_message_ );
282- if (on_close_ ) ws_conn->on_close (on_close_ );
283- if (on_error_ ) ws_conn->on_error (on_error_ );
317+ if (message_handler && *message_handler ) ws_conn->on_message (*message_handler );
318+ if (close_handler && *close_handler ) ws_conn->on_close (*close_handler );
319+ if (error_handler && *error_handler ) ws_conn->on_error (*error_handler );
284320
285321 ws_conn->start_read ();
286322 {
287- std::lock_guard<std::mutex> lock (connection_mutex_);
288- connection_ = ws_conn;
323+ if (mutex) {
324+ std::lock_guard<std::mutex> lock (*mutex);
325+ connection_ = ws_conn;
326+ }
289327 }
290328
291- if (on_open_) on_open_ ();
329+ if (open_handler && *open_handler) (*open_handler) ();
292330 });
293331 }
294332 };
0 commit comments