diff --git a/Cargo.lock b/Cargo.lock index c37b946..e5832b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,10 +2,17 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + [[package]] name = "cairo-debugger" version = "0.1.0" dependencies = [ + "anyhow", "dap", "tracing", ] @@ -13,8 +20,7 @@ dependencies = [ [[package]] name = "dap" version = "0.4.1-alpha1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35c7fc89d334ab745ba679f94c7314c9b17ecdcd923c111df6206e9fd7729fa9" +source = "git+https://github.com/software-mansion-labs/dap-rs?rev=4440a6f#4440a6fa4ffd26f88f1191ee2371a482fde2a539" dependencies = [ "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 9838fab..a1e33af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,6 @@ version = "0.1.0" edition = "2024" [dependencies] -dap = "0.4.1-alpha1" +dap = { git = "https://github.com/software-mansion-labs/dap-rs", rev = "4440a6f" } tracing = "0.1" +anyhow = "1.0" diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 0000000..bbf7398 --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,126 @@ +use std::io::{BufReader, BufWriter}; +use std::net::{TcpListener, TcpStream}; +use std::sync::mpsc; +use std::thread; +use std::thread::JoinHandle; + +use anyhow::Context; +use anyhow::Result; +use dap::base_message::Sendable; +use dap::errors::ServerError; +use dap::prelude::{Event, Request, ResponseBody, Server}; +use dap::server::{ServerReader, ServerWriter}; + +pub struct Connection { + inbound_rx: mpsc::Receiver, + outbound_tx: mpsc::Sender, + + // NOTE: The order of members matters here. + // I/O threads must be dropped after the channels. + _io_threads: IoThreads, +} + +impl Connection { + pub fn new() -> Result { + let tcp_listener = TcpListener::bind("127.0.0.1:0").map_err(ServerError::IoError)?; + let os_assigned_port = tcp_listener.local_addr()?.port(); + // Print it so that the client can read it. + println!("\nDEBUGGER PORT: {os_assigned_port}"); + + let (stream, _client_addr) = tcp_listener.accept().map_err(ServerError::IoError)?; + let input = BufReader::new(stream.try_clone()?); + let output = BufWriter::new(stream); + + let (server_reader, server_writer) = Server::new(input, output).split_server(); + + let (inbound_tx, inbound_rx) = mpsc::channel::(); + let (outbound_tx, outbound_rx) = mpsc::channel::(); + + Ok(Self { + inbound_rx, + outbound_tx, + _io_threads: IoThreads::spawn(server_reader, server_writer, inbound_tx, outbound_rx), + }) + } + + pub fn next_request(&self) -> Result { + self.inbound_rx.recv().context("Inbound connection closed") + } + + pub fn send_event(&self, event: Event) -> Result<()> { + self.outbound_tx + .send(Sendable::Event(event)) + .context("Sending event to outbound channel failed") + } + + pub fn send_success(&self, request: Request, body: ResponseBody) -> Result<()> { + self.outbound_tx + .send(Sendable::Response(request.success(body))) + .context("Sending success response to outbound channel failed") + } + + pub fn send_error(&self, request: Request, msg: &str) -> Result<()> { + self.outbound_tx + .send(Sendable::Response(request.error(msg))) + .context("Sending error response to outbound channel failed") + } +} + +struct IoThreads { + pub reader: Option>, + pub writer: Option>, +} + +impl IoThreads { + fn spawn( + server_reader: ServerReader, + server_writer: ServerWriter, + inbound_tx: mpsc::Sender, + outbound_rx: mpsc::Receiver, + ) -> Self { + Self { + reader: Some(spawn_reader_thread(server_reader, inbound_tx)), + writer: Some(spawn_writer_thread(server_writer, outbound_rx)), + } + } +} + +impl Drop for IoThreads { + fn drop(&mut self) { + self.reader.take().map(|h| h.join()); + self.writer.take().map(|h| h.join()); + } +} + +fn spawn_reader_thread( + mut server_reader: ServerReader, + inbound_tx: mpsc::Sender, +) -> JoinHandle<()> { + thread::spawn(move || { + while let Ok(Some(request)) = server_reader.poll_request() { + if inbound_tx.send(request).is_err() { + // TODO: Add error tracing + break; + } + } + }) +} + +fn spawn_writer_thread( + mut server_writer: ServerWriter, + outbound_rx: mpsc::Receiver, +) -> JoinHandle<()> { + thread::spawn(move || { + while let Ok(msg) = outbound_rx.recv() { + match msg { + Sendable::Response(response) => { + server_writer.respond(response).expect("Failed to send response") + } + Sendable::Event(event) => { + server_writer.send_event(event).expect("Failed to send event") + } + Sendable::ReverseRequest(_) => unreachable!("Reverse requests are not supported"), + } + } + }) +} diff --git a/src/lib.rs b/src/lib.rs index 82eb092..a3c7dd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,7 @@ -use std::io::{BufReader, BufWriter}; -use std::net::{TcpListener, TcpStream}; - -use dap::errors::ServerError; +use anyhow::Result; +use connection::Connection; use dap::events::{Event, StoppedEventBody}; -use dap::prelude::{Command, Request, ResponseBody, Server}; +use dap::prelude::{Command, Request, ResponseBody}; use dap::responses::{ EvaluateResponse, ScopesResponse, SetBreakpointsResponse, StackTraceResponse, ThreadsResponse, VariablesResponse, @@ -11,9 +9,11 @@ use dap::responses::{ use dap::types::{Breakpoint, Capabilities, Source, StackFrame, StoppedEventReason, Thread}; use tracing::trace; +mod connection; + // TODO: add vm, add handlers for requests. pub struct CairoDebugger { - server: Server, + connection: Connection, } enum ServerResponse { @@ -24,27 +24,20 @@ enum ServerResponse { } impl CairoDebugger { - pub fn connect() -> Result { - let tcp_listener = TcpListener::bind("127.0.0.1:0").map_err(ServerError::IoError)?; - let os_assigned_port = tcp_listener.local_addr().unwrap().port(); - // Print it so that the client can read it. - println!("\nDEBUGGER PORT: {os_assigned_port}"); - - let (stream, _client_addr) = tcp_listener.accept().map_err(ServerError::IoError)?; - let input = BufReader::new(stream.try_clone().unwrap()); - let output = BufWriter::new(stream); - Ok(Self { server: Server::new(input, output) }) + pub fn connect() -> Result { + let connection = Connection::new()?; + Ok(Self { connection }) } - pub fn run(&mut self) -> Result<(), ServerError> { - while let Some(req) = self.server.poll_request()? { + pub fn run(&mut self) -> Result<()> { + while let Ok(req) = self.connection.next_request() { match handle_request(&req) { - ServerResponse::Success(body) => self.server.respond(req.success(body))?, - ServerResponse::Error(msg) => self.server.respond(req.error(&msg))?, - ServerResponse::Event(event) => self.server.send_event(event)?, + ServerResponse::Success(body) => self.connection.send_success(req, body)?, + ServerResponse::Error(msg) => self.connection.send_error(req, &msg)?, + ServerResponse::Event(event) => self.connection.send_event(event)?, ServerResponse::SuccessThenEvent(body, event) => { - self.server.respond(req.success(body))?; - self.server.send_event(event)?; + self.connection.send_success(req, body)?; + self.connection.send_event(event)?; } } }