Skip to content

Commit f9783b4

Browse files
authored
Merge pull request #17 from RobGeada/SchemeSelection
Choose scheme based on presence of TLS
2 parents f618acf + 4475981 commit f9783b4

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

src/main.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ async fn main() {
6060
.compact()
6161
.init();
6262

63-
let orchestrator_client = Arc::new(
63+
let (client, scheme) =
6464
build_orchestrator_client(&gateway_config.orchestrator.host)
65-
.expect("Failed to build HTTP(s) client for communicating with orchestrator"),
66-
);
65+
.expect("Failed to build HTTP(s) client for communicating with orchestrator");
66+
let orchestrator_client = Arc::new(client);
6767

6868
let mut app = Router::new().layer(
6969
TraceLayer::new_for_http()
@@ -77,6 +77,7 @@ async fn main() {
7777
let path = format!("/{}/v1/chat/completions", route.name);
7878
let fallback_message = route.fallback_message.clone();
7979
let orchestrator_client = orchestrator_client.clone();
80+
let scheme = scheme.clone();
8081
app = app.route(
8182
&path,
8283
post(
@@ -88,6 +89,7 @@ async fn main() {
8889
gateway_config,
8990
fallback_message,
9091
orchestrator_client,
92+
scheme,
9193
)
9294
},
9395
),
@@ -145,6 +147,7 @@ async fn handle_generation(
145147
gateway_config: GatewayConfig,
146148
route_fallback_message: Option<String>,
147149
orchestrator_client: Arc<reqwest::Client>,
150+
scheme: String,
148151
) -> Result<impl IntoResponse, (StatusCode, String)> {
149152
tracing::debug!("handle_generation called with payload: {:?}", payload);
150153

@@ -156,11 +159,14 @@ async fn handle_generation(
156159

157160
let url: String = match gateway_config.orchestrator.port {
158161
Some(port) => format!(
159-
"https://{}:{}/api/v2/chat/completions-detection",
160-
gateway_config.orchestrator.host, port
162+
"{}://{}:{}/api/v2/chat/completions-detection",
163+
scheme,
164+
gateway_config.orchestrator.host,
165+
port
161166
),
162167
None => format!(
163-
"https://{}/api/v2/chat/completions-detection",
168+
"{}://{}/api/v2/chat/completions-detection",
169+
scheme,
164170
gateway_config.orchestrator.host
165171
),
166172
};
@@ -192,7 +198,7 @@ async fn handle_generation(
192198
}
193199
}
194200

195-
fn build_orchestrator_client(hostname: &str) -> Result<reqwest::Client, anyhow::Error> {
201+
fn build_orchestrator_client(hostname: &str) -> Result<(reqwest::Client, String), anyhow::Error> {
196202
use openssl::pkcs12::Pkcs12;
197203
use openssl::pkey::PKey;
198204
use openssl::x509::X509;
@@ -205,6 +211,7 @@ fn build_orchestrator_client(hostname: &str) -> Result<reqwest::Client, anyhow::
205211
let ca_path = "/etc/tls/ca/service-ca.crt";
206212

207213
let mut builder = Client::builder();
214+
let mut scheme = String::from("http");
208215

209216
// Add custom CA if it exists
210217
if fs::metadata(ca_path).is_ok() {
@@ -238,11 +245,14 @@ fn build_orchestrator_client(hostname: &str) -> Result<reqwest::Client, anyhow::
238245
let identity = Identity::from_pkcs12_der(&pkcs12_der, "")?;
239246

240247
builder = builder.identity(identity);
248+
249+
// set https
250+
scheme = String::from("https");
241251
} else {
242252
tracing::warn!("mTLS enabled but TLS cert or key not found, using default client");
243253
};
244254

245-
Ok(builder.build()?)
255+
Ok((builder.build()?, scheme))
246256
}
247257

248258
async fn orchestrator_post_request(

0 commit comments

Comments
 (0)