Skip to content

Commit 20abeae

Browse files
kostasrimromange
authored andcommitted
chore: write path for TlsSocket::AsyncReadSome (#415)
This pr implements the asynchronous write path (when engine state NEED_WRITE) for TlsSocket::AsyncReadSome. * add async write * add test Signed-off-by: kostas <[email protected]> Signed-off-by: Roman Gershman <[email protected]>
1 parent af21832 commit 20abeae

File tree

7 files changed

+116
-46
lines changed

7 files changed

+116
-46
lines changed

.github/workflows/mac-os.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ jobs:
4545
ninja -k 5 base/all io/all strings/all util/all echo_server ping_iouring_server \
4646
https_client_cli s3_demo
4747
./fibers_test --logtostderr --gtest_repeat=10
48-
ctest -V -L CI
48+
GLOG_logtostderr=1 ctest -V -L CI

io/io.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct AsyncReadState {
5050

5151
AsyncReadState(AsyncSource* source, const iovec* v, uint32_t length)
5252
: arr(length), owner(source) {
53+
cur = arr.data();
5354
std::copy(v, v + length, arr.data());
5455
}
5556

util/fibers/epoll_socket.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ void EpollSocket::AsyncWriteSome(const iovec* v, uint32_t len, io::AsyncProgress
386386
async_write_pending_ = 1;
387387
}
388388

389+
// TODO implement async functionality
389390
void EpollSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
390391
auto res = ReadSome(v, len);
391392
cb(res);

util/fibers/listener_interface.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ void ListenerInterface::RunAcceptLoop() {
9898

9999
if (!sock_->IsUDS()) {
100100
ep = sock_->LocalEndpoint();
101-
VSOCK(0, *sock_) << "AcceptServer - listening on port " << ep.port();
101+
VSOCK(0, *sock_) << "AcceptServer - listening on " << ep.address() << ":" << ep.port();
102102
}
103103

104104
PreAcceptLoop(sock_->proactor());

util/tls/tls_socket.cc

Lines changed: 90 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,11 @@ void TlsSocket::SetProactor(ProactorBase* p) {
516516

517517
void TlsSocket::AsyncReq::MaybeSendOutputAsyncWithRead() {
518518
if (owner->engine_->OutputPending() != 0) {
519-
// sync interface, works because we are still executing within a fiber
520-
// used for "mocking" and shall be replaced on the next PR with actual async op
521-
owner->MaybeSendOutput();
519+
// Once the networking socket completes the write, it will start the read path
520+
// We use this bool to signal this.
521+
should_read = true;
522+
StartUpstreamWrite();
523+
return;
522524
}
523525

524526
// TODO handle WRITE_IN_PROGRESS here by adding pending_blocked_
@@ -535,16 +537,17 @@ void TlsSocket::AsyncReq::AsyncProgressCb(io::Result<size_t> read_result) {
535537
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_)
536538
<< ", write_total:" << owner->upstream_write_ << " "
537539
<< " pending output: " << owner->engine_->OutputPending() << " "
538-
<< "StartUpstreamRead failed " << read_result.error();
540+
<< "AsyncProgressCb failed " << read_result.error();
539541
}
540542
// Erronous path. Apply the completion callback and exit.
541543
CompleteAsyncReq(read_result);
542544
return;
543545
}
544546

545-
DVLOG(1) << "HandleUpstreamRead " << *read_result << " bytes";
547+
DVLOG(1) << "AsyncProgressCb " << *read_result << " bytes";
546548
owner->engine_->CommitInput(*read_result);
547-
Engine::OpResult engine_read = owner->MaybeReadFromEngine(vec, len);
549+
Engine::OpResult engine_read =
550+
owner->engine_->Read(reinterpret_cast<uint8_t*>(vec->iov_base), vec->iov_len);
548551
if (engine_read > 0) {
549552
CompleteAsyncReq(engine_read);
550553
return;
@@ -579,42 +582,21 @@ void TlsSocket::AsyncReq::HandleOpAsync() {
579582
case Engine::NEED_READ_AND_MAYBE_WRITE:
580583
MaybeSendOutputAsyncWithRead();
581584
break;
582-
// TODO handle NEED_WRITE
585+
case Engine::NEED_WRITE:
586+
MaybeSendOutputAsync();
587+
break;
583588
default:
584589
// EOF_STREAM should be handled earlier
585590
LOG(DFATAL) << "Unsupported " << op_val;
586591
}
587592
}
588593

589-
Engine::OpResult TlsSocket::MaybeReadFromEngine(const iovec* v, uint32_t len) {
590-
size_t read_len = std::min(v->iov_len, size_t(INT_MAX));
591-
Engine::OpResult op_val = engine_->Read(reinterpret_cast<uint8_t*>(v->iov_base), read_len);
592-
DVLOG(2) << "Engine::Read " << read_len << " bytes, got " << op_val;
593-
// if read_len == op_val we could try to read more. However, the next read might require
594-
// an async operation on the underline socket because op_val < 0.
595-
// The problem here is that SSL_read from engine_->Read is *not* idempotent and we might
596-
// end up in a situation where we need to do two things at the same time:
597-
// 1. Call the callers completion callback which will start another async op because
598-
// we read less bytes than what was requested, i.e, read_total < sum_of_all(v->len).
599-
// 2. Start another async operation to satisfy the protocol because op_val < 0 and we
600-
// called engine_->Read which is *not* idempotent.
601-
// For that, it's best to let it flow naturally. If there is some data in the engine read it
602-
// and call the completion callback which will in turn try to read more from the engine.
603-
// It will read everything or reach to a point that an async operation needs to be dispatched.
604-
// That way, we get a linear view of the operations involved with the downside of a few more
605-
// function calls (since we don't try to drain the whole engine as we don't know if the next
606-
// read can be satisfied or dispatch as an async operation).
607-
// Last but not least, it was advised here:
608-
// https://github.com/romange/helio/pull/408#discussion_r2080998216
609-
// That we should remove engine reads from the AsyncRequest all together and return
610-
// to the caller if there was some data read.
611-
return op_val;
612-
}
613-
614594
void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb cb) {
595+
// Engine read
615596
CHECK(!async_read_req_);
616597

617-
Engine::OpResult op_val = MaybeReadFromEngine(v, len);
598+
Engine::OpResult op_val = engine_->Read(reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len);
599+
DVLOG(2) << "Engine::Read tried to read " << v->iov_len << " bytes, got " << op_val;
618600
// We read some data from the engine. Satisfy the request and return.
619601
if (op_val > 0) {
620602
return cb(op_val);
@@ -632,5 +614,80 @@ void TlsSocket::AsyncReadSome(const iovec* v, uint32_t len, io::AsyncProgressCb
632614
async_read_req_->HandleOpAsync();
633615
}
634616

617+
void TlsSocket::AsyncReq::CompleteAsyncWrite(io::Result<size_t> write_result) {
618+
if (!write_result) {
619+
owner->state_ &= ~WRITE_IN_PROGRESS;
620+
621+
// broken_pipe - happens when the other side closes the connection. do not log this.
622+
if (write_result.error() != errc::broken_pipe) {
623+
VLOG(1) << "sock[" << owner->native_handle() << "], state " << int(owner->state_)
624+
<< ", write_total:" << owner->upstream_write_ << " "
625+
<< " pending output: " << owner->engine_->OutputPending()
626+
<< " HandleUpstreamAsyncWrite failed " << write_result.error();
627+
}
628+
629+
// We are done. Errornous exit.
630+
CompleteAsyncReq(write_result);
631+
return;
632+
}
633+
634+
CHECK_GT(*write_result, 0u);
635+
owner->upstream_write_ += *write_result;
636+
owner->engine_->ConsumeOutputBuf(*write_result);
637+
// We might have more data pending. Peek again.
638+
Buffer buffer = owner->engine_->PeekOutputBuf();
639+
640+
// We are not done. Re-arm the async write until we drive it to completion or error.
641+
// We would also like to avoid fragmented socket writes so we make sure we drain it here
642+
if (!buffer.empty()) {
643+
auto& scratch = scratch_iovec;
644+
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
645+
scratch.iov_len = buffer.size();
646+
owner->next_sock_->AsyncWriteSome(
647+
&scratch, 1, [this](auto write_result) { CompleteAsyncWrite(write_result); });
648+
return;
649+
}
650+
651+
if (owner->engine_->OutputPending() > 0) {
652+
LOG(DFATAL) << "ssl buffer is not empty with " << owner->engine_->OutputPending()
653+
<< " bytes. Async short write detected";
654+
}
655+
656+
owner->state_ &= ~WRITE_IN_PROGRESS;
657+
658+
// We are done with the writes, check if we also need to read because we are
659+
// in NEED_READ_AND_MAYBE_WRITE state
660+
if (should_read) {
661+
should_read = false;
662+
StartUpstreamRead();
663+
}
664+
}
665+
666+
void TlsSocket::AsyncReq::StartUpstreamWrite() {
667+
Engine::Buffer buffer = owner->engine_->PeekOutputBuf();
668+
DCHECK(!buffer.empty());
669+
DCHECK((owner->state_ & WRITE_IN_PROGRESS) == 0);
670+
671+
DVLOG(2) << "StartUpstreamWrite " << buffer.size();
672+
// we do not allow concurrent writes from multiple fibers.
673+
owner->state_ |= WRITE_IN_PROGRESS;
674+
675+
auto& scratch = scratch_iovec;
676+
scratch.iov_base = const_cast<uint8_t*>(buffer.data());
677+
scratch.iov_len = buffer.size();
678+
679+
owner->next_sock_->AsyncWriteSome(
680+
&scratch, 1, [this](auto write_result) { CompleteAsyncWrite(write_result); });
681+
}
682+
683+
void TlsSocket::AsyncReq::MaybeSendOutputAsync() {
684+
if (owner->engine_->OutputPending() == 0) {
685+
return;
686+
}
687+
688+
// TODO handle WRITE_IN_PROGRESS to avoid deadlock
689+
StartUpstreamWrite();
690+
}
691+
635692
} // namespace tls
636693
} // namespace util

util/tls/tls_socket.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class TlsSocket final : public FiberSocketBase {
116116
enum { WRITE_IN_PROGRESS = 1, READ_IN_PROGRESS = 2, SHUTDOWN_IN_PROGRESS = 4, SHUTDOWN_DONE = 8 };
117117
uint8_t state_{0};
118118

119+
// TODO turn this into a class with proper access specifiers
119120
struct AsyncReq {
120121
TlsSocket* owner;
121122
// Callback passed from the user.
@@ -127,26 +128,23 @@ class TlsSocket final : public FiberSocketBase {
127128

128129
iovec scratch_iovec;
129130

131+
bool should_read = false;
132+
130133
// Asynchronous helpers
131134
void MaybeSendOutputAsyncWithRead();
135+
void MaybeSendOutputAsync();
132136

133137
void HandleOpAsync();
134138

135139
void StartUpstreamRead();
140+
void StartUpstreamWrite();
136141

137142
void CompleteAsyncReq(io::Result<size_t> result);
143+
void CompleteAsyncWrite(io::Result<size_t> write_result);
138144

139145
void AsyncProgressCb(io::Result<size_t> result);
140146
};
141147

142-
// Helper function that resets the internal async request, applies the
143-
// user AsyncProgressCb and returns. We need this, because progress callbacks
144-
// can start another async request and for that to work, we need to clean up
145-
// the one we are running on.
146-
void CompleteAsyncRequest(io::Result<size_t> result);
147-
148-
Engine::OpResult MaybeReadFromEngine(const iovec* v, uint32_t len);
149-
150148
std::unique_ptr<AsyncReq> async_read_req_;
151149
};
152150

util/tls/tls_socket_test.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,17 @@ class AsyncTlsSocketTest : public testing::TestWithParam<string_view> {
227227

228228
using IoResult = int;
229229

230-
// TODO clean up
231230
virtual void HandleRequest() {
232231
tls_socket_ = std::make_unique<tls::TlsSocket>(conn_socket_.release());
233232
ssl_ctx_ = CreateSslCntx(SERVER);
234233
tls_socket_->InitSSL(ssl_ctx_);
235234
tls_socket_->Accept();
236235

237236
uint8_t buf[16];
238-
std::fill(std::begin(buf), std::end(buf), uint8_t(120));
237+
auto res = tls_socket_->Recv(buf);
238+
EXPECT_TRUE(res.has_value());
239+
EXPECT_TRUE(res.value() == 16);
240+
239241
auto write_res = tls_socket_->Write(buf);
240242
EXPECT_FALSE(write_res);
241243
}
@@ -345,11 +347,22 @@ TEST_P(AsyncTlsSocketTest, AsyncRW) {
345347
proactor_->Await([&] {
346348
ThisFiber::SetName("ConnectFb");
347349

348-
LOG(INFO) << "Connecting to " << listen_ep_;
349350
error_code ec = tls_sock->Connect(listen_ep_);
350351
EXPECT_FALSE(ec);
351352
uint8_t res[16];
352353
std::fill(std::begin(res), std::end(res), uint8_t(120));
354+
{
355+
Done done;
356+
iovec v{.iov_base = &res, .iov_len = 16};
357+
358+
tls_sock->AsyncWriteSome(&v, 1, [done](auto result) mutable {
359+
EXPECT_TRUE(result.has_value());
360+
EXPECT_EQ(*result, 16);
361+
done.Notify();
362+
});
363+
364+
done.Wait();
365+
}
353366
{
354367
uint8_t buf[16];
355368
Done done;

0 commit comments

Comments
 (0)