Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 75 additions & 22 deletions bindings/rust/extended/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ use core::{
};
use libc::c_void;
use s2n_tls_sys::*;
use std::{any::Any, ffi::CStr};
use std::{
any::{Any, TypeId},
collections::HashMap,
ffi::CStr,
};

mod builder;
pub use builder::*;
Expand Down Expand Up @@ -1408,49 +1412,63 @@ impl Connection {
Ok(())
}

/// Associates an arbitrary application context with the Connection to be later retrieved via
/// Associates arbitrary application contexts with the Connection to be later retrieved via
/// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs.
///
/// This API will override an existing application context set on the Connection.
/// While multiple application contexts of different types may be set, previous values of the same type will be overridden.
///
/// Corresponds to [s2n_connection_set_ctx].
pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
self.context_mut().app_context = Some(Box::new(app_context));
let context_type_id = TypeId::of::<T>();
self.context_mut()
.app_context
.insert(context_type_id, Box::new(app_context));
}

/// Removes an application context set on the Connection.
///
/// Returns Some containing the removed context if it exists, or None if no context
/// of the specified type was previously set.
pub fn remove_application_context<T: Send + Sync + 'static>(
&mut self,
) -> Option<Box<dyn Any + Send + Sync>> {
let context_type_id = TypeId::of::<T>();
self.context_mut().app_context.remove(&context_type_id)
}

/// Retrieves a reference to the application context associated with the Connection.
///
/// If an application context hasn't already been set on the Connection, or if the set
/// application context isn't of type T, None will be returned.
/// Returns None if the provided type T does not match the type of any application context set on the Connection.
///
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a
/// mutable reference to the context, use [`Self::application_context_mut()`].
///
/// Corresponds to [s2n_connection_get_ctx].
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
match self.context().app_context.as_ref() {
None => None,
// The Any trait keeps track of the application context's type. downcast_ref() returns
// Some only if the correct type is provided:
// https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
Some(app_context) => app_context.downcast_ref::<T>(),
}
let context_type_id = TypeId::of::<T>();
// The Any trait keeps track of the application context's type. downcast_ref() returns
// Some only if the correct type is provided:
// https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
self.context()
.app_context
.get(&context_type_id)
.and_then(|app_context| app_context.downcast_ref::<T>())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

}

/// Retrieves a mutable reference to the application context associated with the Connection.
///
/// If an application context hasn't already been set on the Connection, or if the set
/// application context isn't of type T, None will be returned.
/// Returns None if the provided type T does not match the type of any application context set on the Connection.
///
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an
/// immutable reference to the context, use [`Self::application_context()`].
///
/// Corresponds to [s2n_connection_get_ctx].
pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
match self.context_mut().app_context.as_mut() {
None => None,
Some(app_context) => app_context.downcast_mut::<T>(),
}
let context_type_id = TypeId::of::<T>();
self.context_mut()
.app_context
.get_mut(&context_type_id)
.and_then(|app_context| app_context.downcast_mut::<T>())
}

#[cfg(feature = "unstable-cert_authorities")]
Expand All @@ -1475,7 +1493,7 @@ struct Context {
async_callback: Option<AsyncCallback>,
verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
connection_initialized: bool,
app_context: Option<Box<dyn Any + Send + Sync>>,
app_context: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
#[cfg(feature = "unstable-renegotiate")]
pub(crate) renegotiate_state: RenegotiateState,
#[cfg(feature = "unstable-cert_authorities")]
Expand All @@ -1490,7 +1508,7 @@ impl Context {
async_callback: None,
verify_host_callback: None,
connection_initialized: false,
app_context: None,
app_context: HashMap::new(),
#[cfg(feature = "unstable-renegotiate")]
renegotiate_state: RenegotiateState::default(),
#[cfg(feature = "unstable-cert_authorities")]
Expand Down Expand Up @@ -1602,6 +1620,7 @@ impl Drop for Connection {
mod tests {
use super::*;
use crate::testing::{build_config, SniTestCerts, TestPair};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

// ensure the connection context is send
#[test]
Expand All @@ -1622,10 +1641,11 @@ mod tests {
fn test_app_context_set_and_retrieve() {
let mut connection = Connection::new_server();

let test_value: u32 = 1142;

// Before a context is set, None is returned.
assert!(connection.application_context::<u32>().is_none());

let test_value: u32 = 1142;
connection.set_application_context(test_value);

// After a context is set, the application data is returned.
Expand Down Expand Up @@ -1669,6 +1689,39 @@ mod tests {
assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
}

/// Test that multiple application contexts can be set in a connection
#[test]
fn test_multiple_app_contexts() {
let mut connection = Connection::new_server();

let first_test_value: u16 = 1142;
connection.set_application_context(first_test_value);

assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);

// Insert the second application context to the connection
let second_test_value: SocketAddr =
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
connection.set_application_context(second_test_value);

assert_eq!(
*connection.application_context::<SocketAddr>().unwrap(),
second_test_value
);

// Remove the second application context
assert_eq!(
second_test_value,
*connection
.remove_application_context::<SocketAddr>()
.unwrap()
.downcast::<SocketAddr>()
.unwrap()
);

assert!(connection.application_context::<SocketAddr>().is_none());
}

/// Test that a context of another type can't be retrieved.
#[test]
fn test_app_context_invalid_type() {
Expand Down
12 changes: 6 additions & 6 deletions bindings/rust/standard/integration/tests/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ mod memory_test {
/// lifecycle. The static memory row is an absolute measurement, not a diff.
fn assert_expected(&self) {
const EXPECTED_MEMORY: &[(Lifecycle, usize)] = &[
(Lifecycle::ConnectionInit, 61_482),
(Lifecycle::AfterClientHello, 88_302),
(Lifecycle::AfterServerHello, 116_669),
(Lifecycle::AfterClientFinished, 107_976),
(Lifecycle::HandshakeComplete, 90_563),
(Lifecycle::ApplicationData, 90_563),
(Lifecycle::ConnectionInit, 61_578),
(Lifecycle::AfterClientHello, 88_406),
(Lifecycle::AfterServerHello, 116_773),
(Lifecycle::AfterClientFinished, 108_080),
(Lifecycle::HandshakeComplete, 90_667),
(Lifecycle::ApplicationData, 90_667),
];
let actual_memory: Vec<(Lifecycle, usize)> = Lifecycle::all_stages()
.into_iter()
Expand Down
Loading