From c95a3e831aa33f36d03ed5c83dc64ed2afb7aca3 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Thu, 3 Apr 2025 13:36:15 -0300 Subject: [PATCH 01/24] Stream content tests and error handling fix (#350) Signed-off-by: Mateus Devino --- src/server.rs | 10 +- tests/common/orchestrator.rs | 25 +- tests/streaming_classification_with_gen.rs | 2 + tests/streaming_content_detection.rs | 640 +++++++++++++++++++++ 4 files changed, 671 insertions(+), 6 deletions(-) create mode 100644 tests/streaming_content_detection.rs diff --git a/src/server.rs b/src/server.rs index d896584d..b798db87 100644 --- a/src/server.rs +++ b/src/server.rs @@ -455,10 +455,12 @@ async fn stream_content_detection( // Create input stream let input_stream = json_lines - .map(|result| { - let message = result.unwrap(); - message.validate()?; - Ok(message) + .map(|result| match result { + Ok(message) => { + message.validate()?; + Ok(message) + } + Err(error) => Err(orchestrator::errors::Error::Validation(error.to_string())), }) .boxed(); diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 257fd850..8de48100 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -26,11 +26,16 @@ use std::{ use bytes::Bytes; use eventsource_stream::{EventStream, Eventsource}; use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator}; -use futures::{Stream, StreamExt, stream::BoxStream}; +use futures::{ + Stream, StreamExt, + stream::{ + BoxStream, {self}, + }, +}; use mocktail::server::MockServer; use rand::Rng; use rustls::crypto::ring; -use serde::de::DeserializeOwned; +use serde::{Serialize, de::DeserializeOwned}; use tokio::task::JoinHandle; use url::Url; @@ -45,6 +50,8 @@ pub const ORCHESTRATOR_GENERATION_WITH_DETECTION_ENDPOINT: &str = "/api/v2/text/generation-detection"; pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content"; +pub const ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT: &str = + "/api/v2/text/detection/stream-content"; pub const ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT: &str = "/api/v2/text/detection/generated"; pub const ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/context"; pub const ORCHESTRATOR_CHAT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/chat"; @@ -327,3 +334,17 @@ fn find_available_port() -> Option { fn port_is_available(port: u16) -> bool { std::net::TcpListener::bind(("0.0.0.0", port)).is_ok() } + +pub fn json_lines_stream( + messages: impl IntoIterator, +) -> impl Stream, std::io::Error>> { + let chunks = messages + .into_iter() + .map(|msg| { + let mut bytes = serde_json::to_vec(&msg).unwrap(); + bytes.push(b'\n'); + Ok(bytes) + }) + .collect::>>(); + stream::iter(chunks) +} diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index 7e79ec34..adaebd91 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -772,6 +772,8 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { }); // Add output detection mock + // TODO: Simply clone mocks instead of create two exact MockSets when/if + // this gets merged: https://github.com/IBM/mocktail/pull/41 let mut angle_brackets_mocks = MockSet::new(); angle_brackets_mocks.mock(|when, then| { when.post() diff --git a/tests/streaming_content_detection.rs b/tests/streaming_content_detection.rs new file mode 100644 index 00000000..db261269 --- /dev/null +++ b/tests/streaming_content_detection.rs @@ -0,0 +1,640 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; + +use common::{ + chunker::{CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE, CHUNKER_STREAMING_ENDPOINT}, + detectors::{ + DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_PARENTHESIS_SENTENCE, + TEXT_CONTENTS_DETECTOR_ENDPOINT, + }, + errors::{DetectorError, OrchestratorError}, + orchestrator::{ + ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, + ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT, TestOrchestratorServer, json_lines_stream, + }, +}; +use fms_guardrails_orchestr8::{ + clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + models::{ + DetectorParams, Metadata, StreamingContentDetectionRequest, + StreamingContentDetectionResponse, + }, + pb::{ + caikit::runtime::chunkers::BidiStreamingChunkerTokenizationTaskRequest, + caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token}, + }, +}; +use futures::StreamExt; +use mocktail::{MockSet, server::MockServer}; +use serde_json::json; +use test_log::test; +use tracing::debug; + +pub mod common; + +/// Asserts scenario with no detections +#[test(tokio::test)] +async fn no_detections() -> Result<(), anyhow::Error> { + let chunker_id = CHUNKER_NAME_SENTENCE; + let angle_brackets_detector = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let parenthesis_detector = DETECTOR_NAME_PARENTHESIS_SENTENCE; + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id) + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Hi".into(), + input_index_stream: 0, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " there!".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " How".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " you?".into(), + input_index_stream: 4, + }, + ]); + + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 9, + text: "Hi there!".into(), + }], + token_count: 0, + processed_index: 9, + start_index: 0, + input_start_index: 0, + input_end_index: 0, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 9, + end: 22, + text: " How are you?".into(), + }], + token_count: 0, + processed_index: 22, + start_index: 9, + input_start_index: 0, + input_end_index: 0, + }, + ]); + }); + + // Add input detection mock + // TODO: Simply clone mocks instead of create two exact MockSets when/if + // this gets merged: https://github.com/IBM/mocktail/pull/41 + let mut angle_brackets_detection_mocks = MockSet::new(); + angle_brackets_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hi there!".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + angle_brackets_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![" How are you?".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + let mut parenthesis_detection_mocks = MockSet::new(); + parenthesis_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hi there!".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + parenthesis_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![" How are you?".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + // Run test orchestrator server + let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_angle_brackets_detector_server = + MockServer::new(angle_brackets_detector).with_mocks(angle_brackets_detection_mocks); + let mock_parenthesis_detector_server = + MockServer::new(parenthesis_detector).with_mocks(parenthesis_detection_mocks); + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([ + &mock_angle_brackets_detector_server, + &mock_parenthesis_detector_server, + ]) + .chunker_servers([&mock_chunker_server]) + .build() + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([ + (angle_brackets_detector.into(), DetectorParams::new()), + (parenthesis_detector.into(), DetectorParams::new()), + ])), + content: "Hi".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " there!".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " How".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " are".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " you?".into(), + }, + ]))) + .send() + .await?; + + // Collects stream results + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + // assertions + let expected_messages = [ + StreamingContentDetectionResponse { + detections: vec![], + start_index: 0, + processed_index: 9, + }, + StreamingContentDetectionResponse { + detections: vec![], + start_index: 9, + processed_index: 22, + }, + ]; + assert_eq!(messages, expected_messages); + + Ok(()) +} + +/// Asserts scenario with detections +#[test(tokio::test)] +async fn detections() -> Result<(), anyhow::Error> { + let chunker_id = CHUNKER_NAME_SENTENCE; + let angle_brackets_detector = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let parenthesis_detector = DETECTOR_NAME_PARENTHESIS_SENTENCE; + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id) + .pb_stream(vec![BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Hi (there)! How are ?".into(), + input_index_stream: 0, + }]); + + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 11, + text: "Hi (there)!".into(), + }], + token_count: 0, + processed_index: 11, + start_index: 0, + input_start_index: 0, + input_end_index: 0, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 11, + end: 26, + text: " How are ?".into(), + }], + token_count: 0, + processed_index: 26, + start_index: 11, + input_start_index: 0, + input_end_index: 0, + }, + ]); + }); + + // Add input detection mock + let mut angle_brackets_detection_mocks = MockSet::new(); + angle_brackets_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hi (there)!".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + angle_brackets_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![" How are ?".into()], + detector_params: DetectorParams::new(), + }); + then.json([[ContentAnalysisResponse { + start: 10, + end: 13, + text: "you".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(angle_brackets_detector.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }]]); + }); + + let mut parenthesis_detection_mocks = MockSet::new(); + parenthesis_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hi (there)!".into()], + detector_params: DetectorParams::new(), + }); + then.json([[ContentAnalysisResponse { + start: 4, + end: 9, + text: "there".into(), + detection: "has_parenthesis".into(), + detection_type: "parenthesis".into(), + detector_id: Some(parenthesis_detector.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }]]); + }); + parenthesis_detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![" How are ?".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + // Run test orchestrator server + let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_angle_brackets_detector_server = + MockServer::new(angle_brackets_detector).with_mocks(angle_brackets_detection_mocks); + let mock_parenthesis_detector_server = + MockServer::new(parenthesis_detector).with_mocks(parenthesis_detection_mocks); + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([ + &mock_angle_brackets_detector_server, + &mock_parenthesis_detector_server, + ]) + .chunker_servers([&mock_chunker_server]) + .build() + .await?; + + // Example orchestrator request with streaming response + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([ + (angle_brackets_detector.into(), DetectorParams::new()), + (parenthesis_detector.into(), DetectorParams::new()), + ])), + content: "Hi (there)! How are ?".into(), + }, + ]))) + .send() + .await?; + + // Collects stream results + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + // assertions + let expected_messages = [ + StreamingContentDetectionResponse { + detections: vec![ContentAnalysisResponse { + start: 4, + end: 9, + text: "there".into(), + detection: "has_parenthesis".into(), + detection_type: "parenthesis".into(), + detector_id: Some(parenthesis_detector.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }], + start_index: 0, + processed_index: 11, + }, + StreamingContentDetectionResponse { + detections: vec![ContentAnalysisResponse { + start: 10, + end: 13, + text: "you".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(angle_brackets_detector.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }], + start_index: 11, + processed_index: 26, + }, + ]; + assert_eq!(messages, expected_messages); + + Ok(()) +} + +/// Asserts clients returning errors. +#[test(tokio::test)] +async fn client_error() -> Result<(), anyhow::Error> { + let chunker_id = CHUNKER_NAME_SENTENCE; + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + + let chunker_error_payload = "Chunker should return an error."; + let detector_error_payload = "Detector should return an error."; + + let mut chunker_mocks = MockSet::new(); + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id) + .pb_stream(vec![BidiStreamingChunkerTokenizationTaskRequest { + text_stream: chunker_error_payload.into(), + input_index_stream: 0, + }]); + then.internal_server_error(); + }); + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, chunker_id) + .pb_stream(vec![BidiStreamingChunkerTokenizationTaskRequest { + text_stream: detector_error_payload.into(), + input_index_stream: 0, + }]); + then.pb_stream([ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: detector_error_payload.len() as i64, + text: detector_error_payload.into(), + }], + token_count: 0, + processed_index: detector_error_payload.len() as i64, + start_index: 0, + input_start_index: 0, + input_end_index: 0, + }]); + }); + + // Add input detection mock + let mut detection_mocks = MockSet::new(); + detection_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![detector_error_payload.into()], + detector_params: DetectorParams::new(), + }); + then.json(DetectorError { + code: 500, + message: "There was an error when running the detection".into(), + }); + }); + + // Run test orchestrator server + let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .build() + .await?; + + // Assert chunker error + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + content: chunker_error_payload.into(), + }, + ]))) + .send() + .await?; + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + let expected_messages = [OrchestratorError { + code: 500, + details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into(), + }]; + assert_eq!(messages, expected_messages); + + // Assert detector error + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + content: detector_error_payload.into(), + }, + ]))) + .send() + .await?; + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + let expected_messages = [OrchestratorError { + code: 500, + details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into(), + }]; + assert_eq!(messages, expected_messages); + + Ok(()) +} + +/// Asserts orchestrator request validation +#[test(tokio::test)] +async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + + // Run test orchestrator server + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .build() + .await?; + + // assert extra argument on request + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([json!( { + "detectors": {detector_name: {}}, + "content": "Hi there!", + "extra_arg": true + })]))) + .send() + .await?; + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].code, 422); + assert!(messages[0].details.starts_with("unknown field `extra_arg`")); + + // assert missing `detectors` on first frame + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([json!( { + "detectors": {detector_name: {}} + })]))) + .send() + .await?; + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].code, 422); + assert!(messages[0].details.starts_with("missing field `content`")); + + // assert missing `detectors` on first frame + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: None, + content: "Hi".into(), + }, + ]))) + .send() + .await?; + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + let expected_messages = [OrchestratorError { + code: 422, + details: "`detectors` is required for the first message".into(), + }]; + assert_eq!( + messages, expected_messages, + "failed on missing `detectors` scenario" + ); + + // assert empty `detectors` on first frame + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::new()), + content: "Hi".into(), + }, + ]))) + .send() + .await?; + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + let expected_messages = [OrchestratorError { + code: 422, + details: "`detectors` must not be empty".into(), + }]; + assert_eq!( + messages, expected_messages, + "failed on empty `detectors` scenario" + ); + + Ok(()) +} From 49af6ea504449807c8c14332bddd17f1c057b81d Mon Sep 17 00:00:00 2001 From: Paulo Marques Caldeira Junior <7291154+pmcjr@users.noreply.github.com> Date: Fri, 4 Apr 2025 14:44:03 -0300 Subject: [PATCH 02/24] Integration tests for `/api/v2/chat/completions-detection` (#360) --- src/clients/openai.rs | 30 +- tests/chat_completions_detection.rs | 1009 +++++++++++++++++++++++++++ tests/common/chat_completion.rs | 19 + tests/common/chat_completions.rs | 19 + tests/common/mod.rs | 1 + tests/common/orchestrator.rs | 3 + tests/test_config.yaml | 8 +- 7 files changed, 1072 insertions(+), 17 deletions(-) create mode 100644 tests/chat_completions_detection.rs create mode 100644 tests/common/chat_completion.rs create mode 100644 tests/common/chat_completions.rs diff --git a/src/clients/openai.rs b/src/clients/openai.rs index d22db9bc..4e03c5e6 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -476,7 +476,7 @@ pub struct ImageUrl { pub detail: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolCall { /// The ID of the tool call. pub id: String, @@ -487,7 +487,7 @@ pub struct ToolCall { pub function: Function, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Function { /// The name of the function to call. pub name: String, @@ -497,7 +497,7 @@ pub struct Function { } /// Represents a chat completion response returned by model, based on the provided input. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletion { /// A unique identifier for the chat completion. pub id: String, @@ -527,7 +527,7 @@ pub struct ChatCompletion { } /// A chat completion choice. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionChoice { /// The index of the choice in the list of choices. pub index: u32, @@ -540,7 +540,7 @@ pub struct ChatCompletionChoice { } /// A chat completion message generated by the model. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionMessage { /// The role of the author of this message. pub role: Role, @@ -554,7 +554,7 @@ pub struct ChatCompletionMessage { pub refusal: Option, } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct ChatCompletionLogprobs { /// A list of message content tokens with log probability information. pub content: Option>, @@ -564,7 +564,7 @@ pub struct ChatCompletionLogprobs { } /// Log probability information for a choice. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionLogprob { /// The token. pub token: String, @@ -577,7 +577,7 @@ pub struct ChatCompletionLogprob { pub top_logprobs: Option>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionTopLogprob { /// The token. pub token: String, @@ -645,7 +645,7 @@ pub struct ChatCompletionDelta { } /// Usage statistics for a completion. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct Usage { /// Number of tokens in the prompt. pub prompt_tokens: u32, @@ -661,13 +661,13 @@ pub struct Usage { pub completion_token_details: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct CompletionTokenDetails { pub audio_tokens: u32, pub reasoning_tokens: u32, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PromptTokenDetails { pub audio_tokens: u32, pub cached_tokens: u32, @@ -691,7 +691,7 @@ pub struct OpenAiError { } /// Guardrails detection results. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatDetections { #[serde(default, skip_serializing_if = "Vec::is_empty")] pub input: Vec, @@ -700,7 +700,7 @@ pub struct ChatDetections { } /// Guardrails detection result for application on input. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct InputDetectionResult { pub message_index: u32, #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -708,7 +708,7 @@ pub struct InputDetectionResult { } /// Guardrails detection result for application output. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct OutputDetectionResult { pub choice_index: u32, #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -724,7 +724,7 @@ pub struct DetectionResult { } /// Warnings generated by guardrails. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct OrchestratorWarning { r#type: DetectionWarningReason, message: String, diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs new file mode 100644 index 00000000..079c0e30 --- /dev/null +++ b/tests/chat_completions_detection.rs @@ -0,0 +1,1009 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::{collections::HashMap, vec}; + +use anyhow::Ok; +use common::{ + chat_completions::CHAT_COMPLETIONS_ENDPOINT, + chunker::CHUNKER_UNARY_ENDPOINT, + detectors::{ + DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + TEXT_CONTENTS_DETECTOR_ENDPOINT, + }, + errors::{DetectorError, OrchestratorError}, + orchestrator::{ + ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, + ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + }, +}; +use fms_guardrails_orchestr8::{ + clients::{ + chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, + detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + openai::{ + ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, ChatCompletionsRequest, + ChatDetections, Content, DetectorConfig, InputDetectionResult, Message, + OrchestratorWarning, OutputDetectionResult, Role, + }, + }, + models::{ + DetectionWarningReason, DetectorParams, Metadata, UNSUITABLE_INPUT_MESSAGE, + UNSUITABLE_OUTPUT_MESSAGE, + }, + pb::{ + caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, + caikit_data_model::nlp::{Token, TokenizationResults}, + }, +}; +use hyper::StatusCode; +use mocktail::prelude::*; +use test_log::test; + +pub mod common; + +// Constants +const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; +const MODEL_ID: &str = "my-super-model-8B"; + +// Validate that requests without detectors, input detector and output detector configured +// returns text generated by model +#[test(tokio::test)] +async fn no_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + let messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Text("Hello!".to_string())), + role: Role::Assistant, + ..Default::default() + }, + ]; + + // Add mocksets + let mut detector_mocks = MockSet::new(); + let mut chat_mocks = MockSet::new(); + + let expected_choices = vec![ + ChatCompletionChoice { + message: ChatCompletionMessage { + role: messages[0].role.clone(), + content: Some("Hi there!".to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 0, + logprobs: None, + finish_reason: "NOT_FINISHED".to_string(), + }, + ChatCompletionChoice { + message: ChatCompletionMessage { + role: messages[1].role.clone(), + content: Some("Hello!".to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 1, + logprobs: None, + finish_reason: "EOS_TOKEN".to_string(), + }, + ]; + + let expected_detections = Some(ChatDetections { + input: vec![], + output: vec![], + }); + + let chat_completions_response = ChatCompletion { + model: MODEL_ID.into(), + choices: expected_choices.clone(), + detections: Some(ChatDetections { + input: vec![], + output: vec![], + }), + warnings: vec![], + ..Default::default() + }; + + // Add detector input mock + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hi there!".into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![Vec::::new()]); + }); + // Add detector output mock + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hello!".into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![Vec::::new()]); + }); + + // Add chat completions mock + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.json(&chat_completions_response); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); + let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chat_generation_server(&mock_chat_completions_server) + .build() + .await?; + + // Make orchestrator call for input/output no detections + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + output: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + }), + messages, + ..Default::default() + }) + .send() + .await?; + + // Assertions for no detections + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], chat_completions_response.choices[0]); + assert_eq!(results.choices[1], chat_completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert_eq!(results.detections, expected_detections); + + Ok(()) +} + +// Validates that requests with input detector configured returns detections +#[test(tokio::test)] +async fn input_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let input_text = "Hi there! Can you help me with ?"; + + let messages = vec![Message { + content: Some(Content::Text(input_text.to_string())), + role: Role::User, + ..Default::default() + }]; + + // Add mocksets + let mut detector_mocks = MockSet::new(); + let mut chunker_mocks = MockSet::new(); + let mut chat_mocks = MockSet::new(); + + // Add input detection mock response for input detection + let expected_detections = vec![ContentAnalysisResponse { + start: 34, + end: 42, + text: "something".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }]; + + let chat_completions_response = ChatCompletion { + model: MODEL_ID.into(), + choices: vec![], + detections: Some(ChatDetections { + input: vec![InputDetectionResult { + message_index: 0, + results: expected_detections.clone(), + }], + output: vec![], + }), + warnings: vec![OrchestratorWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + ..Default::default() + }; + + // Add chunker tokenization mock for input detection + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: input_text.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: input_text.len() as i64, + text: input_text.into(), + }], + token_count: 0, + }); + }); + + // Add detector input mock + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![input_text.into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![&expected_detections]); + }); + + // Add chat completions mock + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.json(&chat_completions_response); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); + let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) + .grpc() + .with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .chat_generation_server(&mock_chat_completions_server) + .build() + .await?; + + // Make orchestrator call for input/output no detections + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + output: None, + }), + messages, + ..Default::default() + }) + .send() + .await?; + + // Assertions for input detections + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.detections, chat_completions_response.detections); + assert_eq!(results.choices, chat_completions_response.choices); + assert_eq!(results.warnings, chat_completions_response.warnings); + + Ok(()) +} + +// Validates that requests with input detector configured returns propagated errors +#[test(tokio::test)] +async fn input_client_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + // Add 500 expected input detector mock response + let expected_detector_error = DetectorError { + code: 500, + message: "Internal detector error.".into(), + }; + // Add 500 expected orchestrator error response + let expected_orchestrator_error = OrchestratorError { + code: 500, + details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.to_string(), + }; + + // Add input for error scenarios + let chunker_error_input = "This should return a 500 error on chunker"; + let detector_error_input = "This should return a 500 error on detector"; + let chat_completions_error_input = "This should return a 500 error on chat completions"; + + // Add mocksets + let mut chunker_mocks = MockSet::new(); + let mut detector_mocks = MockSet::new(); + let mut chat_mocks = MockSet::new(); + + let messages_chunker_error = vec![Message { + content: Some(Content::Text(chunker_error_input.to_string())), + role: Role::User, + ..Default::default() + }]; + + let messages_detector_error = vec![Message { + content: Some(Content::Text(detector_error_input.to_string())), + role: Role::User, + ..Default::default() + }]; + + let messages_chat_completions_error = vec![Message { + content: Some(Content::Text(chat_completions_error_input.to_string())), + role: Role::User, + ..Default::default() + }]; + + // Add chunker tokenization mock for detector internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: detector_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: detector_error_input.len() as i64, + text: detector_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for completions internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: chat_completions_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: chat_completions_error_input.len() as i64, + text: chat_completions_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for chunker internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: chunker_error_input.into(), + }); + then.internal_server_error(); + }); + + // Add detector mock for chat completions error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![chat_completions_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![Vec::::new()]); + }); + + // Add detector mock for detector error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![detector_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.internal_server_error().json(&expected_detector_error); + }); + + // Add chat completions mock for chat completions error scenario + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages_chat_completions_error.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.internal_server_error(); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); + let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) + .grpc() + .with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .chat_generation_server(&mock_chat_completions_server) + .build() + .await?; + + // Make orchestrator call for chunker error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + output: None, + }), + messages: messages_chunker_error.clone(), + ..Default::default() + }) + .send() + .await?; + + // Assertions for chunker error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for detector error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + output: None, + }), + messages: messages_detector_error.clone(), + ..Default::default() + }) + .send() + .await?; + + // Assertions for detector error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for chat completions error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + output: None, + }), + messages: messages_chat_completions_error.clone(), + ..Default::default() + }) + .send() + .await?; + + // Assertions for chat completions error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + Ok(()) +} + +// Validates that requests with output detector configured returns detections +#[test(tokio::test)] +async fn output_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let input_text = "Hi there! Can you help me with something?"; + let output_text = "Sure! Let me help you with , just tell me what you need."; + + let messages = vec![ + Message { + content: Some(Content::Text(input_text.to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Text(output_text.to_string())), + role: Role::Assistant, + ..Default::default() + }, + ]; + + // Add mocksets + let mut detector_mocks = MockSet::new(); + let mut chat_mocks = MockSet::new(); + let mut chunker_mocks = MockSet::new(); + + // Add output detection mock response for output detection + let expected_detections = vec![ContentAnalysisResponse { + start: 28, + end: 37, + text: "something".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }]; + + // Add chat completion choices response for output detection + let expected_choices = vec![ + ChatCompletionChoice { + message: ChatCompletionMessage { + role: messages[0].role.clone(), + content: Some(input_text.to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 0, + logprobs: None, + finish_reason: "NOT_FINISHED".to_string(), + }, + ChatCompletionChoice { + message: ChatCompletionMessage { + role: messages[1].role.clone(), + content: Some(output_text.to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 1, + logprobs: None, + finish_reason: "EOS_TOKEN".to_string(), + }, + ]; + + // Add chat completion response for output detection + let chat_completions_response = ChatCompletion { + model: MODEL_ID.into(), + choices: expected_choices.clone(), + detections: Some(ChatDetections { + input: vec![], + output: vec![OutputDetectionResult { + choice_index: 1, + results: expected_detections.clone(), + }], + }), + warnings: vec![OrchestratorWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )], + ..Default::default() + }; + + // Add detector output mock for first message + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![input_text.into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![Vec::::new()]); + }); + + // Add detector output mock for generated message + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![output_text.into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![&expected_detections]); + }); + + // Add chunker tokenization mock for output detection user input + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: input_text.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: input_text.len() as i64, + text: input_text.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for output detection assistant output + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: output_text.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: output_text.len() as i64, + text: output_text.into(), + }], + token_count: 0, + }); + }); + + // Add chat completions mock + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.json(&chat_completions_response); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); + let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) + .grpc() + .with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .chat_generation_server(&mock_chat_completions_server) + .build() + .await?; + + // Make orchestrator call for output detections + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: None, + output: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + }), + messages, + ..Default::default() + }) + .send() + .await?; + + // Assertions for output detections + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.detections, chat_completions_response.detections); + assert_eq!(results.choices, chat_completions_response.choices); + assert_eq!(results.warnings, chat_completions_response.warnings); + + Ok(()) +} + +// Validates that requests with output detector configured returns propagated errors +// from detector, chunker and completions server when applicable +#[test(tokio::test)] +async fn output_client_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + // Add 500 expected output detector mock response + let expected_detector_error = DetectorError { + code: 500, + message: "Internal detector error.".into(), + }; + // Add 500 expected orchestrator mock response + let expected_orchestrator_error = OrchestratorError { + code: 500, + details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.to_string(), + }; + + // Add input for error scenarios + let chunker_error_input = "This should return a 500 error on chunker"; + let detector_error_input = "This should return a 500 error on detector"; + let chat_completions_error_input = "This should return a 500 error on chat completions"; + + // Add mocksets + let mut chunker_mocks = MockSet::new(); + let mut detector_mocks = MockSet::new(); + let mut chat_mocks = MockSet::new(); + + let messages_chunker_error = vec![Message { + content: Some(Content::Text(chunker_error_input.to_string())), + role: Role::User, + ..Default::default() + }]; + + let messages_detector_error = vec![Message { + content: Some(Content::Text(detector_error_input.to_string())), + role: Role::User, + ..Default::default() + }]; + + let messages_chat_completions_error = vec![Message { + content: Some(Content::Text(chat_completions_error_input.to_string())), + role: Role::User, + ..Default::default() + }]; + + // Add chunker tokenization mock for detector internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: detector_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: detector_error_input.len() as i64, + text: detector_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for completions internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: chat_completions_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: chat_completions_error_input.len() as i64, + text: chat_completions_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for chunker internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: chunker_error_input.into(), + }); + then.internal_server_error(); + }); + + // Add detector mock for chat completions error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![chat_completions_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.json(vec![Vec::::new()]); + }); + + // Add detector mock for detector error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![detector_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.internal_server_error().json(&expected_detector_error); + }); + + // Add chat completions mock for chunker error scenario + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages_chunker_error.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.internal_server_error(); + }); + + // Add chat completions mock for detector error scenario + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages_detector_error.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.internal_server_error().json(&expected_detector_error); + }); + + // Add chat completions mock for chat completions error scenario + chat_mocks.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(ChatCompletionsRequest { + messages: messages_chat_completions_error.clone(), + model: MODEL_ID.into(), + stream: false, + ..Default::default() + }); + then.internal_server_error(); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); + let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) + .grpc() + .with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .chat_generation_server(&mock_chat_completions_server) + .build() + .await?; + + // Make orchestrator call for chunker error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: None, + output: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + }), + messages: messages_chunker_error.clone(), + ..Default::default() + }) + .send() + .await?; + + // Assertions for chunker error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for detector error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: None, + output: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + }), + messages: messages_detector_error.clone(), + ..Default::default() + }) + .send() + .await?; + + // Assertions for detector error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for chat completions error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: None, + output: Some(HashMap::from([( + detector_name.into(), + DetectorParams::new(), + )])), + }), + messages: messages_chat_completions_error.clone(), + ..Default::default() + }) + .send() + .await?; + + // Assertions for chat completions error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + Ok(()) +} + +// Validate that invalid orchestrator requests returns 422 error +#[test(tokio::test)] +async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + // Start orchestrator server and its dependencies + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .build() + .await?; + + // Orchestrator request with non existing field + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&serde_json::json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {} + } + }, + "messages": vec![Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }], + "some_extra_field": "random value" + })) + .send() + .await?; + + // Assertions for invalid request + let results = response.json::().await?; + assert_eq!(results.code, StatusCode::UNPROCESSABLE_ENTITY); + assert!( + results + .details + .starts_with("some_extra_field: unknown field `some_extra_field`") + ); + + Ok(()) +} diff --git a/tests/common/chat_completion.rs b/tests/common/chat_completion.rs new file mode 100644 index 00000000..7ac67d0e --- /dev/null +++ b/tests/common/chat_completion.rs @@ -0,0 +1,19 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +// Chat completions server endpoint +pub const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; diff --git a/tests/common/chat_completions.rs b/tests/common/chat_completions.rs new file mode 100644 index 00000000..7ac67d0e --- /dev/null +++ b/tests/common/chat_completions.rs @@ -0,0 +1,19 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +// Chat completions server endpoint +pub const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index de269335..59720ca5 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -14,6 +14,7 @@ limitations under the License. */ +pub mod chat_completions; pub mod chunker; pub mod detectors; pub mod errors; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 8de48100..73b16e0a 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -56,6 +56,9 @@ pub const ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT: &str = "/api/v2/text/de pub const ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/context"; pub const ORCHESTRATOR_CHAT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/chat"; +pub const ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT: &str = + "/api/v2/chat/completions-detection"; + // Messages pub const ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE: &str = "unexpected error occurred while processing request"; diff --git a/tests/test_config.yaml b/tests/test_config.yaml index e6fe5f2b..89ecb2df 100644 --- a/tests/test_config.yaml +++ b/tests/test_config.yaml @@ -1,3 +1,7 @@ +chat_generation: + service: + hostname: localhost + port: 3000 generation: provider: nlp # tgis or nlp service: @@ -19,7 +23,7 @@ detectors: service: hostname: localhost chunker_id: sentence_chunker - default_threshold: 0.5 + default_threshold: 0.5 parenthesis_detector_sentence: type: text_contents service: @@ -49,4 +53,4 @@ detectors: service: hostname: localhost chunker_id: whole_doc_chunker - default_threshold: 0.5 \ No newline at end of file + default_threshold: 0.5 From 2557f3c05d83f622f50c0c68dbb4162e5aecdc4e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 11:37:43 -0600 Subject: [PATCH 03/24] Bump openssl from 0.10.71 to 0.10.72 (#361) Bumps [openssl](https://github.com/sfackler/rust-openssl) from 0.10.71 to 0.10.72. - [Release notes](https://github.com/sfackler/rust-openssl/releases) - [Commits](https://github.com/sfackler/rust-openssl/compare/openssl-v0.10.71...openssl-v0.10.72) --- updated-dependencies: - dependency-name: openssl dependency-version: 0.10.72 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8b3bd63d..0f2dc60c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1669,9 +1669,9 @@ checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc" [[package]] name = "openssl" -version = "0.10.71" +version = "0.10.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" dependencies = [ "bitflags", "cfg-if", @@ -1701,9 +1701,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.106" +version = "0.9.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" +checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" dependencies = [ "cc", "libc", From 6473cb8aafa5411db4a75229ab9e219404cf4600 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Mon, 7 Apr 2025 17:52:31 -0300 Subject: [PATCH 04/24] Update Rust to 1.86 (#364) Signed-off-by: Mateus Devino --- Dockerfile | 2 +- rust-toolchain.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 028c05e5..a81fdc32 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ # Specific debian version so that compatible glibc version is used -FROM rust:1.85.1-bullseye AS rust-builder +FROM rust:1.86.0-bullseye AS rust-builder ARG PROTOC_VERSION ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/rust-toolchain.toml b/rust-toolchain.toml index e60b212f..95b8a269 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.85.1" +channel = "1.86.0" components = ["rustfmt", "clippy"] \ No newline at end of file From 3d2aa3cf195ff0ab1548eed2671727a518b00028 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 08:45:15 -0600 Subject: [PATCH 05/24] Bump tokio from 1.44.1 to 1.44.2 (#365) Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.44.1 to 1.44.2. - [Release notes](https://github.com/tokio-rs/tokio/releases) - [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.44.1...tokio-1.44.2) --- updated-dependencies: - dependency-name: tokio dependency-version: 1.44.2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f2dc60c..4c3fbb68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2767,9 +2767,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.1" +version = "1.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" dependencies = [ "backtrace", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 75f79a7a..bffb1bad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ serde = { version = "1.0.217", features = ["derive"] } serde_json = { version = "1.0.135", features = ["preserve_order"] } serde_yml = "0.0.12" thiserror = "2.0.11" -tokio = { version = "1.43.0", features = [ +tokio = { version = "1.44.2", features = [ "rt", "rt-multi-thread", "parking_lot", From 1d872de2fb30fb8d299325fa5307fed375c13287 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Tue, 8 Apr 2025 15:23:43 -0300 Subject: [PATCH 06/24] Single detector tests (#362) * Add single detector tests for streaming_content_detection Signed-off-by: Mateus Devino * Add single detector tests for streaming_classification_with_gen Signed-off-by: Mateus Devino --------- Signed-off-by: Mateus Devino --- tests/streaming_classification_with_gen.rs | 131 ++++++++++++++++++++- tests/streaming_content_detection.rs | 124 +++++++++++++++++-- 2 files changed, 241 insertions(+), 14 deletions(-) diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index adaebd91..bc353f86 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -833,7 +833,7 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { .build() .await?; - // Example orchestrator request with streaming response + // Single-detector scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { @@ -852,10 +852,56 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { }) .send() .await?; + debug!("{response:#?}"); + + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].generated_text, Some("I am great!".into())); + assert_eq!( + messages[0].token_classification_results.output, + Some(vec![]) + ); + assert_eq!(messages[0].start_index, Some(0)); + assert_eq!(messages[0].processed_index, Some(11)); + + assert_eq!(messages[1].generated_text, Some(" What about you?".into())); + assert_eq!( + messages[1].token_classification_results.output, + Some(vec![]) + ); + assert_eq!(messages[1].start_index, Some(11)); + assert_eq!(messages[1].processed_index, Some(27)); + + // Multi-detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([ + (angle_brackets_detector.into(), DetectorParams::new()), + (parenthesis_detector.into(), DetectorParams::new()), + ]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; debug!("{response:#?}"); - // Test custom SseStream wrapper let sse_stream: SseStream = SseStream::new(response.bytes_stream()); let messages = sse_stream @@ -1076,7 +1122,79 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { .build() .await?; - // Example orchestrator request with streaming response + // Single-detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([( + angle_brackets_detector.into(), + DetectorParams::new(), + )]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream + .collect::>() + .await + .into_iter() + .collect::, anyhow::Error>>()?; + debug!("{messages:#?}"); + + let expected_messages = vec![ + ClassifiedGeneratedTextStreamResult { + generated_text: Some("I (am) great!".into()), + token_classification_results: TextGenTokenClassificationResults { + input: None, + output: Some(vec![]), + }, + processed_index: Some(13), + start_index: Some(0), + tokens: Some(vec![]), + input_tokens: Some(vec![]), + ..Default::default() + }, + ClassifiedGeneratedTextStreamResult { + generated_text: Some(" What about ?".into()), + token_classification_results: TextGenTokenClassificationResults { + input: None, + output: Some(vec![TokenClassificationResult { + start: 13, + end: 16, + word: "you".into(), + entity: "has_angle_brackets".into(), + entity_group: "angle_brackets".into(), + detector_id: Some(angle_brackets_detector.into()), + score: 1.0, + token_count: None, + }]), + }, + processed_index: Some(31), + start_index: Some(13), + tokens: Some(vec![]), + input_tokens: Some(vec![]), + ..Default::default() + }, + ]; + + assert_eq!(messages.len(), 2); + assert_eq!( + messages, expected_messages, + "failed on single-detector scenario" + ); + + // Multi-detector scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { @@ -1095,10 +1213,8 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { }) .send() .await?; - debug!("{response:#?}"); - // Test custom SseStream wrapper let sse_stream: SseStream = SseStream::new(response.bytes_stream()); let messages = sse_stream @@ -1154,7 +1270,10 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { ]; assert_eq!(messages.len(), 2); - assert_eq!(messages, expected_messages); + assert_eq!( + messages, expected_messages, + "failed on multi-detector scenario" + ); Ok(()) } diff --git a/tests/streaming_content_detection.rs b/tests/streaming_content_detection.rs index db261269..fc519bcf 100644 --- a/tests/streaming_content_detection.rs +++ b/tests/streaming_content_detection.rs @@ -168,7 +168,63 @@ async fn no_detections() -> Result<(), anyhow::Error> { .build() .await?; - // Example orchestrator request with streaming response + // Single-detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + angle_brackets_detector.into(), + DetectorParams::new(), + )])), + content: "Hi".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " there!".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " How".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " are".into(), + }, + StreamingContentDetectionRequest { + detectors: None, + content: " you?".into(), + }, + ]))) + .send() + .await?; + + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + let expected_messages = [ + StreamingContentDetectionResponse { + detections: vec![], + start_index: 0, + processed_index: 9, + }, + StreamingContentDetectionResponse { + detections: vec![], + start_index: 9, + processed_index: 22, + }, + ]; + assert_eq!( + messages, expected_messages, + "failed on single-detector scenario" + ); + + // Multi-detector scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) .header("content-type", "application/x-ndjson") @@ -200,7 +256,6 @@ async fn no_detections() -> Result<(), anyhow::Error> { .send() .await?; - // Collects stream results let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { @@ -208,7 +263,6 @@ async fn no_detections() -> Result<(), anyhow::Error> { messages.push(serde_json::from_slice(&msg[..]).unwrap()); } - // assertions let expected_messages = [ StreamingContentDetectionResponse { detections: vec![], @@ -221,7 +275,10 @@ async fn no_detections() -> Result<(), anyhow::Error> { processed_index: 22, }, ]; - assert_eq!(messages, expected_messages); + assert_eq!( + messages, expected_messages, + "failed on multi-detector scenario" + ); Ok(()) } @@ -347,7 +404,57 @@ async fn detections() -> Result<(), anyhow::Error> { .build() .await?; - // Example orchestrator request with streaming response + // Single-detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + angle_brackets_detector.into(), + DetectorParams::new(), + )])), + content: "Hi (there)! How are ?".into(), + }, + ]))) + .send() + .await?; + + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + let expected_messages = [ + StreamingContentDetectionResponse { + detections: vec![], + start_index: 0, + processed_index: 11, + }, + StreamingContentDetectionResponse { + detections: vec![ContentAnalysisResponse { + start: 10, + end: 13, + text: "you".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(angle_brackets_detector.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }], + start_index: 11, + processed_index: 26, + }, + ]; + assert_eq!( + messages, expected_messages, + "failed on single-detector scenario" + ); + + // Multi-detector scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) .header("content-type", "application/x-ndjson") @@ -363,7 +470,6 @@ async fn detections() -> Result<(), anyhow::Error> { .send() .await?; - // Collects stream results let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { @@ -371,7 +477,6 @@ async fn detections() -> Result<(), anyhow::Error> { messages.push(serde_json::from_slice(&msg[..]).unwrap()); } - // assertions let expected_messages = [ StreamingContentDetectionResponse { detections: vec![ContentAnalysisResponse { @@ -404,7 +509,10 @@ async fn detections() -> Result<(), anyhow::Error> { processed_index: 26, }, ]; - assert_eq!(messages, expected_messages); + assert_eq!( + messages, expected_messages, + "failed on multi-detector scenario" + ); Ok(()) } From 3a78353d98be61e62aab68dd3494465f3a82f487 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Tue, 15 Apr 2025 11:02:56 -0700 Subject: [PATCH 07/24] refactor: task handlers (#355) Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> Co-authored-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/args.rs | 2 +- src/clients.rs | 6 +- src/clients/chunker.rs | 6 +- src/clients/detector.rs | 2 - src/clients/detector/text_chat.rs | 3 +- src/clients/detector/text_contents.rs | 3 +- src/clients/detector/text_context_doc.rs | 3 +- src/clients/detector/text_generation.rs | 3 +- src/clients/generation.rs | 22 - src/clients/http.rs | 20 +- src/clients/openai.rs | 11 +- src/clients/tgis.rs | 9 +- src/models.rs | 36 +- src/orchestrator.rs | 292 +--- .../chat_completions_detection.rs | 686 -------- src/orchestrator/common/client.rs | 56 +- src/orchestrator/common/tasks.rs | 283 ++-- src/orchestrator/common/utils.rs | 7 + src/orchestrator/detector_processing.rs | 18 - .../detector_processing/content.rs | 118 -- src/orchestrator/handlers.rs | 69 + .../handlers/chat_completions_detection.rs | 65 + .../chat_completions_detection/streaming.rs | 55 + .../chat_completions_detection/unary.rs | 201 +++ src/orchestrator/handlers/chat_detection.rs | 85 + .../handlers/classification_with_gen.rs | 213 +++ .../handlers/context_docs_detection.rs | 89 + .../handlers/detection_on_generation.rs | 88 + .../handlers/generation_with_detection.rs | 112 ++ .../streaming_classification_with_gen.rs | 437 +++++ .../handlers/streaming_content_detection.rs | 258 +++ .../handlers/text_content_detection.rs | 85 + src/orchestrator/streaming.rs | 639 -------- src/orchestrator/streaming/aggregator.rs | 678 -------- .../streaming_content_detection.rs | 491 ------ .../streaming_content_detection/aggregator.rs | 277 ---- src/orchestrator/types/detection.rs | 21 + src/orchestrator/unary.rs | 1459 ----------------- src/server.rs | 73 +- src/utils/trace.rs | 8 +- tests/chat_completions_detection.rs | 79 +- tests/chat_detection.rs | 1 - 42 files changed, 2066 insertions(+), 5003 deletions(-) delete mode 100644 src/orchestrator/chat_completions_detection.rs delete mode 100644 src/orchestrator/detector_processing.rs delete mode 100644 src/orchestrator/detector_processing/content.rs create mode 100644 src/orchestrator/handlers.rs create mode 100644 src/orchestrator/handlers/chat_completions_detection.rs create mode 100644 src/orchestrator/handlers/chat_completions_detection/streaming.rs create mode 100644 src/orchestrator/handlers/chat_completions_detection/unary.rs create mode 100644 src/orchestrator/handlers/chat_detection.rs create mode 100644 src/orchestrator/handlers/classification_with_gen.rs create mode 100644 src/orchestrator/handlers/context_docs_detection.rs create mode 100644 src/orchestrator/handlers/detection_on_generation.rs create mode 100644 src/orchestrator/handlers/generation_with_detection.rs create mode 100644 src/orchestrator/handlers/streaming_classification_with_gen.rs create mode 100644 src/orchestrator/handlers/streaming_content_detection.rs create mode 100644 src/orchestrator/handlers/text_content_detection.rs delete mode 100644 src/orchestrator/streaming.rs delete mode 100644 src/orchestrator/streaming/aggregator.rs delete mode 100644 src/orchestrator/streaming_content_detection.rs delete mode 100644 src/orchestrator/streaming_content_detection/aggregator.rs delete mode 100644 src/orchestrator/unary.rs diff --git a/src/args.rs b/src/args.rs index b5b38f03..152143ab 100644 --- a/src/args.rs +++ b/src/args.rs @@ -140,9 +140,9 @@ impl OtlpProtocol { #[derive(Debug, Clone, Copy, Default, PartialEq)] pub enum LogFormat { + Compact, #[default] Full, - Compact, Pretty, JSON, } diff --git a/src/clients.rs b/src/clients.rs index 0454dac6..cf7062b7 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -32,7 +32,7 @@ use hyper_timeout::TimeoutConnector; use hyper_util::rt::TokioExecutor; use tonic::{Request, metadata::MetadataMap}; use tower::{ServiceBuilder, timeout::TimeoutLayer}; -use tracing::{Span, debug, instrument}; +use tracing::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; use url::Url; @@ -205,7 +205,6 @@ impl ClientMap { } } -#[instrument(skip_all, fields(hostname = service_config.hostname))] pub async fn create_http_client( default_port: u16, service_config: &ServiceConfig, @@ -220,7 +219,6 @@ pub async fn create_http_client( base_url .set_port(Some(port)) .unwrap_or_else(|_| panic!("error setting port: {}", port)); - debug!(%base_url, "creating HTTP client"); let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC); let request_timeout = Duration::from_secs( @@ -257,7 +255,6 @@ pub async fn create_http_client( Ok(HttpClient::new(base_url, client)) } -#[instrument(skip_all, fields(hostname = service_config.hostname))] pub async fn create_grpc_client( default_port: u16, service_config: &ServiceConfig, @@ -270,7 +267,6 @@ pub async fn create_grpc_client( }; let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap(); base_url.set_port(Some(port)).unwrap(); - debug!(%base_url, "creating gRPC client"); let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC); let request_timeout = Duration::from_secs( service_config diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 922517a5..36904433 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -22,7 +22,7 @@ use axum::http::HeaderMap; use futures::{Future, Stream, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::{Code, Request, Response, Status, Streaming}; -use tracing::{Span, debug, info, instrument}; +use tracing::{Span, instrument}; use super::{ BoxStream, Client, Error, create_grpc_client, errors::grpc_to_http_code, @@ -68,7 +68,6 @@ impl ChunkerClient { } } - #[instrument(skip_all, fields(model_id))] pub async fn tokenization_task_predict( &self, model_id: &str, @@ -76,20 +75,17 @@ impl ChunkerClient { ) -> Result { let mut client = self.client.clone(); let request = request_with_headers(request, model_id); - debug!(?request, "sending client request"); let response = client.chunker_tokenization_task_predict(request).await?; let span = Span::current(); trace_context_from_grpc_response(&span, &response); Ok(response.into_inner()) } - #[instrument(skip_all, fields(model_id))] pub async fn bidi_streaming_tokenization_task_predict( &self, model_id: &str, request_stream: BoxStream, ) -> Result>, Error> { - info!("sending client stream request"); let mut client = self.client.clone(); let request = request_with_headers(request_stream, model_id); // NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors. diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 91d0e979..4b8cc103 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -21,7 +21,6 @@ use axum::http::HeaderMap; use http::header::CONTENT_TYPE; use hyper::StatusCode; use serde::Deserialize; -use tracing::instrument; use url::Url; use super::{ @@ -79,7 +78,6 @@ pub trait DetectorClientExt: HttpClientExt { } impl DetectorClientExt for C { - #[instrument(skip_all, fields(model_id, url))] async fn post_to_detector( &self, model_id: &str, diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index cef9af11..c763f3a9 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use hyper::HeaderMap; use serde::Serialize; -use tracing::{info, instrument}; +use tracing::info; use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; use crate::{ @@ -63,7 +63,6 @@ impl TextChatDetectorClient { &self.client } - #[instrument(skip_all, fields(model_id, ?headers))] pub async fn text_chat( &self, model_id: &str, diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index a4fbfca7..27943a86 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -20,7 +20,7 @@ use std::collections::BTreeMap; use async_trait::async_trait; use hyper::HeaderMap; use serde::{Deserialize, Serialize}; -use tracing::{info, instrument}; +use tracing::info; use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; use crate::{ @@ -61,7 +61,6 @@ impl TextContentsDetectorClient { &self.client } - #[instrument(skip_all, fields(model_id))] pub async fn text_contents( &self, model_id: &str, diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index 513ada99..ee3a1f63 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use hyper::HeaderMap; use serde::{Deserialize, Serialize}; -use tracing::{info, instrument}; +use tracing::info; use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; use crate::{ @@ -59,7 +59,6 @@ impl TextContextDocDetectorClient { &self.client } - #[instrument(skip_all, fields(model_id))] pub async fn text_context_doc( &self, model_id: &str, diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index e7c9890c..6ba6b82a 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use hyper::HeaderMap; use serde::Serialize; -use tracing::{info, instrument}; +use tracing::info; use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; use crate::{ @@ -59,7 +59,6 @@ impl TextGenerationDetectorClient { &self.client } - #[instrument(skip_all, fields(model_id))] pub async fn text_generation( &self, model_id: &str, diff --git a/src/clients/generation.rs b/src/clients/generation.rs index afd54e59..507864b6 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -18,7 +18,6 @@ use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; use hyper::HeaderMap; -use tracing::{debug, instrument}; use super::{BoxStream, Client, Error, NlpClient, TgisClient}; use crate::{ @@ -63,7 +62,6 @@ impl GenerationClient { Self(None) } - #[instrument(skip_all, fields(model_id))] pub async fn tokenize( &self, model_id: String, @@ -79,19 +77,15 @@ impl GenerationClient { return_offsets: false, truncate_input_tokens: 0, }; - debug!(provider = "tgis", ?request, "sending tokenize request"); let mut response = client.tokenize(request, headers).await?; - debug!(provider = "tgis", ?response, "received tokenize response"); let response = response.responses.swap_remove(0); Ok((response.token_count, response.tokens)) } Some(GenerationClientInner::Nlp(client)) => { let request = TokenizationTaskRequest { text }; - debug!(provider = "nlp", ?request, "sending tokenize request"); let response = client .tokenization_task_predict(&model_id, request, headers) .await?; - debug!(provider = "nlp", ?response, "received tokenize response"); let tokens = response .results .into_iter() @@ -103,7 +97,6 @@ impl GenerationClient { } } - #[instrument(skip_all, fields(model_id))] pub async fn generate( &self, model_id: String, @@ -120,9 +113,7 @@ impl GenerationClient { requests: vec![GenerationRequest { text }], params, }; - debug!(provider = "tgis", ?request, "sending generate request"); let response = client.generate(request, headers).await?; - debug!(provider = "tgis", ?response, "received generate response"); Ok(response.into()) } Some(GenerationClientInner::Nlp(client)) => { @@ -157,18 +148,15 @@ impl GenerationClient { ..Default::default() } }; - debug!(provider = "nlp", ?request, "sending generate request"); let response = client .text_generation_task_predict(&model_id, request, headers) .await?; - debug!(provider = "nlp", ?response, "received generate response"); Ok(response.into()) } None => Err(Error::ModelNotFound { model_id }), } } - #[instrument(skip_all, fields(model_id))] pub async fn generate_stream( &self, model_id: String, @@ -185,11 +173,6 @@ impl GenerationClient { request: Some(GenerationRequest { text }), params, }; - debug!( - provider = "tgis", - ?request, - "sending generate_stream request" - ); let response_stream = client .generate_stream(request, headers) .await? @@ -229,11 +212,6 @@ impl GenerationClient { ..Default::default() } }; - debug!( - provider = "nlp", - ?request, - "sending generate_stream request" - ); let response_stream = client .server_streaming_text_generation_task_predict(&model_id, request, headers) .await? diff --git a/src/clients/http.rs b/src/clients/http.rs index 7bfdf917..e23d167a 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -20,7 +20,7 @@ use std::{fmt::Debug, ops::Deref, time::Duration}; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use hyper::{ HeaderMap, Method, Request, StatusCode, - body::{Body, Bytes, Incoming}, + body::{Bytes, Incoming}, }; use hyper_rustls::HttpsConnector; use hyper_timeout::TimeoutConnector; @@ -36,7 +36,7 @@ use tower_http::{ Trace, TraceLayer, }, }; -use tracing::{Span, debug, error, info, info_span, instrument}; +use tracing::{Span, error, info, info_span}; use tracing_opentelemetry::OpenTelemetrySpanExt; use url::Url; @@ -137,7 +137,6 @@ impl HttpClient { self.base_url.join(path).unwrap() } - #[instrument(skip_all, fields(url))] pub async fn get( &self, url: Url, @@ -147,7 +146,6 @@ impl HttpClient { self.send(url, Method::GET, headers, body).await } - #[instrument(skip_all, fields(url))] pub async fn post( &self, url: Url, @@ -157,7 +155,6 @@ impl HttpClient { self.send(url, Method::POST, headers, body).await } - #[instrument(skip_all, fields(url))] pub async fn send( &self, url: Url, @@ -172,12 +169,6 @@ impl HttpClient { .uri(url.as_uri()); match builder.headers_mut() { Some(headers_mut) => { - debug!( - ?url, - ?headers, - ?body, - "sending client request" - ); headers_mut.extend(headers); let body = Full::new(Bytes::from(serde_json::to_vec(&body).map_err(|e| { @@ -211,13 +202,6 @@ impl HttpClient { message: format!("client request timeout: {}", e), }), }?; - - debug!( - status = ?response.status(), - headers = ?response.headers(), - size = ?response.size_hint(), - "incoming client response" - ); let span = Span::current(); trace::trace_context_from_http_response(&span, &response); Ok(response.into()) diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 4e03c5e6..bd968474 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -24,7 +24,6 @@ use http_body_util::BodyExt; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; -use tracing::{info, instrument}; use super::{ Client, Error, HttpClient, create_http_client, detector::ContentAnalysisResponse, @@ -70,14 +69,12 @@ impl OpenAiClient { &self.client } - #[instrument(skip_all, fields(request.model))] pub async fn chat_completions( &self, request: ChatCompletionsRequest, headers: HeaderMap, ) -> Result { let url = self.inner().endpoint(CHAT_COMPLETIONS_ENDPOINT); - info!("sending Open AI chat completion request to {}", url); if request.stream { let (tx, rx) = mpsc::channel(32); let mut event_stream = self @@ -296,11 +293,11 @@ pub struct ChatCompletionsRequest { #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct DetectorConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub input: Option>, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub input: HashMap, - #[serde(skip_serializing_if = "Option::is_none")] - pub output: Option>, + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub output: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index 0c9ba649..ff7f845b 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -20,7 +20,7 @@ use axum::http::HeaderMap; use futures::{StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::Code; -use tracing::{Span, debug, instrument}; +use tracing::Span; use super::{ BoxStream, Client, Error, create_grpc_client, errors::grpc_to_http_code, @@ -52,14 +52,12 @@ impl TgisClient { Self { client } } - #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn generate( &self, request: BatchedGenerationRequest, headers: HeaderMap, ) -> Result { let request = grpc_request_with_headers(request, headers); - debug!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); let response = client.generate(request).await?; let span = Span::current(); @@ -67,14 +65,12 @@ impl TgisClient { Ok(response.into_inner()) } - #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn generate_stream( &self, request: SingleGenerationRequest, headers: HeaderMap, ) -> Result>, Error> { let request = grpc_request_with_headers(request, headers); - debug!(?request, "sending request to TGIS gRPC service"); let mut client = self.client.clone(); let response = client.generate_stream(request).await?; let span = Span::current(); @@ -82,7 +78,6 @@ impl TgisClient { Ok(response.into_inner().map_err(Into::into).boxed()) } - #[instrument(skip_all, fields(model_id = request.model_id))] pub async fn tokenize( &self, request: BatchedTokenizeRequest, @@ -90,7 +85,6 @@ impl TgisClient { ) -> Result { let mut client = self.client.clone(); let request = grpc_request_with_headers(request, headers); - debug!(?request, "sending request to TGIS gRPC service"); let response = client.tokenize(request).await?; let span = Span::current(); trace_context_from_grpc_response(&span, &response); @@ -98,7 +92,6 @@ impl TgisClient { } pub async fn model_info(&self, request: ModelInfoRequest) -> Result { - debug!(?request, "sending request to TGIS gRPC service"); let request = grpc_request_with_headers(request, HeaderMap::new()); let mut client = self.client.clone(); let response = client.model_info(request).await?; diff --git a/src/models.rs b/src/models.rs index 8c7c5a28..21c4c4fd 100644 --- a/src/models.rs +++ b/src/models.rs @@ -135,10 +135,12 @@ impl GuardrailsHttpRequest { // Validate detector params if let Some(config) = guardrail_config { - if let Some(input_detectors) = config.input_detectors() { + let input_detectors = config.input.as_ref().map(|input| &input.models); + let output_detectors = config.output.as_ref().map(|output| &output.models); + if let Some(input_detectors) = input_detectors { validate_detector_params(input_detectors)?; } - if let Some(output_detectors) = config.output_detectors() { + if let Some(output_detectors) = output_detectors { validate_detector_params(output_detectors)?; } } @@ -165,12 +167,18 @@ impl GuardrailsConfig { self.input.as_ref().and_then(|input| input.masks.as_deref()) } - pub fn input_detectors(&self) -> Option<&HashMap> { - self.input.as_ref().map(|input| &input.models) + pub fn input_detectors(&self) -> HashMap { + self.input + .as_ref() + .map(|input| input.models.clone()) + .unwrap_or_default() } - pub fn output_detectors(&self) -> Option<&HashMap> { - self.output.as_ref().map(|output| &output.models) + pub fn output_detectors(&self) -> HashMap { + self.output + .as_ref() + .map(|output| output.models.clone()) + .unwrap_or_default() } } @@ -516,6 +524,22 @@ pub struct DetectionWarning { pub message: Option, } +impl DetectionWarning { + pub fn unsuitable_input() -> Self { + DetectionWarning { + id: Some(DetectionWarningReason::UnsuitableInput), + message: Some(UNSUITABLE_INPUT_MESSAGE.to_string()), + } + } + + pub fn unsuitable_output() -> Self { + DetectionWarning { + id: Some(DetectionWarningReason::UnsuitableOutput), + message: Some(UNSUITABLE_OUTPUT_MESSAGE.to_string()), + } + } +} + /// Enumeration of warning reasons on input detection /// Since this enum's variants do not hold data, we can easily define them as `#[repr(C)]` /// which helps with FFI. diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 6452a0ea..0f6ee9ba 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -14,43 +14,28 @@ limitations under the License. */ - pub mod errors; pub use errors::Error; -use futures::Stream; -pub mod chat_completions_detection; pub mod common; -pub mod detector_processing; -pub mod streaming; -pub mod streaming_content_detection; +pub mod handlers; pub mod types; -pub mod unary; -use std::{collections::HashMap, pin::Pin, sync::Arc}; +use std::sync::Arc; -use axum::http::header::HeaderMap; -use opentelemetry::trace::TraceId; use tokio::{sync::RwLock, time::Instant}; use tracing::{debug, info}; use crate::{ clients::{ - self, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, + ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, chunker::ChunkerClient, detector::{ TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, - text_context_doc::ContextType, }, - openai::{ChatCompletionsRequest, OpenAiClient}, + openai::OpenAiClient, }, config::{DetectorType, GenerationProvider, OrchestratorConfig}, health::HealthCheckCache, - models::{ - ChatDetectionHttpRequest, ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, - DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig, - GuardrailsHttpRequest, GuardrailsTextGenerationParameters, - StreamingContentDetectionRequest, TextContentDetectionHttpRequest, - }, }; #[cfg_attr(test, derive(Default))] @@ -214,272 +199,3 @@ async fn create_clients(config: &OrchestratorConfig) -> Result } Ok(clients) } - -#[derive(Debug, Clone)] -pub struct Chunk { - pub offset: usize, - pub text: String, -} - -#[derive(Debug)] -pub struct ClassificationWithGenTask { - pub trace_id: TraceId, - pub model_id: String, - pub inputs: String, - pub guardrails_config: GuardrailsConfig, - pub text_gen_parameters: Option, - pub headers: HeaderMap, -} - -impl ClassificationWithGenTask { - pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { - Self { - trace_id, - model_id: request.model_id, - inputs: request.inputs, - guardrails_config: request.guardrail_config.unwrap_or_default(), - text_gen_parameters: request.text_gen_parameters, - headers, - } - } -} - -/// Task for the /api/v2/text/detection/content endpoint -#[derive(Debug)] -pub struct GenerationWithDetectionTask { - /// Unique identifier of request trace - pub trace_id: TraceId, - - /// Model ID of the LLM - pub model_id: String, - - /// User prompt to be sent to the LLM - pub prompt: String, - - /// Detectors configuration - pub detectors: HashMap, - - /// LLM Parameters - pub text_gen_parameters: Option, - - // Headermap - pub headers: HeaderMap, -} - -impl GenerationWithDetectionTask { - pub fn new( - trace_id: TraceId, - request: GenerationWithDetectionHttpRequest, - headers: HeaderMap, - ) -> Self { - Self { - trace_id, - model_id: request.model_id, - prompt: request.prompt, - detectors: request.detectors, - text_gen_parameters: request.text_gen_parameters, - headers, - } - } -} - -/// Task for the /api/v2/text/detection/content endpoint -#[derive(Debug)] -pub struct TextContentDetectionTask { - /// Unique identifier of request trace - pub trace_id: TraceId, - - /// Content to run detection on - pub content: String, - - /// Detectors configuration - pub detectors: HashMap, - - // Headermap - pub headers: HeaderMap, -} - -impl TextContentDetectionTask { - pub fn new( - trace_id: TraceId, - request: TextContentDetectionHttpRequest, - headers: HeaderMap, - ) -> Self { - Self { - trace_id, - content: request.content, - detectors: request.detectors, - headers, - } - } -} - -/// Task for the /api/v1/text/task/detection/context endpoint -#[derive(Debug)] -pub struct ContextDocsDetectionTask { - /// Unique identifier of request trace - pub trace_id: TraceId, - - /// Content to run detection on - pub content: String, - - /// Context type - pub context_type: ContextType, - - /// Context - pub context: Vec, - - /// Detectors configuration - pub detectors: HashMap, - - // Headermap - pub headers: HeaderMap, -} - -impl ContextDocsDetectionTask { - pub fn new(trace_id: TraceId, request: ContextDocsHttpRequest, headers: HeaderMap) -> Self { - Self { - trace_id, - content: request.content, - context_type: request.context_type, - context: request.context, - detectors: request.detectors, - headers, - } - } -} - -/// Task for the /api/v2/text/detection/chat endpoint -#[derive(Debug)] -pub struct ChatDetectionTask { - /// Request unique identifier - pub trace_id: TraceId, - - /// Detectors configuration - pub detectors: HashMap, - - // Messages to run detection on - pub messages: Vec, - - // Tools definitions, optional - pub tools: Vec, - - // Headermap - pub headers: HeaderMap, -} - -impl ChatDetectionTask { - pub fn new(trace_id: TraceId, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self { - Self { - trace_id, - detectors: request.detectors, - messages: request.messages, - tools: request.tools, - headers, - } - } -} - -/// Task for the /api/v2/text/detection/generated endpoint -#[derive(Debug)] -pub struct DetectionOnGenerationTask { - /// Unique identifier of request trace - pub trace_id: TraceId, - - /// User prompt to be sent to the LLM - pub prompt: String, - - /// Text generated by the LLM - pub generated_text: String, - - /// Detectors configuration - pub detectors: HashMap, - - // Headermap - pub headers: HeaderMap, -} - -impl DetectionOnGenerationTask { - pub fn new( - trace_id: TraceId, - request: DetectionOnGeneratedHttpRequest, - headers: HeaderMap, - ) -> Self { - Self { - trace_id, - prompt: request.prompt, - generated_text: request.generated_text, - detectors: request.detectors, - headers, - } - } -} - -#[allow(dead_code)] -#[derive(Debug)] -pub struct StreamingClassificationWithGenTask { - pub trace_id: TraceId, - pub model_id: String, - pub inputs: String, - pub guardrails_config: GuardrailsConfig, - pub text_gen_parameters: Option, - pub headers: HeaderMap, -} - -impl StreamingClassificationWithGenTask { - pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { - Self { - trace_id, - model_id: request.model_id, - inputs: request.inputs, - guardrails_config: request.guardrail_config.unwrap_or_default(), - text_gen_parameters: request.text_gen_parameters, - headers, - } - } -} - -#[derive(Debug)] -pub struct ChatCompletionsDetectionTask { - /// Unique identifier of request trace - pub trace_id: TraceId, - /// Chat completion request - pub request: ChatCompletionsRequest, - // Headermap - pub headers: HeaderMap, -} - -impl ChatCompletionsDetectionTask { - pub fn new(trace_id: TraceId, request: ChatCompletionsRequest, headers: HeaderMap) -> Self { - Self { - trace_id, - request, - headers, - } - } -} - -pub struct StreamingContentDetectionTask { - pub trace_id: TraceId, - pub headers: HeaderMap, - pub detectors: HashMap, - pub input_stream: - Pin> + Send>>, -} - -impl StreamingContentDetectionTask { - pub fn new( - trace_id: TraceId, - headers: HeaderMap, - input_stream: Pin< - Box> + Send>, - >, - ) -> Self { - Self { - trace_id, - headers, - detectors: HashMap::default(), - input_stream, - } - } -} diff --git a/src/orchestrator/chat_completions_detection.rs b/src/orchestrator/chat_completions_detection.rs deleted file mode 100644 index ce45c61b..00000000 --- a/src/orchestrator/chat_completions_detection.rs +++ /dev/null @@ -1,686 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ -use std::{ - collections::{BTreeMap, HashMap, btree_map}, - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; - -use axum::http::HeaderMap; -use futures::future::{join_all, try_join_all}; -use serde::{Deserialize, Serialize}; -use tracing::{debug, info, instrument}; -use uuid::Uuid; - -use super::{ChatCompletionsDetectionTask, Context, Error, Orchestrator}; -use crate::{ - clients::{ - detector::{ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse}, - openai::{ - ChatCompletion, ChatCompletionChoice, ChatCompletionsRequest, ChatCompletionsResponse, - ChatDetections, Content, DetectionResult, InputDetectionResult, OpenAiClient, - OrchestratorWarning, OutputDetectionResult, Role, - }, - }, - config::DetectorType, - models::{ - DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE, - }, - orchestrator::{ - Chunk, - detector_processing::content, - unary::{chunk, detect_content}, - }, -}; - -/// Internal structure to capture chat messages (both request and response) -/// and prepare it for processing -#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] -pub struct ChatMessageInternal { - /// Index of the message - pub message_index: u32, - /// The role of the messages author. - pub role: Role, - /// The contents of the message. - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - /// The refusal message by the assistant. (assistant message only) - #[serde(skip_serializing_if = "Option::is_none")] - pub refusal: Option, -} - -pub enum DetectorRequest { - ContentAnalysisRequest(ContentAnalysisRequest), - ChatDetectionRequest(ChatDetectionRequest), -} - -// Get Vec from ChatCompletionsRequest -impl From<&ChatCompletionsRequest> for Vec { - fn from(value: &ChatCompletionsRequest) -> Self { - value - .messages - .iter() - .enumerate() - .map(|(index, message)| ChatMessageInternal { - message_index: index as u32, - role: message.role.clone(), - content: message.content.clone(), - refusal: message.refusal.clone(), - }) - .collect() - } -} - -// Get Vec from ChatCompletion -impl From<&Box> for Vec { - fn from(value: &Box) -> Self { - value - .choices - .iter() - .map(|choice| ChatMessageInternal { - message_index: choice.index, - role: choice.message.role.clone(), - content: Some(Content::Text( - choice.message.content.clone().unwrap_or_default(), - )), - refusal: choice.message.refusal.clone(), - }) - .collect() - } -} - -// Get Vec from ChatCompletionChoice -impl From for Vec { - fn from(value: ChatCompletionChoice) -> Self { - vec![ChatMessageInternal { - message_index: value.index, - role: value.message.role, - content: Some(Content::Text(value.message.content.unwrap_or_default())), - refusal: value.message.refusal, - }] - } -} -// TODO: Add from function for streaming response as well - -impl Orchestrator { - #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] - - pub async fn handle_chat_completions_detection( - &self, - task: ChatCompletionsDetectionTask, - ) -> Result { - info!("handling chat completions detection task"); - let ctx = self.ctx.clone(); - - let task_handle = tokio::spawn(async move { - // Convert the request into a format that can be used for processing - let chat_messages = Vec::::from(&task.request); - let detectors = task.request.detectors.clone().unwrap_or_default(); - - let input_detections = match detectors.input { - Some(detectors) if !detectors.is_empty() => { - // Call out to input detectors using chunk - message_detection(&ctx, &detectors, chat_messages, &task.headers).await? - } - _ => None, - }; - - debug!(?input_detections); - - if let Some(input_detections) = input_detections { - let detections = sort_detections(input_detections); - - Ok(ChatCompletionsResponse::Unary(Box::new(ChatCompletion { - id: Uuid::new_v4().simple().to_string(), - model: task.request.model.clone(), - choices: vec![], - created: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64, - detections: Some(ChatDetections { - input: detections - .into_iter() - .map(|detection_result| InputDetectionResult { - message_index: detection_result.index, - results: detection_result.results, - }) - .collect(), - output: vec![], - }), - warnings: vec![OrchestratorWarning::new( - DetectionWarningReason::UnsuitableInput, - UNSUITABLE_INPUT_MESSAGE, - )], - ..Default::default() - }))) - } else { - let client = ctx - .clients - .get_as::("chat_generation") - .expect("chat_generation client not found"); - let mut chat_request = task.request; - let model_id = chat_request.model.clone(); - // Remove detectors as chat completion server would reject extra parameter - chat_request.detectors = None; - let headers = task.headers.clone(); - let chat_completions = client - .chat_completions(chat_request, headers) - .await - .map_err(|error| Error::ChatCompletionRequestFailed { - id: model_id.clone(), - error, - })?; - - match handle_output_detections( - &chat_completions, - detectors.output, - ctx, - &task.headers, - model_id, - ) - .await - { - Some(chat_completion_detections) => Ok(chat_completion_detections), - None => Ok(chat_completions), - } - } - }); - - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => Ok(result), - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - // TODO: Transform error from chat completion client - Err(error) - } - // Task cancelled or panicked - Err(error) => { - // TODO: Transform error from chat completion client - let error = error.into(); - Err(error) - } - } - } -} - -#[instrument(skip_all)] -// pub async fn input_detection( -pub async fn message_detection( - ctx: &Arc, - detectors: &HashMap, - chat_messages: Vec, - headers: &HeaderMap, -) -> Result>, Error> { - debug!(?detectors, "starting input detection on chat completions"); - - let ctx = ctx.clone(); - - // pre-process chat messages based on individual detectors to prepare for chunking - let processed_chat_messages = preprocess_chat_messages(&ctx, detectors, chat_messages)?; - - // Call out to the chunker to get chunks of messages based on detector type - let chunks = detector_chunk_task(&ctx, processed_chat_messages).await?; - - // We run over each detector and take out messages that are appropriate for that detector. - let tasks = detectors - .iter() - .flat_map(|(detector_id, detector_params)| { - let detector_id = detector_id.clone(); - let detector_config = ctx - .config - .detectors - .get(&detector_id) - .unwrap_or_else(|| panic!("detector config not found for {}", detector_id)); - let default_threshold = detector_config.default_threshold; - - let detector_type = &detector_config.r#type; - - // Get chunks corresponding to each message - let messages = chunks.get(&detector_id).unwrap().clone(); - - match detector_type { - DetectorType::TextContents => { - // spawn parallel processes for each message index and run detection on them. - messages - .into_iter() - .map(|(index, chunks)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - let headers = headers.clone(); - - tokio::spawn({ - async move { - // Call content detector on the chunks of particular message - // and return the index and detection results - let detections = detect_content( - ctx.clone(), - detector_id.clone(), - default_threshold, - detector_params.clone(), - chunks, - headers.clone(), - ) - .await?; - Ok((index, detections)) - } - }) - }) - .collect::>() - } - _ => unimplemented!(), - } - }) - .collect::>(); - - // Await detections - let detections = try_join_all(tasks) - .await? - .into_iter() - .collect::, Error>>()?; - - // Build detection map - let mut detection_map: BTreeMap> = BTreeMap::new(); - for (index, detections) in detections { - if !detections.is_empty() { - match detection_map.entry(index) { - btree_map::Entry::Occupied(mut entry) => { - entry.get_mut().extend_from_slice(&detections); - } - btree_map::Entry::Vacant(entry) => { - entry.insert(detections); - } - } - } - } - - // Build vec of DetectionResult - // NOTE: seems unnecessary, could we just use the BTreeMap instead? - let detections = detection_map - .into_iter() - .map(|(index, results)| DetectionResult { index, results }) - .collect::>(); - - Ok((!detections.is_empty()).then_some(detections)) -} - -/// Function to filter messages based on individual detectors -/// Returns a HashMap of detector id to filtered messages -fn preprocess_chat_messages( - ctx: &Arc, - detectors: &HashMap, - messages: Vec, -) -> Result>, Error> { - detectors - .iter() - .map( - |(detector_id, _)| -> Result<(String, Vec), Error> { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_config = ctx - .config - .detectors - .get(&detector_id) - .unwrap_or_else(|| panic!("detector config not found for {}", detector_id)); - let detector_type = &detector_config.r#type; - // Filter messages based on detector type - let messages = match detector_type { - DetectorType::TextContents => content::filter_chat_messages(&messages), - _ => unimplemented!(), - }?; - Ok((detector_id, messages)) - }, - ) - .collect() -} - -// Function to chunk Vec based on the chunker id and return chunks in Vec form -// Output maps each detector_id with corresponding chunk -async fn detector_chunk_task( - ctx: &Arc, - detector_chat_messages: HashMap>, -) -> Result)>>, Error> { - let mut chunks = HashMap::new(); - - // TODO: Improve error handling for the code below - for (detector_id, chat_messages) in detector_chat_messages.into_iter() { - let chunk_tasks = chat_messages - .into_iter() - .map(|message| { - let Some(Content::Text(text)) = message.content else { - panic!("Only text content accepted") - }; - let offset: usize = 0; - let task = tokio::spawn({ - let detector_id = detector_id.clone(); - let ctx = ctx.clone(); - async move { - let chunker_id = ctx.config.get_chunker_id(&detector_id).unwrap(); - chunk(&ctx, chunker_id, offset, text).await - } - }); - // Return tuple of message index and task - (message.message_index, task) - // chunking_tasks.push((detector_id, task)); - }) - .collect::>(); - - let results = join_all(chunk_tasks.into_iter().map(|(index, handle)| async move { - match handle.await { - Ok(Ok(value)) => Ok((index, value)), // Success - Ok(Err(err)) => { - // Task returned an error - Err(err) - } - Err(_) => { - // Chunking failed - Err(Error::Other("Chunking task failed".to_string())) - } - } - })) - .await - .into_iter() - .collect::, Error>>(); - - match results { - Ok(chunk_value) => { - chunks.insert(detector_id.clone(), chunk_value); - Ok(()) - } - Err(err) => Err(err), - }? - } - - Ok(chunks) -} - -fn sort_detections(mut detections: Vec) -> Vec { - // Sort input detections by message_index - detections.sort_by_key(|value| value.index); - - detections - .into_iter() - .map(|mut detection| { - // sort detection by starting span - detection.results.sort_by_key(|value| value.start); - detection - }) - .collect::>() -} - -async fn handle_output_detections( - chat_completions: &ChatCompletionsResponse, - detector_output: Option>, - ctx: Arc, - headers: &HeaderMap, - model_id: String, -) -> Option { - if let ChatCompletionsResponse::Unary(chat_completion) = chat_completions { - let choices = Vec::::from(chat_completion); - - let output_detections = match detector_output { - Some(detectors) if !detectors.is_empty() => { - let tasks = choices.into_iter().map(|choice| { - tokio::spawn({ - let ctx = ctx.clone(); - let detectors = detectors.clone(); - let headers = headers.clone(); - async move { - let result = - message_detection(&ctx, &detectors, vec![choice], &headers).await; - - if let Ok(Some(detection_results)) = result { - return detection_results; - } - - vec![] - } - }) - }); - - let detections = try_join_all(tasks).await; - - match detections { - Ok(d) => Some( - d.iter() - .flatten() - .cloned() - .collect::>(), - ), - Err(_) => None, - } - } - _ => None, - }; - - debug!(?output_detections); - - match output_detections { - Some(output_detections) if !output_detections.is_empty() => { - let detections = sort_detections(output_detections); - - return Some(ChatCompletionsResponse::Unary(Box::new(ChatCompletion { - id: Uuid::new_v4().simple().to_string(), - object: chat_completion.object.clone(), - created: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64, - model: model_id.to_string(), - choices: chat_completion.choices.clone(), - usage: chat_completion.usage.clone(), - system_fingerprint: chat_completion.system_fingerprint.clone(), - service_tier: chat_completion.service_tier.clone(), - detections: Some(ChatDetections { - input: vec![], - output: detections - .into_iter() - .map(|detection_result| OutputDetectionResult { - choice_index: detection_result.index, - results: detection_result.results, - }) - .collect(), - }), - warnings: vec![OrchestratorWarning::new( - DetectionWarningReason::UnsuitableOutput, - UNSUITABLE_OUTPUT_MESSAGE, - )], - }))); - } - _ => {} - } - } - None -} - -#[cfg(test)] -mod tests { - use std::any::{Any, TypeId}; - - use super::*; - use crate::{ - config::DetectorConfig, - orchestrator::{ClientMap, OrchestratorConfig}, - }; - - // Test to verify preprocess_chat_messages works correctly for multiple content type detectors - // with single message in chat request - #[test] - fn pretest_process_chat_messages_multiple_content_detector() { - // Test setup - let clients = ClientMap::new(); - let detector_1_id = "detector1"; - let detector_2_id = "detector2"; - let mut ctx = Context::new(OrchestratorConfig::default(), clients); - // add detector - ctx.config.detectors.insert( - detector_1_id.to_string().clone(), - DetectorConfig { - ..Default::default() - }, - ); - ctx.config.detectors.insert( - detector_2_id.to_string().clone(), - DetectorConfig { - ..Default::default() - }, - ); - - let ctx = Arc::new(ctx); - let mut detectors = HashMap::new(); - detectors.insert(detector_1_id.to_string(), DetectorParams::new()); - detectors.insert(detector_2_id.to_string(), DetectorParams::new()); - - let messages = vec![ChatMessageInternal { - message_index: 0, - content: Some(Content::Text("hello".to_string())), - role: Role::Assistant, - ..Default::default() - }]; - let processed_messages = preprocess_chat_messages(&ctx, &detectors, messages).unwrap(); - // Assertions - assert!(processed_messages[detector_1_id].len() == 1); - assert!(processed_messages[detector_2_id].len() == 1); - } - - // Test preprocess_chat_messages returns error correctly for multiple content type detectors - // with incorrect message requirements - #[test] - fn pretest_process_chat_messages_error_handling() { - // Test setup - let clients = ClientMap::new(); - let detector_1_id = "detector1"; - let mut ctx = Context::new(OrchestratorConfig::default(), clients); - // add detector - ctx.config.detectors.insert( - detector_1_id.to_string().clone(), - DetectorConfig { - ..Default::default() - }, - ); - - let ctx = Arc::new(ctx); - let mut detectors = HashMap::new(); - detectors.insert(detector_1_id.to_string(), DetectorParams::new()); - - let messages = vec![ChatMessageInternal { - message_index: 0, - content: Some(Content::Text("hello".to_string())), - // Invalid role will return error used for testing - role: Role::Tool, - ..Default::default() - }]; - - let processed_messages = preprocess_chat_messages(&ctx, &detectors, messages); - - // Assertions - assert!(processed_messages.is_err()); - let error = processed_messages.unwrap_err(); - assert_eq!(error.type_id(), TypeId::of::()); - assert_eq!( - error.to_string(), - "validation error: Last message role must be user, assistant, or system" - ); - } - // validate chat completions request with invalid fields - // (nonexistant fields or typos) - #[test] - fn test_validate() { - // Additional unknown field (additional_field) - let json_data = r#" - { - "messages": [ - { - "content": "this is a nice sentence", - "role": "user", - "name": "string" - } - ], - "model": "my_model", - "additional_field": "test", - "n": 1, - "temperature": 1, - "top_p": 1, - "user": "user-1234", - "detectors": { - "input": {} - } - } - "#; - let result: Result = serde_json::from_str(json_data); - assert!(result.is_err()); - let error = result.unwrap_err().to_string(); - assert!( - error - .to_string() - .contains("unknown field `additional_field") - ); - - // Additional unknown field (additional_message") - let json_data = r#" - { - "messages": [ - { - "content": "this is a nice sentence", - "role": "user", - "name": "string", - "additional_msg: "test" - } - ], - "model": "my_model", - "n": 1, - "temperature": 1, - "top_p": 1, - "user": "user-1234", - "detectors": { - "input": {} - } - } - "#; - let result: Result = serde_json::from_str(json_data); - assert!(result.is_err()); - let error = result.unwrap_err().to_string(); - assert!(error.to_string().contains("unknown field `additional_msg")); - - // Additional unknown field (typo for input field in detectors) - let json_data = r#" - { - "messages": [ - { - "content": "this is a nice sentence", - "role": "user", - "name": "string" - } - ], - "model": "my_model", - "n": 1, - "temperature": 1, - "top_p": 1, - "user": "user-1234", - "detectors": { - "inputs": {} - } - } - "#; - let result: Result = serde_json::from_str(json_data); - assert!(result.is_err()); - let error = result.unwrap_err().to_string(); - assert!(error.to_string().contains("unknown field `inputs")); - } -} diff --git a/src/orchestrator/common/client.rs b/src/orchestrator/common/client.rs index f60c3058..672df54a 100644 --- a/src/orchestrator/common/client.rs +++ b/src/orchestrator/common/client.rs @@ -103,11 +103,12 @@ pub async fn detect_text_contents( detector_id: DetectorId, params: DetectorParams, chunks: Chunks, + apply_chunk_offset: bool, ) -> Result { let detector_id = detector_id.clone(); let contents = chunks - .into_iter() - .map(|chunk| chunk.text) + .iter() + .map(|chunk| chunk.text.clone()) .collect::>(); if contents.is_empty() { return Ok(Detections::default()); @@ -122,7 +123,26 @@ pub async fn detect_text_contents( error, })?; debug!(%detector_id, ?response, "received detector response"); - Ok(response.into()) + let detections = chunks + .into_iter() + .zip(response) + .flat_map(|(chunk, detections)| { + detections + .into_iter() + .map(|detection| { + let mut detection: Detection = detection.into(); + detection.detector_id = Some(detector_id.clone()); + if apply_chunk_offset { + let offset = chunk.start; + detection.start = detection.start.map(|start| start + offset); + detection.end = detection.end.map(|end| end + offset); + } + detection + }) + .collect::>() + }) + .collect::(); + Ok(detections) } /// Sends request to text generation detector client. @@ -146,7 +166,15 @@ pub async fn detect_text_generation( error, })?; debug!(%detector_id, ?response, "received detector response"); - Ok(response.into()) + let detections = response + .into_iter() + .map(|detection| { + let mut detection: Detection = detection.into(); + detection.detector_id = Some(detector_id.clone()); + detection + }) + .collect::(); + Ok(detections) } /// Sends request to text chat detector client. @@ -170,7 +198,15 @@ pub async fn detect_text_chat( error, })?; debug!(%detector_id, ?response, "received detector response"); - Ok(response.into()) + let detections = response + .into_iter() + .map(|detection| { + let mut detection: Detection = detection.into(); + detection.detector_id = Some(detector_id.clone()); + detection + }) + .collect::(); + Ok(detections) } /// Sends request to text context detector client. @@ -195,7 +231,15 @@ pub async fn detect_text_context( error, })?; debug!(%detector_id, ?response, "received detector response"); - Ok(response.into()) + let detections = response + .into_iter() + .map(|detection| { + let mut detection: Detection = detection.into(); + detection.detector_id = Some(detector_id.clone()); + detection + }) + .collect::(); + Ok(detections) } /// Sends request to openai chat completions client. diff --git a/src/orchestrator/common/tasks.rs b/src/orchestrator/common/tasks.rs index 637cb88c..3326676a 100644 --- a/src/orchestrator/common/tasks.rs +++ b/src/orchestrator/common/tasks.rs @@ -21,7 +21,7 @@ use futures::{StreamExt, TryStreamExt, future::try_join_all, stream}; use http::HeaderMap; use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, instrument}; +use tracing::{Instrument, debug, instrument}; use super::{client::*, utils::*}; use crate::{ @@ -39,7 +39,6 @@ use crate::{ }; /// Spawns chunk tasks. Returns a map of chunks. -#[instrument(skip_all)] pub async fn chunks( ctx: Arc, chunkers: Vec, @@ -55,42 +54,46 @@ pub async fn chunks( let inputs = inputs.clone(); // Spawn task for chunker // Chunkers are processed in-parallel - tokio::spawn(async move { - // Send concurrent requests for inputs - let chunks = stream::iter(inputs) - .map(|(offset, text)| { - let ctx = ctx.clone(); - let chunker_id = chunker_id.clone(); - async move { - if chunker_id == DEFAULT_CHUNKER_ID { - debug!("using whole doc chunker"); - // Return single chunk - return Ok(whole_doc_chunk(offset, text)); + tokio::spawn( + async move { + // Send concurrent requests for inputs + let chunks = stream::iter(inputs) + .map(|(offset, text)| { + let ctx = ctx.clone(); + let chunker_id = chunker_id.clone(); + async move { + if chunker_id == DEFAULT_CHUNKER_ID { + debug!("using whole doc chunker"); + // Return single chunk + return Ok(whole_doc_chunk(offset, text)); + } + let client = ctx + .clients + .get_as::(&chunker_id) + .ok_or_else(|| Error::ChunkerNotFound(chunker_id.clone()))?; + let chunks = chunk(client, chunker_id.clone(), text) + .await? + .into_iter() + .map(|mut chunk| { + chunk.start += offset; + chunk.end += offset; + chunk + }) + .collect::(); + Ok::<_, Error>(chunks) } - let client = ctx - .clients - .get_as::(&chunker_id) - .ok_or_else(|| Error::ChunkerNotFound(chunker_id.clone()))?; - let chunks = chunk(client, chunker_id.clone(), text) - .await? - .into_iter() - .map(|mut chunk| { - chunk.start += offset; - chunk.end += offset; - chunk - }) - .collect::(); - Ok::<_, Error>(chunks) - } - }) - .buffer_unordered(ctx.config.chunker_concurrent_requests) - .try_collect::>() - .await? - .into_iter() - .flatten() - .collect::(); - Ok::<(ChunkerId, Chunks), Error>((chunker_id, chunks)) - }) + .in_current_span() + }) + .buffer_unordered(ctx.config.chunker_concurrent_requests) + .try_collect::>() + .await? + .into_iter() + .flatten() + .collect::(); + Ok::<(ChunkerId, Chunks), Error>((chunker_id, chunks)) + } + .in_current_span(), + ) }) .collect::>(); let chunk_map = try_join_all(tasks) @@ -102,7 +105,6 @@ pub async fn chunks( /// Spawns chunk streaming tasks. /// Returns a map of chunk broadcast channels. -#[instrument(skip_all)] pub async fn chunk_streams( ctx: Arc, chunkers: Vec, @@ -154,26 +156,29 @@ fn whole_doc_chunk_stream( // Create output channel let (output_tx, output_rx) = mpsc::channel(1); // Spawn task to collect input channel - tokio::spawn(async move { - // Collect input channel - // Alternatively, wrap receiver in BroadcastStream and collect() via StreamExt - let mut inputs = Vec::new(); - while let Ok(input) = input_broadcast_rx.recv().await.unwrap() { - inputs.push(input); + tokio::spawn( + async move { + // Collect input channel + // Alternatively, wrap receiver in BroadcastStream and collect() via StreamExt + let mut inputs = Vec::new(); + while let Ok(input) = input_broadcast_rx.recv().await.unwrap() { + inputs.push(input); + } + // Build chunk + let (indices, text): (Vec<_>, Vec<_>) = inputs.into_iter().unzip(); + let text = text.concat(); + let chunk = Chunk { + input_start_index: 0, + input_end_index: indices.last().copied().unwrap_or_default(), + start: 0, + end: text.chars().count(), + text, + }; + // Send chunk to output channel + let _ = output_tx.send(Ok::<_, Error>(chunk)).await; } - // Build chunk - let (indices, text): (Vec<_>, Vec<_>) = inputs.into_iter().unzip(); - let text = text.concat(); - let chunk = Chunk { - input_start_index: 0, - input_end_index: indices.last().copied().unwrap_or_default(), - start: 0, - end: text.chars().count(), - text, - }; - // Send chunk to output channel - let _ = output_tx.send(Ok::<_, Error>(chunk)).await; - }); + .in_current_span(), + ); Ok::<_, Error>(ReceiverStream::new(output_rx).boxed()) } @@ -206,24 +211,28 @@ pub async fn text_contents_detections( .map(|(detector_id, mut params, chunks)| { let ctx = ctx.clone(); let headers = headers.clone(); - let threshold = params.pop_threshold().unwrap_or_default(); + let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; + let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { let client = ctx .clients .get_as::(&detector_id) .unwrap(); - let detections = - detect_text_contents(client, headers, detector_id.clone(), params, chunks) - .await? - .into_iter() - .filter(|detection| detection.score >= threshold) - .map(|mut detection| { - detection.detector_id = Some(detector_id.clone()); - detection - }) - .collect::(); + let detections = detect_text_contents( + client, + headers, + detector_id.clone(), + params, + chunks.clone(), + true, + ) + .await? + .into_iter() + .filter(|detection| detection.score >= threshold) + .collect::(); Ok::<_, Error>(detections) } + .in_current_span() }) .buffer_unordered(ctx.config.detector_concurrent_requests) .try_collect::>() @@ -251,58 +260,64 @@ pub async fn text_contents_detection_streams( for (detector_id, mut params) in detectors { let ctx = ctx.clone(); let headers = headers.clone(); - let threshold = params.pop_threshold().unwrap_or_default(); + let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; + let threshold = params.pop_threshold().unwrap_or(default_threshold); let chunker_id = ctx.config.get_chunker_id(&detector_id).unwrap(); // Subscribe to chunk broadcast channel let mut chunk_rx = chunk_stream_map.get(&chunker_id).unwrap().subscribe(); // Create detection channel - let (detection_tx, detection_rx) = mpsc::channel(32); + let (detection_tx, detection_rx) = mpsc::channel(128); // Spawn detection task - tokio::spawn(async move { - while let Ok(result) = chunk_rx.recv().await { - match result { - Ok(chunk) => { - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); - match detect_text_contents( - client, - headers.clone(), - detector_id.clone(), - params.clone(), - vec![chunk.clone()].into(), - ) - .await - { - Ok(detections) => { - // Apply threshold and set detector_id - let detections = detections - .into_iter() - .filter(|detection| detection.score >= threshold) - .map(|mut detection| { - detection.detector_id = Some(detector_id.clone()); - detection - }) - .collect::(); - // Send to detection channel - let _ = detection_tx - .send(Ok((input_id, detector_id.clone(), chunk, detections))) - .await; - } - Err(error) => { - // Send error to detection channel - let _ = detection_tx.send(Err(error)).await; + tokio::spawn( + async move { + while let Ok(result) = chunk_rx.recv().await { + match result { + Ok(chunk) => { + let client = ctx + .clients + .get_as::(&detector_id) + .unwrap(); + match detect_text_contents( + client, + headers.clone(), + detector_id.clone(), + params.clone(), + vec![chunk.clone()].into(), + false, + ) + .await + { + Ok(detections) => { + // Apply threshold + let detections = detections + .into_iter() + .filter(|detection| detection.score >= threshold) + .collect::(); + // Send to detection channel + let _ = detection_tx + .send(Ok(( + input_id, + detector_id.clone(), + chunk, + detections, + ))) + .await; + } + Err(error) => { + // Send error to detection channel + let _ = detection_tx.send(Err(error)).await; + } } } - } - Err(error) => { - // Send error to detection channel - let _ = detection_tx.send(Err(error)).await; + Err(error) => { + // Send error to detection channel + let _ = detection_tx.send(Err(error)).await; + } } } } - }); + .in_current_span(), + ); let detection_stream = ReceiverStream::new(detection_rx).boxed(); streams.push(detection_stream); } @@ -315,11 +330,10 @@ pub async fn text_contents_detection_streams( pub async fn text_generation_detections( ctx: Arc, headers: HeaderMap, - detectors: &HashMap, - input_id: InputId, + detectors: HashMap, prompt: String, generated_text: String, -) -> Result<(InputId, Detections), Error> { +) -> Result { let inputs = detectors .iter() .map(|(detector_id, params)| { @@ -336,7 +350,8 @@ pub async fn text_generation_detections( .map(|(detector_id, mut params, prompt, generated_text)| { let ctx = ctx.clone(); let headers = headers.clone(); - let threshold = params.pop_threshold().unwrap_or_default(); + let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; + let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { let client = ctx .clients @@ -353,19 +368,16 @@ pub async fn text_generation_detections( .await? .into_iter() .filter(|detection| detection.score >= threshold) - .map(|mut detection| { - detection.detector_id = Some(detector_id.clone()); - detection - }) .collect::(); Ok::<_, Error>(detections) } + .in_current_span() }) .buffer_unordered(ctx.config.detector_concurrent_requests) .try_collect::>() .await?; let detections = results.into_iter().flatten().collect::(); - Ok((input_id, detections)) + Ok(detections) } /// Spawns text chat detection tasks. @@ -374,11 +386,10 @@ pub async fn text_generation_detections( pub async fn text_chat_detections( ctx: Arc, headers: HeaderMap, - detectors: &HashMap, - input_id: InputId, + detectors: HashMap, messages: Vec, tools: Vec, -) -> Result<(InputId, Detections), Error> { +) -> Result { let inputs = detectors .iter() .map(|(detector_id, params)| { @@ -395,7 +406,8 @@ pub async fn text_chat_detections( .map(|(detector_id, mut params, messages, tools)| { let ctx = ctx.clone(); let headers = headers.clone(); - let threshold = params.pop_threshold().unwrap_or_default(); + let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; + let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { let client = ctx .clients @@ -412,19 +424,16 @@ pub async fn text_chat_detections( .await? .into_iter() .filter(|detection| detection.score >= threshold) - .map(|mut detection| { - detection.detector_id = Some(detector_id.clone()); - detection - }) .collect::(); Ok::<_, Error>(detections) } + .in_current_span() }) .buffer_unordered(ctx.config.detector_concurrent_requests) .try_collect::>() .await?; let detections = results.into_iter().flatten().collect::(); - Ok((input_id, detections)) + Ok(detections) } /// Spawns text context detection tasks. @@ -433,12 +442,11 @@ pub async fn text_chat_detections( pub async fn text_context_detections( ctx: Arc, headers: HeaderMap, - detectors: &HashMap, - input_id: InputId, + detectors: HashMap, content: String, context_type: ContextType, context: Vec, -) -> Result<(InputId, Detections), Error> { +) -> Result { let inputs = detectors .iter() .map(|(detector_id, params)| { @@ -457,7 +465,9 @@ pub async fn text_context_detections( |(detector_id, mut params, content, context_type, context)| { let ctx = ctx.clone(); let headers = headers.clone(); - let threshold = params.pop_threshold().unwrap_or_default(); + let default_threshold = + ctx.config.detector(&detector_id).unwrap().default_threshold; + let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { let client = ctx .clients @@ -475,20 +485,17 @@ pub async fn text_context_detections( .await? .into_iter() .filter(|detection| detection.score >= threshold) - .map(|mut detection| { - detection.detector_id = Some(detector_id.clone()); - detection - }) .collect::(); Ok::<_, Error>(detections) } + .in_current_span() }, ) .buffer_unordered(ctx.config.detector_concurrent_requests) .try_collect::>() .await?; let detections = results.into_iter().flatten().collect::(); - Ok((input_id, detections)) + Ok(detections) } /// Fans-out a stream to a broadcast channel. @@ -496,7 +503,7 @@ pub fn broadcast_stream(mut stream: BoxStream) -> broadcast::Sender where T: Clone + Send + 'static, { - let (broadcast_tx, _) = broadcast::channel(32); + let (broadcast_tx, _) = broadcast::channel(128); tokio::spawn({ let broadcast_tx = broadcast_tx.clone(); async move { diff --git a/src/orchestrator/common/utils.rs b/src/orchestrator/common/utils.rs index f8267f9b..82a5d4a6 100644 --- a/src/orchestrator/common/utils.rs +++ b/src/orchestrator/common/utils.rs @@ -58,6 +58,13 @@ pub fn get_chunker_ids( .collect::, Error>>() } +/// Returns the current unix timestamp. +pub fn current_timestamp() -> std::time::Duration { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() +} + /// Updates an orchestrator config, adding entries for mock servers. /// TODO: move this to the test crate, once created. #[cfg(test)] diff --git a/src/orchestrator/detector_processing.rs b/src/orchestrator/detector_processing.rs deleted file mode 100644 index 88cbeb2a..00000000 --- a/src/orchestrator/detector_processing.rs +++ /dev/null @@ -1,18 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -pub mod content; diff --git a/src/orchestrator/detector_processing/content.rs b/src/orchestrator/detector_processing/content.rs deleted file mode 100644 index bffc8ef2..00000000 --- a/src/orchestrator/detector_processing/content.rs +++ /dev/null @@ -1,118 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ -use crate::{ - clients::openai::{Content, Role}, - models::ValidationError, - orchestrator::chat_completions_detection::ChatMessageInternal, -}; - -/// Function to get content analysis request from chat message by applying rules -pub fn filter_chat_messages( - messages: &[ChatMessageInternal], -) -> Result, ValidationError> { - // Get last message - if messages.is_empty() { - return Err(ValidationError::Invalid("No messages provided".into())); - } - let message = messages.last().unwrap().clone(); - - // Validate message: - // 1. Has text content - if !matches!(message.content, Some(Content::Text(_))) { - return Err(ValidationError::Invalid( - "Last message content must be text".into(), - )); - } - // 2. Role is user | assistant | system - if !matches!(message.role, Role::User | Role::Assistant | Role::System) { - return Err(ValidationError::Invalid( - "Last message role must be user, assistant, or system".into(), - )); - } - - Ok(vec![ChatMessageInternal { - message_index: message.message_index, - role: message.role, - content: message.content, - refusal: message.refusal, - }]) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::orchestrator::chat_completions_detection::ChatMessageInternal; - - #[tokio::test] - async fn test_filter_chat_message_single_messagae() { - let message = vec![ChatMessageInternal { - message_index: 0, - content: Some(Content::Text("hello".to_string())), - role: Role::Assistant, - ..Default::default() - }]; - - let filtered_messages = filter_chat_messages(&message); - - // Assertions - assert!(filtered_messages.is_ok()); - assert_eq!(filtered_messages.unwrap(), message); - } - - #[tokio::test] - async fn test_filter_chat_message_multiple_messages() { - let message = vec![ - ChatMessageInternal { - message_index: 0, - content: Some(Content::Text("hello".to_string())), - role: Role::Assistant, - ..Default::default() - }, - ChatMessageInternal { - message_index: 1, - content: Some(Content::Text("bot".to_string())), - role: Role::Assistant, - ..Default::default() - }, - ]; - - let filtered_messages = filter_chat_messages(&message); - - // Assertions - assert!(filtered_messages.is_ok()); - assert_eq!(filtered_messages.unwrap(), vec![message[1].clone()]); - } - - #[tokio::test] - async fn test_filter_chat_messages_incorrect_role() { - let message = vec![ChatMessageInternal { - message_index: 0, - content: Some(Content::Text("hello".to_string())), - role: Role::Tool, - ..Default::default() - }]; - - let filtered_messages = filter_chat_messages(&message); - - // Assertions - assert!(filtered_messages.is_err()); - assert_eq!( - filtered_messages.unwrap_err().to_string(), - "Last message role must be user, assistant, or system" - ); - } -} diff --git a/src/orchestrator/handlers.rs b/src/orchestrator/handlers.rs new file mode 100644 index 00000000..e681e2ca --- /dev/null +++ b/src/orchestrator/handlers.rs @@ -0,0 +1,69 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use http::HeaderMap; +use opentelemetry::trace::TraceId; + +pub mod classification_with_gen; +pub use classification_with_gen::ClassificationWithGenTask; +pub mod streaming_classification_with_gen; +pub use streaming_classification_with_gen::StreamingClassificationWithGenTask; +pub mod chat_completions_detection; +pub mod streaming_content_detection; +pub use streaming_content_detection::StreamingContentDetectionTask; +pub mod generation_with_detection; +pub use generation_with_detection::GenerationWithDetectionTask; +pub mod chat_detection; +pub use chat_detection::ChatDetectionTask; +pub mod context_docs_detection; +pub use context_docs_detection::ContextDocsDetectionTask; +pub mod detection_on_generation; +pub use detection_on_generation::DetectionOnGenerationTask; +pub mod text_content_detection; +pub use text_content_detection::TextContentDetectionTask; + +use super::Error; + +/// Implements a task handler. +pub trait Handle { + type Response: Send + 'static; + + async fn handle(&self, task: Task) -> Result; // TODO: Task +} + +/// A task. +pub struct Task { + /// Trace ID + pub trace_id: TraceId, + /// Headers + pub headers: HeaderMap, + /// Request + pub request: R, +} + +impl Task { + pub fn new(trace_id: TraceId, headers: HeaderMap, request: R) -> Self { + Self { + trace_id, + headers, + request, + } + } + + pub fn into_parts(self) -> (TraceId, HeaderMap, R) { + (self.trace_id, self.headers, self.request) + } +} diff --git a/src/orchestrator/handlers/chat_completions_detection.rs b/src/orchestrator/handlers/chat_completions_detection.rs new file mode 100644 index 00000000..7b78b2d2 --- /dev/null +++ b/src/orchestrator/handlers/chat_completions_detection.rs @@ -0,0 +1,65 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::instrument; + +use super::Handle; +use crate::{ + clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, + orchestrator::{Error, Orchestrator}, +}; + +pub mod streaming; +pub mod unary; + +impl Handle for Orchestrator { + type Response = ChatCompletionsResponse; + + #[instrument( + name = "chat_completions_detection", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: ChatCompletionsDetectionTask) -> Result { + let ctx = self.ctx.clone(); + match task.request.stream { + true => streaming::handle_streaming(ctx, task).await, + false => unary::handle_unary(ctx, task).await, + } + } +} + +#[derive(Debug)] +pub struct ChatCompletionsDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Request + pub request: ChatCompletionsRequest, + /// Headers + pub headers: HeaderMap, +} + +impl ChatCompletionsDetectionTask { + pub fn new(trace_id: TraceId, request: ChatCompletionsRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + request, + headers, + } + } +} diff --git a/src/orchestrator/handlers/chat_completions_detection/streaming.rs b/src/orchestrator/handlers/chat_completions_detection/streaming.rs new file mode 100644 index 00000000..d6c27cf8 --- /dev/null +++ b/src/orchestrator/handlers/chat_completions_detection/streaming.rs @@ -0,0 +1,55 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::sync::Arc; + +use tokio::sync::mpsc; +use tracing::{Instrument, info}; + +use super::ChatCompletionsDetectionTask; +use crate::{ + clients::openai::*, + orchestrator::{Context, Error}, +}; + +pub async fn handle_streaming( + _ctx: Arc, + task: ChatCompletionsDetectionTask, +) -> Result { + let trace_id = task.trace_id; + let detectors = task.request.detectors.clone().unwrap_or_default(); + info!(%trace_id, config = ?detectors, "task started"); + let _input_detectors = detectors.input; + let _output_detectors = detectors.output; + + // Create response channel + let (response_tx, response_rx) = + mpsc::channel::, Error>>(128); + + tokio::spawn( + async move { + // TODO + let _ = response_tx + .send(Err(Error::Validation( + "streaming is not yet supported".into(), + ))) + .await; + } + .in_current_span(), + ); + + Ok(ChatCompletionsResponse::Streaming(response_rx)) +} diff --git a/src/orchestrator/handlers/chat_completions_detection/unary.rs b/src/orchestrator/handlers/chat_completions_detection/unary.rs new file mode 100644 index 00000000..112862c0 --- /dev/null +++ b/src/orchestrator/handlers/chat_completions_detection/unary.rs @@ -0,0 +1,201 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::{collections::HashMap, sync::Arc}; + +use futures::future::try_join_all; +use tracing::{Instrument, error, info, instrument}; +use uuid::Uuid; + +use super::ChatCompletionsDetectionTask; +use crate::{ + clients::openai::*, + models::{ + DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE, + }, + orchestrator::{Context, Error, common, types::ChatMessageIterator}, +}; + +pub async fn handle_unary( + ctx: Arc, + task: ChatCompletionsDetectionTask, +) -> Result { + let trace_id = task.trace_id; + let detectors = task.request.detectors.clone().unwrap_or_default(); + info!(%trace_id, config = ?detectors, "task started"); + let input_detectors = detectors.input; + let output_detectors = detectors.output; + + // TODO: validate requested guardrails + + if !input_detectors.is_empty() { + // Handle input detection + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(completion)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Return response with input detections and terminate + let response = completion.into(); + return Ok(response); + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + return Err(error); + } + } + } + + // Handle chat completion + let client = ctx + .clients + .get_as::("chat_generation") + .unwrap(); + let chat_completion = + match common::chat_completion(client, task.headers.clone(), task.request.clone()).await { + Ok(ChatCompletionsResponse::Unary(chat_completion)) => *chat_completion, + Ok(ChatCompletionsResponse::Streaming(_)) => unimplemented!(), + Err(error) => return Err(error), + }; + + if !output_detectors.is_empty() { + // Handle output detection + let chat_completion = + handle_output_detection(ctx.clone(), task, output_detectors, chat_completion).await?; + Ok(chat_completion.into()) + } else { + // No output detectors, send chat completion response + Ok(chat_completion.into()) + } +} + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &ChatCompletionsDetectionTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let model_id = task.request.model.clone(); + + // Input detectors are only applied to the last message + // Get the last message + let messages = task.request.messages(); + let message = if let Some(message) = messages.last() { + message + } else { + return Err(Error::Validation("No messages provided".into())); + }; + // Validate role + if !matches!( + message.role, + Some(Role::User) | Some(Role::Assistant) | Some(Role::System) + ) { + return Err(Error::Validation( + "Last message role must be user, assistant, or system".into(), + )); + } + let input_id = message.index; + let input_text = message.text.map(|s| s.to_string()).unwrap_or_default(); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + input_id, + vec![(0, input_text)], + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Build chat completion with input detections + let chat_completion = ChatCompletion { + id: Uuid::new_v4().simple().to_string(), + model: model_id, + created: common::current_timestamp().as_secs() as i64, + detections: Some(ChatDetections { + input: vec![InputDetectionResult { + message_index: message.index, + results: detections.into(), + }], + ..Default::default() + }), + warnings: vec![OrchestratorWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + ..Default::default() + }; + Ok(Some(chat_completion)) + } else { + // No input detections + Ok(None) + } +} + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: ChatCompletionsDetectionTask, + detectors: HashMap, + mut chat_completion: ChatCompletion, +) -> Result { + let mut tasks = Vec::with_capacity(chat_completion.choices.len()); + for choice in &chat_completion.choices { + let input_id = choice.index; + let input_text = choice.message.content.clone().unwrap_or_default(); + tasks.push(tokio::spawn( + common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + input_id, + vec![(0, input_text)], + ) + .in_current_span(), + )); + } + let detections = try_join_all(tasks) + .await? + .into_iter() + .collect::, Error>>()?; + if !detections.is_empty() { + // Update chat completion with detections + let output = detections + .into_iter() + .filter(|(_, detections)| !detections.is_empty()) + .map(|(input_id, detections)| OutputDetectionResult { + choice_index: input_id, + results: detections.into(), + }) + .collect::>(); + if !output.is_empty() { + chat_completion.detections = Some(ChatDetections { + output, + ..Default::default() + }); + chat_completion.warnings = vec![OrchestratorWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )]; + } + } + Ok(chat_completion) +} diff --git a/src/orchestrator/handlers/chat_detection.rs b/src/orchestrator/handlers/chat_detection.rs new file mode 100644 index 00000000..6ec10a0f --- /dev/null +++ b/src/orchestrator/handlers/chat_detection.rs @@ -0,0 +1,85 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; + +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::{info, instrument}; + +use super::Handle; +use crate::{ + clients::openai, + models::{ChatDetectionHttpRequest, ChatDetectionResult, DetectorParams}, + orchestrator::{Error, Orchestrator, common}, +}; + +impl Handle for Orchestrator { + type Response = ChatDetectionResult; + + #[instrument( + name = "chat_detection", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: ChatDetectionTask) -> Result { + let ctx = self.ctx.clone(); + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.detectors, "task started"); + + // TODO: validate requested guardrails + + // Handle detection + let detections = common::text_chat_detections( + ctx, + task.headers, + task.detectors, + task.messages, + task.tools, + ) + .await?; + + Ok(ChatDetectionResult { + detections: detections.into(), + }) + } +} + +#[derive(Debug)] +pub struct ChatDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Detectors configuration + pub detectors: HashMap, + /// Messages to run detection on + pub messages: Vec, + /// Tools + pub tools: Vec, + /// Headers + pub headers: HeaderMap, +} + +impl ChatDetectionTask { + pub fn new(trace_id: TraceId, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + detectors: request.detectors, + messages: request.messages, + tools: request.tools, + headers, + } + } +} diff --git a/src/orchestrator/handlers/classification_with_gen.rs b/src/orchestrator/handlers/classification_with_gen.rs new file mode 100644 index 00000000..f0d8e1fb --- /dev/null +++ b/src/orchestrator/handlers/classification_with_gen.rs @@ -0,0 +1,213 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::{collections::HashMap, sync::Arc}; + +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::{error, info, instrument}; + +use super::Handle; +use crate::{ + clients::GenerationClient, + models::{ + ClassifiedGeneratedTextResult, DetectionWarning, DetectorParams, GuardrailsConfig, + GuardrailsHttpRequest, GuardrailsTextGenerationParameters, + TextGenTokenClassificationResults, + }, + orchestrator::{Context, Error, Orchestrator, common}, +}; + +impl Handle for Orchestrator { + type Response = ClassifiedGeneratedTextResult; + + #[instrument( + name = "classification_with_gen", + skip_all, + fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers) + )] + async fn handle(&self, task: ClassificationWithGenTask) -> Result { + let ctx = self.ctx.clone(); + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.guardrails_config, "task started"); + let input_detectors = task.guardrails_config.input_detectors(); + let output_detectors = task.guardrails_config.output_detectors(); + + // TODO: validate requested guardrails + + if !input_detectors.is_empty() { + // Handle input detection + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(response)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Return response with input detections and terminate + return Ok(response); + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + return Err(error); + } + } + } + + // Handle generation + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + let generation = common::generate( + client, + task.headers.clone(), + task.model_id.clone(), + task.inputs.clone(), + task.text_gen_parameters.clone(), + ) + .await?; + + if !output_detectors.is_empty() { + // Handle output detection + handle_output_detection(ctx.clone(), task, output_detectors, generation).await + } else { + // No output detectors, return generation + info!(%trace_id, "task completed: returning generation response"); + Ok(generation) + } + } +} + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &ClassificationWithGenTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let inputs = common::apply_masks(task.inputs.clone(), task.guardrails_config.input_masks()); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + 0, + inputs, + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Get token count + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + let input_token_count = match common::tokenize( + client, + task.headers.clone(), + task.model_id.clone(), + task.inputs.clone(), + ) + .await + { + Ok((token_count, _tokens)) => token_count, + Err(error) => { + error!(%trace_id, %error, "task failed: error tokenizing input text"); + return Err(error); + } + }; + // Build response with input detections + let response = ClassifiedGeneratedTextResult { + input_token_count, + token_classification_results: TextGenTokenClassificationResults { + input: Some(detections.into()), + output: None, + }, + warnings: Some(vec![DetectionWarning::unsuitable_input()]), + ..Default::default() + }; + Ok(Some(response)) + } else { + // No input detections + Ok(None) + } +} + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: ClassificationWithGenTask, + detectors: HashMap, + generation: ClassifiedGeneratedTextResult, +) -> Result { + let trace_id = task.trace_id; + let generated_text = generation.generated_text.clone().unwrap_or_default(); + let detections = match common::text_contents_detections( + ctx, + task.headers, + detectors, + 0, + vec![(0, generated_text)], + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing output detections"); + return Err(error); + } + }; + let mut response = generation; + if !detections.is_empty() { + response.token_classification_results.output = Some(detections.into()); + response.warnings = Some(vec![DetectionWarning::unsuitable_output()]); + } + info!(%trace_id, "task completed: returning response with output detections"); + Ok(response) +} + +#[derive(Debug)] +pub struct ClassificationWithGenTask { + /// Trace ID + pub trace_id: TraceId, + /// Model ID + pub model_id: String, + /// Input text + pub inputs: String, + /// Guardrails config + pub guardrails_config: GuardrailsConfig, + /// Text generation parameters + pub text_gen_parameters: Option, + /// Headers + pub headers: HeaderMap, +} + +impl ClassificationWithGenTask { + pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + model_id: request.model_id, + inputs: request.inputs, + guardrails_config: request.guardrail_config.unwrap_or_default(), + text_gen_parameters: request.text_gen_parameters, + headers, + } + } +} diff --git a/src/orchestrator/handlers/context_docs_detection.rs b/src/orchestrator/handlers/context_docs_detection.rs new file mode 100644 index 00000000..1fa9190b --- /dev/null +++ b/src/orchestrator/handlers/context_docs_detection.rs @@ -0,0 +1,89 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; + +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::{info, instrument}; + +use super::Handle; +use crate::{ + clients::detector::ContextType, + models::{ContextDocsHttpRequest, ContextDocsResult, DetectorParams}, + orchestrator::{Error, Orchestrator, common}, +}; + +impl Handle for Orchestrator { + type Response = ContextDocsResult; + + #[instrument( + name = "context_docs_detection", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: ContextDocsDetectionTask) -> Result { + let ctx = self.ctx.clone(); + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.detectors, "task started"); + + // TODO: validate requested guardrails + + // Handle detection + let detections = common::text_context_detections( + ctx, + task.headers, + task.detectors, + task.content, + task.context_type, + task.context, + ) + .await?; + + Ok(ContextDocsResult { + detections: detections.into(), + }) + } +} + +#[derive(Debug)] +pub struct ContextDocsDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Content text + pub content: String, + /// Context type + pub context_type: ContextType, + /// Context + pub context: Vec, + /// Detectors configuration + pub detectors: HashMap, + /// Headers + pub headers: HeaderMap, +} + +impl ContextDocsDetectionTask { + pub fn new(trace_id: TraceId, request: ContextDocsHttpRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + content: request.content, + context_type: request.context_type, + context: request.context, + detectors: request.detectors, + headers, + } + } +} diff --git a/src/orchestrator/handlers/detection_on_generation.rs b/src/orchestrator/handlers/detection_on_generation.rs new file mode 100644 index 00000000..0c563116 --- /dev/null +++ b/src/orchestrator/handlers/detection_on_generation.rs @@ -0,0 +1,88 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; + +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::{info, instrument}; + +use super::Handle; +use crate::{ + models::{DetectionOnGeneratedHttpRequest, DetectionOnGenerationResult, DetectorParams}, + orchestrator::{Error, Orchestrator, common}, +}; + +impl Handle for Orchestrator { + type Response = DetectionOnGenerationResult; + + #[instrument( + name = "detection_on_generation", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: DetectionOnGenerationTask) -> Result { + let ctx = self.ctx.clone(); + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.detectors, "task started"); + + // TODO: validate requested guardrails + + // Handle detection + let detections = common::text_generation_detections( + ctx, + task.headers, + task.detectors, + task.prompt, + task.generated_text, + ) + .await?; + + Ok(DetectionOnGenerationResult { + detections: detections.into(), + }) + } +} + +#[derive(Debug)] +pub struct DetectionOnGenerationTask { + /// Trace ID + pub trace_id: TraceId, + /// Prompt text + pub prompt: String, + /// Generated text + pub generated_text: String, + /// Detectors configuration + pub detectors: HashMap, + /// Headers + pub headers: HeaderMap, +} + +impl DetectionOnGenerationTask { + pub fn new( + trace_id: TraceId, + request: DetectionOnGeneratedHttpRequest, + headers: HeaderMap, + ) -> Self { + Self { + trace_id, + prompt: request.prompt, + generated_text: request.generated_text, + detectors: request.detectors, + headers, + } + } +} diff --git a/src/orchestrator/handlers/generation_with_detection.rs b/src/orchestrator/handlers/generation_with_detection.rs new file mode 100644 index 00000000..e6e99463 --- /dev/null +++ b/src/orchestrator/handlers/generation_with_detection.rs @@ -0,0 +1,112 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; + +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::{info, instrument}; + +use super::Handle; +use crate::{ + clients::GenerationClient, + models::{ + DetectorParams, GenerationWithDetectionHttpRequest, GenerationWithDetectionResult, + GuardrailsTextGenerationParameters, + }, + orchestrator::{Error, Orchestrator, common}, +}; + +impl Handle for Orchestrator { + type Response = GenerationWithDetectionResult; + + #[instrument( + name = "generation_with_detection", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: GenerationWithDetectionTask) -> Result { + let ctx = self.ctx.clone(); + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.detectors, "task started"); + + // TODO: validate requested guardrails + + // Handle generation + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + let generation = common::generate( + client, + task.headers.clone(), + task.model_id.clone(), + task.prompt.clone(), + task.text_gen_parameters.clone(), + ) + .await?; + let generated_text = generation.generated_text.unwrap_or_default(); + + // Handle detection + let detections = common::text_generation_detections( + ctx, + task.headers, + task.detectors, + task.prompt, + generated_text.clone(), + ) + .await?; + + Ok(GenerationWithDetectionResult { + generated_text, + input_token_count: generation.input_token_count, + detections: detections.into(), + }) + } +} + +#[derive(Debug)] +pub struct GenerationWithDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Model ID + pub model_id: String, + /// Prompt text + pub prompt: String, + /// Detectors configuration + pub detectors: HashMap, + /// Text generation parameters + pub text_gen_parameters: Option, + /// Headers + pub headers: HeaderMap, +} + +impl GenerationWithDetectionTask { + pub fn new( + trace_id: TraceId, + request: GenerationWithDetectionHttpRequest, + headers: HeaderMap, + ) -> Self { + Self { + trace_id, + model_id: request.model_id, + prompt: request.prompt, + detectors: request.detectors, + text_gen_parameters: request.text_gen_parameters, + headers, + } + } +} diff --git a/src/orchestrator/handlers/streaming_classification_with_gen.rs b/src/orchestrator/handlers/streaming_classification_with_gen.rs new file mode 100644 index 00000000..b24b2e09 --- /dev/null +++ b/src/orchestrator/handlers/streaming_classification_with_gen.rs @@ -0,0 +1,437 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use futures::StreamExt; +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{Instrument, error, info, instrument}; + +use super::Handle; +use crate::{ + clients::GenerationClient, + models::{ + ClassifiedGeneratedTextStreamResult, DetectionWarning, DetectorParams, GuardrailsConfig, + GuardrailsHttpRequest, GuardrailsTextGenerationParameters, + TextGenTokenClassificationResults, + }, + orchestrator::{ + Context, Error, Orchestrator, common, + types::{ + Chunk, DetectionBatchStream, DetectionStream, Detections, GenerationStream, + MaxProcessedIndexBatcher, + }, + }, +}; + +impl Handle for Orchestrator { + type Response = ReceiverStream>; + + #[instrument( + name = "streaming_classification_with_gen", + skip_all, + fields( + trace_id = task.trace_id.to_string(), + model_id = task.model_id, + headers = ?task.headers + ) + )] + async fn handle( + &self, + task: StreamingClassificationWithGenTask, + ) -> Result { + let ctx = self.ctx.clone(); + + // Create response channel + let (response_tx, response_rx) = + mpsc::channel::>(128); + + tokio::spawn(async move { + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.guardrails_config, "task started"); + let input_detectors = task.guardrails_config.input_detectors(); + let output_detectors = task.guardrails_config.output_detectors(); + + // TODO: validate requested guardrails + + if !input_detectors.is_empty() { + // Handle input detection + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(response)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Send message with input detections to response channel and terminate + let _ = response_tx.send(Ok(response)).await; + return; + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + + // Create generation stream + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + let generation_stream = match common::generate_stream( + client, + task.headers.clone(), + task.model_id.clone(), + task.inputs.clone(), + task.text_gen_parameters.clone(), + ) + .await + { + Ok(stream) => stream, + Err(error) => { + error!(%trace_id, %error, "task failed: error creating generation stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + }; + + if !output_detectors.is_empty() { + // Handle output detection + handle_output_detection( + ctx.clone(), + task, + output_detectors, + generation_stream, + response_tx, + ) + .await; + } else { + // No output detectors, forward generation stream to response stream + forward_generation_stream(trace_id, generation_stream, response_tx).await; + } + }.in_current_span()); + + Ok(ReceiverStream::new(response_rx)) + } +} + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &StreamingClassificationWithGenTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let inputs = common::apply_masks(task.inputs.clone(), task.guardrails_config.input_masks()); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + 0, + inputs, + ) + .await + { + Ok((_input_id, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Get token count + let client = ctx + .clients + .get_as::("generation") + .unwrap(); + let input_token_count = match common::tokenize( + client, + task.headers.clone(), + task.model_id.clone(), + task.inputs.clone(), + ) + .await + { + Ok((token_count, _tokens)) => token_count, + Err(error) => { + error!(%trace_id, %error, "task failed: error tokenizing input text"); + return Err(error); + } + }; + // Build response with input detections + let response = ClassifiedGeneratedTextStreamResult { + input_token_count, + token_classification_results: TextGenTokenClassificationResults { + input: Some(detections.into()), + output: None, + }, + warnings: Some(vec![DetectionWarning::unsuitable_input()]), + ..Default::default() + }; + Ok(Some(response)) + } else { + // No input detections + Ok(None) + } +} + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: StreamingClassificationWithGenTask, + detectors: HashMap, + mut generation_stream: GenerationStream, + response_tx: mpsc::Sender>, +) { + let trace_id = task.trace_id; + // Create input channel for detection pipeline + let (input_tx, input_rx) = mpsc::channel(128); + // Create shared generations + let generations: Arc>> = + Arc::new(RwLock::new(Vec::new())); + // Create detection streams + let detection_streams = common::text_contents_detection_streams( + ctx, + task.headers.clone(), + detectors.clone(), + 0, + input_rx, + ) + .await; + + // Spawn task to process detection streams + tokio::spawn({ + let generations = generations.clone(); + async move { + match detection_streams { + Ok(mut detection_streams) if detection_streams.len() == 1 => { + // Process single detection stream, batching not applicable + let detection_stream = detection_streams.swap_remove(0); + process_detection_stream(trace_id, generations, detection_stream, response_tx) + .await; + } + Ok(detection_streams) => { + // Create detection batch stream + let detection_batch_stream = DetectionBatchStream::new( + MaxProcessedIndexBatcher::new(detectors.len()), + detection_streams, + ); + process_detection_batch_stream( + trace_id, + generations, + detection_batch_stream, + response_tx, + ) + .await; + } + Err(error) => { + error!(%trace_id, %error, "task failed: error creating detection streams"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + } + } + } + .in_current_span() + }); + + // Spawn task to consume generations + tokio::spawn( + async move { + while let Some((index, result)) = generation_stream.next().await { + match result { + Ok(generation) => { + // Send generated text to input channel + let input = (index, generation.generated_text.clone().unwrap_or_default()); + let _ = input_tx.send(Ok(input)).await; + // Update shared generations + generations.write().unwrap().push(generation); + } + Err(error) => { + // Send error to input channel + let _ = input_tx.send(Err(error)).await; + // TODO: catch generation errors here to terminate all tasks? + } + } + } + } + .in_current_span(), + ); +} + +/// Consumes a generation stream, forwarding messages to a response channel. +#[instrument(skip_all)] +async fn forward_generation_stream( + trace_id: TraceId, + mut generation_stream: GenerationStream, + response_tx: mpsc::Sender>, +) { + while let Some((_index, result)) = generation_stream.next().await { + match result { + Ok(generation) => { + // Send message to response channel + if response_tx.send(Ok(generation)).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from generation stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: generation stream closed"); +} + +/// Consumes a detection stream, builds responses, and sends them to a response channel. +#[instrument(skip_all)] +async fn process_detection_stream( + trace_id: TraceId, + generations: Arc>>, + mut detection_stream: DetectionStream, + response_tx: mpsc::Sender>, +) { + while let Some(result) = detection_stream.next().await { + match result { + Ok((_, _detector_id, chunk, detections)) => { + // Create response for this batch with output detections + let response = output_detection_response(&generations, chunk, detections).unwrap(); + // Send message to response channel + if response_tx.send(Ok(response)).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection stream closed"); +} + +/// Consumes a detection batch stream, builds responses, and sends them to a response channel. +#[instrument(skip_all)] +async fn process_detection_batch_stream( + trace_id: TraceId, + generations: Arc>>, + mut detection_batch_stream: DetectionBatchStream, + response_tx: mpsc::Sender>, +) { + while let Some(result) = detection_batch_stream.next().await { + match result { + Ok((chunk, detections)) => { + // Create response for this batch with output detections + let response = output_detection_response(&generations, chunk, detections).unwrap(); + // Send message to response channel + if response_tx.send(Ok(response)).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection batch stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection batch stream closed"); +} + +/// Builds a response with output detections. +fn output_detection_response( + generations: &Arc>>, + chunk: Chunk, + detections: Detections, +) -> Result { + // Get subset of generations relevant for this chunk + let generations_slice = generations + .read() + .unwrap() + .get(chunk.input_start_index..=chunk.input_end_index) + .unwrap_or_default() + .to_vec(); + let last = generations_slice.last().cloned().unwrap_or_default(); + let tokens = generations_slice + .iter() + .flat_map(|generation| generation.tokens.clone().unwrap_or_default()) + .collect::>(); + let mut response = ClassifiedGeneratedTextStreamResult { + generated_text: Some(chunk.text), + start_index: Some(chunk.start as u32), + processed_index: Some(chunk.end as u32), + tokens: Some(tokens), + ..last + }; + response.token_classification_results.output = Some(detections.into()); + if chunk.input_start_index == 0 { + // Get input_token_count and seed from first generation message + let first = generations_slice.first().unwrap(); + response.input_token_count = first.input_token_count; + response.seed = first.seed; + // Get input_tokens from second generation message (if specified) + response.input_tokens = if let Some(second) = generations_slice.get(1) { + second.input_tokens.clone() + } else { + Some(Vec::default()) + }; + } + Ok(response) +} + +#[derive(Debug)] +pub struct StreamingClassificationWithGenTask { + /// Trace ID + pub trace_id: TraceId, + /// Model ID + pub model_id: String, + /// Input text + pub inputs: String, + /// Guardrails configuration + pub guardrails_config: GuardrailsConfig, + /// Text generation parameters + pub text_gen_parameters: Option, + /// Headers + pub headers: HeaderMap, +} + +impl StreamingClassificationWithGenTask { + pub fn new(trace_id: TraceId, request: GuardrailsHttpRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + model_id: request.model_id, + inputs: request.inputs, + guardrails_config: request.guardrail_config.unwrap_or_default(), + text_gen_parameters: request.text_gen_parameters, + headers, + } + } +} diff --git a/src/orchestrator/handlers/streaming_content_detection.rs b/src/orchestrator/handlers/streaming_content_detection.rs new file mode 100644 index 00000000..32e41b55 --- /dev/null +++ b/src/orchestrator/handlers/streaming_content_detection.rs @@ -0,0 +1,258 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::{collections::HashMap, pin::Pin, sync::Arc}; + +use futures::{Stream, StreamExt, stream::Peekable}; +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{Instrument, error, info, instrument}; + +use super::Handle; +use crate::{ + models::{DetectorParams, StreamingContentDetectionRequest, StreamingContentDetectionResponse}, + orchestrator::{ + Context, Error, Orchestrator, common, + types::{BoxStream, DetectionBatchStream, DetectionStream, MaxProcessedIndexBatcher}, + }, +}; + +type InputStream = + Pin)> + Send>>; + +impl Handle for Orchestrator { + type Response = ReceiverStream>; + + #[instrument( + name = "streaming_content_detection", + skip_all, + fields(trace_id = task.trace_id.to_string(), headers = ?task.headers) + )] + async fn handle(&self, task: StreamingContentDetectionTask) -> Result { + let ctx = self.ctx.clone(); + + // Create response channel + let (response_tx, response_rx) = + mpsc::channel::>(128); + + tokio::spawn( + async move { + let trace_id = task.trace_id; + let headers = task.headers; + let mut input_stream = Box::pin(task.input_stream.peekable()); + let detectors = match extract_detectors(&mut input_stream).await { + Ok(detectors) => detectors, + Err(error) => { + error!(%error, "error extracting detectors from first message"); + let _ = response_tx.send(Err(error)).await; + return; + } + }; + info!(%trace_id, config = ?detectors, "task started"); + + // TODO: validate requested guardrails + + handle_detection(ctx, trace_id, headers, detectors, input_stream, response_tx) + .await; + } + .in_current_span(), + ); + + Ok(ReceiverStream::new(response_rx)) + } +} + +/// Extracts detectors config from first message. +async fn extract_detectors( + input_stream: &mut Peekable, +) -> Result, Error> { + // We can use Peekable to get a reference to it instead of consuming the message here + // Peekable::peek() takes self: Pin<&mut Peekable<_>>, which is why we need to pin it + // https://docs.rs/futures/latest/futures/stream/struct.Peekable.html + if let Some((_index, result)) = Pin::new(input_stream).peek().await { + match result { + Ok(msg) => { + if let Some(detectors) = &msg.detectors { + if detectors.is_empty() { + return Err(Error::Validation( + "`detectors` must not be empty".to_string(), + )); + } + return Ok(detectors.clone()); + } + } + Err(error) => return Err(error.clone()), + } + } + Err(Error::Validation( + "`detectors` is required for the first message".into(), + )) +} + +#[instrument(skip_all)] +async fn handle_detection( + ctx: Arc, + trace_id: TraceId, + headers: HeaderMap, + detectors: HashMap, + mut input_stream: InputStream, + response_tx: mpsc::Sender>, +) { + // Create input channel for detection pipeline + let (input_tx, input_rx) = mpsc::channel(128); + // Create detection streams + let detection_streams = + common::text_contents_detection_streams(ctx, headers, detectors.clone(), 0, input_rx).await; + + // Spawn task to process detection streams + tokio::spawn( + async move { + match detection_streams { + Ok(mut detection_streams) if detection_streams.len() == 1 => { + // Process single detection stream, batching not applicable + let detection_stream = detection_streams.swap_remove(0); + process_detection_stream(trace_id, detection_stream, response_tx).await; + } + Ok(detection_streams) => { + // Create detection batch stream + let detection_batch_stream = DetectionBatchStream::new( + MaxProcessedIndexBatcher::new(detectors.len()), + detection_streams, + ); + process_detection_batch_stream(trace_id, detection_batch_stream, response_tx) + .await; + } + Err(error) => { + error!(%trace_id, %error, "task failed: error creating detection streams"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + } + } + } + .in_current_span(), + ); + + // Spawn task to consume input stream + tokio::spawn( + async move { + while let Some((index, result)) = input_stream.next().await { + match result { + Ok(message) => { + // Send content text to input channel + let _ = input_tx.send(Ok((index, message.content))).await; + } + Err(error) => { + // Send error to input channel + let _ = input_tx.send(Err(error)).await; + } + } + } + } + .in_current_span(), + ); +} + +/// Consumes a detection stream, builds responses, and sends them to a response channel. +#[instrument(skip_all)] +async fn process_detection_stream( + trace_id: TraceId, + mut detection_stream: DetectionStream, + response_tx: mpsc::Sender>, +) { + while let Some(result) = detection_stream.next().await { + match result { + Ok((_, _detector_id, chunk, detections)) => { + let response = StreamingContentDetectionResponse { + start_index: chunk.start as u32, + processed_index: chunk.end as u32, + detections: detections.into(), + }; + // Send message to response channel + if response_tx.send(Ok(response)).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection stream closed"); +} + +/// Consumes a detection batch stream, builds responses, and sends them to a response channel. +#[instrument(skip_all)] +async fn process_detection_batch_stream( + trace_id: TraceId, + mut detection_batch_stream: DetectionBatchStream, + response_tx: mpsc::Sender>, +) { + while let Some(result) = detection_batch_stream.next().await { + match result { + Ok((chunk, detections)) => { + let response = StreamingContentDetectionResponse { + start_index: chunk.start as u32, + processed_index: chunk.end as u32, + detections: detections.into(), + }; + // Send message to response channel + if response_tx.send(Ok(response)).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection batch stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection batch stream closed"); +} + +pub struct StreamingContentDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Headers + pub headers: HeaderMap, + /// Detectors configuration + pub detectors: HashMap, + /// Input stream to run detections on + pub input_stream: BoxStream<(usize, Result)>, +} + +impl StreamingContentDetectionTask { + pub fn new( + trace_id: TraceId, + headers: HeaderMap, + input_stream: BoxStream<(usize, Result)>, + ) -> Self { + Self { + trace_id, + headers, + detectors: HashMap::default(), + input_stream, + } + } +} diff --git a/src/orchestrator/handlers/text_content_detection.rs b/src/orchestrator/handlers/text_content_detection.rs new file mode 100644 index 00000000..a4773444 --- /dev/null +++ b/src/orchestrator/handlers/text_content_detection.rs @@ -0,0 +1,85 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::collections::HashMap; + +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::{info, instrument}; + +use super::Handle; +use crate::{ + models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, + orchestrator::{Error, Orchestrator, common}, +}; + +impl Handle for Orchestrator { + type Response = TextContentDetectionResult; + + #[instrument( + name = "text_content_detection", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: TextContentDetectionTask) -> Result { + let ctx = self.ctx.clone(); + let trace_id = task.trace_id; + info!(%trace_id, config = ?task.detectors, "task started"); + + // TODO: validate requested guardrails + + // Handle detection + let (_, detections) = common::text_contents_detections( + ctx, + task.headers, + task.detectors, + 0, + vec![(0, task.content)], + ) + .await?; + + Ok(TextContentDetectionResult { + detections: detections.into(), + }) + } +} + +#[derive(Debug)] +pub struct TextContentDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Content text + pub content: String, + /// Detectors configuration + pub detectors: HashMap, + /// Headers + pub headers: HeaderMap, +} + +impl TextContentDetectionTask { + pub fn new( + trace_id: TraceId, + request: TextContentDetectionHttpRequest, + headers: HeaderMap, + ) -> Self { + Self { + trace_id, + content: request.content, + detectors: request.detectors, + headers, + } + } +} diff --git a/src/orchestrator/streaming.rs b/src/orchestrator/streaming.rs deleted file mode 100644 index 20cc4759..00000000 --- a/src/orchestrator/streaming.rs +++ /dev/null @@ -1,639 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -pub mod aggregator; - -use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; - -use aggregator::Aggregator; -use axum::http::HeaderMap; -use futures::{Stream, StreamExt, TryStreamExt, future::try_join_all}; -use tokio::sync::{broadcast, mpsc}; -use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; -use tracing::{Instrument, Span, debug, error, info, instrument, warn}; - -use super::{Context, Error, Orchestrator, StreamingClassificationWithGenTask}; -use crate::{ - clients::{ - GenerationClient, TextContentsDetectorClient, - chunker::{ChunkerClient, DEFAULT_CHUNKER_ID, tokenize_whole_doc_stream}, - detector::ContentAnalysisRequest, - }, - models::{ - ClassifiedGeneratedTextStreamResult, DetectionWarning, DetectionWarningReason, - DetectorParams, GuardrailsTextGenerationParameters, TextGenTokenClassificationResults, - TokenClassificationResult, UNSUITABLE_INPUT_MESSAGE, - }, - orchestrator::{ - common::get_chunker_ids, - unary::{input_detection_task, tokenize}, - }, - pb::{caikit::runtime::chunkers, caikit_data_model::nlp::ChunkerTokenizationStreamResult}, -}; - -pub type Chunk = ChunkerTokenizationStreamResult; -pub type Detections = Vec; - -impl Orchestrator { - /// Handles streaming tasks. - #[instrument(skip_all, fields(trace_id = task.trace_id.to_string(), model_id = task.model_id, headers = ?task.headers))] - pub async fn handle_streaming_classification_with_gen( - &self, - task: StreamingClassificationWithGenTask, - ) -> ReceiverStream> { - info!(config = ?task.guardrails_config, "starting task"); - - let ctx = self.ctx.clone(); - let trace_id = task.trace_id; - let model_id = task.model_id; - let params = task.text_gen_parameters; - let input_text = task.inputs; - let headers = task.headers; - - // Create response channel - #[allow(clippy::type_complexity)] - let (response_tx, response_rx): ( - mpsc::Sender>, - mpsc::Receiver>, - ) = mpsc::channel(1024); - - tokio::spawn(async move { - // Do input detections (unary) - let masks = task.guardrails_config.input_masks(); - let input_detectors = task.guardrails_config.input_detectors(); - let input_detections = match input_detectors { - Some(detectors) if !detectors.is_empty() => { - match input_detection_task( - &ctx, - detectors, - input_text.clone(), - masks, - headers.clone(), - ) - .await - { - Ok(result) => result, - Err(error) => { - error!(%trace_id, %error, "task failed"); - let _ = response_tx.send(Err(error)).await; - return; - } - } - } - _ => None, - }; - debug!(?input_detections); // TODO: metrics - if let Some(mut input_detections) = input_detections { - // Detected HAP/PII - // Do tokenization to get input_token_count - let (input_token_count, _tokens) = - match tokenize(&ctx, model_id.clone(), input_text.clone(), headers.clone()) - .await - { - Ok(result) => result, - Err(error) => { - error!(%trace_id, %error, "task failed"); - let _ = response_tx.send(Err(error)).await; - return; - } - }; - input_detections.sort_by_key(|r| r.start); - // Send result with input detections - let _ = response_tx - .send(Ok(ClassifiedGeneratedTextStreamResult { - input_token_count, - token_classification_results: TextGenTokenClassificationResults { - input: Some(input_detections), - output: None, - }, - warnings: Some(vec![DetectionWarning { - id: Some(DetectionWarningReason::UnsuitableInput), - message: Some(UNSUITABLE_INPUT_MESSAGE.to_string()), - }]), - ..Default::default() - })) - .await; - } else { - // No HAP/PII detected - // Do text generation (streaming) - let mut generation_stream = match generate_stream( - &ctx, - model_id.clone(), - input_text.clone(), - params.clone(), - headers.clone(), - ) - .await - { - Ok(generation_stream) => generation_stream, - Err(error) => { - error!(%trace_id, %error, "task failed"); - let _ = response_tx.send(Err(error)).await; - return; - } - }; - - // Do output detections (streaming) - let output_detectors = task.guardrails_config.output_detectors(); - match output_detectors { - Some(detectors) if !detectors.is_empty() => { - // Create error channel - // - // This channel is used for error notification & messaging and task cancellation. - // When a task fails, it notifies other tasks by sending the error to error_tx. - // - // The parent task receives the error, logs it, forwards it to the client via response_tx, - // and terminates the task. - let (error_tx, _) = broadcast::channel(1); - - let mut result_rx = match streaming_output_detection_task( - &ctx, - detectors, - generation_stream, - error_tx.clone(), - headers.clone(), - ) - .await - { - Ok(result_rx) => result_rx, - Err(error) => { - error!(%trace_id, %error, "task failed"); - let _ = error_tx.send(error.clone()); - let _ = response_tx.send(Err(error)).await; - return; - } - }; - // Forward generation results with detections to response channel - tokio::spawn(async move { - let mut error_rx = error_tx.subscribe(); - loop { - tokio::select! { - Ok(error) = error_rx.recv() => { - error!(%trace_id, %error, "task failed"); - debug!(%trace_id, "sending error to client and terminating"); - let _ = response_tx.send(Err(error)).await; - return; - }, - result = result_rx.recv() => { - match result { - Some(result) => { - debug!(%trace_id, ?result, "sending result to client"); - if (response_tx.send(result).await).is_err() { - warn!(%trace_id, "response channel closed (client disconnected), terminating task"); - // Broadcast cancellation signal to tasks - let _ = error_tx.send(Error::Cancelled); - return; - } - }, - None => { - info!(%trace_id, "task completed: stream closed"); - break; - }, - } - } - } - } - }); - } - _ => { - // No output detectors, forward generation results to response channel - tokio::spawn(async move { - while let Some(result) = generation_stream.next().await { - debug!(%trace_id, ?result, "sending result to client"); - if (response_tx.send(result).await).is_err() { - warn!(%trace_id, "response channel closed (client disconnected), terminating task"); - return; - } - } - debug!(%trace_id, "task completed: stream closed"); - }); - } - } - } - }.instrument(Span::current())); - ReceiverStream::new(response_rx) - } -} - -/// Handles streaming output detection task. -#[instrument(skip_all)] -async fn streaming_output_detection_task( - ctx: &Arc, - detectors: &HashMap, - generation_stream: Pin< - Box> + Send>, - >, - error_tx: broadcast::Sender, - headers: HeaderMap, -) -> Result>, Error> { - debug!(?detectors, "creating chunk broadcast streams"); - - // Create generation broadcast stream - let (generation_tx, generation_rx) = broadcast::channel(1024); - - let chunker_ids = get_chunker_ids(ctx, detectors)?; - // Create a map of chunker_id->chunk_broadcast_stream - // This is to enable fan-out of chunk streams to potentially multiple detectors that use the same chunker. - // Each detector task will subscribe to an associated chunk stream. - let chunk_broadcast_streams = try_join_all( - chunker_ids - .into_iter() - .map(|chunker_id| { - debug!(%chunker_id, "creating chunk broadcast stream"); - let ctx = ctx.clone(); - let error_tx = error_tx.clone(); - // Subscribe to generation stream - let generation_rx = generation_tx.subscribe(); - async move { - let chunk_tx = - chunk_broadcast_task(ctx, chunker_id.clone(), generation_rx, error_tx) - .await?; - Ok::<(String, broadcast::Sender), Error>((chunker_id, chunk_tx)) - } - }) - .collect::>(), - ) - .await? - .into_iter() - .collect::>(); - - // Spawn detection tasks to subscribe to chunker stream, - // send requests to detector service, and send results to detection stream - debug!("spawning detection tasks"); - let mut detection_streams = Vec::with_capacity(detectors.len()); - for (detector_id, detector_params) in detectors.iter() { - // Create a mutable copy of the parameters, so that we can modify it based on processing - let mut detector_params = detector_params.clone(); - let detector_id = detector_id.to_string(); - let chunker_id = ctx - .config - .get_chunker_id(&detector_id) - .expect("chunker id is not found"); - - // Get the detector config - // TODO: Add error handling - let detector_config = ctx - .config - .detectors - .get(&detector_id) - .expect("detector config not found"); - - // Get the default threshold to use if threshold is not provided by the user - let default_threshold = detector_config.default_threshold; - let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); - - // Create detection stream - let (detector_tx, detector_rx) = mpsc::channel(1024); - // Subscribe to chunk broadcast stream - let chunk_rx = chunk_broadcast_streams - .get(&chunker_id) - .unwrap() - .subscribe(); - let error_tx = error_tx.clone(); - tokio::spawn( - detection_task( - ctx.clone(), - detector_id.clone(), - detector_params, - threshold, - detector_tx, - chunk_rx, - error_tx, - headers.clone(), - ) - .instrument(Span::current()), - ); - detection_streams.push((detector_id, detector_rx)); - } - - debug!("processing detection streams"); - let aggregator = Aggregator::default(); - let result_rx = aggregator.run(generation_tx.subscribe(), detection_streams); - - debug!("spawning generation broadcast task"); - // Spawn task to consume generation stream and forward to broadcast stream - tokio::spawn( - generation_broadcast_task(generation_stream, generation_tx, error_tx.clone()) - .instrument(Span::current()), - ); - drop(generation_rx); - - Ok(result_rx) -} - -#[instrument(skip_all)] -async fn generation_broadcast_task( - mut generation_stream: Pin< - Box> + Send>, - >, - generation_tx: broadcast::Sender, - error_tx: broadcast::Sender, -) { - debug!("forwarding response stream"); - let mut error_rx = error_tx.subscribe(); - loop { - tokio::select! { - _ = error_rx.recv() => { - warn!("cancellation signal received, terminating task"); - break - }, - result = generation_stream.next() => { - match result { - Some(Ok(generation)) => { - debug!(?generation, "received generation"); - let _ = generation_tx.send(generation); - }, - Some(Err(error)) => { - error!(%error, "generation error, cancelling task"); - let _ = error_tx.send(error); - tokio::time::sleep(Duration::from_millis(5)).await; - break; - }, - None => { - debug!("stream closed"); - break - }, - } - } - } - } -} - -/// Wraps a unary detector service to make it streaming. -/// Consumes chunk broadcast stream, sends unary requests to a detector service, -/// and sends chunk + responses to detection stream. -#[allow(clippy::too_many_arguments)] -#[instrument(skip_all, fields(detector_id))] -async fn detection_task( - ctx: Arc, - detector_id: String, - detector_params: DetectorParams, - threshold: f64, - detector_tx: mpsc::Sender<(Chunk, Detections)>, - mut chunk_rx: broadcast::Receiver, - error_tx: broadcast::Sender, - headers: HeaderMap, -) { - debug!(threshold, "starting task"); - let mut error_rx = error_tx.subscribe(); - - loop { - tokio::select! { - _ = error_rx.recv() => { - warn!("cancellation signal received, terminating task"); - break - }, - result = chunk_rx.recv() => { - match result { - Ok(chunk) => { - debug!(%detector_id, ?chunk, "received chunk"); - // Send request to detector service - let contents = chunk - .results - .iter() - .map(|token| token.text.clone()) - .collect::>(); - if contents.is_empty() { - debug!("empty chunk, skipping detector request."); - break; - } else { - let request = ContentAnalysisRequest::new(contents.clone(), detector_params.clone()); - let headers = headers.clone(); - debug!(%detector_id, ?request, "sending detector request"); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap_or_else(|| panic!("text contents detector client not found for {}", detector_id)); - match client.text_contents(&detector_id, request, headers) - .await - .map_err(|error| Error::DetectorRequestFailed { id: detector_id.clone(), error }) { - Ok(response) => { - debug!(%detector_id, ?response, "received detector response"); - let detections = response - .into_iter() - .flat_map(|r| { - r.into_iter().filter_map(|resp| { - let mut result: TokenClassificationResult = resp.into(); - // add detector_id - result.detector_id = Some(detector_id.clone()); - (result.score >= threshold).then_some(result) - }) - }) - .collect::>(); - let _ = detector_tx.send((chunk, detections)).await; - }, - Err(error) => { - error!(%detector_id, %error, "detector error, cancelling task"); - let _ = error_tx.send(error); - tokio::time::sleep(Duration::from_millis(5)).await; - break; - }, - } - } - }, - Err(broadcast::error::RecvError::Closed) => { - debug!(%detector_id, "stream closed"); - break; - }, - Err(broadcast::error::RecvError::Lagged(_)) => { - debug!(%detector_id, "stream lagged"); - continue; - } - } - }, - } - } -} - -/// Opens bi-directional stream to a chunker service -/// with generation stream input and returns chunk broadcast stream. -#[instrument(skip_all, fields(chunker_id))] -async fn chunk_broadcast_task( - ctx: Arc, - chunker_id: String, - generation_rx: broadcast::Receiver, - error_tx: broadcast::Sender, -) -> Result, Error> { - // Consume generation stream and convert to chunker input stream - debug!("creating chunker input stream"); - // NOTE: Text gen providers can return more than 1 token in single stream object. This can create - // edge cases where the enumeration generated below may not line up with token / response boundaries. - // So the more accurate way here might be to use `Tokens` object from response, but since that is an - // optional response parameter, we are avoiding that for now. - let input_stream = BroadcastStream::new(generation_rx) - .enumerate() - .map(|(token_pointer, generation_result)| { - let generated_text = generation_result - .unwrap() - .generated_text - .unwrap_or_default(); - chunkers::BidiStreamingChunkerTokenizationTaskRequest { - text_stream: generated_text, - input_index_stream: token_pointer as i64, - } - }) - .boxed(); - debug!("creating chunker output stream"); - let id = chunker_id.clone(); // workaround for StreamExt::map_err - let response_stream = if chunker_id == DEFAULT_CHUNKER_ID { - info!("Using default whole doc chunker"); - let (response_tx, response_rx) = mpsc::channel(1); - // Spawn task to collect input stream - tokio::spawn( - async move { - // NOTE: this will not resolve until the input stream is closed - let response = tokenize_whole_doc_stream(input_stream).await; - let _ = response_tx.send(response).await; - } - .instrument(Span::current()), - ); - Ok(ReceiverStream::new(response_rx).boxed()) - } else { - let client = ctx.clients.get_as::(&chunker_id).unwrap(); - client - .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) - .await - }; - - let mut output_stream = response_stream - .map_err(|error| Error::ChunkerRequestFailed { - id: chunker_id.clone(), - error, - })? - .map_err(move |error| Error::ChunkerRequestFailed { - id: id.clone(), - error, - }); // maps stream errors - - // Spawn task to consume output stream forward to broadcast channel - debug!("spawning chunker broadcast task"); - let (chunk_tx, _) = broadcast::channel(1024); - tokio::spawn({ - let mut error_rx = error_tx.subscribe(); - let chunk_tx = chunk_tx.clone(); - async move { - loop { - tokio::select! { - _ = error_rx.recv() => { - warn!("cancellation signal received, terminating task"); - break - }, - result = output_stream.next() => { - match result { - Some(Ok(chunk)) => { - debug!(?chunk, "received chunk"); - let _ = chunk_tx.send(chunk); - }, - Some(Err(error)) => { - error!(%error, "chunker error, cancelling task"); - let _ = error_tx.send(error); - tokio::time::sleep(Duration::from_millis(5)).await; - break; - }, - None => { - debug!("stream closed"); - break - }, - } - } - } - } - } - .instrument(Span::current()) - }); - Ok(chunk_tx) -} - -/// Sends generate stream request to a generation service. -#[allow(clippy::type_complexity)] -#[instrument(skip_all, fields(model_id))] -async fn generate_stream( - ctx: &Arc, - model_id: String, - text: String, - params: Option, - headers: HeaderMap, -) -> Result< - Pin> + Send>>, - Error, -> { - debug!(?params, "sending generate stream request"); - let client = ctx - .clients - .get_as::("generation") - .unwrap(); - Ok(client - .generate_stream(model_id.clone(), text, params, headers) - .await - .map_err(|error| Error::GenerateRequestFailed { - id: model_id.clone(), - error, - })? - .map_err(move |error| Error::GenerateRequestFailed { - id: model_id.clone(), - error, - }) // maps stream errors - .boxed()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_generation_broadcast_task() { - let (generation_tx, generation_rx) = mpsc::channel(4); - let (generation_broadcast_tx, mut generation_broadcast_rx) = broadcast::channel(4); - let generation_stream = ReceiverStream::new(generation_rx).boxed(); - let (error_tx, _) = broadcast::channel(1); - let results = vec![ - ClassifiedGeneratedTextStreamResult { - generated_text: Some("hello".into()), - ..Default::default() - }, - ClassifiedGeneratedTextStreamResult { - generated_text: Some(" ".into()), - ..Default::default() - }, - ClassifiedGeneratedTextStreamResult { - generated_text: Some("world".into()), - ..Default::default() - }, - ]; - tokio::spawn( - { - let results = results.clone(); - async move { - for result in results { - let _ = generation_tx.send(Ok(result)).await; - } - } - } - .instrument(Span::current()), - ); - tokio::spawn( - generation_broadcast_task(generation_stream, generation_broadcast_tx, error_tx) - .instrument(Span::current()), - ); - let mut broadcast_results = Vec::with_capacity(results.len()); - while let Ok(result) = generation_broadcast_rx.recv().await { - println!("{result:?}"); - broadcast_results.push(result); - } - assert_eq!(results, broadcast_results) - } -} diff --git a/src/orchestrator/streaming/aggregator.rs b/src/orchestrator/streaming/aggregator.rs deleted file mode 100644 index 8c8ae197..00000000 --- a/src/orchestrator/streaming/aggregator.rs +++ /dev/null @@ -1,678 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -#![allow(dead_code)] -use std::{ - collections::{BTreeMap, btree_map}, - sync::Arc, -}; - -use tokio::sync::{broadcast, mpsc, oneshot}; -use tracing::instrument; - -use crate::{ - models::ClassifiedGeneratedTextStreamResult, - orchestrator::{ - Error, - streaming::{Chunk, Detections}, - }, -}; - -pub type DetectorId = String; -pub type Span = (i64, i64); - -#[derive(Debug, Clone, Copy)] -pub enum AggregationStrategy { - MaxProcessedIndex, -} - -pub struct Aggregator { - strategy: AggregationStrategy, -} - -impl Default for Aggregator { - fn default() -> Self { - Self { - strategy: AggregationStrategy::MaxProcessedIndex, - } - } -} - -impl Aggregator { - pub fn new(strategy: AggregationStrategy) -> Self { - Self { strategy } - } - - #[instrument(skip_all)] - pub fn run( - &self, - mut generation_rx: broadcast::Receiver, - detection_streams: Vec<(DetectorId, mpsc::Receiver<(Chunk, Detections)>)>, - ) -> mpsc::Receiver> { - // Create result channel - let (result_tx, result_rx) = mpsc::channel(32); - - // Create actors - let generation_actor = Arc::new(GenerationActorHandle::new()); - let result_actor = ResultActorHandle::new(generation_actor.clone(), result_tx); - let aggregation_actor = Arc::new(AggregationActorHandle::new( - result_actor, - detection_streams.len(), - //self.strategy, - )); - - // Spawn task to send generations to generation actor - tokio::spawn({ - async move { - while let Ok(generation) = generation_rx.recv().await { - let _ = generation_actor.put(generation).await; - } - } - }); - - // Spawn tasks to process detection streams concurrently - for (_detector_id, mut stream) in detection_streams { - let aggregation_actor = aggregation_actor.clone(); - tokio::spawn(async move { - while let Some((chunk, detections)) = stream.recv().await { - // Send to aggregation actor - aggregation_actor.send(chunk, detections).await; - } - }); - } - result_rx - } -} - -#[derive(Debug)] -struct ResultActorMessage { - pub chunk: Chunk, - pub detections: Detections, -} - -/// Builds results and sends them to result channel. -struct ResultActor { - rx: mpsc::Receiver, - generation_actor: Arc, - result_tx: mpsc::Sender>, -} - -impl ResultActor { - pub fn new( - rx: mpsc::Receiver, - generation_actor: Arc, - result_tx: mpsc::Sender>, - ) -> Self { - Self { - rx, - generation_actor, - result_tx, - } - } - - async fn run(&mut self) { - while let Some(msg) = self.rx.recv().await { - self.handle(msg).await; - } - } - - async fn handle(&mut self, msg: ResultActorMessage) { - let chunk = msg.chunk; - let detections = msg.detections; - let generated_text: String = chunk.results.into_iter().map(|t| t.text).collect(); - let input_start_index = chunk.input_start_index as usize; - let input_end_index = chunk.input_end_index as usize; - - // Get subset of generation responses relevant for this chunk - let generations = self - .generation_actor - .get_range(input_start_index, input_end_index) - .await; - - // Build result - let tokens = generations - .iter() - .flat_map(|generation| generation.tokens.clone().unwrap_or_default()) - .collect::>(); - let mut result = ClassifiedGeneratedTextStreamResult { - generated_text: Some(generated_text), - start_index: Some(chunk.start_index as u32), - processed_index: Some(chunk.processed_index as u32), - tokens: Some(tokens), - // Populate fields from last response or default - ..generations.last().cloned().unwrap_or_default() - }; - result.token_classification_results.output = Some(detections); - if input_start_index == 0 { - // Get input_token_count and seed from first generation message - let first = generations - .first() - .expect("first element in classified generated text stream result not found"); - result.input_token_count = first.input_token_count; - result.seed = first.seed; - // Get input_tokens from second generation message (if specified) - let input_tokens = if let Some(second) = generations.get(1) { - second.input_tokens.clone() - } else { - Some(Vec::default()) - }; - result.input_tokens = input_tokens; - } - - // Send result to result channel - let _ = self.result_tx.send(Ok(result)).await; - } -} - -/// [`ResultActor`] handle. -struct ResultActorHandle { - tx: mpsc::Sender, -} - -impl ResultActorHandle { - pub fn new( - generation_actor: Arc, - result_tx: mpsc::Sender>, - ) -> Self { - let (tx, rx) = mpsc::channel(32); - let mut actor = ResultActor::new(rx, generation_actor, result_tx); - tokio::spawn(async move { actor.run().await }); - Self { tx } - } - - pub async fn send(&self, chunk: Chunk, detections: Detections) { - let msg = ResultActorMessage { chunk, detections }; - let _ = self.tx.send(msg).await; - } -} - -#[derive(Debug)] -struct AggregationActorMessage { - pub chunk: Chunk, - pub detections: Detections, -} - -/// Aggregates detections and sends them to [`ResultActor`]. -struct AggregationActor { - rx: mpsc::Receiver, - result_actor: ResultActorHandle, - tracker: Tracker, - n_detectors: usize, -} - -impl AggregationActor { - pub fn new( - rx: mpsc::Receiver, - result_actor: ResultActorHandle, - n_detectors: usize, - ) -> Self { - let tracker = Tracker::new(); - Self { - rx, - result_actor, - tracker, - n_detectors, - } - } - - async fn run(&mut self) { - while let Some(msg) = self.rx.recv().await { - self.handle(msg).await; - } - } - - async fn handle(&mut self, msg: AggregationActorMessage) { - let chunk = msg.chunk; - let detections = msg.detections; - - // Add to tracker - let span = (chunk.start_index, chunk.processed_index); - self.tracker - .insert(span, TrackerEntry::new(chunk, detections)); - - // Check if we have all detections for the first span - if self - .tracker - .first() - .is_some_and(|first| first.detections.len() == self.n_detectors) - { - // Take first span and send to result actor - if let Some((_key, value)) = self.tracker.pop_first() { - let chunk = value.chunk; - let mut detections: Detections = value.detections.into_iter().flatten().collect(); - // Provide sorted detections within each chunk - detections.sort_by_key(|r| r.start); - let _ = self.result_actor.send(chunk, detections).await; - } - } - } -} - -/// [`AggregationActor`] handle. -struct AggregationActorHandle { - tx: mpsc::Sender, -} - -impl AggregationActorHandle { - pub fn new(result_actor: ResultActorHandle, n_detectors: usize) -> Self { - let (tx, rx) = mpsc::channel(32); - let mut actor = AggregationActor::new(rx, result_actor, n_detectors); - tokio::spawn(async move { actor.run().await }); - Self { tx } - } - - pub async fn send(&self, chunk: Chunk, detections: Detections) { - let msg = AggregationActorMessage { chunk, detections }; - let _ = self.tx.send(msg).await; - } -} - -#[derive(Debug)] -enum GenerationActorMessage { - Put(ClassifiedGeneratedTextStreamResult), - Get { - index: usize, - response_tx: oneshot::Sender>, - }, - GetRange { - start: usize, - end: usize, - response_tx: oneshot::Sender>, - }, - Length { - response_tx: oneshot::Sender, - }, -} - -/// Consumes generations from generation stream and provides them to [`ResultActor`]. -struct GenerationActor { - rx: mpsc::Receiver, - generations: Vec, -} - -impl GenerationActor { - pub fn new(rx: mpsc::Receiver) -> Self { - let generations = Vec::new(); - Self { rx, generations } - } - - async fn run(&mut self) { - while let Some(msg) = self.rx.recv().await { - self.handle(msg); - } - } - - fn handle(&mut self, msg: GenerationActorMessage) { - match msg { - GenerationActorMessage::Put(generation) => self.generations.push(generation), - GenerationActorMessage::Get { index, response_tx } => { - let generation = self.generations.get(index).cloned(); - let _ = response_tx.send(generation); - } - GenerationActorMessage::GetRange { - start, - end, - response_tx, - } => { - let generations = self.generations[start..=end].to_vec(); - let _ = response_tx.send(generations); - } - GenerationActorMessage::Length { response_tx } => { - let _ = response_tx.send(self.generations.len()); - } - } - } -} - -/// [`GenerationActor`] handle. -struct GenerationActorHandle { - tx: mpsc::Sender, -} - -impl GenerationActorHandle { - pub fn new() -> Self { - let (tx, rx) = mpsc::channel(32); - let mut actor = GenerationActor::new(rx); - tokio::spawn(async move { actor.run().await }); - Self { tx } - } - - pub async fn put(&self, generation: ClassifiedGeneratedTextStreamResult) { - let msg = GenerationActorMessage::Put(generation); - let _ = self.tx.send(msg).await; - } - - pub async fn get(&self, index: usize) -> Option { - let (response_tx, response_rx) = oneshot::channel(); - let msg = GenerationActorMessage::Get { index, response_tx }; - let _ = self.tx.send(msg).await; - response_rx.await.unwrap() - } - - pub async fn get_range( - &self, - start: usize, - end: usize, - ) -> Vec { - let (response_tx, response_rx) = oneshot::channel(); - let msg = GenerationActorMessage::GetRange { - start, - end, - response_tx, - }; - let _ = self.tx.send(msg).await; - response_rx.await.unwrap() - } - - pub async fn len(&self) -> usize { - let (response_tx, response_rx) = oneshot::channel(); - let msg = GenerationActorMessage::Length { response_tx }; - let _ = self.tx.send(msg).await; - response_rx.await.unwrap() - } -} - -#[derive(Debug, Clone)] -pub struct TrackerEntry { - pub chunk: Chunk, - pub detections: Vec, -} - -impl TrackerEntry { - pub fn new(chunk: Chunk, detections: Detections) -> Self { - Self { - chunk, - detections: vec![detections], - } - } -} - -#[derive(Debug, Clone, Default)] -pub struct Tracker { - state: BTreeMap, -} - -impl Tracker { - pub fn new() -> Self { - Self { - state: BTreeMap::new(), - } - } - - pub fn insert(&mut self, key: Span, value: TrackerEntry) { - match self.state.entry(key) { - btree_map::Entry::Vacant(entry) => { - // New span, insert entry with chunk and detections - entry.insert(value); - } - btree_map::Entry::Occupied(mut entry) => { - // Existing span, extend detections - entry.get_mut().detections.extend(value.detections); - } - } - } - - /// Returns the key-value pair of the first span. - pub fn first_key_value(&self) -> Option<(&Span, &TrackerEntry)> { - self.state.first_key_value() - } - - /// Returns the value of the first span. - pub fn first(&self) -> Option<&TrackerEntry> { - self.state.first_key_value().map(|(_, value)| value) - } - - /// Removes and returns the key-value pair of the first span. - pub fn pop_first(&mut self) -> Option<(Span, TrackerEntry)> { - self.state.pop_first() - } - - /// Returns the number of elements in the tracker. - pub fn len(&self) -> usize { - self.state.len() - } - - /// Returns true if the tracker contains no elements. - pub fn is_empty(&self) -> bool { - self.state.is_empty() - } - - /// Gets an iterator over the keys of the tracker, in sorted order. - pub fn keys(&self) -> btree_map::Keys<'_, (i64, i64), TrackerEntry> { - self.state.keys() - } - - /// Gets an iterator over the values of the tracker, in sorted order. - pub fn values(&self) -> btree_map::Values<'_, (i64, i64), TrackerEntry> { - self.state.values() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - models::TokenClassificationResult, - pb::caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token}, - }; - - fn get_detection_obj( - span: Span, - text: &str, - detection: &str, - detection_type: &str, - detector_id: &str, - ) -> TokenClassificationResult { - TokenClassificationResult { - start: span.0 as u32, - end: span.1 as u32, - word: text.to_string(), - entity: detection.to_string(), - entity_group: detection_type.to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.99, - token_count: None, - } - } - - #[tokio::test] - /// Test to check the aggregation of streaming generation results with multiple detectors on a single chunk. - async fn test_aggregation_single_chunk_multi_detection() { - let chunks = vec![Chunk { - results: [Token { - start: 0, - end: 24, - text: "This is a dummy sentence".into(), - }] - .into(), - token_count: 5, - processed_index: 4, - start_index: 0, - input_start_index: 0, - input_end_index: 0, - }]; - - let detector_count = 2; - let mut detection_streams = Vec::with_capacity(detector_count); - - // Note: below is detection / chunks on batch of size 1 with 1 sentence - for chunk in &chunks { - let chunk_token = chunk.results[0].clone(); - let text = &chunk_token.text; - let whole_span = (chunk_token.start, chunk_token.end); - let partial_span = (chunk_token.start + 2, chunk_token.end - 2); - - let (detector_tx1, detector_rx1) = mpsc::channel(1); - let detection = get_detection_obj(whole_span, text, "has_HAP", "HAP", "en-hap"); - let _ = detector_tx1.send((chunk.clone(), vec![detection])).await; - - let (detector_tx2, detector_rx2) = mpsc::channel(1); - let detection = get_detection_obj(partial_span, text, "email_ID", "PII", "en-pii"); - let _ = detector_tx2.send((chunk.clone(), vec![detection])).await; - - // Push HAP after PII to make sure detection ordering is not coincidental - detection_streams.push(("pii-1".into(), detector_rx2)); - detection_streams.push(("hap-1".into(), detector_rx1)); - } - - let (generation_tx, generation_rx) = broadcast::channel(1); - let _ = generation_tx.send(ClassifiedGeneratedTextStreamResult::default()); - let aggregator = Aggregator::default(); - - let mut result_rx = aggregator.run(generation_rx, detection_streams); - let mut chunk_count = 0; - while let Some(result) = result_rx.recv().await { - let detection = result - .unwrap() - .token_classification_results - .output - .unwrap_or_default(); - assert_eq!(detection.len(), detector_count); - // Expect HAP first since whole_span start is before partial_span start - assert_eq!(detection[0].entity_group, "HAP"); - assert_eq!(detection[1].entity_group, "PII"); - chunk_count += 1; - } - assert_eq!(chunk_count, chunks.len()); - } - - #[test] - fn test_tracker_with_out_of_order_chunks() { - let chunks = [ - ChunkerTokenizationStreamResult { - results: [Token { - start: 0, - end: 56, - text: " a powerful tool for the development \ - of complex systems." - .into(), - }] - .to_vec(), - token_count: 0, - processed_index: 56, - start_index: 0, - input_start_index: 0, - input_end_index: 10, - }, - ChunkerTokenizationStreamResult { - results: [Token { - start: 56, - end: 135, - text: " It has been used in many fields, such as \ - computer vision and image processing." - .into(), - }] - .to_vec(), - token_count: 0, - processed_index: 135, - start_index: 56, - input_start_index: 11, - input_end_index: 26, - }, - ]; - let n_detectors = 2; - let mut tracker = Tracker::new(); - - // Insert out-of-order detection results - for (key, value) in [ - // detector 1, chunk 2 - ( - (chunks[1].start_index, chunks[1].processed_index), - TrackerEntry::new(chunks[1].clone(), vec![]), - ), - // detector 2, chunk 1 - ( - (chunks[0].start_index, chunks[0].processed_index), - TrackerEntry::new(chunks[0].clone(), vec![]), - ), - // detector 2, chunk 2 - ( - (chunks[1].start_index, chunks[1].processed_index), - TrackerEntry::new(chunks[1].clone(), vec![]), - ), - ] { - tracker.insert(key, value); - } - // We now have both detector results for chunk 2, but not chunk 1 - - // We do not have all detections for the first chunk - assert!( - tracker - .first() - .is_none_or(|first| first.detections.len() != n_detectors), - "detections length should not be 2 for first chunk" - ); - - // Insert entry for detector 1, chunk 1 - tracker.insert( - (chunks[0].start_index, chunks[0].processed_index), - TrackerEntry::new(chunks[0].clone(), vec![]), - ); - - // We have all detections for the first chunk - assert!( - tracker - .first() - .is_some_and(|first| first.detections.len() == n_detectors), - "detections length should be 2 for first chunk" - ); - - // There should be entries for 2 chunks - assert_eq!(tracker.len(), 2, "tracker length should be 2"); - - // Detections length should be 2 for each chunk - assert_eq!( - tracker - .values() - .map(|entry| entry.detections.len()) - .collect::>(), - vec![2, 2], - "detections length should be 2 for each chunk" - ); - - // The first entry should be for chunk 1 - let first_key = *tracker - .first_key_value() - .expect("tracker should have first entry") - .0; - assert_eq!( - first_key, - (chunks[0].start_index, chunks[0].processed_index), - "first should be chunk 1" - ); - - // Tracker should remove and return entry for chunk 1 - let (key, value) = tracker - .pop_first() - .expect("tracker should have first entry"); - assert!( - key.0 == 0 && value.chunk.start_index == 0, - "first should be chunk 1" - ); - assert!(tracker.len() == 1, "tracker length should be 1"); - - // Tracker should remove and return entry for chunk 2 - let (key, value) = tracker - .pop_first() - .expect("tracker should have first entry"); - assert!( - key.0 == 56 && value.chunk.start_index == 56, - "first should be chunk 2" - ); - assert!(tracker.is_empty(), "tracker should be empty"); - } -} diff --git a/src/orchestrator/streaming_content_detection.rs b/src/orchestrator/streaming_content_detection.rs deleted file mode 100644 index 0810df8e..00000000 --- a/src/orchestrator/streaming_content_detection.rs +++ /dev/null @@ -1,491 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -/////////////////////////////////////////////////////////////////////////////////// -// A lot of the code in this file is similar to src/orchestrator/streaming.rs, // -// with expection of `Orchestrator::handle_streaming_content_detection()`` and // -// `extract_detectors().` // -// The main difference in the remaining methods was the replacement of // -// `ClassifiedGeneratedTextStreamResult` with `StreamingContentDetectionRequest` // -// and `StreamingContentDetectionResponse`. // -// This can likely be improved in a future refactor to use generics instead of // -// duplicating these very similar methods. // -/////////////////////////////////////////////////////////////////////////////////// -mod aggregator; - -use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; - -use aggregator::Aggregator; -use futures::{Stream, StreamExt, TryStreamExt, future::try_join_all, stream::Peekable}; -use hyper::HeaderMap; -use tokio::sync::{broadcast, mpsc}; -use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; -use tracing::{debug, error, info, instrument, warn}; - -use super::{Context, Error, Orchestrator, StreamingContentDetectionTask, streaming::Detections}; -use crate::{ - clients::{ - TextContentsDetectorClient, - chunker::{ChunkerClient, DEFAULT_CHUNKER_ID, tokenize_whole_doc_stream}, - detector::ContentAnalysisRequest, - }, - models::{ - DetectorParams, StreamingContentDetectionRequest, StreamingContentDetectionResponse, - TokenClassificationResult, - }, - orchestrator::{common::get_chunker_ids, streaming::Chunk}, - pb::caikit::runtime::chunkers, -}; - -type ContentInputStream = - Pin> + Send>>; - -impl Orchestrator { - /// Handles content detection streaming tasks. - #[instrument(skip_all, fields(trace_id = task.trace_id.to_string(), headers = ?task.headers))] - pub async fn handle_streaming_content_detection( - &self, - task: StreamingContentDetectionTask, - ) -> ReceiverStream> { - let ctx = self.ctx.clone(); - let trace_id = task.trace_id; - let headers = task.headers; - - let mut input_stream = Box::pin(task.input_stream.peekable()); - - // Create response channel - #[allow(clippy::type_complexity)] - let (response_tx, response_rx): ( - mpsc::Sender>, - mpsc::Receiver>, - ) = mpsc::channel(32); - - // Spawn task to process input stream - tokio::spawn(async move { - let detectors = match extract_detectors(&mut input_stream).await { - Ok(detectors) => detectors, - Err(error) => { - error!(%error, "error extracting detector information from first stream frame"); - let _ = response_tx.send(Err(error)).await; - return; - } - }; - - // Create error channel - // - // This channel is used for error notification & messaging and task cancellation. - // When a task fails, it notifies other tasks by sending the error to error_tx. - // - // The parent task receives the error, logs it, forwards it to the client via response_tx, - // and terminates the task. - let (error_tx, _) = broadcast::channel(1); - - let mut result_rx = match streaming_detection_task( - &ctx, - &detectors, - input_stream, - error_tx.clone(), - headers.clone(), - ) - .await - { - Ok(result_rx) => result_rx, - Err(error) => { - error!(%trace_id, %error, "task failed"); - let _ = error_tx.send(error.clone()); - let _ = response_tx.send(Err(error)).await; - return; - } - }; - tokio::spawn(async move { - let mut error_rx = error_tx.subscribe(); - loop { - tokio::select! { - Ok(error) = error_rx.recv() => { - error!(%trace_id, %error, "task failed"); - debug!(%trace_id, "sending error to client and terminating"); - let _ = response_tx.send(Err(error)).await; - return; - }, - result = result_rx.recv() => { - match result { - Some(result) => { - debug!(%trace_id, ?result, "sending result to client"); - if (response_tx.send(result).await).is_err() { - warn!(%trace_id, "response channel closed (client disconnected), terminating task"); - // Broadcast cancellation signal to tasks - let _ = error_tx.send(Error::Cancelled); - return; - } - }, - None => { - info!(%trace_id, "task completed: stream closed"); - break; - }, - } - } - } - } - }); - }); - ReceiverStream::new(response_rx) - } -} - -/// Extracts detectors config from first message. -async fn extract_detectors( - input_stream: &mut Peekable, -) -> Result, Error> { - // We can use Peekable to get a reference to it instead of consuming the message here - // Peekable::peek() takes self: Pin<&mut Peekable<_>>, which is why we need to pin it - // https://docs.rs/futures/latest/futures/stream/struct.Peekable.html - if let Some(result) = Pin::new(input_stream).peek().await { - match result { - Ok(msg) => { - if let Some(detectors) = &msg.detectors { - if detectors.is_empty() { - return Err(Error::Validation( - "`detectors` must not be empty".to_string(), - )); - } - return Ok(detectors.clone()); - } - } - Err(error) => return Err(error.clone()), - } - } - Err(Error::Validation( - "`detectors` is required for the first message".into(), - )) -} - -/// Handles streaming output detection task. -#[instrument(skip_all)] -async fn streaming_detection_task( - ctx: &Arc, - detectors: &HashMap, - input_stream: ContentInputStream, - error_tx: broadcast::Sender, - headers: HeaderMap, -) -> Result>, Error> { - debug!(?detectors, "creating chunk broadcast streams"); - - // Create input broadcast stream - let (input_tx, input_rx) = broadcast::channel(1024); - - let chunker_ids = get_chunker_ids(ctx, detectors)?; - // Create a map of chunker_id->chunk_broadcast_stream - // This is to enable fan-out of chunk streams to potentially multiple detectors that use the same chunker. - // Each detector task will subscribe to an associated chunk stream. - let chunk_broadcast_streams = try_join_all( - chunker_ids - .into_iter() - .map(|chunker_id| { - debug!(%chunker_id, "creating chunk broadcast stream"); - let ctx = ctx.clone(); - let error_tx = error_tx.clone(); - // Subscribe to input stream - let input_rx = input_tx.subscribe(); - async move { - let chunk_tx = - chunk_broadcast_task(ctx, chunker_id.clone(), input_rx, error_tx).await?; - Ok::<(String, broadcast::Sender), Error>((chunker_id, chunk_tx)) - } - }) - .collect::>(), - ) - .await? - .into_iter() - .collect::>(); - - // Spawn detection tasks to subscribe to chunker stream, - // send requests to detector service, and send results to detection stream - debug!("spawning detection tasks"); - let mut detection_streams = Vec::with_capacity(detectors.len()); - for (detector_id, detector_params) in detectors.iter() { - // Create a mutable copy of the parameters, so that we can modify it based on processing - let mut detector_params = detector_params.clone(); - let detector_id = detector_id.to_string(); - let chunker_id = ctx - .config - .get_chunker_id(&detector_id) - .expect("chunker id is not found"); - - // Get the detector config - // TODO: Add error handling - let detector_config = ctx - .config - .detectors - .get(&detector_id) - .expect("detector config not found"); - - // Get the default threshold to use if threshold is not provided by the user - let default_threshold = detector_config.default_threshold; - let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); - - // Create detection stream - let (detector_tx, detector_rx) = mpsc::channel(1024); - // Subscribe to chunk broadcast stream - let chunk_rx = chunk_broadcast_streams - .get(&chunker_id) - .unwrap() - .subscribe(); - let error_tx = error_tx.clone(); - tokio::spawn(detection_task( - ctx.clone(), - detector_id.clone(), - detector_params, - threshold, - detector_tx, - chunk_rx, - error_tx, - headers.clone(), - )); - detection_streams.push((detector_id, detector_rx)); - } - - debug!("processing detection streams"); - let aggregator = Aggregator::default(); - let result_rx = aggregator.run(detection_streams); - - debug!("spawning input broadcast task"); - // Spawn task to consume input stream and forward to broadcast stream - tokio::spawn(input_broadcast_task( - input_stream, - input_tx, - error_tx.clone(), - )); - drop(input_rx); - - Ok(result_rx) -} - -/// Opens bi-directional stream to a chunker service -/// with input stream input and returns chunk broadcast stream. -#[instrument(skip_all, fields(chunker_id))] -async fn chunk_broadcast_task( - ctx: Arc, - chunker_id: String, - input_rx: broadcast::Receiver, - error_tx: broadcast::Sender, -) -> Result, Error> { - // Consume input stream and convert to chunker input stream - debug!("creating chunker input stream"); - // NOTE: Text gen providers can return more than 1 token in single stream object. This can create - // edge cases where the enumeration generated below may not line up with token / response boundaries. - // So the more accurate way here might be to use `Tokens` object from response, but since that is an - // optional response parameter, we are avoiding that for now. - let input_stream = BroadcastStream::new(input_rx) - .enumerate() - .map(|(token_pointer, input_result)| { - let generated_text = input_result.unwrap().content; - chunkers::BidiStreamingChunkerTokenizationTaskRequest { - text_stream: generated_text, - input_index_stream: token_pointer as i64, - } - }) - .boxed(); - debug!("creating chunker output stream"); - let id = chunker_id.clone(); // workaround for StreamExt::map_err - - let response_stream = if chunker_id == DEFAULT_CHUNKER_ID { - info!("Using default whole doc chunker"); - let (response_tx, response_rx) = mpsc::channel(1); - // Spawn task to collect input stream - tokio::spawn(async move { - // NOTE: this will not resolve until the input stream is closed - let response = tokenize_whole_doc_stream(input_stream).await; - let _ = response_tx.send(response).await; - }); - Ok(ReceiverStream::new(response_rx).boxed()) - } else { - let client = ctx.clients.get_as::(&chunker_id).unwrap(); - client - .bidi_streaming_tokenization_task_predict(&chunker_id, input_stream) - .await - }; - - let mut output_stream = response_stream - .map_err(|error| Error::ChunkerRequestFailed { - id: chunker_id.clone(), - error, - })? - .map_err(move |error| Error::ChunkerRequestFailed { - id: id.clone(), - error, - }); // maps stream errors - - // Spawn task to consume output stream forward to broadcast channel - debug!("spawning chunker broadcast task"); - let (chunk_tx, _) = broadcast::channel(1024); - tokio::spawn({ - let mut error_rx = error_tx.subscribe(); - let chunk_tx = chunk_tx.clone(); - async move { - loop { - tokio::select! { - _ = error_rx.recv() => { - warn!("cancellation signal received, terminating task"); - break - }, - result = output_stream.next() => { - match result { - Some(Ok(chunk)) => { - debug!(?chunk, "received chunk"); - let _ = chunk_tx.send(chunk); - }, - Some(Err(error)) => { - error!(%error, "chunker error, cancelling task"); - let _ = error_tx.send(error); - tokio::time::sleep(Duration::from_millis(5)).await; - break; - }, - None => { - debug!("stream closed"); - break - }, - } - } - } - } - } - }); - Ok(chunk_tx) -} - -/// Wraps a unary detector service to make it streaming. -/// Consumes chunk broadcast stream, sends unary requests to a detector service, -/// and sends chunk + responses to detection stream. -#[allow(clippy::too_many_arguments)] -#[instrument(skip_all, fields(detector_id))] -async fn detection_task( - ctx: Arc, - detector_id: String, - detector_params: DetectorParams, - threshold: f64, - detector_tx: mpsc::Sender<(Chunk, Detections)>, - mut chunk_rx: broadcast::Receiver, - error_tx: broadcast::Sender, - headers: HeaderMap, -) { - debug!(threshold, "starting task"); - let mut error_rx = error_tx.subscribe(); - - loop { - tokio::select! { - _ = error_rx.recv() => { - warn!("cancellation signal received, terminating task"); - break - }, - result = chunk_rx.recv() => { - match result { - Ok(chunk) => { - debug!(%detector_id, ?chunk, "received chunk"); - // Send request to detector service - let contents = chunk - .results - .iter() - .map(|token| token.text.clone()) - .collect::>(); - if contents.is_empty() { - debug!("empty chunk, skipping detector request."); - break; - } else { - let request = ContentAnalysisRequest::new(contents.clone(), detector_params.clone()); - let headers = headers.clone(); - debug!(%detector_id, ?request, "sending detector request"); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap_or_else(|| panic!("text contents detector client not found for {}", detector_id)); - match client.text_contents(&detector_id, request, headers) - .await - .map_err(|error| Error::DetectorRequestFailed { id: detector_id.clone(), error }) { - Ok(response) => { - debug!(%detector_id, ?response, "received detector response"); - let detections = response - .into_iter() - .flat_map(|r| { - r.into_iter().filter_map(|resp| { - let mut result: TokenClassificationResult = resp.into(); - // add detector_id - result.detector_id = Some(detector_id.clone()); - (result.score >= threshold).then_some(result) - }) - }) - .collect::>(); - let _ = detector_tx.send((chunk, detections)).await; - }, - Err(error) => { - error!(%detector_id, %error, "detector error, cancelling task"); - let _ = error_tx.send(error); - tokio::time::sleep(Duration::from_millis(5)).await; - break; - }, - } - } - }, - Err(broadcast::error::RecvError::Closed) => { - debug!(%detector_id, "stream closed"); - break; - }, - Err(broadcast::error::RecvError::Lagged(_)) => { - debug!(%detector_id, "stream lagged"); - continue; - } - } - }, - } - } -} - -/// Broadcasts messages from input stream to input broadcast channel. -/// Triggers task cancellation if an error message is received. -#[instrument(skip_all)] -async fn input_broadcast_task( - mut input_stream: ContentInputStream, - input_tx: broadcast::Sender, - error_tx: broadcast::Sender, -) { - let mut error_rx = error_tx.subscribe(); - loop { - tokio::select! { - _ = error_rx.recv() => { - warn!("cancellation signal received, terminating task"); - break - }, - result = input_stream.next() => { - match result { - Some(Ok(msg)) => { - debug!(?msg, "received message"); - let _ = input_tx.send(msg); - }, - Some(Err(error)) => { - error!(%error, "received error message, cancelling task"); - let _ = error_tx.send(error); - tokio::time::sleep(Duration::from_millis(5)).await; - break; - }, - None => { - debug!("stream closed"); - break - }, - } - } - } - } -} diff --git a/src/orchestrator/streaming_content_detection/aggregator.rs b/src/orchestrator/streaming_content_detection/aggregator.rs deleted file mode 100644 index 40a2cf57..00000000 --- a/src/orchestrator/streaming_content_detection/aggregator.rs +++ /dev/null @@ -1,277 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -////////////////////////////////////////////////////////////////////////////////////// -// This file contains a simplified version of the code in // -// src/orchestrator/streaming/aggregator.rs. // -// The main difference is that this file does not contain generation nor a // -// result actor. This code also reuses whatever is possible from the aforementioned // -// file, such as tracker and aggregation strategy structs. // -// `ClassifiedGeneratedTextStreamResult` with `StreamingContentDetectionRequest` // -// and `StreamingContentDetectionResponse`. // -// This can likely be improved in a future refactor to use generics instead of // -// duplicating these very similar methods. // -////////////////////////////////////////////////////////////////////////////////////// -#![allow(dead_code)] -use std::sync::Arc; - -use tokio::sync::mpsc; -use tracing::instrument; - -use crate::{ - clients::detector::ContentAnalysisResponse, - models::{Metadata, StreamingContentDetectionResponse}, - orchestrator::{ - Error, - streaming::{ - Chunk, Detections, - aggregator::{AggregationStrategy, DetectorId, Tracker, TrackerEntry}, - }, - }, -}; - -pub struct Aggregator { - strategy: AggregationStrategy, -} - -impl Default for Aggregator { - fn default() -> Self { - Self { - strategy: AggregationStrategy::MaxProcessedIndex, - } - } -} - -impl Aggregator { - pub fn new(strategy: AggregationStrategy) -> Self { - Self { strategy } - } - - #[instrument(skip_all)] - pub fn run( - &self, - detection_streams: Vec<(DetectorId, mpsc::Receiver<(Chunk, Detections)>)>, - ) -> mpsc::Receiver> { - // Create result channel - let (result_tx, result_rx) = mpsc::channel(32); - - // Create actors - let aggregation_actor = Arc::new(AggregationActorHandle::new( - result_tx, - detection_streams.len(), - )); - - // Spawn tasks to process detection streams concurrently - for (_detector_id, mut stream) in detection_streams { - let aggregation_actor = aggregation_actor.clone(); - tokio::spawn(async move { - while let Some((chunk, detections)) = stream.recv().await { - // Send to aggregation actor - aggregation_actor.send(chunk, detections).await; - } - }); - } - result_rx - } -} - -#[derive(Debug)] -struct AggregationActorMessage { - pub chunk: Chunk, - pub detections: Detections, -} - -/// Aggregates detections, builds results, and sends them to result channel. -struct AggregationActor { - rx: mpsc::Receiver, - result_tx: mpsc::Sender>, - tracker: Tracker, - n_detectors: usize, -} - -impl AggregationActor { - pub fn new( - rx: mpsc::Receiver, - result_tx: mpsc::Sender>, - n_detectors: usize, - ) -> Self { - let tracker = Tracker::new(); - Self { - rx, - result_tx, - tracker, - n_detectors, - } - } - - async fn run(&mut self) { - while let Some(msg) = self.rx.recv().await { - self.handle(msg).await; - } - } - - async fn handle(&mut self, msg: AggregationActorMessage) { - let chunk = msg.chunk; - let detections = msg.detections; - - // Add to tracker - let span = (chunk.start_index, chunk.processed_index); - self.tracker - .insert(span, TrackerEntry::new(chunk, detections)); - - // Check if we have all detections for the first span - if self - .tracker - .first() - .is_some_and(|first| first.detections.len() == self.n_detectors) - { - // Take first span and send result - if let Some((_key, value)) = self.tracker.pop_first() { - let chunk = value.chunk; - let mut detections: Detections = value.detections.into_iter().flatten().collect(); - // Provide sorted detections within each chunk - detections.sort_by_key(|r| r.start); - - // Build response message - let response = StreamingContentDetectionResponse { - start_index: chunk.start_index as u32, - processed_index: chunk.processed_index as u32, - detections: detections - .into_iter() - .map(|r| ContentAnalysisResponse { - start: r.start as usize, - end: r.end as usize, - text: r.word, - detection: r.entity, - detection_type: r.entity_group, - detector_id: r.detector_id, - score: r.score, - evidence: None, - metadata: Metadata::new(), - }) - .collect(), - }; - // Send to result channel - let _ = self.result_tx.send(Ok(response)).await; - } - } - } -} - -/// [`AggregationActor`] handle. -struct AggregationActorHandle { - tx: mpsc::Sender, -} - -impl AggregationActorHandle { - pub fn new( - result_tx: mpsc::Sender>, - n_detectors: usize, - ) -> Self { - let (tx, rx) = mpsc::channel(32); - let mut actor = AggregationActor::new(rx, result_tx, n_detectors); - tokio::spawn(async move { actor.run().await }); - Self { tx } - } - - pub async fn send(&self, chunk: Chunk, detections: Detections) { - let msg = AggregationActorMessage { chunk, detections }; - let _ = self.tx.send(msg).await; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - models::TokenClassificationResult, orchestrator::streaming::aggregator::Span, - pb::caikit_data_model::nlp::Token, - }; - - fn get_detection_obj( - span: Span, - text: &str, - detection: &str, - detection_type: &str, - detector_id: &str, - ) -> TokenClassificationResult { - TokenClassificationResult { - start: span.0 as u32, - end: span.1 as u32, - word: text.to_string(), - entity: detection.to_string(), - entity_group: detection_type.to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.99, - token_count: None, - } - } - - #[tokio::test] - /// Test to check the aggregation of streaming generation results with multiple detectors on a single chunk. - async fn test_aggregation_single_chunk_multi_detection() { - let chunks = vec![Chunk { - results: [Token { - start: 0, - end: 24, - text: "This is a dummy sentence".into(), - }] - .into(), - token_count: 5, - processed_index: 4, - start_index: 0, - input_start_index: 0, - input_end_index: 0, - }]; - - let detector_count = 2; - let mut detection_streams = Vec::with_capacity(detector_count); - - // Note: below is detection / chunks on batch of size 1 with 1 sentence - for chunk in &chunks { - let chunk_token = chunk.results[0].clone(); - let text = &chunk_token.text; - let whole_span = (chunk_token.start, chunk_token.end); - let partial_span = (chunk_token.start + 2, chunk_token.end - 2); - - let (detector_tx1, detector_rx1) = mpsc::channel(1); - let detection = get_detection_obj(whole_span, text, "has_HAP", "HAP", "en-hap"); - let _ = detector_tx1.send((chunk.clone(), vec![detection])).await; - - let (detector_tx2, detector_rx2) = mpsc::channel(1); - let detection = get_detection_obj(partial_span, text, "email_ID", "PII", "en-pii"); - let _ = detector_tx2.send((chunk.clone(), vec![detection])).await; - - // Push HAP after PII to make sure detection ordering is not coincidental - detection_streams.push(("pii-1".into(), detector_rx2)); - detection_streams.push(("hap-1".into(), detector_rx1)); - } - - let aggregator = Aggregator::new(AggregationStrategy::MaxProcessedIndex); - let mut result_rx = aggregator.run(detection_streams); - let mut chunk_count = 0; - while let Some(result) = result_rx.recv().await { - let detection = result.unwrap().detections; - assert_eq!(detection.len(), detector_count); - // Expect HAP first since whole_span start is before partial_span start - assert_eq!(detection[0].detection_type, "HAP"); - assert_eq!(detection[1].detection_type, "PII"); - chunk_count += 1; - } - assert_eq!(chunk_count, chunks.len()); - } -} diff --git a/src/orchestrator/types/detection.rs b/src/orchestrator/types/detection.rs index b6fdfebb..fdf9e7cf 100644 --- a/src/orchestrator/types/detection.rs +++ b/src/orchestrator/types/detection.rs @@ -206,6 +206,27 @@ impl From for Detection { } } +impl From for models::DetectionResult { + fn from(value: Detection) -> Self { + let evidence = (!value.evidence.is_empty()) + .then_some(value.evidence.into_iter().map(Into::into).collect()); + Self { + detection_type: value.detection_type, + detection: value.detection, + detector_id: value.detector_id, + score: value.score, + evidence, + metadata: value.metadata, + } + } +} + +impl From for Vec { + fn from(value: Detections) -> Self { + value.into_iter().map(Into::into).collect() + } +} + impl From> for Detections { fn from(value: Vec) -> Self { value.into_iter().map(Into::into).collect() diff --git a/src/orchestrator/unary.rs b/src/orchestrator/unary.rs deleted file mode 100644 index 1dc21e1d..00000000 --- a/src/orchestrator/unary.rs +++ /dev/null @@ -1,1459 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -use std::{collections::HashMap, sync::Arc}; - -use axum::http::HeaderMap; -use futures::{ - future::try_join_all, - stream::{self, StreamExt}, -}; -use tracing::{Instrument, Span, debug, error, info, instrument}; - -use super::{ - ChatDetectionTask, Chunk, ClassificationWithGenTask, Context, ContextDocsDetectionTask, - DetectionOnGenerationTask, Error, GenerationWithDetectionTask, Orchestrator, - TextContentDetectionTask, -}; -use crate::{ - clients::{ - GenerationClient, - chunker::{ChunkerClient, DEFAULT_CHUNKER_ID, tokenize_whole_doc}, - detector::{ - ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse, - ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest, - TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient, - TextGenerationDetectorClient, - }, - openai::{Message, Tool}, - }, - models::{ - ChatDetectionResult, ClassifiedGeneratedTextResult, ContextDocsResult, - DetectionOnGenerationResult, DetectionResult, DetectionWarning, DetectionWarningReason, - DetectorParams, GenerationWithDetectionResult, GuardrailsTextGenerationParameters, - TextContentDetectionResult, TextGenTokenClassificationResults, TokenClassificationResult, - UNSUITABLE_INPUT_MESSAGE, - }, - orchestrator::common::{apply_masks, get_chunker_ids}, - pb::caikit::runtime::chunkers, -}; - -const DEFAULT_STREAM_BUFFER_SIZE: usize = 5; - -impl Orchestrator { - /// Handles unary tasks. - #[instrument(skip_all, fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers))] - pub async fn handle_classification_with_gen( - &self, - task: ClassificationWithGenTask, - ) -> Result { - let ctx = self.ctx.clone(); - let trace_id = task.trace_id; - let headers = task.headers; - info!(config = ?task.guardrails_config, "handling classification with generation task"); - let task_handle = tokio::spawn( - async move { - let input_text = task.inputs.clone(); - let masks = task.guardrails_config.input_masks(); - let input_detectors = task.guardrails_config.input_detectors(); - // Do input detections - let input_detections = match input_detectors { - Some(detectors) if !detectors.is_empty() => { - input_detection_task( - &ctx, - detectors, - input_text.clone(), - masks, - headers.clone(), - ) - .await? - } - _ => None, - }; - debug!(?input_detections); - if let Some(mut input_detections) = input_detections { - // Detected HAP/PII - // Do tokenization to get input_token_count - let (input_token_count, _tokens) = tokenize( - &ctx, - task.model_id.clone(), - task.inputs.clone(), - headers.clone(), - ) - .await?; - // Send result with input detections - input_detections.sort_by_key(|r| r.start); - Ok(ClassifiedGeneratedTextResult { - input_token_count, - token_classification_results: TextGenTokenClassificationResults { - input: Some(input_detections), - output: None, - }, - warnings: Some(vec![DetectionWarning { - id: Some(DetectionWarningReason::UnsuitableInput), - message: Some(UNSUITABLE_INPUT_MESSAGE.to_string()), - }]), - ..Default::default() - }) - } else { - // No HAP/PII detected - // Do text generation - let mut generation_results = generate( - &ctx, - task.model_id.clone(), - task.inputs.clone(), - task.text_gen_parameters.clone(), - headers.clone(), - ) - .await?; - debug!(?generation_results); - // Do output detections - let output_detectors = task.guardrails_config.output_detectors(); - let output_detections = match output_detectors { - Some(detectors) if !detectors.is_empty() => { - let generated_text = generation_results - .generated_text - .clone() - .unwrap_or_default(); - output_detection_task(&ctx, detectors, generated_text, headers).await? - } - _ => None, - }; - debug!(?output_detections); - if let Some(mut output_detections) = output_detections { - output_detections.sort_by_key(|r| r.start); - generation_results.token_classification_results.output = - Some(output_detections); - } - Ok(generation_results) - } - } - .instrument(Span::current()), - ); - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => { - debug!(%trace_id, ?result, "sending result to client"); - info!(%trace_id, "task completed"); - Ok(result) - } - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - error!(%trace_id, %error, "task failed"); - Err(error) - } - // Task cancelled or panicked - Err(error) => { - let error = error.into(); - error!(%trace_id, %error, "task failed"); - Err(error) - } - } - } - - /// Handles the given generation task, followed by detections. - #[instrument(skip_all, fields(trace_id = ?task.trace_id, model_id = task.model_id, headers = ?task.headers))] - pub async fn handle_generation_with_detection( - &self, - task: GenerationWithDetectionTask, - ) -> Result { - info!( - detectors = ?task.detectors, - "handling generation with detection task" - ); - let ctx = self.ctx.clone(); - let headers = task.headers; - let task_handle = tokio::spawn( - async move { - let generation_results = generate( - &ctx, - task.model_id.clone(), - task.prompt.clone(), - task.text_gen_parameters.clone(), - headers.clone(), - ) - .await?; - - // call detection - let detections = try_join_all( - task.detectors - .iter() - .map(|(detector_id, detector_params)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - let prompt = task.prompt.clone(); - let generated_text = generation_results - .generated_text - .clone() - .unwrap_or_default(); - async { - detect_for_generation( - ctx, - detector_id, - detector_params, - prompt, - generated_text, - headers.clone(), - ) - .await - } - }) - .collect::>(), - ) - .await? - .into_iter() - .flatten() - .collect::>(); - - debug!(?generation_results); - Ok(GenerationWithDetectionResult { - generated_text: generation_results.generated_text.unwrap_or_default(), - input_token_count: generation_results.input_token_count, - detections, - }) - } - .instrument(Span::current()), - ); - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => Ok(result), - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - error!(trace_id = ?task.trace_id, %error, "generation with detection unary task failed"); - Err(error) - } - // Task cancelled or panicked - Err(error) => { - let error = error.into(); - error!(trace_id = ?task.trace_id, %error, "generation with detection unary task failed"); - Err(error) - } - } - } - - /// Handles detection on textual content - #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] - pub async fn handle_text_content_detection( - &self, - task: TextContentDetectionTask, - ) -> Result { - info!("handling text content detection task"); - - let ctx = self.ctx.clone(); - let headers = task.headers; - - let task_handle = tokio::spawn( - async move { - let content = task.content.clone(); - // No masking applied, so offset change is 0 - let offset: usize = 0; - let text_with_offsets = [(offset, content)].to_vec(); - - let detectors = task.detectors.clone(); - - let chunker_ids = get_chunker_ids(&ctx, &detectors)?; - let chunks = chunk_task(&ctx, chunker_ids, text_with_offsets).await?; - - // Call detectors - let mut detections = try_join_all( - task.detectors - .iter() - .map(|(detector_id, detector_params)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - let detector_config = - ctx.config.detectors.get(&detector_id).unwrap_or_else(|| { - panic!("detector config not found for {}", detector_id) - }); - - let chunker_id = detector_config.chunker_id.as_str(); - - let default_threshold = detector_config.default_threshold; - - let chunk = chunks - .get(chunker_id) - .unwrap_or_else(|| panic!("chunk not found for {}", chunker_id)) - .clone(); - - let headers = headers.clone(); - - async move { - detect_content( - ctx, - detector_id, - default_threshold, - detector_params, - chunk, - headers, - ) - .await - } - }) - .collect::>(), - ) - .await? - .into_iter() - .flatten() - .collect::>(); - - detections.sort_by_key(|r| r.start); - // Send result with detections - Ok(TextContentDetectionResult { detections }) - } - .instrument(Span::current()), - ); - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => Ok(result), - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - error!(trace_id = ?task.trace_id, %error, "text content detection task failed"); - Err(error) - } - // Task cancelled or panicked - Err(error) => { - let error = error.into(); - error!(trace_id = ?task.trace_id, %error, "text content detection task failed"); - Err(error) - } - } - } - - /// Handles context-related detections on textual content - #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] - pub async fn handle_context_documents_detection( - &self, - task: ContextDocsDetectionTask, - ) -> Result { - info!( - detectors = ?task.detectors, - "handling context documents detection task" - ); - let ctx = self.ctx.clone(); - let headers = task.headers; - let task_handle = tokio::spawn( - async move { - // call detection - let detections = try_join_all( - task.detectors - .iter() - .map(|(detector_id, detector_params)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - let content = task.content.clone(); - let context_type = task.context_type.clone(); - let context = task.context.clone(); - let headers = headers.clone(); - - async { - detect_for_context( - ctx, - detector_id, - detector_params, - content, - context_type, - context, - headers, - ) - .await - } - }) - .collect::>(), - ) - .await? - .into_iter() - .flatten() - .collect::>(); - - Ok(ContextDocsResult { detections }) - } - .instrument(Span::current()), - ); - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => Ok(result), - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - error!(trace_id = ?task.trace_id, %error, "context documents detection task failed"); - Err(error) - } - // Task cancelled or panicked - Err(error) => { - let error = error.into(); - error!(trace_id = ?task.trace_id, %error, "context documents detection task failed"); - Err(error) - } - } - } - - /// Handles detections on generated text (without performing generation) - #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] - pub async fn handle_generated_text_detection( - &self, - task: DetectionOnGenerationTask, - ) -> Result { - info!( - detectors = ?task.detectors, - "handling detection on generated content task" - ); - let ctx = self.ctx.clone(); - let headers = task.headers; - - let task_handle = tokio::spawn( - async move { - // call detection - let detections = try_join_all( - task.detectors - .iter() - .map(|(detector_id, detector_params)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - let prompt = task.prompt.clone(); - let generated_text = task.generated_text.clone(); - let headers = headers.clone(); - async { - detect_for_generation( - ctx, - detector_id, - detector_params, - prompt, - generated_text, - headers, - ) - .await - } - }) - .collect::>(), - ) - .await? - .into_iter() - .flatten() - .collect::>(); - - Ok(DetectionOnGenerationResult { detections }) - } - .instrument(Span::current()), - ); - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => Ok(result), - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - error!(trace_id = ?task.trace_id, %error, "detection on generated content task failed"); - Err(error) - } - // Task cancelled or panicked - Err(error) => { - let error = error.into(); - error!(trace_id = ?task.trace_id, %error, "detection on generated content task failed"); - Err(error) - } - } - } - - /// Handles detections on chat messages (without performing generation) - #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] - pub async fn handle_chat_detection( - &self, - task: ChatDetectionTask, - ) -> Result { - info!( - detectors = ?task.detectors, - "handling detection on chat content task" - ); - let ctx = self.ctx.clone(); - let headers = task.headers; - - let task_handle = tokio::spawn( - async move { - // call detection - let detections = try_join_all( - task.detectors - .iter() - .map(|(detector_id, detector_params)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - let messages = task.messages.clone(); - let tools = task.tools.clone(); - let headers = headers.clone(); - async { - detect_for_chat( - ctx, - detector_id, - detector_params, - messages, - tools, - headers, - ) - .await - } - }) - .collect::>(), - ) - .await? - .into_iter() - .flatten() - .collect::>(); - - Ok(ChatDetectionResult { detections }) - } - .instrument(Span::current()), - ); - match task_handle.await { - // Task completed successfully - Ok(Ok(result)) => Ok(result), - // Task failed, return error propagated from child task that failed - Ok(Err(error)) => { - error!(%error, "detection task on chat failed"); - Err(error) - } - // Task cancelled or panicked - Err(error) => { - let error = error.into(); - error!(%error, "detection task on chat failed"); - Err(error) - } - } - } -} - -/// Handles input detection task. -#[instrument(skip_all)] -pub async fn input_detection_task( - ctx: &Arc, - detectors: &HashMap, - input_text: String, - masks: Option<&[(usize, usize)]>, - headers: HeaderMap, -) -> Result>, Error> { - debug!(?detectors, "starting input detection"); - let text_with_offsets = apply_masks(input_text, masks); - let chunker_ids = get_chunker_ids(ctx, detectors)?; - let chunks = chunk_task(ctx, chunker_ids, text_with_offsets).await?; - let detections = detection_task(ctx, detectors, chunks, headers).await?; - Ok((!detections.is_empty()).then_some(detections)) -} - -/// Handles output detection task. -#[instrument(skip_all)] -async fn output_detection_task( - ctx: &Arc, - detectors: &HashMap, - generated_text: String, - headers: HeaderMap, -) -> Result>, Error> { - debug!(detectors = ?detectors.keys(), "starting output detection"); - let text_with_offsets = apply_masks(generated_text, None); - let chunker_ids = get_chunker_ids(ctx, detectors)?; - let chunks = chunk_task(ctx, chunker_ids, text_with_offsets).await?; - let detections = detection_task(ctx, detectors, chunks, headers).await?; - Ok((!detections.is_empty()).then_some(detections)) -} - -/// Handles detection task. -#[instrument(skip_all)] -async fn detection_task( - ctx: &Arc, - detectors: &HashMap, - chunks: HashMap>, - headers: HeaderMap, -) -> Result, Error> { - debug!(detectors = ?detectors.keys(), "handling detection tasks"); - // Spawn tasks for each detector - let tasks = detectors - .iter() - .map(|(detector_id, detector_params)| { - let ctx = ctx.clone(); - let detector_id = detector_id.clone(); - let detector_params = detector_params.clone(); - // Get the detector config - let detector_config = ctx - .config - .detectors - .get(&detector_id) - .ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))?; - // Get the default threshold to use if threshold is not provided by the user - let default_threshold = detector_config.default_threshold; - // Get chunker for detector - let chunker_id = detector_config.chunker_id.as_str(); - let chunks = chunks.get(chunker_id).unwrap().clone(); - let headers = headers.clone(); - Ok(tokio::spawn( - async move { - detect( - ctx, - detector_id, - default_threshold, - detector_params, - chunks, - headers, - ) - .await - } - .instrument(Span::current()), - )) - }) - .collect::, Error>>()?; - let results = try_join_all(tasks) - .await? - .into_iter() - .collect::, Error>>()? - .into_iter() - .flatten() - .collect::>(); - Ok(results) -} - -/// Handles chunk task. -#[instrument(skip_all)] -pub async fn chunk_task( - ctx: &Arc, - chunker_ids: Vec, - text_with_offsets: Vec<(usize, String)>, -) -> Result>, Error> { - debug!(?chunker_ids, "handling chunk task"); - // Spawn tasks for each chunker - let tasks = chunker_ids - .into_iter() - .map(|chunker_id| { - let ctx = ctx.clone(); - let text_with_offsets = text_with_offsets.clone(); - tokio::spawn( - async move { chunk_parallel(&ctx, chunker_id, text_with_offsets).await } - .instrument(Span::current()), - ) - }) - .collect::>(); - let results = try_join_all(tasks) - .await? - .into_iter() - .collect::, Error>>()?; - Ok(results) -} - -/// Sends a request to a detector service and applies threshold. -#[instrument(skip_all, fields(detector_id))] -pub async fn detect( - ctx: Arc, - detector_id: String, - default_threshold: f64, - mut detector_params: DetectorParams, - chunks: Vec, - headers: HeaderMap, -) -> Result, Error> { - let detector_id = detector_id.clone(); - let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); - let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect(); - let response = if contents.is_empty() { - // skip detector call as contents is empty - Vec::default() - } else { - let request = ContentAnalysisRequest::new(contents, detector_params); - debug!(?request, "sending detector request"); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); - client - .text_contents(&detector_id, request, headers) - .await - .map_err(|error| { - debug!(?error, "error received from detector"); - Error::DetectorRequestFailed { - id: detector_id.clone(), - error, - } - })? - }; - debug!(?response, "received detector response"); - if chunks.len() != response.len() { - return Err(Error::Other(format!( - "Detector {detector_id} did not return expected number of responses" - ))); - } - let results = chunks - .into_iter() - .zip(response) - .flat_map(|(chunk, response)| { - response - .into_iter() - .filter_map(|resp| { - let mut result: TokenClassificationResult = resp.into(); - // add detector_id - result.detector_id = Some(detector_id.clone()); - result.start += chunk.offset as u32; - result.end += chunk.offset as u32; - (result.score >= threshold).then_some(result) - }) - .collect::>() - }) - .collect::>(); - Ok::, Error>(results) -} - -/// Sends a request to a detector service and applies threshold. -/// TODO: Cleanup by removing duplicate code and merging it with above `detect` function -#[instrument(skip_all, fields(detector_id))] -pub async fn detect_content( - ctx: Arc, - detector_id: String, - default_threshold: f64, - mut detector_params: DetectorParams, - chunks: Vec, - headers: HeaderMap, -) -> Result, Error> { - let detector_id = detector_id.clone(); - let threshold = detector_params.pop_threshold().unwrap_or(default_threshold); - let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect(); - let response = if contents.is_empty() { - // skip detector call as contents is empty - Vec::default() - } else { - let request = ContentAnalysisRequest::new(contents, detector_params); - debug!(?request, threshold, "sending detector request"); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); - client - .text_contents(&detector_id, request, headers) - .await - .map_err(|error| { - debug!(?error, "error received from detector"); - Error::DetectorRequestFailed { - id: detector_id.clone(), - error, - } - })? - }; - debug!(%detector_id, ?response, "received detector response"); - if chunks.len() != response.len() { - return Err(Error::Other(format!( - "Detector {detector_id} did not return expected number of responses" - ))); - } - let results = chunks - .into_iter() - .zip(response) - .flat_map(|(chunk, response)| { - response - .into_iter() - .filter_map(|mut resp| { - resp.start += chunk.offset; - resp.end += chunk.offset; - // add detector_id - resp.detector_id = Some(detector_id.clone()); - (resp.score >= threshold).then_some(resp) - }) - .collect::>() - }) - .collect::>(); - Ok::, Error>(results) -} - -/// Calls a detector that implements the /api/v1/text/generation endpoint -#[instrument(skip_all, fields(detector_id))] -pub async fn detect_for_generation( - ctx: Arc, - detector_id: String, - mut detector_params: DetectorParams, - prompt: String, - generated_text: String, - headers: HeaderMap, -) -> Result, Error> { - let detector_id = detector_id.clone(); - let threshold = detector_params.pop_threshold().unwrap_or( - detector_params.pop_threshold().unwrap_or( - ctx.config - .detectors - .get(&detector_id) - .ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))? - .default_threshold, - ), - ); - let request = - GenerationDetectionRequest::new(prompt.clone(), generated_text.clone(), detector_params); - debug!(threshold, ?request, "sending generation detector request"); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap_or_else(|| { - panic!( - "text generation detector client not found for {}", - detector_id - ) - }); - let response = client - .text_generation(&detector_id, request, headers) - .await - .map(|results| { - results - .into_iter() - .filter(|detection| detection.score > threshold) - .map(|mut detection| { - // add detector_id - detection.detector_id = Some(detector_id.clone()); - detection - }) - .collect() - }) - .map_err(|error| Error::DetectorRequestFailed { - id: detector_id.clone(), - error, - })?; - debug!(?response, "received generation detector response"); - Ok::, Error>(response) -} - -/// Calls a detector that implements the /api/v1/text/chat endpoint -pub async fn detect_for_chat( - ctx: Arc, - detector_id: String, - mut detector_params: DetectorParams, - messages: Vec, - tools: Vec, - headers: HeaderMap, -) -> Result, Error> { - let detector_id = detector_id.clone(); - let threshold = detector_params.pop_threshold().unwrap_or( - detector_params.pop_threshold().unwrap_or( - ctx.config - .detectors - .get(&detector_id) - .ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))? - .default_threshold, - ), - ); - let request = ChatDetectionRequest::new(messages, tools, detector_params); - debug!(%detector_id, ?request, "sending chat detector request"); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); - let response = client - .text_chat(&detector_id, request, headers) - .await - .map(|results| { - results - .into_iter() - .filter(|detection| detection.score > threshold) - .map(|mut detection| { - //add detector_id - detection.detector_id = Some(detector_id.clone()); - detection - }) - .collect() - }) - .map_err(|error| Error::DetectorRequestFailed { - id: detector_id.clone(), - error, - })?; - debug!(%detector_id, ?response, "received chat detector response"); - Ok::, Error>(response) -} - -/// Calls a detector that implements the /api/v1/text/doc endpoint -#[instrument(skip_all, fields(detector_id))] -pub async fn detect_for_context( - ctx: Arc, - detector_id: String, - mut detector_params: DetectorParams, - content: String, - context_type: ContextType, - context: Vec, - headers: HeaderMap, -) -> Result, Error> { - let detector_id = detector_id.clone(); - let threshold = detector_params.pop_threshold().unwrap_or( - detector_params.pop_threshold().unwrap_or( - ctx.config - .detectors - .get(&detector_id) - .ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))? - .default_threshold, - ), - ); - let request = - ContextDocsDetectionRequest::new(content, context_type, context, detector_params.clone()); - debug!( - ?request, - threshold, - ?detector_params, - "sending context detector request" - ); - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap_or_else(|| { - panic!( - "text context doc detector client not found for {}", - detector_id - ) - }); - let response = client - .text_context_doc(&detector_id, request, headers) - .await - .map(|results| { - results - .into_iter() - .filter(|detection| detection.score > threshold) - .map(|mut detection| { - //add detector_id - detection.detector_id = Some(detector_id.clone()); - detection - }) - .collect() - }) - .map_err(|error| Error::DetectorRequestFailed { - id: detector_id.clone(), - error, - })?; - debug!(%detector_id, ?response, "received context detector response"); - Ok::, Error>(response) -} - -/// Sends request to chunker service. -#[instrument(skip_all, fields(chunker_id))] -pub async fn chunk( - ctx: &Arc, - chunker_id: String, - offset: usize, - text: String, -) -> Result, Error> { - let request = chunkers::ChunkerTokenizationTaskRequest { text }; - debug!(?request, offset, "sending chunk request"); - let response = if chunker_id == DEFAULT_CHUNKER_ID { - tokenize_whole_doc(request) - } else { - let client = ctx.clients.get_as::(&chunker_id).unwrap(); - client - .tokenization_task_predict(&chunker_id, request) - .await - .map_err(|error| Error::ChunkerRequestFailed { - id: chunker_id.clone(), - error, - })? - }; - - debug!(?response, "received chunker response"); - Ok(response - .results - .into_iter() - .map(|token| Chunk { - offset: offset + token.start as usize, - text: token.text, - }) - .collect::>()) -} - -/// Sends parallel requests to a chunker service. -#[instrument(skip_all, fields(chunker_id))] -pub async fn chunk_parallel( - ctx: &Arc, - chunker_id: String, - text_with_offsets: Vec<(usize, String)>, -) -> Result<(String, Vec), Error> { - debug!("sending parallel chunk requests"); - let chunks = stream::iter(text_with_offsets) - .map(|(offset, text)| { - let ctx = ctx.clone(); - let chunker_id = chunker_id.clone(); - async move { - let results = chunk(&ctx, chunker_id, offset, text).await?; - Ok::, Error>(results) - } - }) - .buffered(DEFAULT_STREAM_BUFFER_SIZE) - .collect::>() - .await - .into_iter() - .collect::, Error>>()? - .into_iter() - .flatten() - .collect::>(); - Ok((chunker_id, chunks)) -} - -/// Sends tokenize request to a generation service. -#[instrument(skip_all, fields(model_id))] -pub async fn tokenize( - ctx: &Arc, - model_id: String, - text: String, - headers: HeaderMap, -) -> Result<(u32, Vec), Error> { - debug!("sending tokenize request"); - let client = ctx - .clients - .get_as::("generation") - .unwrap(); - client - .tokenize(model_id.clone(), text, headers) - .await - .map_err(|error| Error::TokenizeRequestFailed { - id: model_id, - error, - }) -} - -/// Sends generate request to a generation service. -#[instrument(skip_all, fields(model_id))] -async fn generate( - ctx: &Arc, - model_id: String, - text: String, - params: Option, - headers: HeaderMap, -) -> Result { - debug!("sending generate request"); - let client = ctx - .clients - .get_as::("generation") - .unwrap(); - client - .generate(model_id.clone(), text, params, headers) - .await - .map_err(|error| Error::GenerateRequestFailed { - id: model_id, - error, - }) -} - -#[cfg(test)] -mod tests { - use hyper::{HeaderMap, StatusCode}; - - use super::*; - use crate::{ - clients::{ - self, ClientMap, GenerationClient, TgisClient, - detector::{ContentAnalysisResponse, GenerationDetectionRequest}, - }, - config::{DetectorConfig, OrchestratorConfig}, - models::{DetectionResult, EvidenceObj, FinishReason, Metadata, THRESHOLD_PARAM}, - pb::fmaas::{ - BatchedGenerationRequest, BatchedGenerationResponse, GenerationRequest, - GenerationResponse, StopReason, - }, - }; - - // Test for TGIS generation with default parameter - #[tokio::test] - async fn test_tgis_generate_with_default_params() { - let mut tgis_client = TgisClient::faux(); - - let sample_text = String::from("sample text"); - let text_gen_model_id = String::from("test-llm-id-1"); - - let generation_response = GenerationResponse { - text: String::from("sample response worked"), - stop_reason: StopReason::EosToken.into(), - stop_sequence: String::from("\n"), - generated_token_count: 3, - seed: 7, - ..Default::default() - }; - - let client_generation_response = BatchedGenerationResponse { - responses: [generation_response].to_vec(), - }; - - let expected_generate_req_args = BatchedGenerationRequest { - model_id: text_gen_model_id.clone(), - prefix_id: None, - requests: [GenerationRequest { - text: sample_text.clone(), - }] - .to_vec(), - params: None, - }; - - let expected_generate_response = ClassifiedGeneratedTextResult { - generated_text: Some(client_generation_response.responses[0].text.clone()), - finish_reason: Some(FinishReason::EosToken), - generated_token_count: Some(3), - seed: Some(7), - ..Default::default() - }; - - // Construct a behavior for the mock object - faux::when!(tgis_client.generate(expected_generate_req_args, HeaderMap::new())) - .once() // TODO: Add with_args - .then_return(Ok(client_generation_response)); - - let generation_client = GenerationClient::tgis(tgis_client.clone()); - - let mut clients = ClientMap::new(); - clients.insert("generation".into(), generation_client); - let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); - - // Test request formulation and response processing is as expected - assert_eq!( - generate(&ctx, text_gen_model_id, sample_text, None, HeaderMap::new()) - .await - .unwrap(), - expected_generate_response - ); - } - - /// This test checks if calls to detectors are being handled appropriately. - /// It receives an input of two chunks. The first sentence does not contain a - /// detection. The second one does. - /// - /// The idea behind this test case is to test that... - /// 1. offsets are calculated correctly. - /// 2. detections below the threshold are not returned to the client. - #[tokio::test] - async fn test_handle_detection_task() { - let generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut detector_client = TextContentsDetectorClient::faux(); - - let detector_id = "mocked_hap_detector"; - let threshold = 0.5; - // Input: "I don't like potatoes. I hate aliens."; - let first_sentence = "I don't like potatoes.".to_string(); - let second_sentence = "I hate aliens.".to_string(); - let mut detector_params = DetectorParams::new(); - detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); - let chunks = vec![ - Chunk { - offset: 0, - text: first_sentence.clone(), - }, - Chunk { - offset: 23, - text: second_sentence.clone(), - }, - ]; - - // Since only the second chunk has a detection, we only expect one detection in the output. - let expected_response: Vec = vec![TokenClassificationResult { - start: 23, - end: 37, - word: second_sentence.clone(), - entity: "has_HAP".to_string(), - entity_group: "hap".to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.9, - token_count: None, - }]; - - faux::when!(detector_client.text_contents( - detector_id, - ContentAnalysisRequest::new( - vec![first_sentence.clone(), second_sentence.clone()], - DetectorParams::new() - ), - HeaderMap::new(), - )) - .once() - .then_return(Ok(vec![ - vec![ContentAnalysisResponse { - start: 0, - end: 22, - text: first_sentence.clone(), - detection: "has_HAP".to_string(), - detection_type: "hap".to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.1, - evidence: Some(vec![]), - metadata: Metadata::new(), - }], - vec![ContentAnalysisResponse { - start: 0, - end: 14, - text: second_sentence.clone(), - detection: "has_HAP".to_string(), - detection_type: "hap".to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.9, - evidence: Some(vec![]), - metadata: Metadata::new(), - }], - ])); - - let mut clients = ClientMap::new(); - clients.insert("generation".into(), generation_client); - clients.insert(detector_id.into(), detector_client); - let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); - - assert_eq!( - detect( - ctx, - detector_id.to_string(), - threshold, - detector_params, - chunks, - HeaderMap::new(), - ) - .await - .unwrap(), - expected_response - ); - } - - /// This test checks if calls to detectors returning 503 are being propagated in the orchestrator response. - #[tokio::test] - async fn test_detect_when_detector_returns_503() { - let generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut detector_client = TextContentsDetectorClient::faux(); - - let detector_id = "mocked_503_detector"; - let sentence = "This call will return a 503.".to_string(); - let threshold = 0.5; - let mut detector_params = DetectorParams::new(); - detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); - let chunks = vec![Chunk { - offset: 0, - text: sentence.clone(), - }]; - - // We expect the detector call to return a 503, with a response complying with the error response. - let expected_response = Error::DetectorRequestFailed { - id: detector_id.to_string(), - error: clients::Error::Http { - code: StatusCode::SERVICE_UNAVAILABLE, - message: "Service Unavailable".to_string(), - }, - }; - - faux::when!(detector_client.text_contents( - detector_id, - ContentAnalysisRequest::new(vec![sentence.clone()], DetectorParams::new()), - HeaderMap::new(), - )) - .once() - .then_return(Err(clients::Error::Http { - code: StatusCode::SERVICE_UNAVAILABLE, - message: "Service Unavailable".to_string(), - })); - - let mut clients = ClientMap::new(); - clients.insert("generation".into(), generation_client); - clients.insert(detector_id.into(), detector_client); - let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); - - assert_eq!( - detect( - ctx, - detector_id.to_string(), - threshold, - detector_params, - chunks, - HeaderMap::new(), - ) - .await - .unwrap_err(), - expected_response - ); - } - - #[tokio::test] - async fn test_handle_detection_task_with_whitespace() { - let generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut detector_client = TextContentsDetectorClient::faux(); - - let detector_id = "mocked_hap_detector"; - let threshold = 0.5; - let first_sentence = "".to_string(); - let mut detector_params = DetectorParams::new(); - detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); - let chunks = vec![Chunk { - offset: 0, - text: first_sentence.clone(), - }]; - - faux::when!(detector_client.text_contents( - detector_id, - ContentAnalysisRequest::new(vec![first_sentence.clone()], DetectorParams::new()), - HeaderMap::new(), - )) - .once() - .then_return(Ok(vec![vec![]])); - - let mut clients = ClientMap::new(); - clients.insert("generation".into(), generation_client); - clients.insert(detector_id.into(), detector_client); - let ctx = Arc::new(Context::new(OrchestratorConfig::default(), clients)); - - let expected_response_whitespace = vec![]; - assert_eq!( - detect( - ctx, - detector_id.to_string(), - threshold, - detector_params, - chunks, - HeaderMap::new(), - ) - .await - .unwrap(), - expected_response_whitespace - ); - } - - #[tokio::test] - async fn test_detect_for_generation() { - let generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut detector_client = TextGenerationDetectorClient::faux(); - - let detector_id = "mocked_answer_relevance_detector"; - let threshold = 0.5; - let prompt = "What is the capital of Brazil?".to_string(); - let generated_text = "The capital of Brazil is Brasilia.".to_string(); - let mut detector_params = DetectorParams::new(); - detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); - - let expected_response: Vec = vec![DetectionResult { - detection_type: "relevance".to_string(), - detection: "is_relevant".to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.9, - evidence: Some( - [EvidenceObj { - name: "relevant chunk".into(), - value: Some("What is capital of Brazil".into()), - score: Some(0.99), - evidence: None, - }] - .to_vec(), - ), - metadata: Metadata::new(), - }]; - - faux::when!(detector_client.text_generation( - detector_id, - GenerationDetectionRequest::new( - prompt.clone(), - generated_text.clone(), - DetectorParams::new() - ), - HeaderMap::new(), - )) - .once() - .then_return(Ok(vec![DetectionResult { - detection_type: "relevance".to_string(), - detection: "is_relevant".to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.9, - evidence: Some( - [EvidenceObj { - name: "relevant chunk".into(), - value: Some("What is capital of Brazil".into()), - score: Some(0.99), - evidence: None, - }] - .to_vec(), - ), - metadata: Metadata::new(), - }])); - - let mut clients = ClientMap::new(); - clients.insert("generation".into(), generation_client); - clients.insert(detector_id.into(), detector_client); - let mut ctx = Context::new(OrchestratorConfig::default(), clients); - // add detector - ctx.config.detectors.insert( - detector_id.to_string(), - DetectorConfig { - ..Default::default() - }, - ); - let ctx = Arc::new(ctx); - - assert_eq!( - detect_for_generation( - ctx, - detector_id.to_string(), - detector_params, - prompt, - generated_text, - HeaderMap::new(), - ) - .await - .unwrap(), - expected_response - ); - } - - #[tokio::test] - async fn test_detect_for_generation_below_threshold() { - let generation_client = GenerationClient::tgis(TgisClient::faux()); - let mut detector_client = TextGenerationDetectorClient::faux(); - - let detector_id = "mocked_answer_relevance_detector"; - let threshold = 0.5; - let prompt = "What is the capital of Brazil?".to_string(); - let generated_text = - "The most beautiful places can be found in Rio de Janeiro.".to_string(); - let mut detector_params = DetectorParams::new(); - detector_params.insert(THRESHOLD_PARAM.into(), threshold.into()); - - let expected_response: Vec = vec![]; - - faux::when!(detector_client.text_generation( - detector_id, - GenerationDetectionRequest::new( - prompt.clone(), - generated_text.clone(), - DetectorParams::new() - ), - HeaderMap::new(), - )) - .once() - .then_return(Ok(vec![DetectionResult { - detection_type: "relevance".to_string(), - detection: "is_relevant".to_string(), - detector_id: Some(detector_id.to_string()), - score: 0.1, - evidence: None, - metadata: Metadata::new(), - }])); - - let mut clients = ClientMap::new(); - clients.insert("generation".into(), generation_client); - clients.insert(detector_id.into(), detector_client); - let mut ctx = Context::new(OrchestratorConfig::default(), clients); - // add detector - ctx.config.detectors.insert( - detector_id.to_string(), - DetectorConfig { - ..Default::default() - }, - ); - let ctx = Arc::new(ctx); - - assert_eq!( - detect_for_generation( - ctx, - detector_id.to_string(), - detector_params, - prompt, - generated_text, - HeaderMap::new(), - ) - .await - .unwrap(), - expected_response - ); - } -} diff --git a/src/server.rs b/src/server.rs index b798db87..10be5e05 100644 --- a/src/server.rs +++ b/src/server.rs @@ -50,7 +50,7 @@ use tokio_rustls::TlsAcceptor; use tokio_stream::wrappers::ReceiverStream; use tower::Service; use tower_http::trace::TraceLayer; -use tracing::{Span, debug, error, info, instrument, warn}; +use tracing::{Span, debug, error, info, warn}; use tracing_opentelemetry::OpenTelemetrySpanExt; use webpki::types::{CertificateDer, PrivateKeyDer}; @@ -58,10 +58,8 @@ use crate::{ clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, models::{self, InfoParams, InfoResponse, StreamingContentDetectionRequest}, orchestrator::{ - self, ChatCompletionsDetectionTask, ChatDetectionTask, ClassificationWithGenTask, - ContextDocsDetectionTask, DetectionOnGenerationTask, GenerationWithDetectionTask, - Orchestrator, StreamingClassificationWithGenTask, StreamingContentDetectionTask, - TextContentDetectionTask, + self, Orchestrator, + handlers::{chat_completions_detection::ChatCompletionsDetectionTask, *}, }, utils, }; @@ -340,28 +338,21 @@ async fn info( Ok(Json(InfoResponse { services })) } -#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn classification_with_gen( State(state): State>, headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ClassificationWithGenTask::new(trace_id, request, headers); - match state - .orchestrator - .handle_classification_with_gen(task) - .await - { + match state.orchestrator.handle(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } -#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn generation_with_detection( State(state): State>, headers: HeaderMap, @@ -371,28 +362,21 @@ async fn generation_with_detection( >, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = GenerationWithDetectionTask::new(trace_id, request, headers); - match state - .orchestrator - .handle_generation_with_detection(task) - .await - { + match state.orchestrator.handle(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } -#[instrument(skip_all, fields(model_id = ?request.model_id))] async fn stream_classification_with_gen( State(state): State>, headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Sse>> { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); if let Err(error) = request.validate() { // Request validation failed, return stream with single error SSE event let error: Error = error.into(); @@ -406,10 +390,7 @@ async fn stream_classification_with_gen( } let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = StreamingClassificationWithGenTask::new(trace_id, request, headers); - let response_stream = state - .orchestrator - .handle_streaming_classification_with_gen(task) - .await; + let response_stream = state.orchestrator.handle(task).await.unwrap(); // Convert response stream to a stream of SSE events let event_stream = response_stream .map(|message| match message { @@ -429,15 +410,12 @@ async fn stream_classification_with_gen( Sse::new(event_stream).keep_alive(KeepAlive::default()) } -#[instrument(skip_all)] async fn stream_content_detection( State(state): State>, headers: HeaderMap, json_lines: JsonLines, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling content detection streaming request"); - // Validate the content-type from the header and ensure it is application/x-ndjson // If it's not, return a UnsupportedContentType error with the appropriate message let content_type = headers @@ -462,19 +440,17 @@ async fn stream_content_detection( } Err(error) => Err(orchestrator::errors::Error::Validation(error.to_string())), }) + .enumerate() .boxed(); // Create task and submit to handler let task = StreamingContentDetectionTask::new(trace_id, headers, input_stream); - let mut response_stream = state - .orchestrator - .handle_streaming_content_detection(task) - .await; + let mut response_stream = state.orchestrator.handle(task).await?; // Create output stream // This stream returns ND-JSON formatted messages to the client // StreamingContentDetectionResponse / server::Error - let (output_tx, output_rx) = mpsc::channel::>(32); + let (output_tx, output_rx) = mpsc::channel::>(128); let output_stream = ReceiverStream::new(output_rx); // Spawn task to consume response stream (typed) and send to output stream (json) @@ -499,7 +475,6 @@ async fn stream_content_detection( Ok(Response::new(axum::body::Body::from_stream(output_stream))) } -#[instrument(skip_all)] async fn detection_content( State(state): State>, headers: HeaderMap, @@ -509,38 +484,30 @@ async fn detection_content( >, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = TextContentDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle_text_content_detection(task).await { + match state.orchestrator.handle(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } -#[instrument(skip_all)] async fn detect_context_documents( State(state): State>, headers: HeaderMap, WithRejection(Json(request), _): WithRejection, Error>, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ContextDocsDetectionTask::new(trace_id, request, headers); - match state - .orchestrator - .handle_context_documents_detection(task) - .await - { + match state.orchestrator.handle(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } -#[instrument(skip_all)] async fn detect_chat( State(state): State>, headers: HeaderMap, @@ -550,13 +517,12 @@ async fn detect_chat( request.validate_for_text()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ChatDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle_chat_detection(task).await { + match state.orchestrator.handle(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } -#[instrument(skip_all)] async fn detect_generated( State(state): State>, headers: HeaderMap, @@ -566,21 +532,15 @@ async fn detect_generated( >, ) -> Result { let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = DetectionOnGenerationTask::new(trace_id, request, headers); - match state - .orchestrator - .handle_generated_text_detection(task) - .await - { + match state.orchestrator.handle(task).await { Ok(response) => Ok(Json(response).into_response()), Err(error) => Err(error.into()), } } -#[instrument(skip_all)] async fn chat_completions_detection( State(state): State>, headers: HeaderMap, @@ -588,14 +548,9 @@ async fn chat_completions_detection( ) -> Result { use ChatCompletionsResponse::*; let trace_id = Span::current().context().span().span_context().trace_id(); - info!(?trace_id, "handling request"); let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); - match state - .orchestrator - .handle_chat_completions_detection(task) - .await - { + match state.orchestrator.handle(task).await { Ok(response) => match response { Unary(response) => Ok(Json(response).into_response()), Streaming(response_rx) => { diff --git a/src/utils/trace.rs b/src/utils/trace.rs index e55278f2..334f1198 100644 --- a/src/utils/trace.rs +++ b/src/utils/trace.rs @@ -230,7 +230,7 @@ pub fn init_tracing( pub fn incoming_request_span(request: &Request) -> Span { info_span!( - "incoming_orchestrator_http_request", + "request", request_method = request.method().to_string(), request_path = request.uri().path().to_string(), response_status_code = tracing::field::Empty, @@ -249,7 +249,7 @@ pub fn on_incoming_request(request: &Request, span: &Span) { method = %request.method(), path = %request.uri().path(), monotonic_counter.incoming_request_count = 1, - "started processing orchestrator request", + "started processing request", ); } @@ -266,7 +266,7 @@ pub fn on_outgoing_response(response: &Response, latency: Duration, span: &Span) duration_ms = %latency.as_millis(), monotonic_counter.handled_request_count = 1, histogram.service_request_duration = latency.as_millis() as u64, - "finished processing orchestrator request" + "finished processing request" ); if response.status().is_server_error() { @@ -299,7 +299,7 @@ pub fn on_outgoing_eos(_trailers: Option<&HeaderMap>, stream_duration: Duration, monotonic_counter.service_stream_response_count = 1, histogram.service_stream_response_duration = stream_duration.as_millis() as u64, stream_duration = stream_duration.as_millis(), - "end of orchestrator stream" + "end of stream" ); } diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs index 079c0e30..44e8ee18 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_detection.rs @@ -107,19 +107,10 @@ async fn no_detections() -> Result<(), anyhow::Error> { finish_reason: "EOS_TOKEN".to_string(), }, ]; - - let expected_detections = Some(ChatDetections { - input: vec![], - output: vec![], - }); - let chat_completions_response = ChatCompletion { model: MODEL_ID.into(), choices: expected_choices.clone(), - detections: Some(ChatDetections { - input: vec![], - output: vec![], - }), + detections: None, warnings: vec![], ..Default::default() }; @@ -175,14 +166,8 @@ async fn no_detections() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), - output: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), + input: HashMap::from([(detector_name.into(), DetectorParams::new())]), + output: HashMap::from([(detector_name.into(), DetectorParams::new())]), }), messages, ..Default::default() @@ -196,7 +181,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { assert_eq!(results.choices[0], chat_completions_response.choices[0]); assert_eq!(results.choices[1], chat_completions_response.choices[1]); assert_eq!(results.warnings, vec![]); - assert_eq!(results.detections, expected_detections); + assert!(results.detections.is_none()); Ok(()) } @@ -310,11 +295,8 @@ async fn input_detections() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), - output: None, + input: HashMap::from([(detector_name.into(), DetectorParams::new())]), + output: HashMap::new(), }), messages, ..Default::default() @@ -475,11 +457,8 @@ async fn input_client_error() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), - output: None, + input: HashMap::from([(detector_name.into(), DetectorParams::new())]), + output: HashMap::new(), }), messages: messages_chunker_error.clone(), ..Default::default() @@ -497,11 +476,8 @@ async fn input_client_error() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), - output: None, + input: HashMap::from([(detector_name.into(), DetectorParams::new())]), + output: HashMap::new(), }), messages: messages_detector_error.clone(), ..Default::default() @@ -519,11 +495,8 @@ async fn input_client_error() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), - output: None, + input: HashMap::from([(detector_name.into(), DetectorParams::new())]), + output: HashMap::new(), }), messages: messages_chat_completions_error.clone(), ..Default::default() @@ -710,11 +683,8 @@ async fn output_detections() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: None, - output: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), + input: HashMap::new(), + output: HashMap::from([(detector_name.into(), DetectorParams::new())]), }), messages, ..Default::default() @@ -902,11 +872,8 @@ async fn output_client_error() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: None, - output: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), + input: HashMap::new(), + output: HashMap::from([(detector_name.into(), DetectorParams::new())]), }), messages: messages_chunker_error.clone(), ..Default::default() @@ -924,11 +891,8 @@ async fn output_client_error() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: None, - output: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), + input: HashMap::new(), + output: HashMap::from([(detector_name.into(), DetectorParams::new())]), }), messages: messages_detector_error.clone(), ..Default::default() @@ -946,11 +910,8 @@ async fn output_client_error() -> Result<(), anyhow::Error> { .json(&ChatCompletionsRequest { model: MODEL_ID.into(), detectors: Some(DetectorConfig { - input: None, - output: Some(HashMap::from([( - detector_name.into(), - DetectorParams::new(), - )])), + input: HashMap::new(), + output: HashMap::from([(detector_name.into(), DetectorParams::new())]), }), messages: messages_chat_completions_error.clone(), ..Default::default() diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index cdcab601..5be53268 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -58,7 +58,6 @@ async fn no_detections() -> Result<(), anyhow::Error> { }, ]; let parameters = BTreeMap::from([("id".into(), "a".into()), ("type".into(), "b".into())]); - // tools are just passed through to the detector let tools = vec![Tool { r#type: "function".into(), function: ToolFunction { From 75c2d016acb5fc783acd14dcc7edc4aa158d808b Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Wed, 16 Apr 2025 18:57:29 -0300 Subject: [PATCH 08/24] Drop faux (#374) Signed-off-by: Mateus Devino --- Cargo.lock | 71 ------------------------ Cargo.toml | 1 - src/clients/chunker.rs | 3 - src/clients/detector/text_chat.rs | 5 -- src/clients/detector/text_contents.rs | 5 -- src/clients/detector/text_context_doc.rs | 5 -- src/clients/detector/text_generation.rs | 5 -- src/clients/generation.rs | 3 - src/clients/nlp.rs | 3 - src/clients/openai.rs | 4 -- src/clients/tgis.rs | 36 ------------ 11 files changed, 141 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4c3fbb68..e33aecff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,41 +524,6 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" -[[package]] -name = "darling" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" -dependencies = [ - "darling_core", - "quote", - "syn", -] - [[package]] name = "data-encoding" version = "2.8.0" @@ -678,29 +643,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "faux" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a8e414cd04dc036003ccd2cc56492d5a7365e9d3f40b27c43606e42b54e5d1f" -dependencies = [ - "faux_macros", - "paste", -] - -[[package]] -name = "faux_macros" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc583ba3a3c259f1986c68e1501228013a34224b9ac97d17b30d7f017213b2a" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn", - "uuid", -] - [[package]] name = "fixedbitset" version = "0.5.7" @@ -719,7 +661,6 @@ dependencies = [ "bytes", "clap", "eventsource-stream", - "faux", "futures", "futures-util", "ginepro", @@ -1308,12 +1249,6 @@ dependencies = [ "syn", ] -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.0.3" @@ -1820,12 +1755,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - [[package]] name = "percent-encoding" version = "2.3.1" diff --git a/Cargo.toml b/Cargo.toml index bffb1bad..a8c9f9dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,7 +89,6 @@ tonic-build = "0.12.3" [dev-dependencies] axum-test = "17.1.0" -faux = "0.1.12" mocktail = { version = "0.2.4-alpha" } rand = "0.9.0" test-log = "0.2.17" diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 36904433..9b3c9974 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -50,14 +50,12 @@ pub const DEFAULT_CHUNKER_ID: &str = "whole_doc_chunker"; type StreamingTokenizationResult = Result>, Status>; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct ChunkerClient { client: ChunkersServiceClient>, health_client: HealthClient>, } -#[cfg_attr(test, faux::methods)] impl ChunkerClient { pub async fn new(config: &ServiceConfig) -> Self { let client = create_grpc_client(DEFAULT_PORT, config, ChunkersServiceClient::new).await; @@ -99,7 +97,6 @@ impl ChunkerClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for ChunkerClient { fn name(&self) -> &str { diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs index c763f3a9..4eea8bfc 100644 --- a/src/clients/detector/text_chat.rs +++ b/src/clients/detector/text_chat.rs @@ -34,14 +34,12 @@ use crate::{ const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat"; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TextChatDetectorClient { client: HttpClient, health_client: Option, } -#[cfg_attr(test, faux::methods)] impl TextChatDetectorClient { pub async fn new( config: &ServiceConfig, @@ -75,7 +73,6 @@ impl TextChatDetectorClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for TextChatDetectorClient { fn name(&self) -> &str { @@ -91,10 +88,8 @@ impl Client for TextChatDetectorClient { } } -#[cfg_attr(test, faux::methods)] impl DetectorClient for TextChatDetectorClient {} -#[cfg_attr(test, faux::methods)] impl HttpClientExt for TextChatDetectorClient { fn inner(&self) -> &HttpClient { self.client() diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs index 27943a86..91ec9f5d 100644 --- a/src/clients/detector/text_contents.rs +++ b/src/clients/detector/text_contents.rs @@ -32,14 +32,12 @@ use crate::{ const CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TextContentsDetectorClient { client: HttpClient, health_client: Option, } -#[cfg_attr(test, faux::methods)] impl TextContentsDetectorClient { pub async fn new( config: &ServiceConfig, @@ -73,7 +71,6 @@ impl TextContentsDetectorClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for TextContentsDetectorClient { fn name(&self) -> &str { @@ -89,10 +86,8 @@ impl Client for TextContentsDetectorClient { } } -#[cfg_attr(test, faux::methods)] impl DetectorClient for TextContentsDetectorClient {} -#[cfg_attr(test, faux::methods)] impl HttpClientExt for TextContentsDetectorClient { fn inner(&self) -> &HttpClient { self.client() diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs index ee3a1f63..086520ae 100644 --- a/src/clients/detector/text_context_doc.rs +++ b/src/clients/detector/text_context_doc.rs @@ -30,14 +30,12 @@ use crate::{ const CONTEXT_DOC_DETECTOR_ENDPOINT: &str = "/api/v1/text/context/doc"; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TextContextDocDetectorClient { client: HttpClient, health_client: Option, } -#[cfg_attr(test, faux::methods)] impl TextContextDocDetectorClient { pub async fn new( config: &ServiceConfig, @@ -71,7 +69,6 @@ impl TextContextDocDetectorClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for TextContextDocDetectorClient { fn name(&self) -> &str { @@ -87,10 +84,8 @@ impl Client for TextContextDocDetectorClient { } } -#[cfg_attr(test, faux::methods)] impl DetectorClient for TextContextDocDetectorClient {} -#[cfg_attr(test, faux::methods)] impl HttpClientExt for TextContextDocDetectorClient { fn inner(&self) -> &HttpClient { self.client() diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs index 6ba6b82a..8de1c020 100644 --- a/src/clients/detector/text_generation.rs +++ b/src/clients/detector/text_generation.rs @@ -30,14 +30,12 @@ use crate::{ const GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation"; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TextGenerationDetectorClient { client: HttpClient, health_client: Option, } -#[cfg_attr(test, faux::methods)] impl TextGenerationDetectorClient { pub async fn new( config: &ServiceConfig, @@ -71,7 +69,6 @@ impl TextGenerationDetectorClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for TextGenerationDetectorClient { fn name(&self) -> &str { @@ -87,10 +84,8 @@ impl Client for TextGenerationDetectorClient { } } -#[cfg_attr(test, faux::methods)] impl DetectorClient for TextGenerationDetectorClient {} -#[cfg_attr(test, faux::methods)] impl HttpClientExt for TextGenerationDetectorClient { fn inner(&self) -> &HttpClient { self.client() diff --git a/src/clients/generation.rs b/src/clients/generation.rs index 507864b6..15590623 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -38,7 +38,6 @@ use crate::{ }, }; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct GenerationClient(Option); @@ -48,7 +47,6 @@ enum GenerationClientInner { Nlp(NlpClient), } -#[cfg_attr(test, faux::methods)] impl GenerationClient { pub fn tgis(client: TgisClient) -> Self { Self(Some(GenerationClientInner::Tgis(client))) @@ -224,7 +222,6 @@ impl GenerationClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for GenerationClient { fn name(&self) -> &str { diff --git a/src/clients/nlp.rs b/src/clients/nlp.rs index f7d9006a..33145ebf 100644 --- a/src/clients/nlp.rs +++ b/src/clients/nlp.rs @@ -47,14 +47,12 @@ use crate::{ const DEFAULT_PORT: u16 = 8085; const MODEL_ID_HEADER_NAME: &str = "mm-model-id"; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct NlpClient { client: NlpServiceClient>, health_client: HealthClient>, } -#[cfg_attr(test, faux::methods)] impl NlpClient { pub async fn new(config: &ServiceConfig) -> Self { let client = create_grpc_client(DEFAULT_PORT, config, NlpServiceClient::new).await; @@ -133,7 +131,6 @@ impl NlpClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for NlpClient { fn name(&self) -> &str { diff --git a/src/clients/openai.rs b/src/clients/openai.rs index bd968474..05f15a63 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -40,14 +40,12 @@ const DEFAULT_PORT: u16 = 8080; const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct OpenAiClient { client: HttpClient, health_client: Option, } -#[cfg_attr(test, faux::methods)] impl OpenAiClient { pub async fn new( config: &ServiceConfig, @@ -136,7 +134,6 @@ impl OpenAiClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for OpenAiClient { fn name(&self) -> &str { @@ -152,7 +149,6 @@ impl Client for OpenAiClient { } } -#[cfg_attr(test, faux::methods)] impl HttpClientExt for OpenAiClient { fn inner(&self) -> &HttpClient { self.client() diff --git a/src/clients/tgis.rs b/src/clients/tgis.rs index ff7f845b..12212d40 100644 --- a/src/clients/tgis.rs +++ b/src/clients/tgis.rs @@ -39,13 +39,11 @@ use crate::{ const DEFAULT_PORT: u16 = 8033; -#[cfg_attr(test, faux::create)] #[derive(Clone)] pub struct TgisClient { client: GenerationServiceClient>, } -#[cfg_attr(test, faux::methods)] impl TgisClient { pub async fn new(config: &ServiceConfig) -> Self { let client = create_grpc_client(DEFAULT_PORT, config, GenerationServiceClient::new).await; @@ -101,7 +99,6 @@ impl TgisClient { } } -#[cfg_attr(test, faux::methods)] #[async_trait] impl Client for TgisClient { fn name(&self) -> &str { @@ -134,36 +131,3 @@ impl Client for TgisClient { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::pb::fmaas::model_info_response; - - #[tokio::test] - async fn test_model_info() { - // Initialize a mock object from `TgisClient` - let mut mock_client = TgisClient::faux(); - - let request = ModelInfoRequest { - model_id: "test-model-1".to_string(), - }; - - let expected_response = ModelInfoResponse { - max_sequence_length: 2, - max_new_tokens: 20, - max_beam_width: 3, - model_kind: model_info_response::ModelKind::DecoderOnly.into(), - max_beam_sequence_lengths: [].to_vec(), - }; - // Construct a behavior for the mock object - faux::when!(mock_client.model_info(request.clone())) - .once() - .then_return(Ok(expected_response.clone())); - // Test the mock object's behaviour - assert_eq!( - mock_client.model_info(request).await.unwrap(), - expected_response - ); - } -} From d6028922ac58b8e7782c091e4b6be59bc7f7b381 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Thu, 17 Apr 2025 10:25:55 -0600 Subject: [PATCH 09/24] :fire: Chunker client cleanup (#376) * :fire: Remove whole doc chunking from chunker client Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :fire: Remove unused imports Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/clients/chunker.rs | 111 ++--------------------------------------- 1 file changed, 3 insertions(+), 108 deletions(-) diff --git a/src/clients/chunker.rs b/src/clients/chunker.rs index 9b3c9974..05eed2c0 100644 --- a/src/clients/chunker.rs +++ b/src/clients/chunker.rs @@ -19,10 +19,10 @@ use std::pin::Pin; use async_trait::async_trait; use axum::http::HeaderMap; -use futures::{Future, Stream, StreamExt, TryStreamExt}; +use futures::{Future, StreamExt, TryStreamExt}; use ginepro::LoadBalancedChannel; use tonic::{Code, Request, Response, Status, Streaming}; -use tracing::{Span, instrument}; +use tracing::Span; use super::{ BoxStream, Client, Error, create_grpc_client, errors::grpc_to_http_code, @@ -36,7 +36,7 @@ use crate::{ BidiStreamingChunkerTokenizationTaskRequest, ChunkerTokenizationTaskRequest, chunkers_service_client::ChunkersServiceClient, }, - caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults}, + caikit_data_model::nlp::{ChunkerTokenizationStreamResult, TokenizationResults}, grpc::health::v1::{HealthCheckRequest, health_client::HealthClient}, }, utils::trace::trace_context_from_grpc_response, @@ -137,108 +137,3 @@ fn request_with_headers(request: T, model_id: &str) -> Request { .insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap()); request } - -/// Unary tokenization result of the entire doc -#[instrument(skip_all)] -pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults { - let codepoint_count = request.text.chars().count() as i64; - TokenizationResults { - results: vec![Token { - start: 0, - end: codepoint_count, - text: request.text, - }], - token_count: 1, // entire doc - } -} - -/// Streaming tokenization result for the entire doc stream -#[instrument(skip_all)] -pub async fn tokenize_whole_doc_stream( - request: impl Stream, -) -> Result { - let (text, index_vec): (String, Vec) = request - .map(|r| (r.text_stream, r.input_index_stream)) - .collect() - .await; - let codepoint_count = text.chars().count() as i64; - let input_end_index = index_vec.last().copied().unwrap_or_default(); - Ok(ChunkerTokenizationStreamResult { - results: vec![Token { - start: 0, - end: codepoint_count, - text, - }], - token_count: 1, // entire doc/stream - processed_index: codepoint_count, - start_index: 0, - input_start_index: 0, - input_end_index, - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tokenize_whole_doc() { - let request = ChunkerTokenizationTaskRequest { - text: "Lorem ipsum dolor sit amet consectetur adipiscing \ - elit sed do eiusmod tempor incididunt ut labore et dolore \ - magna aliqua." - .into(), - }; - let expected_response = TokenizationResults { - results: vec![Token { - start: 0, - end: 121, - text: "Lorem ipsum dolor sit amet consectetur \ - adipiscing elit sed do eiusmod tempor incididunt \ - ut labore et dolore magna aliqua." - .into(), - }], - token_count: 1, - }; - let response = tokenize_whole_doc(request); - assert_eq!(response, expected_response) - } - - #[tokio::test] - async fn test_tokenize_whole_doc_stream() { - let request = futures::stream::iter(vec![ - BidiStreamingChunkerTokenizationTaskRequest { - text_stream: "Lorem ipsum dolor sit amet ".into(), - input_index_stream: 0, - }, - BidiStreamingChunkerTokenizationTaskRequest { - text_stream: "consectetur adipiscing elit ".into(), - input_index_stream: 1, - }, - BidiStreamingChunkerTokenizationTaskRequest { - text_stream: "sed do eiusmod tempor incididunt ".into(), - input_index_stream: 2, - }, - BidiStreamingChunkerTokenizationTaskRequest { - text_stream: "ut labore et dolore magna aliqua.".into(), - input_index_stream: 3, - }, - ]); - let expected_response = ChunkerTokenizationStreamResult { - results: vec![Token { - start: 0, - end: 121, - text: "Lorem ipsum dolor sit amet consectetur adipiscing elit \ - sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." - .into(), - }], - token_count: 1, - processed_index: 121, - start_index: 0, - input_start_index: 0, - input_end_index: 3, - }; - let response = tokenize_whole_doc_stream(request).await.unwrap(); - assert_eq!(response, expected_response); - } -} From bd79b421c0561634e94362e12c0d1fb0bb90e822 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:37:54 -0300 Subject: [PATCH 10/24] Guardrails request config validation (#371) * Validate guardrails for chat_completions unary Signed-off-by: Mateus Devino * Validate guardrails for chat_detection Signed-off-by: Mateus Devino * Validate guardrails for classification_with_text_gen Signed-off-by: Mateus Devino * Add non-existing detector tests for chat_completions Signed-off-by: Mateus Devino * Add non-existing detector tests for chat_detection Signed-off-by: Mateus Devino * Validate guardrails for context_docs_detection Signed-off-by: Mateus Devino * Validate guardrails for detection_on_generation Signed-off-by: Mateus Devino * Make non existing detector a constant Signed-off-by: Mateus Devino * Validate guardrails for generation_with_detection Signed-off-by: Mateus Devino * Validate guardrails for text_content_detection Signed-off-by: Mateus Devino * Add guardrails validation for streaming_classification_with_gen Signed-off-by: Mateus Devino * Add guardrails validation for streaming_content_detection Signed-off-by: Mateus Devino * Update logs Signed-off-by: Mateus Devino * Update src/utils.rs Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> * Update src/utils.rs Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> * Update src/utils.rs Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> * Update src/utils.rs Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> * Update src/utils.rs Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> * Apply changes requested Signed-off-by: Mateus Devino * Remove unneded conditionals Signed-off-by: Mateus Devino --------- Signed-off-by: Mateus Devino Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- src/config.rs | 2 +- src/orchestrator/common/utils.rs | 41 ++++ .../chat_completions_detection/unary.rs | 21 +- src/orchestrator/handlers/chat_detection.rs | 13 +- .../handlers/classification_with_gen.rs | 21 +- .../handlers/context_docs_detection.rs | 13 +- .../handlers/detection_on_generation.rs | 13 +- .../handlers/generation_with_detection.rs | 13 +- .../streaming_classification_with_gen.rs | 26 +- .../handlers/streaming_content_detection.rs | 14 +- .../handlers/text_content_detection.rs | 13 +- src/utils.rs | 1 - tests/chat_completions_detection.rs | 137 ++++++++++- tests/chat_detection.rs | 72 +++++- tests/classification_with_text_gen.rs | 134 ++++++++++- tests/common/detectors.rs | 3 + tests/context_docs_detection.rs | 79 +++++- tests/detection_on_generation.rs | 69 +++++- tests/generation_with_detection.rs | 70 +++++- tests/streaming_classification_with_gen.rs | 227 +++++++++++++++++- tests/streaming_content_detection.rs | 111 ++++++++- tests/test_config.yaml | 12 + tests/text_content_detection.rs | 61 ++++- 23 files changed, 1117 insertions(+), 49 deletions(-) diff --git a/src/config.rs b/src/config.rs index 4e37e07f..bfd06325 100644 --- a/src/config.rs +++ b/src/config.rs @@ -169,7 +169,7 @@ pub struct DetectorConfig { pub r#type: DetectorType, } -#[derive(Default, Clone, Debug, Deserialize)] +#[derive(Default, Clone, Debug, Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] #[non_exhaustive] pub enum DetectorType { diff --git a/src/orchestrator/common/utils.rs b/src/orchestrator/common/utils.rs index 82a5d4a6..8e6521cd 100644 --- a/src/orchestrator/common/utils.rs +++ b/src/orchestrator/common/utils.rs @@ -16,7 +16,11 @@ */ use std::{collections::HashMap, sync::Arc}; +use tracing::error; + use crate::{ + clients::chunker::DEFAULT_CHUNKER_ID, + config::{DetectorConfig, DetectorType}, models::DetectorParams, orchestrator::{Context, Error}, }; @@ -112,6 +116,43 @@ pub fn configure_mock_servers( }; } +/// Validates guardrails on request. +pub fn validate_detectors( + detectors: &HashMap, + orchestrator_detectors: &HashMap, + allowed_detector_types: &[DetectorType], + allows_whole_doc_chunker: bool, +) -> Result<(), Error> { + let whole_doc_chunker_id = DEFAULT_CHUNKER_ID; + for detector_id in detectors.keys() { + // validate detectors + match orchestrator_detectors.get(detector_id) { + Some(detector_config) => { + if !allowed_detector_types.contains(&detector_config.r#type) { + let error = Error::Validation(format!( + "detector `{detector_id}` is not supported by this endpoint" + )); + error!("{error}"); + return Err(error); + } + if !allows_whole_doc_chunker && detector_config.chunker_id == whole_doc_chunker_id { + let error = Error::Validation(format!( + "detector `{detector_id}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint" + )); + error!("{error}"); + return Err(error); + } + } + None => { + let error = Error::DetectorNotFound(detector_id.clone()); + error!("{error}"); + return Err(error); + } + } + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/orchestrator/handlers/chat_completions_detection/unary.rs b/src/orchestrator/handlers/chat_completions_detection/unary.rs index 112862c0..a8ce51ff 100644 --- a/src/orchestrator/handlers/chat_completions_detection/unary.rs +++ b/src/orchestrator/handlers/chat_completions_detection/unary.rs @@ -23,10 +23,15 @@ use uuid::Uuid; use super::ChatCompletionsDetectionTask; use crate::{ clients::openai::*, + config::DetectorType, models::{ DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE, }, - orchestrator::{Context, Error, common, types::ChatMessageIterator}, + orchestrator::{ + Context, Error, + common::{self, validate_detectors}, + types::ChatMessageIterator, + }, }; pub async fn handle_unary( @@ -39,7 +44,19 @@ pub async fn handle_unary( let input_detectors = detectors.input; let output_detectors = detectors.output; - // TODO: validate requested guardrails + validate_detectors( + &input_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + )?; + + validate_detectors( + &output_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + )?; if !input_detectors.is_empty() { // Handle input detection diff --git a/src/orchestrator/handlers/chat_detection.rs b/src/orchestrator/handlers/chat_detection.rs index 6ec10a0f..29b4fcf8 100644 --- a/src/orchestrator/handlers/chat_detection.rs +++ b/src/orchestrator/handlers/chat_detection.rs @@ -23,8 +23,12 @@ use tracing::{info, instrument}; use super::Handle; use crate::{ clients::openai, + config::DetectorType, models::{ChatDetectionHttpRequest, ChatDetectionResult, DetectorParams}, - orchestrator::{Error, Orchestrator, common}, + orchestrator::{ + Error, Orchestrator, + common::{self, validate_detectors}, + }, }; impl Handle for Orchestrator { @@ -40,7 +44,12 @@ impl Handle for Orchestrator { let trace_id = task.trace_id; info!(%trace_id, config = ?task.detectors, "task started"); - // TODO: validate requested guardrails + validate_detectors( + &task.detectors, + &ctx.config.detectors, + &[DetectorType::TextChat], + true, + )?; // Handle detection let detections = common::text_chat_detections( diff --git a/src/orchestrator/handlers/classification_with_gen.rs b/src/orchestrator/handlers/classification_with_gen.rs index f0d8e1fb..59d80dde 100644 --- a/src/orchestrator/handlers/classification_with_gen.rs +++ b/src/orchestrator/handlers/classification_with_gen.rs @@ -24,12 +24,16 @@ use tracing::{error, info, instrument}; use super::Handle; use crate::{ clients::GenerationClient, + config::DetectorType, models::{ ClassifiedGeneratedTextResult, DetectionWarning, DetectorParams, GuardrailsConfig, GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextGenTokenClassificationResults, }, - orchestrator::{Context, Error, Orchestrator, common}, + orchestrator::{ + Context, Error, Orchestrator, + common::{self, validate_detectors}, + }, }; impl Handle for Orchestrator { @@ -47,7 +51,20 @@ impl Handle for Orchestrator { let input_detectors = task.guardrails_config.input_detectors(); let output_detectors = task.guardrails_config.output_detectors(); - // TODO: validate requested guardrails + // input detectors validation + validate_detectors( + &input_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + )?; + // output detectors validation + validate_detectors( + &output_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + )?; if !input_detectors.is_empty() { // Handle input detection diff --git a/src/orchestrator/handlers/context_docs_detection.rs b/src/orchestrator/handlers/context_docs_detection.rs index 1fa9190b..8e751373 100644 --- a/src/orchestrator/handlers/context_docs_detection.rs +++ b/src/orchestrator/handlers/context_docs_detection.rs @@ -23,8 +23,12 @@ use tracing::{info, instrument}; use super::Handle; use crate::{ clients::detector::ContextType, + config::DetectorType, models::{ContextDocsHttpRequest, ContextDocsResult, DetectorParams}, - orchestrator::{Error, Orchestrator, common}, + orchestrator::{ + Error, Orchestrator, + common::{self, validate_detectors}, + }, }; impl Handle for Orchestrator { @@ -40,7 +44,12 @@ impl Handle for Orchestrator { let trace_id = task.trace_id; info!(%trace_id, config = ?task.detectors, "task started"); - // TODO: validate requested guardrails + validate_detectors( + &task.detectors, + &ctx.config.detectors, + &[DetectorType::TextContextDoc], + true, + )?; // Handle detection let detections = common::text_context_detections( diff --git a/src/orchestrator/handlers/detection_on_generation.rs b/src/orchestrator/handlers/detection_on_generation.rs index 0c563116..36af6070 100644 --- a/src/orchestrator/handlers/detection_on_generation.rs +++ b/src/orchestrator/handlers/detection_on_generation.rs @@ -22,8 +22,12 @@ use tracing::{info, instrument}; use super::Handle; use crate::{ + config::DetectorType, models::{DetectionOnGeneratedHttpRequest, DetectionOnGenerationResult, DetectorParams}, - orchestrator::{Error, Orchestrator, common}, + orchestrator::{ + Error, Orchestrator, + common::{self, validate_detectors}, + }, }; impl Handle for Orchestrator { @@ -39,7 +43,12 @@ impl Handle for Orchestrator { let trace_id = task.trace_id; info!(%trace_id, config = ?task.detectors, "task started"); - // TODO: validate requested guardrails + validate_detectors( + &task.detectors, + &ctx.config.detectors, + &[DetectorType::TextGeneration], + true, + )?; // Handle detection let detections = common::text_generation_detections( diff --git a/src/orchestrator/handlers/generation_with_detection.rs b/src/orchestrator/handlers/generation_with_detection.rs index e6e99463..d49ac76c 100644 --- a/src/orchestrator/handlers/generation_with_detection.rs +++ b/src/orchestrator/handlers/generation_with_detection.rs @@ -23,11 +23,15 @@ use tracing::{info, instrument}; use super::Handle; use crate::{ clients::GenerationClient, + config::DetectorType, models::{ DetectorParams, GenerationWithDetectionHttpRequest, GenerationWithDetectionResult, GuardrailsTextGenerationParameters, }, - orchestrator::{Error, Orchestrator, common}, + orchestrator::{ + Error, Orchestrator, + common::{self, validate_detectors}, + }, }; impl Handle for Orchestrator { @@ -43,7 +47,12 @@ impl Handle for Orchestrator { let trace_id = task.trace_id; info!(%trace_id, config = ?task.detectors, "task started"); - // TODO: validate requested guardrails + validate_detectors( + &task.detectors, + &ctx.config.detectors, + &[DetectorType::TextGeneration], + true, + )?; // Handle generation let client = ctx diff --git a/src/orchestrator/handlers/streaming_classification_with_gen.rs b/src/orchestrator/handlers/streaming_classification_with_gen.rs index b24b2e09..82427aff 100644 --- a/src/orchestrator/handlers/streaming_classification_with_gen.rs +++ b/src/orchestrator/handlers/streaming_classification_with_gen.rs @@ -30,13 +30,15 @@ use tracing::{Instrument, error, info, instrument}; use super::Handle; use crate::{ clients::GenerationClient, + config::DetectorType, models::{ ClassifiedGeneratedTextStreamResult, DetectionWarning, DetectorParams, GuardrailsConfig, GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextGenTokenClassificationResults, }, orchestrator::{ - Context, Error, Orchestrator, common, + Context, Error, Orchestrator, + common::{self, validate_detectors}, types::{ Chunk, DetectionBatchStream, DetectionStream, Detections, GenerationStream, MaxProcessedIndexBatcher, @@ -72,7 +74,27 @@ impl Handle for Orchestrator { let input_detectors = task.guardrails_config.input_detectors(); let output_detectors = task.guardrails_config.output_detectors(); - // TODO: validate requested guardrails + // input detectors validation + if let Err(error) = validate_detectors( + &input_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + false, + ) { + let _ = response_tx.send(Err(error)).await; + return; + } + + // output detectors validation + if let Err(error) = validate_detectors( + &output_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + false, + ) { + let _ = response_tx.send(Err(error)).await; + return; + } if !input_detectors.is_empty() { // Handle input detection diff --git a/src/orchestrator/handlers/streaming_content_detection.rs b/src/orchestrator/handlers/streaming_content_detection.rs index 32e41b55..97e143a3 100644 --- a/src/orchestrator/handlers/streaming_content_detection.rs +++ b/src/orchestrator/handlers/streaming_content_detection.rs @@ -25,9 +25,11 @@ use tracing::{Instrument, error, info, instrument}; use super::Handle; use crate::{ + config::DetectorType, models::{DetectorParams, StreamingContentDetectionRequest, StreamingContentDetectionResponse}, orchestrator::{ - Context, Error, Orchestrator, common, + Context, Error, Orchestrator, + common::{self, validate_detectors}, types::{BoxStream, DetectionBatchStream, DetectionStream, MaxProcessedIndexBatcher}, }, }; @@ -65,7 +67,15 @@ impl Handle for Orchestrator { }; info!(%trace_id, config = ?detectors, "task started"); - // TODO: validate requested guardrails + if let Err(error) = validate_detectors( + &detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + false, + ) { + let _ = response_tx.send(Err(error)).await; + return; + } handle_detection(ctx, trace_id, headers, detectors, input_stream, response_tx) .await; diff --git a/src/orchestrator/handlers/text_content_detection.rs b/src/orchestrator/handlers/text_content_detection.rs index a4773444..a8b02c66 100644 --- a/src/orchestrator/handlers/text_content_detection.rs +++ b/src/orchestrator/handlers/text_content_detection.rs @@ -22,8 +22,12 @@ use tracing::{info, instrument}; use super::Handle; use crate::{ + config::DetectorType, models::{DetectorParams, TextContentDetectionHttpRequest, TextContentDetectionResult}, - orchestrator::{Error, Orchestrator, common}, + orchestrator::{ + Error, Orchestrator, + common::{self, validate_detectors}, + }, }; impl Handle for Orchestrator { @@ -39,7 +43,12 @@ impl Handle for Orchestrator { let trace_id = task.trace_id; info!(%trace_id, config = ?task.detectors, "task started"); - // TODO: validate requested guardrails + validate_detectors( + &task.detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + )?; // Handle detection let (_, detections) = common::text_contents_detections( diff --git a/src/utils.rs b/src/utils.rs index 3e74220c..7fd86dc2 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,5 @@ use hyper::Uri; use url::Url; - pub mod json; pub mod tls; pub mod trace; diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs index 44e8ee18..1e10ed53 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_detection.rs @@ -22,7 +22,8 @@ use common::{ chat_completions::CHAT_COMPLETIONS_ENDPOINT, chunker::CHUNKER_UNARY_ENDPOINT, detectors::{ - DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + ANSWER_RELEVANCE_DETECTOR, DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, @@ -53,6 +54,7 @@ use fms_guardrails_orchestr8::{ use hyper::StatusCode; use mocktail::prelude::*; use test_log::test; +use tracing::debug; pub mod common; @@ -930,13 +932,20 @@ async fn output_client_error() -> Result<(), anyhow::Error> { #[test(tokio::test)] async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + // Start orchestrator server and its dependencies let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .build() .await?; - // Orchestrator request with non existing field + let messages = vec![Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }]; + + // Extra request field scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) .json(&serde_json::json!({ @@ -947,24 +956,134 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { detector_name: {} } }, - "messages": vec![Message { - content: Some(Content::Text("Hi there!".to_string())), - role: Role::User, - ..Default::default() - }], + "messages": &messages, "some_extra_field": "random value" })) .send() .await?; - // Assertions for invalid request let results = response.json::().await?; - assert_eq!(results.code, StatusCode::UNPROCESSABLE_ENTITY); + debug!("{results:#?}"); + assert_eq!( + results.code, + StatusCode::UNPROCESSABLE_ENTITY, + "failed on extra request field scenario" + ); assert!( results .details .starts_with("some_extra_field: unknown field `some_extra_field`") ); + // Invalid input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: HashMap::from([(ANSWER_RELEVANCE_DETECTOR.into(), DetectorParams::new())]), + output: HashMap::new(), + }), + messages: messages.clone(), + ..Default::default() + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + ANSWER_RELEVANCE_DETECTOR + ) + }, + "failed on invalid input detector scenario" + ); + + // Non-existing input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + output: HashMap::new(), + }), + messages: messages.clone(), + ..Default::default() + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing input detector scenario" + ); + + // Invalid output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: HashMap::new(), + output: HashMap::from([(ANSWER_RELEVANCE_DETECTOR.into(), DetectorParams::new())]), + }), + messages: messages.clone(), + ..Default::default() + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + ANSWER_RELEVANCE_DETECTOR + ) + }, + "failed on invalid output detector scenario" + ); + + // Non-existing output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&ChatCompletionsRequest { + model: MODEL_ID.into(), + detectors: Some(DetectorConfig { + input: HashMap::new(), + output: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + }), + messages: messages.clone(), + ..Default::default() + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing input detector scenario" + ); + Ok(()) } diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 5be53268..2f00ee98 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -17,7 +17,10 @@ use std::collections::{BTreeMap, HashMap}; use common::{ - detectors::{CHAT_DETECTOR_ENDPOINT, PII_DETECTOR}, + detectors::{ + ANSWER_RELEVANCE_DETECTOR_SENTENCE, CHAT_DETECTOR_ENDPOINT, NON_EXISTING_DETECTOR, + PII_DETECTOR, + }, errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, @@ -379,10 +382,77 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { })) .send() .await?; + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let response = response.json::().await?; + debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("Message content cannot be empty")); + // Asserts requests with detector with invalid type return 422 + let messages = vec![ + Message { + role: Role::User, + content: Some(Content::Text("Hi there!".into())), + ..Default::default() + }, + Message { + role: Role::Assistant, + content: Some(Content::Text("Hello!".into())), + ..Default::default() + }, + ]; + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([( + ANSWER_RELEVANCE_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + messages: messages.clone(), + tools: vec![], + }) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + ANSWER_RELEVANCE_DETECTOR_SENTENCE + ) + }, + "failed on detector with invalid type scenario" + ); + + // Asserts requests with non-existing detector return 422 + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_DETECTION_ENDPOINT) + .json(&ChatDetectionHttpRequest { + detectors: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + messages: messages.clone(), + tools: vec![], + }) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing detector scenario" + ); + Ok(()) } diff --git a/tests/classification_with_text_gen.rs b/tests/classification_with_text_gen.rs index d73eff80..83f5aa6c 100644 --- a/tests/classification_with_text_gen.rs +++ b/tests/classification_with_text_gen.rs @@ -21,7 +21,8 @@ use anyhow::Ok; use common::{ chunker::CHUNKER_UNARY_ENDPOINT, detectors::{ - DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + ANSWER_RELEVANCE_DETECTOR_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, @@ -55,6 +56,7 @@ use fms_guardrails_orchestr8::{ use hyper::StatusCode; use mocktail::prelude::*; use test_log::test; +use tracing::debug; pub mod common; @@ -1206,7 +1208,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .build() .await?; - // Orchestrator request with non existing field + // Extra request parameter scenario let response = orchestrator_server .post(ORCHESTRATOR_UNARY_ENDPOINT) .json(&serde_json::json!({ @@ -1219,8 +1221,8 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - // Assertions for invalid request let results = response.json::().await?; + debug!("{results:#?}"); assert_eq!(results.code, StatusCode::UNPROCESSABLE_ENTITY); assert!( results @@ -1228,5 +1230,131 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .starts_with("non_existing_field: unknown field `non_existing_field`") ); + // Invalid input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: MODEL_ID.into(), + inputs: "This should return a 422".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([( + ANSWER_RELEVANCE_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + ANSWER_RELEVANCE_DETECTOR_SENTENCE + ) + }, + "failed on input detector with invalid type scenario" + ); + + // non-existing input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: MODEL_ID.into(), + inputs: "This should return a 404".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing input detector scenario" + ); + + // Invalid output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: MODEL_ID.into(), + inputs: "This should return a 422".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([( + ANSWER_RELEVANCE_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + ANSWER_RELEVANCE_DETECTOR_SENTENCE + ) + }, + "failed on output detector with invalid type scenario" + ); + + // non-existing output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: MODEL_ID.into(), + inputs: "This should return a 404".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing output detector scenario" + ); + Ok(()) } diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs index ee46ba9d..19f67767 100644 --- a/tests/common/detectors.rs +++ b/tests/common/detectors.rs @@ -20,8 +20,11 @@ pub const DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: &str = "angle_brackets_detecto pub const DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE: &str = "angle_brackets_detector_sentence"; pub const DETECTOR_NAME_PARENTHESIS_SENTENCE: &str = "parenthesis_detector_sentence"; pub const ANSWER_RELEVANCE_DETECTOR: &str = "answer_relevance_detector"; +pub const ANSWER_RELEVANCE_DETECTOR_SENTENCE: &str = "answer_relevance_detector_sentence"; pub const FACT_CHECKING_DETECTOR: &str = "fact_checking_detector"; +pub const FACT_CHECKING_DETECTOR_SENTENCE: &str = "fact_checking_detector_sentence"; pub const PII_DETECTOR: &str = "pii_detector"; +pub const NON_EXISTING_DETECTOR: &str = "non_existing_detector"; // Detector endpoints pub const TEXT_CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index f1c4741a..bc361084 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -17,7 +17,10 @@ use std::collections::HashMap; use common::{ - detectors::{CONTEXT_DOC_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR}, + detectors::{ + ANSWER_RELEVANCE_DETECTOR_SENTENCE, CONTEXT_DOC_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR, + NON_EXISTING_DETECTOR, + }, errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, @@ -362,5 +365,79 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(response.code, 422); assert!(response.details.starts_with("missing field `detectors`")); + // Asserts requests missing `detectors` return 422. + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&json!({ + "content": content, + "context": context, + "context_type": "docs" + })) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!(response.code, 422); + assert!(response.details.starts_with("missing field `detectors`")); + + // Asserts requests with detectors of invalid type return 422. + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([( + ANSWER_RELEVANCE_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + content: content.into(), + context_type: ContextType::Url, + context: context.clone(), + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + ANSWER_RELEVANCE_DETECTOR_SENTENCE + ) + }, + "failed on detector with invalid type scenario" + ); + + // Asserts requests with non-existing detectors return 404. + let response = orchestrator_server + .post(ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT) + .json(&ContextDocsHttpRequest { + detectors: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + content: content.into(), + context_type: ContextType::Url, + context, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing detector scenario" + ); + Ok(()) } diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index cbf1fc6b..38a47762 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -18,7 +18,10 @@ use std::collections::HashMap; use common::{ - detectors::{ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT}, + detectors::{ + ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT, + FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, + }, errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, @@ -315,8 +318,68 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let response = response.json::().await?; debug!("{response:#?}"); - assert_eq!(response.code, 422); - assert_eq!(response.details, "`detectors` is required"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: "`detectors` is required".into() + }, + "failed on empty `detectors` scenario" + ); + + // asserts requests with invalid detector type + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([( + FACT_CHECKING_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + FACT_CHECKING_DETECTOR_SENTENCE + ) + }, + "failed on invalid detector scenario" + ); + + // asserts requests with non-existing dewtector + let response = orchestrator_server + .post(ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT) + .json(&DetectionOnGeneratedHttpRequest { + prompt: prompt.into(), + generated_text: generated_text.into(), + detectors: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing detector scenario" + ); Ok(()) } diff --git a/tests/generation_with_detection.rs b/tests/generation_with_detection.rs index 245e3791..99ae194c 100644 --- a/tests/generation_with_detection.rs +++ b/tests/generation_with_detection.rs @@ -17,7 +17,10 @@ use std::collections::HashMap; use common::{ - detectors::{ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT}, + detectors::{ + ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT, + FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, + }, errors::{DetectorError, OrchestratorError}, generation::{GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_UNARY_ENDPOINT}, orchestrator::{ @@ -418,8 +421,69 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let response = response.json::().await?; debug!("{response:#?}"); - assert_eq!(response.code, 422); - assert_eq!(response.details, "`detectors` is required"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: "`detectors` is required".into() + }, + ); + + // assert request with invalid type detectors + let response = orchestrator_server + .post(ORCHESTRATOR_GENERATION_WITH_DETECTION_ENDPOINT) + .json(&GenerationWithDetectionHttpRequest { + model_id: model_id.into(), + prompt: prompt.into(), + detectors: HashMap::from([( + FACT_CHECKING_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + FACT_CHECKING_DETECTOR_SENTENCE + ) + }, + "failed at invalid detector scenario" + ); + + // assert request with non-existing detector + let response = orchestrator_server + .post(ORCHESTRATOR_GENERATION_WITH_DETECTION_ENDPOINT) + .json(&GenerationWithDetectionHttpRequest { + model_id: model_id.into(), + prompt: prompt.into(), + detectors: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let response = response.json::().await?; + debug!("{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing detector scenario" + ); Ok(()) } diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index bc353f86..36725a05 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -23,7 +23,8 @@ use common::{ CHUNKER_UNARY_ENDPOINT, }, detectors::{ - DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_PARENTHESIS_SENTENCE, + DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + DETECTOR_NAME_PARENTHESIS_SENTENCE, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, @@ -626,13 +627,12 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { let model_id = "my-super-model-8B"; - // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .build() .await?; - // Example orchestrator request with streaming response + // Request with extra fields scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&serde_json::json!({ @@ -649,9 +649,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!(?response); - // assertions assert_eq!(response.status(), 422); - let response_body = response.json::().await?; assert_eq!(response_body.code, 422); assert!( @@ -660,6 +658,225 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .starts_with("non_existing_field: unknown field `non_existing_field`") ); + // Invalid input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This request contains a detector with invalid type".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([( + FACT_CHECKING_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!(?response); + + assert_eq!(response.status(), 200); + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + FACT_CHECKING_DETECTOR_SENTENCE + ) + }, + "failed at invalid input detector scenario" + ); + + // Invalid chunker on input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This request contains a detector with an invalid chunker".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([( + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(), + DetectorParams::new(), + )]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), 200); + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 422, + details: format!( + "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC + ) + }, + "failed on input detector with invalid chunker scenario" + ); + + // Non-existing input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This request contains a detector with invalid type".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!(?response); + + assert_eq!(response.status(), 200); + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed at non-existing input detector scenario" + ); + + // Invalid output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This request contains a detector with invalid type".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([( + FACT_CHECKING_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!(?response); + + assert_eq!(response.status(), 200); + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + FACT_CHECKING_DETECTOR_SENTENCE + ) + }, + "failed at invalid output detector scenario" + ); + + // Invalid chunker on output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This request contains a detector with an invalid chunker".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([( + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(), + DetectorParams::new(), + )]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), 200); + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 422, + details: format!( + "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC + ) + }, + "failed on output detector with invalid chunker scenario" + ); + + // Non-existing output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This request contains a detector with invalid type".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: Some(GuardrailsConfigOutput { + models: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!(?response); + + assert_eq!(response.status(), 200); + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed at non-existing output detector scenario" + ); + Ok(()) } diff --git a/tests/streaming_content_detection.rs b/tests/streaming_content_detection.rs index fc519bcf..1304e195 100644 --- a/tests/streaming_content_detection.rs +++ b/tests/streaming_content_detection.rs @@ -19,7 +19,8 @@ use std::collections::HashMap; use common::{ chunker::{CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE, CHUNKER_STREAMING_ENDPOINT}, detectors::{ - DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_PARENTHESIS_SENTENCE, + DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, + DETECTOR_NAME_PARENTHESIS_SENTENCE, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, @@ -744,5 +745,113 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { "failed on empty `detectors` scenario" ); + // assert detector with invalid type on first frame + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + FACT_CHECKING_DETECTOR_SENTENCE.into(), + DetectorParams::new(), + )])), + content: "Hi".into(), + }, + ]))) + .send() + .await?; + + assert_eq!(response.status(), 200); + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + FACT_CHECKING_DETECTOR_SENTENCE + ) + }, + "failed at invalid input detector scenario" + ); + + // assert detector with invalid chunker on first frame + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(), + DetectorParams::new(), + )])), + content: "Hi".into(), + }, + ]))) + .send() + .await?; + + assert_eq!(response.status(), 200); + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 422, + details: format!( + "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC + ) + }, + "failed at detector with invalid chunker scenario" + ); + + // assert non-existing detector on first frame + let response = orchestrator_server + .post(ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT) + .header("content-type", "application/x-ndjson") + .body(reqwest::Body::wrap_stream(json_lines_stream([ + StreamingContentDetectionRequest { + detectors: Some(HashMap::from([( + NON_EXISTING_DETECTOR.into(), + DetectorParams::new(), + )])), + content: "Hi".into(), + }, + ]))) + .send() + .await?; + + assert_eq!(response.status(), 200); + let mut messages = Vec::::with_capacity(1); + let mut stream = response.bytes_stream(); + while let Some(Ok(msg)) = stream.next().await { + debug!("recv: {msg:?}"); + messages.push(serde_json::from_slice(&msg[..]).unwrap()); + } + + assert_eq!(messages.len(), 1); + assert_eq!( + messages[0], + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed at non-existing input detector scenario" + ); + Ok(()) } diff --git a/tests/test_config.yaml b/tests/test_config.yaml index 89ecb2df..cc691881 100644 --- a/tests/test_config.yaml +++ b/tests/test_config.yaml @@ -41,6 +41,12 @@ detectors: service: hostname: localhost chunker_id: whole_doc_chunker + default_threshold: 0.5 + answer_relevance_detector_sentence: + type: text_generation + service: + hostname: localhost + chunker_id: sentence_chunker default_threshold: 0.5 fact_checking_detector: type: text_context_doc @@ -48,6 +54,12 @@ detectors: hostname: localhost chunker_id: whole_doc_chunker default_threshold: 0.5 + fact_checking_detector_sentence: + type: text_context_doc + service: + hostname: localhost + chunker_id: sentence_chunker + default_threshold: 0.5 pii_detector: type: text_chat service: diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index ffcd2260..aca2f6cf 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -21,7 +21,7 @@ use common::{ chunker::{CHUNKER_NAME_SENTENCE, CHUNKER_UNARY_ENDPOINT}, detectors::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, - TEXT_CONTENTS_DETECTOR_ENDPOINT, + FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, errors::{DetectorError, OrchestratorError}, orchestrator::{ @@ -468,8 +468,63 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let response: OrchestratorError = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); - assert_eq!(response.code, 422); - assert_eq!(response.details, "`detectors` is required"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: "`detectors` is required".into() + }, + "failed on empty `detectors` scenario" + ); + + // assert detector with invalid type + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&json!({ + "content": "This sentence has no detections.", + "detectors": {FACT_CHECKING_DETECTOR_SENTENCE: {}}, + })) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + let response: OrchestratorError = response.json().await?; + debug!("orchestrator json response body:\n{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + FACT_CHECKING_DETECTOR_SENTENCE + ) + }, + "failed on invalid detector type scenario" + ); + + // assert detector with invalid type + let response = orchestrator_server + .post(ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT) + .json(&json!({ + "content": "This sentence has no detections.", + "detectors": {NON_EXISTING_DETECTOR: {}}, + })) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let response: OrchestratorError = response.json().await?; + debug!("orchestrator json response body:\n{response:#?}"); + assert_eq!( + response, + OrchestratorError { + code: 404, + details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + }, + "failed on non-existing detector scenario" + ); Ok(()) } From 92536988495cf9bcadc07e4cf014e60349baa8bc Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:15:28 -0300 Subject: [PATCH 11/24] Integration tests for no detectors requests (#373) Signed-off-by: Mateus Devino --- tests/classification_with_text_gen.rs | 115 +++++++++++++++++++++ tests/streaming_classification_with_gen.rs | 59 ++++++++++- 2 files changed, 171 insertions(+), 3 deletions(-) diff --git a/tests/classification_with_text_gen.rs b/tests/classification_with_text_gen.rs index 83f5aa6c..9894e618 100644 --- a/tests/classification_with_text_gen.rs +++ b/tests/classification_with_text_gen.rs @@ -64,6 +64,121 @@ pub mod common; const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; const MODEL_ID: &str = "my-super-model-8B"; +#[test(tokio::test)] +async fn no_detectors() -> Result<(), anyhow::Error> { + // Add generation mock + let model_id = "my-super-model-8B"; + let inputs = "Hi there! How are you?"; + + // Add expected generated text + let expected_response = GeneratedTextResult { + generated_text: "I am great!".into(), + generated_tokens: 0, + finish_reason: 0, + input_token_count: 0, + seed: 0, + tokens: vec![], + input_tokens: vec![], + }; + + let mut mocks = MockSet::new(); + mocks.mock(|when, then| { + when.path(GENERATION_NLP_UNARY_ENDPOINT) + .header(GENERATION_NLP_MODEL_ID_HEADER_NAME, MODEL_ID) + .pb(TextGenerationTaskRequest { + text: inputs.into(), + ..Default::default() + }); + then.pb(expected_response.clone()); + }); + + // Configure mock servers + let generation_server = MockServer::new("nlp").grpc().with_mocks(mocks); + + // Run test orchestrator server + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .generation_server(&generation_server) + .build() + .await?; + + // Empty `guardrail_config` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: inputs.into(), + guardrail_config: None, + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!( + results.generated_text, + Some(expected_response.clone().generated_text) + ); + assert_eq!(results.warnings, None); + + // `guardrail_config` with `input` and `output` set to None scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!( + results.generated_text, + Some(expected_response.clone().generated_text) + ); + assert_eq!(results.warnings, None); + + // `guardrail_config` with `input` and `output` set to empty map scenario + let response = orchestrator_server + .post(ORCHESTRATOR_UNARY_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::new(), + masks: None, + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + debug!("{response:#?}"); + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!( + results.generated_text, + Some(expected_response.clone().generated_text) + ); + assert_eq!(results.warnings, None); + + Ok(()) +} + // Validate that requests without detectors, input detector and output detector configured // returns text generated by model #[test(tokio::test)] diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index 36725a05..8805a752 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -127,7 +127,7 @@ async fn no_detectors() -> Result<(), anyhow::Error> { .build() .await?; - // Example orchestrator request with streaming response + // Empty `guardrail_config` scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAMING_ENDPOINT) .json(&GuardrailsHttpRequest { @@ -139,13 +139,66 @@ async fn no_detectors() -> Result<(), anyhow::Error> { .send() .await?; - // Collects stream results let sse_stream: SseStream = SseStream::new(response.bytes_stream()); let messages = sse_stream.try_collect::>().await?; debug!("{messages:#?}"); - // assertions + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].generated_text, Some("I".into())); + assert_eq!(messages[1].generated_text, Some(" am".into())); + assert_eq!(messages[2].generated_text, Some(" great!".into())); + + // `guardrail_config` with `input` and `output` set to None scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: None, + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].generated_text, Some("I".into())); + assert_eq!(messages[1].generated_text, Some(" am".into())); + assert_eq!(messages[2].generated_text, Some(" great!".into())); + + // `guardrail_config` with `input` and `output` set to empty map scenario + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "Hi there! How are you?".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::new(), + masks: None, + }), + output: Some(GuardrailsConfigOutput { + models: HashMap::new(), + }), + }), + text_gen_parameters: None, + }) + .send() + .await?; + + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + assert_eq!(messages.len(), 3); assert_eq!(messages[0].generated_text, Some("I".into())); assert_eq!(messages[1].generated_text, Some(" am".into())); From 6bc045f0768bc9f9688c47a2dc5f7605a01b4d96 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Fri, 18 Apr 2025 09:59:56 -0700 Subject: [PATCH 12/24] refactor ChatCompletionsRequest (#375) Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> Co-authored-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> --- src/clients/openai.rs | 291 ++++++++------ src/orchestrator/common/client.rs | 8 +- .../chat_completions_detection/streaming.rs | 2 +- .../chat_completions_detection/unary.rs | 2 +- tests/chat_completions_detection.rs | 370 ++++++++---------- 5 files changed, 341 insertions(+), 332 deletions(-) diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 05f15a63..917e667a 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -23,6 +23,7 @@ use futures::StreamExt; use http_body_util::BodyExt; use hyper::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; use tokio::sync::mpsc; use super::{ @@ -32,7 +33,7 @@ use super::{ use crate::{ config::ServiceConfig, health::HealthCheckResult, - models::{DetectionWarningReason, DetectorParams}, + models::{DetectionWarningReason, DetectorParams, ValidationError}, orchestrator, }; @@ -167,122 +168,83 @@ impl From for ChatCompletionsResponse { } } -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] +/// Represents a chat completions request. +/// +/// As orchestrator is only concerned with a limited subset +/// of request fields, we deserialize to an inner [`serde_json::Map`] +/// and only validate and extract the fields used by this service. +/// This type is then serialized to the inner [`serde_json::Map`]. +/// +/// This is to avoid tracking and updating OpenAI and vLLM +/// parameter additions/changes. Full validation is delegated to +/// the downstream server implementation. +/// +/// Validated fields: detectors (internal), model, messages +#[derive(Debug, Default, Clone, PartialEq, Deserialize)] +#[serde(try_from = "Map")] pub struct ChatCompletionsRequest { - /// A list of messages comprising the conversation so far. - pub messages: Vec, - /// ID of the model to use. - pub model: String, - /// Whether or not to store the output of this chat completion request. - #[serde(skip_serializing_if = "Option::is_none")] - pub store: Option, - /// Developer-defined tags and values. - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - /// Modify the likelihood of specified tokens appearing in the completion. - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - /// Whether to return log probabilities of the output tokens or not. - /// If true, returns the log probabilities of each output token returned in the content of message. - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - /// An integer between 0 and 20 specifying the number of most likely tokens to return - /// at each token position, each with an associated log probability. - /// logprobs must be set to true if this parameter is used. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_logprobs: Option, - /// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED) - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. - #[serde(skip_serializing_if = "Option::is_none")] - pub max_completion_tokens: Option, - /// How many chat completion choices to generate for each input message. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - /// Positive values penalize new tokens based on whether they appear in the text so far, - /// increasing the model's likelihood to talk about new topics. - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - /// An object specifying the format that the model must output. - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, - /// If specified, our system will make a best effort to sample deterministically, - /// such that repeated requests with the same seed and parameters should return the same result. - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - /// Specifies the latency tier to use for processing the request. - #[serde(skip_serializing_if = "Option::is_none")] - pub service_tier: Option, - /// Up to 4 sequences where the API will stop generating further tokens. - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - /// If set, partial message deltas will be sent, like in ChatGPT. - /// Tokens will be sent as data-only server-sent events as they become available, - /// with the stream terminated by a data: [DONE] message. - #[serde(default)] + /// Detector config. + pub detectors: DetectorConfig, + /// Stream parameter. pub stream: bool, - /// Options for streaming response. Only set this when you set stream: true. - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - /// What sampling temperature to use, between 0 and 2. - /// Higher values like 0.8 will make the output more random, - /// while lower values like 0.2 will make it more focused and deterministic. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - /// An alternative to sampling with temperature, called nucleus sampling, - /// where the model considers the results of the tokens with top_p probability mass. - /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - /// A list of tools the model may call. - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub tools: Vec, - /// Controls which (if any) tool is called by the model. - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - /// Whether to enable parallel function calling during tool use. - #[serde(skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - /// A unique identifier representing your end-user. - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, + /// Model name. + pub model: String, + /// Messages. + pub messages: Vec, + /// Inner request. + pub inner: Map, +} - // Additional vllm params - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub use_beam_search: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub min_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub repetition_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub length_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub early_stopping: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub ignore_eos: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub min_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_token_ids: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub skip_special_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub spaces_between_special_tokens: Option, +impl TryFrom> for ChatCompletionsRequest { + type Error = ValidationError; - // Detectors - // Note: We are making it optional, since this structure also gets used to - // form request for chat completions. And downstream server, might choose to - // reject extra parameters. - #[serde(skip_serializing_if = "Option::is_none")] - pub detectors: Option, + fn try_from(mut value: Map) -> Result { + let detectors = if let Some(detectors) = value.remove("detectors") { + DetectorConfig::deserialize(detectors) + .map_err(|_| ValidationError::Invalid("error deserializing `detectors`".into()))? + } else { + DetectorConfig::default() + }; + let stream = value + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or_default(); + let model = if let Some(Value::String(model)) = value.get("model") { + Ok(model.clone()) + } else { + Err(ValidationError::Required("model".into())) + }?; + if model.is_empty() { + return Err(ValidationError::Invalid("`model` must not be empty".into())); + } + let messages = if let Some(messages) = value.get("messages") { + Vec::::deserialize(messages) + .map_err(|_| ValidationError::Invalid("error deserializing `messages`".into())) + } else { + Err(ValidationError::Required("messages".into())) + }?; + if messages.is_empty() { + return Err(ValidationError::Invalid( + "`messages` must not be empty".into(), + )); + } + Ok(ChatCompletionsRequest { + detectors, + stream, + model, + messages, + inner: value, + }) + } +} + +impl Serialize for ChatCompletionsRequest { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.inner.serialize(serializer) + } } /// Structure to contain parameters for detectors. @@ -291,7 +253,6 @@ pub struct ChatCompletionsRequest { pub struct DetectorConfig { #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub input: HashMap, - #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub output: HashMap, } @@ -369,7 +330,7 @@ pub enum Role { Tool, } -#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Message { /// The role of the author of this message. @@ -731,3 +692,103 @@ impl OrchestratorWarning { } } } + +#[cfg(test)] +mod test { + use serde_json::json; + + use super::*; + + #[test] + fn test_chat_completions_request() -> Result<(), serde_json::Error> { + // Test deserialize + let detectors = DetectorConfig { + input: HashMap::from([("some_detector".into(), DetectorParams::new())]), + output: HashMap::new(), + }; + let messages = vec![Message { + content: Some(Content::Text("Hi there!".to_string())), + ..Default::default() + }]; + let json_request = json!({ + "model": "test", + "detectors": detectors, + "messages": messages, + }); + let request = ChatCompletionsRequest::deserialize(&json_request)?; + let mut inner = json_request.as_object().unwrap().to_owned(); + inner.remove("detectors").unwrap(); + assert_eq!( + request, + ChatCompletionsRequest { + detectors, + stream: false, + model: "test".into(), + messages: messages.clone(), + inner, + } + ); + + // Test deserialize with no detectors + let json_request = json!({ + "model": "test", + "messages": messages, + }); + let request = ChatCompletionsRequest::deserialize(&json_request)?; + let inner = json_request.as_object().unwrap().to_owned(); + assert_eq!( + request, + ChatCompletionsRequest { + detectors: DetectorConfig::default(), + stream: false, + model: "test".into(), + messages: messages.clone(), + inner, + } + ); + + // Test deserialize validation errors + let result = ChatCompletionsRequest::deserialize(json!({ + "detectors": DetectorConfig::default(), + "messages": messages, + })); + assert!(result.is_err_and(|error| error.to_string() == "`model` is required")); + + let result = ChatCompletionsRequest::deserialize(json!({ + "model": "", + "detectors": DetectorConfig::default(), + "messages": Vec::::default(), + })); + assert!(result.is_err_and(|error| error.to_string() == "`model` must not be empty")); + + let result = ChatCompletionsRequest::deserialize(json!({ + "model": "test", + "detectors": DetectorConfig::default(), + "messages": Vec::::default(), + })); + assert!(result.is_err_and(|error| error.to_string() == "`messages` must not be empty")); + + let result = ChatCompletionsRequest::deserialize(json!({ + "model": "test", + "detectors": DetectorConfig::default(), + "messages": ["invalid"], + })); + assert!(result.is_err_and(|error| error.to_string() == "error deserializing `messages`")); + + // Test serialize + let serialized_request = serde_json::to_value(request)?; + assert_eq!( + serialized_request, + json!({ + "model": "test", + "messages": [Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }], + }) + ); + + Ok(()) + } +} diff --git a/src/orchestrator/common/client.rs b/src/orchestrator/common/client.rs index 672df54a..653e2a78 100644 --- a/src/orchestrator/common/client.rs +++ b/src/orchestrator/common/client.rs @@ -247,10 +247,8 @@ pub async fn detect_text_context( pub async fn chat_completion( client: &OpenAiClient, headers: HeaderMap, - mut request: openai::ChatCompletionsRequest, + request: openai::ChatCompletionsRequest, ) -> Result { - request.stream = false; - request.detectors = None; let model_id = request.model.clone(); debug!(%model_id, ?request, "sending chat completions request"); let response = client @@ -269,10 +267,8 @@ pub async fn chat_completion( pub async fn chat_completion_stream( client: &OpenAiClient, headers: HeaderMap, - mut request: openai::ChatCompletionsRequest, + request: openai::ChatCompletionsRequest, ) -> Result { - request.stream = true; - request.detectors = None; let model_id = request.model.clone(); debug!(%model_id, ?request, "sending chat completions stream request"); let response = client diff --git a/src/orchestrator/handlers/chat_completions_detection/streaming.rs b/src/orchestrator/handlers/chat_completions_detection/streaming.rs index d6c27cf8..74ebcb59 100644 --- a/src/orchestrator/handlers/chat_completions_detection/streaming.rs +++ b/src/orchestrator/handlers/chat_completions_detection/streaming.rs @@ -30,7 +30,7 @@ pub async fn handle_streaming( task: ChatCompletionsDetectionTask, ) -> Result { let trace_id = task.trace_id; - let detectors = task.request.detectors.clone().unwrap_or_default(); + let detectors = task.request.detectors.clone(); info!(%trace_id, config = ?detectors, "task started"); let _input_detectors = detectors.input; let _output_detectors = detectors.output; diff --git a/src/orchestrator/handlers/chat_completions_detection/unary.rs b/src/orchestrator/handlers/chat_completions_detection/unary.rs index a8ce51ff..dc310d36 100644 --- a/src/orchestrator/handlers/chat_completions_detection/unary.rs +++ b/src/orchestrator/handlers/chat_completions_detection/unary.rs @@ -39,7 +39,7 @@ pub async fn handle_unary( task: ChatCompletionsDetectionTask, ) -> Result { let trace_id = task.trace_id; - let detectors = task.request.detectors.clone().unwrap_or_default(); + let detectors = task.request.detectors.clone(); info!(%trace_id, config = ?detectors, "task started"); let input_detectors = detectors.input; let output_detectors = detectors.output; diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs index 1e10ed53..85261360 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_detection.rs @@ -15,9 +15,6 @@ */ -use std::{collections::HashMap, vec}; - -use anyhow::Ok; use common::{ chat_completions::CHAT_COMPLETIONS_ENDPOINT, chunker::CHUNKER_UNARY_ENDPOINT, @@ -37,9 +34,8 @@ use fms_guardrails_orchestr8::{ chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, detector::{ContentAnalysisRequest, ContentAnalysisResponse}, openai::{ - ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, ChatCompletionsRequest, - ChatDetections, Content, DetectorConfig, InputDetectionResult, Message, - OrchestratorWarning, OutputDetectionResult, Role, + ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, ChatDetections, Content, + InputDetectionResult, Message, OrchestratorWarning, OutputDetectionResult, Role, }, }, models::{ @@ -53,6 +49,7 @@ use fms_guardrails_orchestr8::{ }; use hyper::StatusCode; use mocktail::prelude::*; +use serde_json::json; use test_log::test; use tracing::debug; @@ -140,14 +137,10 @@ async fn no_detections() -> Result<(), anyhow::Error> { // Add chat completions mock chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages, + })); then.json(&chat_completions_response); }); @@ -165,15 +158,18 @@ async fn no_detections() -> Result<(), anyhow::Error> { // Make orchestrator call for input/output no detections let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(detector_name.into(), DetectorParams::new())]), - output: HashMap::from([(detector_name.into(), DetectorParams::new())]), - }), - messages, - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": { + detector_name: {}, + }, + }, + "messages": messages, + })) .send() .await?; @@ -265,14 +261,10 @@ async fn input_detections() -> Result<(), anyhow::Error> { // Add chat completions mock chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages, + })); then.json(&chat_completions_response); }); @@ -294,15 +286,16 @@ async fn input_detections() -> Result<(), anyhow::Error> { // Make orchestrator call for input/output no detections let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(detector_name.into(), DetectorParams::new())]), - output: HashMap::new(), - }), - messages, - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "messages": messages, + })) .send() .await?; @@ -427,14 +420,10 @@ async fn input_client_error() -> Result<(), anyhow::Error> { // Add chat completions mock for chat completions error scenario chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages_chat_completions_error.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages_chat_completions_error, + })); then.internal_server_error(); }); @@ -456,15 +445,16 @@ async fn input_client_error() -> Result<(), anyhow::Error> { // Make orchestrator call for chunker error scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(detector_name.into(), DetectorParams::new())]), - output: HashMap::new(), - }), - messages: messages_chunker_error.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "messages": messages_chunker_error, + })) .send() .await?; @@ -475,15 +465,16 @@ async fn input_client_error() -> Result<(), anyhow::Error> { // Make orchestrator call for detector error scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(detector_name.into(), DetectorParams::new())]), - output: HashMap::new(), - }), - messages: messages_detector_error.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "messages": messages_detector_error, + })) .send() .await?; @@ -494,15 +485,16 @@ async fn input_client_error() -> Result<(), anyhow::Error> { // Make orchestrator call for chat completions error scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(detector_name.into(), DetectorParams::new())]), - output: HashMap::new(), - }), - messages: messages_chat_completions_error.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "messages": messages_chat_completions_error, + })) .send() .await?; @@ -653,14 +645,10 @@ async fn output_detections() -> Result<(), anyhow::Error> { // Add chat completions mock chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages, + })); then.json(&chat_completions_response); }); @@ -682,15 +670,16 @@ async fn output_detections() -> Result<(), anyhow::Error> { // Make orchestrator call for output detections let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::new(), - output: HashMap::from([(detector_name.into(), DetectorParams::new())]), - }), - messages, - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "messages": messages, + })) .send() .await?; @@ -816,40 +805,28 @@ async fn output_client_error() -> Result<(), anyhow::Error> { // Add chat completions mock for chunker error scenario chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages_chunker_error.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages_chunker_error, + })); then.internal_server_error(); }); // Add chat completions mock for detector error scenario chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages_detector_error.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages_detector_error, + })); then.internal_server_error().json(&expected_detector_error); }); // Add chat completions mock for chat completions error scenario chat_mocks.mock(|when, then| { - when.post() - .path(CHAT_COMPLETIONS_ENDPOINT) - .json(ChatCompletionsRequest { - messages: messages_chat_completions_error.clone(), - model: MODEL_ID.into(), - stream: false, - ..Default::default() - }); + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages_chat_completions_error, + })); then.internal_server_error(); }); @@ -871,15 +848,16 @@ async fn output_client_error() -> Result<(), anyhow::Error> { // Make orchestrator call for chunker error scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::new(), - output: HashMap::from([(detector_name.into(), DetectorParams::new())]), - }), - messages: messages_chunker_error.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "messages": messages_chunker_error, + })) .send() .await?; @@ -890,15 +868,16 @@ async fn output_client_error() -> Result<(), anyhow::Error> { // Make orchestrator call for detector error scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::new(), - output: HashMap::from([(detector_name.into(), DetectorParams::new())]), - }), - messages: messages_detector_error.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "messages": messages_detector_error, + })) .send() .await?; @@ -909,15 +888,16 @@ async fn output_client_error() -> Result<(), anyhow::Error> { // Make orchestrator call for chat completions error scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::new(), - output: HashMap::from([(detector_name.into(), DetectorParams::new())]), - }), - messages: messages_chat_completions_error.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "messages": messages_chat_completions_error, + })) .send() .await?; @@ -931,8 +911,6 @@ async fn output_client_error() -> Result<(), anyhow::Error> { // Validate that invalid orchestrator requests returns 422 error #[test(tokio::test)] async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { - let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; - // Start orchestrator server and its dependencies let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) @@ -945,51 +923,22 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { ..Default::default() }]; - // Extra request field scenario + // Invalid input detector scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&serde_json::json!({ + .json(&json!({ "model": MODEL_ID, - "detectors": { - "input": {}, - "output": { - detector_name: {} - } + "detectors": { + "input": { + ANSWER_RELEVANCE_DETECTOR: {}, + }, + "output": {} }, - "messages": &messages, - "some_extra_field": "random value" + "messages": messages, })) .send() .await?; - let results = response.json::().await?; - debug!("{results:#?}"); - assert_eq!( - results.code, - StatusCode::UNPROCESSABLE_ENTITY, - "failed on extra request field scenario" - ); - assert!( - results - .details - .starts_with("some_extra_field: unknown field `some_extra_field`") - ); - - // Invalid input detector scenario - let response = orchestrator_server - .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(ANSWER_RELEVANCE_DETECTOR.into(), DetectorParams::new())]), - output: HashMap::new(), - }), - messages: messages.clone(), - ..Default::default() - }) - .send() - .await?; - let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( @@ -1007,15 +956,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { // Non-existing input detector scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), - output: HashMap::new(), - }), - messages: messages.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + NON_EXISTING_DETECTOR: {}, + }, + "output": {} + }, + "messages": messages, + })) .send() .await?; @@ -1033,15 +983,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { // Invalid output detector scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::new(), - output: HashMap::from([(ANSWER_RELEVANCE_DETECTOR.into(), DetectorParams::new())]), - }), - messages: messages.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + ANSWER_RELEVANCE_DETECTOR: {}, + }, + }, + "messages": messages, + })) .send() .await?; @@ -1062,15 +1013,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { // Non-existing output detector scenario let response = orchestrator_server .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) - .json(&ChatCompletionsRequest { - model: MODEL_ID.into(), - detectors: Some(DetectorConfig { - input: HashMap::new(), - output: HashMap::from([(NON_EXISTING_DETECTOR.into(), DetectorParams::new())]), - }), - messages: messages.clone(), - ..Default::default() - }) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + NON_EXISTING_DETECTOR: {}, + } + }, + "messages": messages, + })) .send() .await?; From 1adecc30bfa45a9bb70a73c4d253a855b852d052 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Thu, 24 Apr 2025 16:57:32 -0300 Subject: [PATCH 13/24] Tests cleanup (#379) * Remove duplicate chat_completions test file Signed-off-by: Mateus Devino * Build mocktail from main Signed-off-by: Mateus Devino * Drop unnecessary clone and vec! calls for mocks Signed-off-by: Mateus Devino * Clone duplicate mocks Signed-off-by: Mateus Devino * Build mocktail from latest main Signed-off-by: Mateus Devino * Add orchestrator error 500 helper Signed-off-by: Mateus Devino * More orchestrator error helper functions Signed-off-by: Mateus Devino * Update tests/common/errors.rs Co-authored-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> * Move imports below copyright notice Signed-off-by: Mateus Devino * Remove get_ prefix from helper functions Signed-off-by: Mateus Devino * Replace generics with &str Signed-off-by: Mateus Devino * Make helper functions part of OrchestratorError Signed-off-by: Mateus Devino * Remove unneeded constants Signed-off-by: Mateus Devino * Rename chunker error Signed-off-by: Mateus Devino --------- Signed-off-by: Mateus Devino Signed-off-by: Mateus Devino <19861348+mdevino@users.noreply.github.com> Co-authored-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- Cargo.lock | 10 +- Cargo.toml | 2 +- tests/chat_completions_detection.rs | 54 +++----- tests/chat_detection.rs | 23 +--- tests/classification_with_text_gen.rs | 76 ++++------- tests/common/chat_completions.rs | 19 --- ...{chat_completion.rs => chat_generation.rs} | 0 tests/common/errors.rs | 48 +++++++ tests/common/mod.rs | 2 +- tests/common/orchestrator.rs | 2 - tests/context_docs_detection.rs | 22 +-- tests/detection_on_generation.rs | 27 +--- tests/generation_with_detection.rs | 38 ++---- tests/streaming_classification_with_gen.rs | 126 ++++-------------- tests/streaming_content_detection.rs | 49 ++----- tests/text_content_detection.rs | 11 +- 16 files changed, 164 insertions(+), 345 deletions(-) delete mode 100644 tests/common/chat_completions.rs rename tests/common/{chat_completion.rs => chat_generation.rs} (100%) diff --git a/Cargo.lock b/Cargo.lock index e33aecff..bc41907f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1382,7 +1382,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -1515,9 +1515,8 @@ dependencies = [ [[package]] name = "mocktail" -version = "0.2.4-alpha" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca54227ec1ea1a7186f14d0e126e989f4db19a3d8cdf75352e4f99f69d78182" +version = "0.2.5-alpha" +source = "git+https://github.com/IBM/mocktail#6296c2783ba1d433407ae1d8144ec5619dc021b9" dependencies = [ "bytes", "futures", @@ -1533,9 +1532,9 @@ dependencies = [ "thiserror 2.0.12", "tokio", "tokio-stream", - "tonic", "tracing", "url", + "uuid", ] [[package]] @@ -3033,6 +3032,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ "getrandom 0.3.2", + "rand 0.9.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a8c9f9dc..27d62511 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,7 +89,7 @@ tonic-build = "0.12.3" [dev-dependencies] axum-test = "17.1.0" -mocktail = { version = "0.2.4-alpha" } +mocktail = { git = "https://github.com/IBM/mocktail" } rand = "0.9.0" test-log = "0.2.17" diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs index 85261360..d1e16b52 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_detection.rs @@ -16,7 +16,7 @@ */ use common::{ - chat_completions::CHAT_COMPLETIONS_ENDPOINT, + chat_generation::CHAT_COMPLETIONS_ENDPOINT, chunker::CHUNKER_UNARY_ENDPOINT, detectors::{ ANSWER_RELEVANCE_DETECTOR, DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, @@ -26,7 +26,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, - ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -122,7 +122,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { contents: vec!["Hi there!".into()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add detector output mock detector_mocks.mock(|when, then| { @@ -132,7 +132,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { contents: vec!["Hello!".into()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add chat completions mock @@ -256,7 +256,7 @@ async fn input_detections() -> Result<(), anyhow::Error> { contents: vec![input_text.into()], detector_params: DetectorParams::new(), }); - then.json(vec![&expected_detections]); + then.json([&expected_detections]); }); // Add chat completions mock @@ -319,10 +319,7 @@ async fn input_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; // Add 500 expected orchestrator error response - let expected_orchestrator_error = OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.to_string(), - }; + let expected_orchestrator_error = OrchestratorError::internal(); // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; @@ -404,7 +401,7 @@ async fn input_client_error() -> Result<(), anyhow::Error> { contents: vec![chat_completions_error_input.into()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add detector mock for detector error scenario @@ -595,7 +592,7 @@ async fn output_detections() -> Result<(), anyhow::Error> { contents: vec![input_text.into()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add detector output mock for generated message @@ -606,7 +603,7 @@ async fn output_detections() -> Result<(), anyhow::Error> { contents: vec![output_text.into()], detector_params: DetectorParams::new(), }); - then.json(vec![&expected_detections]); + then.json([&expected_detections]); }); // Add chunker tokenization mock for output detection user input @@ -704,10 +701,7 @@ async fn output_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; // Add 500 expected orchestrator mock response - let expected_orchestrator_error = OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.to_string(), - }; + let expected_orchestrator_error = OrchestratorError::internal(); // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; @@ -789,7 +783,7 @@ async fn output_client_error() -> Result<(), anyhow::Error> { contents: vec![chat_completions_error_input.into()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add detector mock for detector error scenario @@ -943,13 +937,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - ANSWER_RELEVANCE_DETECTOR - ) - }, + OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR), "failed on invalid input detector scenario" ); @@ -973,10 +961,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing input detector scenario" ); @@ -1000,13 +985,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - ANSWER_RELEVANCE_DETECTOR - ) - }, + OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR), "failed on invalid output detector scenario" ); @@ -1030,10 +1009,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing input detector scenario" ); diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 2f00ee98..95c266fa 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -23,8 +23,7 @@ use common::{ }, errors::{DetectorError, OrchestratorError}, orchestrator::{ - ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, - ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -89,7 +88,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { tools: tools.clone(), detector_params: DetectorParams::new(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -159,7 +158,7 @@ async fn detections() -> Result<(), anyhow::Error> { tools: vec![], detector_params: DetectorParams::new(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -254,8 +253,7 @@ async fn client_errors() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); let response = response.json::().await?; debug!("{response:#?}"); - assert_eq!(response.code, 500); - assert_eq!(response.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(response, OrchestratorError::internal()); Ok(()) } @@ -421,13 +419,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - ANSWER_RELEVANCE_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), "failed on detector with invalid type scenario" ); @@ -447,10 +439,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing detector scenario" ); diff --git a/tests/classification_with_text_gen.rs b/tests/classification_with_text_gen.rs index 9894e618..c8161a7c 100644 --- a/tests/classification_with_text_gen.rs +++ b/tests/classification_with_text_gen.rs @@ -31,8 +31,8 @@ use common::{ GENERATION_NLP_UNARY_ENDPOINT, }, orchestrator::{ - ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, - ORCHESTRATOR_UNARY_ENDPOINT, ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE, TestOrchestratorServer, + ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_UNARY_ENDPOINT, + ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE, TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -220,7 +220,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { contents: vec![text_mock_input.clone()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add output detector mock @@ -232,7 +232,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { contents: vec![expected_response.generated_text.clone()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Add chunker tokenization for input mock @@ -491,7 +491,7 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { ], detector_params: DetectorParams::new(), }); - then.json([vec![], vec![expected_detections[0].clone()]]); + then.json([vec![], vec![&expected_detections[0]]]); }); // Add input detection mock for multiple detections @@ -508,8 +508,8 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { }); then.json([ vec![], - vec![expected_detections[0].clone()], - vec![expected_detections[1].clone()], + vec![&expected_detections[0]], + vec![&expected_detections[1]], ]); }); @@ -647,6 +647,8 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; + let orchestrator_error_500 = OrchestratorError::internal(); + // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; let detector_error_input = "This should return a 500 error on detector"; @@ -674,8 +676,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { contents: vec![detector_error_input.into()], detector_params: DetectorParams::new(), }); - then.internal_server_error() - .message(expected_detector_error.message); + then.internal_server_error().json(&expected_detector_error); }); // Add generation mock for generation internal server error scenario generation_mocks.mock(|when, then| { @@ -753,8 +754,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { // Assertions for generation internal server error scenario let results = response.json::().await?; - assert_eq!(results.code, StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(results.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response for detector internal server error scenario let response = orchestrator_server @@ -779,8 +779,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { // Assertions for detector internal server error scenario let results = response.json::().await?; - assert_eq!(results.code, StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(results.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response let response = orchestrator_server @@ -805,8 +804,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { // Assertions for chunker internal server error scenario let results = response.json::().await?; - assert_eq!(results.code, StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(results.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(results, orchestrator_error_500); Ok(()) } @@ -975,7 +973,7 @@ async fn output_detector_detections() -> Result<(), anyhow::Error> { ], detector_params: DetectorParams::new(), }); - then.json([vec![], vec![expected_detections[0].clone()]]); + then.json([vec![], vec![&expected_detections[0]]]); }); // Add output detection mock for output multiple detections @@ -992,8 +990,8 @@ async fn output_detector_detections() -> Result<(), anyhow::Error> { }); then.json([ vec![], - vec![expected_detections[0].clone()], - vec![expected_detections[1].clone()], + vec![&expected_detections[0]], + vec![&expected_detections[1]], ]); }); @@ -1124,6 +1122,8 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; + let orchestrator_error_500 = OrchestratorError::internal(); + // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; let detector_error_input = "This should return a 500 error on detector"; @@ -1141,7 +1141,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { contents: vec![generation_server_error_input.into()], detector_params: DetectorParams::new(), }); - then.json([vec![[Vec::::new()]]]); + then.json([[[Vec::::new()]]]); }); // Add output detection mock for detector internal server error scenario @@ -1152,8 +1152,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { contents: vec![detector_error_input.into()], detector_params: DetectorParams::new(), }); - then.internal_server_error() - .message(expected_detector_error.message); + then.internal_server_error().json(&expected_detector_error); }); // Add generation mock for generation internal server error scenario @@ -1258,8 +1257,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { // Assertions for generation internal server error scenario let results = response.json::().await?; - assert_eq!(results.code, StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(results.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response for detector internal server error scenario let response = orchestrator_server @@ -1283,8 +1281,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { // Assertions for detector internal server error scenario let results = response.json::().await?; - assert_eq!(results.code, StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(results.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response let response = orchestrator_server @@ -1308,8 +1305,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { // Assertions for chunker internal server error scenario let results = response.json::().await?; - assert_eq!(results.code, StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!(results.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(results, orchestrator_error_500); Ok(()) } @@ -1370,13 +1366,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - ANSWER_RELEVANCE_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), "failed on input detector with invalid type scenario" ); @@ -1402,10 +1392,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing input detector scenario" ); @@ -1433,13 +1420,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - ANSWER_RELEVANCE_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), "failed on output detector with invalid type scenario" ); @@ -1464,10 +1445,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing output detector scenario" ); diff --git a/tests/common/chat_completions.rs b/tests/common/chat_completions.rs deleted file mode 100644 index 7ac67d0e..00000000 --- a/tests/common/chat_completions.rs +++ /dev/null @@ -1,19 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -// Chat completions server endpoint -pub const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; diff --git a/tests/common/chat_completion.rs b/tests/common/chat_generation.rs similarity index 100% rename from tests/common/chat_completion.rs rename to tests/common/chat_generation.rs diff --git a/tests/common/errors.rs b/tests/common/errors.rs index e40c1f08..3c5e9844 100644 --- a/tests/common/errors.rs +++ b/tests/common/errors.rs @@ -29,3 +29,51 @@ pub struct OrchestratorError { pub code: u16, pub details: String, } + +impl OrchestratorError { + /// Helper function that generates an orchestrator internal + /// server error. + pub fn internal() -> OrchestratorError { + OrchestratorError { + code: 500, + details: "unexpected error occurred while processing request".into(), + } + } + /// Helper function that generates an orchestrator non-existing detector error. + pub fn detector_not_found(detector_name: &str) -> Self { + Self { + code: 404, + details: format!("detector `{}` not found", detector_name), + } + } + + /// Helper function that generates an orchestrator invalid detector error. + pub fn detector_not_supported(detector_name: &str) -> Self { + Self { + code: 422, + details: format!( + "detector `{}` is not supported by this endpoint", + detector_name + ), + } + } + + /// Helper function that generates an orchestrator required field error. + pub fn required(field_name: &str) -> Self { + Self { + code: 422, + details: format!("`{}` is required", field_name), + } + } + + /// Helper function that generates an orchestrator invalid chunker error. + pub fn chunker_not_supported(detector_name: &str) -> Self { + Self { + code: 422, + details: format!( + "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", + detector_name + ), + } + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 59720ca5..261f942c 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -14,7 +14,7 @@ limitations under the License. */ -pub mod chat_completions; +pub mod chat_generation; pub mod chunker; pub mod detectors; pub mod errors; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index 73b16e0a..f143ffcc 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -60,8 +60,6 @@ pub const ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT: &str = "/api/v2/chat/completions-detection"; // Messages -pub const ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE: &str = - "unexpected error occurred while processing request"; pub const ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. Please check the detected entities on your input and try again with the unsuitable input removed."; pub fn ensure_global_rustls_state() { diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index bc361084..d24eea4e 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -24,7 +24,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, - ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -67,7 +67,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { context_type: ContextType::Url, context: context.clone(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -129,7 +129,7 @@ async fn detections() -> Result<(), anyhow::Error> { context: context.clone(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -217,8 +217,7 @@ async fn client_error() -> Result<(), anyhow::Error> { // assertions assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); let response = response.json::().await?; - assert_eq!(response.code, detector_error.code); - assert_eq!(response.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(response, OrchestratorError::internal()); Ok(()) } @@ -404,13 +403,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - ANSWER_RELEVANCE_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), "failed on detector with invalid type scenario" ); @@ -432,10 +425,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing detector scenario" ); diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index 38a47762..b59765f0 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -25,7 +25,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, - ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -68,7 +68,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { generated_text: generated_text.into(), detector_params: DetectorParams::new(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -128,7 +128,7 @@ async fn detections() -> Result<(), anyhow::Error> { generated_text: generated_text.into(), detector_params: DetectorParams::new(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -212,8 +212,7 @@ async fn client_error() -> Result<(), anyhow::Error> { // assertions assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); let response = response.json::().await?; - assert_eq!(response.code, detector_error.code); - assert_eq!(response.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(response, OrchestratorError::internal()); Ok(()) } @@ -320,10 +319,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, - details: "`detectors` is required".into() - }, + OrchestratorError::required("detectors"), "failed on empty `detectors` scenario" ); @@ -347,13 +343,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - FACT_CHECKING_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), "failed on invalid detector scenario" ); @@ -374,10 +364,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing detector scenario" ); diff --git a/tests/generation_with_detection.rs b/tests/generation_with_detection.rs index 99ae194c..54d3acf2 100644 --- a/tests/generation_with_detection.rs +++ b/tests/generation_with_detection.rs @@ -25,7 +25,7 @@ use common::{ generation::{GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_UNARY_ENDPOINT}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_GENERATION_WITH_DETECTION_ENDPOINT, - ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -89,7 +89,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { generated_text: generated_text.into(), detector_params: DetectorParams::new(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -171,7 +171,7 @@ async fn detections() -> Result<(), anyhow::Error> { generated_text: generated_text.into(), detector_params: DetectorParams::new(), }); - then.json(vec![detection.clone()]); + then.json([&detection]); }); // Start orchestrator server and its dependencies @@ -221,6 +221,7 @@ async fn client_error() -> Result<(), anyhow::Error> { code: 500, message: "Here's your 500 error".into(), }; + let orchestrator_error_500 = OrchestratorError::internal(); // Add generation mock let model_id = "my-super-model-8B"; @@ -286,10 +287,7 @@ async fn client_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( response.json::().await?, - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } + orchestrator_error_500 ); // assert generation error @@ -308,10 +306,7 @@ async fn client_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( response.json::().await?, - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } + orchestrator_error_500 ); Ok(()) @@ -421,13 +416,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let response = response.json::().await?; debug!("{response:#?}"); - assert_eq!( - response, - OrchestratorError { - code: 422, - details: "`detectors` is required".into() - }, - ); + assert_eq!(response, OrchestratorError::required("detectors")); // assert request with invalid type detectors let response = orchestrator_server @@ -450,13 +439,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - FACT_CHECKING_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), "failed at invalid detector scenario" ); @@ -478,10 +461,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed on non-existing detector scenario" ); diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index 8805a752..06eab391 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -33,9 +33,8 @@ use common::{ GENERATION_NLP_TOKENIZATION_ENDPOINT, }, orchestrator::{ - ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, - ORCHESTRATOR_STREAMING_ENDPOINT, ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE, SseStream, - TestOrchestratorServer, + ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_STREAMING_ENDPOINT, + ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE, SseStream, TestOrchestratorServer, }, }; use eventsource_stream::Eventsource; @@ -387,7 +386,7 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { ], detector_params: DetectorParams::new(), }); - then.json(vec![vec![], vec![mock_detection_response.clone()]]); + then.json([vec![], vec![&mock_detection_response]]); }); // Add generation mock for input token count @@ -486,6 +485,8 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { let detector_error_input = "Detector should return an error"; let generation_server_error_input = "Generation should return an error"; + let orchestrator_error_500 = OrchestratorError::internal(); + let mut chunker_mocks = MockSet::new(); chunker_mocks.mock(|when, then| { when.path(CHUNKER_UNARY_ENDPOINT) @@ -598,13 +599,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!( - messages[0], - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } - ); + assert_eq!(messages[0], orchestrator_error_500); // Test error from detector let response = orchestrator_server @@ -631,13 +626,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!( - messages[0], - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } - ); + assert_eq!(messages[0], orchestrator_error_500); // Test error from generation server let response = orchestrator_server @@ -664,13 +653,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!( - messages[0], - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } - ); + assert_eq!(messages[0], orchestrator_error_500); Ok(()) } @@ -740,13 +723,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - FACT_CHECKING_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), "failed at invalid input detector scenario" ); @@ -779,13 +756,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 422, - details: format!( - "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", - DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC - ) - }, + OrchestratorError::chunker_not_supported(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC), "failed on input detector with invalid chunker scenario" ); @@ -815,10 +786,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed at non-existing input detector scenario" ); @@ -850,13 +818,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - FACT_CHECKING_DETECTOR_SENTENCE - ) - }, + OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), "failed at invalid output detector scenario" ); @@ -888,13 +850,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 422, - details: format!( - "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", - DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC - ) - }, + OrchestratorError::chunker_not_supported(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC), "failed on output detector with invalid chunker scenario" ); @@ -923,10 +879,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) - }, + OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), "failed at non-existing output detector scenario" ); @@ -1041,31 +994,8 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { ]); }); - // Add output detection mock - // TODO: Simply clone mocks instead of create two exact MockSets when/if - // this gets merged: https://github.com/IBM/mocktail/pull/41 - let mut angle_brackets_mocks = MockSet::new(); - angle_brackets_mocks.mock(|when, then| { - when.post() - .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) - .json(ContentAnalysisRequest { - contents: vec!["I am great!".into()], - detector_params: DetectorParams::new(), - }); - then.json([Vec::::new()]); - }); - angle_brackets_mocks.mock(|when, then| { - when.post() - .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) - .json(ContentAnalysisRequest { - contents: vec![" What about you?".into()], - detector_params: DetectorParams::new(), - }); - then.json([Vec::::new()]); - }); - - let mut parenthesis_mocks = MockSet::new(); - parenthesis_mocks.mock(|when, then| { + let mut detection_mocks = MockSet::new(); + detection_mocks.mock(|when, then| { when.post() .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) .json(ContentAnalysisRequest { @@ -1074,7 +1004,7 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { }); then.json([Vec::::new()]); }); - parenthesis_mocks.mock(|when, then| { + detection_mocks.mock(|when, then| { when.post() .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) .json(ContentAnalysisRequest { @@ -1087,9 +1017,9 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { // Start orchestrator server and its dependencies let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); let mock_angle_brackets_detector_server = - MockServer::new(angle_brackets_detector).with_mocks(angle_brackets_mocks); + MockServer::new(angle_brackets_detector).with_mocks(detection_mocks.clone()); let mock_parenthesis_detector_server = - MockServer::new(parenthesis_detector).with_mocks(parenthesis_mocks); + MockServer::new(parenthesis_detector).with_mocks(detection_mocks); let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() @@ -1553,6 +1483,8 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { async fn output_detector_client_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let orchestrator_error_500 = OrchestratorError::internal(); + // Add generation mock let model_id = "my-super-model-8B"; let mut generation_mocks = MockSet::new(); @@ -1772,13 +1704,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!( - messages[0], - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } - ); + assert_eq!(messages[0], orchestrator_error_500); // assert detector error let response = orchestrator_server @@ -1830,13 +1756,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { assert_eq!(first_response.start_index, Some(0)); assert_eq!(first_response.processed_index, Some(11)); - assert_eq!( - second_response, - OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into() - } - ); + assert_eq!(second_response, orchestrator_error_500); Ok(()) } diff --git a/tests/streaming_content_detection.rs b/tests/streaming_content_detection.rs index 1304e195..6100d561 100644 --- a/tests/streaming_content_detection.rs +++ b/tests/streaming_content_detection.rs @@ -25,8 +25,8 @@ use common::{ }, errors::{DetectorError, OrchestratorError}, orchestrator::{ - ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, - ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT, TestOrchestratorServer, json_lines_stream, + ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT, + TestOrchestratorServer, json_lines_stream, }, }; use fms_guardrails_orchestr8::{ @@ -110,31 +110,9 @@ async fn no_detections() -> Result<(), anyhow::Error> { ]); }); - // Add input detection mock - // TODO: Simply clone mocks instead of create two exact MockSets when/if - // this gets merged: https://github.com/IBM/mocktail/pull/41 - let mut angle_brackets_detection_mocks = MockSet::new(); - angle_brackets_detection_mocks.mock(|when, then| { - when.post() - .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) - .json(ContentAnalysisRequest { - contents: vec!["Hi there!".into()], - detector_params: DetectorParams::new(), - }); - then.json([Vec::::new()]); - }); - angle_brackets_detection_mocks.mock(|when, then| { - when.post() - .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) - .json(ContentAnalysisRequest { - contents: vec![" How are you?".into()], - detector_params: DetectorParams::new(), - }); - then.json([Vec::::new()]); - }); - - let mut parenthesis_detection_mocks = MockSet::new(); - parenthesis_detection_mocks.mock(|when, then| { + // Add input detection mocks + let mut detection_mocks = MockSet::new(); + detection_mocks.mock(|when, then| { when.post() .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) .json(ContentAnalysisRequest { @@ -143,7 +121,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); then.json([Vec::::new()]); }); - parenthesis_detection_mocks.mock(|when, then| { + detection_mocks.mock(|when, then| { when.post() .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) .json(ContentAnalysisRequest { @@ -156,9 +134,9 @@ async fn no_detections() -> Result<(), anyhow::Error> { // Run test orchestrator server let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); let mock_angle_brackets_detector_server = - MockServer::new(angle_brackets_detector).with_mocks(angle_brackets_detection_mocks); + MockServer::new(angle_brackets_detector).with_mocks(detection_mocks.clone()); let mock_parenthesis_detector_server = - MockServer::new(parenthesis_detector).with_mocks(parenthesis_detection_mocks); + MockServer::new(parenthesis_detector).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([ @@ -527,6 +505,8 @@ async fn client_error() -> Result<(), anyhow::Error> { let chunker_error_payload = "Chunker should return an error."; let detector_error_payload = "Detector should return an error."; + let orchestrator_error_500 = OrchestratorError::internal(); + let mut chunker_mocks = MockSet::new(); chunker_mocks.mock(|when, then| { when.path(CHUNKER_STREAMING_ENDPOINT) @@ -604,10 +584,7 @@ async fn client_error() -> Result<(), anyhow::Error> { debug!("recv: {msg:?}"); messages.push(serde_json::from_slice(&msg[..]).unwrap()); } - let expected_messages = [OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into(), - }]; + let expected_messages = [orchestrator_error_500]; assert_eq!(messages, expected_messages); // Assert detector error @@ -631,10 +608,6 @@ async fn client_error() -> Result<(), anyhow::Error> { debug!("recv: {msg:?}"); messages.push(serde_json::from_slice(&msg[..]).unwrap()); } - let expected_messages = [OrchestratorError { - code: 500, - details: ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE.into(), - }]; assert_eq!(messages, expected_messages); Ok(()) diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index aca2f6cf..1f160398 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -26,7 +26,7 @@ use common::{ errors::{DetectorError, OrchestratorError}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT, - ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE, TestOrchestratorServer, + TestOrchestratorServer, }, }; use fms_guardrails_orchestr8::{ @@ -92,7 +92,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { ], detector_params: DetectorParams::new(), }); - then.json(vec![ + then.json([ Vec::::new(), Vec::::new(), ]); @@ -106,7 +106,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { contents: vec!["This sentence has no detections.".into()], detector_params: DetectorParams::new(), }); - then.json(vec![Vec::::new()]); + then.json([Vec::::new()]); }); // Start orchestrator server and its dependencies @@ -238,7 +238,7 @@ async fn detections() -> Result<(), anyhow::Error> { contents: vec!["This sentence has .".into()], detector_params: DetectorParams::new(), }); - then.json(vec![vec![ContentAnalysisResponse { + then.json([[ContentAnalysisResponse { start: 18, end: 35, text: "a detection here".into(), @@ -387,8 +387,7 @@ async fn client_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); let response: OrchestratorError = response.json().await?; - assert_eq!(response.code, 500); - assert_eq!(response.details, ORCHESTRATOR_INTERNAL_SERVER_ERROR_MESSAGE); + assert_eq!(response, OrchestratorError::internal()); Ok(()) } From c65cd4c475cc825e35e10e8027d24a1661261849 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:59:06 -0700 Subject: [PATCH 14/24] Fix DetectionBatchStream (#381) Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- .../types/detection_batch_stream.rs | 181 +++++++++++++++--- src/orchestrator/types/detection_batcher.rs | 7 +- .../detection_batcher/chat_completion.rs | 15 +- .../detection_batcher/max_processed_index.rs | 5 + .../types/detection_batcher/noop.rs | 6 +- 5 files changed, 181 insertions(+), 33 deletions(-) diff --git a/src/orchestrator/types/detection_batch_stream.rs b/src/orchestrator/types/detection_batch_stream.rs index 80f599a7..a03da5dd 100644 --- a/src/orchestrator/types/detection_batch_stream.rs +++ b/src/orchestrator/types/detection_batch_stream.rs @@ -15,9 +15,10 @@ */ use futures::{Stream, StreamExt, stream}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; +use tracing::{debug, error}; -use super::{DetectionBatcher, DetectionStream}; +use super::{Chunk, DetectionBatcher, DetectionStream, Detections, DetectorId, InputId}; use crate::orchestrator::Error; /// A stream adapter that wraps multiple detection streams and @@ -34,52 +35,65 @@ impl DetectionBatchStream where B: DetectionBatcher, { - pub fn new(mut batcher: B, streams: Vec) -> Self { - // Create batch channel + pub fn new(batcher: B, streams: Vec) -> Self { let (batch_tx, batch_rx) = mpsc::channel(32); - - // Create a stream set (single stream) from multiple detection streams + // Create single stream from multiple detection streams let mut stream_set = stream::select_all(streams); - - // Spawn batcher task - // This task consumes new detections, pushes them to the batcher, - // and sends batches to the batch channel as they become ready. + // Create batcher manager, an actor to manage the batcher instead of using locks + let batcher_manager = DetectionBatcherManagerHandle::new(batcher); + // Spawn task to receive detections and process batches tokio::spawn(async move { + let mut stream_completed = false; loop { tokio::select! { - result = stream_set.next() => { - match result { + // Disable random branch selection to poll the futures in order + biased; + + // Receive detections and push to batcher + msg = stream_set.next(), if !stream_completed => { + match msg { Some(Ok((input_id, detector_id, chunk, detections))) => { - // Push detections to batcher - batcher.push(input_id, detector_id, chunk, detections); - - // Check if the next batch is ready - if let Some(batch) = batcher.pop_batch() { - // Send batch to batch channel - let _ = batch_tx.send(Ok(batch)).await; - } + debug!(%input_id, ?chunk, ?detections, "pushing detections to batcher"); + batcher_manager + .push(input_id, detector_id, chunk, detections) + .await; }, Some(Err(error)) => { - // Send error to batch channel + error!(?error, "sending error to batch channel"); let _ = batch_tx.send(Err(error)).await; break; }, None => { - // Detection stream set closed - break; + debug!("detections stream has completed"); + stream_completed = true; }, } }, + // Pop batches and send them to batch channel + Some(batch) = batcher_manager.pop() => { + debug!(?batch, "sending batch to batch channel"); + let _ = batch_tx.send(Ok(batch)).await; + }, + // Terminate task when stream is completed and batcher state is empty + empty = batcher_manager.is_empty(), if stream_completed => { + if empty { + break; + } + } } } + debug!("detection batch stream task has completed"); }); Self { batch_rx } } } -impl Stream for DetectionBatchStream { - type Item = Result; +impl Stream for DetectionBatchStream +where + B: DetectionBatcher, +{ + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, @@ -88,3 +102,120 @@ impl Stream for DetectionBatchStream { self.batch_rx.poll_recv(cx) } } + +enum DetectionBatcherMessage { + Push { + input_id: InputId, + detector_id: DetectorId, + chunk: Chunk, + detections: Detections, + }, + Pop { + response_tx: oneshot::Sender>, + }, + IsEmpty { + response_tx: oneshot::Sender, + }, +} + +/// An actor that manages a [`DetectionBatcher`]. +struct DetectionBatcherManager { + batcher: B, + rx: mpsc::Receiver>, +} + +impl DetectionBatcherManager +where + B: DetectionBatcher, +{ + pub fn new(batcher: B, rx: mpsc::Receiver>) -> Self { + Self { batcher, rx } + } + + async fn run(&mut self) { + while let Some(msg) = self.rx.recv().await { + match msg { + DetectionBatcherMessage::Push { + input_id, + detector_id, + chunk, + detections, + } => { + debug!(%input_id, %detector_id, ?chunk, ?detections, "handling push request"); + self.batcher.push(input_id, detector_id, chunk, detections) + } + DetectionBatcherMessage::Pop { response_tx } => { + debug!("handling pop request"); + let batch = self.batcher.pop_batch(); + debug!(?batch, "sending pop response"); + let _ = response_tx.send(batch); + } + DetectionBatcherMessage::IsEmpty { response_tx } => { + debug!("handling is_empty request"); + let empty = self.batcher.is_empty(); + debug!(%empty, "sending is_empty response"); + let _ = response_tx.send(empty); + } + } + } + } +} + +/// A handle to a [`DetectionBatcherManager`]. +#[derive(Clone)] +struct DetectionBatcherManagerHandle { + tx: mpsc::Sender>, +} + +impl DetectionBatcherManagerHandle +where + B: DetectionBatcher, + B::Batch: Clone, +{ + /// Creates a new [`DetectionBatcherManager`] and returns its handle. + pub fn new(batcher: B) -> Self { + let (tx, rx) = mpsc::channel(32); + let mut actor = DetectionBatcherManager::new(batcher, rx); + tokio::spawn(async move { actor.run().await }); + Self { tx } + } + + /// Pushes new detections to the batcher. + pub async fn push( + &self, + input_id: InputId, + detector_id: DetectorId, + chunk: Chunk, + detections: Detections, + ) { + let _ = self + .tx + .send(DetectionBatcherMessage::Push { + input_id, + detector_id, + chunk, + detections, + }) + .await; + } + + /// Removes the next batch of detections from the batcher, if ready. + pub async fn pop(&self) -> Option { + let (response_tx, response_rx) = oneshot::channel(); + let _ = self + .tx + .send(DetectionBatcherMessage::Pop { response_tx }) + .await; + response_rx.await.unwrap_or_default() + } + + /// Returns `true` if the batcher state is empty. + pub async fn is_empty(&self) -> bool { + let (response_tx, response_rx) = oneshot::channel(); + let _ = self + .tx + .send(DetectionBatcherMessage::IsEmpty { response_tx }) + .await; + response_rx.await.unwrap() + } +} diff --git a/src/orchestrator/types/detection_batcher.rs b/src/orchestrator/types/detection_batcher.rs index d089f740..19ca98f3 100644 --- a/src/orchestrator/types/detection_batcher.rs +++ b/src/orchestrator/types/detection_batcher.rs @@ -25,8 +25,8 @@ use super::{Chunk, Detections, DetectorId, InputId}; /// A detection batcher. /// Implements pluggable batching logic for a [`DetectionBatchStream`]. -pub trait DetectionBatcher: Send + 'static { - type Batch: Send + 'static; +pub trait DetectionBatcher: std::fmt::Debug + Clone + Send + 'static { + type Batch: std::fmt::Debug + Clone + Send + 'static; /// Pushes new detections. fn push( @@ -39,4 +39,7 @@ pub trait DetectionBatcher: Send + 'static { /// Removes the next batch of detections, if ready. fn pop_batch(&mut self) -> Option; + + /// Returns `true` if the batcher state is empty. + fn is_empty(&self) -> bool; } diff --git a/src/orchestrator/types/detection_batcher/chat_completion.rs b/src/orchestrator/types/detection_batcher/chat_completion.rs index d14585db..7c616abc 100644 --- a/src/orchestrator/types/detection_batcher/chat_completion.rs +++ b/src/orchestrator/types/detection_batcher/chat_completion.rs @@ -15,27 +15,28 @@ */ #![allow(dead_code)] +use std::collections::BTreeMap; + use super::{Chunk, DetectionBatcher, Detections, DetectorId, InputId}; -use crate::orchestrator::types::Chunks; /// A batcher for chat completions. +#[derive(Debug, Clone)] pub struct ChatCompletionBatcher { detectors: Vec, - // state: TBD + state: BTreeMap<(Chunk, u32), Vec>, } impl ChatCompletionBatcher { pub fn new(detectors: Vec) -> Self { - // let state = TBD::new(); Self { detectors, - // state, + state: BTreeMap::default(), } } } impl DetectionBatcher for ChatCompletionBatcher { - type Batch = (u32, Chunks, Detections); // placeholder, actual type TBD + type Batch = (u32, Chunk, Detections); fn push( &mut self, @@ -53,4 +54,8 @@ impl DetectionBatcher for ChatCompletionBatcher { // ref: https://github.com/foundation-model-stack/fms-guardrails-orchestrator/blob/main/docs/architecture/adrs/005-chat-completion-support.md#streaming-response todo!() } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } } diff --git a/src/orchestrator/types/detection_batcher/max_processed_index.rs b/src/orchestrator/types/detection_batcher/max_processed_index.rs index c6187ee3..9e99fef7 100644 --- a/src/orchestrator/types/detection_batcher/max_processed_index.rs +++ b/src/orchestrator/types/detection_batcher/max_processed_index.rs @@ -33,6 +33,7 @@ use super::{Chunk, DetectionBatcher, Detections, DetectorId, InputId}; /// and so on. /// /// This batcher requires that all detectors use the same chunker. +#[derive(Debug, Clone)] pub struct MaxProcessedIndexBatcher { n_detectors: usize, state: BTreeMap>, @@ -84,6 +85,10 @@ impl DetectionBatcher for MaxProcessedIndexBatcher { } None } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } } #[cfg(test)] diff --git a/src/orchestrator/types/detection_batcher/noop.rs b/src/orchestrator/types/detection_batcher/noop.rs index e626c766..14048252 100644 --- a/src/orchestrator/types/detection_batcher/noop.rs +++ b/src/orchestrator/types/detection_batcher/noop.rs @@ -19,7 +19,7 @@ use std::collections::VecDeque; use super::{Chunk, DetectionBatcher, Detections, DetectorId, InputId}; /// A no-op batcher that doesn't actually batch. -#[derive(Default)] +#[derive(Default, Debug, Clone)] pub struct NoopBatcher { state: VecDeque<(Chunk, Detections)>, } @@ -46,4 +46,8 @@ impl DetectionBatcher for NoopBatcher { fn pop_batch(&mut self) -> Option { self.state.pop_front() } + + fn is_empty(&self) -> bool { + self.state.is_empty() + } } From 9f66e5d4ca5fdd6b8c9d2e3e12c74d677602be2a Mon Sep 17 00:00:00 2001 From: Sanketha-Cr Date: Fri, 25 Apr 2025 22:08:54 +0530 Subject: [PATCH 15/24] To support s390x (#369) * To support s390x Signed-off-by: Sanketha CR Sanketha.cr@ibm.com Signed-off-by: Sanketha * Update Dockerfile Updated Docker file removed elif block and updated with else condition Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Sanketha-Cr --------- Signed-off-by: Sanketha Signed-off-by: Sanketha-Cr Co-authored-by: root Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- Dockerfile | 15 ++++++++++++--- rust-toolchain.toml | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index a81fdc32..e51c5177 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,15 +5,24 @@ ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ # Specific debian version so that compatible glibc version is used -FROM rust:1.86.0-bullseye AS rust-builder +FROM rust:1.86.0 AS rust-builder ARG PROTOC_VERSION ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse # Install protoc, no longer included in prost crate RUN cd /tmp && \ - curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip && \ - unzip protoc-*.zip -d /usr/local && rm protoc-*.zip + if [ "$(uname -m)" = "s390x" ]; then \ + apt update && \ + apt install -y cmake clang libclang-dev curl unzip && \ + curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-s390_64.zip; \ + else \ + curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip; \ + fi && \ + unzip protoc-*.zip -d /usr/local && \ + rm protoc-*.zip + +ENV LIBCLANG_PATH=/usr/lib/llvm-14/lib/ WORKDIR /app diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 95b8a269..d0929449 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] channel = "1.86.0" -components = ["rustfmt", "clippy"] \ No newline at end of file +components = ["rustfmt", "clippy"] From 063b8d6f6adf89ef7d68528789c09cea3efde8b0 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Fri, 25 Apr 2025 12:59:29 -0600 Subject: [PATCH 16/24] :sparkles: Chat completions batcher (#380) * :sparkles::white_check_mark: Initial chat completions batcher Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :recycle::white_check_mark: Chat completions batcher with out-of-order chunks Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Different choice chunks test Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :test_tube: Detection batch test Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :test_tube: Switch chunk ordering Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :test_tube: Reverse chunk order to non-edge case Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :bulb: Add unit tests comments on edge case Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :art::bulb: Clean up debug Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :twisted_rightwards_arrows: Merge with main Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * Update src/orchestrator/types/detection_batcher/chat_completion.rs Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :label: Add ChoiceIndex type Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :recycle: Use ChoiceIndex Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Update tests with ChoiceIndex Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Co-authored-by: Dan Clark <44146800+declark1@users.noreply.github.com> --- .../detection_batcher/chat_completion.rs | 575 +++++++++++++++++- .../detection_batcher/max_processed_index.rs | 5 + 2 files changed, 564 insertions(+), 16 deletions(-) diff --git a/src/orchestrator/types/detection_batcher/chat_completion.rs b/src/orchestrator/types/detection_batcher/chat_completion.rs index 7c616abc..dae87e43 100644 --- a/src/orchestrator/types/detection_batcher/chat_completion.rs +++ b/src/orchestrator/types/detection_batcher/chat_completion.rs @@ -14,48 +14,591 @@ limitations under the License. */ -#![allow(dead_code)] -use std::collections::BTreeMap; +use std::collections::{BTreeMap, btree_map}; -use super::{Chunk, DetectionBatcher, Detections, DetectorId, InputId}; +use super::{Chunk, DetectionBatcher, Detections, DetectorId}; + +pub type ChoiceIndex = u32; /// A batcher for chat completions. +/// +/// A batch corresponds to a choice-chunk (where each chunk is associated +/// with a particular choice through a ChoiceIndex). Batches are returned +/// in-order as detections from all detectors are received for the choice-chunk. +/// +/// Chat completion messages have a `choices` field containing +/// a single choice, e.g. +/// ```text +/// data: {"id":"chat-", ..., "choices":[{"index":0, ...}]} +/// data: {"id":"chat-", ..., "choices":[{"index":1, ...}]} +/// ``` +/// And we track chunks for each choice independently. +/// +/// This batcher requires that all detectors use the same chunker. #[derive(Debug, Clone)] pub struct ChatCompletionBatcher { - detectors: Vec, - state: BTreeMap<(Chunk, u32), Vec>, + n_detectors: usize, + // We place the chunk first since chunk ordering includes where + // the chunk is in all the processed messages. + state: BTreeMap<(Chunk, ChoiceIndex), Vec>, } impl ChatCompletionBatcher { - pub fn new(detectors: Vec) -> Self { + pub fn new(n_detectors: usize) -> Self { Self { - detectors, + n_detectors, state: BTreeMap::default(), } } } impl DetectionBatcher for ChatCompletionBatcher { - type Batch = (u32, Chunk, Detections); + type Batch = (Chunk, ChoiceIndex, Detections); fn push( &mut self, - _input_id: InputId, + choice_index: ChoiceIndex, _detector_id: DetectorId, - _chunk: Chunk, - _detections: Detections, + chunk: Chunk, + detections: Detections, ) { - // NOTE: input_id maps to choice_index - todo!() + match self.state.entry((chunk, choice_index)) { + btree_map::Entry::Vacant(entry) => { + // New chunk, insert entry + entry.insert(vec![detections]); + } + btree_map::Entry::Occupied(mut entry) => { + // Existing chunk, push detections + entry.get_mut().push(detections); + } + } } fn pop_batch(&mut self) -> Option { - // TODO: implement batching logic to align with requirements - // ref: https://github.com/foundation-model-stack/fms-guardrails-orchestrator/blob/main/docs/architecture/adrs/005-chat-completion-support.md#streaming-response - todo!() + // Batching logic here will only assume detections with the same chunker type + // Requirements in https://github.com/foundation-model-stack/fms-guardrails-orchestrator/blob/main/docs/architecture/adrs/005-chat-completion-support.md#streaming-response + // for detections on whole output will be handled outside of the batcher + + // Check if we have all detections for the next chunk + if self + .state + .first_key_value() + .is_some_and(|(_, detections)| detections.len() == self.n_detectors) + { + // We have all detections for the chunk, remove and return it. + if let Some(((chunk, choice_index), detections)) = self.state.pop_first() { + let detections = detections.into_iter().flatten().collect(); + return Some((chunk, choice_index, detections)); + } + } + None } fn is_empty(&self) -> bool { self.state.is_empty() } } + +#[cfg(test)] +mod test { + use std::task::Poll; + + use futures::StreamExt; + use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; + + use super::*; + use crate::orchestrator::{ + Error, + types::{Detection, DetectionBatchStream}, + }; + + #[test] + fn test_batcher_with_single_chunk() { + let choice_index = 0; + let chunk = Chunk { + input_start_index: 0, + input_end_index: 0, + start: 0, + end: 24, + text: "this is a dummy sentence".into(), + }; + + // Create a batcher that will process batches for 2 detectors + let n_detectors = 2; + let mut batcher = ChatCompletionBatcher::new(n_detectors); + + // Push chunk detections for pii detector + batcher.push( + choice_index, + "pii".into(), + chunk.clone(), + vec![Detection { + start: Some(5), + end: Some(10), + detector_id: Some("pii".into()), + detection_type: "pii".into(), + score: 0.4, + ..Default::default() + }] + .into(), + ); + + // We only have detections for 1 detector + // pop_batch() should return None + assert!(batcher.pop_batch().is_none()); + + // Push chunk detections for hap detector + batcher.push( + choice_index, + "hap".into(), + chunk.clone(), + vec![ + Detection { + start: Some(5), + end: Some(10), + detector_id: Some("hap".into()), + detection_type: "hap".into(), + score: 0.8, + ..Default::default() + }, + Detection { + start: Some(15), + end: Some(20), + detector_id: Some("hap".into()), + detection_type: "hap".into(), + score: 0.8, + ..Default::default() + }, + ] + .into(), + ); + + // We have detections for 2 detectors + // pop_batch() should return a batch containing 3 detections for the chunk + let batch = batcher.pop_batch(); + assert!( + batch.is_some_and(|(actual_chunk, actual_choice_index, detections)| { + actual_chunk == chunk + && actual_choice_index == choice_index + && detections.len() == 3 + }) + ); + } + + #[test] + fn test_batcher_with_out_of_order_chunks_same_per_choice() { + let choices = 2; + // Chunks here will be apply to both choices + let chunks = [ + Chunk { + input_start_index: 0, + input_end_index: 10, + start: 0, + end: 56, + text: " a powerful tool for the development \ + of complex systems." + .into(), + }, + Chunk { + input_start_index: 11, + input_end_index: 26, + start: 56, + end: 135, + text: " It has been used in many fields, such as \ + computer vision and image processing." + .into(), + }, + ]; + + // Create a batcher that will process batches for 2 detectors + let n_detectors = 2; + let mut batcher = ChatCompletionBatcher::new(n_detectors); + + for choice_index in 0..choices { + // Push chunk-2 detections for pii detector + batcher.push( + choice_index, + "pii".into(), + chunks[1].clone(), + Detections::default(), // no detections + ); + // Push chunk-1 detections for hap detector + batcher.push( + choice_index, + "hap".into(), + chunks[0].clone(), + Detections::default(), // no detections + ); + // Push chunk-2 detections for hap detector + batcher.push( + choice_index, + "hap".into(), + chunks[1].clone(), + Detections::default(), // no detections + ); + } + + // We have all detections for chunk-2, but not chunk-1 + // pop_batch() should return None + assert!(batcher.pop_batch().is_none()); + + // Push chunk-1 detections for pii detector + for choice_index in 0..choices { + batcher.push( + choice_index, + "pii".into(), + chunks[0].clone(), + vec![Detection { + start: Some(10), + end: Some(20), + detector_id: Some("pii".into()), + detection_type: "pii".into(), + score: 0.4, + ..Default::default() + }] + .into(), + ); + } + + // We have all detections for chunk-1 and chunk-2 + // pop_batch() should return chunk-1 with 1 pii detection, for the first choice + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == chunks[0] && choice_index == 0 && detections.len() == 1 + })); + + // Return the same chunk-1 with 1 pii detection for the second choice + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == chunks[0] && choice_index == 1 && detections.len() == 1 + })); + + // pop_batch() should return chunk-2 with no detections, for the first choice + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == chunks[1] && choice_index == 0 && detections.is_empty() + })); + + // Return the same chunk-2 with no detections for the second choice + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == chunks[1] && choice_index == 1 && detections.is_empty() + })); + + // batcher state should be empty as all batches have been returned + assert!(batcher.state.is_empty()); + } + + #[test] + fn test_batcher_with_out_of_order_chunks_different_per_choice() { + // Chunks here will be apply to the first choice + let choice_1_index = 0; + let choice_1_chunks = [ + Chunk { + input_start_index: 0, + input_end_index: 10, + start: 0, + end: 46, + text: " a tool for the development \ + of simple systems." + .into(), + }, + Chunk { + input_start_index: 11, + input_end_index: 26, + start: 46, + end: 125, + text: " It has been used in many fields, such as \ + computer vision and audio processing." + .into(), + }, + ]; + + // Chunks here will apply to the second choice + let choice_2_index = 1; + let choice_2_chunks = [ + Chunk { + input_start_index: 0, + input_end_index: 10, + start: 0, + end: 56, + text: " a powerful tool for the development \ + of complex systems." + .into(), + }, + Chunk { + input_start_index: 11, + input_end_index: 26, + start: 56, + end: 135, + text: " It has been used in many fields, such as \ + computer vision and image processing." + .into(), + }, + ]; + + // Create a batcher that will process batches for 2 detectors + let n_detectors = 2; + let mut batcher = ChatCompletionBatcher::new(n_detectors); + + // Intersperse choice detections + // NOTE: There may be an edge case when chunk-2 (or later) detections are pushed + // for all detectors here before their respective earlier detections (e.g. chunk-1 here). + // At this batcher level, ordering will be expected but may present as an edge case + // at the stream level ref. + // https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/377 + + // Push chunk-2 detections for pii detector, choice 1 + batcher.push( + choice_1_index, + "pii".into(), + choice_1_chunks[1].clone(), + Detections::default(), // no detections + ); + // Same for choice 2 + batcher.push( + choice_2_index, + "pii".into(), + choice_2_chunks[1].clone(), + Detections::default(), // no detections + ); + // Push chunk-2 detections for hap detector, choice 2 + batcher.push( + choice_2_index, + "hap".into(), + choice_2_chunks[1].clone(), + Detections::default(), // no detections + ); + // Same for choice 1 + batcher.push( + choice_1_index, + "hap".into(), + choice_1_chunks[1].clone(), + Detections::default(), // no detections + ); + // Push chunk-1 detections for hap detector, choice 1 + batcher.push( + choice_1_index, + "hap".into(), + choice_1_chunks[0].clone(), + Detections::default(), // no detections + ); + // Same for choice 2 + batcher.push( + choice_2_index, + "hap".into(), + choice_2_chunks[0].clone(), + Detections::default(), // no detections + ); + + // We have all detections for chunk-2, but not chunk-1, for both choices + // pop_batch() should return None + assert!(batcher.pop_batch().is_none()); + + // Push chunk-1 detections for pii detector, for first choice + batcher.push( + choice_1_index, + "pii".into(), + choice_1_chunks[0].clone(), + vec![Detection { + start: Some(10), + end: Some(20), + detector_id: Some("pii".into()), + detection_type: "pii".into(), + score: 0.4, + ..Default::default() + }] + .into(), + ); + // Push chunk-1 detections for pii detector, for second choice + batcher.push( + choice_2_index, + "pii".into(), + choice_2_chunks[0].clone(), + vec![Detection { + start: Some(10), + end: Some(20), + detector_id: Some("pii".into()), + detection_type: "pii".into(), + score: 0.4, + ..Default::default() + }] + .into(), + ); + + // We have all detections for chunk-1 and chunk-2 + // Expect 4 chunks, with those for the chunk-1 chunks first + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == choice_1_chunks[0] && choice_index == choice_1_index && detections.len() == 1 + })); + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == choice_2_chunks[0] && choice_index == choice_2_index && detections.len() == 1 + })); + + // chunk-2 chunks + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == choice_1_chunks[1] && choice_index == choice_1_index && detections.is_empty() + })); + let batch = batcher.pop_batch(); + assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + chunk == choice_2_chunks[1] && choice_index == choice_2_index && detections.is_empty() + })); + + // batcher state should be empty as all batches (4 chunks) have been returned + assert!(batcher.state.is_empty()); + } + + #[tokio::test] + async fn test_detection_batch_stream_chat() -> Result<(), Error> { + let choices = 2; + // Chunks here will be apply to both choices + let chunks = [ + Chunk { + input_start_index: 0, + input_end_index: 10, + start: 0, + end: 56, + text: " a powerful tool for the development \ + of complex systems." + .into(), + }, + Chunk { + input_start_index: 11, + input_end_index: 26, + start: 56, + end: 135, + text: " It has been used in many fields, such as \ + computer vision and image processing." + .into(), + }, + ]; + + // Create detection channels and streams + let (pii_detections_tx, pii_detections_rx) = + mpsc::channel::>(4); + let pii_detections_stream = ReceiverStream::new(pii_detections_rx).boxed(); + let (hap_detections_tx, hap_detections_rx) = + mpsc::channel::>(4); + let hap_detections_stream = ReceiverStream::new(hap_detections_rx).boxed(); + + // Create a batcher that will process batches for 2 detectors + let n_detectors = 2; + let batcher = ChatCompletionBatcher::new(n_detectors); + + // Create detection batch stream + let streams = vec![pii_detections_stream, hap_detections_stream]; + let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams); + + for choice_index in 0..choices { + // Send chunk-2 detections for pii detector + let _ = pii_detections_tx + .send(Ok(( + choice_index, + "pii".into(), + chunks[1].clone(), + Detections::default(), // no detections + ))) + .await; + + // Send chunk-1 detections for hap detector + let _ = hap_detections_tx + .send(Ok(( + choice_index, + "hap".into(), + chunks[0].clone(), + Detections::default(), // no detections + ))) + .await; + + // Send chunk-2 detections for hap detector + let _ = hap_detections_tx + .send(Ok(( + choice_index, + "hap".into(), + chunks[1].clone(), + Detections::default(), // no detections + ))) + .await; + } + + // We have all detections for chunk-2, but not chunk-1 + // detection_batch_stream.next() future should not be ready + assert!(matches!( + futures::poll!(detection_batch_stream.next()), + Poll::Pending + )); + + // Send chunk-1 detections for pii detector + for choice_index in 0..choices { + let _ = pii_detections_tx + .send(Ok(( + choice_index, + "pii".into(), + chunks[0].clone(), + vec![Detection { + start: Some(10), + end: Some(20), + detector_id: Some("pii".into()), + detection_type: "pii".into(), + score: 0.4, + ..Default::default() + }] + .into(), + ))) + .await; + } + + // We have all detections for chunk-1 and chunk-2 + // detection_batch_stream.next() should be ready and return chunk-1 with 1 pii detection, for choice 1 + let batch = detection_batch_stream.next().await; + assert!(batch.is_some_and(|result| { + result.is_ok_and(|(chunk, choice_index, detections)| { + chunk == chunks[0] && choice_index == 0 && detections.len() == 1 + }) + })); + + // Then choice 2 + let batch = detection_batch_stream.next().await; + assert!(batch.is_some_and(|result| { + result.is_ok_and(|(chunk, choice_index, detections)| { + chunk == chunks[0] && choice_index == 1 && detections.len() == 1 + }) + })); + + // detection_batch_stream.next() should be ready and return chunk-2 with no detections, for choice 1 + let batch = detection_batch_stream.next().await; + assert!(batch.is_some_and(|result| { + result.is_ok_and(|(chunk, choice_index, detections)| { + chunk == chunks[1] && choice_index == 0 && detections.is_empty() + }) + })); + + // Then choice 2 + let batch = detection_batch_stream.next().await; + assert!(batch.is_some_and(|result| { + result.is_ok_and(|(chunk, choice_index, detections)| { + chunk == chunks[1] && choice_index == 1 && detections.is_empty() + }) + })); + + // detection_batch_stream.next() future should not be ready + // as detection senders have not been closed + assert!(matches!( + futures::poll!(detection_batch_stream.next()), + Poll::Pending + )); + + // Drop detection senders + drop(pii_detections_tx); + drop(hap_detections_tx); + + // detection_batch_stream.next() should return None + assert!(detection_batch_stream.next().await.is_none()); + + Ok(()) + } +} diff --git a/src/orchestrator/types/detection_batcher/max_processed_index.rs b/src/orchestrator/types/detection_batcher/max_processed_index.rs index 9e99fef7..60c4edb8 100644 --- a/src/orchestrator/types/detection_batcher/max_processed_index.rs +++ b/src/orchestrator/types/detection_batcher/max_processed_index.rs @@ -202,6 +202,11 @@ mod test { let n = 2; let mut batcher = MaxProcessedIndexBatcher::new(n); + // NOTE: Both chunk-2 detections are pushed for detectors here before their + // respective chunk-1 detections. At this batcher level, ordering will be + // expected but may present as an edge case at the stream level ref. + // https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/377 + // Push chunk-2 detections for pii detector batcher.push( input_id, From d473b00aad31b1699864471d8d491094b046a59a Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:28:23 -0600 Subject: [PATCH 17/24] :goal_net: Handle unsupported media type errors (#386) * :goal_net: Handle unsupported media type errors Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :recycle: Use mutable headers Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :recycle: Use HeaderValue Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- src/clients/detector.rs | 7 +++---- src/clients/http.rs | 3 +++ src/orchestrator/common/client.rs | 9 ++++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 4b8cc103..05dfd103 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -25,7 +25,7 @@ use url::Url; use super::{ Error, - http::{HttpClientExt, RequestBody, ResponseBody}, + http::{HttpClientExt, JSON_CONTENT_TYPE, RequestBody, ResponseBody}, }; pub mod text_contents; @@ -82,12 +82,11 @@ impl DetectorClientExt for C { &self, model_id: &str, url: Url, - headers: HeaderMap, + mut headers: HeaderMap, request: impl RequestBody, ) -> Result { - let mut headers = headers; headers.append(DETECTOR_ID_HEADER_NAME, model_id.parse().unwrap()); - headers.append(CONTENT_TYPE, "application/json".parse().unwrap()); + headers.append(CONTENT_TYPE, JSON_CONTENT_TYPE); // Header used by a router component, if available headers.append(MODEL_HEADER_NAME, model_id.parse().unwrap()); diff --git a/src/clients/http.rs b/src/clients/http.rs index e23d167a..485d8829 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -17,6 +17,7 @@ use std::{fmt::Debug, ops::Deref, time::Duration}; +use http::header::HeaderValue; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use hyper::{ HeaderMap, Method, Request, StatusCode, @@ -46,6 +47,8 @@ use crate::{ utils::{AsUriExt, trace}, }; +pub const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json"); + /// Any type that implements Debug and Serialize can be used as a request body pub trait RequestBody: Debug + Serialize {} diff --git a/src/orchestrator/common/client.rs b/src/orchestrator/common/client.rs index 653e2a78..27375103 100644 --- a/src/orchestrator/common/client.rs +++ b/src/orchestrator/common/client.rs @@ -16,7 +16,7 @@ */ //! Client helpers use futures::{StreamExt, TryStreamExt}; -use http::HeaderMap; +use http::{HeaderMap, header::CONTENT_TYPE}; use tokio::sync::broadcast; use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; use tracing::{debug, instrument}; @@ -30,6 +30,7 @@ use crate::{ GenerationDetectionRequest, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, + http::JSON_CONTENT_TYPE, openai::{self, ChatCompletionsResponse, OpenAiClient}, }, models::{ @@ -246,11 +247,12 @@ pub async fn detect_text_context( #[instrument(skip_all, fields(model_id))] pub async fn chat_completion( client: &OpenAiClient, - headers: HeaderMap, + mut headers: HeaderMap, request: openai::ChatCompletionsRequest, ) -> Result { let model_id = request.model.clone(); debug!(%model_id, ?request, "sending chat completions request"); + headers.append(CONTENT_TYPE, JSON_CONTENT_TYPE); let response = client .chat_completions(request, headers) .await @@ -266,11 +268,12 @@ pub async fn chat_completion( #[instrument(skip_all, fields(model_id))] pub async fn chat_completion_stream( client: &OpenAiClient, - headers: HeaderMap, + mut headers: HeaderMap, request: openai::ChatCompletionsRequest, ) -> Result { let model_id = request.model.clone(); debug!(%model_id, ?request, "sending chat completions stream request"); + headers.append(CONTENT_TYPE, JSON_CONTENT_TYPE); let response = client .chat_completions(request, headers) .await From e3985d56b00ec9f680d51905322121c7f16ea979 Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Tue, 29 Apr 2025 13:07:27 -0600 Subject: [PATCH 18/24] :bug: Allow input detection on whole input for streaming text generation endpoint (#388) * :bug::wrench: Allow input detection on text generation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * :white_check_mark: Valid whole doc input detection test Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .../streaming_classification_with_gen.rs | 13 +- tests/streaming_classification_with_gen.rs | 120 +++++++++++++----- 2 files changed, 96 insertions(+), 37 deletions(-) diff --git a/src/orchestrator/handlers/streaming_classification_with_gen.rs b/src/orchestrator/handlers/streaming_classification_with_gen.rs index 82427aff..f43f3039 100644 --- a/src/orchestrator/handlers/streaming_classification_with_gen.rs +++ b/src/orchestrator/handlers/streaming_classification_with_gen.rs @@ -74,18 +74,25 @@ impl Handle for Orchestrator { let input_detectors = task.guardrails_config.input_detectors(); let output_detectors = task.guardrails_config.output_detectors(); - // input detectors validation + // Input detectors validation + // Allow `whole_doc_chunker` detectors on input detection + // because the input detection call is unary if let Err(error) = validate_detectors( &input_detectors, &ctx.config.detectors, &[DetectorType::TextContents], - false, + true, ) { let _ = response_tx.send(Err(error)).await; return; } - // output detectors validation + // Output detectors validation + // Disallow `whole_doc_chunker` detectors on output detection + // for now until results of these detectors are handled as + // planned for chat completions, with detection results + // provided separately at the end but not blocking other + // detection results that may be provided on smaller chunks if let Err(error) = validate_detectors( &output_detectors, &ctx.config.detectors, diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index 06eab391..b1aee7c0 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -404,6 +404,32 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { then.pb(mock_tokenization_response.clone()); }); + // Detector on whole doc / entire input for multi-detector scenario + let whole_doc_mock_detection_response = ContentAnalysisResponse { + start: 0, + end: 61, + text: "This sentence does not have a detection. But .".into(), + detection: "has_angle_brackets_1".into(), + detection_type: "angle_brackets_1".into(), + detector_id: Some(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }; + let mut whole_doc_detection_mocks = MockSet::new(); + whole_doc_detection_mocks.mock(|when, then| { + when.path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![ + "This sentence does not have a detection. But .".into(), + ], + detector_params: DetectorParams::new(), + }); + then.json([vec![&whole_doc_mock_detection_response]]); + }); + let mock_whole_doc_detector_server = MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC) + .with_mocks(whole_doc_detection_mocks); + // Start orchestrator server and its dependencies let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); @@ -411,7 +437,7 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&generation_server) - .detector_servers([&mock_detector_server]) + .detector_servers([&mock_detector_server, &mock_whole_doc_detector_server]) .chunker_servers([&mock_chunker_server]) .build() .await?; @@ -471,6 +497,65 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { }]) ); + // Multi-detector scenario with detector that uses content from entire input + let response = orchestrator_server + .post(ORCHESTRATOR_STREAMING_ENDPOINT) + .json(&GuardrailsHttpRequest { + model_id: model_id.into(), + inputs: "This sentence does not have a detection. But .".into(), + guardrail_config: Some(GuardrailsConfig { + input: Some(GuardrailsConfigInput { + models: HashMap::from([ + (detector_name.into(), DetectorParams::new()), + ( + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(), + DetectorParams::new(), + ), + ]), + masks: None, + }), + output: None, + }), + text_gen_parameters: None, + }) + .send() + .await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + assert_eq!(messages.len(), 1); + assert!(messages[0].generated_text.is_none()); + assert_eq!( + messages[0].token_classification_results, + TextGenTokenClassificationResults { + input: Some(vec![ + TokenClassificationResult { + start: 0, + end: 61, + word: whole_doc_mock_detection_response.text, + entity: whole_doc_mock_detection_response.detection, + entity_group: whole_doc_mock_detection_response.detection_type, + detector_id: whole_doc_mock_detection_response.detector_id, + score: whole_doc_mock_detection_response.score, + token_count: None + }, + TokenClassificationResult { + start: 46, // index of first token of detected text, relative to the `inputs` string sent in the orchestrator request. + end: 59, // index of last token (+1) of detected text, relative to the `inputs` string sent in the orchestrator request. + word: "this one does".into(), + entity: "has_angle_brackets".into(), + entity_group: "angle_brackets".into(), + detector_id: Some(detector_name.to_string()), + score: mock_detection_response.score, + token_count: None + } + ]), + output: None + } + ); + Ok(()) } @@ -727,39 +812,6 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { "failed at invalid input detector scenario" ); - // Invalid chunker on input detector scenario - let response = orchestrator_server - .post(ORCHESTRATOR_STREAMING_ENDPOINT) - .json(&GuardrailsHttpRequest { - model_id: model_id.into(), - inputs: "This request contains a detector with an invalid chunker".into(), - guardrail_config: Some(GuardrailsConfig { - input: Some(GuardrailsConfigInput { - models: HashMap::from([( - DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(), - DetectorParams::new(), - )]), - masks: None, - }), - output: None, - }), - text_gen_parameters: None, - }) - .send() - .await?; - debug!("{response:#?}"); - - assert_eq!(response.status(), 200); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; - debug!("{messages:#?}"); - assert_eq!(messages.len(), 1); - assert_eq!( - messages[0], - OrchestratorError::chunker_not_supported(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC), - "failed on input detector with invalid chunker scenario" - ); - // Non-existing input detector scenario let response = orchestrator_server .post(ORCHESTRATOR_STREAMING_ENDPOINT) From 9c7fd4816b0fd493d7f5d06184c98d45500e4085 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Thu, 1 May 2025 10:37:25 -0700 Subject: [PATCH 19/24] refactor: server cleanups (#387) Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- src/main.rs | 11 +- src/server.rs | 800 +++++----------------------------- src/server/errors.rs | 132 ++++++ src/server/routes.rs | 368 ++++++++++++++++ src/server/tls.rs | 172 ++++++++ tests/canary_test.rs | 71 --- tests/resources/localhost.crt | 19 + tests/resources/localhost.key | 28 ++ 8 files changed, 844 insertions(+), 757 deletions(-) create mode 100644 src/server/errors.rs create mode 100644 src/server/routes.rs create mode 100644 src/server/tls.rs delete mode 100644 tests/canary_test.rs create mode 100644 tests/resources/localhost.crt create mode 100644 tests/resources/localhost.key diff --git a/src/main.rs b/src/main.rs index 2aa25ad8..b8d7675b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ use clap::Parser; use fms_guardrails_orchestr8::{ args::Args, config::OrchestratorConfig, orchestrator::Orchestrator, server, utils, }; +use tracing::info; fn main() -> Result<(), anyhow::Error> { rustls::crypto::aws_lc_rs::default_provider() @@ -50,7 +51,7 @@ fn main() -> Result<(), anyhow::Error> { let config = OrchestratorConfig::load(args.config_path).await?; let orchestrator = Orchestrator::new(config, args.start_up_health_check).await?; - server::run( + let (health_handle, guardrails_handle) = server::run( http_addr, health_http_addr, args.tls_cert_path, @@ -58,7 +59,13 @@ fn main() -> Result<(), anyhow::Error> { args.tls_client_ca_cert_path, orchestrator, ) - .await?; + .await + .unwrap_or_else(|e| panic!("failed to run server: {e}")); + + // Await server shutdown + let _ = tokio::join!(health_handle, guardrails_handle); + info!("shutdown complete"); + Ok(trace_shutdown()?) }) } diff --git a/src/server.rs b/src/server.rs index 10be5e05..f9336dc5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,569 +14,85 @@ limitations under the License. */ +use std::{net::SocketAddr, path::PathBuf, sync::Arc}; -use std::{ - collections::{HashMap, HashSet}, - convert::Infallible, - error::Error as _, - fs::File, - io::BufReader, - net::SocketAddr, - path::PathBuf, - sync::Arc, -}; - -use axum::{ - Json, Router, - extract::{Query, Request, State, rejection::JsonRejection}, - http::{HeaderMap, StatusCode}, - response::{ - IntoResponse, Response, - sse::{Event, KeepAlive, Sse}, - }, - routing::{get, post}, -}; -use axum_extra::{extract::WithRejection, json_lines::JsonLines}; -use futures::{ - Stream, StreamExt, - stream::{self, BoxStream}, -}; -use hyper::body::Incoming; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use opentelemetry::trace::TraceContextExt; -use rustls::{RootCertStore, ServerConfig, server::WebPkiClientVerifier}; -use tokio::{net::TcpListener, signal, sync::mpsc}; -use tokio_rustls::TlsAcceptor; -use tokio_stream::wrappers::ReceiverStream; -use tower::Service; +use tokio::{net::TcpListener, signal}; use tower_http::trace::TraceLayer; -use tracing::{Span, debug, error, info, warn}; -use tracing_opentelemetry::OpenTelemetrySpanExt; -use webpki::types::{CertificateDer, PrivateKeyDer}; - -use crate::{ - clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, - models::{self, InfoParams, InfoResponse, StreamingContentDetectionRequest}, - orchestrator::{ - self, Orchestrator, - handlers::{chat_completions_detection::ChatCompletionsDetectionTask, *}, - }, - utils, -}; - -const API_PREFIX: &str = r#"/api/v1/task"#; -// New orchestrator API -const TEXT_API_PREFIX: &str = r#"/api/v2/text"#; - -const PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION"); -const PACKAGE_NAME: &str = env!("CARGO_PKG_NAME"); +use tracing::info; -/// Server shared state -pub struct ServerState { - orchestrator: Orchestrator, -} +use crate::orchestrator::Orchestrator; -impl ServerState { - pub fn new(orchestrator: Orchestrator) -> Self { - Self { orchestrator } - } -} +mod errors; +mod routes; +mod tls; +pub use errors::Error; +use tls::{configure_tls, serve_with_tls}; -/// Run the orchestrator server +/// Configures and runs orchestrator servers. pub async fn run( - http_addr: SocketAddr, - health_http_addr: SocketAddr, + guardrails_addr: SocketAddr, + health_addr: SocketAddr, tls_cert_path: Option, tls_key_path: Option, tls_client_ca_cert_path: Option, orchestrator: Orchestrator, -) -> Result<(), Error> { - // Overall, the server setup and run does a couple of steps: - // (1) Sets up a HTTP server (without TLS) for the health endpoint - // (2) Sets up a HTTP(s) server for the main guardrails endpoints - // (2a) Configures TLS or mTLS depending on certs/key provided - // (2b) Adds server routes - // (2c) Generate the server task based on whether or not TLS is configured - // (3) Launch each server as a separate task - // NOTE: axum::serve is used for servers without TLS since it is designed to be - // simple and not allow for much configuration. To allow for TLS configuration - // with rustls, the hyper and tower crates [what axum is built on] had to - // be used directly - - let shared_state = Arc::new(ServerState::new(orchestrator)); - - // (1) Separate HTTP health server without TLS for probes - let health_app = get_health_app(shared_state.clone()); - let health_listener = TcpListener::bind(&health_http_addr) - .await - .unwrap_or_else(|_| panic!("failed to bind to {health_http_addr}")); - let health_server = axum::serve(health_listener, health_app.into_make_service()) - .with_graceful_shutdown(shutdown_signal()); - let health_handle = - tokio::task::spawn(async { health_server.await.expect("HTTP health server crashed!") }); - info!( - "HTTP health server started on port {}", - health_http_addr.port() - ); - - // (2) Main guardrails server - // (2a) Configure TLS if requested - let mut arc_server_config: Option> = None; - if let (Some(cert_path), Some(key_path)) = (tls_cert_path, tls_key_path) { - info!("Configuring Server TLS for incoming connections"); - let server_cert = load_certs(&cert_path); - let key = load_private_key(&key_path); - - // A process wide default crypto provider is needed, aws_lc_rs feature is enabled by default - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - - // Configure mTLS if client CA is provided - let client_auth = if tls_client_ca_cert_path.is_some() { - info!("Configuring TLS trust certificate (mTLS) for incoming connections"); - let client_certs = load_certs( - tls_client_ca_cert_path - .as_ref() - .expect("error loading certs for mTLS"), - ); - let mut client_auth_certs = RootCertStore::empty(); - for client_cert in client_certs { - // Should be only one - client_auth_certs - .add(client_cert.clone()) - .unwrap_or_else(|e| { - panic!("error adding client cert {:?}: {}", client_cert, e) - }); - } - WebPkiClientVerifier::builder(client_auth_certs.into()) - .build() - .unwrap_or_else(|e| panic!("error building client verifier: {}", e)) - } else { - WebPkiClientVerifier::no_client_auth() - }; - - let server_config = ServerConfig::builder() - .with_client_cert_verifier(client_auth) - .with_single_cert(server_cert, key) - .expect("bad server certificate or key"); - arc_server_config = Some(Arc::new(server_config)); - } else { - info!("HTTP server not configured with TLS") - } - - // (2b) Add main guardrails server routes - let mut router = Router::new() - .route( - &format!("{}/classification-with-text-generation", API_PREFIX), - post(classification_with_gen), - ) - .route( - &format!("{}/detection/stream-content", TEXT_API_PREFIX), - post(stream_content_detection), - ) - .route( - &format!( - "{}/server-streaming-classification-with-text-generation", - API_PREFIX - ), - post(stream_classification_with_gen), - ) - .route( - &format!("{}/generation-detection", TEXT_API_PREFIX), - post(generation_with_detection), - ) - .route( - &format!("{}/detection/content", TEXT_API_PREFIX), - post(detection_content), - ) - .route( - &format!("{}/detection/chat", TEXT_API_PREFIX), - post(detect_chat), - ) - .route( - &format!("{}/detection/context", TEXT_API_PREFIX), - post(detect_context_documents), - ) - .route( - &format!("{}/detection/generated", TEXT_API_PREFIX), - post(detect_generated), - ); - - // If chat generation is configured, enable the chat completions detection endpoint. - if shared_state.orchestrator.config().chat_generation.is_some() { - info!("Enabling chat completions detection endpoint"); - router = router.route( - "/api/v2/chat/completions-detection", - post(chat_completions_detection), - ); - } - - let app = router.with_state(shared_state).layer( +) -> Result<(tokio::task::JoinHandle<()>, tokio::task::JoinHandle<()>), Error> { + let state = Arc::new(ServerState::new(orchestrator)); + let health_handle = run_health_server(health_addr, state.clone()).await?; + let guardrails_handle = run_guardrails_server( + guardrails_addr, + tls_cert_path, + tls_key_path, + tls_client_ca_cert_path, + state, + ) + .await?; + Ok((health_handle, guardrails_handle)) +} + +/// Configures and runs health server. +async fn run_health_server( + addr: SocketAddr, + state: Arc, +) -> Result, Error> { + info!("starting health server on {addr}"); + let app = routes::health_router(state); + let listener = TcpListener::bind(&addr).await?; + let server = + axum::serve(listener, app.into_make_service()).with_graceful_shutdown(shutdown_signal()); + Ok(tokio::task::spawn(async { + server.await.expect("health server crashed!") + })) +} + +/// Configures and runs guardrails server. +async fn run_guardrails_server( + addr: SocketAddr, + tls_cert_path: Option, + tls_key_path: Option, + tls_client_ca_cert_path: Option, + state: Arc, +) -> Result, Error> { + info!("starting guardrails server on {addr}"); + let router = routes::guardrails_router(state); + let app = router.layer( TraceLayer::new_for_http() - .make_span_with(utils::trace::incoming_request_span) - .on_request(utils::trace::on_incoming_request) - .on_response(utils::trace::on_outgoing_response) - .on_eos(utils::trace::on_outgoing_eos), + .make_span_with(crate::utils::trace::incoming_request_span) + .on_request(crate::utils::trace::on_incoming_request) + .on_response(crate::utils::trace::on_outgoing_response) + .on_eos(crate::utils::trace::on_outgoing_eos), ); - - // (2c) Generate main guardrails server handle based on whether TLS is needed - let listener: TcpListener = TcpListener::bind(&http_addr) - .await - .unwrap_or_else(|_| panic!("failed to bind to {http_addr}")); - let guardrails_handle = if arc_server_config.is_some() { - // TLS - // Use more low level server configuration than axum for configurability - // Ref. https://github.com/tokio-rs/axum/blob/main/examples/low-level-rustls/src/main.rs - info!("HTTPS server started on port {}", http_addr.port()); - let tls_acceptor = TlsAcceptor::from(arc_server_config.unwrap()); - tokio::spawn(async move { - let graceful = hyper_util::server::graceful::GracefulShutdown::new(); - let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); - let mut signal = std::pin::pin!(shutdown_signal()); - loop { - let tower_service = app.clone(); - let tls_acceptor = tls_acceptor.clone(); - - // Wait for new tcp connection - let (cnx, addr) = tokio::select! { - res = listener.accept() => { - match res { - Ok(res) => res, - Err(err) => { - error!("error accepting tcp connection: {err}"); - continue; - } - } - } - _ = &mut signal => { - debug!("graceful shutdown signal received"); - break; - } - }; - - // Wait for tls handshake - let stream = tokio::select! { - res = tls_acceptor.accept(cnx) => { - match res { - Ok(stream) => stream, - Err(err) => { - error!("error accepting connection on handshake: {err}"); - continue; - } - } - } - _ = &mut signal => { - debug!("graceful shutdown signal received"); - break; - } - }; - - // `TokioIo` converts between Hyper's own `AsyncRead` and `AsyncWrite` traits - let stream = TokioIo::new(stream); - - let hyper_service = - hyper::service::service_fn(move |request: Request| { - // Clone necessary since hyper's `Service` uses `&self` whereas - // tower's `Service` requires `&mut self` - tower_service.clone().call(request) - }); - let conn = builder.serve_connection_with_upgrades(stream, hyper_service); - let fut = graceful.watch(conn.into_owned()); - tokio::spawn(async move { - if let Err(err) = fut.await { - warn!("error serving connection from {}: {}", addr, err); - } - }); - } - - tokio::select! { - () = graceful.shutdown() => { - debug!("Gracefully shutdown!"); - }, - () = tokio::time::sleep(std::time::Duration::from_secs(10)) => { - debug!("Waited 10 seconds for graceful shutdown, aborting..."); - } - } - }) + let listener = TcpListener::bind(&addr).await?; + let tls_config = configure_tls(tls_cert_path, tls_key_path, tls_client_ca_cert_path); + let shutdown_signal = shutdown_signal(); + if let Some(tls_config) = tls_config { + Ok(serve_with_tls(app, listener, tls_config, shutdown_signal)) } else { - // Non-TLS - // Keep simple axum serve call for http version - let http_server = axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(shutdown_signal()); - info!("HTTP server started on port {}", http_addr.port()); - tokio::task::spawn(async { http_server.await.expect("HTTP server crashed!") }) - }; - - // (3) Launch each server as a separate task - let (health_res, guardrails_res) = tokio::join!(health_handle, guardrails_handle); - health_res.unwrap(); - guardrails_res.unwrap(); - info!("Shutdown complete for servers"); - Ok(()) -} - -pub fn get_health_app(state: Arc) -> Router { - Router::new() - .route("/health", get(health)) - .route("/info", get(info)) - .with_state(state) -} - -async fn health() -> Result { - // NOTE: we are only adding the package information in the `health` endpoint to have this endpoint - // provide a non empty 200 response. If we need to add more information regarding dependencies version - // or such things, then we will add another `/info` endpoint accordingly. And those info - // should not be added in `health` endpoint` - let info_object = HashMap::from([(PACKAGE_NAME, PACKAGE_VERSION)]); - Ok(Json(info_object).into_response()) -} - -async fn info( - State(state): State>, - Query(params): Query, -) -> Result, Error> { - let services = state.orchestrator.client_health(params.probe).await; - Ok(Json(InfoResponse { services })) -} - -async fn classification_with_gen( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection, Error>, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - request.validate()?; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ClassificationWithGenTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), - } -} - -async fn generation_with_detection( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection< - Json, - Error, - >, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - request.validate()?; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = GenerationWithDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), - } -} - -async fn stream_classification_with_gen( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection, Error>, -) -> Sse>> { - let trace_id = Span::current().context().span().span_context().trace_id(); - if let Err(error) = request.validate() { - // Request validation failed, return stream with single error SSE event - let error: Error = error.into(); - return Sse::new( - stream::iter([Ok(Event::default() - .event("error") - .json_data(error.to_json()) - .unwrap())]) - .boxed(), - ); - } - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = StreamingClassificationWithGenTask::new(trace_id, request, headers); - let response_stream = state.orchestrator.handle(task).await.unwrap(); - // Convert response stream to a stream of SSE events - let event_stream = response_stream - .map(|message| match message { - Ok(response) => Ok(Event::default() - //.event("message") NOTE: per spec, should not be included for data-only message events - .json_data(response) - .unwrap()), - Err(error) => { - let error: Error = error.into(); - Ok(Event::default() - .event("error") - .json_data(error.to_json()) - .unwrap()) - } - }) - .boxed(); - Sse::new(event_stream).keep_alive(KeepAlive::default()) -} - -async fn stream_content_detection( - State(state): State>, - headers: HeaderMap, - json_lines: JsonLines, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - // Validate the content-type from the header and ensure it is application/x-ndjson - // If it's not, return a UnsupportedContentType error with the appropriate message - let content_type = headers - .get(http::header::CONTENT_TYPE) - .and_then(|value| value.to_str().ok()); - match content_type { - Some(content_type) if content_type.starts_with("application/x-ndjson") => (), - _ => { - return Err(Error::UnsupportedContentType( - "expected application/x-ndjson".into(), - )); - } - }; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - - // Create input stream - let input_stream = json_lines - .map(|result| match result { - Ok(message) => { - message.validate()?; - Ok(message) - } - Err(error) => Err(orchestrator::errors::Error::Validation(error.to_string())), - }) - .enumerate() - .boxed(); - - // Create task and submit to handler - let task = StreamingContentDetectionTask::new(trace_id, headers, input_stream); - let mut response_stream = state.orchestrator.handle(task).await?; - - // Create output stream - // This stream returns ND-JSON formatted messages to the client - // StreamingContentDetectionResponse / server::Error - let (output_tx, output_rx) = mpsc::channel::>(128); - let output_stream = ReceiverStream::new(output_rx); - - // Spawn task to consume response stream (typed) and send to output stream (json) - tokio::spawn(async move { - while let Some(result) = response_stream.next().await { - match result { - Ok(msg) => { - let msg = utils::json::to_nd_string(&msg).unwrap(); - let _ = output_tx.send(Ok(msg)).await; - } - Err(error) => { - // Convert orchestrator::Error to server::Error - let error: Error = error.into(); - // server::Error doesn't impl Serialize, so we use to_json() - let error_msg = utils::json::to_nd_string(&error.to_json()).unwrap(); - let _ = output_tx.send(Ok(error_msg)).await; - } - } - } - }); - - Ok(Response::new(axum::body::Body::from_stream(output_stream))) -} - -async fn detection_content( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection< - Json, - Error, - >, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - request.validate()?; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = TextContentDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), - } -} - -async fn detect_context_documents( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection, Error>, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - request.validate()?; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ContextDocsDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), - } -} - -async fn detect_chat( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection, Error>, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - request.validate_for_text()?; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ChatDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), - } -} - -async fn detect_generated( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection< - Json, - Error, - >, -) -> Result { - let trace_id = Span::current().context().span().span_context().trace_id(); - request.validate()?; - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = DetectionOnGenerationTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => Ok(Json(response).into_response()), - Err(error) => Err(error.into()), - } -} - -async fn chat_completions_detection( - State(state): State>, - headers: HeaderMap, - WithRejection(Json(request), _): WithRejection, Error>, -) -> Result { - use ChatCompletionsResponse::*; - let trace_id = Span::current().context().span().span_context().trace_id(); - let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); - let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); - match state.orchestrator.handle(task).await { - Ok(response) => match response { - Unary(response) => Ok(Json(response).into_response()), - Streaming(response_rx) => { - let response_stream = ReceiverStream::new(response_rx); - // Convert response stream to a stream of SSE events - let event_stream: BoxStream> = response_stream - .map(|message| match message { - Ok(Some(chunk)) => Ok(Event::default().json_data(chunk).unwrap()), - Ok(None) => { - // The stream completed, send [DONE] message - Ok(Event::default().data("[DONE]")) - } - Err(error) => { - let error: Error = error.into(); - Ok(Event::default() - .event("error") - .json_data(error.to_json()) - .unwrap()) - } - }) - .boxed(); - let sse = Sse::new(event_stream).keep_alive(KeepAlive::default()); - Ok(sse.into_response()) - } - }, - Err(error) => Err(error.into()), + let server = + axum::serve(listener, app.into_make_service()).with_graceful_shutdown(shutdown_signal); + Ok(tokio::task::spawn(async { + server.await.expect("guardrails server crashed!") + })) } } @@ -587,7 +103,6 @@ async fn shutdown_signal() { .await .expect("failed to install Ctrl+C handler"); }; - #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) @@ -595,155 +110,72 @@ async fn shutdown_signal() { .recv() .await; }; - #[cfg(not(unix))] let terminate = std::future::pending::<()>(); - tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } - info!("signal received, starting graceful shutdown"); } -// Ref. https://github.com/rustls/rustls/blob/main/examples/src/bin/tlsserver-mio.rs -/// Load certificates from a file -fn load_certs(filename: &PathBuf) -> Vec> { - let cert_file = File::open(filename).expect("cannot open certificate file"); - let mut reader = BufReader::new(cert_file); - rustls_pemfile::certs(&mut reader) - .map(|result| result.unwrap()) - .collect() +/// Server shared state +pub struct ServerState { + orchestrator: Orchestrator, } -/// Load private key from a file -fn load_private_key(filename: &PathBuf) -> PrivateKeyDer<'static> { - let key_file = File::open(filename).expect("cannot open private key file"); - let mut reader = BufReader::new(key_file); - - loop { - match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { - Some(rustls_pemfile::Item::Pkcs1Key(key)) => return key.into(), - Some(rustls_pemfile::Item::Pkcs8Key(key)) => return key.into(), - Some(rustls_pemfile::Item::Sec1Key(key)) => return key.into(), - None => break, - _ => {} - } +impl ServerState { + pub fn new(orchestrator: Orchestrator) -> Self { + Self { orchestrator } } - - panic!( - "no keys found in {:?} (encrypted keys not supported)", - filename - ); } -/// High-level errors to return to clients. -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("{0}")] - Validation(String), - #[error("{0}")] - NotFound(String), - #[error("{0}")] - ServiceUnavailable(String), - #[error("unexpected error occurred while processing request")] - Unexpected, - #[error(transparent)] - JsonExtractorRejection(#[from] JsonRejection), - #[error("{0}")] - JsonError(String), - #[error("unsupported content type: {0}")] - UnsupportedContentType(String), -} +#[cfg(test)] +mod tests { + use super::*; -impl From for Error { - fn from(value: orchestrator::Error) -> Self { - use orchestrator::Error::*; - match value { - DetectorNotFound(_) | ChunkerNotFound(_) => Self::NotFound(value.to_string()), - DetectorRequestFailed { ref error, .. } - | ChunkerRequestFailed { ref error, .. } - | GenerateRequestFailed { ref error, .. } - | ChatCompletionRequestFailed { ref error, .. } - | TokenizeRequestFailed { ref error, .. } => match error.status_code() { - StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY => { - Self::Validation(value.to_string()) - } - StatusCode::NOT_FOUND => Self::NotFound(value.to_string()), - StatusCode::SERVICE_UNAVAILABLE => Self::ServiceUnavailable(value.to_string()), - _ => Self::Unexpected, - }, - JsonError(message) => Self::JsonError(message), - Validation(message) => Self::Validation(message), - _ => Self::Unexpected, - } - } -} - -impl Error { - pub fn to_json(self) -> serde_json::Value { - use Error::*; - let (code, message) = match self { - Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), - ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), - UnsupportedContentType(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()), - Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), - JsonExtractorRejection(json_rejection) => match json_rejection { - JsonRejection::JsonDataError(e) => { - // Get lower-level serde error message - let message = e.source().map(|e| e.to_string()).unwrap_or_default(); - (e.status(), message) - } - _ => (json_rejection.status(), json_rejection.body_text()), - }, - JsonError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - }; - serde_json::json!({ - "code": code.as_u16(), - "details": message, - }) - } -} + #[tokio::test] + async fn test_run_bind_failure() -> Result<(), Error> { + let guardrails_addr: SocketAddr = "0.0.0.0:50101".parse().unwrap(); + let health_addr: SocketAddr = "0.0.0.0:50103".parse().unwrap(); + let _listener = TcpListener::bind(&guardrails_addr).await?; + let result = run( + guardrails_addr, + health_addr, + None, + None, + None, + Orchestrator::default(), + ) + .await; + assert!(result.is_err_and(|error| matches!(error, Error::IoError(_)) + && error.to_string().starts_with("Address already in use"))); + Ok(()) + } + + #[tokio::test] + async fn test_run_with_tls() -> Result<(), Error> { + let guardrails_addr: SocketAddr = "0.0.0.0:50104".parse().unwrap(); + let health_addr: SocketAddr = "0.0.0.0:50105".parse().unwrap(); + let resources: PathBuf = [env!("CARGO_MANIFEST_DIR"), "tests", "resources"] + .iter() + .collect(); + let tls_cert_path = resources.join("localhost.crt"); + let tls_key_path = resources.join("localhost.key"); + let (_health_handle, guardrails_handle) = run( + guardrails_addr, + health_addr, + Some(tls_cert_path), + Some(tls_key_path), + None, + Orchestrator::default(), + ) + .await?; -impl IntoResponse for Error { - fn into_response(self) -> Response { - use Error::*; - let (code, message) = match self { - Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), - ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), - UnsupportedContentType(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()), - Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), - JsonExtractorRejection(json_rejection) => match json_rejection { - JsonRejection::JsonDataError(e) => { - // Get lower-level serde error message - let message = e.source().map(|e| e.to_string()).unwrap_or_default(); - (e.status(), message) - } - _ => (json_rejection.status(), json_rejection.body_text()), - }, - JsonError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - }; - let error = serde_json::json!({ - "code": code.as_u16(), - "details": message, - }); - (code, Json(error)).into_response() - } -} + // Ensure guardrails server task is still running + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + assert!(!guardrails_handle.is_finished()); -impl From for Error { - fn from(value: models::ValidationError) -> Self { - Self::Validation(value.to_string()) + Ok(()) } } - -fn filter_headers(passthrough_headers: &HashSet, headers: HeaderMap) -> HeaderMap { - headers - .iter() - .filter(|(name, _)| passthrough_headers.contains(&name.as_str().to_lowercase())) - .map(|(name, value)| (name.clone(), value.clone())) - .collect() -} diff --git a/src/server/errors.rs b/src/server/errors.rs new file mode 100644 index 00000000..b5b54b32 --- /dev/null +++ b/src/server/errors.rs @@ -0,0 +1,132 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::error::Error as _; + +use axum::{ + Json, + extract::rejection::JsonRejection, + response::{IntoResponse, Response}, +}; +use http::StatusCode; + +use crate::{models::ValidationError, orchestrator}; + +/// High-level errors to return to clients. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("{0}")] + Validation(String), + #[error("{0}")] + NotFound(String), + #[error("{0}")] + ServiceUnavailable(String), + #[error("unexpected error occurred while processing request")] + Unexpected, + #[error(transparent)] + JsonExtractorRejection(#[from] JsonRejection), + #[error("{0}")] + JsonError(String), + #[error("unsupported content type: {0}")] + UnsupportedContentType(String), + #[error(transparent)] + IoError(#[from] std::io::Error), +} + +impl From for Error { + fn from(value: orchestrator::Error) -> Self { + use orchestrator::Error::*; + match value { + DetectorNotFound(_) | ChunkerNotFound(_) => Self::NotFound(value.to_string()), + DetectorRequestFailed { ref error, .. } + | ChunkerRequestFailed { ref error, .. } + | GenerateRequestFailed { ref error, .. } + | ChatCompletionRequestFailed { ref error, .. } + | TokenizeRequestFailed { ref error, .. } => match error.status_code() { + StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY => { + Self::Validation(value.to_string()) + } + StatusCode::NOT_FOUND => Self::NotFound(value.to_string()), + StatusCode::SERVICE_UNAVAILABLE => Self::ServiceUnavailable(value.to_string()), + _ => Self::Unexpected, + }, + JsonError(message) => Self::JsonError(message), + Validation(message) => Self::Validation(message), + _ => Self::Unexpected, + } + } +} + +impl Error { + pub fn to_json(self) -> serde_json::Value { + use Error::*; + let (code, message) = match self { + Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), + ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), + UnsupportedContentType(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()), + Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + JsonExtractorRejection(json_rejection) => match json_rejection { + JsonRejection::JsonDataError(e) => { + // Get lower-level serde error message + let message = e.source().map(|e| e.to_string()).unwrap_or_default(); + (e.status(), message) + } + _ => (json_rejection.status(), json_rejection.body_text()), + }, + JsonError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + IoError(error) => (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()), + }; + serde_json::json!({ + "code": code.as_u16(), + "details": message, + }) + } +} + +impl IntoResponse for Error { + fn into_response(self) -> Response { + use Error::*; + let (code, message) = match self { + Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), + ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), + UnsupportedContentType(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()), + Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), + JsonExtractorRejection(json_rejection) => match json_rejection { + JsonRejection::JsonDataError(e) => { + // Get lower-level serde error message + let message = e.source().map(|e| e.to_string()).unwrap_or_default(); + (e.status(), message) + } + _ => (json_rejection.status(), json_rejection.body_text()), + }, + JsonError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), + IoError(error) => (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()), + }; + let error = serde_json::json!({ + "code": code.as_u16(), + "details": message, + }); + (code, Json(error)).into_response() + } +} + +impl From for Error { + fn from(value: ValidationError) -> Self { + Self::Validation(value.to_string()) + } +} diff --git a/src/server/routes.rs b/src/server/routes.rs new file mode 100644 index 00000000..bf654b2e --- /dev/null +++ b/src/server/routes.rs @@ -0,0 +1,368 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::{ + collections::{HashMap, HashSet}, + convert::Infallible, + sync::Arc, +}; + +use axum::{ + Json, Router, + extract::{Query, State}, + http::HeaderMap, + response::{ + IntoResponse, Response, + sse::{Event, KeepAlive, Sse}, + }, + routing::{get, post}, +}; +use axum_extra::{extract::WithRejection, json_lines::JsonLines}; +use futures::{ + Stream, StreamExt, + stream::{self, BoxStream}, +}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::info; + +use super::{Error, ServerState}; +use crate::{ + clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, + models::{self, InfoParams, InfoResponse, StreamingContentDetectionRequest}, + orchestrator::{ + self, + handlers::{chat_completions_detection::ChatCompletionsDetectionTask, *}, + }, + utils::{self, trace::current_trace_id}, +}; + +const PACKAGE_VERSION: &str = env!("CARGO_PKG_VERSION"); +const PACKAGE_NAME: &str = env!("CARGO_PKG_NAME"); + +/// Creates health router. +pub fn health_router(state: Arc) -> Router { + Router::new() + .route("/health", get(health)) + .route("/info", get(info)) + .with_state(state) +} + +/// Creates guardrails router. +pub fn guardrails_router(state: Arc) -> Router { + let mut router = Router::new() + // v1 routes + .route( + "/api/v1/task/classification-with-text-generation", + post(classification_with_gen), + ) + .route( + "/api/v1/task/server-streaming-classification-with-text-generation", + post(stream_classification_with_gen), + ) + // v2 routes + .route( + "/api/v2/text/detection/stream-content", + post(stream_content_detection), + ) + .route( + "/api/v2/text/generation-detection", + post(generation_with_detection), + ) + .route("/api/v2/text/detection/content", post(detection_content)) + .route("/api/v2/text/detection/chat", post(detect_chat)) + .route( + "/api/v2/text/detection/context", + post(detect_context_documents), + ) + .route("/api/v2/text/detection/generated", post(detect_generated)); + if state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + router.with_state(state) +} + +async fn health() -> Result { + // NOTE: we are only adding the package information in the `health` endpoint to have this endpoint + // provide a non empty 200 response. If we need to add more information regarding dependencies version + // or such things, then we will add another `/info` endpoint accordingly. And those info + // should not be added in `health` endpoint` + let info_object = HashMap::from([(PACKAGE_NAME, PACKAGE_VERSION)]); + Ok(Json(info_object).into_response()) +} + +async fn info( + State(state): State>, + Query(params): Query, +) -> Result, Error> { + let services = state.orchestrator.client_health(params.probe).await; + Ok(Json(InfoResponse { services })) +} + +async fn classification_with_gen( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = current_trace_id(); + request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ClassificationWithGenTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +async fn generation_with_detection( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection< + Json, + Error, + >, +) -> Result { + let trace_id = current_trace_id(); + request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = GenerationWithDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +async fn stream_classification_with_gen( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Sse>> { + let trace_id = current_trace_id(); + if let Err(error) = request.validate() { + // Request validation failed, return stream with single error SSE event + let error: Error = error.into(); + return Sse::new( + stream::iter([Ok(Event::default() + .event("error") + .json_data(error.to_json()) + .unwrap())]) + .boxed(), + ); + } + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = StreamingClassificationWithGenTask::new(trace_id, request, headers); + let response_stream = state.orchestrator.handle(task).await.unwrap(); + // Convert response stream to a stream of SSE events + let event_stream = response_stream + .map(|message| match message { + Ok(response) => Ok(Event::default() + //.event("message") NOTE: per spec, should not be included for data-only message events + .json_data(response) + .unwrap()), + Err(error) => { + let error: Error = error.into(); + Ok(Event::default() + .event("error") + .json_data(error.to_json()) + .unwrap()) + } + }) + .boxed(); + Sse::new(event_stream).keep_alive(KeepAlive::default()) +} + +async fn stream_content_detection( + State(state): State>, + headers: HeaderMap, + json_lines: JsonLines, +) -> Result { + let trace_id = current_trace_id(); + // Validate the content-type from the header and ensure it is application/x-ndjson + // If it's not, return a UnsupportedContentType error with the appropriate message + let content_type = headers + .get(http::header::CONTENT_TYPE) + .and_then(|value| value.to_str().ok()); + match content_type { + Some(content_type) if content_type.starts_with("application/x-ndjson") => (), + _ => { + return Err(Error::UnsupportedContentType( + "expected application/x-ndjson".into(), + )); + } + }; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + + // Create input stream + let input_stream = json_lines + .map(|result| match result { + Ok(message) => { + message.validate()?; + Ok(message) + } + Err(error) => Err(orchestrator::errors::Error::Validation(error.to_string())), + }) + .enumerate() + .boxed(); + + // Create task and submit to handler + let task = StreamingContentDetectionTask::new(trace_id, headers, input_stream); + let mut response_stream = state.orchestrator.handle(task).await?; + + // Create output stream + // This stream returns ND-JSON formatted messages to the client + // StreamingContentDetectionResponse / server::Error + let (output_tx, output_rx) = mpsc::channel::>(128); + let output_stream = ReceiverStream::new(output_rx); + + // Spawn task to consume response stream (typed) and send to output stream (json) + tokio::spawn(async move { + while let Some(result) = response_stream.next().await { + match result { + Ok(msg) => { + let msg = utils::json::to_nd_string(&msg).unwrap(); + let _ = output_tx.send(Ok(msg)).await; + } + Err(error) => { + // Convert orchestrator::Error to server::Error + let error: Error = error.into(); + // server::Error doesn't impl Serialize, so we use to_json() + let error_msg = utils::json::to_nd_string(&error.to_json()).unwrap(); + let _ = output_tx.send(Ok(error_msg)).await; + } + } + } + }); + + Ok(Response::new(axum::body::Body::from_stream(output_stream))) +} + +async fn detection_content( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection< + Json, + Error, + >, +) -> Result { + let trace_id = current_trace_id(); + request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = TextContentDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +async fn detect_context_documents( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = current_trace_id(); + request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ContextDocsDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +async fn detect_chat( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = current_trace_id(); + request.validate_for_text()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ChatDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +async fn detect_generated( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection< + Json, + Error, + >, +) -> Result { + let trace_id = current_trace_id(); + request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = DetectionOnGenerationTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => Ok(Json(response).into_response()), + Err(error) => Err(error.into()), + } +} + +async fn chat_completions_detection( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + use ChatCompletionsResponse::*; + let trace_id = current_trace_id(); + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => match response { + Unary(response) => Ok(Json(response).into_response()), + Streaming(response_rx) => { + let response_stream = ReceiverStream::new(response_rx); + // Convert response stream to a stream of SSE events + let event_stream: BoxStream> = response_stream + .map(|message| match message { + Ok(Some(chunk)) => Ok(Event::default().json_data(chunk).unwrap()), + Ok(None) => { + // The stream completed, send [DONE] message + Ok(Event::default().data("[DONE]")) + } + Err(error) => { + let error: Error = error.into(); + Ok(Event::default() + .event("error") + .json_data(error.to_json()) + .unwrap()) + } + }) + .boxed(); + let sse = Sse::new(event_stream).keep_alive(KeepAlive::default()); + Ok(sse.into_response()) + } + }, + Err(error) => Err(error.into()), + } +} + +/// Filters a [`HeaderMap`] with a set of header names, returning a new [`HeaderMap`]. +pub fn filter_headers(passthrough_headers: &HashSet, headers: HeaderMap) -> HeaderMap { + headers + .iter() + .filter(|(name, _)| passthrough_headers.contains(&name.as_str().to_lowercase())) + .map(|(name, value)| (name.clone(), value.clone())) + .collect() +} diff --git a/src/server/tls.rs b/src/server/tls.rs new file mode 100644 index 00000000..079262d6 --- /dev/null +++ b/src/server/tls.rs @@ -0,0 +1,172 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::{fs::File, io::BufReader, path::PathBuf, sync::Arc}; + +use axum::{Router, extract::Request}; +use hyper::body::Incoming; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use rustls::{RootCertStore, ServerConfig, server::WebPkiClientVerifier}; +use tokio::net::TcpListener; +use tokio_rustls::TlsAcceptor; +use tower::Service; +use tracing::{debug, error, info, warn}; +use webpki::types::{CertificateDer, PrivateKeyDer}; + +/// Loads certificates and configures TLS. +pub fn configure_tls( + tls_cert_path: Option, + tls_key_path: Option, + tls_client_ca_cert_path: Option, +) -> Option> { + if let (Some(cert_path), Some(key_path)) = (tls_cert_path, tls_key_path) { + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + let cert = load_certs(&cert_path); + let key = load_private_key(&key_path); + // Configure mTLS if client CA is provided + let client_auth = if let Some(client_ca_cert_path) = tls_client_ca_cert_path { + let client_certs = load_certs(&client_ca_cert_path); + let mut client_auth_certs = RootCertStore::empty(); + for client_cert in client_certs { + client_auth_certs + .add(client_cert.clone()) + .unwrap_or_else(|e| { + panic!("error adding client cert {:?}: {}", client_cert, e) + }); + } + info!("mTLS enabled"); + WebPkiClientVerifier::builder(client_auth_certs.into()) + .build() + .unwrap_or_else(|e| panic!("error building client verifier: {}", e)) + } else { + info!("TLS enabled"); + WebPkiClientVerifier::no_client_auth() + }; + let server_config = ServerConfig::builder() + .with_client_cert_verifier(client_auth) + .with_single_cert(cert, key) + .expect("bad server certificate or key"); + Some(Arc::new(server_config)) + } else { + info!("TLS not enabled"); + None + } +} + +/// Serve the service with the supplied listener, TLS config, and shutdown signal. +/// Based on https://github.com/tokio-rs/axum/blob/main/examples/low-level-rustls/src/main.rs +pub fn serve_with_tls( + app: Router, + listener: TcpListener, + tls_config: Arc, + shutdown_signal: F, +) -> tokio::task::JoinHandle<()> +where + F: Future + Send + 'static, +{ + let tls_acceptor = TlsAcceptor::from(tls_config); + tokio::spawn(async move { + let graceful = hyper_util::server::graceful::GracefulShutdown::new(); + let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + let mut signal = std::pin::pin!(shutdown_signal); + loop { + let tower_service = app.clone(); + let tls_acceptor = tls_acceptor.clone(); + // Wait for new tcp connection + let (cnx, addr) = tokio::select! { + res = listener.accept() => { + match res { + Ok(res) => res, + Err(err) => { + error!("error accepting tcp connection: {err}"); + continue; + } + } + } + _ = &mut signal => { + debug!("graceful shutdown signal received"); + break; + } + }; + // Wait for tls handshake + let stream = tokio::select! { + res = tls_acceptor.accept(cnx) => { + match res { + Ok(stream) => stream, + Err(err) => { + error!("error accepting connection on handshake: {err}"); + continue; + } + } + } + _ = &mut signal => { + debug!("graceful shutdown signal received"); + break; + } + }; + // `TokioIo` converts between Hyper's own `AsyncRead` and `AsyncWrite` traits + let stream = TokioIo::new(stream); + let hyper_service = hyper::service::service_fn(move |request: Request| { + // Clone necessary since hyper's `Service` uses `&self` whereas + // tower's `Service` requires `&mut self` + tower_service.clone().call(request) + }); + let conn = builder.serve_connection_with_upgrades(stream, hyper_service); + let fut = graceful.watch(conn.into_owned()); + tokio::spawn(async move { + if let Err(err) = fut.await { + warn!("error serving connection from {}: {}", addr, err); + } + }); + } + tokio::select! { + () = graceful.shutdown() => { + debug!("graceful shutdown completed"); + }, + () = tokio::time::sleep(std::time::Duration::from_secs(10)) => { + debug!("graceful shutdown timed out, aborting..."); + } + } + }) +} + +/// Load certificates from a file +fn load_certs(filename: &PathBuf) -> Vec> { + let cert_file = File::open(filename).expect("cannot open certificate file"); + let mut reader = BufReader::new(cert_file); + rustls_pemfile::certs(&mut reader) + .map(|result| result.unwrap()) + .collect() +} + +/// Load private key from a file +fn load_private_key(filename: &PathBuf) -> PrivateKeyDer<'static> { + let key_file = File::open(filename).expect("cannot open private key file"); + let mut reader = BufReader::new(key_file); + loop { + match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { + Some(rustls_pemfile::Item::Pkcs1Key(key)) => return key.into(), + Some(rustls_pemfile::Item::Pkcs8Key(key)) => return key.into(), + Some(rustls_pemfile::Item::Sec1Key(key)) => return key.into(), + None => break, + _ => {} + } + } + panic!( + "no keys found in {:?} (encrypted keys not supported)", + filename + ); +} diff --git a/tests/canary_test.rs b/tests/canary_test.rs deleted file mode 100644 index 0a59c5a7..00000000 --- a/tests/canary_test.rs +++ /dev/null @@ -1,71 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -// This is needed because integration test files are compiled as separate crates. -// If any of the code in this file is not used by any of the test files, a warning about unused code is generated. -// For more: https://github.com/rust-lang/rust/issues/46379 - -use std::sync::Arc; - -use axum_test::TestServer; -use common::orchestrator::ensure_global_rustls_state; -use fms_guardrails_orchestr8::{ - config::OrchestratorConfig, - orchestrator::Orchestrator, - server::{ServerState, get_health_app}, -}; -use hyper::StatusCode; -use serde_json::Value; -use test_log::test; -use tokio::sync::OnceCell; -use tracing::debug; - -pub mod common; - -/// Async lazy initialization of shared state using tokio::sync::OnceCell -static ONCE: OnceCell> = OnceCell::const_new(); - -/// The actual async function that initializes the shared state if not already initialized -async fn shared_state() -> Arc { - let config = OrchestratorConfig::load("tests/test_config.yaml") - .await - .unwrap(); - let orchestrator = Orchestrator::new(config, false).await.unwrap(); - Arc::new(ServerState::new(orchestrator)) -} - -/// Checks if the health endpoint is working -/// NOTE: We do not currently mock client services yet, so this test is -/// superficially testing the client health endpoints on the orchestrator is accessible -/// and when the orchestrator is running (healthy) all the health endpoints return 200 OK. -/// This will happen even if the client services or their health endpoints are not found. -#[test(tokio::test)] -async fn test_health() { - ensure_global_rustls_state(); - let shared_state = ONCE.get_or_init(shared_state).await.clone(); - let server = TestServer::new(get_health_app(shared_state)).unwrap(); - let response = server.get("/health").await; - debug!("{:#?}", response); - let body: Value = serde_json::from_str(response.text().as_str()).unwrap(); - debug!("{}", serde_json::to_string_pretty(&body).unwrap()); - response.assert_status(StatusCode::OK); - let response = server.get("/info").await; - println!("{:#?}", response); - let body: Value = serde_json::from_str(response.text().as_str()).unwrap(); - println!("{}", serde_json::to_string_pretty(&body).unwrap()); - response.assert_status(StatusCode::OK); -} diff --git a/tests/resources/localhost.crt b/tests/resources/localhost.crt new file mode 100644 index 00000000..fbeff826 --- /dev/null +++ b/tests/resources/localhost.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDDzCCAfegAwIBAgIUIYDXhZUtJjQitYHNddtEb/axf50wDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDQyOTE5MDIwNVoXDTM1MDQy +NzE5MDIwNVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAvrvegKc1NZKYbnosX019wgLW9hyhdaxAIyQ5fJRWO07i +IjlUGN9wR9PgckC195dN+3KilrrfYuXeOJclSqeaegFbyFgu7oaUtLZXtl7i5Fw/ +OuVrwnPLeCrFzqlkEdHDF92nKQ2RLvySlYsBxu+eJskLUM+lNO5H5LxaqwZSio9c +9ONLnc9kfVqo7jdYxtQnxDPmtyXNnXk+Bl28cCqT4aa335twu77Zt1yQ/vamWMXi +6skfFBJI3HtFNnqgX+48okzYEVpPXWRjoVt6XWhQDYcZhXO/sH6Npaovc7wXGoc9 +6Nwnpb5IiQPJKnOwOwTKb8M+V7sDfb6Z0OZlOGbWnQIDAQABo1kwVzAUBgNVHREE +DTALgglsb2NhbGhvc3QwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUFBwMB +MB0GA1UdDgQWBBSU1izIFJDicxS91WjpaSzK5rAYkzANBgkqhkiG9w0BAQsFAAOC +AQEAkEOPGAXeSj2/7ncCQlum5WK7INevfASQewzMl88QEWoSzFtJTeVr2t6BD8XB +jEgxHV+LGoMGwKpElb4DRbNUXrZIJvrLA3Ov9bvzdwxft44nQAJhjC+KBX/jQpYp +OO9y0HVhxrwIWyTFVLmkkCurtooKtFI+3j7vqtesv2wolCCM9k7dP2p6Vkee8PGJ +2ZT0A/C8r+9oDaVAvd0NNPoXi+zo5mskDGR6JiY5mlVii0OYi3XWw0OfM9zRLVoz +QrHhsU9dvXi9d6S2DlXE2SaEDFgytPDbGmqOXl4wV6GGD/cEwUm7DzVB0JA8e8R/ +053YAE72EqFnnXtdyrzjNbxX5w== +-----END CERTIFICATE----- diff --git a/tests/resources/localhost.key b/tests/resources/localhost.key new file mode 100644 index 00000000..7318721d --- /dev/null +++ b/tests/resources/localhost.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC+u96ApzU1kphu +eixfTX3CAtb2HKF1rEAjJDl8lFY7TuIiOVQY33BH0+ByQLX3l037cqKWut9i5d44 +lyVKp5p6AVvIWC7uhpS0tle2XuLkXD865WvCc8t4KsXOqWQR0cMX3acpDZEu/JKV +iwHG754myQtQz6U07kfkvFqrBlKKj1z040udz2R9WqjuN1jG1CfEM+a3Jc2deT4G +XbxwKpPhprffm3C7vtm3XJD+9qZYxeLqyR8UEkjce0U2eqBf7jyiTNgRWk9dZGOh +W3pdaFANhxmFc7+wfo2lqi9zvBcahz3o3CelvkiJA8kqc7A7BMpvwz5XuwN9vpnQ +5mU4ZtadAgMBAAECggEAFMrFnWg85m9p4wc5/ZSkvCjTlqY12q5RmpMZGjHQeawm +fh0aRBDdfFMGUWYpAAnUNDBBtaU+81jEIg6l/86oChtKTluAmDt/C/khtC2BVewR +b4HxfpPhbyLYh+gS0td3SUt+LL1AIatibl9mLPHTn35El8VfCNdd6ns29hRdHKIG +cmJWdU6+HRBopIKYKR5Qx3t2ttSqxRwDF/PyPnytqFRP8BJfZOYCBoby0PzkVYYI +gjtHK75OtTDblMzf/wvXXstVv/cChibu+J28L7Y7cZM1MhceoW4KVw/U9O+P7eqp +KII2aVDIfjO9RBjriN1YHHTvZwjf42B1bw2EAYcEeQKBgQDsaslefYxQyYwekmZG +cwKH1lmT7sq6ED39TP/5subV7z+g0c88QW9Cc+yo1qFFoClGNN85/LYl78d1IokD +OEI1olaFCyJJhys9ZwJ5n/34a9YH3WGjVCEqyYR0qCvbVnvzvWE4pFwSMSzYqSFM +GZgevtaFXcHgQhK457N00zfA+QKBgQDOiF8AA4aUibYPIb/v8mzyU04nzu0p+SZ/ +jLCk9JyKOwmMH9+OWKGiHpxWIaARtBMepke6JQWjJOd00s9Wg1RmlF3BNYJU99zt +vEMGwwIdzdUgDPuuLCnLfRpllWj7tMISPMY2ckuiqGtzVlw/23uYkocWVyft6dx0 +/Y3POKbPxQKBgQDcK2LXDaLkBZ7pRbvbtfXQXS3VF9hSSRgB0ni8qQBSkdm2wk31 +tpaP03e9kQxl1A88I7cTwKY9VD1zd7MTfYwjeMiMZF8NDMWXDFyAuiAB/yM29dOd +EJdGyp8BKTCWtsb+qgplfqOECanTKEcrINbLIzySvUr1t3LKInW8wYu1uQKBgQDG +dKfEnZ6uEH3OoIGMkYg6ee35to6R7IUfvxLmDt50vTH5YY8xet0lqQBUi08Cc+SD +aQg3R+fY0ldOHFt7KArr3tkQFNi9yMaT8nj9gFkCRozqlU8qF+m5TOcWgbE1XIW3 +fIqCOuWO0QMe+vb9rWtgOjxwLSODK1rZV+LyId+4eQKBgFm25EJrTxhdJ5fVmzno +h8PpyVPaXnnVIpy8V/YxZUn5OiAWQ5c5nNLIw2MLOyLf76HNuoKTux+wiV4ozOG4 +gt2zt8k/Cz/vVAJT9gtMIrbqlFedsjMa6JiOk8qZW27pnIhFvRvzxixIQXpDxcnd +DLtqaawkz+wN5SSe5rwe1A+B +-----END PRIVATE KEY----- From 19be26ecc2d6b883963c564ebe111e5818f6c2d1 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Tue, 6 May 2025 12:02:57 -0700 Subject: [PATCH 20/24] OpenAiClient updates & completions implementation (#390) Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- src/clients/openai.rs | 442 ++++++++++++------ src/orchestrator/common/client.rs | 53 ++- src/orchestrator/errors.rs | 2 + .../handlers/chat_completions_detection.rs | 4 +- src/orchestrator/types.rs | 6 +- src/server/routes.rs | 1 + tests/chat_completions_detection.rs | 4 + 7 files changed, 358 insertions(+), 154 deletions(-) diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 917e667a..19050e56 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -22,13 +22,15 @@ use eventsource_stream::Eventsource; use futures::StreamExt; use http_body_util::BodyExt; use hyper::{HeaderMap, StatusCode}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::{Map, Value}; use tokio::sync::mpsc; +use url::Url; use super::{ - Client, Error, HttpClient, create_http_client, detector::ContentAnalysisResponse, - http::HttpClientExt, + Client, Error, HttpClient, create_http_client, + detector::ContentAnalysisResponse, + http::{HttpClientExt, RequestBody}, }; use crate::{ config::ServiceConfig, @@ -40,6 +42,7 @@ use crate::{ const DEFAULT_PORT: u16 = 8080; const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; +const COMPLETIONS_ENDPOINT: &str = "/v1/completions"; #[derive(Clone)] pub struct OpenAiClient { @@ -73,65 +76,102 @@ impl OpenAiClient { request: ChatCompletionsRequest, headers: HeaderMap, ) -> Result { - let url = self.inner().endpoint(CHAT_COMPLETIONS_ENDPOINT); - if request.stream { - let (tx, rx) = mpsc::channel(32); - let mut event_stream = self - .inner() - .post(url, headers, request) - .await? - .0 - .into_data_stream() - .eventsource(); - // Spawn task to forward events to receiver - tokio::spawn(async move { - while let Some(result) = event_stream.next().await { - match result { - Ok(event) if event.data == "[DONE]" => { - // Send None to signal that the stream completed - let _ = tx.send(Ok(None)).await; - break; + let url = self.client.endpoint(CHAT_COMPLETIONS_ENDPOINT); + if let Some(true) = request.stream { + let rx = self.handle_streaming(url, request, headers).await?; + Ok(ChatCompletionsResponse::Streaming(rx)) + } else { + let chat_completion = self.handle_unary(url, request, headers).await?; + Ok(ChatCompletionsResponse::Unary(chat_completion)) + } + } + + pub async fn completions( + &self, + request: CompletionsRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.endpoint(COMPLETIONS_ENDPOINT); + if let Some(true) = request.stream { + let rx = self.handle_streaming(url, request, headers).await?; + Ok(CompletionsResponse::Streaming(rx)) + } else { + let completion = self.handle_unary(url, request, headers).await?; + Ok(CompletionsResponse::Unary(completion)) + } + } + + async fn handle_unary(&self, url: Url, request: R, headers: HeaderMap) -> Result + where + R: RequestBody, + S: DeserializeOwned, + { + let response = self.client.post(url, headers, request).await?; + match response.status() { + StatusCode::OK => response.json::().await, + _ => { + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occurred".into() + }; + Err(Error::Http { code, message }) + } + } + } + + async fn handle_streaming( + &self, + url: Url, + request: R, + headers: HeaderMap, + ) -> Result, orchestrator::Error>>, Error> + where + R: RequestBody, + S: DeserializeOwned + Send + 'static, + { + let (tx, rx) = mpsc::channel(32); + let mut event_stream = self + .client + .post(url, headers, request) + .await? + .0 + .into_data_stream() + .eventsource(); + // Spawn task to forward events to receiver + tokio::spawn(async move { + while let Some(result) = event_stream.next().await { + match result { + Ok(event) if event.data == "[DONE]" => { + // Send None to signal that the stream completed + let _ = tx.send(Ok(None)).await; + break; + } + Ok(event) => match serde_json::from_str::(&event.data) { + Ok(chunk) => { + let _ = tx.send(Ok(Some(chunk))).await; } - Ok(event) => match serde_json::from_str::(&event.data) - { - Ok(chunk) => { - let _ = tx.send(Ok(Some(chunk))).await; - } - Err(e) => { - let error = Error::Http { - code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("deserialization error: {e}"), - }; - let _ = tx.send(Err(error.into())).await; - } - }, - Err(error) => { - // We received an error from the event stream, send error message + Err(e) => { let error = Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: error.to_string(), + message: format!("deserialization error: {e}"), }; let _ = tx.send(Err(error.into())).await; } + }, + Err(error) => { + // We received an error from the event stream, send error message + let error = Error::Http { + code: StatusCode::INTERNAL_SERVER_ERROR, + message: error.to_string(), + }; + let _ = tx.send(Err(error.into())).await; } } - }); - Ok(ChatCompletionsResponse::Streaming(rx)) - } else { - let response = self.client.clone().post(url, headers, request).await?; - match response.status() { - StatusCode::OK => Ok(response.json::().await?.into()), - _ => { - let code = response.status(); - let message = if let Ok(response) = response.json::().await { - response.message - } else { - "unknown error occurred".into() - }; - Err(Error::Http { code, message }) - } } - } + }); + Ok(rx) } } @@ -156,6 +196,7 @@ impl HttpClientExt for OpenAiClient { } } +/// Chat completions response. #[derive(Debug)] pub enum ChatCompletionsResponse { Unary(Box), @@ -168,86 +209,99 @@ impl From for ChatCompletionsResponse { } } -/// Represents a chat completions request. +/// Completions (legacy) response. +#[derive(Debug)] +pub enum CompletionsResponse { + Unary(Box), + Streaming(mpsc::Receiver, orchestrator::Error>>), +} + +impl From for CompletionsResponse { + fn from(value: Completion) -> Self { + Self::Unary(Box::new(value)) + } +} + +/// Chat completions request. /// /// As orchestrator is only concerned with a limited subset -/// of request fields, we deserialize to an inner [`serde_json::Map`] -/// and only validate and extract the fields used by this service. -/// This type is then serialized to the inner [`serde_json::Map`]. +/// of request fields, we only inline and validate fields used by +/// this service. Extra fields are deserialized to `extra` via +/// struct flattening. The `detectors` field is not serialized. /// /// This is to avoid tracking and updating OpenAI and vLLM /// parameter additions/changes. Full validation is delegated to /// the downstream server implementation. -/// -/// Validated fields: detectors (internal), model, messages -#[derive(Debug, Default, Clone, PartialEq, Deserialize)] -#[serde(try_from = "Map")] +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatCompletionsRequest { /// Detector config. + #[serde(default, skip_serializing)] pub detectors: DetectorConfig, /// Stream parameter. - pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, /// Model name. pub model: String, /// Messages. pub messages: Vec, - /// Inner request. - pub inner: Map, + /// Extra fields not captured above. + #[serde(flatten)] + pub extra: Map, } -impl TryFrom> for ChatCompletionsRequest { - type Error = ValidationError; - - fn try_from(mut value: Map) -> Result { - let detectors = if let Some(detectors) = value.remove("detectors") { - DetectorConfig::deserialize(detectors) - .map_err(|_| ValidationError::Invalid("error deserializing `detectors`".into()))? - } else { - DetectorConfig::default() - }; - let stream = value - .get("stream") - .and_then(|v| v.as_bool()) - .unwrap_or_default(); - let model = if let Some(Value::String(model)) = value.get("model") { - Ok(model.clone()) - } else { - Err(ValidationError::Required("model".into())) - }?; - if model.is_empty() { +impl ChatCompletionsRequest { + pub fn validate(&self) -> Result<(), ValidationError> { + if self.model.is_empty() { return Err(ValidationError::Invalid("`model` must not be empty".into())); } - let messages = if let Some(messages) = value.get("messages") { - Vec::::deserialize(messages) - .map_err(|_| ValidationError::Invalid("error deserializing `messages`".into())) - } else { - Err(ValidationError::Required("messages".into())) - }?; - if messages.is_empty() { + if self.messages.is_empty() { return Err(ValidationError::Invalid( "`messages` must not be empty".into(), )); } - Ok(ChatCompletionsRequest { - detectors, - stream, - model, - messages, - inner: value, - }) + Ok(()) } } -impl Serialize for ChatCompletionsRequest { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.inner.serialize(serializer) +/// Completions (legacy) request. +/// +/// As orchestrator is only concerned with a limited subset +/// of request fields, we only inline and validate fields used by +/// this service. Extra fields are deserialized to `extra` via +/// struct flattening. The `detectors` field is not serialized. +/// +/// This is to avoid tracking and updating OpenAI and vLLM +/// parameter additions/changes. Full validation is delegated to +/// the downstream server implementation. +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] +pub struct CompletionsRequest { + /// Stream parameter. + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + /// Model name. + pub model: String, + /// Prompt text. + pub prompt: String, + /// Extra fields not captured above. + #[serde(flatten)] + pub extra: Map, +} + +impl CompletionsRequest { + pub fn validate(&self) -> Result<(), ValidationError> { + if self.model.is_empty() { + return Err(ValidationError::Invalid("`model` must not be empty".into())); + } + if self.prompt.is_empty() { + return Err(ValidationError::Invalid( + "`prompt` must not be empty".into(), + )); + } + Ok(()) } } -/// Structure to contain parameters for detectors. +/// Detector config. #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct DetectorConfig { @@ -257,6 +311,7 @@ pub struct DetectorConfig { pub output: HashMap, } +/// Response format. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ResponseFormat { /// The type of response format being defined. @@ -266,6 +321,7 @@ pub struct ResponseFormat { pub json_schema: HashMap, } +/// Tool. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { /// The type of the tool. @@ -274,6 +330,7 @@ pub struct Tool { pub function: ToolFunction, } +/// Tool function. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolFunction { /// The name of the function to be called. @@ -290,6 +347,7 @@ pub struct ToolFunction { pub strict: Option, } +/// Tool choice. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolChoice { @@ -301,6 +359,7 @@ pub enum ToolChoice { Object(ToolChoiceObject), } +/// Tool choice object. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolChoiceObject { /// The type of the tool. @@ -309,6 +368,7 @@ pub struct ToolChoiceObject { pub function: Function, } +/// Stream options. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StreamOptions { /// If set, an additional chunk will be streamed before the data: [DONE] message. @@ -319,6 +379,7 @@ pub struct StreamOptions { pub include_usage: Option, } +/// Role. #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -330,6 +391,7 @@ pub enum Role { Tool, } +/// Message. #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Message { @@ -352,6 +414,7 @@ pub struct Message { pub tool_call_id: Option, } +/// Content. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum Content { @@ -396,6 +459,7 @@ impl From> for Content { } } +/// Content type. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub enum ContentType { #[serde(rename = "text")] @@ -405,6 +469,7 @@ pub enum ContentType { ImageUrl, } +/// Content part. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct ContentPart { /// The type of the content part. @@ -421,6 +486,7 @@ pub struct ContentPart { pub refusal: Option, } +/// Image url. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ImageUrl { /// Either a URL of the image or the base64 encoded image data. @@ -430,6 +496,7 @@ pub struct ImageUrl { pub detail: Option, } +/// Tool call. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolCall { /// The ID of the tool call. @@ -441,6 +508,7 @@ pub struct ToolCall { pub function: Function, } +/// Function. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Function { /// The name of the function to call. @@ -450,7 +518,7 @@ pub struct Function { pub arguments: Option, } -/// Represents a chat completion response returned by model, based on the provided input. +/// Chat completion response. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletion { /// A unique identifier for the chat completion. @@ -465,6 +533,8 @@ pub struct ChatCompletion { pub choices: Vec, /// Usage statistics for the completion request. pub usage: Usage, + /// Prompt logprobs. + pub prompt_logprobs: Option>>>, /// This fingerprint represents the backend configuration that the model runs with. #[serde(skip_serializing_if = "Option::is_none")] pub system_fingerprint: Option, @@ -480,7 +550,7 @@ pub struct ChatCompletion { pub warnings: Vec, } -/// A chat completion choice. +/// Chat completion choice. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionChoice { /// The index of the choice in the list of choices. @@ -491,9 +561,11 @@ pub struct ChatCompletionChoice { pub logprobs: Option, /// The reason the model stopped generating tokens. pub finish_reason: String, + /// The stop string or token id that caused the completion. + pub stop_reason: Option, } -/// A chat completion message generated by the model. +/// Chat completion message. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionMessage { /// The role of the author of this message. @@ -508,6 +580,7 @@ pub struct ChatCompletionMessage { pub refusal: Option, } +/// Chat completion logprobs. #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub struct ChatCompletionLogprobs { /// A list of message content tokens with log probability information. @@ -517,7 +590,7 @@ pub struct ChatCompletionLogprobs { pub refusal: Option>, } -/// Log probability information for a choice. +/// Chat completion logprob. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionLogprob { /// The token. @@ -531,6 +604,7 @@ pub struct ChatCompletionLogprob { pub top_logprobs: Option>, } +/// Chat completion top logprob. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatCompletionTopLogprob { /// The token. @@ -539,7 +613,7 @@ pub struct ChatCompletionTopLogprob { pub logprob: f32, } -/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input. +/// Streaming chat completion chunk. #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ChatCompletionChunk { /// A unique identifier for the chat completion. Each chunk has the same ID. @@ -569,6 +643,7 @@ pub struct ChatCompletionChunk { pub warnings: Vec, } +/// Streaming chat completion chunk choice. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionChunkChoice { /// The index of the choice in the list of choices. @@ -579,9 +654,11 @@ pub struct ChatCompletionChunkChoice { pub logprobs: Option, /// The reason the model stopped generating tokens. pub finish_reason: Option, + /// The stop string or token id that caused the completion. + pub stop_reason: Option, } -/// A chat completion delta generated by streamed model responses. +/// Streaming chat completion delta. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionDelta { /// The role of the author of this message. @@ -598,7 +675,69 @@ pub struct ChatCompletionDelta { pub tool_calls: Vec, } -/// Usage statistics for a completion. +/// Completion (legacy) response. Also used for streaming. +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] +pub struct Completion { + /// A unique identifier for the completion. + pub id: String, + /// The object type, which is always `text_completion`. + pub object: String, + /// The Unix timestamp (in seconds) of when the chat completion was created. + pub created: i64, + /// The model used for the completion. + pub model: String, + /// A list of completion choices. Can be more than one if n is greater than 1. + pub choices: Vec, + /// Usage statistics for the completion request. + pub usage: Option, + /// This fingerprint represents the backend configuration that the model runs with. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +/// Completion (legacy) choice. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct CompletionChoice { + /// The index of the choice in the list of choices. + pub index: u32, + /// Text generated by the model. + pub text: String, + /// Log probability information for the choice. + pub logprobs: Option, + /// The reason the model stopped generating tokens. + pub finish_reason: Option, + /// The stop string or token id that caused the completion. + pub stop_reason: Option, + /// Prompt logprobs. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_logprobs: Option>>>, +} + +/// Completion logprobs. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct CompletionLogprobs { + /// Tokens generated by the model. + pub tokens: Vec, + /// Token logprobs. + pub token_logprobs: Vec, + /// Top logprobs. + pub top_logprobs: Vec>, + /// Text offsets. + pub text_offset: Vec, +} + +/// Logprob. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct Logprob { + /// The logprob of the chosen token + pub logprob: f32, + /// The vocab rank of the chosen token (>=1) + pub rank: Option, + /// The decoded chosen token index + pub decoded_token: Option, +} + +/// Completion usage statistics. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct Usage { /// Number of tokens in the prompt. @@ -615,18 +754,21 @@ pub struct Usage { pub completion_token_details: Option, } +/// Completion token details. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct CompletionTokenDetails { pub audio_tokens: u32, pub reasoning_tokens: u32, } +/// Prompt token details. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PromptTokenDetails { pub audio_tokens: u32, pub cached_tokens: u32, } +/// Stop tokens. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum StopTokens { @@ -634,6 +776,7 @@ pub enum StopTokens { String(String), } +/// OpenAI error response. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OpenAiError { pub object: Option, @@ -644,7 +787,7 @@ pub struct OpenAiError { pub code: u16, } -/// Guardrails detection results. +/// Guardrails chat detections. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatDetections { #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -653,7 +796,7 @@ pub struct ChatDetections { pub output: Vec, } -/// Guardrails detection result for application on input. +/// Guardrails chat input detections. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct InputDetectionResult { pub message_index: u32, @@ -661,7 +804,7 @@ pub struct InputDetectionResult { pub results: Vec, } -/// Guardrails detection result for application output. +/// Guardrails chat output detections. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct OutputDetectionResult { pub choice_index: u32, @@ -669,15 +812,7 @@ pub struct OutputDetectionResult { pub results: Vec, } -/// Represents the input and output of detection results following processing. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DetectionResult { - pub index: u32, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub results: Vec, -} - -/// Warnings generated by guardrails. +/// Guardrails warning. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct OrchestratorWarning { r#type: DetectionWarningReason, @@ -714,18 +849,19 @@ mod test { "model": "test", "detectors": detectors, "messages": messages, + "frequency_penalty": 2.0, }); let request = ChatCompletionsRequest::deserialize(&json_request)?; - let mut inner = json_request.as_object().unwrap().to_owned(); - inner.remove("detectors").unwrap(); + let mut extra = Map::new(); + extra.insert("frequency_penalty".into(), 2.0.into()); assert_eq!( request, ChatCompletionsRequest { detectors, - stream: false, + stream: None, model: "test".into(), messages: messages.clone(), - inner, + extra, } ); @@ -735,57 +871,67 @@ mod test { "messages": messages, }); let request = ChatCompletionsRequest::deserialize(&json_request)?; - let inner = json_request.as_object().unwrap().to_owned(); assert_eq!( request, ChatCompletionsRequest { detectors: DetectorConfig::default(), - stream: false, + stream: None, model: "test".into(), messages: messages.clone(), - inner, + extra: Map::new(), } ); - // Test deserialize validation errors + // Test deserialize errors let result = ChatCompletionsRequest::deserialize(json!({ "detectors": DetectorConfig::default(), "messages": messages, })); - assert!(result.is_err_and(|error| error.to_string() == "`model` is required")); + assert!(result.is_err_and(|error| error.to_string().starts_with("missing field `model`"))); let result = ChatCompletionsRequest::deserialize(json!({ + "model": "test", + "detectors": DetectorConfig::default(), + "messages": ["invalid"], + })); + assert!(result.is_err_and(|error| error.to_string() + == "invalid type: string \"invalid\", expected struct Message")); + + // Test validation errors + let request = ChatCompletionsRequest::deserialize(json!({ "model": "", "detectors": DetectorConfig::default(), "messages": Vec::::default(), - })); + }))?; + let result = request.validate(); assert!(result.is_err_and(|error| error.to_string() == "`model` must not be empty")); - let result = ChatCompletionsRequest::deserialize(json!({ + let request = ChatCompletionsRequest::deserialize(json!({ "model": "test", "detectors": DetectorConfig::default(), "messages": Vec::::default(), - })); + }))?; + let result = request.validate(); assert!(result.is_err_and(|error| error.to_string() == "`messages` must not be empty")); - let result = ChatCompletionsRequest::deserialize(json!({ - "model": "test", - "detectors": DetectorConfig::default(), - "messages": ["invalid"], - })); - assert!(result.is_err_and(|error| error.to_string() == "error deserializing `messages`")); - // Test serialize + let request = ChatCompletionsRequest::deserialize(&json!({ + "model": "test", + "detectors": { + "input": {"some_detector": {}}, + "output": {}, + }, + "messages": [{"role": "user", "content": "Hi there!"}], + "frequency_penalty": 2.0, + }))?; let serialized_request = serde_json::to_value(request)?; + // should include stream: false and exclude detectors assert_eq!( serialized_request, json!({ "model": "test", - "messages": [Message { - content: Some(Content::Text("Hi there!".to_string())), - role: Role::User, - ..Default::default() - }], + "messages": [{"role": "user", "content": "Hi there!"}], + "frequency_penalty": 2.0, }) ); diff --git a/src/orchestrator/common/client.rs b/src/orchestrator/common/client.rs index 27375103..b4d0c12b 100644 --- a/src/orchestrator/common/client.rs +++ b/src/orchestrator/common/client.rs @@ -31,7 +31,7 @@ use crate::{ TextGenerationDetectorClient, }, http::JSON_CONTENT_TYPE, - openai::{self, ChatCompletionsResponse, OpenAiClient}, + openai::{self, OpenAiClient}, }, models::{ ClassifiedGeneratedTextResult as GenerateResponse, DetectorParams, @@ -282,8 +282,55 @@ pub async fn chat_completion_stream( error, })?; let stream = match response { - ChatCompletionsResponse::Streaming(rx) => ReceiverStream::new(rx), - ChatCompletionsResponse::Unary(_) => unimplemented!(), + openai::ChatCompletionsResponse::Streaming(rx) => ReceiverStream::new(rx), + openai::ChatCompletionsResponse::Unary(_) => unimplemented!(), + } + .enumerate() + .boxed(); + Ok(stream) +} + +/// Sends request to openai completions client. +#[instrument(skip_all, fields(model_id))] +pub async fn completion( + client: &OpenAiClient, + mut headers: HeaderMap, + request: openai::CompletionsRequest, +) -> Result { + let model_id = request.model.clone(); + debug!(%model_id, ?request, "sending completions request"); + headers.append(CONTENT_TYPE, JSON_CONTENT_TYPE); + let response = client + .completions(request, headers) + .await + .map_err(|error| Error::CompletionRequestFailed { + id: model_id.clone(), + error, + })?; + debug!(%model_id, ?response, "received completions response"); + Ok(response) +} + +/// Sends stream request to openai completions client. +#[instrument(skip_all, fields(model_id))] +pub async fn completion_stream( + client: &OpenAiClient, + mut headers: HeaderMap, + request: openai::CompletionsRequest, +) -> Result { + let model_id = request.model.clone(); + debug!(%model_id, ?request, "sending completions stream request"); + headers.append(CONTENT_TYPE, JSON_CONTENT_TYPE); + let response = client + .completions(request, headers) + .await + .map_err(|error| Error::CompletionRequestFailed { + id: model_id.clone(), + error, + })?; + let stream = match response { + openai::CompletionsResponse::Streaming(rx) => ReceiverStream::new(rx), + openai::CompletionsResponse::Unary(_) => unimplemented!(), } .enumerate() .boxed(); diff --git a/src/orchestrator/errors.rs b/src/orchestrator/errors.rs index 9f691656..55456b83 100644 --- a/src/orchestrator/errors.rs +++ b/src/orchestrator/errors.rs @@ -33,6 +33,8 @@ pub enum Error { GenerateRequestFailed { id: String, error: clients::Error }, #[error("chat completion request failed for `{id}`: {error}")] ChatCompletionRequestFailed { id: String, error: clients::Error }, + #[error("completion request failed for `{id}`: {error}")] + CompletionRequestFailed { id: String, error: clients::Error }, #[error("tokenize request failed for `{id}`: {error}")] TokenizeRequestFailed { id: String, error: clients::Error }, #[error("validation error: {0}")] diff --git a/src/orchestrator/handlers/chat_completions_detection.rs b/src/orchestrator/handlers/chat_completions_detection.rs index 7b78b2d2..0f5beb18 100644 --- a/src/orchestrator/handlers/chat_completions_detection.rs +++ b/src/orchestrator/handlers/chat_completions_detection.rs @@ -38,8 +38,8 @@ impl Handle for Orchestrator { async fn handle(&self, task: ChatCompletionsDetectionTask) -> Result { let ctx = self.ctx.clone(); match task.request.stream { - true => streaming::handle_streaming(ctx, task).await, - false => unary::handle_unary(ctx, task).await, + Some(true) => streaming::handle_streaming(ctx, task).await, + _ => unary::handle_unary(ctx, task).await, } } } diff --git a/src/orchestrator/types.rs b/src/orchestrator/types.rs index c23070c8..b187fa0a 100644 --- a/src/orchestrator/types.rs +++ b/src/orchestrator/types.rs @@ -30,7 +30,10 @@ pub mod detection_batch_stream; pub use detection_batch_stream::*; use super::Error; -use crate::{clients::openai::ChatCompletionChunk, models::ClassifiedGeneratedTextStreamResult}; +use crate::{ + clients::openai::{ChatCompletionChunk, Completion}, + models::ClassifiedGeneratedTextStreamResult, +}; pub type ChunkerId = String; pub type DetectorId = String; @@ -42,3 +45,4 @@ pub type InputStream = BoxStream>; pub type DetectionStream = BoxStream>; pub type GenerationStream = BoxStream<(usize, Result)>; pub type ChatCompletionStream = BoxStream<(usize, Result, Error>)>; +pub type CompletionStream = BoxStream<(usize, Result, Error>)>; diff --git a/src/server/routes.rs b/src/server/routes.rs index bf654b2e..55ee5331 100644 --- a/src/server/routes.rs +++ b/src/server/routes.rs @@ -326,6 +326,7 @@ async fn chat_completions_detection( ) -> Result { use ChatCompletionsResponse::*; let trace_id = current_trace_id(); + request.validate()?; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); match state.orchestrator.handle(task).await { diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs index d1e16b52..334d49c3 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_detection.rs @@ -93,6 +93,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { index: 0, logprobs: None, finish_reason: "NOT_FINISHED".to_string(), + stop_reason: None, }, ChatCompletionChoice { message: ChatCompletionMessage { @@ -104,6 +105,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { index: 1, logprobs: None, finish_reason: "EOS_TOKEN".to_string(), + stop_reason: None, }, ]; let chat_completions_response = ChatCompletion { @@ -552,6 +554,7 @@ async fn output_detections() -> Result<(), anyhow::Error> { index: 0, logprobs: None, finish_reason: "NOT_FINISHED".to_string(), + stop_reason: None, }, ChatCompletionChoice { message: ChatCompletionMessage { @@ -563,6 +566,7 @@ async fn output_detections() -> Result<(), anyhow::Error> { index: 1, logprobs: None, finish_reason: "EOS_TOKEN".to_string(), + stop_reason: None, }, ]; From 2d3b5e4c39493b8e9833bb3b8958c1a4fe95c75c Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Thu, 8 May 2025 14:31:21 -0300 Subject: [PATCH 21/24] Implement re-try logic for TestOrchestratorServer port binding (#392) * Implement re-try logic for TestOrchestratorServer port binding Signed-off-by: Mateus Devino * Match ioError specifically Signed-off-by: Mateus Devino * Apply changes requested on PR Signed-off-by: Mateus Devino --------- Signed-off-by: Mateus Devino --- tests/common/orchestrator.rs | 99 +++++++++++++----------------------- 1 file changed, 34 insertions(+), 65 deletions(-) diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index f143ffcc..f9fb7e9c 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -25,7 +25,7 @@ use std::{ use bytes::Bytes; use eventsource_stream::{EventStream, Eventsource}; -use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator}; +use fms_guardrails_orchestr8::{config::OrchestratorConfig, orchestrator::Orchestrator, server}; use futures::{ Stream, StreamExt, stream::{ @@ -33,10 +33,10 @@ use futures::{ }, }; use mocktail::server::MockServer; -use rand::Rng; +use rand::{Rng, SeedableRng, rngs::SmallRng}; use rustls::crypto::ring; use serde::{Serialize, de::DeserializeOwned}; -use tokio::task::JoinHandle; +use tracing::{error, warn}; use url::Url; // Default orchestrator configuration file for integration tests. @@ -131,71 +131,54 @@ impl<'a> TestOrchestratorServerBuilder<'a> { initialize_chunkers(self.chunker_servers.as_deref(), &mut config).await?; // Create & start test orchestrator server - let port = self.port.unwrap_or_else(|| find_available_port().unwrap()); - let health_port = self - .health_port - .unwrap_or_else(|| find_available_port().unwrap()); - let mut server = TestOrchestratorServer::new(config, port, health_port); - server.start().await?; + let server = TestOrchestratorServer::start(config).await?; Ok(server) } } pub struct TestOrchestratorServer { - config: OrchestratorConfig, - port: u16, - health_port: u16, base_url: Url, health_url: Url, client: reqwest::Client, - _handle: Option>>, } impl TestOrchestratorServer { - pub fn new(config: OrchestratorConfig, port: u16, health_port: u16) -> Self { - let base_url = Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(); - let health_url = Url::parse(&format!("http://0.0.0.0:{health_port}/health")).unwrap(); - let client = reqwest::Client::builder().build().unwrap(); - Self { - config, - port, - health_port, - base_url, - health_url, - client, - _handle: None, - } - } - pub fn builder<'a>() -> TestOrchestratorServerBuilder<'a> { TestOrchestratorServerBuilder::default() } /// Starts the orchestrator server. - pub async fn start(&mut self) -> Result<(), anyhow::Error> { - let orchestrator = Orchestrator::new(self.config.clone(), false).await?; - let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), self.port); - let health_http_addr: SocketAddr = - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), self.health_port); - let handle = tokio::spawn(async move { - fms_guardrails_orchestr8::server::run( - http_addr, - health_http_addr, - None, - None, - None, - orchestrator, - ) - .await?; - Ok::<(), anyhow::Error>(()) - }); - self._handle = Some(handle); - - // Give the server time to become ready. - tokio::time::sleep(Duration::from_millis(10)).await; - - Ok(()) + pub async fn start(config: OrchestratorConfig) -> Result { + let mut rng = SmallRng::from_os_rng(); + loop { + let port = rng.random_range(10000..60000); + let health_port = rng.random_range(10000..60000); + let orchestrator = Orchestrator::new(config.clone(), false).await?; + let http_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); + let health_http_addr: SocketAddr = + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), health_port); + match server::run(http_addr, health_http_addr, None, None, None, orchestrator).await { + Ok(_) => { + // Give the server time to become ready. + tokio::time::sleep(Duration::from_millis(10)).await; + return Ok(Self { + base_url: Url::parse(&format!("http://0.0.0.0:{port}")).unwrap(), + health_url: Url::parse(&format!("http://0.0.0.0:{health_port}/health")) + .unwrap(), + client: reqwest::Client::builder().build().unwrap(), + }); + } + Err(server::Error::IoError(error)) => { + warn!(%error, "failed to start server, trying again with different ports..."); + continue; + } + Err(error) => { + error!(%error, "failed to start server"); + return Err(error.into()); + } + }; + } } pub fn server_url(&self, path: &str) -> Url { @@ -322,20 +305,6 @@ where } } -fn find_available_port() -> Option { - let mut rng = rand::rng(); - loop { - let port: u16 = rng.random_range(40000..60000); - if port_is_available(port) { - return Some(port); - } - } -} - -fn port_is_available(port: u16) -> bool { - std::net::TcpListener::bind(("0.0.0.0", port)).is_ok() -} - pub fn json_lines_stream( messages: impl IntoIterator, ) -> impl Stream, std::io::Error>> { From 88dcf85c16452064465ba2864b939ee2490fad26 Mon Sep 17 00:00:00 2001 From: Dan Clark <44146800+declark1@users.noreply.github.com> Date: Mon, 12 May 2025 09:49:08 -0700 Subject: [PATCH 22/24] Batcher simplifications (#394) * Drop Batch associated type from DetectionBatcher and generics, drop detector_id from DetectionStream, integrate single detection stream optimization Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> * Update DetectionBatchStream docstring Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --------- Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- src/orchestrator/common/tasks.rs | 20 +-- .../streaming_classification_with_gen.rs | 43 +---- .../handlers/streaming_content_detection.rs | 43 +---- src/orchestrator/types.rs | 3 +- .../types/detection_batch_stream.rs | 155 +++++++++--------- src/orchestrator/types/detection_batcher.rs | 18 +- .../detection_batcher/chat_completion.rs | 64 +++----- .../detection_batcher/max_processed_index.rs | 58 +++---- .../types/detection_batcher/noop.rs | 53 ------ 9 files changed, 135 insertions(+), 322 deletions(-) delete mode 100644 src/orchestrator/types/detection_batcher/noop.rs diff --git a/src/orchestrator/common/tasks.rs b/src/orchestrator/common/tasks.rs index 3326676a..efd708da 100644 --- a/src/orchestrator/common/tasks.rs +++ b/src/orchestrator/common/tasks.rs @@ -190,9 +190,9 @@ pub async fn text_contents_detections( ctx: Arc, headers: HeaderMap, detectors: HashMap, - input_id: InputId, + input_id: u32, inputs: Vec<(usize, String)>, -) -> Result<(InputId, Detections), Error> { +) -> Result<(u32, Detections), Error> { let chunkers = get_chunker_ids(&ctx, &detectors)?; let chunk_map = chunks(ctx.clone(), chunkers, inputs).await?; let inputs = detectors @@ -249,7 +249,7 @@ pub async fn text_contents_detection_streams( ctx: Arc, headers: HeaderMap, detectors: HashMap, - input_id: InputId, + input_id: u32, input_rx: mpsc::Receiver>, // (message_index, text) ) -> Result, Error> { // Create chunk streams @@ -294,14 +294,8 @@ pub async fn text_contents_detection_streams( .filter(|detection| detection.score >= threshold) .collect::(); // Send to detection channel - let _ = detection_tx - .send(Ok(( - input_id, - detector_id.clone(), - chunk, - detections, - ))) - .await; + let _ = + detection_tx.send(Ok((input_id, chunk, detections))).await; } Err(error) => { // Send error to detection channel @@ -985,9 +979,7 @@ mod test { let mut fake_detector_stream = detection_streams.swap_remove(0); let mut results = Vec::with_capacity(1); - while let Some(Ok((_input_id, _detector_id, _chunk, detections))) = - fake_detector_stream.next().await - { + while let Some(Ok((_input_id, _chunk, detections))) = fake_detector_stream.next().await { results.push(detections); } assert_eq!(results.len(), 1); diff --git a/src/orchestrator/handlers/streaming_classification_with_gen.rs b/src/orchestrator/handlers/streaming_classification_with_gen.rs index f43f3039..d1d60214 100644 --- a/src/orchestrator/handlers/streaming_classification_with_gen.rs +++ b/src/orchestrator/handlers/streaming_classification_with_gen.rs @@ -40,8 +40,7 @@ use crate::{ Context, Error, Orchestrator, common::{self, validate_detectors}, types::{ - Chunk, DetectionBatchStream, DetectionStream, Detections, GenerationStream, - MaxProcessedIndexBatcher, + Chunk, DetectionBatchStream, Detections, GenerationStream, MaxProcessedIndexBatcher, }, }, }; @@ -254,12 +253,6 @@ async fn handle_output_detection( let generations = generations.clone(); async move { match detection_streams { - Ok(mut detection_streams) if detection_streams.len() == 1 => { - // Process single detection stream, batching not applicable - let detection_stream = detection_streams.swap_remove(0); - process_detection_stream(trace_id, generations, detection_stream, response_tx) - .await; - } Ok(detection_streams) => { // Create detection batch stream let detection_batch_stream = DetectionBatchStream::new( @@ -335,47 +328,17 @@ async fn forward_generation_stream( info!(%trace_id, "task completed: generation stream closed"); } -/// Consumes a detection stream, builds responses, and sends them to a response channel. -#[instrument(skip_all)] -async fn process_detection_stream( - trace_id: TraceId, - generations: Arc>>, - mut detection_stream: DetectionStream, - response_tx: mpsc::Sender>, -) { - while let Some(result) = detection_stream.next().await { - match result { - Ok((_, _detector_id, chunk, detections)) => { - // Create response for this batch with output detections - let response = output_detection_response(&generations, chunk, detections).unwrap(); - // Send message to response channel - if response_tx.send(Ok(response)).await.is_err() { - info!(%trace_id, "task completed: client disconnected"); - return; - } - } - Err(error) => { - error!(%trace_id, %error, "task failed: error received from detection stream"); - // Send error to response channel and terminate - let _ = response_tx.send(Err(error)).await; - return; - } - } - } - info!(%trace_id, "task completed: detection stream closed"); -} - /// Consumes a detection batch stream, builds responses, and sends them to a response channel. #[instrument(skip_all)] async fn process_detection_batch_stream( trace_id: TraceId, generations: Arc>>, - mut detection_batch_stream: DetectionBatchStream, + mut detection_batch_stream: DetectionBatchStream, response_tx: mpsc::Sender>, ) { while let Some(result) = detection_batch_stream.next().await { match result { - Ok((chunk, detections)) => { + Ok((_, chunk, detections)) => { // Create response for this batch with output detections let response = output_detection_response(&generations, chunk, detections).unwrap(); // Send message to response channel diff --git a/src/orchestrator/handlers/streaming_content_detection.rs b/src/orchestrator/handlers/streaming_content_detection.rs index 97e143a3..7d09116d 100644 --- a/src/orchestrator/handlers/streaming_content_detection.rs +++ b/src/orchestrator/handlers/streaming_content_detection.rs @@ -30,7 +30,7 @@ use crate::{ orchestrator::{ Context, Error, Orchestrator, common::{self, validate_detectors}, - types::{BoxStream, DetectionBatchStream, DetectionStream, MaxProcessedIndexBatcher}, + types::{BoxStream, DetectionBatchStream, MaxProcessedIndexBatcher}, }, }; @@ -133,11 +133,6 @@ async fn handle_detection( tokio::spawn( async move { match detection_streams { - Ok(mut detection_streams) if detection_streams.len() == 1 => { - // Process single detection stream, batching not applicable - let detection_stream = detection_streams.swap_remove(0); - process_detection_stream(trace_id, detection_stream, response_tx).await; - } Ok(detection_streams) => { // Create detection batch stream let detection_batch_stream = DetectionBatchStream::new( @@ -177,48 +172,16 @@ async fn handle_detection( ); } -/// Consumes a detection stream, builds responses, and sends them to a response channel. -#[instrument(skip_all)] -async fn process_detection_stream( - trace_id: TraceId, - mut detection_stream: DetectionStream, - response_tx: mpsc::Sender>, -) { - while let Some(result) = detection_stream.next().await { - match result { - Ok((_, _detector_id, chunk, detections)) => { - let response = StreamingContentDetectionResponse { - start_index: chunk.start as u32, - processed_index: chunk.end as u32, - detections: detections.into(), - }; - // Send message to response channel - if response_tx.send(Ok(response)).await.is_err() { - info!(%trace_id, "task completed: client disconnected"); - return; - } - } - Err(error) => { - error!(%trace_id, %error, "task failed: error received from detection stream"); - // Send error to response channel and terminate - let _ = response_tx.send(Err(error)).await; - return; - } - } - } - info!(%trace_id, "task completed: detection stream closed"); -} - /// Consumes a detection batch stream, builds responses, and sends them to a response channel. #[instrument(skip_all)] async fn process_detection_batch_stream( trace_id: TraceId, - mut detection_batch_stream: DetectionBatchStream, + mut detection_batch_stream: DetectionBatchStream, response_tx: mpsc::Sender>, ) { while let Some(result) = detection_batch_stream.next().await { match result { - Ok((chunk, detections)) => { + Ok((_, chunk, detections)) => { let response = StreamingContentDetectionResponse { start_index: chunk.start as u32, processed_index: chunk.end as u32, diff --git a/src/orchestrator/types.rs b/src/orchestrator/types.rs index b187fa0a..089bc74f 100644 --- a/src/orchestrator/types.rs +++ b/src/orchestrator/types.rs @@ -37,12 +37,11 @@ use crate::{ pub type ChunkerId = String; pub type DetectorId = String; -pub type InputId = u32; pub type BoxStream = Pin + Send>>; pub type ChunkStream = BoxStream>; pub type InputStream = BoxStream>; -pub type DetectionStream = BoxStream>; +pub type DetectionStream = BoxStream>; pub type GenerationStream = BoxStream<(usize, Result)>; pub type ChatCompletionStream = BoxStream<(usize, Result, Error>)>; pub type CompletionStream = BoxStream<(usize, Result, Error>)>; diff --git a/src/orchestrator/types/detection_batch_stream.rs b/src/orchestrator/types/detection_batch_stream.rs index a03da5dd..6da2a2cf 100644 --- a/src/orchestrator/types/detection_batch_stream.rs +++ b/src/orchestrator/types/detection_batch_stream.rs @@ -18,69 +18,84 @@ use futures::{Stream, StreamExt, stream}; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error}; -use super::{Chunk, DetectionBatcher, DetectionStream, Detections, DetectorId, InputId}; +use super::{Batch, Chunk, DetectionBatcher, DetectionStream, Detections}; use crate::orchestrator::Error; -/// A stream adapter that wraps multiple detection streams and +/// A stream adapter that wraps detection streams and /// produces a stream of batches using a [`DetectionBatcher`] /// implementation. /// -/// The detection batcher enables flexible batching -/// logic and returned batch types for different use cases. -pub struct DetectionBatchStream { - batch_rx: mpsc::Receiver>, +/// The detection batcher enables flexible batching logic for different use cases. +pub struct DetectionBatchStream { + batch_rx: mpsc::Receiver>, } -impl DetectionBatchStream -where - B: DetectionBatcher, -{ - pub fn new(batcher: B, streams: Vec) -> Self { +impl DetectionBatchStream { + pub fn new(batcher: impl DetectionBatcher, mut streams: Vec) -> Self { let (batch_tx, batch_rx) = mpsc::channel(32); - // Create single stream from multiple detection streams - let mut stream_set = stream::select_all(streams); - // Create batcher manager, an actor to manage the batcher instead of using locks - let batcher_manager = DetectionBatcherManagerHandle::new(batcher); // Spawn task to receive detections and process batches tokio::spawn(async move { - let mut stream_completed = false; - loop { - tokio::select! { - // Disable random branch selection to poll the futures in order - biased; - - // Receive detections and push to batcher - msg = stream_set.next(), if !stream_completed => { - match msg { - Some(Ok((input_id, detector_id, chunk, detections))) => { - debug!(%input_id, ?chunk, ?detections, "pushing detections to batcher"); - batcher_manager - .push(input_id, detector_id, chunk, detections) - .await; - }, - Some(Err(error)) => { - error!(?error, "sending error to batch channel"); - let _ = batch_tx.send(Err(error)).await; - break; - }, - None => { - debug!("detections stream has completed"); - stream_completed = true; - }, + if streams.len() == 1 { + // Skip the batching process for a single detection stream + let mut stream = streams.swap_remove(0); + while let Some(msg) = stream.next().await { + match msg { + Ok(batch) => { + debug!(?batch, "sending batch to batch channel"); + let _ = batch_tx.send(Ok(batch)).await; } - }, - // Pop batches and send them to batch channel - Some(batch) = batcher_manager.pop() => { - debug!(?batch, "sending batch to batch channel"); - let _ = batch_tx.send(Ok(batch)).await; - }, - // Terminate task when stream is completed and batcher state is empty - empty = batcher_manager.is_empty(), if stream_completed => { - if empty { + Err(error) => { + error!(?error, "sending error to batch channel"); + let _ = batch_tx.send(Err(error)).await; break; } } } + debug!("detections stream has completed"); + } else { + // Create single stream from multiple detection streams + let mut stream_set = stream::select_all(streams); + // Create batcher manager, an actor to manage the batcher instead of using locks + let batcher_manager = DetectionBatcherManagerHandle::new(batcher); + let mut stream_completed = false; + loop { + tokio::select! { + // Disable random branch selection to poll the futures in order + biased; + + // Receive detections and push to batcher + msg = stream_set.next(), if !stream_completed => { + match msg { + Some(Ok((input_id, chunk, detections))) => { + debug!(%input_id, ?chunk, ?detections, "pushing detections to batcher"); + batcher_manager + .push(input_id, chunk, detections) + .await; + }, + Some(Err(error)) => { + error!(?error, "sending error to batch channel"); + let _ = batch_tx.send(Err(error)).await; + break; + }, + None => { + debug!("detections stream has completed"); + stream_completed = true; + }, + } + }, + // Pop batches and send them to batch channel + Some(batch) = batcher_manager.pop() => { + debug!(?batch, "sending batch to batch channel"); + let _ = batch_tx.send(Ok(batch)).await; + }, + // Terminate task when stream is completed and batcher state is empty + empty = batcher_manager.is_empty(), if stream_completed => { + if empty { + break; + } + } + } + } } debug!("detection batch stream task has completed"); }); @@ -89,11 +104,8 @@ where } } -impl Stream for DetectionBatchStream -where - B: DetectionBatcher, -{ - type Item = Result; +impl Stream for DetectionBatchStream { + type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, @@ -103,10 +115,9 @@ where } } -enum DetectionBatcherMessage { +enum DetectionBatcherMessage { Push { - input_id: InputId, - detector_id: DetectorId, + input_id: u32, chunk: Chunk, detections: Detections, }, @@ -121,14 +132,14 @@ enum DetectionBatcherMessage { /// An actor that manages a [`DetectionBatcher`]. struct DetectionBatcherManager { batcher: B, - rx: mpsc::Receiver>, + rx: mpsc::Receiver, } impl DetectionBatcherManager where B: DetectionBatcher, { - pub fn new(batcher: B, rx: mpsc::Receiver>) -> Self { + pub fn new(batcher: B, rx: mpsc::Receiver) -> Self { Self { batcher, rx } } @@ -137,12 +148,11 @@ where match msg { DetectionBatcherMessage::Push { input_id, - detector_id, chunk, detections, } => { - debug!(%input_id, %detector_id, ?chunk, ?detections, "handling push request"); - self.batcher.push(input_id, detector_id, chunk, detections) + debug!(%input_id, ?chunk, ?detections, "handling push request"); + self.batcher.push(input_id, chunk, detections) } DetectionBatcherMessage::Pop { response_tx } => { debug!("handling pop request"); @@ -163,17 +173,13 @@ where /// A handle to a [`DetectionBatcherManager`]. #[derive(Clone)] -struct DetectionBatcherManagerHandle { - tx: mpsc::Sender>, +struct DetectionBatcherManagerHandle { + tx: mpsc::Sender, } -impl DetectionBatcherManagerHandle -where - B: DetectionBatcher, - B::Batch: Clone, -{ +impl DetectionBatcherManagerHandle { /// Creates a new [`DetectionBatcherManager`] and returns its handle. - pub fn new(batcher: B) -> Self { + pub fn new(batcher: impl DetectionBatcher) -> Self { let (tx, rx) = mpsc::channel(32); let mut actor = DetectionBatcherManager::new(batcher, rx); tokio::spawn(async move { actor.run().await }); @@ -181,18 +187,11 @@ where } /// Pushes new detections to the batcher. - pub async fn push( - &self, - input_id: InputId, - detector_id: DetectorId, - chunk: Chunk, - detections: Detections, - ) { + pub async fn push(&self, input_id: u32, chunk: Chunk, detections: Detections) { let _ = self .tx .send(DetectionBatcherMessage::Push { input_id, - detector_id, chunk, detections, }) @@ -200,7 +199,7 @@ where } /// Removes the next batch of detections from the batcher, if ready. - pub async fn pop(&self) -> Option { + pub async fn pop(&self) -> Option { let (response_tx, response_rx) = oneshot::channel(); let _ = self .tx diff --git a/src/orchestrator/types/detection_batcher.rs b/src/orchestrator/types/detection_batcher.rs index 19ca98f3..71b0be69 100644 --- a/src/orchestrator/types/detection_batcher.rs +++ b/src/orchestrator/types/detection_batcher.rs @@ -16,29 +16,21 @@ */ pub mod chat_completion; pub use chat_completion::*; -pub mod noop; -pub use noop::*; pub mod max_processed_index; pub use max_processed_index::*; -use super::{Chunk, Detections, DetectorId, InputId}; +use super::{Chunk, Detections}; + +pub type Batch = (u32, Chunk, Detections); /// A detection batcher. /// Implements pluggable batching logic for a [`DetectionBatchStream`]. pub trait DetectionBatcher: std::fmt::Debug + Clone + Send + 'static { - type Batch: std::fmt::Debug + Clone + Send + 'static; - /// Pushes new detections. - fn push( - &mut self, - input_id: InputId, - detector_id: DetectorId, - chunk: Chunk, - detections: Detections, - ); + fn push(&mut self, input_id: u32, chunk: Chunk, detections: Detections); /// Removes the next batch of detections, if ready. - fn pop_batch(&mut self) -> Option; + fn pop_batch(&mut self) -> Option; /// Returns `true` if the batcher state is empty. fn is_empty(&self) -> bool; diff --git a/src/orchestrator/types/detection_batcher/chat_completion.rs b/src/orchestrator/types/detection_batcher/chat_completion.rs index dae87e43..5a8fc19a 100644 --- a/src/orchestrator/types/detection_batcher/chat_completion.rs +++ b/src/orchestrator/types/detection_batcher/chat_completion.rs @@ -16,7 +16,7 @@ */ use std::collections::{BTreeMap, btree_map}; -use super::{Chunk, DetectionBatcher, Detections, DetectorId}; +use super::{Batch, Chunk, DetectionBatcher, Detections}; pub type ChoiceIndex = u32; @@ -53,15 +53,7 @@ impl ChatCompletionBatcher { } impl DetectionBatcher for ChatCompletionBatcher { - type Batch = (Chunk, ChoiceIndex, Detections); - - fn push( - &mut self, - choice_index: ChoiceIndex, - _detector_id: DetectorId, - chunk: Chunk, - detections: Detections, - ) { + fn push(&mut self, choice_index: ChoiceIndex, chunk: Chunk, detections: Detections) { match self.state.entry((chunk, choice_index)) { btree_map::Entry::Vacant(entry) => { // New chunk, insert entry @@ -74,7 +66,7 @@ impl DetectionBatcher for ChatCompletionBatcher { } } - fn pop_batch(&mut self) -> Option { + fn pop_batch(&mut self) -> Option { // Batching logic here will only assume detections with the same chunker type // Requirements in https://github.com/foundation-model-stack/fms-guardrails-orchestrator/blob/main/docs/architecture/adrs/005-chat-completion-support.md#streaming-response // for detections on whole output will be handled outside of the batcher @@ -88,7 +80,7 @@ impl DetectionBatcher for ChatCompletionBatcher { // We have all detections for the chunk, remove and return it. if let Some(((chunk, choice_index), detections)) = self.state.pop_first() { let detections = detections.into_iter().flatten().collect(); - return Some((chunk, choice_index, detections)); + return Some((choice_index, chunk, detections)); } } None @@ -131,7 +123,6 @@ mod test { // Push chunk detections for pii detector batcher.push( choice_index, - "pii".into(), chunk.clone(), vec![Detection { start: Some(5), @@ -151,7 +142,6 @@ mod test { // Push chunk detections for hap detector batcher.push( choice_index, - "hap".into(), chunk.clone(), vec![ Detection { @@ -178,7 +168,7 @@ mod test { // pop_batch() should return a batch containing 3 detections for the chunk let batch = batcher.pop_batch(); assert!( - batch.is_some_and(|(actual_chunk, actual_choice_index, detections)| { + batch.is_some_and(|(actual_choice_index, actual_chunk, detections)| { actual_chunk == chunk && actual_choice_index == choice_index && detections.len() == 3 @@ -219,21 +209,18 @@ mod test { // Push chunk-2 detections for pii detector batcher.push( choice_index, - "pii".into(), chunks[1].clone(), Detections::default(), // no detections ); // Push chunk-1 detections for hap detector batcher.push( choice_index, - "hap".into(), chunks[0].clone(), Detections::default(), // no detections ); // Push chunk-2 detections for hap detector batcher.push( choice_index, - "hap".into(), chunks[1].clone(), Detections::default(), // no detections ); @@ -247,7 +234,6 @@ mod test { for choice_index in 0..choices { batcher.push( choice_index, - "pii".into(), chunks[0].clone(), vec![Detection { start: Some(10), @@ -264,25 +250,25 @@ mod test { // We have all detections for chunk-1 and chunk-2 // pop_batch() should return chunk-1 with 1 pii detection, for the first choice let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == chunks[0] && choice_index == 0 && detections.len() == 1 })); // Return the same chunk-1 with 1 pii detection for the second choice let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == chunks[0] && choice_index == 1 && detections.len() == 1 })); // pop_batch() should return chunk-2 with no detections, for the first choice let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == chunks[1] && choice_index == 0 && detections.is_empty() })); // Return the same chunk-2 with no detections for the second choice let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == chunks[1] && choice_index == 1 && detections.is_empty() })); @@ -352,42 +338,36 @@ mod test { // Push chunk-2 detections for pii detector, choice 1 batcher.push( choice_1_index, - "pii".into(), choice_1_chunks[1].clone(), Detections::default(), // no detections ); // Same for choice 2 batcher.push( choice_2_index, - "pii".into(), choice_2_chunks[1].clone(), Detections::default(), // no detections ); // Push chunk-2 detections for hap detector, choice 2 batcher.push( choice_2_index, - "hap".into(), choice_2_chunks[1].clone(), Detections::default(), // no detections ); // Same for choice 1 batcher.push( choice_1_index, - "hap".into(), choice_1_chunks[1].clone(), Detections::default(), // no detections ); // Push chunk-1 detections for hap detector, choice 1 batcher.push( choice_1_index, - "hap".into(), choice_1_chunks[0].clone(), Detections::default(), // no detections ); // Same for choice 2 batcher.push( choice_2_index, - "hap".into(), choice_2_chunks[0].clone(), Detections::default(), // no detections ); @@ -399,7 +379,6 @@ mod test { // Push chunk-1 detections for pii detector, for first choice batcher.push( choice_1_index, - "pii".into(), choice_1_chunks[0].clone(), vec![Detection { start: Some(10), @@ -414,7 +393,6 @@ mod test { // Push chunk-1 detections for pii detector, for second choice batcher.push( choice_2_index, - "pii".into(), choice_2_chunks[0].clone(), vec![Detection { start: Some(10), @@ -430,21 +408,21 @@ mod test { // We have all detections for chunk-1 and chunk-2 // Expect 4 chunks, with those for the chunk-1 chunks first let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == choice_1_chunks[0] && choice_index == choice_1_index && detections.len() == 1 })); let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == choice_2_chunks[0] && choice_index == choice_2_index && detections.len() == 1 })); // chunk-2 chunks let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == choice_1_chunks[1] && choice_index == choice_1_index && detections.is_empty() })); let batch = batcher.pop_batch(); - assert!(batch.is_some_and(|(chunk, choice_index, detections)| { + assert!(batch.is_some_and(|(choice_index, chunk, detections)| { chunk == choice_2_chunks[1] && choice_index == choice_2_index && detections.is_empty() })); @@ -479,10 +457,10 @@ mod test { // Create detection channels and streams let (pii_detections_tx, pii_detections_rx) = - mpsc::channel::>(4); + mpsc::channel::>(4); let pii_detections_stream = ReceiverStream::new(pii_detections_rx).boxed(); let (hap_detections_tx, hap_detections_rx) = - mpsc::channel::>(4); + mpsc::channel::>(4); let hap_detections_stream = ReceiverStream::new(hap_detections_rx).boxed(); // Create a batcher that will process batches for 2 detectors @@ -498,7 +476,6 @@ mod test { let _ = pii_detections_tx .send(Ok(( choice_index, - "pii".into(), chunks[1].clone(), Detections::default(), // no detections ))) @@ -508,7 +485,6 @@ mod test { let _ = hap_detections_tx .send(Ok(( choice_index, - "hap".into(), chunks[0].clone(), Detections::default(), // no detections ))) @@ -518,7 +494,6 @@ mod test { let _ = hap_detections_tx .send(Ok(( choice_index, - "hap".into(), chunks[1].clone(), Detections::default(), // no detections ))) @@ -537,7 +512,6 @@ mod test { let _ = pii_detections_tx .send(Ok(( choice_index, - "pii".into(), chunks[0].clone(), vec![Detection { start: Some(10), @@ -556,7 +530,7 @@ mod test { // detection_batch_stream.next() should be ready and return chunk-1 with 1 pii detection, for choice 1 let batch = detection_batch_stream.next().await; assert!(batch.is_some_and(|result| { - result.is_ok_and(|(chunk, choice_index, detections)| { + result.is_ok_and(|(choice_index, chunk, detections)| { chunk == chunks[0] && choice_index == 0 && detections.len() == 1 }) })); @@ -564,7 +538,7 @@ mod test { // Then choice 2 let batch = detection_batch_stream.next().await; assert!(batch.is_some_and(|result| { - result.is_ok_and(|(chunk, choice_index, detections)| { + result.is_ok_and(|(choice_index, chunk, detections)| { chunk == chunks[0] && choice_index == 1 && detections.len() == 1 }) })); @@ -572,7 +546,7 @@ mod test { // detection_batch_stream.next() should be ready and return chunk-2 with no detections, for choice 1 let batch = detection_batch_stream.next().await; assert!(batch.is_some_and(|result| { - result.is_ok_and(|(chunk, choice_index, detections)| { + result.is_ok_and(|(choice_index, chunk, detections)| { chunk == chunks[1] && choice_index == 0 && detections.is_empty() }) })); @@ -580,7 +554,7 @@ mod test { // Then choice 2 let batch = detection_batch_stream.next().await; assert!(batch.is_some_and(|result| { - result.is_ok_and(|(chunk, choice_index, detections)| { + result.is_ok_and(|(choice_index, chunk, detections)| { chunk == chunks[1] && choice_index == 1 && detections.is_empty() }) })); diff --git a/src/orchestrator/types/detection_batcher/max_processed_index.rs b/src/orchestrator/types/detection_batcher/max_processed_index.rs index 60c4edb8..08717200 100644 --- a/src/orchestrator/types/detection_batcher/max_processed_index.rs +++ b/src/orchestrator/types/detection_batcher/max_processed_index.rs @@ -16,7 +16,7 @@ */ use std::collections::{BTreeMap, btree_map}; -use super::{Chunk, DetectionBatcher, Detections, DetectorId, InputId}; +use super::{Batch, Chunk, DetectionBatcher, Detections}; /// A batcher based on the original "max processed index" /// aggregator. @@ -49,15 +49,7 @@ impl MaxProcessedIndexBatcher { } impl DetectionBatcher for MaxProcessedIndexBatcher { - type Batch = (Chunk, Detections); - - fn push( - &mut self, - _input_id: InputId, - _detector_id: DetectorId, - chunk: Chunk, - detections: Detections, - ) { + fn push(&mut self, _input_id: u32, chunk: Chunk, detections: Detections) { match self.state.entry(chunk) { btree_map::Entry::Vacant(entry) => { // New chunk, insert entry @@ -70,7 +62,7 @@ impl DetectionBatcher for MaxProcessedIndexBatcher { } } - fn pop_batch(&mut self) -> Option { + fn pop_batch(&mut self) -> Option { // Check if we have all detections for the next chunk if self .state @@ -80,7 +72,7 @@ impl DetectionBatcher for MaxProcessedIndexBatcher { // We have all detections for the chunk, remove and return it. if let Some((chunk, detections)) = self.state.pop_first() { let detections = detections.into_iter().flatten().collect(); - return Some((chunk, detections)); + return Some((0, chunk, detections)); } } None @@ -123,7 +115,6 @@ mod test { // Push chunk detections for pii detector batcher.push( input_id, - "pii".into(), chunk.clone(), vec![Detection { start: Some(5), @@ -143,7 +134,6 @@ mod test { // Push chunk detections for hap detector batcher.push( input_id, - "hap".into(), chunk.clone(), vec![ Detection { @@ -169,9 +159,9 @@ mod test { // We have detections for 2 detectors // pop_batch() should return a batch containing 3 detections for the chunk let batch = batcher.pop_batch(); - assert!( - batch.is_some_and(|(chunk, detections)| { chunk == chunk && detections.len() == 3 }) - ); + assert!(batch.is_some_and(|(_input_id, chunk, detections)| { + chunk == chunk && detections.len() == 3 + })); } #[test] @@ -210,21 +200,18 @@ mod test { // Push chunk-2 detections for pii detector batcher.push( input_id, - "pii".into(), chunks[1].clone(), Detections::default(), // no detections ); // Push chunk-2 detections for hap detector batcher.push( input_id, - "hap".into(), chunks[1].clone(), Detections::default(), // no detections ); // Push chunk-1 detections for hap detector batcher.push( input_id, - "hap".into(), chunks[0].clone(), Detections::default(), // no detections ); @@ -236,7 +223,6 @@ mod test { // Push chunk-1 detections for pii detector batcher.push( input_id, - "pii".into(), chunks[0].clone(), vec![Detection { start: Some(10), @@ -252,17 +238,15 @@ mod test { // We have all detections for chunk-1 and chunk-2 // pop_batch() should return chunk-1 with 1 pii detection let batch = batcher.pop_batch(); - assert!( - batch - .is_some_and(|(chunk, detections)| { chunk == chunks[0] && detections.len() == 1 }) - ); + assert!(batch.is_some_and(|(_input_id, chunk, detections)| { + chunk == chunks[0] && detections.len() == 1 + })); // pop_batch() should return chunk-2 with no detections let batch = batcher.pop_batch(); - assert!( - batch - .is_some_and(|(chunk, detections)| { chunk == chunks[1] && detections.is_empty() }) - ); + assert!(batch.is_some_and(|(_input_id, chunk, detections)| { + chunk == chunks[1] && detections.is_empty() + })); // batcher state should be empty as all batches have been returned assert!(batcher.state.is_empty()); @@ -294,10 +278,10 @@ mod test { // Create detection channels and streams let (pii_detections_tx, pii_detections_rx) = - mpsc::channel::>(4); + mpsc::channel::>(4); let pii_detections_stream = ReceiverStream::new(pii_detections_rx).boxed(); let (hap_detections_tx, hap_detections_rx) = - mpsc::channel::>(4); + mpsc::channel::>(4); let hap_detections_stream = ReceiverStream::new(hap_detections_rx).boxed(); // Create a batcher that will process batches for 2 detectors @@ -312,7 +296,6 @@ mod test { let _ = pii_detections_tx .send(Ok(( input_id, - "pii".into(), chunks[1].clone(), Detections::default(), // no detections ))) @@ -322,7 +305,6 @@ mod test { let _ = hap_detections_tx .send(Ok(( input_id, - "hap".into(), chunks[0].clone(), Detections::default(), // no detections ))) @@ -332,7 +314,6 @@ mod test { let _ = hap_detections_tx .send(Ok(( input_id, - "hap".into(), chunks[1].clone(), Detections::default(), // no detections ))) @@ -349,7 +330,6 @@ mod test { let _ = pii_detections_tx .send(Ok(( input_id, - "pii".into(), chunks[0].clone(), vec![Detection { start: Some(10), @@ -367,13 +347,17 @@ mod test { // detection_batch_stream.next() should be ready and return chunk-1 with 1 pii detection let batch = detection_batch_stream.next().await; assert!(batch.is_some_and(|result| { - result.is_ok_and(|(chunk, detections)| chunk == chunks[0] && detections.len() == 1) + result.is_ok_and(|(_input_id, chunk, detections)| { + chunk == chunks[0] && detections.len() == 1 + }) })); // detection_batch_stream.next() should be ready and return chunk-2 with no detections let batch = detection_batch_stream.next().await; assert!(batch.is_some_and(|result| { - result.is_ok_and(|(chunk, detections)| chunk == chunks[1] && detections.is_empty()) + result.is_ok_and(|(_input_id, chunk, detections)| { + chunk == chunks[1] && detections.is_empty() + }) })); // detection_batch_stream.next() future should not be ready diff --git a/src/orchestrator/types/detection_batcher/noop.rs b/src/orchestrator/types/detection_batcher/noop.rs deleted file mode 100644 index 14048252..00000000 --- a/src/orchestrator/types/detection_batcher/noop.rs +++ /dev/null @@ -1,53 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ -use std::collections::VecDeque; - -use super::{Chunk, DetectionBatcher, Detections, DetectorId, InputId}; - -/// A no-op batcher that doesn't actually batch. -#[derive(Default, Debug, Clone)] -pub struct NoopBatcher { - state: VecDeque<(Chunk, Detections)>, -} - -impl NoopBatcher { - pub fn new() -> Self { - Self::default() - } -} - -impl DetectionBatcher for NoopBatcher { - type Batch = (Chunk, Detections); - - fn push( - &mut self, - _input_id: InputId, - _detector_id: DetectorId, - chunk: Chunk, - detections: Detections, - ) { - self.state.push_back((chunk, detections)); - } - - fn pop_batch(&mut self) -> Option { - self.state.pop_front() - } - - fn is_empty(&self) -> bool { - self.state.is_empty() - } -} From cc064f18a0e64eebb8bb5aef224d19f3db8ae396 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Tue, 20 May 2025 16:57:06 -0300 Subject: [PATCH 23/24] Validate message content on chat completions endpoint (#383) * Add content validation Signed-off-by: Mateus Devino * Add no_detectors tests for chat completions Signed-off-by: Mateus Devino * Add empty content message test for chat completions Signed-off-by: Mateus Devino * Add comments about changing validation Signed-off-by: Mateus Devino * Update chat completions passthrough tests for empty message Signed-off-by: Mateus Devino * Add further validation for empty content Signed-off-by: Mateus Devino * Test case: passthrough with last message as an array Signed-off-by: Mateus Devino * Prevent content array on last message when input detector is present Signed-off-by: Mateus Devino * nit: combine if clauses Signed-off-by: Mateus Devino --------- Signed-off-by: Mateus Devino --- src/clients/openai.rs | 46 +++ .../chat_completions_detection/unary.rs | 2 + tests/chat_completions_detection.rs | 366 +++++++++++++++++- 3 files changed, 413 insertions(+), 1 deletion(-) diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 19050e56..3c679ecb 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -259,6 +259,26 @@ impl ChatCompletionsRequest { "`messages` must not be empty".into(), )); } + + if !self.detectors.input.is_empty() { + // Content of type Array is not supported yet + // Adding this validation separately as we do plan to support arrays of string in the future + if let Some(Content::Array(_)) = self.messages.last().unwrap().content { + return Err(ValidationError::Invalid( + "Detection on array is not supported".into(), + )); + } + + // As text_content detections only run on last message at the moment, only the last + // message is being validated. + if self.messages.last().unwrap().is_text_content_empty() { + return Err(ValidationError::Invalid( + "if input detectors are provided, `content` must not be empty on last message" + .into(), + )); + } + } + Ok(()) } } @@ -414,6 +434,32 @@ pub struct Message { pub tool_call_id: Option, } +impl Message { + /// Checks if text content of a message is empty. + /// + /// The following messages are considered empty: + /// 1. [`Message::content`] is None. + /// 2. [`Message::content`] is an empty string. + /// 3. [`Message::content`] is an empty array. + /// 4. [`Message::content`] is an array of empty strings and ContentType is Text. + pub fn is_text_content_empty(&self) -> bool { + match &self.content { + Some(content) => match content { + Content::Text(string) => string.is_empty(), + Content::Array(content_parts) => { + content_parts.is_empty() + || content_parts.iter().all(|content_part| { + content_part.text.is_none() + || (content_part.r#type == ContentType::Text + && content_part.text.as_ref().unwrap().is_empty()) + }) + } + }, + None => true, + } + } +} + /// Content. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] diff --git a/src/orchestrator/handlers/chat_completions_detection/unary.rs b/src/orchestrator/handlers/chat_completions_detection/unary.rs index dc310d36..4444ae8f 100644 --- a/src/orchestrator/handlers/chat_completions_detection/unary.rs +++ b/src/orchestrator/handlers/chat_completions_detection/unary.rs @@ -108,6 +108,8 @@ async fn handle_input_detection( let model_id = task.request.model.clone(); // Input detectors are only applied to the last message + // If this changes, the empty content validation in [`ChatCompletionsRequest::validate`] + // should also change. // Get the last message let messages = task.request.messages(); let message = if let Some(message) = messages.last() { diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_detection.rs index 334d49c3..42e4e1d0 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_detection.rs @@ -35,7 +35,8 @@ use fms_guardrails_orchestr8::{ detector::{ContentAnalysisRequest, ContentAnalysisResponse}, openai::{ ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, ChatDetections, Content, - InputDetectionResult, Message, OrchestratorWarning, OutputDetectionResult, Role, + ContentPart, ContentType, InputDetectionResult, Message, OrchestratorWarning, + OutputDetectionResult, Role, }, }, models::{ @@ -59,6 +60,199 @@ pub mod common; const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; const MODEL_ID: &str = "my-super-model-8B"; +// Validate passthrough scenario +#[test(tokio::test)] +async fn no_detectors() -> Result<(), anyhow::Error> { + let messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Text("".to_string())), + role: Role::Assistant, + ..Default::default() + }, + ]; + + // Add mocksets + let mut chat_mocks = MockSet::new(); + + let expected_choices = vec![ + ChatCompletionChoice { + message: ChatCompletionMessage { + role: messages[0].role.clone(), + content: Some("Hi there!".to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 0, + logprobs: None, + finish_reason: "NOT_FINISHED".to_string(), + stop_reason: None, + }, + ChatCompletionChoice { + message: ChatCompletionMessage { + role: messages[1].role.clone(), + content: Some("Hello!".to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 1, + logprobs: None, + finish_reason: "EOS_TOKEN".to_string(), + stop_reason: None, + }, + ]; + let chat_completions_response = ChatCompletion { + model: MODEL_ID.into(), + choices: expected_choices.clone(), + detections: None, + warnings: vec![], + ..Default::default() + }; + + // Add chat completions mock + chat_mocks.mock(|when, then| { + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages, + })); + then.json(&chat_completions_response); + }); + + // Start orchestrator server and its dependencies + let mut mock_chat_completions_server = + MockServer::new("chat_completions").with_mocks(chat_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .chat_generation_server(&mock_chat_completions_server) + .build() + .await?; + + // Empty `detectors` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": {}, + "messages": messages, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], chat_completions_response.choices[0]); + assert_eq!(results.choices[1], chat_completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + // Missing `detectors` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "messages": messages, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], chat_completions_response.choices[0]); + assert_eq!(results.choices[1], chat_completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + // `detectors` with empty `input` and `output` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "messages": messages, + "detectors": { + "input": {}, + "output": {}, + }, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], chat_completions_response.choices[0]); + assert_eq!(results.choices[1], chat_completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + // message content as array, `detectors` with empty `input` and `output` scenario + let messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Array(vec![ + ContentPart { + r#type: ContentType::Text, + text: Some("How".into()), + image_url: None, + refusal: None, + }, + ContentPart { + r#type: ContentType::Text, + text: Some("are".into()), + image_url: None, + refusal: None, + }, + ContentPart { + r#type: ContentType::Text, + text: Some("you?".into()), + image_url: None, + refusal: None, + }, + ])), + role: Role::Assistant, + ..Default::default() + }, + ]; + + // add new mock + mock_chat_completions_server.mock(|when, then| { + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages, + })); + then.json(&chat_completions_response); + }); + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "messages": messages, + "detectors": { + "input": {}, + "output": {}, + }, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], chat_completions_response.choices[0]); + assert_eq!(results.choices[1], chat_completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + Ok(()) +} + // Validate that requests without detectors, input detector and output detector configured // returns text generated by model #[test(tokio::test)] @@ -1017,5 +1211,175 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { "failed on non-existing input detector scenario" ); + // input detectors and last message without `content` scenario + let no_content_messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + role: Role::User, + ..Default::default() + }, + ]; + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: {} + + } + }, + "messages": no_content_messages, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: "if input detectors are provided, `content` must not be empty on last message" + .into() + } + ); + + // input detectors and last message with empty string as `content` scenario + let no_content_messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Text("".into())), + role: Role::User, + ..Default::default() + }, + ]; + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: {} + + } + }, + "messages": no_content_messages, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: "if input detectors are provided, `content` must not be empty on last message" + .into() + } + ); + + // input detectors and last message with empty array as `content` scenario + let no_content_messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Array(Vec::new())), + role: Role::User, + ..Default::default() + }, + ]; + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: {} + + } + }, + "messages": no_content_messages, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: "Detection on array is not supported".into() + } + ); + + // input detectors and last message with array of empty strings as `content` scenario + let no_content_messages = vec![ + Message { + content: Some(Content::Text("Hi there!".to_string())), + role: Role::User, + ..Default::default() + }, + Message { + content: Some(Content::Array(vec![ + ContentPart { + r#type: ContentType::Text, + text: Some("".into()), + image_url: None, + refusal: None, + }, + ContentPart { + r#type: ContentType::Text, + text: Some("".into()), + image_url: None, + refusal: None, + }, + ])), + role: Role::User, + ..Default::default() + }, + ]; + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC: {} + + } + }, + "messages": no_content_messages, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + OrchestratorError { + code: 422, + details: "Detection on array is not supported".into() + } + ); + Ok(()) } From 2347bba32b3c6e715f043fe4acf08777a9320611 Mon Sep 17 00:00:00 2001 From: Mateus Devino <19861348+mdevino@users.noreply.github.com> Date: Thu, 29 May 2025 14:06:47 -0300 Subject: [PATCH 24/24] Update Rust to 1.87.0 and dependencies (#404) Signed-off-by: Mateus Devino --- Cargo.lock | 495 ++++++++++++++++++++++---------------------- Cargo.toml | 48 ++--- Dockerfile | 2 +- rust-toolchain.toml | 2 +- 4 files changed, 271 insertions(+), 276 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc41907f..a4018fe0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,9 +78,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.97" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "assert-json-diff" @@ -145,9 +145,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-lc-rs" -version = "1.12.6" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dabb68eb3a7aa08b46fddfd59a3d55c978243557a90ab804769f7e20e67d2b01" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" dependencies = [ "aws-lc-sys", "zeroize", @@ -155,9 +155,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.27.1" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77926887776171ced7d662120a75998e444d3750c951abfe07f90da130514b1f" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" dependencies = [ "bindgen", "cc", @@ -195,11 +195,11 @@ dependencies = [ [[package]] name = "axum" -version = "0.8.1" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ - "axum-core 0.5.0", + "axum-core 0.5.2", "bytes", "form_urlencoded", "futures-util", @@ -249,12 +249,12 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" dependencies = [ "bytes", - "futures-util", + "futures-core", "http", "http-body", "http-body-util", @@ -269,12 +269,12 @@ dependencies = [ [[package]] name = "axum-extra" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "460fc6f625a1f7705c6cf62d0d070794e94668988b1c38111baeec177c715f7b" +checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d" dependencies = [ - "axum 0.8.1", - "axum-core 0.5.0", + "axum 0.8.4", + "axum-core 0.5.2", "bytes", "futures-util", "http", @@ -282,6 +282,7 @@ dependencies = [ "http-body-util", "mime", "pin-project-lite", + "rustversion", "serde", "serde_json", "tokio", @@ -294,14 +295,14 @@ dependencies = [ [[package]] name = "axum-test" -version = "17.2.0" +version = "17.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317c1f4ecc1e68e0ad5decb78478421055c963ce215e736ed97463fa609cd196" +checksum = "0eb1dfb84bd48bad8e4aa1acb82ed24c2bb5e855b659959b4e03b4dca118fcac" dependencies = [ "anyhow", "assert-json-diff", "auto-future", - "axum 0.8.1", + "axum 0.8.4", "bytes", "bytesize", "cookie", @@ -368,9 +369,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bumpalo" @@ -386,15 +387,15 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "bytesize" -version = "1.3.2" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d2c12f985c78475a6b8d629afd0c360260ef34cfef52efccdcfd31972f81c2e" +checksum = "a3c8f83209414aacf0eeae3cf730b18d6981697fba62f200fcfb92b9f082acba" [[package]] name = "cc" -version = "1.2.17" +version = "1.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a" +checksum = "16595d3be041c03b09d08d0858631facccee9221e579704070e6e9e4915d3bc7" dependencies = [ "jobserver", "libc", @@ -435,9 +436,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.32" +version = "4.5.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6088f3ae8c3608d19260cd7445411865a485688711b78b5be70d78cd96136f83" +checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" dependencies = [ "clap_builder", "clap_derive", @@ -445,9 +446,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.32" +version = "4.5.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22a7ef7f676155edfb82daa97f99441f3ebf4a58d5e32f295a56259f1b6facc8" +checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" dependencies = [ "anstream", "anstyle", @@ -618,9 +619,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" dependencies = [ "libc", "windows-sys 0.59.0", @@ -651,11 +652,11 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "fms-guardrails-orchestr8" -version = "0.1.0" +version = "0.10.0" dependencies = [ "anyhow", "async-trait", - "axum 0.8.1", + "axum 0.8.4", "axum-extra", "axum-test", "bytes", @@ -679,7 +680,7 @@ dependencies = [ "opentelemetry_sdk", "pin-project-lite", "prost", - "rand 0.9.0", + "rand 0.9.1", "reqwest", "rustls", "rustls-pemfile", @@ -830,9 +831,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", @@ -843,9 +844,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", "js-sys", @@ -886,9 +887,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "h2" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" +checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" dependencies = [ "atomic-waker", "bytes", @@ -896,7 +897,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.8.0", + "indexmap 2.9.0", "slab", "tokio", "tokio-util", @@ -911,9 +912,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" [[package]] name = "heck" @@ -1065,11 +1066,10 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.5" +version = "0.27.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" dependencies = [ - "futures-util", "http", "hyper", "hyper-util", @@ -1080,7 +1080,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.0", ] [[package]] @@ -1114,40 +1114,48 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "b1c293b6b3d21eca78250dc7dbebd6b9210ec5530e038cbfe0661b5c47ab06e8" dependencies = [ + "base64", "bytes", "futures-channel", + "futures-core", "futures-util", "http", "http-body", "hyper", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", "socket2", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -1156,31 +1164,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -1188,67 +1176,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "idna" version = "1.0.3" @@ -1262,9 +1237,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -1282,12 +1257,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.3", ] [[package]] @@ -1308,6 +1283,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -1340,10 +1325,11 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -1371,9 +1357,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libloading" @@ -1382,7 +1368,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -1409,15 +1395,15 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" @@ -1431,9 +1417,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.26" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "lru-cache" @@ -1444,6 +1430,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "matchers" version = "0.1.0" @@ -1477,16 +1469,6 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" -[[package]] -name = "mime_guess" -version = "2.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" -dependencies = [ - "mime", - "unicase", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1495,28 +1477,28 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.5" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "mocktail" version = "0.2.5-alpha" -source = "git+https://github.com/IBM/mocktail#6296c2783ba1d433407ae1d8144ec5619dc021b9" +source = "git+https://github.com/IBM/mocktail#025d724965f5d4ee7cc6666bf22845a896b00b58" dependencies = [ "bytes", "futures", @@ -1526,7 +1508,7 @@ dependencies = [ "hyper", "hyper-util", "prost", - "rand 0.9.0", + "rand 0.9.1", "serde", "serde_json", "thiserror 2.0.12", @@ -1597,9 +1579,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.1" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "openssl" @@ -1635,9 +1617,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.107" +version = "0.9.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" +checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847" dependencies = [ "cc", "libc", @@ -1767,7 +1749,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.8.0", + "indexmap 2.9.0", ] [[package]] @@ -1808,6 +1790,15 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1835,9 +1826,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" dependencies = [ "proc-macro2", "syn", @@ -1845,9 +1836,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -1906,9 +1897,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3bd15a6f2967aef83887dcb9fec0014580467e33720d073560cf015a5683012" +checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" dependencies = [ "bytes", "cfg_aliases", @@ -1926,13 +1917,14 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.10" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc" +checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ "bytes", - "getrandom 0.3.2", - "rand 0.9.0", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.1", "ring", "rustc-hash 2.1.1", "rustls", @@ -1946,9 +1938,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.10" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" +checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" dependencies = [ "cfg_aliases", "libc", @@ -1986,13 +1978,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy", ] [[package]] @@ -2021,7 +2012,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -2030,14 +2021,14 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.2", + "getrandom 0.3.3", ] [[package]] name = "redox_syscall" -version = "0.5.10" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1" +checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" dependencies = [ "bitflags", ] @@ -2088,9 +2079,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.15" +version = "0.12.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" +checksum = "e98ff6b0dbbe4d5a37318f433d4fc82babd21631f194d370409ceb2e40b2f0b5" dependencies = [ "base64", "bytes", @@ -2116,35 +2107,32 @@ dependencies = [ "pin-project-lite", "quinn", "rustls", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", - "system-configuration", "tokio", "tokio-native-tls", "tokio-rustls", "tokio-util", "tower 0.5.2", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", - "windows-registry", + "webpki-roots 1.0.0", ] [[package]] name = "reserve-port" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "359fc315ed556eb0e42ce74e76f4b1cd807b50fa6307f3de4e51f92dbe86e2d5" +checksum = "ba3747658ee2585ecf5607fa9887c92eff61b362ff5253dbf797dfeb73d33d78" dependencies = [ - "lazy_static", "thiserror 2.0.12", ] @@ -2165,7 +2153,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -2173,17 +2161,16 @@ dependencies = [ [[package]] name = "rust-multipart-rfc7578_2" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc4bb9e7c9abe5fa5f30c2d8f8fefb9e0080a2c1e3c2e567318d2907054b35d3" +checksum = "c839d037155ebc06a571e305af66ff9fd9063a6e662447051737e1ac75beea41" dependencies = [ "bytes", "futures-core", "futures-util", "http", "mime", - "mime_guess", - "rand 0.9.0", + "rand 0.9.1", "thiserror 2.0.12", ] @@ -2220,29 +2207,29 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.3" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e56a18552996ac8d29ecc3b190b4fdbb2d91ca4ec396de7bbffaf43f3d637e96" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ "bitflags", "errno", "libc", - "linux-raw-sys 0.9.3", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.25" +version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c" +checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "aws-lc-rs", "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.0", + "rustls-webpki 0.103.3", "subtle", "zeroize", ] @@ -2270,11 +2257,12 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ "web-time", + "zeroize", ] [[package]] @@ -2290,9 +2278,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.0" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aa4eeac2588ffff23e9d7a7e9b3f971c5fb5b7ebc9452745e0c232c64f83b2f" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ "aws-lc-rs", "ring", @@ -2302,9 +2290,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" [[package]] name = "ryu" @@ -2389,7 +2377,7 @@ version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ - "indexmap 2.8.0", + "indexmap 2.9.0", "itoa", "memchr", "ryu", @@ -2424,7 +2412,7 @@ version = "0.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59e2dd588bf1597a252c3b920e0143eb99b0f76e4e082f4c92ce34fbc9e71ddd" dependencies = [ - "indexmap 2.8.0", + "indexmap 2.9.0", "itoa", "libyml", "memchr", @@ -2450,9 +2438,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -2468,15 +2456,15 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", @@ -2502,9 +2490,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -2522,9 +2510,9 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", @@ -2554,14 +2542,14 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", - "getrandom 0.3.2", + "getrandom 0.3.3", "once_cell", - "rustix 1.0.3", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -2639,9 +2627,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.40" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d9c75b47bdff86fa3334a3db91356b8d7d86a9b839dab7d0bdc5c3d3a077618" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -2660,9 +2648,9 @@ checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.21" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29aa485584182073ed57fd5004aa09c371f021325014694e432313345865fd04" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -2670,9 +2658,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -2695,9 +2683,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.2" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", "bytes", @@ -2756,9 +2744,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.14" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", @@ -2798,7 +2786,7 @@ dependencies = [ "tower-layer", "tower-service", "tracing", - "webpki-roots", + "webpki-roots 0.26.8", ] [[package]] @@ -2853,15 +2841,18 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +checksum = "0fdb0c213ca27a9f57ab69ddb290fd80d970922355b83ae380b395d3986b8a2e" dependencies = [ "bitflags", "bytes", + "futures-util", "http", "http-body", + "iri-string", "pin-project-lite", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -2978,12 +2969,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "unicase" -version = "2.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" - [[package]] name = "unicode-ident" version = "1.0.18" @@ -3007,12 +2992,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -3027,12 +3006,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ - "getrandom 0.3.2", - "rand 0.9.0", + "getrandom 0.3.3", + "js-sys", + "rand 0.9.1", + "wasm-bindgen", ] [[package]] @@ -3190,6 +3171,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-roots" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2853738d1cc4f2da3a225c18ec6c3721abb31961096e9dbf5ab35fa88b19cfdb" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "4.4.2" @@ -3515,17 +3505,11 @@ dependencies = [ "bitflags", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "yansi" @@ -3535,9 +3519,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -3547,9 +3531,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", @@ -3559,18 +3543,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", @@ -3604,11 +3588,22 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ "yoke", "zerofrom", @@ -3617,9 +3612,9 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 27d62511..c162bbaf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fms-guardrails-orchestr8" -version = "0.1.0" +version = "0.10.0" edition = "2024" authors = ["Evaline Ju", "Gaurav Kumbhat", "Dan Clark"] description = "Foundation models orchestration server" @@ -14,24 +14,24 @@ name = "fms-guardrails-orchestr8" path = "src/main.rs" [dependencies] -anyhow = "1.0.95" -async-trait = "0.1.85" -axum = { version = "0.8.1", features = ["json"] } -axum-extra = { version = "0.10.0", features = ["json-lines"] } -bytes = "1.10.0" -clap = { version = "4.5.26", features = ["derive", "env"] } +anyhow = "1.0.98" +async-trait = "0.1.88" +axum = { version = "0.8.4", features = ["json"] } +axum-extra = { version = "0.10.1", features = ["json-lines"] } +bytes = "1.10.1" +clap = { version = "4.5.39", features = ["derive", "env"] } eventsource-stream = "0.2.3" futures = "0.3.31" futures-util = { version = "0.3", default-features = false, features = [] } ginepro = "0.8.2" -http = "1.2.0" +http = "1.3.1" http-body = "1.0" -http-body-util = "0.1.2" +http-body-util = "0.1.3" http-serde = "2.1.1" -hyper = { version = "1.5.2", features = ["http1", "http2", "server"] } -hyper-rustls = { version = "0.27.5", features = ["ring"] } +hyper = { version = "1.6.0", features = ["http1", "http2", "server"] } +hyper-rustls = { version = "0.27.6", features = ["ring"] } hyper-timeout = "0.5.2" -hyper-util = { version = "0.1.10", features = [ +hyper-util = { version = "0.1.13", features = [ "server-auto", "server-graceful", "tokio", @@ -44,24 +44,24 @@ opentelemetry-otlp = { version = "0.27.0", features = [ ] } opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio", "metrics"] } pin-project-lite = "0.2.16" -prost = "0.13.4" -reqwest = { version = "0.12.12", features = [ +prost = "0.13.5" +reqwest = { version = "0.12.18", features = [ "blocking", "rustls-tls", "json", "stream", ] } -rustls = { version = "0.23.21", default-features = false, features = [ +rustls = { version = "0.23.27", default-features = false, features = [ "ring", "std", ] } rustls-pemfile = "2.2.0" rustls-webpki = "0.102.8" -serde = { version = "1.0.217", features = ["derive"] } -serde_json = { version = "1.0.135", features = ["preserve_order"] } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = { version = "1.0.140", features = ["preserve_order"] } serde_yml = "0.0.12" -thiserror = "2.0.11" -tokio = { version = "1.44.2", features = [ +thiserror = "2.0.12" +tokio = { version = "1.45.1", features = [ "rt", "rt-multi-thread", "parking_lot", @@ -69,7 +69,7 @@ tokio = { version = "1.44.2", features = [ "sync", "fs", ] } -tokio-rustls = { version = "0.26.1", features = ["ring"] } +tokio-rustls = { version = "0.26.2", features = ["ring"] } tokio-stream = { version = "0.1.17", features = ["sync"] } tonic = { version = "0.12.3", features = [ "tls", @@ -77,20 +77,20 @@ tonic = { version = "0.12.3", features = [ "tls-webpki-roots", ] } tower = { version = "0.5.2", features = ["timeout"] } -tower-http = { version = "0.6.2", features = ["trace"] } +tower-http = { version = "0.6.4", features = ["trace"] } tracing = "0.1.41" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } url = "2.5.4" -uuid = { version = "1.12.1", features = ["v4"] } +uuid = { version = "1.17.0", features = ["v4"] } [build-dependencies] tonic-build = "0.12.3" [dev-dependencies] -axum-test = "17.1.0" +axum-test = "17.3.0" mocktail = { git = "https://github.com/IBM/mocktail" } -rand = "0.9.0" +rand = "0.9.1" test-log = "0.2.17" [profile.release] diff --git a/Dockerfile b/Dockerfile index e51c5177..69ca4f6b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ # Specific debian version so that compatible glibc version is used -FROM rust:1.86.0 AS rust-builder +FROM rust:1.87.0 AS rust-builder ARG PROTOC_VERSION ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/rust-toolchain.toml b/rust-toolchain.toml index d0929449..5675074a 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.86.0" +channel = "1.87.0" components = ["rustfmt", "clippy"]