Skip to content

Commit adfff1a

Browse files
committed
feat: add additional application context into Connection
1 parent bcb3f28 commit adfff1a

File tree

5 files changed

+128
-51
lines changed

5 files changed

+128
-51
lines changed

bindings/rust/extended/s2n-tls/src/callbacks/cert_validation.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ pub trait CertValidationCallbackSync: 'static + Send + Sync {
5353
mod tests {
5454
use super::*;
5555
use crate::{connection::Connection, security, testing::*};
56+
use std::any::TypeId;
5657

5758
struct ValidationContext {
5859
accept: bool,
@@ -66,7 +67,10 @@ mod tests {
6667
_info: &mut CertValidationInfo,
6768
) -> Result<bool, Error> {
6869
self.0.increment();
69-
let context = conn.application_context::<ValidationContext>().unwrap();
70+
let application_context_type_id = TypeId::of::<ValidationContext>();
71+
let context = conn
72+
.application_context::<ValidationContext>(application_context_type_id)
73+
.unwrap();
7074
Ok(context.accept)
7175
}
7276
}

bindings/rust/extended/s2n-tls/src/connection.rs

Lines changed: 105 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ use core::{
2828
};
2929
use libc::c_void;
3030
use s2n_tls_sys::*;
31-
use std::{any::Any, ffi::CStr};
31+
use std::{
32+
any::{Any, TypeId},
33+
collections::HashMap,
34+
ffi::CStr,
35+
};
3236

3337
mod builder;
3438
pub use builder::*;
@@ -1408,27 +1412,41 @@ impl Connection {
14081412
Ok(())
14091413
}
14101414

1411-
/// Associates an arbitrary application context with the Connection to be later retrieved via
1415+
/// Associates arbitrary application contexts with the Connection to be later retrieved via
14121416
/// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs.
14131417
///
1414-
/// This API will override an existing application context set on the Connection.
1418+
/// This API will add additional application context on top of existing contexts set on the Connection.
14151419
///
14161420
/// Corresponds to [s2n_connection_set_ctx].
14171421
pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
1418-
self.context_mut().app_context = Some(Box::new(app_context));
1422+
let context_type_id = TypeId::of::<T>();
1423+
self.context_mut()
1424+
.app_context
1425+
.insert(context_type_id, Box::new(app_context));
1426+
}
1427+
1428+
/// Remove a application context set on the Connection.
1429+
pub fn remove_application_context(
1430+
&mut self,
1431+
context_type_id: TypeId,
1432+
) -> Option<Box<dyn Any + Send + Sync>> {
1433+
self.context_mut().app_context.remove(&context_type_id)
14191434
}
14201435

14211436
/// Retrieves a reference to the application context associated with the Connection.
14221437
///
1423-
/// If an application context hasn't already been set on the Connection, or if the set
1424-
/// application context isn't of type T, None will be returned.
1438+
/// If application context hasn't already been set on the Connection, or if the set
1439+
/// application context doesn't match the type id of type T, None will be returned.
14251440
///
14261441
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a
14271442
/// mutable reference to the context, use [`Self::application_context_mut()`].
14281443
///
14291444
/// Corresponds to [s2n_connection_get_ctx].
1430-
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
1431-
match self.context().app_context.as_ref() {
1445+
pub fn application_context<T: Send + Sync + 'static>(
1446+
&self,
1447+
context_type_id: TypeId,
1448+
) -> Option<&T> {
1449+
match self.context().app_context.get(&context_type_id) {
14321450
None => None,
14331451
// The Any trait keeps track of the application context's type. downcast_ref() returns
14341452
// Some only if the correct type is provided:
@@ -1439,15 +1457,18 @@ impl Connection {
14391457

14401458
/// Retrieves a mutable reference to the application context associated with the Connection.
14411459
///
1442-
/// If an application context hasn't already been set on the Connection, or if the set
1443-
/// application context isn't of type T, None will be returned.
1460+
/// If application context hasn't already been set on the Connection, or if the set
1461+
/// application context doesn't match the type id of type T, None will be returned.
14441462
///
14451463
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an
14461464
/// immutable reference to the context, use [`Self::application_context()`].
14471465
///
14481466
/// Corresponds to [s2n_connection_get_ctx].
1449-
pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
1450-
match self.context_mut().app_context.as_mut() {
1467+
pub fn application_context_mut<T: Send + Sync + 'static>(
1468+
&mut self,
1469+
context_type_id: TypeId,
1470+
) -> Option<&mut T> {
1471+
match self.context_mut().app_context.get_mut(&context_type_id) {
14511472
None => None,
14521473
Some(app_context) => app_context.downcast_mut::<T>(),
14531474
}
@@ -1475,7 +1496,7 @@ struct Context {
14751496
async_callback: Option<AsyncCallback>,
14761497
verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
14771498
connection_initialized: bool,
1478-
app_context: Option<Box<dyn Any + Send + Sync>>,
1499+
app_context: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
14791500
#[cfg(feature = "unstable-renegotiate")]
14801501
pub(crate) renegotiate_state: RenegotiateState,
14811502
#[cfg(feature = "unstable-cert_authorities")]
@@ -1490,7 +1511,7 @@ impl Context {
14901511
async_callback: None,
14911512
verify_host_callback: None,
14921513
connection_initialized: false,
1493-
app_context: None,
1514+
app_context: HashMap::new(),
14941515
#[cfg(feature = "unstable-renegotiate")]
14951516
renegotiate_state: RenegotiateState::default(),
14961517
#[cfg(feature = "unstable-cert_authorities")]
@@ -1602,6 +1623,7 @@ impl Drop for Connection {
16021623
mod tests {
16031624
use super::*;
16041625
use crate::testing::{build_config, SniTestCerts, TestPair};
1626+
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
16051627

16061628
// ensure the connection context is send
16071629
#[test]
@@ -1622,14 +1644,23 @@ mod tests {
16221644
fn test_app_context_set_and_retrieve() {
16231645
let mut connection = Connection::new_server();
16241646

1647+
let test_value: u32 = 1142;
1648+
let test_context_type_id = TypeId::of::<u32>();
1649+
16251650
// Before a context is set, None is returned.
1626-
assert!(connection.application_context::<u32>().is_none());
1651+
assert!(connection
1652+
.application_context::<u32>(test_context_type_id)
1653+
.is_none());
16271654

1628-
let test_value: u32 = 1142;
16291655
connection.set_application_context(test_value);
16301656

16311657
// After a context is set, the application data is returned.
1632-
assert_eq!(*connection.application_context::<u32>().unwrap(), 1142);
1658+
assert_eq!(
1659+
*connection
1660+
.application_context::<u32>(test_context_type_id)
1661+
.unwrap(),
1662+
1142
1663+
);
16331664
}
16341665

16351666
/// Test that an application context can be modified.
@@ -1640,33 +1671,62 @@ mod tests {
16401671
let mut connection = Connection::new_server();
16411672
connection.set_application_context(test_value);
16421673

1643-
let context_value = connection.application_context_mut::<u64>().unwrap();
1674+
let context_type_id = TypeId::of::<u64>();
1675+
let context_value = connection
1676+
.application_context_mut::<u64>(context_type_id)
1677+
.unwrap();
16441678
*context_value += 1;
16451679

1646-
assert_eq!(*connection.application_context::<u64>().unwrap(), 1);
1680+
assert_eq!(
1681+
*connection
1682+
.application_context::<u64>(context_type_id)
1683+
.unwrap(),
1684+
1
1685+
);
16471686
}
16481687

1649-
/// Test that an application context can be overridden.
1688+
/// Test that multiple application contexts can be set in a connection
16501689
#[test]
1651-
fn test_app_context_override() {
1690+
fn test_multiple_app_contexts() {
16521691
let mut connection = Connection::new_server();
16531692

1654-
let test_value: u16 = 1142;
1655-
connection.set_application_context(test_value);
1656-
1657-
assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);
1658-
1659-
// Override the context with a new value.
1660-
let test_value: u16 = 10;
1661-
connection.set_application_context(test_value);
1662-
1663-
assert_eq!(*connection.application_context::<u16>().unwrap(), 10);
1664-
1665-
// Override the context with a new type.
1666-
let test_value: i16 = -20;
1667-
connection.set_application_context(test_value);
1668-
1669-
assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
1693+
let first_test_value: u16 = 1142;
1694+
connection.set_application_context(first_test_value);
1695+
1696+
let first_test_value_type_id = TypeId::of::<u16>();
1697+
assert_eq!(
1698+
*connection
1699+
.application_context::<u16>(first_test_value_type_id)
1700+
.unwrap(),
1701+
1142
1702+
);
1703+
1704+
// Insert the second application context to the connection
1705+
let second_test_value: SocketAddr =
1706+
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
1707+
connection.set_application_context(second_test_value);
1708+
1709+
let second_test_value_type_id = TypeId::of::<SocketAddr>();
1710+
assert_eq!(
1711+
*connection
1712+
.application_context::<SocketAddr>(second_test_value_type_id)
1713+
.unwrap(),
1714+
second_test_value
1715+
);
1716+
1717+
// Remove the second application context
1718+
assert_eq!(
1719+
second_test_value,
1720+
*connection
1721+
.remove_application_context(second_test_value_type_id)
1722+
.unwrap()
1723+
.downcast::<SocketAddr>()
1724+
.unwrap()
1725+
);
1726+
1727+
assert!(connection
1728+
.application_context::<SocketAddr>(second_test_value_type_id)
1729+
.is_none());
16701730
}
16711731

16721732
/// Test that a context of another type can't be retrieved.
@@ -1677,11 +1737,17 @@ mod tests {
16771737
let test_value: u32 = 0;
16781738
connection.set_application_context(test_value);
16791739

1740+
let invalid_context_type_id = TypeId::of::<i16>();
16801741
// A context type that wasn't set shouldn't be returned.
1681-
assert!(connection.application_context::<i16>().is_none());
1742+
assert!(connection
1743+
.application_context::<i16>(invalid_context_type_id)
1744+
.is_none());
16821745

1746+
let valid_context_type_id = TypeId::of::<u32>();
16831747
// Retrieving the correct type succeeds.
1684-
assert!(connection.application_context::<u32>().is_some());
1748+
assert!(connection
1749+
.application_context::<u32>(valid_context_type_id)
1750+
.is_some());
16851751
}
16861752

16871753
/// Test that the `certificate_match` Rust wrapper returns expected enum variant

bindings/rust/extended/s2n-tls/src/renegotiate.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ mod tests {
440440
ErrorCode, Ssl, SslContext, SslFiletype, SslMethod, SslStream, SslVerifyMode, SslVersion,
441441
};
442442
use std::{
443+
any::TypeId,
443444
error::Error,
444445
io::{Read, Write},
445446
pin::Pin,
@@ -1087,7 +1088,10 @@ mod tests {
10871088
ctx.waker().wake_by_ref();
10881089
let this = self.get_mut();
10891090
// Assert that nothing is currently set
1090-
assert!(conn.application_context::<String>().is_none());
1091+
let application_context_type_id = TypeId::of::<String>();
1092+
assert!(conn
1093+
.application_context::<String>(application_context_type_id)
1094+
.is_none());
10911095
if this.count > 1 {
10921096
// Repeatedly block the handshake in order to verify
10931097
// that renegotiate can handle Pending callbacks.
@@ -1122,7 +1126,8 @@ mod tests {
11221126
pair.assert_renegotiate()?;
11231127
assert_eq!(wake_count, count_per_handshake * 2);
11241128

1125-
let context: Option<&String> = pair.client.application_context();
1129+
let application_context_type_id = TypeId::of::<String>();
1130+
let context: Option<&String> = pair.client.application_context(application_context_type_id);
11261131
assert_eq!(Some(&expected_context), context);
11271132

11281133
Ok(())

bindings/rust/extended/s2n-tls/src/testing/s2n_tls.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mod tests {
1313
use core::sync::atomic::Ordering;
1414
use futures_test::task::{new_count_waker, noop_waker};
1515
use security::Policy;
16-
use std::{fs, path::Path, pin::Pin, sync::atomic::AtomicUsize};
16+
use std::{any::TypeId, fs, path::Path, pin::Pin, sync::atomic::AtomicUsize};
1717

1818
#[test]
1919
fn handshake_default() {
@@ -634,8 +634,9 @@ mod tests {
634634
&self,
635635
connection: &mut connection::Connection,
636636
) -> ConnectionFutureResult {
637+
let context_type_id = TypeId::of::<TestApplicationContext>();
637638
let app_context = connection
638-
.application_context_mut::<TestApplicationContext>()
639+
.application_context_mut::<TestApplicationContext>(context_type_id)
639640
.unwrap();
640641
app_context.invoked_count += 1;
641642
Ok(None)
@@ -663,9 +664,10 @@ mod tests {
663664

664665
pair.handshake()?;
665666

667+
let context_type_id = TypeId::of::<TestApplicationContext>();
666668
let context = pair
667669
.server
668-
.application_context::<TestApplicationContext>()
670+
.application_context::<TestApplicationContext>(context_type_id)
669671
.unwrap();
670672
assert_eq!(context.invoked_count, 1);
671673

bindings/rust/standard/integration/tests/memory.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,12 @@ mod memory_test {
274274
/// lifecycle. The static memory row is an absolute measurement, not a diff.
275275
fn assert_expected(&self) {
276276
const EXPECTED_MEMORY: &[(Lifecycle, usize)] = &[
277-
(Lifecycle::ConnectionInit, 61_482),
278-
(Lifecycle::AfterClientHello, 88_302),
279-
(Lifecycle::AfterServerHello, 116_669),
280-
(Lifecycle::AfterClientFinished, 107_976),
281-
(Lifecycle::HandshakeComplete, 90_563),
282-
(Lifecycle::ApplicationData, 90_563),
277+
(Lifecycle::ConnectionInit, 61_578),
278+
(Lifecycle::AfterClientHello, 88_406),
279+
(Lifecycle::AfterServerHello, 116_773),
280+
(Lifecycle::AfterClientFinished, 108_080),
281+
(Lifecycle::HandshakeComplete, 90_667),
282+
(Lifecycle::ApplicationData, 90_667),
283283
];
284284
let actual_memory: Vec<(Lifecycle, usize)> = Lifecycle::all_stages()
285285
.into_iter()

0 commit comments

Comments
 (0)