Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c95a3e8
Stream content tests and error handling fix (#350)
mdevino Apr 3, 2025
49af6ea
Integration tests for `/api/v2/chat/completions-detection` (#360)
pmcjr Apr 4, 2025
2557f3c
Bump openssl from 0.10.71 to 0.10.72 (#361)
dependabot[bot] Apr 7, 2025
6473cb8
Update Rust to 1.86 (#364)
mdevino Apr 7, 2025
3d2aa3c
Bump tokio from 1.44.1 to 1.44.2 (#365)
dependabot[bot] Apr 8, 2025
1d872de
Single detector tests (#362)
mdevino Apr 8, 2025
3a78353
refactor: task handlers (#355)
declark1 Apr 15, 2025
75c2d01
Drop faux (#374)
mdevino Apr 16, 2025
d602892
:fire: Chunker client cleanup (#376)
evaline-ju Apr 17, 2025
bd79b42
Guardrails request config validation (#371)
mdevino Apr 17, 2025
9253698
Integration tests for no detectors requests (#373)
mdevino Apr 17, 2025
6bc045f
refactor ChatCompletionsRequest (#375)
declark1 Apr 18, 2025
1adecc3
Tests cleanup (#379)
mdevino Apr 24, 2025
c65cd4c
Fix DetectionBatchStream (#381)
declark1 Apr 24, 2025
9f66e5d
To support s390x (#369)
Sanketha-Cr Apr 25, 2025
063b8d6
:sparkles: Chat completions batcher (#380)
evaline-ju Apr 25, 2025
d473b00
:goal_net: Handle unsupported media type errors (#386)
evaline-ju Apr 28, 2025
e3985d5
:bug: Allow input detection on whole input for streaming text generat…
evaline-ju Apr 29, 2025
9c7fd48
refactor: server cleanups (#387)
declark1 May 1, 2025
19be26e
OpenAiClient updates & completions implementation (#390)
declark1 May 6, 2025
2d3b5e4
Implement re-try logic for TestOrchestratorServer port binding (#392)
mdevino May 8, 2025
88dcf85
Batcher simplifications (#394)
declark1 May 12, 2025
cc064f1
Validate message content on chat completions endpoint (#383)
mdevino May 20, 2025
2347bba
Update Rust to 1.87.0 and dependencies (#404)
mdevino May 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
572 changes: 248 additions & 324 deletions Cargo.lock

Large diffs are not rendered by default.

51 changes: 25 additions & 26 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fms-guardrails-orchestr8"
version = "0.1.0"
version = "0.10.0"
edition = "2024"
authors = ["Evaline Ju", "Gaurav Kumbhat", "Dan Clark"]
description = "Foundation models orchestration server"
Expand All @@ -14,24 +14,24 @@ name = "fms-guardrails-orchestr8"
path = "src/main.rs"

[dependencies]
anyhow = "1.0.95"
async-trait = "0.1.85"
axum = { version = "0.8.1", features = ["json"] }
axum-extra = { version = "0.10.0", features = ["json-lines"] }
bytes = "1.10.0"
clap = { version = "4.5.26", features = ["derive", "env"] }
anyhow = "1.0.98"
async-trait = "0.1.88"
axum = { version = "0.8.4", features = ["json"] }
axum-extra = { version = "0.10.1", features = ["json-lines"] }
bytes = "1.10.1"
clap = { version = "4.5.39", features = ["derive", "env"] }
eventsource-stream = "0.2.3"
futures = "0.3.31"
futures-util = { version = "0.3", default-features = false, features = [] }
ginepro = "0.8.2"
http = "1.2.0"
http = "1.3.1"
http-body = "1.0"
http-body-util = "0.1.2"
http-body-util = "0.1.3"
http-serde = "2.1.1"
hyper = { version = "1.5.2", features = ["http1", "http2", "server"] }
hyper-rustls = { version = "0.27.5", features = ["ring"] }
hyper = { version = "1.6.0", features = ["http1", "http2", "server"] }
hyper-rustls = { version = "0.27.6", features = ["ring"] }
hyper-timeout = "0.5.2"
hyper-util = { version = "0.1.10", features = [
hyper-util = { version = "0.1.13", features = [
"server-auto",
"server-graceful",
"tokio",
Expand All @@ -44,54 +44,53 @@ opentelemetry-otlp = { version = "0.27.0", features = [
] }
opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio", "metrics"] }
pin-project-lite = "0.2.16"
prost = "0.13.4"
reqwest = { version = "0.12.12", features = [
prost = "0.13.5"
reqwest = { version = "0.12.18", features = [
"blocking",
"rustls-tls",
"json",
"stream",
] }
rustls = { version = "0.23.21", default-features = false, features = [
rustls = { version = "0.23.27", default-features = false, features = [
"ring",
"std",
] }
rustls-pemfile = "2.2.0"
rustls-webpki = "0.102.8"
serde = { version = "1.0.217", features = ["derive"] }
serde_json = { version = "1.0.135", features = ["preserve_order"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = { version = "1.0.140", features = ["preserve_order"] }
serde_yml = "0.0.12"
thiserror = "2.0.11"
tokio = { version = "1.43.0", features = [
thiserror = "2.0.12"
tokio = { version = "1.45.1", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
"fs",
] }
tokio-rustls = { version = "0.26.1", features = ["ring"] }
tokio-rustls = { version = "0.26.2", features = ["ring"] }
tokio-stream = { version = "0.1.17", features = ["sync"] }
tonic = { version = "0.12.3", features = [
"tls",
"tls-roots",
"tls-webpki-roots",
] }
tower = { version = "0.5.2", features = ["timeout"] }
tower-http = { version = "0.6.2", features = ["trace"] }
tower-http = { version = "0.6.4", features = ["trace"] }
tracing = "0.1.41"
tracing-opentelemetry = "0.28.0"
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
url = "2.5.4"
uuid = { version = "1.12.1", features = ["v4"] }
uuid = { version = "1.17.0", features = ["v4"] }

[build-dependencies]
tonic-build = "0.12.3"

[dev-dependencies]
axum-test = "17.1.0"
faux = "0.1.12"
mocktail = { version = "0.2.4-alpha" }
rand = "0.9.0"
axum-test = "17.3.0"
mocktail = { git = "https://github.com/IBM/mocktail" }
rand = "0.9.1"
test-log = "0.2.17"

[profile.release]
Expand Down
15 changes: 12 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,24 @@ ARG CONFIG_FILE=config/config.yaml

## Rust builder ################################################################
# Specific debian version so that compatible glibc version is used
FROM rust:1.85.1-bullseye AS rust-builder
FROM rust:1.87.0 AS rust-builder
ARG PROTOC_VERSION

ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

# Install protoc, no longer included in prost crate
RUN cd /tmp && \
curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip && \
unzip protoc-*.zip -d /usr/local && rm protoc-*.zip
if [ "$(uname -m)" = "s390x" ]; then \
apt update && \
apt install -y cmake clang libclang-dev curl unzip && \
curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-s390_64.zip; \
else \
curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip; \
fi && \
unzip protoc-*.zip -d /usr/local && \
rm protoc-*.zip

ENV LIBCLANG_PATH=/usr/lib/llvm-14/lib/

WORKDIR /app

Expand Down
4 changes: 2 additions & 2 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "1.85.1"
components = ["rustfmt", "clippy"]
channel = "1.87.0"
components = ["rustfmt", "clippy"]
2 changes: 1 addition & 1 deletion src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ impl OtlpProtocol {

#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub enum LogFormat {
Compact,
#[default]
Full,
Compact,
Pretty,
JSON,
}
Expand Down
6 changes: 1 addition & 5 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use hyper_timeout::TimeoutConnector;
use hyper_util::rt::TokioExecutor;
use tonic::{Request, metadata::MetadataMap};
use tower::{ServiceBuilder, timeout::TimeoutLayer};
use tracing::{Span, debug, instrument};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use url::Url;

Expand Down Expand Up @@ -205,7 +205,6 @@ impl ClientMap {
}
}

#[instrument(skip_all, fields(hostname = service_config.hostname))]
pub async fn create_http_client(
default_port: u16,
service_config: &ServiceConfig,
Expand All @@ -220,7 +219,6 @@ pub async fn create_http_client(
base_url
.set_port(Some(port))
.unwrap_or_else(|_| panic!("error setting port: {}", port));
debug!(%base_url, "creating HTTP client");

let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC);
let request_timeout = Duration::from_secs(
Expand Down Expand Up @@ -257,7 +255,6 @@ pub async fn create_http_client(
Ok(HttpClient::new(base_url, client))
}

#[instrument(skip_all, fields(hostname = service_config.hostname))]
pub async fn create_grpc_client<C: Debug + Clone>(
default_port: u16,
service_config: &ServiceConfig,
Expand All @@ -270,7 +267,6 @@ pub async fn create_grpc_client<C: Debug + Clone>(
};
let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)).unwrap();
base_url.set_port(Some(port)).unwrap();
debug!(%base_url, "creating gRPC client");
let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC);
let request_timeout = Duration::from_secs(
service_config
Expand Down
118 changes: 3 additions & 115 deletions src/clients/chunker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use std::pin::Pin;

use async_trait::async_trait;
use axum::http::HeaderMap;
use futures::{Future, Stream, StreamExt, TryStreamExt};
use futures::{Future, StreamExt, TryStreamExt};
use ginepro::LoadBalancedChannel;
use tonic::{Code, Request, Response, Status, Streaming};
use tracing::{Span, debug, info, instrument};
use tracing::Span;

use super::{
BoxStream, Client, Error, create_grpc_client, errors::grpc_to_http_code,
Expand All @@ -36,7 +36,7 @@ use crate::{
BidiStreamingChunkerTokenizationTaskRequest, ChunkerTokenizationTaskRequest,
chunkers_service_client::ChunkersServiceClient,
},
caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults},
caikit_data_model::nlp::{ChunkerTokenizationStreamResult, TokenizationResults},
grpc::health::v1::{HealthCheckRequest, health_client::HealthClient},
},
utils::trace::trace_context_from_grpc_response,
Expand All @@ -50,14 +50,12 @@ pub const DEFAULT_CHUNKER_ID: &str = "whole_doc_chunker";
type StreamingTokenizationResult =
Result<Response<Streaming<ChunkerTokenizationStreamResult>>, Status>;

#[cfg_attr(test, faux::create)]
#[derive(Clone)]
pub struct ChunkerClient {
client: ChunkersServiceClient<OtelGrpcService<LoadBalancedChannel>>,
health_client: HealthClient<OtelGrpcService<LoadBalancedChannel>>,
}

#[cfg_attr(test, faux::methods)]
impl ChunkerClient {
pub async fn new(config: &ServiceConfig) -> Self {
let client = create_grpc_client(DEFAULT_PORT, config, ChunkersServiceClient::new).await;
Expand All @@ -68,28 +66,24 @@ impl ChunkerClient {
}
}

#[instrument(skip_all, fields(model_id))]
pub async fn tokenization_task_predict(
&self,
model_id: &str,
request: ChunkerTokenizationTaskRequest,
) -> Result<TokenizationResults, Error> {
let mut client = self.client.clone();
let request = request_with_headers(request, model_id);
debug!(?request, "sending client request");
let response = client.chunker_tokenization_task_predict(request).await?;
let span = Span::current();
trace_context_from_grpc_response(&span, &response);
Ok(response.into_inner())
}

#[instrument(skip_all, fields(model_id))]
pub async fn bidi_streaming_tokenization_task_predict(
&self,
model_id: &str,
request_stream: BoxStream<BidiStreamingChunkerTokenizationTaskRequest>,
) -> Result<BoxStream<Result<ChunkerTokenizationStreamResult, Error>>, Error> {
info!("sending client stream request");
let mut client = self.client.clone();
let request = request_with_headers(request_stream, model_id);
// NOTE: this is an ugly workaround to avoid bogus higher-ranked lifetime errors.
Expand All @@ -103,7 +97,6 @@ impl ChunkerClient {
}
}

#[cfg_attr(test, faux::methods)]
#[async_trait]
impl Client for ChunkerClient {
fn name(&self) -> &str {
Expand Down Expand Up @@ -144,108 +137,3 @@ fn request_with_headers<T>(request: T, model_id: &str) -> Request<T> {
.insert(MODEL_ID_HEADER_NAME, model_id.parse().unwrap());
request
}

/// Unary tokenization result of the entire doc
#[instrument(skip_all)]
pub fn tokenize_whole_doc(request: ChunkerTokenizationTaskRequest) -> TokenizationResults {
let codepoint_count = request.text.chars().count() as i64;
TokenizationResults {
results: vec![Token {
start: 0,
end: codepoint_count,
text: request.text,
}],
token_count: 1, // entire doc
}
}

/// Streaming tokenization result for the entire doc stream
#[instrument(skip_all)]
pub async fn tokenize_whole_doc_stream(
request: impl Stream<Item = BidiStreamingChunkerTokenizationTaskRequest>,
) -> Result<ChunkerTokenizationStreamResult, Error> {
let (text, index_vec): (String, Vec<i64>) = request
.map(|r| (r.text_stream, r.input_index_stream))
.collect()
.await;
let codepoint_count = text.chars().count() as i64;
let input_end_index = index_vec.last().copied().unwrap_or_default();
Ok(ChunkerTokenizationStreamResult {
results: vec![Token {
start: 0,
end: codepoint_count,
text,
}],
token_count: 1, // entire doc/stream
processed_index: codepoint_count,
start_index: 0,
input_start_index: 0,
input_end_index,
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_tokenize_whole_doc() {
let request = ChunkerTokenizationTaskRequest {
text: "Lorem ipsum dolor sit amet consectetur adipiscing \
elit sed do eiusmod tempor incididunt ut labore et dolore \
magna aliqua."
.into(),
};
let expected_response = TokenizationResults {
results: vec![Token {
start: 0,
end: 121,
text: "Lorem ipsum dolor sit amet consectetur \
adipiscing elit sed do eiusmod tempor incididunt \
ut labore et dolore magna aliqua."
.into(),
}],
token_count: 1,
};
let response = tokenize_whole_doc(request);
assert_eq!(response, expected_response)
}

#[tokio::test]
async fn test_tokenize_whole_doc_stream() {
let request = futures::stream::iter(vec![
BidiStreamingChunkerTokenizationTaskRequest {
text_stream: "Lorem ipsum dolor sit amet ".into(),
input_index_stream: 0,
},
BidiStreamingChunkerTokenizationTaskRequest {
text_stream: "consectetur adipiscing elit ".into(),
input_index_stream: 1,
},
BidiStreamingChunkerTokenizationTaskRequest {
text_stream: "sed do eiusmod tempor incididunt ".into(),
input_index_stream: 2,
},
BidiStreamingChunkerTokenizationTaskRequest {
text_stream: "ut labore et dolore magna aliqua.".into(),
input_index_stream: 3,
},
]);
let expected_response = ChunkerTokenizationStreamResult {
results: vec![Token {
start: 0,
end: 121,
text: "Lorem ipsum dolor sit amet consectetur adipiscing elit \
sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
.into(),
}],
token_count: 1,
processed_index: 121,
start_index: 0,
input_start_index: 0,
input_end_index: 3,
};
let response = tokenize_whole_doc_stream(request).await.unwrap();
assert_eq!(response, expected_response);
}
}
Loading