diff --git a/CODEOWNERS b/CODEOWNERS index 3a458b63..de0fb03d 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -8,4 +8,4 @@ # https://help.github.com/en/articles/about-code-owners # -* @gkumbhat @evaline-ju @declark1 \ No newline at end of file +* @gkumbhat @evaline-ju @declark1 @mdevino diff --git a/Cargo.lock b/Cargo.lock index a4018fe0..8a73e2ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "adler2" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aho-corasick" @@ -28,9 +28,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.18" +version = "0.6.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" dependencies = [ "anstyle", "anstyle-parse", @@ -43,36 +43,36 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" [[package]] name = "anstyle-parse" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" dependencies = [ "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.7" +version = "3.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" dependencies = [ "anstyle", - "once_cell", + "once_cell_polyfill", "windows-sys 0.59.0", ] @@ -92,28 +92,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "async-trait" version = "0.1.88" @@ -139,15 +117,15 @@ checksum = "3c1e7e457ea78e524f48639f551fd79703ac3f2237f5ecccdf4708f8a75ad373" [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.13.1" +version = "1.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" +checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba" dependencies = [ "aws-lc-sys", "zeroize", @@ -155,9 +133,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.29.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" dependencies = [ "bindgen", "cc", @@ -166,40 +144,13 @@ dependencies = [ "fs_extra", ] -[[package]] -name = "axum" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" -dependencies = [ - "async-trait", - "axum-core 0.4.5", - "bytes", - "futures-util", - "http", - "http-body", - "http-body-util", - "itoa", - "matchit 0.7.3", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "sync_wrapper", - "tower 0.5.2", - "tower-layer", - "tower-service", -] - [[package]] name = "axum" version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ - "axum-core 0.5.2", + "axum-core", "bytes", "form_urlencoded", "futures-util", @@ -209,7 +160,7 @@ dependencies = [ "hyper", "hyper-util", "itoa", - "matchit 0.8.4", + "matchit", "memchr", "mime", "percent-encoding", @@ -221,32 +172,12 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tower 0.5.2", + "tower", "tower-layer", "tower-service", "tracing", ] -[[package]] -name = "axum-core" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "http", - "http-body", - "http-body-util", - "mime", - "pin-project-lite", - "rustversion", - "sync_wrapper", - "tower-layer", - "tower-service", -] - [[package]] name = "axum-core" version = "0.5.2" @@ -273,8 +204,8 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d" dependencies = [ - "axum 0.8.4", - "axum-core 0.5.2", + "axum", + "axum-core", "bytes", "futures-util", "http", @@ -288,7 +219,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tower 0.5.2", + "tower", "tower-layer", "tower-service", ] @@ -302,7 +233,7 @@ dependencies = [ "anyhow", "assert-json-diff", "auto-future", - "axum 0.8.4", + "axum", "bytes", "bytesize", "cookie", @@ -319,15 +250,15 @@ dependencies = [ "serde_urlencoded", "smallvec", "tokio", - "tower 0.5.2", + "tower", "url", ] [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -375,9 +306,9 @@ checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytes" @@ -393,9 +324,9 @@ checksum = "a3c8f83209414aacf0eeae3cf730b18d6981697fba62f200fcfb92b9f082acba" [[package]] name = "cc" -version = "1.2.24" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16595d3be041c03b09d08d0858631facccee9221e579704070e6e9e4915d3bc7" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "jobserver", "libc", @@ -413,9 +344,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" [[package]] name = "cfg_aliases" @@ -436,9 +367,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.39" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -446,9 +377,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.39" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", @@ -458,9 +389,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.32" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", @@ -470,9 +401,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" [[package]] name = "cmake" @@ -485,9 +416,9 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] name = "cookie" @@ -511,9 +442,9 @@ dependencies = [ [[package]] name = "core-foundation" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" dependencies = [ "core-foundation-sys", "libc", @@ -525,11 +456,31 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" [[package]] name = "deranged" @@ -601,9 +552,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3716d7a920fb4fac5d84e9d4bce8ceb321e9414b4409da61b07b75c1e3d0697" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -619,12 +570,12 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -652,15 +603,16 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "fms-guardrails-orchestr8" -version = "0.10.0" +version = "0.13.0" dependencies = [ "anyhow", "async-trait", - "axum 0.8.4", + "axum", "axum-extra", "axum-test", "bytes", "clap", + "dashmap", "eventsource-stream", "futures", "futures-util", @@ -680,11 +632,12 @@ dependencies = [ "opentelemetry_sdk", "pin-project-lite", "prost", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "rustls", "rustls-pemfile", - "rustls-webpki 0.102.8", + "rustls-pki-types", + "rustls-webpki", "serde", "serde_json", "serde_yml", @@ -695,7 +648,7 @@ dependencies = [ "tokio-stream", "tonic", "tonic-build", - "tower 0.5.2", + "tower", "tower-http", "tracing", "tracing-opentelemetry", @@ -838,7 +791,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -865,8 +818,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "ginepro" version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a71e958d6edde3a87f7341c81c978fccc013950257d49a8e2da2a2361946924" +source = "git+https://github.com/gkumbhat/ginepro?rev=863ca186f37abf5997126aa97e85b56ca288a76c#863ca186f37abf5997126aa97e85b56ca288a76c" dependencies = [ "anyhow", "async-trait", @@ -875,7 +827,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tonic", - "tower 0.4.13", + "tower", "tracing", ] @@ -887,9 +839,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "h2" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" +checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785" dependencies = [ "atomic-waker", "bytes", @@ -897,7 +849,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.9.0", + "indexmap", "slab", "tokio", "tokio-util", @@ -906,15 +858,15 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.12.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" [[package]] name = "hashbrown" -version = "0.15.3" +version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" [[package]] name = "heck" @@ -976,17 +928,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "hostname" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9c7c7c8ac16c798734b8a24560c1362120597c40d5e1459f09498f8f6c8f2ba" -dependencies = [ - "cfg-if", - "libc", - "windows", -] - [[package]] name = "http" version = "1.3.1" @@ -1066,9 +1007,9 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.6" +version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http", "hyper", @@ -1080,7 +1021,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots 1.0.0", + "webpki-roots 1.0.2", ] [[package]] @@ -1114,9 +1055,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.13" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c293b6b3d21eca78250dc7dbebd6b9210ec5530e038cbfe0661b5c47ab06e8" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "base64", "bytes", @@ -1130,7 +1071,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.0", "system-configuration", "tokio", "tower-service", @@ -1247,22 +1188,23 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ - "autocfg", - "hashbrown 0.12.3", + "equivalent", + "hashbrown 0.15.4", ] [[package]] -name = "indexmap" -version = "2.9.0" +name = "io-uring" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ - "equivalent", - "hashbrown 0.15.3", + "bitflags", + "cfg-if", + "libc", ] [[package]] @@ -1271,7 +1213,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2", + "socket2 0.5.10", "widestring", "windows-sys 0.48.0", "winreg", @@ -1357,18 +1299,18 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.172" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libloading" -version = "0.8.6" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.53.2", ] [[package]] @@ -1407,9 +1349,9 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -1445,12 +1387,6 @@ dependencies = [ "regex-automata 0.1.10", ] -[[package]] -name = "matchit" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" - [[package]] name = "matchit" version = "0.8.4" @@ -1459,9 +1395,9 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "memchr" -version = "2.7.4" +version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "mime" @@ -1477,9 +1413,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", ] @@ -1491,14 +1427,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.59.0", ] [[package]] name = "mocktail" -version = "0.2.5-alpha" -source = "git+https://github.com/IBM/mocktail#025d724965f5d4ee7cc6666bf22845a896b00b58" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053f7ba52863e22dfd2970075bbc69c4224ca6ae03896a5f69a0d5982deb5e0a" dependencies = [ "bytes", "futures", @@ -1508,7 +1445,7 @@ dependencies = [ "hyper", "hyper-util", "prost", - "rand 0.9.1", + "rand 0.9.2", "serde", "serde_json", "thiserror 2.0.12", @@ -1521,9 +1458,9 @@ dependencies = [ [[package]] name = "multimap" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" [[package]] name = "native-tls" @@ -1583,11 +1520,17 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" + [[package]] name = "openssl" -version = "0.10.72" +version = "0.10.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" +checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" dependencies = [ "bitflags", "cfg-if", @@ -1617,9 +1560,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.108" +version = "0.9.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847" +checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" dependencies = [ "cc", "libc", @@ -1629,23 +1572,23 @@ dependencies = [ [[package]] name = "opentelemetry" -version = "0.27.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab70038c28ed37b97d8ed414b6429d343a8bbf44c9f79ec854f3a643029ba6d7" +checksum = "aaf416e4cb72756655126f7dd7bb0af49c674f4c1b9903e80c009e0c37e552e6" dependencies = [ "futures-core", "futures-sink", "js-sys", "pin-project-lite", - "thiserror 1.0.69", + "thiserror 2.0.12", "tracing", ] [[package]] name = "opentelemetry-http" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" +checksum = "50f6639e842a97dbea8886e3439710ae463120091e2e064518ba8e716e6ac36d" dependencies = [ "async-trait", "bytes", @@ -1656,19 +1599,18 @@ dependencies = [ [[package]] name = "opentelemetry-otlp" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" +checksum = "dbee664a43e07615731afc539ca60c6d9f1a9425e25ca09c57bc36c87c55852b" dependencies = [ - "async-trait", - "futures-core", "http", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", "prost", - "thiserror 1.0.69", + "reqwest", + "thiserror 2.0.12", "tokio", "tonic", "tracing", @@ -1676,9 +1618,9 @@ dependencies = [ [[package]] name = "opentelemetry-proto" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6e05acbfada5ec79023c85368af14abd0b307c015e9064d249b2a950ef459a6" +checksum = "2e046fd7660710fe5a05e8748e70d9058dc15c94ba914e7c4faa7c728f0e8ddc" dependencies = [ "opentelemetry", "opentelemetry_sdk", @@ -1688,23 +1630,20 @@ dependencies = [ [[package]] name = "opentelemetry_sdk" -version = "0.27.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "231e9d6ceef9b0b2546ddf52335785ce41252bc7474ee8ba05bfad277be13ab8" +checksum = "11f644aa9e5e31d11896e024305d7e3c98a88884d9f8919dbf37a9991bc47a4b" dependencies = [ - "async-trait", "futures-channel", "futures-executor", "futures-util", - "glob", "opentelemetry", "percent-encoding", - "rand 0.8.5", + "rand 0.9.2", "serde_json", - "thiserror 1.0.69", + "thiserror 2.0.12", "tokio", "tokio-stream", - "tracing", ] [[package]] @@ -1715,9 +1654,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -1725,9 +1664,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -1749,7 +1688,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.9.0", + "indexmap", ] [[package]] @@ -1826,9 +1765,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.32" +version = "0.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" +checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" dependencies = [ "proc-macro2", "syn", @@ -1908,7 +1847,7 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls", - "socket2", + "socket2 0.5.10", "thiserror 2.0.12", "tokio", "tracing", @@ -1924,7 +1863,7 @@ dependencies = [ "bytes", "getrandom 0.3.3", "lru-slab", - "rand 0.9.1", + "rand 0.9.2", "ring", "rustc-hash 2.1.1", "rustls", @@ -1938,14 +1877,14 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.12" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" +checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.5.10", "tracing", "windows-sys 0.59.0", ] @@ -1961,9 +1900,9 @@ dependencies = [ [[package]] name = "r-efi" -version = "5.2.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" @@ -1978,9 +1917,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -2026,9 +1965,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.11" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" +checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ "bitflags", ] @@ -2079,9 +2018,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.18" +version = "0.12.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e98ff6b0dbbe4d5a37318f433d4fc82babd21631f194d370409ceb2e40b2f0b5" +checksum = "cbc931937e6ca3a06e3b6c0aa7841849b160a90351d6ab467a8b9b9959767531" dependencies = [ "base64", "bytes", @@ -2097,12 +2036,10 @@ dependencies = [ "hyper-rustls", "hyper-tls", "hyper-util", - "ipnet", "js-sys", "log", "mime", "native-tls", - "once_cell", "percent-encoding", "pin-project-lite", "quinn", @@ -2116,7 +2053,7 @@ dependencies = [ "tokio-native-tls", "tokio-rustls", "tokio-util", - "tower 0.5.2", + "tower", "tower-http", "tower-service", "url", @@ -2124,26 +2061,23 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 1.0.0", + "webpki-roots 1.0.2", ] [[package]] name = "reserve-port" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3747658ee2585ecf5607fa9887c92eff61b362ff5253dbf797dfeb73d33d78" +checksum = "21918d6644020c6f6ef1993242989bf6d4952d2e025617744f184c02df51c356" dependencies = [ "thiserror 2.0.12", ] [[package]] name = "resolv-conf" -version = "0.7.1" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48375394603e3dd4b2d64371f7148fd8c7baa2680e28741f2cb8d23b59e3d4c4" -dependencies = [ - "hostname", -] +checksum = "95325155c684b1c89f7765e30bc1c42e4a6da51ca513615660cb8a62ef9a88e3" [[package]] name = "ring" @@ -2170,15 +2104,15 @@ dependencies = [ "futures-util", "http", "mime", - "rand 0.9.1", + "rand 0.9.2", "thiserror 2.0.12", ] [[package]] name = "rustc-demangle" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" [[package]] name = "rustc-hash" @@ -2207,29 +2141,29 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "rustls" -version = "0.23.27" +version = "0.23.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" +checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ "aws-lc-rs", "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.3", + "rustls-webpki", "subtle", "zeroize", ] @@ -2267,20 +2201,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.102.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - -[[package]] -name = "rustls-webpki" -version = "0.103.3" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ "aws-lc-rs", "ring", @@ -2335,7 +2258,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ "bitflags", - "core-foundation 0.10.0", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -2373,11 +2296,11 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ - "indexmap 2.9.0", + "indexmap", "itoa", "memchr", "ryu", @@ -2412,7 +2335,7 @@ version = "0.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59e2dd588bf1597a252c3b920e0143eb99b0f76e4e082f4c92ce34fbc9e71ddd" dependencies = [ - "indexmap 2.9.0", + "indexmap", "itoa", "libyml", "memchr", @@ -2447,18 +2370,15 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" [[package]] name = "smallvec" -version = "1.15.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "socket2" @@ -2470,6 +2390,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2490,9 +2420,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.101" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -2549,15 +2479,15 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.7", + "rustix 1.0.8", "windows-sys 0.59.0", ] [[package]] name = "test-log" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" dependencies = [ "env_logger", "test-log-macros", @@ -2566,9 +2496,9 @@ dependencies = [ [[package]] name = "test-log-macros" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" dependencies = [ "proc-macro2", "quote", @@ -2617,12 +2547,11 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -2683,18 +2612,20 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.45.1" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "slab", + "socket2 0.5.10", "tokio-macros", "windows-sys 0.52.0", ] @@ -2757,13 +2688,12 @@ dependencies = [ [[package]] name = "tonic" -version = "0.12.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" dependencies = [ - "async-stream", "async-trait", - "axum 0.7.9", + "axum", "base64", "bytes", "h2", @@ -2777,23 +2707,22 @@ dependencies = [ "pin-project", "prost", "rustls-native-certs", - "rustls-pemfile", - "socket2", + "socket2 0.5.10", "tokio", "tokio-rustls", "tokio-stream", - "tower 0.4.13", + "tower", "tower-layer", "tower-service", "tracing", - "webpki-roots 0.26.8", + "webpki-roots 0.26.11", ] [[package]] name = "tonic-build" -version = "0.12.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11" +checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847" dependencies = [ "prettyplease", "proc-macro2", @@ -2803,26 +2732,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "indexmap 1.9.3", - "pin-project", - "pin-project-lite", - "rand 0.8.5", - "slab", - "tokio", - "tokio-util", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "tower" version = "0.5.2" @@ -2831,9 +2740,12 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", + "indexmap", "pin-project-lite", + "slab", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -2841,9 +2753,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.4" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fdb0c213ca27a9f57ab69ddb290fd80d970922355b83ae380b395d3986b8a2e" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ "bitflags", "bytes", @@ -2852,7 +2764,7 @@ dependencies = [ "http-body", "iri-string", "pin-project-lite", - "tower 0.5.2", + "tower", "tower-layer", "tower-service", "tracing", @@ -2884,9 +2796,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", @@ -2895,9 +2807,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", "valuable", @@ -2916,9 +2828,9 @@ dependencies = [ [[package]] name = "tracing-opentelemetry" -version = "0.28.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a971f6058498b5c0f1affa23e7ea202057a7301dbff68e968b2d578bcbd053" +checksum = "ddcf5959f39507d0d04d6413119c04f33b623f4f951ebcbdddddfad2d0623a9c" dependencies = [ "js-sys", "once_cell", @@ -3012,7 +2924,7 @@ checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", - "rand 0.9.1", + "rand 0.9.2", "wasm-bindgen", ] @@ -3045,9 +2957,9 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" @@ -3164,18 +3076,18 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.8" +version = "0.26.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" dependencies = [ - "rustls-pki-types", + "webpki-roots 1.0.2", ] [[package]] name = "webpki-roots" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2853738d1cc4f2da3a225c18ec6c3721abb31961096e9dbf5ab35fa88b19cfdb" +checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2" dependencies = [ "rustls-pki-types", ] @@ -3220,56 +3132,37 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" -dependencies = [ - "windows-core", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets 0.52.6", -] - [[package]] name = "windows-link" -version = "0.1.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" [[package]] name = "windows-registry" -version = "0.4.0" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" +checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e" dependencies = [ + "windows-link", "windows-result", "windows-strings", - "windows-targets 0.53.0", ] [[package]] name = "windows-result" -version = "0.3.2" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ "windows-link", ] [[package]] name = "windows-strings" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ "windows-link", ] @@ -3301,6 +3194,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3334,9 +3236,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.53.0" +version = "0.53.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" dependencies = [ "windows_aarch64_gnullvm 0.53.0", "windows_aarch64_msvc 0.53.0", @@ -3543,18 +3445,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c162bbaf..33f74361 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fms-guardrails-orchestr8" -version = "0.10.0" +version = "0.13.0" edition = "2024" authors = ["Evaline Ju", "Gaurav Kumbhat", "Dan Clark"] description = "Foundation models orchestration server" @@ -19,49 +19,51 @@ async-trait = "0.1.88" axum = { version = "0.8.4", features = ["json"] } axum-extra = { version = "0.10.1", features = ["json-lines"] } bytes = "1.10.1" -clap = { version = "4.5.39", features = ["derive", "env"] } +clap = { version = "4.5.41", features = ["derive", "env"] } +dashmap = "6.1.0" eventsource-stream = "0.2.3" futures = "0.3.31" futures-util = { version = "0.3", default-features = false, features = [] } -ginepro = "0.8.2" +ginepro = { git = "https://github.com/gkumbhat/ginepro", rev = "863ca186f37abf5997126aa97e85b56ca288a76c" } http = "1.3.1" http-body = "1.0" http-body-util = "0.1.3" http-serde = "2.1.1" hyper = { version = "1.6.0", features = ["http1", "http2", "server"] } -hyper-rustls = { version = "0.27.6", features = ["ring"] } +hyper-rustls = { version = "0.27.7", features = ["ring"] } hyper-timeout = "0.5.2" -hyper-util = { version = "0.1.13", features = [ +hyper-util = { version = "0.1.16", features = [ "server-auto", "server-graceful", "tokio", ] } -opentelemetry = { version = "0.27.1", features = ["metrics", "trace"] } -opentelemetry-http = { version = "0.27.0", features = ["reqwest"] } -opentelemetry-otlp = { version = "0.27.0", features = [ +opentelemetry = { version = "0.30.0", features = ["metrics", "trace"] } +opentelemetry-http = { version = "0.30.0", features = ["reqwest"] } +opentelemetry-otlp = { version = "0.30.0", features = [ "grpc-tonic", "http-proto", ] } -opentelemetry_sdk = { version = "0.27.1", features = ["rt-tokio", "metrics"] } +opentelemetry_sdk = { version = "0.30.0", features = ["rt-tokio", "metrics"] } pin-project-lite = "0.2.16" prost = "0.13.5" -reqwest = { version = "0.12.18", features = [ +reqwest = { version = "0.12.22", features = [ "blocking", "rustls-tls", "json", "stream", ] } -rustls = { version = "0.23.27", default-features = false, features = [ +rustls = { version = "0.23.29", default-features = false, features = [ "ring", "std", ] } rustls-pemfile = "2.2.0" -rustls-webpki = "0.102.8" +rustls-pki-types = "1.12.0" +rustls-webpki = "0.103.4" serde = { version = "1.0.219", features = ["derive"] } -serde_json = { version = "1.0.140", features = ["preserve_order"] } +serde_json = { version = "1.0.141", features = ["preserve_order"] } serde_yml = "0.0.12" thiserror = "2.0.12" -tokio = { version = "1.45.1", features = [ +tokio = { version = "1.46.1", features = [ "rt", "rt-multi-thread", "parking_lot", @@ -71,27 +73,27 @@ tokio = { version = "1.45.1", features = [ ] } tokio-rustls = { version = "0.26.2", features = ["ring"] } tokio-stream = { version = "0.1.17", features = ["sync"] } -tonic = { version = "0.12.3", features = [ - "tls", - "tls-roots", +tonic = { version = "0.13.1", features = [ + "tls-ring", + "tls-native-roots", "tls-webpki-roots", ] } tower = { version = "0.5.2", features = ["timeout"] } -tower-http = { version = "0.6.4", features = ["trace"] } +tower-http = { version = "0.6.6", features = ["trace"] } tracing = "0.1.41" -tracing-opentelemetry = "0.28.0" +tracing-opentelemetry = "0.31.0" tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } url = "2.5.4" uuid = { version = "1.17.0", features = ["v4"] } [build-dependencies] -tonic-build = "0.12.3" +tonic-build = "0.13.1" [dev-dependencies] axum-test = "17.3.0" -mocktail = { git = "https://github.com/IBM/mocktail" } -rand = "0.9.1" -test-log = "0.2.17" +mocktail = "0.3.0" +rand = "0.9.2" +test-log = "0.2.18" [profile.release] debug = false diff --git a/Dockerfile b/Dockerfile.amd64 similarity index 56% rename from Dockerfile rename to Dockerfile.amd64 index 69ca4f6b..a08ed27b 100644 --- a/Dockerfile +++ b/Dockerfile.amd64 @@ -4,32 +4,39 @@ ARG PROTOC_VERSION=29.3 ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ -# Specific debian version so that compatible glibc version is used -FROM rust:1.87.0 AS rust-builder +FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} AS rust-builder ARG PROTOC_VERSION ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse +# Install dependencies +RUN microdnf --disableplugin=subscription-manager -y update && \ + microdnf install --disableplugin=subscription-manager -y \ + unzip \ + ca-certificates \ + openssl-devel \ + gcc && \ + microdnf clean all + +COPY rust-toolchain.toml rust-toolchain.toml + +# Install rustup [needed for latest Rust versions] +RUN curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain none -y --no-modify-path && \ + . "$HOME/.cargo/env" && \ + rustup install && \ + rustup component add rustfmt + +# Set PATH so rustc, cargo, rustup are available +ENV PATH="/root/.cargo/bin:${PATH}" + # Install protoc, no longer included in prost crate RUN cd /tmp && \ - if [ "$(uname -m)" = "s390x" ]; then \ - apt update && \ - apt install -y cmake clang libclang-dev curl unzip && \ - curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-s390_64.zip; \ - else \ - curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip; \ - fi && \ - unzip protoc-*.zip -d /usr/local && \ - rm protoc-*.zip - + curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-x86_64.zip && \ + unzip protoc-*.zip -d /usr/local && rm protoc-*.zip ENV LIBCLANG_PATH=/usr/lib/llvm-14/lib/ WORKDIR /app -COPY rust-toolchain.toml rust-toolchain.toml - -RUN rustup component add rustfmt - ## Orchestrator builder ######################################################### FROM rust-builder AS fms-guardrails-orchestr8-builder ARG CONFIG_FILE=config/config.yaml @@ -38,30 +45,22 @@ COPY build.rs *.toml LICENSE /app/ COPY ${CONFIG_FILE} /app/config/config.yaml COPY protos/ /app/protos/ COPY src/ /app/src/ +COPY tests/ /app/tests/ WORKDIR /app # TODO: Make releases via cargo-release -RUN cargo install --root /app/ --path . - -## Tests stage ################################################################## -FROM fms-guardrails-orchestr8-builder AS tests +RUN cargo build --release +# Copy test resources required for executing unit tests +COPY tests/resources /app/tests/resources RUN cargo test -## Lint stage ################################################################### -FROM fms-guardrails-orchestr8-builder AS lint -RUN cargo clippy --all-targets --all-features -- -D warnings - -## Formatting check stage ####################################################### -FROM fms-guardrails-orchestr8-builder AS format -RUN cargo +nightly fmt --check - ## Release Image ################################################################ FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} AS fms-guardrails-orchestr8-release ARG CONFIG_FILE=config/config.yaml -COPY --from=fms-guardrails-orchestr8-builder /app/bin/ /app/bin/ +COPY --from=fms-guardrails-orchestr8-builder /app/target/release/fms-guardrails-orchestr8 /app/bin/ COPY ${CONFIG_FILE} /app/config/config.yaml RUN microdnf install -y --disableplugin=subscription-manager shadow-utils compat-openssl11 && \ diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 30829b7c..30700bf1 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -4,21 +4,41 @@ ARG PROTOC_VERSION=29.3 ARG CONFIG_FILE=config/config.yaml ## Rust builder ################################################################ -# Specific debian version so that compatible glibc version is used -FROM rust:1.87.0 AS rust-builder +FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} AS rust-builder ARG PROTOC_VERSION ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse +# Install dependencies +RUN microdnf --disableplugin=subscription-manager -y update && \ + microdnf install --disableplugin=subscription-manager -y \ + unzip \ + ca-certificates \ + openssl-devel \ + gcc \ + cmake \ + clang \ + clang-devel && \ + microdnf clean all + +COPY rust-toolchain.toml rust-toolchain.toml + +# Install rustup [needed for latest Rust versions] +RUN curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain none -y --no-modify-path && \ + . "$HOME/.cargo/env" && \ + rustup install && \ + rustup component add rustfmt + +# Set PATH so rustc, cargo, rustup are available +ENV PATH="/root/.cargo/bin:${PATH}" + # Install protoc, no longer included in prost crate RUN cd /tmp && \ - apt update && \ - apt install -y cmake libclang-dev && \ curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-ppcle_64.zip; \ unzip protoc-*.zip -d /usr/local && \ rm protoc-*.zip -ENV LIBCLANG_PATH=/usr/lib/llvm-14/lib/ +ENV LIBCLANG_PATH=/usr/lib64/ WORKDIR /app @@ -67,4 +87,4 @@ HEALTHCHECK NONE ENV ORCHESTRATOR_CONFIG=/app/config/config.yaml -CMD ["/app/bin/fms-guardrails-orchestr8"] +CMD ["/app/bin/fms-guardrails-orchestr8"] \ No newline at end of file diff --git a/Dockerfile.s390x b/Dockerfile.s390x new file mode 100644 index 00000000..54ca44e9 --- /dev/null +++ b/Dockerfile.s390x @@ -0,0 +1,84 @@ +ARG UBI_MINIMAL_BASE_IMAGE=registry.access.redhat.com/ubi9/ubi-minimal +ARG UBI_BASE_IMAGE_TAG=latest +ARG PROTOC_VERSION=29.3 +ARG CONFIG_FILE=config/config.yaml + +## Rust builder ################################################################ +FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} AS rust-builder +ARG PROTOC_VERSION + +ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +# Install dependencies +RUN microdnf --disableplugin=subscription-manager -y update && \ + microdnf install --disableplugin=subscription-manager -y \ + unzip \ + ca-certificates \ + openssl-devel \ + gcc \ + cmake \ + clang \ + clang-devel && \ + microdnf clean all + +COPY rust-toolchain.toml rust-toolchain.toml + +# Install rustup [needed for latest Rust versions] +RUN curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain none -y --no-modify-path && \ + . "$HOME/.cargo/env" && \ + rustup install && \ + rustup component add rustfmt + +# Set PATH so rustc, cargo, rustup are available +ENV PATH="/root/.cargo/bin:${PATH}" + +# Install protoc, no longer included in prost crate +RUN cd /tmp && \ + curl -L -O https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-s390_64.zip && \ + unzip protoc-*.zip -d /usr/local && rm protoc-*.zip +ENV LIBCLANG_PATH=/usr/lib64/ + +WORKDIR /app + +## Orchestrator builder ######################################################### +FROM rust-builder AS fms-guardrails-orchestr8-builder +ARG CONFIG_FILE=config/config.yaml + +COPY build.rs *.toml LICENSE /app/ +COPY ${CONFIG_FILE} /app/config/config.yaml +COPY protos/ /app/protos/ +COPY src/ /app/src/ +COPY tests/ /app/tests/ + +WORKDIR /app + +ENV CFLAGS="-Wno-string-compare" +# TODO: Make releases via cargo-release +RUN cargo build --release +# Copy test resources required for executing unit tests +COPY tests/resources /app/tests/resources +RUN cargo test + +## Release Image ################################################################ + +FROM ${UBI_MINIMAL_BASE_IMAGE}:${UBI_BASE_IMAGE_TAG} AS fms-guardrails-orchestr8-release +ARG CONFIG_FILE=config/config.yaml + +COPY --from=fms-guardrails-orchestr8-builder /app/target/release/fms-guardrails-orchestr8 /app/bin/ +COPY ${CONFIG_FILE} /app/config/config.yaml + +RUN microdnf install -y --disableplugin=subscription-manager shadow-utils compat-openssl11 && \ + microdnf clean all --disableplugin=subscription-manager + +RUN groupadd --system orchestr8 --gid 1001 && \ + adduser --system --uid 1001 --gid 0 --groups orchestr8 \ + --create-home --home-dir /app --shell /sbin/nologin \ + --comment "FMS Orchestrator User" orchestr8 + +USER orchestr8 + +HEALTHCHECK NONE + +ENV ORCHESTRATOR_CONFIG=/app/config/config.yaml + +CMD ["/app/bin/fms-guardrails-orchestr8"] diff --git a/build.rs b/build.rs index eb309e4b..eb770dff 100644 --- a/build.rs +++ b/build.rs @@ -17,7 +17,7 @@ fn main() -> Result<(), Box> { ], &["protos"], ) - .unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); Ok(()) } diff --git a/config/config.yaml b/config/config.yaml index ef5be2c4..a95b4003 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -12,7 +12,7 @@ generation: hostname: localhost port: 8033 # Generation server used for chat endpoints -# chat_generation: +# openai: # service: # hostname: localhost # port: 8080 @@ -33,6 +33,7 @@ detectors: # Detector ID/name to be used in user requests hap-en: # Detector type (text_contents, text_generation, text_chat, text_context_doc) + # NOTE: can be a string or a list for multiple detector types. type: text_contents service: hostname: localhost diff --git a/docs/api/openapi_detector_api.yaml b/docs/api/openapi_detector_api.yaml index f5f19419..475040d7 100644 --- a/docs/api/openapi_detector_api.yaml +++ b/docs/api/openapi_detector_api.yaml @@ -250,7 +250,7 @@ components: type: array items: allOf: - - $ref: https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml#/components/schemas/ChatCompletionRequestMessage + - $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/ChatCompletionRequestMessage - type: object tools: type: array @@ -260,7 +260,7 @@ components: the model may generate JSON inputs for. A max of 128 functions are supported. items: - $ref: https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml#/components/schemas/ChatCompletionTool + $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/ChatCompletionTool detector_params: type: object default: {} diff --git a/docs/api/orchestrator_openapi_0_1_0.yaml b/docs/api/orchestrator_openapi_0_1_0.yaml index b25d409b..5e63dc48 100644 --- a/docs/api/orchestrator_openapi_0_1_0.yaml +++ b/docs/api/orchestrator_openapi_0_1_0.yaml @@ -8,7 +8,9 @@ tags: - name: Task - Detection description: Standalone detections - name: Task - Chat Completions, with detection - description: Detections on list of messages comprising a conversation and/or completions from a model + description: Detections on list of messages comprising a conversation and/or chat completions from a model + - name: Task - Completions, with detection + description: Detections on model prompt and/or completions from a model paths: /health: get: @@ -74,6 +76,12 @@ paths: application/json: schema: $ref: "#/components/schemas/ClassifiedGeneratedTextResult" + "400": + description: Bad Request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" "404": description: Resource Not Found content: @@ -106,6 +114,12 @@ paths: text/event-stream: schema: $ref: "#/components/schemas/ClassifiedGeneratedTextStreamResult" + "400": + description: Bad Request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" "404": description: Resource Not Found content: @@ -351,6 +365,51 @@ paths: application/json: schema: $ref: "#/components/schemas/GuardrailsCreateChatCompletionResponse" + "400": + description: Bad Request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Resource Not Found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "422": + description: Validation Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /api/v2/text/completions-detection: + post: + tags: + - Task - Completions, with detection + operationId: >- + api_v2_text_completions_detection_handler + summary: Creates a model response with detections for the given prompt + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/GuardrailsCreateCompletionRequest" + responses: + "200": + description: Successful Response + content: + application/json: + schema: + $ref: "#/components/schemas/GuardrailsCreateCompletionResponse" + "400": + description: Bad Request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" "404": description: Resource Not Found content: @@ -505,7 +564,7 @@ components: minItems: 1 items: allOf: - - $ref: https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml#/components/schemas/ChatCompletionRequestMessage + - $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/ChatCompletionRequestMessage - type: object tools: type: array @@ -515,7 +574,7 @@ components: the model may generate JSON inputs for. A max of 128 functions are supported. items: - $ref: https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml#/components/schemas/ChatCompletionTool + $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/ChatCompletionTool additionalProperties: false required: ["detectors", "messages"] type: object @@ -865,9 +924,9 @@ components: ########################## Chat Completion ################################# GuardrailsCreateChatCompletionRequest: title: Guardrails Chat Completion Request - description: Guardrails chat completion request (adds detectors on OpenAI chat completion) + description: Guardrails chat completion request (adds detectors on OpenAI chat completions) allOf: - - $ref: https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml#/components/schemas/CreateChatCompletionRequest + - $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/CreateChatCompletionRequest - type: object properties: detectors: @@ -878,9 +937,9 @@ components: GuardrailsCreateChatCompletionResponse: title: Guardrails Chat Completion Response - description: Guardrails chat completion response (adds detections on OpenAI chat completion) + description: Guardrails chat completion response (adds detections on OpenAI chat completions) allOf: - - $ref: https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml#/components/schemas/CreateChatCompletionResponse + - $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/CreateChatCompletionResponse - type: object properties: detections: @@ -892,25 +951,6 @@ components: required: - detections - Detectors: - title: Guardrails Detectors - description: Specify detectors for guardrails - properties: - input: - type: object - title: Input Detectors - default: {} - output: - type: object - title: Output Detectors - default: {} - example: - input: - hap-v1-model-en: {} - output: - pii-v1: {} - conversation-detector: {} - ChatCompletionsDetections: title: Chat Completions Detections properties: @@ -937,7 +977,7 @@ components: "text": "string", "detection_type": "HAP", "detection": "has_HAP", - "detector_id": "hap-v1-model-en", # Future addition + "detector_id": "hap-v1-model-en", "score": 0.999, } output: @@ -950,16 +990,15 @@ components: "text": "string", "detection_type": "HAP", "detection": "has_HAP", - "detector_id": "hap-v1-model-en", # Future addition + "detector_id": "hap-v1-model-en", "score": 0.999, } - { "detection_type": "string", "detection": "string", - "detector_id": "relevance-v1-en", # Future addition + "detector_id": "relevance-v1-en", "score": 0, } - MessageDetections: title: Message Detections properties: @@ -978,6 +1017,95 @@ components: - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: - message_index + + ########################## Completion ################################# + GuardrailsCreateCompletionRequest: + title: Guardrails Completion Request + description: Guardrails text completion request (adds detectors on OpenAI completions) + allOf: + - $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/CreateCompletionRequest + - type: object + properties: + detectors: + $ref: "#/components/schemas/Detectors" + default: {} + required: + - detectors + + GuardrailsCreateCompletionResponse: + title: Guardrails Completion Response + description: Guardrails completion response (adds detections on OpenAI completions) + allOf: + - $ref: https://raw.githubusercontent.com/openai/openai-openapi/manual_spec/openapi.yaml#/components/schemas/CreateCompletionResponse + - type: object + properties: + detections: + $ref: "#/components/schemas/CompletionsDetections" + warnings: + type: array + items: + $ref: "#/components/schemas/Warning" + required: + - detections + + CompletionsDetections: + title: Completions Detections + properties: + input: + type: array + items: + $ref: "#/components/schemas/PromptDetections" + title: Detections on prompt for completions + default: {} + output: + type: array + items: + $ref: "#/components/schemas/ChoiceDetections" + title: Detections on output of completions + default: {} + default: {} + example: + input: + - results: + - { + "start": 0, + "end": 80, + "text": "string", + "detection_type": "HAP", + "detection": "has_HAP", + "detector_id": "hap-v1-model-en", + "score": 0.999, + } + output: + - choice_index: 0 + - choice_index: 1 + results: + - { + "start": 0, + "end": 20, + "text": "string", + "detection_type": "HAP", + "detection": "has_HAP", + "detector_id": "hap-v1-model-en", + "score": 0.999, + } + - { + "detection_type": "string", + "detection": "string", + "detector_id": "relevance-v1-en", + "score": 0, + } + PromptDetections: + title: Prompt Detections + properties: + results: + title: Detection results + type: array + items: + anyOf: + - $ref: "#/components/schemas/DetectionContentResponseObject" + + ########################## General ################################# ChoiceDetections: title: Choice Detections properties: @@ -996,8 +1124,24 @@ components: - $ref: "#/components/schemas/GeneratedTextDetectionResponseObject" required: - choice_index - - ########################## General ################################# + Detectors: + title: Guardrails Detectors + description: Specify detectors for guardrails + properties: + input: + type: object + title: Input Detectors + default: {} + output: + type: object + title: Output Detectors + default: {} + example: + input: + hap-v1-model-en: {} + output: + pii-v1: {} + conversation-detector: {} Error: type: object properties: diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 5675074a..7855e6d5 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.87.0" +channel = "1.88.0" components = ["rustfmt", "clippy"] diff --git a/src/args.rs b/src/args.rs index 152143ab..5237b3cd 100644 --- a/src/args.rs +++ b/src/args.rs @@ -19,6 +19,12 @@ use std::{fmt::Display, path::PathBuf}; use clap::Parser; use tracing::{error, warn}; +use url::Url; + +use crate::{ + models::ValidationError, + utils::trace::{DEFAULT_GRPC_OTLP_ENDPOINT, DEFAULT_HTTP_OTLP_ENDPOINT}, +}; #[derive(Parser, Debug, Clone)] #[clap(author, version, about, long_about = None)] @@ -50,11 +56,11 @@ pub struct Args { #[clap(default_value = "fms_guardrails_orchestr8", long, env)] pub otlp_service_name: String, #[clap(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")] - pub otlp_endpoint: Option, + pub otlp_endpoint: Option, #[clap(long, env = "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")] - pub otlp_traces_endpoint: Option, + pub otlp_traces_endpoint: Option, #[clap(long, env = "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")] - pub otlp_metrics_endpoint: Option, + pub otlp_metrics_endpoint: Option, #[clap( default_value_t = OtlpProtocol::Grpc, long, @@ -89,8 +95,7 @@ impl From for OtlpExport { "traces" => OtlpExport::Traces, "metrics" => OtlpExport::Metrics, _ => panic!( - "Invalid OTLP export type {}, orchestrator only supports exporting traces and metrics via OTLP", - s + "Invalid OTLP export type {s}, orchestrator only supports exporting traces and metrics via OTLP" ), } } @@ -129,15 +134,6 @@ impl From for OtlpProtocol { } } -impl OtlpProtocol { - pub fn default_endpoint(&self) -> &str { - match self { - OtlpProtocol::Grpc => "http://localhost:4317", - OtlpProtocol::Http => "http://localhost:4318", - } - } -} - #[derive(Debug, Clone, Copy, Default, PartialEq)] pub enum LogFormat { Compact, @@ -186,29 +182,91 @@ pub struct TracingConfig { pub quiet: bool, } -impl From for TracingConfig { - fn from(args: Args) -> Self { +impl TryFrom for TracingConfig { + type Error = ValidationError; + + fn try_from(args: Args) -> Result { + use OtlpProtocol::*; let otlp_protocol = args.otlp_protocol; - let otlp_endpoint = args - .otlp_endpoint - .unwrap_or(otlp_protocol.default_endpoint().to_string()); - let otlp_traces_endpoint = args.otlp_traces_endpoint.unwrap_or(otlp_endpoint.clone()); - let otlp_metrics_endpoint = args.otlp_metrics_endpoint.unwrap_or(otlp_endpoint.clone()); + // Use provided otlp_traces_protocol or default to otlp_protocol let otlp_traces_protocol = args.otlp_traces_protocol.unwrap_or(otlp_protocol); + // Use provided otlp_metrics_protocol or default to otlp_protocol let otlp_metrics_protocol = args.otlp_metrics_protocol.unwrap_or(otlp_protocol); - TracingConfig { + // Validate provided endpoints + if let Some(endpoint) = &args.otlp_endpoint { + if endpoint.path() != "/" { + return Err(ValidationError::Invalid("invalid otlp_endpoint".into())); + } + } + if let Some(endpoint) = &args.otlp_traces_endpoint { + match otlp_traces_protocol { + Grpc => { + if endpoint.path() != "/" { + return Err(ValidationError::Invalid( + "invalid otlp_traces_endpoint for grpc protocol".into(), + )); + } + } + Http => { + if endpoint.path() != "/v1/traces" { + return Err(ValidationError::Invalid("invalid otlp_traces_endpoint for http protocol: path should be /v1/traces".into()) + ); + } + } + } + } + if let Some(endpoint) = &args.otlp_metrics_endpoint { + match otlp_metrics_protocol { + Grpc => { + if endpoint.path() != "/" { + return Err(ValidationError::Invalid( + "invalid otlp_metrics_endpoint for grpc protocol".into(), + )); + } + } + Http => { + if endpoint.path() != "/v1/metrics" { + return Err(ValidationError::Invalid( + "invalid otlp_metrics_endpoint for http protocol: path should be /v1/metrics".into(), + )); + } + } + } + } + + // Use provided otlp_endpoint or default for protocol + let otlp_endpoint = args.otlp_endpoint.unwrap_or(match otlp_protocol { + Grpc => Url::parse(DEFAULT_GRPC_OTLP_ENDPOINT).unwrap(), + Http => Url::parse(DEFAULT_HTTP_OTLP_ENDPOINT).unwrap(), + }); + // Use provided otlp_traces_endpoint or build from otlp_endpoint + let otlp_traces_endpoint = + args.otlp_traces_endpoint + .unwrap_or(match otlp_traces_protocol { + Grpc => otlp_endpoint.clone(), + Http => otlp_endpoint.clone().join("v1/traces").unwrap(), + }); + // Use provided otlp_metrics_endpoint or build from otlp_endpoint + let otlp_metrics_endpoint = + args.otlp_metrics_endpoint + .unwrap_or(match otlp_metrics_protocol { + Grpc => otlp_endpoint.clone(), + Http => otlp_endpoint.clone().join("v1/metrics").unwrap(), + }); + + Ok(TracingConfig { service_name: args.otlp_service_name, traces: match args.otlp_export.contains(&OtlpExport::Traces) { - true => Some((otlp_traces_protocol, otlp_traces_endpoint)), + true => Some((otlp_traces_protocol, otlp_traces_endpoint.into())), false => None, }, metrics: match args.otlp_export.contains(&OtlpExport::Metrics) { - true => Some((otlp_metrics_protocol, otlp_metrics_endpoint)), + true => Some((otlp_metrics_protocol, otlp_metrics_endpoint.into())), false => None, }, log_format: args.log_format, quiet: args.quiet, - } + }) } } diff --git a/src/clients.rs b/src/clients.rs index cf7062b7..303220f3 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -27,7 +27,7 @@ use std::{ use async_trait::async_trait; use axum::http::{Extensions, HeaderMap}; use futures::Stream; -use ginepro::LoadBalancedChannel; +use ginepro::{LoadBalancedChannel, ResolutionStrategy}; use hyper_timeout::TimeoutConnector; use hyper_util::rt::TokioExecutor; use tonic::{Request, metadata::MetadataMap}; @@ -49,9 +49,10 @@ pub mod http; pub use http::{HttpClient, http_trace_layer}; pub mod chunker; +pub use chunker::ChunkerClient; pub mod detector; -pub use detector::TextContentsDetectorClient; +pub use detector::DetectorClient; pub mod tgis; pub use tgis::TgisClient; @@ -70,6 +71,9 @@ pub mod openai; const DEFAULT_CONNECT_TIMEOUT_SEC: u64 = 60; const DEFAULT_REQUEST_TIMEOUT_SEC: u64 = 600; const DEFAULT_GRPC_PROBE_INTERVAL_SEC: u64 = 10; +const DEFAULT_RES_STRATEGY_TIMEOUT_SEC: u64 = 10; +const DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL: u64 = 30; +const DEFAULT_KEEP_ALIVE_TIMEOUT: u64 = 30; pub type BoxStream = Pin + Send>>; @@ -144,27 +148,15 @@ impl ClientMap { self.0.insert(key, Box::new(value)); } - /// Returns a reference to the client trait object. + /// Returns a reference to the concrete client type. #[inline] - pub fn get(&self, key: &str) -> Option<&dyn Client> { - self.0.get(key).map(|v| v.as_ref()) - } - - /// Returns a mutable reference to the client trait object. - #[inline] - pub fn get_mut(&mut self, key: &str) -> Option<&mut dyn Client> { - self.0.get_mut(key).map(|v| v.as_mut()) - } - - /// Downcasts and returns a reference to the concrete client type. - #[inline] - pub fn get_as(&self, key: &str) -> Option<&V> { + pub fn get(&self, key: &str) -> Option<&V> { self.0.get(key)?.downcast_ref::() } - /// Downcasts and returns a mutable reference to the concrete client type. + /// Returns a mutable reference to the concrete client type. #[inline] - pub fn get_mut_as(&mut self, key: &str) -> Option<&mut V> { + pub fn get_mut(&mut self, key: &str) -> Option<&mut V> { self.0.get_mut(key)?.downcast_mut::() } @@ -215,10 +207,10 @@ pub async fn create_http_client( None => "http", }; let mut base_url = Url::parse(&format!("{}://{}", protocol, &service_config.hostname)) - .unwrap_or_else(|e| panic!("error parsing base url: {}", e)); + .unwrap_or_else(|e| panic!("error parsing base url: {e}")); base_url .set_port(Some(port)) - .unwrap_or_else(|_| panic!("error setting port: {}", port)); + .unwrap_or_else(|_| panic!("error setting port: {port}")); let connect_timeout = Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SEC); let request_timeout = Duration::from_secs( @@ -278,11 +270,25 @@ pub async fn create_grpc_client( .grpc_dns_probe_interval .unwrap_or(DEFAULT_GRPC_PROBE_INTERVAL_SEC), ); + let resolution_strategy_timeout = Duration::from_secs( + service_config + .resolution_strategy_timeout + .unwrap_or(DEFAULT_RES_STRATEGY_TIMEOUT_SEC), + ); + let resolution_strategy = match &service_config.resolution_strategy { + Some(name) if name == "eager" => ResolutionStrategy::Eager { + timeout: resolution_strategy_timeout, + }, + _ => ResolutionStrategy::Lazy, + }; let mut builder = LoadBalancedChannel::builder((service_config.hostname.clone(), port)) .dns_probe_interval(grpc_dns_probe_interval) .connect_timeout(connect_timeout) - .timeout(request_timeout); - + .timeout(request_timeout) + .keep_alive_while_idle(true) + .keep_alive_timeout(Duration::from_secs(DEFAULT_KEEP_ALIVE_TIMEOUT)) + .http2_keep_alive_interval(Duration::from_secs(DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL)) + .resolution_strategy(resolution_strategy); let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls { let cert_path = tls_config.cert_path.as_ref().unwrap().as_path(); let key_path = tls_config.key_path.as_ref().unwrap().as_path(); diff --git a/src/clients/detector.rs b/src/clients/detector.rs index 05dfd103..72f34416 100644 --- a/src/clients/detector.rs +++ b/src/clients/detector.rs @@ -15,70 +15,62 @@ */ -use std::fmt::Debug; +use std::{collections::BTreeMap, fmt::Debug}; +use async_trait::async_trait; use axum::http::HeaderMap; use http::header::CONTENT_TYPE; use hyper::StatusCode; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use tracing::info; use url::Url; use super::{ Error, - http::{HttpClientExt, JSON_CONTENT_TYPE, RequestBody, ResponseBody}, + http::{JSON_CONTENT_TYPE, RequestBody, ResponseBody}, +}; +use crate::{ + clients::{ + Client, HttpClient, create_http_client, + openai::{Message, Tool}, + }, + config::ServiceConfig, + health::HealthCheckResult, + models::{DetectionResult, DetectorParams, EvidenceObj, Metadata}, }; -pub mod text_contents; -pub use text_contents::*; -pub mod text_chat; -pub use text_chat::*; -pub mod text_context_doc; -pub use text_context_doc::*; -pub mod text_generation; -pub use text_generation::*; - -const DEFAULT_PORT: u16 = 8080; +pub const DEFAULT_PORT: u16 = 8080; +pub const MODEL_HEADER_NAME: &str = "x-model-name"; pub const DETECTOR_ID_HEADER_NAME: &str = "detector-id"; -const MODEL_HEADER_NAME: &str = "x-model-name"; +pub const CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; +pub const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat"; +pub const CONTEXT_DOC_DETECTOR_ENDPOINT: &str = "/api/v1/text/context/doc"; +pub const GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation"; -#[derive(Debug, Clone, Deserialize)] -pub struct DetectorError { - pub code: u16, - pub message: String, +#[derive(Clone)] +pub struct DetectorClient { + client: HttpClient, + health_client: Option, } -impl From for Error { - fn from(error: DetectorError) -> Self { - Error::Http { - code: StatusCode::from_u16(error.code).unwrap(), - message: error.message, - } +impl DetectorClient { + pub async fn new( + config: &ServiceConfig, + health_config: Option<&ServiceConfig>, + ) -> Result { + let client = create_http_client(DEFAULT_PORT, config).await?; + let health_client = if let Some(health_config) = health_config { + Some(create_http_client(DEFAULT_PORT, health_config).await?) + } else { + None + }; + Ok(Self { + client, + health_client, + }) } -} - -/// This trait should be implemented by all detectors. -/// If the detector has an HTTP client (currently all detector clients are HTTP) this trait will -/// implicitly extend the client with an HTTP detector specific post function. -pub trait DetectorClient {} - -/// Provides a helper extension for HTTP detector clients. -pub trait DetectorClientExt: HttpClientExt { - /// Wraps the post function with extra detector functionality - /// (detector id header injection & error handling) - async fn post_to_detector( - &self, - model_id: &str, - url: Url, - headers: HeaderMap, - request: impl RequestBody, - ) -> Result; - - /// Wraps call to inner HTTP client endpoint function. - fn endpoint(&self, path: &str) -> Url; -} -impl DetectorClientExt for C { - async fn post_to_detector( + async fn post( &self, model_id: &str, url: Url, @@ -90,7 +82,7 @@ impl DetectorClientExt for C { // Header used by a router component, if available headers.append(MODEL_HEADER_NAME, model_id.parse().unwrap()); - let response = self.inner().post(url, headers, request).await?; + let response = self.client.post(url, headers, request).await?; let status = response.status(); match status { @@ -106,7 +98,222 @@ impl DetectorClientExt for C { } } - fn endpoint(&self, path: &str) -> Url { - self.inner().endpoint(path) + pub async fn text_contents( + &self, + model_id: &str, + request: ContentAnalysisRequest, + headers: HeaderMap, + ) -> Result>, Error> { + let url = self.client.endpoint(CONTENTS_DETECTOR_ENDPOINT); + info!("sending text content detector request to {}", url); + self.post(model_id, url, headers, request).await + } + + pub async fn text_chat( + &self, + model_id: &str, + request: ChatDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self.client.endpoint(CHAT_DETECTOR_ENDPOINT); + info!("sending text chat detector request to {}", url); + self.post(model_id, url, headers, request).await + } + + pub async fn text_context_doc( + &self, + model_id: &str, + request: ContextDocsDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self.client.endpoint(CONTEXT_DOC_DETECTOR_ENDPOINT); + info!("sending text context doc detector request to {}", url); + self.post(model_id, url, headers, request).await + } + + pub async fn text_generation( + &self, + model_id: &str, + request: GenerationDetectionRequest, + headers: HeaderMap, + ) -> Result, Error> { + let url = self.client.endpoint(GENERATION_DETECTOR_ENDPOINT); + info!("sending text generation detector request to {}", url); + self.post(model_id, url, headers, request).await + } +} + +#[async_trait] +impl Client for DetectorClient { + fn name(&self) -> &str { + "detector" + } + + async fn health(&self) -> HealthCheckResult { + if let Some(health_client) = &self.health_client { + health_client.health().await + } else { + self.client.health().await + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct DetectorError { + pub code: u16, + pub message: String, +} + +impl From for Error { + fn from(error: DetectorError) -> Self { + Error::Http { + code: StatusCode::from_u16(error.code).unwrap(), + message: error.message, + } + } +} + +/// Request for text content analysis +/// Results of this request will contain analysis / detection of each of the provided documents +/// in the order they are present in the `contents` object. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ContentAnalysisRequest { + /// Field allowing users to provide list of documents for analysis + pub contents: Vec, + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +impl ContentAnalysisRequest { + pub fn new(contents: Vec, detector_params: DetectorParams) -> ContentAnalysisRequest { + ContentAnalysisRequest { + contents, + detector_params, + } + } +} + +/// Response of text content analysis endpoint +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct ContentAnalysisResponse { + /// Start index of detection + pub start: usize, + /// End index of detection + pub end: usize, + /// Text corresponding to detection + pub text: String, + /// Relevant detection class + pub detection: String, + /// Detection type or aggregate detection label + pub detection_type: String, + /// Optional, ID of Detector + pub detector_id: Option, + /// Score of detection + pub score: f64, + /// Optional, any applicable evidence for detection + #[serde(skip_serializing_if = "Option::is_none")] + pub evidence: Option>, + // Optional metadata block + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub metadata: Metadata, +} + +impl From for crate::models::TokenClassificationResult { + fn from(value: ContentAnalysisResponse) -> Self { + Self { + start: value.start as u32, + end: value.end as u32, + word: value.text, + entity: value.detection, + entity_group: value.detection_type, + detector_id: value.detector_id, + score: value.score, + token_count: None, + } + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/chat endpoint. +#[derive(Debug, Clone, Serialize)] +pub struct ChatDetectionRequest { + /// Chat messages to run detection on + pub messages: Vec, + /// Optional list of tool definitions + pub tools: Vec, + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +impl ChatDetectionRequest { + pub fn new(messages: Vec, tools: Vec, detector_params: DetectorParams) -> Self { + Self { + messages, + tools, + detector_params, + } + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/context/doc endpoint. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone, Serialize)] +pub struct ContextDocsDetectionRequest { + /// Content to run detection on + pub content: String, + /// Type of context being sent + pub context_type: ContextType, + /// Context to run detection on + pub context: Vec, + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +/// Enum representing the context type of a detection +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum ContextType { + #[serde(rename = "docs")] + Document, + #[serde(rename = "url")] + Url, +} + +impl ContextDocsDetectionRequest { + pub fn new( + content: String, + context_type: ContextType, + context: Vec, + detector_params: DetectorParams, + ) -> Self { + Self { + content, + context_type, + context, + detector_params, + } + } +} + +/// A struct representing a request to a detector compatible with the +/// /api/v1/text/generation endpoint. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone, Serialize)] +pub struct GenerationDetectionRequest { + /// User prompt sent to LLM + pub prompt: String, + /// Text generated from an LLM + pub generated_text: String, + /// Detector parameters (available parameters depend on the detector) + pub detector_params: DetectorParams, +} + +impl GenerationDetectionRequest { + pub fn new(prompt: String, generated_text: String, detector_params: DetectorParams) -> Self { + Self { + prompt, + generated_text, + detector_params, + } } } diff --git a/src/clients/detector/text_chat.rs b/src/clients/detector/text_chat.rs deleted file mode 100644 index 4eea8bfc..00000000 --- a/src/clients/detector/text_chat.rs +++ /dev/null @@ -1,122 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -use async_trait::async_trait; -use hyper::HeaderMap; -use serde::Serialize; -use tracing::info; - -use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; -use crate::{ - clients::{ - Client, Error, HttpClient, create_http_client, - http::HttpClientExt, - openai::{Message, Tool}, - }, - config::ServiceConfig, - health::HealthCheckResult, - models::{DetectionResult, DetectorParams}, -}; - -const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat"; - -#[derive(Clone)] -pub struct TextChatDetectorClient { - client: HttpClient, - health_client: Option, -} - -impl TextChatDetectorClient { - pub async fn new( - config: &ServiceConfig, - health_config: Option<&ServiceConfig>, - ) -> Result { - let client = create_http_client(DEFAULT_PORT, config).await?; - let health_client = if let Some(health_config) = health_config { - Some(create_http_client(DEFAULT_PORT, health_config).await?) - } else { - None - }; - Ok(Self { - client, - health_client, - }) - } - - fn client(&self) -> &HttpClient { - &self.client - } - - pub async fn text_chat( - &self, - model_id: &str, - request: ChatDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let url = self.endpoint(CHAT_DETECTOR_ENDPOINT); - info!("sending text chat detector request to {}", url); - self.post_to_detector(model_id, url, headers, request).await - } -} - -#[async_trait] -impl Client for TextChatDetectorClient { - fn name(&self) -> &str { - "text_chat_detector" - } - - async fn health(&self) -> HealthCheckResult { - if let Some(health_client) = &self.health_client { - health_client.health().await - } else { - self.client.health().await - } - } -} - -impl DetectorClient for TextChatDetectorClient {} - -impl HttpClientExt for TextChatDetectorClient { - fn inner(&self) -> &HttpClient { - self.client() - } -} - -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/chat endpoint. -// #[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Clone, Serialize)] -pub struct ChatDetectionRequest { - /// Chat messages to run detection on - pub messages: Vec, - - /// Optional list of tool definitions - pub tools: Vec, - - /// Detector parameters (available parameters depend on the detector) - pub detector_params: DetectorParams, -} - -impl ChatDetectionRequest { - pub fn new(messages: Vec, tools: Vec, detector_params: DetectorParams) -> Self { - Self { - messages, - tools, - detector_params, - } - } -} diff --git a/src/clients/detector/text_contents.rs b/src/clients/detector/text_contents.rs deleted file mode 100644 index 91ec9f5d..00000000 --- a/src/clients/detector/text_contents.rs +++ /dev/null @@ -1,156 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -use std::collections::BTreeMap; - -use async_trait::async_trait; -use hyper::HeaderMap; -use serde::{Deserialize, Serialize}; -use tracing::info; - -use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; -use crate::{ - clients::{Client, Error, HttpClient, create_http_client, http::HttpClientExt}, - config::ServiceConfig, - health::HealthCheckResult, - models::{DetectorParams, EvidenceObj, Metadata}, -}; - -const CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents"; - -#[derive(Clone)] -pub struct TextContentsDetectorClient { - client: HttpClient, - health_client: Option, -} - -impl TextContentsDetectorClient { - pub async fn new( - config: &ServiceConfig, - health_config: Option<&ServiceConfig>, - ) -> Result { - let client = create_http_client(DEFAULT_PORT, config).await?; - let health_client = if let Some(health_config) = health_config { - Some(create_http_client(DEFAULT_PORT, health_config).await?) - } else { - None - }; - Ok(Self { - client, - health_client, - }) - } - - fn client(&self) -> &HttpClient { - &self.client - } - - pub async fn text_contents( - &self, - model_id: &str, - request: ContentAnalysisRequest, - headers: HeaderMap, - ) -> Result>, Error> { - let url = self.endpoint(CONTENTS_DETECTOR_ENDPOINT); - info!("sending text content detector request to {}", url); - self.post_to_detector(model_id, url, headers, request).await - } -} - -#[async_trait] -impl Client for TextContentsDetectorClient { - fn name(&self) -> &str { - "text_contents_detector" - } - - async fn health(&self) -> HealthCheckResult { - if let Some(health_client) = &self.health_client { - health_client.health().await - } else { - self.client.health().await - } - } -} - -impl DetectorClient for TextContentsDetectorClient {} - -impl HttpClientExt for TextContentsDetectorClient { - fn inner(&self) -> &HttpClient { - self.client() - } -} - -/// Request for text content analysis -/// Results of this request will contain analysis / detection of each of the provided documents -/// in the order they are present in the `contents` object. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ContentAnalysisRequest { - /// Field allowing users to provide list of documents for analysis - pub contents: Vec, - - /// Detector parameters (available parameters depend on the detector) - pub detector_params: DetectorParams, -} - -impl ContentAnalysisRequest { - pub fn new(contents: Vec, detector_params: DetectorParams) -> ContentAnalysisRequest { - ContentAnalysisRequest { - contents, - detector_params, - } - } -} - -/// Response of text content analysis endpoint -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ContentAnalysisResponse { - /// Start index of detection - pub start: usize, - /// End index of detection - pub end: usize, - /// Text corresponding to detection - pub text: String, - /// Relevant detection class - pub detection: String, - /// Detection type or aggregate detection label - pub detection_type: String, - /// Optional, ID of Detector - pub detector_id: Option, - /// Score of detection - pub score: f64, - /// Optional, any applicable evidence for detection - #[serde(skip_serializing_if = "Option::is_none")] - pub evidence: Option>, - // Optional metadata block - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - pub metadata: Metadata, -} - -impl From for crate::models::TokenClassificationResult { - fn from(value: ContentAnalysisResponse) -> Self { - Self { - start: value.start as u32, - end: value.end as u32, - word: value.text, - entity: value.detection, - entity_group: value.detection_type, - detector_id: value.detector_id, - score: value.score, - token_count: None, - } - } -} diff --git a/src/clients/detector/text_context_doc.rs b/src/clients/detector/text_context_doc.rs deleted file mode 100644 index 086520ae..00000000 --- a/src/clients/detector/text_context_doc.rs +++ /dev/null @@ -1,136 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -use async_trait::async_trait; -use hyper::HeaderMap; -use serde::{Deserialize, Serialize}; -use tracing::info; - -use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; -use crate::{ - clients::{Client, Error, HttpClient, create_http_client, http::HttpClientExt}, - config::ServiceConfig, - health::HealthCheckResult, - models::{DetectionResult, DetectorParams}, -}; - -const CONTEXT_DOC_DETECTOR_ENDPOINT: &str = "/api/v1/text/context/doc"; - -#[derive(Clone)] -pub struct TextContextDocDetectorClient { - client: HttpClient, - health_client: Option, -} - -impl TextContextDocDetectorClient { - pub async fn new( - config: &ServiceConfig, - health_config: Option<&ServiceConfig>, - ) -> Result { - let client = create_http_client(DEFAULT_PORT, config).await?; - let health_client = if let Some(health_config) = health_config { - Some(create_http_client(DEFAULT_PORT, health_config).await?) - } else { - None - }; - Ok(Self { - client, - health_client, - }) - } - - fn client(&self) -> &HttpClient { - &self.client - } - - pub async fn text_context_doc( - &self, - model_id: &str, - request: ContextDocsDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let url = self.endpoint(CONTEXT_DOC_DETECTOR_ENDPOINT); - info!("sending text context doc detector request to {}", url); - self.post_to_detector(model_id, url, headers, request).await - } -} - -#[async_trait] -impl Client for TextContextDocDetectorClient { - fn name(&self) -> &str { - "text_context_doc_detector" - } - - async fn health(&self) -> HealthCheckResult { - if let Some(health_client) = &self.health_client { - health_client.health().await - } else { - self.client.health().await - } - } -} - -impl DetectorClient for TextContextDocDetectorClient {} - -impl HttpClientExt for TextContextDocDetectorClient { - fn inner(&self) -> &HttpClient { - self.client() - } -} - -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/context/doc endpoint. -#[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Clone, Serialize)] -pub struct ContextDocsDetectionRequest { - /// Content to run detection on - pub content: String, - - /// Type of context being sent - pub context_type: ContextType, - - /// Context to run detection on - pub context: Vec, - - /// Detector parameters (available parameters depend on the detector) - pub detector_params: DetectorParams, -} - -/// Enum representing the context type of a detection -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum ContextType { - #[serde(rename = "docs")] - Document, - #[serde(rename = "url")] - Url, -} - -impl ContextDocsDetectionRequest { - pub fn new( - content: String, - context_type: ContextType, - context: Vec, - detector_params: DetectorParams, - ) -> Self { - Self { - content, - context_type, - context, - detector_params, - } - } -} diff --git a/src/clients/detector/text_generation.rs b/src/clients/detector/text_generation.rs deleted file mode 100644 index 8de1c020..00000000 --- a/src/clients/detector/text_generation.rs +++ /dev/null @@ -1,118 +0,0 @@ -/* - Copyright FMS Guardrails Orchestrator Authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - -*/ - -use async_trait::async_trait; -use hyper::HeaderMap; -use serde::Serialize; -use tracing::info; - -use super::{DEFAULT_PORT, DetectorClient, DetectorClientExt}; -use crate::{ - clients::{Client, Error, HttpClient, create_http_client, http::HttpClientExt}, - config::ServiceConfig, - health::HealthCheckResult, - models::{DetectionResult, DetectorParams}, -}; - -const GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation"; - -#[derive(Clone)] -pub struct TextGenerationDetectorClient { - client: HttpClient, - health_client: Option, -} - -impl TextGenerationDetectorClient { - pub async fn new( - config: &ServiceConfig, - health_config: Option<&ServiceConfig>, - ) -> Result { - let client = create_http_client(DEFAULT_PORT, config).await?; - let health_client = if let Some(health_config) = health_config { - Some(create_http_client(DEFAULT_PORT, health_config).await?) - } else { - None - }; - Ok(Self { - client, - health_client, - }) - } - - fn client(&self) -> &HttpClient { - &self.client - } - - pub async fn text_generation( - &self, - model_id: &str, - request: GenerationDetectionRequest, - headers: HeaderMap, - ) -> Result, Error> { - let url = self.endpoint(GENERATION_DETECTOR_ENDPOINT); - info!("sending text generation detector request to {}", url); - self.post_to_detector(model_id, url, headers, request).await - } -} - -#[async_trait] -impl Client for TextGenerationDetectorClient { - fn name(&self) -> &str { - "text_context_doc_detector" - } - - async fn health(&self) -> HealthCheckResult { - if let Some(health_client) = &self.health_client { - health_client.health().await - } else { - self.client.health().await - } - } -} - -impl DetectorClient for TextGenerationDetectorClient {} - -impl HttpClientExt for TextGenerationDetectorClient { - fn inner(&self) -> &HttpClient { - self.client() - } -} - -/// A struct representing a request to a detector compatible with the -/// /api/v1/text/generation endpoint. -#[cfg_attr(test, derive(PartialEq))] -#[derive(Debug, Clone, Serialize)] -pub struct GenerationDetectionRequest { - /// User prompt sent to LLM - pub prompt: String, - - /// Text generated from an LLM - pub generated_text: String, - - /// Detector parameters (available parameters depend on the detector) - pub detector_params: DetectorParams, -} - -impl GenerationDetectionRequest { - pub fn new(prompt: String, generated_text: String, detector_params: DetectorParams) -> Self { - Self { - prompt, - generated_text, - detector_params, - } - } -} diff --git a/src/clients/generation.rs b/src/clients/generation.rs index 15590623..86e13d2a 100644 --- a/src/clients/generation.rs +++ b/src/clients/generation.rs @@ -17,7 +17,8 @@ use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; -use hyper::HeaderMap; +use hyper::{HeaderMap, StatusCode}; +use tracing::warn; use super::{BoxStream, Client, Error, NlpClient, TgisClient}; use crate::{ @@ -38,8 +39,57 @@ use crate::{ }, }; +async fn retry_function(max_retries: usize, func: F) -> Result +where + F: Fn() -> Fut, + Fut: std::future::Future>, +{ + let mut attempt = 0; + + let allowed_retry_codes = [ + StatusCode::BAD_GATEWAY, + StatusCode::SERVICE_UNAVAILABLE, + StatusCode::GATEWAY_TIMEOUT, + StatusCode::HTTP_VERSION_NOT_SUPPORTED, + StatusCode::VARIANT_ALSO_NEGOTIATES, + ]; + loop { + attempt += 1; + + match func().await { + Ok(res) => return Ok(res), + + Err(error) => { + if allowed_retry_codes.contains(&error.status_code()) { + // Only retry when status code is within list + if attempt > max_retries { + warn!( + "Final attempt failed to connect to server. attempt: {}, error: {}", + attempt, &error + ); + return Err(error); + } + // Exponential backoff for retries. + tokio::time::sleep(std::time::Duration::from_millis( + 2_u64.pow(attempt as u32 - 1), + )) + .await; + warn!( + "failed to connect to server. attempt: {}, error: {}", + attempt, &error + ); + continue; + } + + // Else return error + return Err(error); + } + } + } +} + #[derive(Clone)] -pub struct GenerationClient(Option); +pub struct GenerationClient(Option, usize); #[derive(Clone)] enum GenerationClientInner { @@ -48,16 +98,16 @@ enum GenerationClientInner { } impl GenerationClient { - pub fn tgis(client: TgisClient) -> Self { - Self(Some(GenerationClientInner::Tgis(client))) + pub fn tgis(client: TgisClient, max_retries: usize) -> Self { + Self(Some(GenerationClientInner::Tgis(client)), max_retries) } - pub fn nlp(client: NlpClient) -> Self { - Self(Some(GenerationClientInner::Nlp(client))) + pub fn nlp(client: NlpClient, max_retries: usize) -> Self { + Self(Some(GenerationClientInner::Nlp(client)), max_retries) } pub fn not_configured() -> Self { - Self(None) + Self(None, 0) } pub async fn tokenize( @@ -81,9 +131,10 @@ impl GenerationClient { } Some(GenerationClientInner::Nlp(client)) => { let request = TokenizationTaskRequest { text }; - let response = client - .tokenization_task_predict(&model_id, request, headers) - .await?; + let response = retry_function(self.1, || { + client.tokenization_task_predict(&model_id, request.clone(), headers.clone()) + }) + .await?; let tokens = response .results .into_iter() @@ -146,9 +197,10 @@ impl GenerationClient { ..Default::default() } }; - let response = client - .text_generation_task_predict(&model_id, request, headers) - .await?; + let response = retry_function(self.1, || { + client.text_generation_task_predict(&model_id, request.clone(), headers.clone()) + }) + .await?; Ok(response.into()) } None => Err(Error::ModelNotFound { model_id }), @@ -210,11 +262,18 @@ impl GenerationClient { ..Default::default() } }; - let response_stream = client - .server_streaming_text_generation_task_predict(&model_id, request, headers) - .await? - .map_ok(Into::into) - .boxed(); + + let response_stream = retry_function(self.1, || { + client.server_streaming_text_generation_task_predict( + &model_id, + request.clone(), + headers.clone(), + ) + }) + .await? + .map_ok(Into::into) + .boxed(); + Ok(response_stream) } None => Err(Error::ModelNotFound { model_id }), diff --git a/src/clients/http.rs b/src/clients/http.rs index 485d8829..945b19df 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -74,7 +74,7 @@ impl Response { .to_bytes(); serde_json::from_slice::(&data).map_err(|e| Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("client response deserialization failed: {}", e), + message: format!("client response deserialization failed: {e}"), }) } } @@ -177,7 +177,7 @@ impl HttpClient { Full::new(Bytes::from(serde_json::to_vec(&body).map_err(|e| { Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("client request serialization failed: {}", e) + message: format!("client request serialization failed: {e}") } })?)) .map_err(|err| match err {}); @@ -186,7 +186,7 @@ impl HttpClient { .map_err(|e| { Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("client request serialization failed: {}", e) + message: format!("client request serialization failed: {e}") } })?; let response = match self @@ -197,12 +197,12 @@ impl HttpClient { Ok(response) => Ok(response.map_err(|e| { Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("sending client request failed: {}", e) + message: format!("sending client request failed: {e}") } }).into_inner()), Err(e) => Err(Error::Http { code: StatusCode::REQUEST_TIMEOUT, - message: format!("client request timeout: {}", e), + message: format!("client request timeout: {e}"), }), }?; let span = Span::current(); @@ -213,7 +213,7 @@ impl HttpClient { || panic!("unexpected request builder error - headers missing in builder but no errors found"), |e| Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("client request creation failed: {}", e), + message: format!("client request creation failed: {e}"), } )), } diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 3c679ecb..86cb53d1 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -43,6 +43,7 @@ const DEFAULT_PORT: u16 = 8080; const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; const COMPLETIONS_ENDPOINT: &str = "/v1/completions"; +const TOKENIZE_ENDPOINT: &str = "/tokenize"; // This endpoint is vLLM-specific #[derive(Clone)] pub struct OpenAiClient { @@ -76,6 +77,7 @@ impl OpenAiClient { request: ChatCompletionsRequest, headers: HeaderMap, ) -> Result { + tracing::debug!(?headers, "chat_completions headers"); let url = self.client.endpoint(CHAT_COMPLETIONS_ENDPOINT); if let Some(true) = request.stream { let rx = self.handle_streaming(url, request, headers).await?; @@ -101,6 +103,16 @@ impl OpenAiClient { } } + pub async fn tokenize( + &self, + request: TokenizeRequest, + headers: HeaderMap, + ) -> Result { + let url = self.client.endpoint(TOKENIZE_ENDPOINT); + let response = self.handle_unary(url, request, headers).await?; + Ok(response) + } + async fn handle_unary(&self, url: Url, request: R, headers: HeaderMap) -> Result where R: RequestBody, @@ -110,6 +122,7 @@ impl OpenAiClient { match response.status() { StatusCode::OK => response.json::().await, _ => { + // Return error with code and message from downstream server let code = response.status(); let message = if let Ok(response) = response.json::().await { response.message @@ -132,46 +145,73 @@ impl OpenAiClient { S: DeserializeOwned + Send + 'static, { let (tx, rx) = mpsc::channel(32); - let mut event_stream = self - .client - .post(url, headers, request) - .await? - .0 - .into_data_stream() - .eventsource(); - // Spawn task to forward events to receiver - tokio::spawn(async move { - while let Some(result) = event_stream.next().await { - match result { - Ok(event) if event.data == "[DONE]" => { - // Send None to signal that the stream completed - let _ = tx.send(Ok(None)).await; - break; - } - Ok(event) => match serde_json::from_str::(&event.data) { - Ok(chunk) => { - let _ = tx.send(Ok(Some(chunk))).await; - } - Err(e) => { - let error = Error::Http { - code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("deserialization error: {e}"), - }; - let _ = tx.send(Err(error.into())).await; + let response = self.client.post(url, headers, request).await?; + match response.status() { + StatusCode::OK => { + // Create event stream + let mut event_stream = response.0.into_data_stream().eventsource(); + // Spawn task to consume event stream and send messages to receiver + tokio::spawn(async move { + while let Some(result) = event_stream.next().await { + match result { + Ok(event) if event.data == "[DONE]" => { + // DONE message: send None to signal completion + let _ = tx.send(Ok(None)).await; + break; + } + // Attempt to deserialize to S + Ok(event) => match serde_json::from_str::(&event.data) { + Ok(message) => { + let _ = tx.send(Ok(Some(message))).await; + } + Err(_serde_error) => { + // Failed to deserialize to S, attempt to deserialize to OpenAiErrorMessage + let error = match serde_json::from_str::( + &event.data, + ) { + // Return error with code and message from downstream server + Ok(openai_error) => Error::Http { + code: StatusCode::from_u16(openai_error.error.code) + .unwrap(), + message: openai_error.error.message, + }, + // Failed to deserialize to S and OpenAiErrorMessage + // Return internal server error + Err(serde_error) => Error::Http { + code: StatusCode::INTERNAL_SERVER_ERROR, + message: format!( + "deserialization error: {serde_error}" + ), + }, + }; + let _ = tx.send(Err(error.into())).await; + } + }, + Err(error) => { + // Event stream error + // Return internal server error + let error = Error::Http { + code: StatusCode::INTERNAL_SERVER_ERROR, + message: error.to_string(), + }; + let _ = tx.send(Err(error.into())).await; + } } - }, - Err(error) => { - // We received an error from the event stream, send error message - let error = Error::Http { - code: StatusCode::INTERNAL_SERVER_ERROR, - message: error.to_string(), - }; - let _ = tx.send(Err(error.into())).await; } - } + }); + Ok(rx) } - }); - Ok(rx) + _ => { + // Return error with code and message from downstream server + let code = response.status(); + let message = if let Ok(response) = response.json::().await { + response.message + } else { + "unknown error occurred".into() + }; + Err(Error::Http { code, message }) + } + } } } @@ -222,6 +262,14 @@ impl From for CompletionsResponse { } } +/// Tokenize response. +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct TokenizeResponse { + pub count: u32, + pub max_model_len: u32, + pub tokens: Vec, +} + /// Chat completions request. /// /// As orchestrator is only concerned with a limited subset @@ -295,6 +343,13 @@ impl ChatCompletionsRequest { /// the downstream server implementation. #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct CompletionsRequest { + /// Detector config. + #[serde(default, skip_serializing)] + pub detectors: DetectorConfig, + /// Prompt masks. + #[serde(rename = "_prompt_masks", skip_serializing)] + #[doc(hidden)] + pub prompt_masks: Option>, /// Stream parameter. #[serde(skip_serializing_if = "Option::is_none")] pub stream: Option, @@ -312,9 +367,46 @@ impl CompletionsRequest { if self.model.is_empty() { return Err(ValidationError::Invalid("`model` must not be empty".into())); } - if self.prompt.is_empty() { + if !self.detectors.input.is_empty() && self.prompt.is_empty() { + return Err(ValidationError::Invalid( + "`prompt` must not be empty when input detectors are provided".into(), + )); + } + Ok(()) + } +} + +/// Tokenize request. +/// +/// Required when there are input detections. +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] +pub struct TokenizeRequest { + /// Model name. + pub model: String, + /// Prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + /// Messages. + #[serde(skip_serializing_if = "Option::is_none")] + pub messages: Option>, + /// Extra fields not captured above. + #[serde(flatten)] + pub extra: Map, +} + +impl TokenizeRequest { + pub fn validate(&self) -> Result<(), ValidationError> { + if self.model.is_empty() { + return Err(ValidationError::Invalid("`model` must not be empty".into())); + } + if self.prompt.is_some() && self.messages.is_some() { + return Err(ValidationError::Invalid( + "`prompt` and `messages` cannot be used at the same time".into(), + )); + } + if self.prompt.is_none() && self.messages.is_none() { return Err(ValidationError::Invalid( - "`prompt` must not be empty".into(), + "Either `prompt` or `messages` must be supplied".into(), )); } Ok(()) @@ -385,7 +477,7 @@ pub struct ToolChoiceObject { /// The type of the tool. #[serde(rename = "type")] pub r#type: String, - pub function: Function, + pub function: FunctionCall, } /// Stream options. @@ -545,20 +637,26 @@ pub struct ImageUrl { /// Tool call. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolCall { + /// Index + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option, /// The ID of the tool call. - pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, /// The type of the tool. - #[serde(rename = "type")] - pub r#type: String, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, /// The function that the model called. - pub function: Function, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, } -/// Function. +/// Function call. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Function { +pub struct FunctionCall { /// The name of the function to call. - pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, /// The arguments to call the function with, as generated by the model in JSON format. #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, @@ -590,10 +688,18 @@ pub struct ChatCompletion { pub service_tier: Option, /// Detections #[serde(skip_serializing_if = "Option::is_none")] - pub detections: Option, + pub detections: Option, /// Warnings #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub warnings: Vec, + pub warnings: Vec, +} + +/// Helper to accept both string and integer for stop_reason. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum StopReason { + String(String), + Integer(i64), } /// Chat completion choice. @@ -608,7 +714,7 @@ pub struct ChatCompletionChoice { /// The reason the model stopped generating tokens. pub finish_reason: String, /// The stop string or token id that caused the completion. - pub stop_reason: Option, + pub stop_reason: Option, } /// Chat completion message. @@ -627,13 +733,14 @@ pub struct ChatCompletionMessage { } /// Chat completion logprobs. -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Default, Clone, Deserialize, Serialize, PartialEq)] pub struct ChatCompletionLogprobs { /// A list of message content tokens with log probability information. - pub content: Option>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub content: Vec, /// A list of message refusal tokens with log probability information. - #[serde(skip_serializing_if = "Option::is_none")] - pub refusal: Option>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub refusal: Vec, } /// Chat completion logprob. @@ -643,10 +750,9 @@ pub struct ChatCompletionLogprob { pub token: String, /// The log probability of this token. pub logprob: f32, - #[serde(skip_serializing_if = "Option::is_none")] + /// A list of integers representing the UTF-8 bytes representation of the token. pub bytes: Option>, /// List of the most likely tokens and their log probability, at this token position. - #[serde(skip_serializing_if = "Option::is_none")] pub top_logprobs: Option>, } @@ -657,10 +763,12 @@ pub struct ChatCompletionTopLogprob { pub token: String, /// The log probability of this token. pub logprob: f32, + /// A list of integers representing the UTF-8 bytes representation of the token. + pub bytes: Option>, } /// Streaming chat completion chunk. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatCompletionChunk { /// A unique identifier for the chat completion. Each chunk has the same ID. pub id: String, @@ -683,14 +791,31 @@ pub struct ChatCompletionChunk { pub usage: Option, /// Detections #[serde(skip_serializing_if = "Option::is_none")] - pub detections: Option, + pub detections: Option, /// Warnings #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub warnings: Vec, + pub warnings: Vec, +} + +impl Default for ChatCompletionChunk { + fn default() -> Self { + Self { + id: Default::default(), + object: "chat.completion.chunk".into(), + created: Default::default(), + model: Default::default(), + system_fingerprint: Default::default(), + choices: Default::default(), + service_tier: Default::default(), + usage: Default::default(), + detections: Default::default(), + warnings: Default::default(), + } + } } /// Streaming chat completion chunk choice. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatCompletionChunkChoice { /// The index of the choice in the list of choices. pub index: u32, @@ -701,11 +826,11 @@ pub struct ChatCompletionChunkChoice { /// The reason the model stopped generating tokens. pub finish_reason: Option, /// The stop string or token id that caused the completion. - pub stop_reason: Option, + pub stop_reason: Option, } /// Streaming chat completion delta. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct ChatCompletionDelta { /// The role of the author of this message. #[serde(skip_serializing_if = "Option::is_none")] @@ -739,6 +864,12 @@ pub struct Completion { /// This fingerprint represents the backend configuration that the model runs with. #[serde(skip_serializing_if = "Option::is_none")] pub system_fingerprint: Option, + /// Detections + #[serde(skip_serializing_if = "Option::is_none")] + pub detections: Option, + /// Warnings + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub warnings: Vec, } /// Completion (legacy) choice. @@ -753,14 +884,14 @@ pub struct CompletionChoice { /// The reason the model stopped generating tokens. pub finish_reason: Option, /// The stop string or token id that caused the completion. - pub stop_reason: Option, + pub stop_reason: Option, /// Prompt logprobs. #[serde(skip_serializing_if = "Option::is_none")] pub prompt_logprobs: Option>>>, } /// Completion logprobs. -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)] pub struct CompletionLogprobs { /// Tokens generated by the model. pub tokens: Vec, @@ -833,39 +964,45 @@ pub struct OpenAiError { pub code: u16, } -/// Guardrails chat detections. +/// OpenAI streaming error message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAiErrorMessage { + pub error: OpenAiError, +} + +/// Guardrails completion detections. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] -pub struct ChatDetections { +pub struct CompletionDetections { #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub input: Vec, + pub input: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub output: Vec, + pub output: Vec, } -/// Guardrails chat input detections. +/// Guardrails completion input detections. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct InputDetectionResult { +pub struct CompletionInputDetections { pub message_index: u32, - #[serde(default, skip_serializing_if = "Vec::is_empty")] + #[serde(default)] pub results: Vec, } -/// Guardrails chat output detections. +/// Guardrails completion output detections. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct OutputDetectionResult { +pub struct CompletionOutputDetections { pub choice_index: u32, - #[serde(default, skip_serializing_if = "Vec::is_empty")] + #[serde(default)] pub results: Vec, } -/// Guardrails warning. +/// Guardrails completion detection warning. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct OrchestratorWarning { +pub struct CompletionDetectionWarning { r#type: DetectionWarningReason, message: String, } -impl OrchestratorWarning { +impl CompletionDetectionWarning { pub fn new(warning_type: DetectionWarningReason, message: &str) -> Self { Self { r#type: warning_type, @@ -983,4 +1120,57 @@ mod test { Ok(()) } + + /// Test deserialization of stop_reason as integer + #[test] + fn test_chat_completion_choice_stop_reason_integer() { + use serde_json::json; + + use crate::clients::openai::{ChatCompletionChoice, StopReason}; + + let json_choice = json!({ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello!", + "tool_calls": [], + "refusal": null + }, + "logprobs": null, + "finish_reason": "EOS_TOKEN", + "stop_reason": 32007 + }); + + let choice: ChatCompletionChoice = serde_json::from_value(json_choice) + .expect("Should deserialize with integer stop_reason"); + assert_eq!(choice.stop_reason, Some(StopReason::Integer(32007))); + } + + /// Test deserialization of stop_reason as string + #[test] + fn test_chat_completion_choice_stop_reason_string() { + use serde_json::json; + + use crate::clients::openai::{ChatCompletionChoice, StopReason}; + + let json_choice = json!({ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello!", + "tool_calls": [], + "refusal": null + }, + "logprobs": null, + "finish_reason": "EOS_TOKEN", + "stop_reason": "32007" + }); + + let choice: ChatCompletionChoice = serde_json::from_value(json_choice) + .expect("Should deserialize with string stop_reason"); + assert_eq!( + choice.stop_reason, + Some(StopReason::String("32007".to_string())) + ); + } } diff --git a/src/config.rs b/src/config.rs index bfd06325..31de3c2e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,10 @@ use std::{ use serde::Deserialize; use tracing::{debug, error, info, warn}; -use crate::clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname}; +use crate::{ + clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname}, + utils::one_or_many, +}; /// Default allowed headers to passthrough to clients. const DEFAULT_ALLOWED_HEADERS: &[&str] = &[]; @@ -76,6 +79,12 @@ pub struct ServiceConfig { pub tls: Option, /// gRPC probe interval in seconds pub grpc_dns_probe_interval: Option, + /// Resolution strategy + pub resolution_strategy: Option, + /// Resolution strategy timeout in seconds + pub resolution_strategy_timeout: Option, + /// Max retries for client calls [currently only for grpc generation] + pub max_retries: Option, } impl ServiceConfig { @@ -86,6 +95,9 @@ impl ServiceConfig { request_timeout: None, tls: None, grpc_dns_probe_interval: None, + resolution_strategy: None, + resolution_strategy_timeout: None, + max_retries: None, } } } @@ -126,9 +138,9 @@ pub struct GenerationConfig { pub service: ServiceConfig, } -/// Chat generation service configuration +/// OpenAI service configuration #[derive(Default, Clone, Debug, Deserialize)] -pub struct ChatGenerationConfig { +pub struct OpenAiConfig { /// Generation service connection information pub service: ServiceConfig, /// Generation health service connection information @@ -165,8 +177,8 @@ pub struct DetectorConfig { /// Default threshold with which to filter detector results by score pub default_threshold: f64, /// Type of detection this detector performs - #[serde(rename = "type")] - pub r#type: DetectorType, + #[serde(rename = "type", deserialize_with = "one_or_many")] + pub r#type: Vec, } #[derive(Default, Clone, Debug, Deserialize, PartialEq)] @@ -185,8 +197,10 @@ pub enum DetectorType { pub struct OrchestratorConfig { /// Generation service and associated configuration, can be omitted if configuring for generation is not wanted pub generation: Option, - /// Chat generation service and associated configuration, can be omitted if configuring for chat generation is not wanted - pub chat_generation: Option, + /// OpenAI service and associated configuration, can be omitted if configuring for chat generation is not wanted + #[serde(alias = "chat_generation")] + #[serde(alias = "chat_completions")] + pub openai: Option, /// Chunker services and associated configurations, if omitted the default value "whole_doc_chunker" is used pub chunkers: Option>, /// Detector services and associated configurations @@ -215,6 +229,17 @@ impl OrchestratorConfig { error, } })?; + // TODO: Remove if conditions once aliases are deprecated + if config_yaml.contains("chat_generation") { + warn!( + "`chat_generation` is deprecated and will be removed in 1.0. Rename it to `openai`." + ) + } + if config_yaml.contains("chat_completions") { + warn!( + "`chat_completions` is deprecated and will be removed in 1.0. Rename it to `openai`." + ) + } let mut config: OrchestratorConfig = serde_yml::from_str(&config_yaml).map_err(Error::InvalidConfigFile)?; debug!(?config, "loaded orchestrator config"); @@ -260,9 +285,9 @@ impl OrchestratorConfig { if let Some(generation) = &mut self.generation { apply_named_tls_config(&mut generation.service, tls_configs)?; } - // Chat generation - if let Some(chat_generation) = &mut self.chat_generation { - apply_named_tls_config(&mut chat_generation.service, tls_configs)?; + // Open AI + if let Some(openai) = &mut self.openai { + apply_named_tls_config(&mut openai.service, tls_configs)?; } // Chunkers if let Some(chunkers) = &mut self.chunkers { @@ -286,7 +311,7 @@ impl OrchestratorConfig { // Apply validation rules self.validate_generation_config()?; - self.validate_chat_generation_config()?; + self.validate_openai_configs()?; self.validate_detector_configs()?; self.validate_chunker_configs()?; @@ -307,15 +332,16 @@ impl OrchestratorConfig { } /// Validates chat generation config. - fn validate_chat_generation_config(&self) -> Result<(), Error> { - if let Some(chat_generation) = &self.chat_generation { + fn validate_openai_configs(&self) -> Result<(), Error> { + if let Some(openai) = &self.openai { // Hostname is valid - if !is_valid_hostname(&chat_generation.service.hostname) { + if !is_valid_hostname(&openai.service.hostname) { return Err(Error::InvalidHostname( - "`chat_generation` has an invalid hostname".into(), + "`openai` has an invalid hostname".into(), )); } } + Ok(()) } @@ -385,7 +411,7 @@ impl Default for OrchestratorConfig { fn default() -> Self { Self { generation: None, - chat_generation: None, + openai: None, chunkers: None, detectors: HashMap::default(), tls: None, diff --git a/src/main.rs b/src/main.rs index b8d7675b..5d26a835 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,7 +19,10 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use clap::Parser; use fms_guardrails_orchestr8::{ - args::Args, config::OrchestratorConfig, orchestrator::Orchestrator, server, utils, + args::{Args, TracingConfig}, + config::OrchestratorConfig, + orchestrator::Orchestrator, + server, utils, }; use tracing::info; @@ -47,7 +50,8 @@ fn main() -> Result<(), anyhow::Error> { .build() .unwrap() .block_on(async { - let trace_shutdown = utils::trace::init_tracing(args.clone().into())?; + let tracing_config: TracingConfig = args.clone().try_into()?; + let trace_shutdown = utils::trace::init_tracing(tracing_config)?; let config = OrchestratorConfig::load(args.config_path).await?; let orchestrator = Orchestrator::new(config, args.start_up_health_check).await?; @@ -66,6 +70,6 @@ fn main() -> Result<(), anyhow::Error> { let _ = tokio::join!(health_handle, guardrails_handle); info!("shutdown complete"); - Ok(trace_shutdown()?) + trace_shutdown() }) } diff --git a/src/models.rs b/src/models.rs index 21c4c4fd..54852882 100644 --- a/src/models.rs +++ b/src/models.rs @@ -553,6 +553,10 @@ pub enum DetectionWarningReason { /// Unsuitable text detected on output #[serde(rename = "UNSUITABLE_OUTPUT")] UnsuitableOutput, + + /// Unsuitable text detected on output + #[serde(rename = "EMPTY_OUTPUT")] + EmptyOutput, } /// Generated token information diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 0f6ee9ba..ea570c54 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -27,17 +27,15 @@ use tracing::{debug, info}; use crate::{ clients::{ - ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, - chunker::ChunkerClient, - detector::{ - TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, - }, + ChunkerClient, ClientMap, DetectorClient, GenerationClient, NlpClient, TgisClient, openai::OpenAiClient, }, - config::{DetectorType, GenerationProvider, OrchestratorConfig}, + config::{GenerationProvider, OrchestratorConfig}, health::HealthCheckCache, }; +const DEFAULT_MAX_RETRIES: usize = 3; + #[cfg_attr(test, derive(Default))] pub struct Context { config: OrchestratorConfig, @@ -120,28 +118,29 @@ async fn create_clients(config: &OrchestratorConfig) -> Result // Create generation client if let Some(generation) = &config.generation { + let retries = generation + .service + .max_retries + .unwrap_or(DEFAULT_MAX_RETRIES); match generation.provider { GenerationProvider::Tgis => { let tgis_client = TgisClient::new(&generation.service).await; - let generation_client = GenerationClient::tgis(tgis_client); + let generation_client = GenerationClient::tgis(tgis_client, retries); clients.insert("generation".to_string(), generation_client); } GenerationProvider::Nlp => { let nlp_client = NlpClient::new(&generation.service).await; - let generation_client = GenerationClient::nlp(nlp_client); + let generation_client = GenerationClient::nlp(nlp_client, retries); clients.insert("generation".to_string(), generation_client); } } } - // Create chat generation client - if let Some(chat_generation) = &config.chat_generation { - let openai_client = OpenAiClient::new( - &chat_generation.service, - chat_generation.health_service.as_ref(), - ) - .await?; - clients.insert("chat_generation".to_string(), openai_client); + // Create chat completions client + if let Some(openai) = &config.openai { + let openai_client = + OpenAiClient::new(&openai.service, openai.health_service.as_ref()).await?; + clients.insert("openai".to_string(), openai_client); } // Create chunker clients @@ -154,48 +153,10 @@ async fn create_clients(config: &OrchestratorConfig) -> Result // Create detector clients for (detector_id, detector) in &config.detectors { - match detector.r#type { - DetectorType::TextContents => { - clients.insert( - detector_id.into(), - TextContentsDetectorClient::new( - &detector.service, - detector.health_service.as_ref(), - ) - .await?, - ); - } - DetectorType::TextGeneration => { - clients.insert( - detector_id.into(), - TextGenerationDetectorClient::new( - &detector.service, - detector.health_service.as_ref(), - ) - .await?, - ); - } - DetectorType::TextChat => { - clients.insert( - detector_id.into(), - TextChatDetectorClient::new( - &detector.service, - detector.health_service.as_ref(), - ) - .await?, - ); - } - DetectorType::TextContextDoc => { - clients.insert( - detector_id.into(), - TextContextDocDetectorClient::new( - &detector.service, - detector.health_service.as_ref(), - ) - .await?, - ); - } - } + clients.insert( + detector_id.into(), + DetectorClient::new(&detector.service, detector.health_service.as_ref()).await?, + ); } Ok(clients) } diff --git a/src/orchestrator/common/client.rs b/src/orchestrator/common/client.rs index b4d0c12b..71f7ce2d 100644 --- a/src/orchestrator/common/client.rs +++ b/src/orchestrator/common/client.rs @@ -23,15 +23,14 @@ use tracing::{debug, instrument}; use crate::{ clients::{ - GenerationClient, TextContentsDetectorClient, + DetectorClient, GenerationClient, chunker::ChunkerClient, detector::{ ChatDetectionRequest, ContentAnalysisRequest, ContextDocsDetectionRequest, ContextType, - GenerationDetectionRequest, TextChatDetectorClient, TextContextDocDetectorClient, - TextGenerationDetectorClient, + GenerationDetectionRequest, }, http::JSON_CONTENT_TYPE, - openai::{self, OpenAiClient}, + openai::{self, OpenAiClient, TokenizeRequest}, }, models::{ ClassifiedGeneratedTextResult as GenerateResponse, DetectorParams, @@ -99,7 +98,7 @@ pub async fn chunk_stream( /// Sends request to text contents detector client. #[instrument(skip_all, fields(detector_id))] pub async fn detect_text_contents( - client: &TextContentsDetectorClient, + client: &DetectorClient, headers: HeaderMap, detector_id: DetectorId, params: DetectorParams, @@ -149,7 +148,7 @@ pub async fn detect_text_contents( /// Sends request to text generation detector client. #[instrument(skip_all, fields(detector_id))] pub async fn detect_text_generation( - client: &TextGenerationDetectorClient, + client: &DetectorClient, headers: HeaderMap, detector_id: DetectorId, params: DetectorParams, @@ -181,7 +180,7 @@ pub async fn detect_text_generation( /// Sends request to text chat detector client. #[instrument(skip_all, fields(detector_id))] pub async fn detect_text_chat( - client: &TextChatDetectorClient, + client: &DetectorClient, headers: HeaderMap, detector_id: DetectorId, params: DetectorParams, @@ -213,7 +212,7 @@ pub async fn detect_text_chat( /// Sends request to text context detector client. #[instrument(skip_all, fields(detector_id))] pub async fn detect_text_context( - client: &TextContextDocDetectorClient, + client: &DetectorClient, headers: HeaderMap, detector_id: DetectorId, params: DetectorParams, @@ -337,6 +336,27 @@ pub async fn completion_stream( Ok(stream) } +/// Sends tokenize request to OpenAI client. +#[instrument(skip_all, fields(model_id))] +pub async fn tokenize_openai( + client: &OpenAiClient, + mut headers: HeaderMap, + request: TokenizeRequest, +) -> Result { + let model_id = request.model.clone(); + debug!(%model_id, ?request, "sending tokenize request"); + headers.append(CONTENT_TYPE, JSON_CONTENT_TYPE); + let response = client.tokenize(request, headers).await.map_err(|error| { + tracing::error!("Tokenize request failed: {error}"); + Error::TokenizeRequestFailed { + id: model_id.clone(), + error, + } + })?; + debug!(%model_id, ?response, "received tokenize response"); + Ok(response) +} + /// Sends tokenize request to generation client. #[instrument(skip_all, fields(model_id))] pub async fn tokenize( diff --git a/src/orchestrator/common/tasks.rs b/src/orchestrator/common/tasks.rs index efd708da..4533e075 100644 --- a/src/orchestrator/common/tasks.rs +++ b/src/orchestrator/common/tasks.rs @@ -26,12 +26,8 @@ use tracing::{Instrument, debug, instrument}; use super::{client::*, utils::*}; use crate::{ clients::{ - TextContentsDetectorClient, chunker::{ChunkerClient, DEFAULT_CHUNKER_ID}, - detector::{ - ContextType, TextChatDetectorClient, TextContextDocDetectorClient, - TextGenerationDetectorClient, - }, + detector::{ContextType, DetectorClient}, openai, }, models::DetectorParams, @@ -69,7 +65,7 @@ pub async fn chunks( } let client = ctx .clients - .get_as::(&chunker_id) + .get::(&chunker_id) .ok_or_else(|| Error::ChunkerNotFound(chunker_id.clone()))?; let chunks = chunk(client, chunker_id.clone(), text) .await? @@ -127,7 +123,7 @@ pub async fn chunk_streams( } else { let client = ctx .clients - .get_as::(&chunker_id) + .get::(&chunker_id) .ok_or_else(|| Error::ChunkerNotFound(chunker_id.clone()))?; chunk_stream(client, chunker_id.clone(), input_broadcast_rx).await }?; @@ -214,10 +210,7 @@ pub async fn text_contents_detections( let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); + let client = ctx.clients.get::(&detector_id).unwrap(); let detections = detect_text_contents( client, headers, @@ -273,10 +266,7 @@ pub async fn text_contents_detection_streams( while let Ok(result) = chunk_rx.recv().await { match result { Ok(chunk) => { - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); + let client = ctx.clients.get::(&detector_id).unwrap(); match detect_text_contents( client, headers.clone(), @@ -347,10 +337,7 @@ pub async fn text_generation_detections( let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); + let client = ctx.clients.get::(&detector_id).unwrap(); let detections = detect_text_generation( client, headers, @@ -403,10 +390,7 @@ pub async fn text_chat_detections( let default_threshold = ctx.config.detector(&detector_id).unwrap().default_threshold; let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); + let client = ctx.clients.get::(&detector_id).unwrap(); let detections = detect_text_chat( client, headers, @@ -463,10 +447,7 @@ pub async fn text_context_detections( ctx.config.detector(&detector_id).unwrap().default_threshold; let threshold = params.pop_threshold().unwrap_or(default_threshold); async move { - let client = ctx - .clients - .get_as::(&detector_id) - .unwrap(); + let client = ctx.clients.get::(&detector_id).unwrap(); let detections = detect_text_context( client, headers, @@ -611,7 +592,7 @@ mod test { input_end_index: 2, }]); }); - let sentence_chunker_server = MockServer::new("sentence_chunker").grpc().with_mocks(mocks); + let sentence_chunker_server = MockServer::new_grpc("sentence_chunker").with_mocks(mocks); sentence_chunker_server.start().await.unwrap(); // Create whole_doc_chunker @@ -628,9 +609,7 @@ mod test { token_count: 25, }); }); - let whole_doc_chunker_server = MockServer::new("whole_doc_chunker") - .grpc() - .with_mocks(mocks); + let whole_doc_chunker_server = MockServer::new_grpc("whole_doc_chunker").with_mocks(mocks); whole_doc_chunker_server.start().await.unwrap(); // Create error chunker @@ -639,7 +618,7 @@ mod test { when.path(CHUNKER_PATH); then.internal_server_error(); }); - let error_chunker_server = MockServer::new("error_chunker").grpc().with_mocks(mocks); + let error_chunker_server = MockServer::new_grpc("error_chunker").with_mocks(mocks); error_chunker_server.start().await.unwrap(); // Create fake detector @@ -689,7 +668,7 @@ mod test { }]]); }); - let fake_detector_server = MockServer::new("fake_detector").with_mocks(mocks); + let fake_detector_server = MockServer::new_http("fake_detector").with_mocks(mocks); fake_detector_server.start().await.unwrap(); let mut config = OrchestratorConfig::default(); diff --git a/src/orchestrator/common/utils.rs b/src/orchestrator/common/utils.rs index 8e6521cd..a2529d43 100644 --- a/src/orchestrator/common/utils.rs +++ b/src/orchestrator/common/utils.rs @@ -75,7 +75,7 @@ pub fn current_timestamp() -> std::time::Duration { pub fn configure_mock_servers( config: &mut crate::config::OrchestratorConfig, generation_server: Option<&mocktail::server::MockServer>, - chat_generation_server: Option<&mocktail::server::MockServer>, + openai_server: Option<&mocktail::server::MockServer>, detector_servers: Option>, chunker_servers: Option>, ) { @@ -85,11 +85,11 @@ pub fn configure_mock_servers( generation_config.service.port = Some(server.addr().unwrap().port()); config.generation = Some(generation_config); } - if let Some(server) = chat_generation_server { - let mut chat_generation_config = crate::config::ChatGenerationConfig::default(); - chat_generation_config.service.hostname = "localhost".into(); - chat_generation_config.service.port = Some(server.addr().unwrap().port()); - config.chat_generation = Some(chat_generation_config); + if let Some(server) = openai_server { + let mut openai_config = crate::config::OpenAiConfig::default(); + openai_config.service.hostname = "localhost".into(); + openai_config.service.port = Some(server.addr().unwrap().port()); + config.openai = Some(openai_config); }; if let Some(servers) = detector_servers { for server in servers { @@ -116,26 +116,30 @@ pub fn configure_mock_servers( }; } -/// Validates guardrails on request. -pub fn validate_detectors( - detectors: &HashMap, +/// Validates requested detectors. +pub fn validate_detectors<'a>( + detectors: impl IntoIterator, orchestrator_detectors: &HashMap, - allowed_detector_types: &[DetectorType], - allows_whole_doc_chunker: bool, + supported_detector_types: &[DetectorType], + supports_whole_doc_chunker: bool, ) -> Result<(), Error> { let whole_doc_chunker_id = DEFAULT_CHUNKER_ID; - for detector_id in detectors.keys() { - // validate detectors + for (detector_id, _params) in detectors { match orchestrator_detectors.get(detector_id) { Some(detector_config) => { - if !allowed_detector_types.contains(&detector_config.r#type) { + if !detector_config + .r#type + .iter() + .any(|v| supported_detector_types.contains(v)) + { let error = Error::Validation(format!( "detector `{detector_id}` is not supported by this endpoint" )); error!("{error}"); return Err(error); } - if !allows_whole_doc_chunker && detector_config.chunker_id == whole_doc_chunker_id { + if !supports_whole_doc_chunker && detector_config.chunker_id == whole_doc_chunker_id + { let error = Error::Validation(format!( "detector `{detector_id}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint" )); @@ -176,4 +180,97 @@ mod tests { let s = "哈囉世界"; assert_eq!(slice_codepoints(s, 3, 4), "界"); } + + #[test] + fn test_validate_detectors() -> Result<(), Error> { + let orchestrator_detectors = HashMap::from([ + ( + "pii".to_string(), + DetectorConfig { + chunker_id: "sentence".into(), + r#type: vec![DetectorType::TextContents], + ..Default::default() + }, + ), + ( + "pii_whole_doc".to_string(), + DetectorConfig { + chunker_id: "whole_doc_chunker".into(), + r#type: vec![DetectorType::TextContents], + ..Default::default() + }, + ), + ( + "granite_guardian".to_string(), + DetectorConfig { + chunker_id: "sentence".into(), + r#type: vec![DetectorType::TextContents, DetectorType::TextChat], + ..Default::default() + }, + ), + ]); + + assert!( + validate_detectors( + &HashMap::from([("granite_guardian".to_string(), DetectorParams::default())]), + &orchestrator_detectors, + &[DetectorType::TextContents], + true + ) + .is_ok(), + "should pass: model supports text_contents and text_chat, endpoint supports text_contents" + ); + assert!( + validate_detectors( + &HashMap::from([("granite_guardian".to_string(), DetectorParams::default())]), + &orchestrator_detectors, + &[DetectorType::TextContents, DetectorType::TextChat], + true + ) + .is_ok(), + "should pass: model supports text_contents and text_chat, endpoint supports text_contents and text_chat" + ); + assert!( + validate_detectors( + &HashMap::from([("granite_guardian".to_string(), DetectorParams::default())]), + &orchestrator_detectors, + &[DetectorType::TextGeneration], + true + ) + .is_err_and(|e| matches!(e, Error::Validation(_))), + "should fail: model supports text_contents and text_chat, endpoint supports text_generation" + ); + assert!( + validate_detectors( + &HashMap::from([("pii".to_string(), DetectorParams::default())]), + &orchestrator_detectors, + &[DetectorType::TextContextDoc], + false + ) + .is_err_and(|e| matches!(e, Error::Validation(_))), + "should fail: model supports text_contents, endpoint supports text_context_doc" + ); + assert!( + validate_detectors( + &HashMap::from([("pii_whole_doc".to_string(), DetectorParams::default())]), + &orchestrator_detectors, + &[DetectorType::TextContents], + false + ) + .is_err_and(|e| matches!(e, Error::Validation(_))), + "should fail: model uses whole_doc_chunker and endpoint doesn't support it" + ); + assert!( + validate_detectors( + &HashMap::from([("does_not_exist".to_string(), DetectorParams::default())]), + &orchestrator_detectors, + &[DetectorType::TextContents], + true + ) + .is_err_and(|e| matches!(e, Error::DetectorNotFound(_))), + "should fail: requested model does not exist" + ); + + Ok(()) + } } diff --git a/src/orchestrator/handlers.rs b/src/orchestrator/handlers.rs index e681e2ca..3b8b826c 100644 --- a/src/orchestrator/handlers.rs +++ b/src/orchestrator/handlers.rs @@ -22,6 +22,7 @@ pub use classification_with_gen::ClassificationWithGenTask; pub mod streaming_classification_with_gen; pub use streaming_classification_with_gen::StreamingClassificationWithGenTask; pub mod chat_completions_detection; +pub mod completions_detection; pub mod streaming_content_detection; pub use streaming_content_detection::StreamingContentDetectionTask; pub mod generation_with_detection; diff --git a/src/orchestrator/handlers/chat_completions_detection/streaming.rs b/src/orchestrator/handlers/chat_completions_detection/streaming.rs index 74ebcb59..c721c56c 100644 --- a/src/orchestrator/handlers/chat_completions_detection/streaming.rs +++ b/src/orchestrator/handlers/chat_completions_detection/streaming.rs @@ -14,26 +14,38 @@ limitations under the License. */ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; +use futures::{StreamExt, TryStreamExt, stream}; +use opentelemetry::trace::TraceId; use tokio::sync::mpsc; -use tracing::{Instrument, info}; +use tracing::{Instrument, debug, error, info, instrument, warn}; +use uuid::Uuid; use super::ChatCompletionsDetectionTask; use crate::{ clients::openai::*, - orchestrator::{Context, Error}, + config::DetectorType, + models::{ + DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE, + }, + orchestrator::{ + Context, Error, + common::{self, text_contents_detections, validate_detectors}, + types::{ + ChatCompletionStream, ChatMessageIterator, Chunk, CompletionBatcher, CompletionState, + DetectionBatchStream, Detections, + }, + }, }; pub async fn handle_streaming( - _ctx: Arc, + ctx: Arc, task: ChatCompletionsDetectionTask, ) -> Result { let trace_id = task.trace_id; let detectors = task.request.detectors.clone(); info!(%trace_id, config = ?detectors, "task started"); - let _input_detectors = detectors.input; - let _output_detectors = detectors.output; // Create response channel let (response_tx, response_rx) = @@ -41,15 +53,577 @@ pub async fn handle_streaming( tokio::spawn( async move { - // TODO - let _ = response_tx - .send(Err(Error::Validation( - "streaming is not yet supported".into(), - ))) + let input_detectors = detectors.input; + let output_detectors = detectors.output; + + if let Err(error) = validate_detectors( + input_detectors.iter().chain(output_detectors.iter()), + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + ) { + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + + // Handle input detection (unary) + if !input_detectors.is_empty() { + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(chunk)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Send message with input detections to response channel and terminate + let _ = response_tx.send(Ok(Some(chunk))).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + } + } + + // Create chat completions stream + let client = ctx + .clients + .get::("openai") + .unwrap(); + let chat_completion_stream = match common::chat_completion_stream(client, task.headers.clone(), task.request.clone()).await { + Ok(stream) => stream, + Err(error) => { + error!(%trace_id, %error, "task failed: error creating chat completions stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + }; + + if output_detectors.is_empty() { + // No output detectors, forward chat completion chunks to response channel + process_chat_completion_stream(trace_id, chat_completion_stream, None, None, Some(response_tx.clone())).await; + info!(%trace_id, "task completed: chat completion stream closed"); + } else { + // Handle output detection + handle_output_detection( + ctx.clone(), + &task, + output_detectors, + chat_completion_stream, + response_tx.clone(), + ) .await; + } + + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; } .in_current_span(), ); Ok(ChatCompletionsResponse::Streaming(response_rx)) } + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &ChatCompletionsDetectionTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let model_id = task.request.model.clone(); + + // Input detectors are only applied to the last message + // Get the last message + let messages = task.request.messages(); + let message = if let Some(message) = messages.last() { + message + } else { + return Err(Error::Validation("No messages provided".into())); + }; + // Validate role + if !matches!( + message.role, + Some(Role::User) | Some(Role::Assistant) | Some(Role::System) + ) { + return Err(Error::Validation( + "Last message role must be user, assistant, or system".into(), + )); + } + let input_id = message.index; + let input_text = message.text.map(|s| s.to_string()).unwrap_or_default(); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + input_id, + vec![(0, input_text.clone())], + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Get prompt tokens for usage + let client = ctx.clients.get::("openai").unwrap(); + let tokenize_request = TokenizeRequest { + model: model_id.clone(), + prompt: Some(input_text), + ..Default::default() + }; + let tokenize_response = + common::tokenize_openai(client, task.headers.clone(), tokenize_request).await?; + let usage = Usage { + prompt_tokens: tokenize_response.count, + ..Default::default() + }; + + // Build chat completion chunk with input detections + let chunk = ChatCompletionChunk { + id: Uuid::new_v4().simple().to_string(), + model: model_id, + created: common::current_timestamp().as_secs() as i64, + detections: Some(CompletionDetections { + input: vec![CompletionInputDetections { + message_index: message.index, + results: detections.into(), + }], + ..Default::default() + }), + warnings: vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + usage: Some(usage), + ..Default::default() + }; + Ok(Some(chunk)) + } else { + // No input detections + Ok(None) + } +} + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: &ChatCompletionsDetectionTask, + detectors: HashMap, + chat_completion_stream: ChatCompletionStream, + response_tx: mpsc::Sender, Error>>, +) { + let trace_id = task.trace_id; + let request = task.request.clone(); + // Split output detectors into 2 groups: + // 1) Output Detectors: Applied to chunks. Detections are returned in batches. + // 2) Whole Doc Output Detectors: Applied to concatenated chunks (whole doc) after the chat completion stream has been consumed. + // Currently, this is any detector that uses "whole_doc_chunker". + let (whole_doc_detectors, detectors): (HashMap<_, _>, HashMap<_, _>) = + detectors.into_iter().partition(|(detector_id, _)| { + ctx.config.get_chunker_id(detector_id).unwrap() == "whole_doc_chunker" + }); + let completion_state = Arc::new(CompletionState::new()); + + if !detectors.is_empty() { + // Set up streaming detection pipeline + // n represents how many choices to generate for each input message + // Choices are processed independently so each choice has its own input channels and detection streams. + let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize; + // Create input channels + let mut input_txs = HashMap::with_capacity(n); + let mut input_rxs = HashMap::with_capacity(n); + (0..n).for_each(|choice_index| { + let (input_tx, input_rx) = mpsc::channel::>(32); + input_txs.insert(choice_index as u32, input_tx); + input_rxs.insert(choice_index as u32, input_rx); + }); + // Create detection streams + let mut detection_streams = Vec::with_capacity(n * detectors.len()); + for (choice_index, input_rx) in input_rxs { + match common::text_contents_detection_streams( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + choice_index, + input_rx, + ) + .await + { + Ok(streams) => { + detection_streams.extend(streams); + } + Err(error) => { + error!(%trace_id, %error, "task failed: error creating detection streams"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + } + } + } + + // Spawn task to consume chat completions stream and send choice text to detection pipeline + tokio::spawn(process_chat_completion_stream( + trace_id, + chat_completion_stream, + Some(completion_state.clone()), + Some(input_txs), + None, + )); + // Process detection streams and await completion + let detection_batch_stream = + DetectionBatchStream::new(CompletionBatcher::new(detectors.len()), detection_streams); + process_detection_batch_stream( + trace_id, + completion_state.clone(), + detection_batch_stream, + response_tx.clone(), + ) + .await; + } else { + // We only have whole doc detectors, so the streaming detection pipeline is disabled + // Consume chat completions stream and await completion + process_chat_completion_stream( + trace_id, + chat_completion_stream, + Some(completion_state.clone()), + None, + Some(response_tx.clone()), + ) + .await; + } + // NOTE: at this point, the chat completions stream has been fully consumed and chat completion state is final + + // If whole doc output detections or usage is requested, a final message is sent with these items + if !whole_doc_detectors.is_empty() || completion_state.usage().is_some() { + let mut chat_completion = ChatCompletionChunk { + id: completion_state.id().unwrap().to_string(), + created: completion_state.created().unwrap(), + model: completion_state.model().unwrap().to_string(), + usage: completion_state.usage().cloned(), + ..Default::default() + }; + if !whole_doc_detectors.is_empty() { + // Handle whole doc output detection + match handle_whole_doc_output_detection( + ctx.clone(), + task, + whole_doc_detectors, + completion_state, + ) + .await + { + Ok((detections, warnings)) => { + chat_completion.detections = Some(detections); + chat_completion.warnings = warnings; + } + Err(error) => { + error!(%error, "task failed: error processing whole doc output detections"); + // Send error to response channel + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + } + } + // Send chat completion with whole doc output detections and/or usage to response channel + let _ = response_tx.send(Ok(Some(chat_completion))).await; + } +} + +/// Processes chat completion stream. +#[allow(clippy::type_complexity)] +async fn process_chat_completion_stream( + trace_id: TraceId, + mut chat_completion_stream: ChatCompletionStream, + completion_state: Option>>, + input_txs: Option>>>, + response_tx: Option, Error>>>, +) { + while let Some((message_index, result)) = chat_completion_stream.next().await { + match result { + Ok(Some(chat_completion)) => { + // Send chat completion chunk to response channel + // NOTE: this forwards chat completion chunks without detections and is only + // done here for 2 cases: a) no output detectors b) only whole doc output detectors + if let Some(response_tx) = &response_tx { + if response_tx + .send(Ok(Some(chat_completion.clone()))) + .await + .is_err() + { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + if let Some(usage) = &chat_completion.usage + && chat_completion.choices.is_empty() + { + // Update state: set usage + // NOTE: this message has no choices and is not sent to detection input channel + if let Some(state) = &completion_state { + state.set_usage(usage.clone()); + } + } else { + if message_index == 0 { + // Update state: set metadata + // NOTE: these values are the same for all chat completion chunks + if let Some(state) = &completion_state { + state.set_metadata( + chat_completion.id.clone(), + chat_completion.created, + chat_completion.model.clone(), + ); + } + } + // NOTE: chat completion chunks should contain only 1 choice + if let Some(choice) = chat_completion.choices.first() { + // Extract choice text + let choice_text = choice.delta.content.clone().unwrap_or_default(); + // Update state: insert completion + if let Some(state) = &completion_state { + state.insert_completion( + choice.index, + message_index, + chat_completion.clone(), + ); + } + // Send choice text to detection input channel + if let Some(input_tx) = + input_txs.as_ref().and_then(|txs| txs.get(&choice.index)) + { + if !choice_text.is_empty() { + let _ = input_tx.send(Ok((message_index, choice_text))).await; + } + } + } else { + debug!(%trace_id, %message_index, ?chat_completion, "chat completion chunk contains no choice"); + warn!(%trace_id, %message_index, "chat completion chunk contains no choice"); + } + } + } + Ok(None) => (), // Complete, stream has closed + Err(error) => { + error!(%trace_id, %error, "task failed: error received from chat completion stream"); + // Send error to response channel + if let Some(response_tx) = &response_tx { + let _ = response_tx.send(Err(error.clone())).await; + } + // Send error to detection input channels + if let Some(input_txs) = &input_txs { + for input_tx in input_txs.values() { + let _ = input_tx.send(Err(error.clone())).await; + } + } + } + } + } +} + +#[instrument(skip_all)] +async fn handle_whole_doc_output_detection( + ctx: Arc, + task: &ChatCompletionsDetectionTask, + detectors: HashMap, + completion_state: Arc>, +) -> Result<(CompletionDetections, Vec), Error> { + // Create vec of choice_index->inputs, where inputs contains the concatenated text for the choice + let choice_inputs = completion_state + .completions + .iter() + .map(|entry| { + let choice_index = *entry.key(); + let text = entry + .values() + .map(|chunk| { + chunk + .choices + .first() + .and_then(|choice| choice.delta.content.clone()) + .unwrap_or_default() + }) + .collect::(); + let inputs = vec![(0usize, text)]; + (choice_index, inputs) + }) + .collect::>(); + // Process detections concurrently for choices + let choice_detections = stream::iter(choice_inputs) + .map(|(choice_index, inputs)| { + text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + choice_index, + inputs, + ) + }) + .buffer_unordered(ctx.config.detector_concurrent_requests) + .try_collect::>() + .await?; + // Build output detections + let output = choice_detections + .into_iter() + .map(|(choice_index, detections)| CompletionOutputDetections { + choice_index, + results: detections.into(), + }) + .collect::>(); + // Build warnings + let warnings = if output.iter().any(|d| !d.results.is_empty()) { + vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )] + } else { + Vec::new() + }; + let detections = CompletionDetections { + output, + ..Default::default() + }; + Ok((detections, warnings)) +} + +/// Builds a response with output detections. +fn output_detection_response( + completion_state: &Arc>, + choice_index: u32, + chunk: Chunk, + detections: Detections, +) -> Result { + // Get chat completions for this choice index + let chat_completions = completion_state.completions.get(&choice_index).unwrap(); + // Get range of chat completions for this chunk + let chat_completions = chat_completions + .range(chunk.input_start_index..=chunk.input_end_index) + .map(|(_index, chat_completion)| chat_completion.clone()) + .collect::>(); + let content = Some(chunk.text); + let logprobs = merge_logprobs(&chat_completions); + // Build response using the last chat completion received for this chunk + if let Some(chat_completion) = chat_completions.last() { + let mut chat_completion = chat_completion.clone(); + // Set role + chat_completion.choices[0].delta.role = Some(Role::Assistant); + // Set content + chat_completion.choices[0].delta.content = content; + // TODO: if applicable, set tool_calls and refusal + // Set logprobs + chat_completion.choices[0].logprobs = logprobs; + // Set warnings + if !detections.is_empty() { + chat_completion.warnings = vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )]; + } + // Set detections + chat_completion.detections = Some(CompletionDetections { + output: vec![CompletionOutputDetections { + choice_index, + results: detections.into(), + }], + ..Default::default() + }); + Ok(chat_completion) + } else { + error!( + %choice_index, + %chunk.input_start_index, + %chunk.input_end_index, + "no chat completions found for chunk" + ); + Err(Error::Other("no chat completions found for chunk".into())) + } +} + +/// Combines logprobs from chat completion chunks to a single [`ChatCompletionLogprobs`]. +fn merge_logprobs(chat_completions: &[ChatCompletionChunk]) -> Option { + let mut content: Vec = Vec::new(); + let mut refusal: Vec = Vec::new(); + for chat_completion in chat_completions { + if let Some(choice) = chat_completion.choices.first() { + if let Some(logprobs) = &choice.logprobs { + content.extend_from_slice(&logprobs.content); + refusal.extend_from_slice(&logprobs.refusal); + } + } + } + (!content.is_empty() || !refusal.is_empty()) + .then_some(ChatCompletionLogprobs { content, refusal }) +} + +/// Consumes a detection batch stream, builds responses, and sends them to a response channel. +async fn process_detection_batch_stream( + trace_id: TraceId, + completion_state: Arc>, + mut detection_batch_stream: DetectionBatchStream, + response_tx: mpsc::Sender, Error>>, +) { + while let Some(result) = detection_batch_stream.next().await { + match result { + Ok((choice_index, chunk, detections)) => { + let input_end_index = chunk.input_end_index; + match output_detection_response(&completion_state, choice_index, chunk, detections) + { + Ok(chat_completion) => { + // Send chat completion to response channel + debug!(%trace_id, %choice_index, ?chat_completion, "sending chat completion chunk to response channel"); + if response_tx.send(Ok(Some(chat_completion))).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + // If this is the final chat completion chunk with content, send chat completion chunk with finish reason + let chat_completions = + completion_state.completions.get(&choice_index).unwrap(); + if chat_completions.keys().rev().nth(1) == Some(&input_end_index) { + if let Some((_, chat_completion)) = chat_completions.last_key_value() { + if chat_completion + .choices + .first() + .is_some_and(|choice| choice.finish_reason.is_some()) + { + let mut chat_completion = chat_completion.clone(); + // Set role + chat_completion.choices[0].delta.role = Some(Role::Assistant); + debug!(%trace_id, %choice_index, ?chat_completion, "sending chat completion chunk with finish reason to response channel"); + let _ = response_tx.send(Ok(Some(chat_completion))).await; + } + } + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error building output detection response"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection batch stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection batch stream closed"); +} diff --git a/src/orchestrator/handlers/chat_completions_detection/unary.rs b/src/orchestrator/handlers/chat_completions_detection/unary.rs index 4444ae8f..3a373e24 100644 --- a/src/orchestrator/handlers/chat_completions_detection/unary.rs +++ b/src/orchestrator/handlers/chat_completions_detection/unary.rs @@ -45,14 +45,7 @@ pub async fn handle_unary( let output_detectors = detectors.output; validate_detectors( - &input_detectors, - &ctx.config.detectors, - &[DetectorType::TextContents], - true, - )?; - - validate_detectors( - &output_detectors, + input_detectors.iter().chain(output_detectors.iter()), &ctx.config.detectors, &[DetectorType::TextContents], true, @@ -76,10 +69,7 @@ pub async fn handle_unary( } // Handle chat completion - let client = ctx - .clients - .get_as::("chat_generation") - .unwrap(); + let client = ctx.clients.get::("openai").unwrap(); let chat_completion = match common::chat_completion(client, task.headers.clone(), task.request.clone()).await { Ok(ChatCompletionsResponse::Unary(chat_completion)) => *chat_completion, @@ -133,7 +123,7 @@ async fn handle_input_detection( task.headers.clone(), detectors.clone(), input_id, - vec![(0, input_text)], + vec![(0, input_text.clone())], ) .await { @@ -144,22 +134,37 @@ async fn handle_input_detection( } }; if !detections.is_empty() { + // Get prompt tokens for usage + let client = ctx.clients.get::("openai").unwrap(); + let tokenize_request = TokenizeRequest { + model: model_id.clone(), + prompt: Some(input_text), + ..Default::default() + }; + let tokenize_response = + common::tokenize_openai(client, task.headers.clone(), tokenize_request).await?; + let usage = Usage { + prompt_tokens: tokenize_response.count, + ..Default::default() + }; + // Build chat completion with input detections let chat_completion = ChatCompletion { id: Uuid::new_v4().simple().to_string(), model: model_id, created: common::current_timestamp().as_secs() as i64, - detections: Some(ChatDetections { - input: vec![InputDetectionResult { + detections: Some(CompletionDetections { + input: vec![CompletionInputDetections { message_index: message.index, results: detections.into(), }], ..Default::default() }), - warnings: vec![OrchestratorWarning::new( + warnings: vec![CompletionDetectionWarning::new( DetectionWarningReason::UnsuitableInput, UNSUITABLE_INPUT_MESSAGE, )], + usage, ..Default::default() }; Ok(Some(chat_completion)) @@ -178,6 +183,23 @@ async fn handle_output_detection( ) -> Result { let mut tasks = Vec::with_capacity(chat_completion.choices.len()); for choice in &chat_completion.choices { + if choice + .message + .content + .as_ref() + .is_none_or(|content| content.is_empty()) + { + chat_completion + .warnings + .push(CompletionDetectionWarning::new( + DetectionWarningReason::EmptyOutput, + &format!( + "Choice of index {} has no content. Output detection was not executed", + choice.index + ), + )); + continue; + } let input_id = choice.index; let input_text = choice.message.content.clone().unwrap_or_default(); tasks.push(tokio::spawn( @@ -200,17 +222,17 @@ async fn handle_output_detection( let output = detections .into_iter() .filter(|(_, detections)| !detections.is_empty()) - .map(|(input_id, detections)| OutputDetectionResult { + .map(|(input_id, detections)| CompletionOutputDetections { choice_index: input_id, results: detections.into(), }) .collect::>(); if !output.is_empty() { - chat_completion.detections = Some(ChatDetections { + chat_completion.detections = Some(CompletionDetections { output, ..Default::default() }); - chat_completion.warnings = vec![OrchestratorWarning::new( + chat_completion.warnings = vec![CompletionDetectionWarning::new( DetectionWarningReason::UnsuitableOutput, UNSUITABLE_OUTPUT_MESSAGE, )]; diff --git a/src/orchestrator/handlers/classification_with_gen.rs b/src/orchestrator/handlers/classification_with_gen.rs index 59d80dde..2ec8212f 100644 --- a/src/orchestrator/handlers/classification_with_gen.rs +++ b/src/orchestrator/handlers/classification_with_gen.rs @@ -51,16 +51,8 @@ impl Handle for Orchestrator { let input_detectors = task.guardrails_config.input_detectors(); let output_detectors = task.guardrails_config.output_detectors(); - // input detectors validation validate_detectors( - &input_detectors, - &ctx.config.detectors, - &[DetectorType::TextContents], - true, - )?; - // output detectors validation - validate_detectors( - &output_detectors, + input_detectors.iter().chain(output_detectors.iter()), &ctx.config.detectors, &[DetectorType::TextContents], true, @@ -83,10 +75,7 @@ impl Handle for Orchestrator { } // Handle generation - let client = ctx - .clients - .get_as::("generation") - .unwrap(); + let client = ctx.clients.get::("generation").unwrap(); let generation = common::generate( client, task.headers.clone(), @@ -132,10 +121,7 @@ async fn handle_input_detection( }; if !detections.is_empty() { // Get token count - let client = ctx - .clients - .get_as::("generation") - .unwrap(); + let client = ctx.clients.get::("generation").unwrap(); let input_token_count = match common::tokenize( client, task.headers.clone(), diff --git a/src/orchestrator/handlers/completions_detection.rs b/src/orchestrator/handlers/completions_detection.rs new file mode 100644 index 00000000..3ab0401a --- /dev/null +++ b/src/orchestrator/handlers/completions_detection.rs @@ -0,0 +1,65 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use http::HeaderMap; +use opentelemetry::trace::TraceId; +use tracing::instrument; + +use super::Handle; +use crate::{ + clients::openai::{CompletionsRequest, CompletionsResponse}, + orchestrator::{Error, Orchestrator}, +}; + +pub mod streaming; +pub mod unary; + +impl Handle for Orchestrator { + type Response = CompletionsResponse; + + #[instrument( + name = "completions_detection", + skip_all, + fields(trace_id = ?task.trace_id, headers = ?task.headers) + )] + async fn handle(&self, task: CompletionsDetectionTask) -> Result { + let ctx = self.ctx.clone(); + match task.request.stream { + Some(true) => streaming::handle_streaming(ctx, task).await, + _ => unary::handle_unary(ctx, task).await, + } + } +} + +#[derive(Debug)] +pub struct CompletionsDetectionTask { + /// Trace ID + pub trace_id: TraceId, + /// Request + pub request: CompletionsRequest, + /// Headers + pub headers: HeaderMap, +} + +impl CompletionsDetectionTask { + pub fn new(trace_id: TraceId, request: CompletionsRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + request, + headers, + } + } +} diff --git a/src/orchestrator/handlers/completions_detection/streaming.rs b/src/orchestrator/handlers/completions_detection/streaming.rs new file mode 100644 index 00000000..bbdcbbf3 --- /dev/null +++ b/src/orchestrator/handlers/completions_detection/streaming.rs @@ -0,0 +1,608 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::{collections::HashMap, sync::Arc}; + +use futures::{StreamExt, TryStreamExt, stream}; +use opentelemetry::trace::TraceId; +use tokio::sync::mpsc; +use tracing::{Instrument, debug, error, info, instrument, warn}; +use uuid::Uuid; + +use super::CompletionsDetectionTask; +use crate::{ + clients::openai::*, + config::DetectorType, + models::{ + DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE, + }, + orchestrator::{ + Context, Error, + common::{self, text_contents_detections, validate_detectors}, + types::{ + Chunk, CompletionBatcher, CompletionState, CompletionStream, DetectionBatchStream, + Detections, + }, + }, +}; + +pub async fn handle_streaming( + ctx: Arc, + task: CompletionsDetectionTask, +) -> Result { + let trace_id = task.trace_id; + let detectors = task.request.detectors.clone(); + info!(%trace_id, config = ?detectors, "task started"); + + // Create response channel + let (response_tx, response_rx) = mpsc::channel::, Error>>(128); + + tokio::spawn( + async move { + let input_detectors = detectors.input; + let output_detectors = detectors.output; + + if let Err(error) = validate_detectors( + input_detectors.iter().chain(output_detectors.iter()), + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + ) { + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + + // Handle input detection (unary) + if !input_detectors.is_empty() { + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(chunk)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Send message with input detections to response channel and terminate + let _ = response_tx.send(Ok(Some(chunk))).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + } + } + + // Create completions stream + let client = ctx + .clients + .get::("openai") + .unwrap(); + let completion_stream = match common::completion_stream(client, task.headers.clone(), task.request.clone()).await { + Ok(stream) => stream, + Err(error) => { + error!(%trace_id, %error, "task failed: error creating completions stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + }; + + if output_detectors.is_empty() { + // No output detectors, forward completion chunks to response channel + process_completion_stream(trace_id, completion_stream, None, None, Some(response_tx.clone())).await; + info!(%trace_id, "task completed: completion stream closed"); + } else { + // Handle output detection + handle_output_detection( + ctx.clone(), + &task, + output_detectors, + completion_stream, + response_tx.clone(), + ) + .await; + } + + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + } + .in_current_span(), + ); + + Ok(CompletionsResponse::Streaming(response_rx)) +} + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &CompletionsDetectionTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let model_id = task.request.model.clone(); + let inputs = common::apply_masks( + task.request.prompt.clone(), + task.request.prompt_masks.as_deref(), + ); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + 0, + inputs, + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Get prompt tokens for usage + let client = ctx.clients.get::("openai").unwrap(); + let tokenize_request = TokenizeRequest { + model: model_id.clone(), + prompt: Some(task.request.prompt.clone()), + ..Default::default() + }; + let tokenize_response = + common::tokenize_openai(client, task.headers.clone(), tokenize_request).await?; + let usage = Usage { + prompt_tokens: tokenize_response.count, + ..Default::default() + }; + + // Build completion chunk with input detections + let chunk = Completion { + id: Uuid::new_v4().simple().to_string(), + model: model_id, + created: common::current_timestamp().as_secs() as i64, + detections: Some(CompletionDetections { + input: vec![CompletionInputDetections { + message_index: 0, + results: detections.into(), + }], + ..Default::default() + }), + warnings: vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + usage: Some(usage), + ..Default::default() + }; + Ok(Some(chunk)) + } else { + // No input detections + Ok(None) + } +} + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: &CompletionsDetectionTask, + detectors: HashMap, + completion_stream: CompletionStream, + response_tx: mpsc::Sender, Error>>, +) { + let trace_id = task.trace_id; + let request = task.request.clone(); + // Split output detectors into 2 groups: + // 1) Output Detectors: Applied to chunks. Detections are returned in batches. + // 2) Whole Doc Output Detectors: Applied to concatenated chunks (whole doc) after the completion stream has been consumed. + // Currently, this is any detector that uses "whole_doc_chunker". + let (whole_doc_detectors, detectors): (HashMap<_, _>, HashMap<_, _>) = + detectors.into_iter().partition(|(detector_id, _)| { + ctx.config.get_chunker_id(detector_id).unwrap() == "whole_doc_chunker" + }); + let completion_state = Arc::new(CompletionState::new()); + + if !detectors.is_empty() { + // Set up streaming detection pipeline + // n represents how many choices to generate for each input message + // Choices are processed independently so each choice has its own input channels and detection streams. + let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize; + // Create input channels + let mut input_txs = HashMap::with_capacity(n); + let mut input_rxs = HashMap::with_capacity(n); + (0..n).for_each(|choice_index| { + let (input_tx, input_rx) = mpsc::channel::>(32); + input_txs.insert(choice_index as u32, input_tx); + input_rxs.insert(choice_index as u32, input_rx); + }); + // Create detection streams + let mut detection_streams = Vec::with_capacity(n * detectors.len()); + for (choice_index, input_rx) in input_rxs { + match common::text_contents_detection_streams( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + choice_index, + input_rx, + ) + .await + { + Ok(streams) => { + detection_streams.extend(streams); + } + Err(error) => { + error!(%trace_id, %error, "task failed: error creating detection streams"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + } + } + } + + // Spawn task to consume completions stream and send choice text to detection pipeline + tokio::spawn(process_completion_stream( + trace_id, + completion_stream, + Some(completion_state.clone()), + Some(input_txs), + None, + )); + // Process detection streams and await completion + let detection_batch_stream = + DetectionBatchStream::new(CompletionBatcher::new(detectors.len()), detection_streams); + process_detection_batch_stream( + trace_id, + completion_state.clone(), + detection_batch_stream, + response_tx.clone(), + ) + .await; + } else { + // We only have whole doc detectors, so the streaming detection pipeline is disabled + // Consume completions stream and await completion + process_completion_stream( + trace_id, + completion_stream, + Some(completion_state.clone()), + None, + Some(response_tx.clone()), + ) + .await; + } + // NOTE: at this point, the completions stream has been fully consumed and completion state is final + + // If whole doc output detections or usage is requested, a final message is sent with these items + if !whole_doc_detectors.is_empty() || completion_state.usage().is_some() { + let mut completion = Completion { + id: completion_state.id().unwrap().to_string(), + created: completion_state.created().unwrap(), + model: completion_state.model().unwrap().to_string(), + usage: completion_state.usage().cloned(), + ..Default::default() + }; + if !whole_doc_detectors.is_empty() { + // Handle whole doc output detection + match handle_whole_doc_output_detection( + ctx.clone(), + task, + whole_doc_detectors, + completion_state, + ) + .await + { + Ok((detections, warnings)) => { + completion.detections = Some(detections); + completion.warnings = warnings; + } + Err(error) => { + error!(%error, "task failed: error processing whole doc output detections"); + // Send error to response channel + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + // Send completion with whole doc output detections and/or usage to response channel + let _ = response_tx.send(Ok(Some(completion))).await; + } +} + +/// Processes completion stream. +#[allow(clippy::type_complexity)] +async fn process_completion_stream( + trace_id: TraceId, + mut completion_stream: CompletionStream, + completion_state: Option>>, + input_txs: Option>>>, + response_tx: Option, Error>>>, +) { + while let Some((message_index, result)) = completion_stream.next().await { + match result { + Ok(Some(completion)) => { + // Send completion chunk to response channel + // NOTE: this forwards completion chunks without detections and is only + // done here for 2 cases: a) no output detectors b) only whole doc output detectors + if let Some(response_tx) = &response_tx { + if response_tx + .send(Ok(Some(completion.clone()))) + .await + .is_err() + { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + if let Some(usage) = &completion.usage + && completion.choices.is_empty() + { + // Update state: set usage + // NOTE: this message has no choices and is not sent to detection input channel + if let Some(state) = &completion_state { + state.set_usage(usage.clone()); + } + } else { + if message_index == 0 { + // Update state: set metadata + // NOTE: these values are the same for all completion chunks + if let Some(state) = &completion_state { + state.set_metadata( + completion.id.clone(), + completion.created, + completion.model.clone(), + ); + } + } + // NOTE: completion chunks should contain only 1 choice + if let Some(choice) = completion.choices.first() { + // Extract choice text + let choice_text = choice.text.clone(); + // Update state: insert completion + if let Some(state) = &completion_state { + state.insert_completion( + choice.index, + message_index, + completion.clone(), + ); + } + // Send choice text to detection input channel + if let Some(input_tx) = + input_txs.as_ref().and_then(|txs| txs.get(&choice.index)) + { + if !choice_text.is_empty() { + let _ = input_tx.send(Ok((message_index, choice_text))).await; + } + } + } else { + debug!(%trace_id, %message_index, ?completion, "completion chunk contains no choice"); + warn!(%trace_id, %message_index, "completion chunk contains no choice"); + } + } + } + Ok(None) => (), // Complete, stream has closed + Err(error) => { + error!(%trace_id, %error, "task failed: error received from completion stream"); + // Send error to response channel + if let Some(response_tx) = &response_tx { + let _ = response_tx.send(Err(error.clone())).await; + } + // Send error to detection input channels + if let Some(input_txs) = &input_txs { + for input_tx in input_txs.values() { + let _ = input_tx.send(Err(error.clone())).await; + } + } + } + } + } +} + +#[instrument(skip_all)] +async fn handle_whole_doc_output_detection( + ctx: Arc, + task: &CompletionsDetectionTask, + detectors: HashMap, + completion_state: Arc>, +) -> Result<(CompletionDetections, Vec), Error> { + // Create vec of choice_index->inputs, where inputs contains the concatenated text for the choice + let choice_inputs = completion_state + .completions + .iter() + .map(|entry| { + let choice_index = *entry.key(); + let text = entry + .values() + .map(|chunk| { + chunk + .choices + .first() + .map(|choice| choice.text.clone()) + .unwrap_or_default() + }) + .collect::(); + let inputs = vec![(0usize, text)]; + (choice_index, inputs) + }) + .collect::>(); + // Process detections concurrently for choices + let choice_detections = stream::iter(choice_inputs) + .map(|(choice_index, inputs)| { + text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + choice_index, + inputs, + ) + }) + .buffer_unordered(ctx.config.detector_concurrent_requests) + .try_collect::>() + .await?; + // Build output detections + let output = choice_detections + .into_iter() + .map(|(choice_index, detections)| CompletionOutputDetections { + choice_index, + results: detections.into(), + }) + .collect::>(); + // Build warnings + let warnings = if output.iter().any(|d| !d.results.is_empty()) { + vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )] + } else { + Vec::new() + }; + let detections = CompletionDetections { + output, + ..Default::default() + }; + Ok((detections, warnings)) +} + +/// Builds a response with output detections. +fn output_detection_response( + completion_state: &Arc>, + choice_index: u32, + chunk: Chunk, + detections: Detections, +) -> Result { + // Get completions for this choice index + let completions = completion_state.completions.get(&choice_index).unwrap(); + // Get range of completions for this chunk + let completions = completions + .range(chunk.input_start_index..=chunk.input_end_index) + .map(|(_index, completion)| completion.clone()) + .collect::>(); + let logprobs = merge_logprobs(&completions); + // Build response using the last completion received for this chunk + if let Some(completion) = completions.last() { + let mut completion = completion.clone(); + // Set content + completion.choices[0].text = chunk.text; + // Set logprobs + completion.choices[0].logprobs = logprobs; + // Set warnings + if !detections.is_empty() { + completion.warnings = vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )]; + } + // Set detections + completion.detections = Some(CompletionDetections { + output: vec![CompletionOutputDetections { + choice_index, + results: detections.into(), + }], + ..Default::default() + }); + Ok(completion) + } else { + error!( + %choice_index, + %chunk.input_start_index, + %chunk.input_end_index, + "no completions found for chunk" + ); + Err(Error::Other("no completions found for chunk".into())) + } +} + +/// Combines logprobs from completion chunks to a single [`CompletionLogprobs`]. +fn merge_logprobs(completions: &[Completion]) -> Option { + let mut merged_logprobs = CompletionLogprobs::default(); + for completion in completions { + if let Some(choice) = completion.choices.first() { + if let Some(logprobs) = &choice.logprobs { + merged_logprobs.tokens.extend_from_slice(&logprobs.tokens); + merged_logprobs + .token_logprobs + .extend_from_slice(&logprobs.token_logprobs); + merged_logprobs + .top_logprobs + .extend_from_slice(&logprobs.top_logprobs); + merged_logprobs + .text_offset + .extend_from_slice(&logprobs.text_offset); + } + } + } + (!merged_logprobs.tokens.is_empty() + || !merged_logprobs.token_logprobs.is_empty() + || !merged_logprobs.top_logprobs.is_empty() + || !merged_logprobs.text_offset.is_empty()) + .then_some(merged_logprobs) +} + +/// Consumes a detection batch stream, builds responses, and sends them to a response channel. +async fn process_detection_batch_stream( + trace_id: TraceId, + completion_state: Arc>, + mut detection_batch_stream: DetectionBatchStream, + response_tx: mpsc::Sender, Error>>, +) { + while let Some(result) = detection_batch_stream.next().await { + match result { + Ok((choice_index, chunk, detections)) => { + let input_end_index = chunk.input_end_index; + match output_detection_response(&completion_state, choice_index, chunk, detections) + { + Ok(completion) => { + // Send completion to response channel + debug!(%trace_id, %choice_index, ?completion, "sending completion chunk to response channel"); + if response_tx.send(Ok(Some(completion))).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + // If this is the final completion chunk with content, send completion chunk with finish reason + let completions = completion_state.completions.get(&choice_index).unwrap(); + if completions.keys().rev().nth(1) == Some(&input_end_index) { + if let Some((_, completion)) = completions.last_key_value() { + if completion + .choices + .first() + .is_some_and(|choice| choice.finish_reason.is_some()) + { + debug!(%trace_id, %choice_index, ?completion, "sending completion chunk with finish reason to response channel"); + let _ = response_tx.send(Ok(Some(completion.clone()))).await; + } + } + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error building output detection response"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection batch stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection batch stream closed"); +} diff --git a/src/orchestrator/handlers/completions_detection/unary.rs b/src/orchestrator/handlers/completions_detection/unary.rs new file mode 100644 index 00000000..9af632f1 --- /dev/null +++ b/src/orchestrator/handlers/completions_detection/unary.rs @@ -0,0 +1,217 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ +use std::{collections::HashMap, sync::Arc}; + +use futures::future::try_join_all; +use tracing::{Instrument, error, info, instrument}; +use uuid::Uuid; + +use super::CompletionsDetectionTask; +use crate::{ + clients::openai::*, + config::DetectorType, + models::{ + DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE, UNSUITABLE_OUTPUT_MESSAGE, + }, + orchestrator::{ + Context, Error, + common::{self, validate_detectors}, + }, +}; + +pub async fn handle_unary( + ctx: Arc, + task: CompletionsDetectionTask, +) -> Result { + let trace_id = task.trace_id; + let detectors = task.request.detectors.clone(); + info!(%trace_id, config = ?detectors, "task started"); + let input_detectors = detectors.input; + let output_detectors = detectors.output; + + validate_detectors( + input_detectors.iter().chain(output_detectors.iter()), + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + )?; + + if !input_detectors.is_empty() { + // Handle input detection + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(completion)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Return response with input detections and terminate + let response = completion.into(); + return Ok(response); + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + return Err(error); + } + } + } + + // Handle completion + let client = ctx.clients.get::("openai").unwrap(); + let completion = + match common::completion(client, task.headers.clone(), task.request.clone()).await { + Ok(CompletionsResponse::Unary(completion)) => *completion, + Ok(CompletionsResponse::Streaming(_)) => unimplemented!(), + Err(error) => return Err(error), + }; + + if !output_detectors.is_empty() { + // Handle output detection + let completion = + handle_output_detection(ctx.clone(), task, output_detectors, completion).await?; + Ok(completion.into()) + } else { + // No output detectors, send completion response + Ok(completion.into()) + } +} + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &CompletionsDetectionTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let model_id = task.request.model.clone(); + let inputs = common::apply_masks( + task.request.prompt.clone(), + task.request.prompt_masks.as_deref(), + ); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + 0, + inputs, + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Get prompt tokens for usage + let client = ctx.clients.get::("openai").unwrap(); + let tokenize_request = TokenizeRequest { + model: model_id.clone(), + prompt: Some(task.request.prompt.clone()), + ..Default::default() + }; + let tokenize_response = + common::tokenize_openai(client, task.headers.clone(), tokenize_request).await?; + let usage = Usage { + prompt_tokens: tokenize_response.count, + ..Default::default() + }; + + // Build completion with input detections + let completion = Completion { + id: Uuid::new_v4().simple().to_string(), + object: "text_completion".into(), // This value is constant: https://platform.openai.com/docs/api-reference/completions/object#completions/object-object + created: common::current_timestamp().as_secs() as i64, + model: model_id, + detections: Some(CompletionDetections { + input: vec![CompletionInputDetections { + message_index: 0, + results: detections.into(), + }], + ..Default::default() + }), + warnings: vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + usage: Some(usage), + ..Default::default() + }; + Ok(Some(completion)) + } else { + // No input detections + Ok(None) + } +} + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: CompletionsDetectionTask, + detectors: HashMap, + mut completion: Completion, +) -> Result { + let mut tasks = Vec::with_capacity(completion.choices.len()); + for choice in &completion.choices { + if choice.text.is_empty() { + completion.warnings.push(CompletionDetectionWarning::new( + DetectionWarningReason::EmptyOutput, + &format!( + "Choice of index {} has no content. Output detection was not executed", + choice.index + ), + )); + continue; + } + let input_id = choice.index; + let input_text = choice.text.clone(); + tasks.push(tokio::spawn( + common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + input_id, + vec![(0, input_text)], + ) + .in_current_span(), + )); + } + let detections = try_join_all(tasks) + .await? + .into_iter() + .collect::, Error>>()?; + if !detections.is_empty() { + // Update completion with detections + let output = detections + .into_iter() + .filter(|(_, detections)| !detections.is_empty()) + .map(|(input_id, detections)| CompletionOutputDetections { + choice_index: input_id, + results: detections.into(), + }) + .collect::>(); + if !output.is_empty() { + completion.detections = Some(CompletionDetections { + output, + ..Default::default() + }); + completion.warnings = vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )]; + } + } + Ok(completion) +} diff --git a/src/orchestrator/handlers/generation_with_detection.rs b/src/orchestrator/handlers/generation_with_detection.rs index d49ac76c..b5690e9f 100644 --- a/src/orchestrator/handlers/generation_with_detection.rs +++ b/src/orchestrator/handlers/generation_with_detection.rs @@ -55,10 +55,7 @@ impl Handle for Orchestrator { )?; // Handle generation - let client = ctx - .clients - .get_as::("generation") - .unwrap(); + let client = ctx.clients.get::("generation").unwrap(); let generation = common::generate( client, task.headers.clone(), diff --git a/src/orchestrator/handlers/streaming_classification_with_gen.rs b/src/orchestrator/handlers/streaming_classification_with_gen.rs index d1d60214..d031ca65 100644 --- a/src/orchestrator/handlers/streaming_classification_with_gen.rs +++ b/src/orchestrator/handlers/streaming_classification_with_gen.rs @@ -124,7 +124,7 @@ impl Handle for Orchestrator { // Create generation stream let client = ctx .clients - .get_as::("generation") + .get::("generation") .unwrap(); let generation_stream = match common::generate_stream( client, @@ -189,10 +189,7 @@ async fn handle_input_detection( }; if !detections.is_empty() { // Get token count - let client = ctx - .clients - .get_as::("generation") - .unwrap(); + let client = ctx.clients.get::("generation").unwrap(); let input_token_count = match common::tokenize( client, task.headers.clone(), diff --git a/src/orchestrator/types.rs b/src/orchestrator/types.rs index 089bc74f..4d690877 100644 --- a/src/orchestrator/types.rs +++ b/src/orchestrator/types.rs @@ -28,6 +28,8 @@ pub mod detection_batcher; pub use detection_batcher::*; pub mod detection_batch_stream; pub use detection_batch_stream::*; +pub mod completion_state; +pub use completion_state::*; use super::Error; use crate::{ diff --git a/src/orchestrator/types/completion_state.rs b/src/orchestrator/types/completion_state.rs new file mode 100644 index 00000000..cd22f0ba --- /dev/null +++ b/src/orchestrator/types/completion_state.rs @@ -0,0 +1,80 @@ +use std::{collections::BTreeMap, sync::OnceLock}; + +use dashmap::DashMap; + +use super::ChoiceIndex; +use crate::clients::openai::Usage; + +/// Completion state for a streaming completions task. +#[derive(Debug, Default)] +pub struct CompletionState { + /// Completion metadata. + pub metadata: OnceLock, + /// Completion chunks received for each choice. + pub completions: DashMap>, + /// Completion usage statistics. + pub usage: OnceLock, +} + +impl CompletionState +where + T: Default, +{ + pub fn new() -> Self { + Self::default() + } + + /// Sets metadata. + pub fn set_metadata(&self, id: String, created: i64, model: String) { + let _ = self.metadata.set(CompletionMetadata { id, created, model }); + } + + /// Sets usage. + pub fn set_usage(&self, usage: Usage) { + let _ = self.usage.set(usage); + } + + /// Inserts a completion. + pub fn insert_completion( + &self, + choice_index: ChoiceIndex, + message_index: usize, + completion: T, + ) { + match self.completions.entry(choice_index) { + dashmap::Entry::Occupied(mut entry) => { + entry.get_mut().insert(message_index, completion); + } + dashmap::Entry::Vacant(entry) => { + entry.insert(BTreeMap::from([(message_index, completion)])); + } + } + } + + pub fn id(&self) -> Option<&str> { + self.metadata.get().map(|v| v.id.as_ref()) + } + + pub fn created(&self) -> Option { + self.metadata.get().map(|v| v.created) + } + + pub fn model(&self) -> Option<&str> { + self.metadata.get().map(|v| v.model.as_ref()) + } + + pub fn usage(&self) -> Option<&Usage> { + self.usage.get() + } +} + +/// Completion metadata common to all chunks. +#[derive(Debug, Default)] +pub struct CompletionMetadata { + /// A unique identifier for the completion. + pub id: String, + /// The Unix timestamp (in seconds) of when the completion was created. + pub created: i64, + /// The model to generate the completion. + pub model: String, +} diff --git a/src/orchestrator/types/detection_batcher.rs b/src/orchestrator/types/detection_batcher.rs index 71b0be69..63403d3c 100644 --- a/src/orchestrator/types/detection_batcher.rs +++ b/src/orchestrator/types/detection_batcher.rs @@ -14,8 +14,8 @@ limitations under the License. */ -pub mod chat_completion; -pub use chat_completion::*; +pub mod completion; +pub use completion::*; pub mod max_processed_index; pub use max_processed_index::*; diff --git a/src/orchestrator/types/detection_batcher/chat_completion.rs b/src/orchestrator/types/detection_batcher/completion.rs similarity index 97% rename from src/orchestrator/types/detection_batcher/chat_completion.rs rename to src/orchestrator/types/detection_batcher/completion.rs index 5a8fc19a..988d1c29 100644 --- a/src/orchestrator/types/detection_batcher/chat_completion.rs +++ b/src/orchestrator/types/detection_batcher/completion.rs @@ -20,7 +20,7 @@ use super::{Batch, Chunk, DetectionBatcher, Detections}; pub type ChoiceIndex = u32; -/// A batcher for chat completions. +/// A batcher for completions. /// /// A batch corresponds to a choice-chunk (where each chunk is associated /// with a particular choice through a ChoiceIndex). Batches are returned @@ -36,14 +36,14 @@ pub type ChoiceIndex = u32; /// /// This batcher requires that all detectors use the same chunker. #[derive(Debug, Clone)] -pub struct ChatCompletionBatcher { +pub struct CompletionBatcher { n_detectors: usize, // We place the chunk first since chunk ordering includes where // the chunk is in all the processed messages. state: BTreeMap<(Chunk, ChoiceIndex), Vec>, } -impl ChatCompletionBatcher { +impl CompletionBatcher { pub fn new(n_detectors: usize) -> Self { Self { n_detectors, @@ -52,7 +52,7 @@ impl ChatCompletionBatcher { } } -impl DetectionBatcher for ChatCompletionBatcher { +impl DetectionBatcher for CompletionBatcher { fn push(&mut self, choice_index: ChoiceIndex, chunk: Chunk, detections: Detections) { match self.state.entry((chunk, choice_index)) { btree_map::Entry::Vacant(entry) => { @@ -118,7 +118,7 @@ mod test { // Create a batcher that will process batches for 2 detectors let n_detectors = 2; - let mut batcher = ChatCompletionBatcher::new(n_detectors); + let mut batcher = CompletionBatcher::new(n_detectors); // Push chunk detections for pii detector batcher.push( @@ -203,7 +203,7 @@ mod test { // Create a batcher that will process batches for 2 detectors let n_detectors = 2; - let mut batcher = ChatCompletionBatcher::new(n_detectors); + let mut batcher = CompletionBatcher::new(n_detectors); for choice_index in 0..choices { // Push chunk-2 detections for pii detector @@ -326,7 +326,7 @@ mod test { // Create a batcher that will process batches for 2 detectors let n_detectors = 2; - let mut batcher = ChatCompletionBatcher::new(n_detectors); + let mut batcher = CompletionBatcher::new(n_detectors); // Intersperse choice detections // NOTE: There may be an edge case when chunk-2 (or later) detections are pushed @@ -465,7 +465,7 @@ mod test { // Create a batcher that will process batches for 2 detectors let n_detectors = 2; - let batcher = ChatCompletionBatcher::new(n_detectors); + let batcher = CompletionBatcher::new(n_detectors); // Create detection batch stream let streams = vec![pii_detections_stream, hap_detections_stream]; diff --git a/src/server.rs b/src/server.rs index f9336dc5..4c844a62 100644 --- a/src/server.rs +++ b/src/server.rs @@ -148,8 +148,11 @@ mod tests { Orchestrator::default(), ) .await; - assert!(result.is_err_and(|error| matches!(error, Error::IoError(_)) - && error.to_string().starts_with("Address already in use"))); + assert!(result.is_err_and(|error| { + error + .to_string() + .starts_with("io error: Address already in use") + })); Ok(()) } diff --git a/src/server/errors.rs b/src/server/errors.rs index b5b54b32..48a2fb99 100644 --- a/src/server/errors.rs +++ b/src/server/errors.rs @@ -22,111 +22,113 @@ use axum::{ response::{IntoResponse, Response}, }; use http::StatusCode; +use serde::{Deserialize, Serialize}; use crate::{models::ValidationError, orchestrator}; /// High-level errors to return to clients. -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("{0}")] - Validation(String), - #[error("{0}")] - NotFound(String), - #[error("{0}")] - ServiceUnavailable(String), - #[error("unexpected error occurred while processing request")] - Unexpected, - #[error(transparent)] - JsonExtractorRejection(#[from] JsonRejection), - #[error("{0}")] - JsonError(String), - #[error("unsupported content type: {0}")] - UnsupportedContentType(String), - #[error(transparent)] - IoError(#[from] std::io::Error), +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Error { + #[serde(with = "http_serde::status_code")] + pub code: StatusCode, + pub details: String, } -impl From for Error { - fn from(value: orchestrator::Error) -> Self { - use orchestrator::Error::*; - match value { - DetectorNotFound(_) | ChunkerNotFound(_) => Self::NotFound(value.to_string()), - DetectorRequestFailed { ref error, .. } - | ChunkerRequestFailed { ref error, .. } - | GenerateRequestFailed { ref error, .. } - | ChatCompletionRequestFailed { ref error, .. } - | TokenizeRequestFailed { ref error, .. } => match error.status_code() { - StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY => { - Self::Validation(value.to_string()) - } - StatusCode::NOT_FOUND => Self::NotFound(value.to_string()), - StatusCode::SERVICE_UNAVAILABLE => Self::ServiceUnavailable(value.to_string()), - _ => Self::Unexpected, - }, - JsonError(message) => Self::JsonError(message), - Validation(message) => Self::Validation(message), - _ => Self::Unexpected, - } +impl Error { + pub fn code(&self) -> &StatusCode { + &self.code + } + + pub fn details(&self) -> &str { + &self.details } } -impl Error { - pub fn to_json(self) -> serde_json::Value { - use Error::*; - let (code, message) = match self { - Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), - ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), - UnsupportedContentType(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()), - Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), - JsonExtractorRejection(json_rejection) => match json_rejection { - JsonRejection::JsonDataError(e) => { - // Get lower-level serde error message - let message = e.source().map(|e| e.to_string()).unwrap_or_default(); - (e.status(), message) - } - _ => (json_rejection.status(), json_rejection.body_text()), - }, - JsonError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - IoError(error) => (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()), - }; - serde_json::json!({ - "code": code.as_u16(), - "details": message, - }) +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.details()) } } +impl std::error::Error for Error {} + impl IntoResponse for Error { fn into_response(self) -> Response { - use Error::*; - let (code, message) = match self { - Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()), - ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), - UnsupportedContentType(_) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string()), - Unexpected => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), - JsonExtractorRejection(json_rejection) => match json_rejection { - JsonRejection::JsonDataError(e) => { - // Get lower-level serde error message - let message = e.source().map(|e| e.to_string()).unwrap_or_default(); - (e.status(), message) - } - _ => (json_rejection.status(), json_rejection.body_text()), - }, - JsonError(_) => (StatusCode::UNPROCESSABLE_ENTITY, self.to_string()), - IoError(error) => (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()), + let code = *self.code(); + (code, Json(self)).into_response() + } +} + +impl From for Error { + fn from(value: JsonRejection) -> Self { + use JsonRejection::*; + let code = value.status(); + let details = match value { + JsonDataError(error) => { + // Get lower-level serde error message + error.source().map(|e| e.to_string()).unwrap_or_default() + } + _ => value.body_text(), }; - let error = serde_json::json!({ - "code": code.as_u16(), - "details": message, - }); - (code, Json(error)).into_response() + Self { code, details } + } +} + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Self { + code: StatusCode::INTERNAL_SERVER_ERROR, + details: format!("io error: {value}"), + } } } impl From for Error { fn from(value: ValidationError) -> Self { - Self::Validation(value.to_string()) + Self { + code: StatusCode::UNPROCESSABLE_ENTITY, + details: value.to_string(), + } + } +} + +impl From for Error { + fn from(value: orchestrator::Error) -> Self { + use orchestrator::Error::*; + match value { + DetectorNotFound(_) | ChunkerNotFound(_) => Self { + code: StatusCode::NOT_FOUND, + details: value.to_string(), + }, + DetectorRequestFailed { ref error, .. } + | ChunkerRequestFailed { ref error, .. } + | GenerateRequestFailed { ref error, .. } + | ChatCompletionRequestFailed { ref error, .. } + | CompletionRequestFailed { ref error, .. } + | TokenizeRequestFailed { ref error, .. } + | Client(ref error) => match error.status_code() { + // return actual error for subset of errors + StatusCode::BAD_REQUEST + | StatusCode::UNPROCESSABLE_ENTITY + | StatusCode::NOT_FOUND + | StatusCode::SERVICE_UNAVAILABLE => Self { + code: error.status_code(), + details: value.to_string(), + }, + // return generic error for other errors + _ => Self { + code: StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }, + }, + JsonError(message) | Validation(message) => Self { + code: StatusCode::UNPROCESSABLE_ENTITY, + details: message, + }, + _ => Self { + code: StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }, + } } } diff --git a/src/server/routes.rs b/src/server/routes.rs index 55ee5331..07cba593 100644 --- a/src/server/routes.rs +++ b/src/server/routes.rs @@ -41,11 +41,16 @@ use tracing::info; use super::{Error, ServerState}; use crate::{ - clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, + clients::openai::{ + ChatCompletionsRequest, ChatCompletionsResponse, CompletionsRequest, CompletionsResponse, + }, models::{self, InfoParams, InfoResponse, StreamingContentDetectionRequest}, orchestrator::{ self, - handlers::{chat_completions_detection::ChatCompletionsDetectionTask, *}, + handlers::{ + chat_completions_detection::ChatCompletionsDetectionTask, + completions_detection::CompletionsDetectionTask, *, + }, }, utils::{self, trace::current_trace_id}, }; @@ -89,12 +94,18 @@ pub fn guardrails_router(state: Arc) -> Router { post(detect_context_documents), ) .route("/api/v2/text/detection/generated", post(detect_generated)); - if state.orchestrator.config().chat_generation.is_some() { + if state.orchestrator.config().openai.is_some() { info!("Enabling chat completions detection endpoint"); router = router.route( "/api/v2/chat/completions-detection", post(chat_completions_detection), ); + + info!("Enabling completions detection endpoint"); + router = router.route( + "/api/v2/text/completions-detection", + post(completions_detection), + ); } router.with_state(state) } @@ -161,7 +172,7 @@ async fn stream_classification_with_gen( return Sse::new( stream::iter([Ok(Event::default() .event("error") - .json_data(error.to_json()) + .json_data(error) .unwrap())]) .boxed(), ); @@ -178,10 +189,7 @@ async fn stream_classification_with_gen( .unwrap()), Err(error) => { let error: Error = error.into(); - Ok(Event::default() - .event("error") - .json_data(error.to_json()) - .unwrap()) + Ok(Event::default().event("error").json_data(error).unwrap()) } }) .boxed(); @@ -202,9 +210,10 @@ async fn stream_content_detection( match content_type { Some(content_type) if content_type.starts_with("application/x-ndjson") => (), _ => { - return Err(Error::UnsupportedContentType( - "expected application/x-ndjson".into(), - )); + return Err(Error { + code: http::StatusCode::UNSUPPORTED_MEDIA_TYPE, + details: "expected application/x-ndjson".into(), + }); } }; let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); @@ -242,8 +251,7 @@ async fn stream_content_detection( Err(error) => { // Convert orchestrator::Error to server::Error let error: Error = error.into(); - // server::Error doesn't impl Serialize, so we use to_json() - let error_msg = utils::json::to_nd_string(&error.to_json()).unwrap(); + let error_msg = utils::json::to_nd_string(&error).unwrap(); let _ = output_tx.send(Ok(error_msg)).await; } } @@ -344,10 +352,44 @@ async fn chat_completions_detection( } Err(error) => { let error: Error = error.into(); - Ok(Event::default() - .event("error") - .json_data(error.to_json()) - .unwrap()) + Ok(Event::default().event("error").json_data(error).unwrap()) + } + }) + .boxed(); + let sse = Sse::new(event_stream).keep_alive(KeepAlive::default()); + Ok(sse.into_response()) + } + }, + Err(error) => Err(error.into()), + } +} + +async fn completions_detection( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + use CompletionsResponse::*; + let trace_id = current_trace_id(); + request.validate()?; + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = CompletionsDetectionTask::new(trace_id, request, headers); + match state.orchestrator.handle(task).await { + Ok(response) => match response { + Unary(response) => Ok(Json(response).into_response()), + Streaming(response_rx) => { + let response_stream = ReceiverStream::new(response_rx); + // Convert response stream to a stream of SSE events + let event_stream: BoxStream> = response_stream + .map(|message| match message { + Ok(Some(chunk)) => Ok(Event::default().json_data(chunk).unwrap()), + Ok(None) => { + // The stream completed, send [DONE] message + Ok(Event::default().data("[DONE]")) + } + Err(error) => { + let error: Error = error.into(); + Ok(Event::default().event("error").json_data(error).unwrap()) } }) .boxed(); diff --git a/src/server/tls.rs b/src/server/tls.rs index 079262d6..5a51c7d3 100644 --- a/src/server/tls.rs +++ b/src/server/tls.rs @@ -20,11 +20,11 @@ use axum::{Router, extract::Request}; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; use rustls::{RootCertStore, ServerConfig, server::WebPkiClientVerifier}; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; use tower::Service; use tracing::{debug, error, info, warn}; -use webpki::types::{CertificateDer, PrivateKeyDer}; /// Loads certificates and configures TLS. pub fn configure_tls( @@ -43,14 +43,12 @@ pub fn configure_tls( for client_cert in client_certs { client_auth_certs .add(client_cert.clone()) - .unwrap_or_else(|e| { - panic!("error adding client cert {:?}: {}", client_cert, e) - }); + .unwrap_or_else(|e| panic!("error adding client cert {client_cert:?}: {e}")); } info!("mTLS enabled"); WebPkiClientVerifier::builder(client_auth_certs.into()) .build() - .unwrap_or_else(|e| panic!("error building client verifier: {}", e)) + .unwrap_or_else(|e| panic!("error building client verifier: {e}")) } else { info!("TLS enabled"); WebPkiClientVerifier::no_client_auth() @@ -165,8 +163,5 @@ fn load_private_key(filename: &PathBuf) -> PrivateKeyDer<'static> { _ => {} } } - panic!( - "no keys found in {:?} (encrypted keys not supported)", - filename - ); + panic!("no keys found in {filename:?} (encrypted keys not supported)"); } diff --git a/src/utils.rs b/src/utils.rs index 7fd86dc2..02d2b623 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,5 @@ use hyper::Uri; +use serde::{Deserialize, Deserializer, de::DeserializeOwned}; use url::Url; pub mod json; pub mod tls; @@ -14,3 +15,22 @@ impl AsUriExt for Url { Uri::try_from(self.to_string()).unwrap() } } + +/// Serde helper to deserialize one or many [`T`] to [`Vec`]. +pub fn one_or_many<'de, T, D>(deserializer: D) -> Result, D::Error> +where + T: DeserializeOwned, + D: Deserializer<'de>, +{ + #[derive(Deserialize)] + #[serde(untagged)] + enum OneOrMany { + One(T), + Many(Vec), + } + let v: OneOrMany = Deserialize::deserialize(deserializer)?; + match v { + OneOrMany::One(value) => Ok(vec![value]), + OneOrMany::Many(values) => Ok(values), + } +} diff --git a/src/utils/tls.rs b/src/utils/tls.rs index 27928155..94834384 100644 --- a/src/utils/tls.rs +++ b/src/utils/tls.rs @@ -31,7 +31,7 @@ impl Error { pub fn into_client_error(self) -> clients::Error { clients::Error::Http { code: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("client TLS configuration failed: {}", self), + message: format!("client TLS configuration failed: {self}"), } } } diff --git a/src/utils/trace.rs b/src/utils/trace.rs index 334f1198..fcd59f1b 100644 --- a/src/utils/trace.rs +++ b/src/utils/trace.rs @@ -17,18 +17,18 @@ use std::time::Duration; +use anyhow::Context; use axum::{extract::Request, http::HeaderMap, response::Response}; use opentelemetry::{ - KeyValue, global, - trace::{TraceContextExt, TraceError, TraceId, TracerProvider}, + global, + trace::{TraceContextExt, TraceId, TracerProvider}, }; use opentelemetry_http::{HeaderExtractor, HeaderInjector}; -use opentelemetry_otlp::{MetricExporter, SpanExporter, WithExportConfig, WithHttpConfig}; +use opentelemetry_otlp::{MetricExporter, SpanExporter, WithExportConfig}; use opentelemetry_sdk::{ Resource, - metrics::{MetricError, PeriodicReader, SdkMeterProvider}, + metrics::{PeriodicReader, SdkMeterProvider}, propagation::TraceContextPropagator, - runtime, trace::Sampler, }; use tracing::{Span, error, info, info_span}; @@ -40,26 +40,20 @@ use crate::{ clients::http::TracedResponse, }; -#[derive(Debug, thiserror::Error)] -pub enum TracingError { - #[error("Error from tracing provider: {0}")] - TraceError(#[from] TraceError), - #[error("Error from metrics provider: {0}")] - MetricError(#[from] MetricError), -} +pub const DEFAULT_GRPC_OTLP_ENDPOINT: &str = "http://localhost:4317"; +pub const DEFAULT_HTTP_OTLP_ENDPOINT: &str = "http://localhost:4318"; fn resource(tracing_config: TracingConfig) -> Resource { - Resource::new(vec![KeyValue::new( - "service.name", - tracing_config.service_name, - )]) + Resource::builder() + .with_service_name(tracing_config.service_name) + .build() } /// Initializes an OpenTelemetry tracer provider with an OTLP export pipeline based on the /// provided config. fn init_tracer_provider( tracing_config: TracingConfig, -) -> Result, TracingError> { +) -> Result, anyhow::Error> { if let Some((protocol, endpoint)) = tracing_config.clone().traces { let timeout = Duration::from_secs(3); let exporter = match protocol { @@ -67,17 +61,18 @@ fn init_tracer_provider( .with_tonic() .with_endpoint(endpoint) .with_timeout(timeout) - .build()?, + .build() + .context("Failed to build gRPC span exporter")?, OtlpProtocol::Http => SpanExporter::builder() .with_http() - .with_http_client(reqwest::Client::new()) .with_endpoint(endpoint) .with_timeout(timeout) - .build()?, + .build() + .context("Failed to build HTTP span exporter")?, }; Ok(Some( - opentelemetry_sdk::trace::TracerProvider::builder() - .with_batch_exporter(exporter, runtime::Tokio) + opentelemetry_sdk::trace::SdkTracerProvider::builder() + .with_batch_exporter(exporter) .with_resource(resource(tracing_config)) .with_sampler(Sampler::AlwaysOn) .build(), @@ -86,7 +81,7 @@ fn init_tracer_provider( // We still need a tracing provider as long as we are logging in order to enable any // trace-sensitive logs, such as any mentions of a request's trace_id. Ok(Some( - opentelemetry_sdk::trace::TracerProvider::builder() + opentelemetry_sdk::trace::SdkTracerProvider::builder() .with_resource(resource(tracing_config)) .with_sampler(Sampler::AlwaysOn) .build(), @@ -100,7 +95,7 @@ fn init_tracer_provider( /// provided config. fn init_meter_provider( tracing_config: TracingConfig, -) -> Result, TracingError> { +) -> Result, anyhow::Error> { if let Some((protocol, endpoint)) = tracing_config.clone().metrics { // Note: DefaultAggregationSelector removed from OpenTelemetry SDK as of 0.26.0 // as custom aggregation should be available in Views. Cumulative temporality is default. @@ -110,15 +105,16 @@ fn init_meter_provider( .with_tonic() .with_endpoint(endpoint) .with_timeout(timeout) - .build()?, + .build() + .context("Failed to build OTel gRPC metric exporter")?, OtlpProtocol::Http => MetricExporter::builder() .with_http() - .with_http_client(reqwest::Client::new()) .with_endpoint(endpoint) .with_timeout(timeout) - .build()?, + .build() + .context("Failed to build OTel HTTP metric exporter")?, }; - let reader = PeriodicReader::builder(exporter, runtime::Tokio) + let reader = PeriodicReader::builder(exporter) .with_interval(Duration::from_secs(3)) .build(); Ok(Some( @@ -136,7 +132,7 @@ fn init_meter_provider( /// crate. What telemetry is exported and to where is determined based on the provided config pub fn init_tracing( tracing_config: TracingConfig, -) -> Result Result<(), TracingError>, TracingError> { +) -> Result Result<(), anyhow::Error>, anyhow::Error> { let mut layers = Vec::new(); global::set_text_map_propagator(TraceContextPropagator::new()); @@ -153,7 +149,8 @@ pub fn init_tracing( .add_directive("reqwest=error".parse().unwrap()); // Set up tracing layer with OTLP exporter - let trace_provider = init_tracer_provider(tracing_config.clone())?; + let trace_provider = init_tracer_provider(tracing_config.clone()) + .context("Failed to initialize tracer provider")?; if let Some(tracer_provider) = trace_provider.clone() { global::set_tracer_provider(tracer_provider.clone()); layers.push( @@ -164,7 +161,8 @@ pub fn init_tracing( } // Set up metrics layer with OTLP exporter - let meter_provider = init_meter_provider(tracing_config.clone())?; + let meter_provider = init_meter_provider(tracing_config.clone()) + .context("Failed to initialize meter provider")?; if let Some(meter_provider) = meter_provider.clone() { global::set_meter_provider(meter_provider.clone()); layers.push(MetricsLayer::new(meter_provider).boxed()); @@ -190,39 +188,33 @@ pub fn init_tracing( let subscriber = tracing_subscriber::registry().with(filter).with(layers); tracing::subscriber::set_global_default(subscriber).unwrap(); - if let Some(traces) = tracing_config.traces { - info!( - "OTLP tracing enabled: Exporting {} to {}", - traces.0, traces.1 - ); + if let Some((protocol, endpoint)) = tracing_config.traces { + info!(%protocol, %endpoint, "OTLP traces enabled"); } else { - info!("OTLP traces export disabled") + info!("OTLP traces disabled") } - - if let Some(metrics) = tracing_config.metrics { - info!( - "OTLP metrics enabled: Exporting {} to {}", - metrics.0, metrics.1 - ); + if let Some((protocol, endpoint)) = tracing_config.metrics { + info!(%protocol, %endpoint, "OTLP metrics enabled"); } else { - info!("OTLP metrics export disabled") + info!("OTLP metrics disabled") } if !tracing_config.quiet { - info!( - "Stdout logging enabled with format {}", - tracing_config.log_format - ); + info!(format = %tracing_config.log_format, "stdout logging enabled"); } else { - info!("Stdout logging disabled"); // This will only be visible in traces + info!("stdout logging disabled"); // This will only be visible in traces } Ok(move || { - global::shutdown_tracer_provider(); + if let Some(trace_provider) = trace_provider { + trace_provider + .shutdown() + .context("Failed to shutdown tracer provider")?; + } if let Some(meter_provider) = meter_provider { meter_provider .shutdown() - .map_err(TracingError::MetricError)?; + .context("Failed to shutdown meter provider")?; } Ok(()) }) diff --git a/tests/chat_completions_streaming.rs b/tests/chat_completions_streaming.rs new file mode 100644 index 00000000..3838b4c2 --- /dev/null +++ b/tests/chat_completions_streaming.rs @@ -0,0 +1,5097 @@ +pub mod common; +use common::orchestrator::*; +use fms_guardrails_orchestr8::{ + clients::{ + detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + openai::{ + ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionDelta, + ChatCompletionLogprob, ChatCompletionLogprobs, CompletionDetections, + CompletionInputDetections, CompletionOutputDetections, Content, Message, OpenAiError, + OpenAiErrorMessage, Role, TokenizeResponse, Usage, + }, + }, + models::DetectorParams, + pb::{ + caikit::runtime::chunkers::{ + BidiStreamingChunkerTokenizationTaskRequest, ChunkerTokenizationTaskRequest, + }, + caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token, TokenizationResults}, + }, +}; +use futures::{StreamExt, TryStreamExt}; +use mocktail::prelude::*; +use serde_json::json; +use test_log::test; +use tracing::debug; + +use crate::common::{ + chunker::{CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_STREAMING_ENDPOINT, CHUNKER_UNARY_ENDPOINT}, + detectors::{PII_DETECTOR_SENTENCE, PII_DETECTOR_WHOLE_DOC, TEXT_CONTENTS_DETECTOR_ENDPOINT}, + openai::{CHAT_COMPLETIONS_ENDPOINT, TOKENIZE_ENDPOINT}, + sse, +}; + +#[test(tokio::test)] +async fn no_detectors() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when + .post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::Assistant, content: Some(Content::Text("You are a helpful assistant.".into())), ..Default::default()}, + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ] + })); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Hey".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("!".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::Assistant, content: Some(Content::Text("You are a helpful assistant.".into())), ..Default::default()}, + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + assert_eq!(messages.len(), 4); + assert_eq!( + messages[0].choices[0].delta.role, + Some(Role::Assistant), + "missing role message" + ); + assert_eq!(messages[1].choices[0].delta.content, Some("Hey".into())); + assert_eq!(messages[2].choices[0].delta.content, Some("!".into())); + assert_eq!( + messages[3].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn no_detectors_n2() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ], + "n": 2, + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Hey".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some("Hey".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("!".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "n": 2, + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ] + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + assert_eq!(messages.len(), 7); + + // Validate role messages for both choices + assert_eq!(messages[0].choices[0].index, 0); + assert_eq!( + messages[0].choices[0].delta.role, + Some(Role::Assistant), + "choice0: missing role message" + ); + assert_eq!(messages[1].choices[0].index, 1); + assert_eq!( + messages[1].choices[0].delta.role, + Some(Role::Assistant), + "choice1: missing role message" + ); + + // Validate content messages for both choices + assert_eq!(messages[2].choices[0].index, 0); + assert_eq!(messages[2].choices[0].delta.content, Some("Hey".into())); + assert_eq!(messages[3].choices[0].index, 1); + assert_eq!(messages[3].choices[0].delta.content, Some("Hey".into())); + assert_eq!(messages[4].choices[0].index, 0); + assert_eq!(messages[4].choices[0].delta.content, Some("!".into())); + + // Validate stop messages for both choices + assert_eq!(messages[5].choices[0].index, 0); + assert_eq!( + messages[5].choices[0].finish_reason, + Some("stop".into()), + "choice0: missing finish reason message" + ); + assert_eq!(messages[6].choices[0].index, 1); + assert_eq!( + messages[6].choices[0].finish_reason, + Some("stop".into()), + "choice1: missing finish reason message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn input_detectors() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(TOKENIZE_ENDPOINT) + .json(json!({ + "model": "test-0B", + "prompt": "Here is my social security number: 123-45-6789. Can you generate another one like it?", + })); + then.json(&TokenizeResponse { + count: 24, + ..Default::default() + }); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb(ChunkerTokenizationTaskRequest { text: "Here is my social security number: 123-45-6789. Can you generate another one like it?".into() }); + then.pb(TokenizationResults { + results: vec![ + Token { start: 0, end: 47, text: "Here is my social security number: 123-45-6789.".into() }, + Token { start: 48, end: 85, text: "Can you generate another one like it?".into() }, + ], + token_count: 0 + }); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec![ + "Here is my social security number: 123-45-6789.".into(), + "Can you generate another one like it?".into(), + ], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 35, + "end": 46, + "detection": "NationalNumber.SocialSecurityNumber.US", + "detection_type": "pii", + "score": 0.8, + "text": "123-45-6789", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": { + "pii_detector_sentence": {} + }, + "output": {} + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Here is my social security number: 123-45-6789. Can you generate another one like it?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + + // Validate length + assert_eq!(messages.len(), 1, "unexpected number of messages"); + + // Validate input detections + assert_eq!( + messages[0].detections, + Some(CompletionDetections { + input: vec![CompletionInputDetections { + message_index: 0, + results: vec![ContentAnalysisResponse { + start: 35, + end: 46, + text: "123-45-6789".into(), + detection: "NationalNumber.SocialSecurityNumber.US".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + output: vec![], + }), + "unexpected input detections" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn output_detectors() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ] + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 1, + input_end_index: 8, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 51, + text: "\n\n1. (503) 272-8192".into(), + }], + token_count: 0, + processed_index: 51, + start_index: 32, + input_start_index: 9, + input_end_index: 9, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 51, + end: 70, + text: "\n2. (617) 985-3519.".into(), + }], + token_count: 0, + processed_index: 70, + start_index: 51, + input_start_index: 10, + input_end_index: 10, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n1. (503) 272-8192".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 5, + "end": 19, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n2. (617) 985-3519.".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 4, + "end": 18, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + // Validate length + assert_eq!(messages.len(), 4, "unexpected number of messages"); + + // Validate msg-0 choices + assert_eq!( + messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("Here are 2 random phone numbers:".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-0" + ); + // Validate msg-0 detections + assert_eq!( + messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![], + }], + }), + "unexpected detections for msg-0" + ); + + // Validate msg-1 choices + assert_eq!( + messages[1].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n\n1. (503) 272-8192".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-1" + ); + // Validate msg-2 detections + assert_eq!( + messages[1].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 5, + end: 19, + text: "(503) 272-8192".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-1" + ); + + // Validate msg-2 choices + assert_eq!( + messages[2].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n2. (617) 985-3519.".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-2" + ); + // Validate msg-2 detections + assert_eq!( + messages[2].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 4, + end: 18, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-2" + ); + + // Validate finish reason message + assert_eq!( + messages[3].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn output_detectors_with_logprobs() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + "logprobs": true, + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: "Here".into(), + logprob: -0.021, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: " are".into(), + logprob: -0.011, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: " ".into(), + logprob: -0.001, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: "2".into(), + logprob: -0.003, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: " random".into(), + logprob: -0.044, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: " phone".into(), + logprob: -0.004, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: " numbers".into(), + logprob: -0.005, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: ":\n\n".into(), + logprob: -0.001, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: "1. (503) 272-8192\n".into(), + logprob: -0.066, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: "2. (617) 985-3519.".into(), + logprob: -0.055, + bytes: None, + top_logprobs: None, + }], + ..Default::default() + }), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 1, + input_end_index: 8, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 51, + text: "\n\n1. (503) 272-8192".into(), + }], + token_count: 0, + processed_index: 51, + start_index: 32, + input_start_index: 9, + input_end_index: 9, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 51, + end: 70, + text: "\n2. (617) 985-3519.".into(), + }], + token_count: 0, + processed_index: 70, + start_index: 51, + input_start_index: 10, + input_end_index: 10, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n1. (503) 272-8192".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 5, + "end": 19, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n2. (617) 985-3519.".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 4, + "end": 18, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "logprobs": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + // Validate length + assert_eq!(messages.len(), 4, "unexpected number of messages"); + + // Validate msg-0 choices + assert_eq!( + messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("Here are 2 random phone numbers:".into(),), + refusal: None, + tool_calls: vec![], + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ + ChatCompletionLogprob { + token: "Here".into(), + logprob: -0.021, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: " are".into(), + logprob: -0.011, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: " ".into(), + logprob: -0.001, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: "2".into(), + logprob: -0.003, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: " random".into(), + logprob: -0.044, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: " phone".into(), + logprob: -0.004, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: " numbers".into(), + logprob: -0.005, + bytes: None, + top_logprobs: None, + }, + ChatCompletionLogprob { + token: ":\n\n".into(), + logprob: -0.001, + bytes: None, + top_logprobs: None, + }, + ], + refusal: vec![], + },), + ..Default::default() + }], + "unexpected choices for msg-0" + ); + // Validate msg-0 detections + assert_eq!( + messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![], + }], + }), + "unexpected detections for msg-0" + ); + + // Validate msg-1 choices + assert_eq!( + messages[1].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n\n1. (503) 272-8192".into(),), + refusal: None, + tool_calls: vec![], + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: "1. (503) 272-8192\n".into(), + logprob: -0.066, + bytes: None, + top_logprobs: None, + },], + refusal: vec![], + },), + ..Default::default() + }], + "unexpected choices for msg-1" + ); + // Validate msg-1 detections + assert_eq!( + messages[1].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 5, + end: 19, + text: "(503) 272-8192".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-1" + ); + + // Validate msg-2 choices + assert_eq!( + messages[2].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n2. (617) 985-3519.".into(),), + refusal: None, + tool_calls: vec![], + }, + logprobs: Some(ChatCompletionLogprobs { + content: vec![ChatCompletionLogprob { + token: "2. (617) 985-3519.".into(), + logprob: -0.055, + bytes: None, + top_logprobs: None, + },], + refusal: vec![], + },), + ..Default::default() + }], + "unexpected choices for msg-2" + ); + // Validate msg-2 detections + assert_eq!( + messages[2].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 4, + end: 18, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-2" + ); + + // Validate finish reason message + assert_eq!( + messages[3].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn output_detectors_with_usage() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + "stream_options": { + "include_usage": true + } + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 1, + input_end_index: 8, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 51, + text: "\n\n1. (503) 272-8192".into(), + }], + token_count: 0, + processed_index: 51, + start_index: 32, + input_start_index: 9, + input_end_index: 9, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 51, + end: 70, + text: "\n2. (617) 985-3519.".into(), + }], + token_count: 0, + processed_index: 70, + start_index: 51, + input_start_index: 10, + input_end_index: 10, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n1. (503) 272-8192".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 5, + "end": 19, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n2. (617) 985-3519.".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 4, + "end": 18, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + "stream_options": { + "include_usage": true + } + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + // Validate length + assert_eq!(messages.len(), 5, "unexpected number of messages"); + + // Validate msg-0 choices + assert_eq!( + messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("Here are 2 random phone numbers:".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-0" + ); + // Validate msg-0 detections + assert_eq!( + messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![], + }], + }), + "unexpected detections for msg-0" + ); + + // Validate msg-1 choices + assert_eq!( + messages[1].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n\n1. (503) 272-8192".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-1" + ); + // Validate msg-2 detections + assert_eq!( + messages[1].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 5, + end: 19, + text: "(503) 272-8192".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-1" + ); + + // Validate msg-2 choices + assert_eq!( + messages[2].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n2. (617) 985-3519.".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-2" + ); + // Validate msg-2 detections + assert_eq!( + messages[2].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 4, + end: 18, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-2" + ); + + // Validate finish reason message + assert_eq!( + messages[3].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + + // Validate final usage message + assert_eq!( + messages[4].usage, + Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + "unexpected usage for final usage message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn output_detectors_with_continuous_usage_stats() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + "stream_options": { + "include_usage": true, + "continuous_usage_stats": true + } + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 19, + completion_tokens: 0, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 20, + completion_tokens: 1, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 21, + completion_tokens: 2, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 22, + completion_tokens: 3, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 23, + completion_tokens: 4, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 24, + completion_tokens: 5, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 25, + completion_tokens: 6, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 26, + completion_tokens: 7, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 28, + completion_tokens: 8, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 38, + completion_tokens: 19, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + usage: Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 1, + input_end_index: 8, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 51, + text: "\n\n1. (503) 272-8192".into(), + }], + token_count: 0, + processed_index: 51, + start_index: 32, + input_start_index: 9, + input_end_index: 9, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 51, + end: 70, + text: "\n2. (617) 985-3519.".into(), + }], + token_count: 0, + processed_index: 70, + start_index: 51, + input_start_index: 10, + input_end_index: 10, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n1. (503) 272-8192".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 5, + "end": 19, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n2. (617) 985-3519.".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 4, + "end": 18, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + "stream_options": { + "include_usage": true, + "continuous_usage_stats": true + } + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + // Validate length + assert_eq!(messages.len(), 5, "unexpected number of messages"); + + // Validate msg-0 choices + assert_eq!( + messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("Here are 2 random phone numbers:".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-0" + ); + // Validate msg-0 detections + assert_eq!( + messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![], + }], + }), + "unexpected detections for msg-0" + ); + // Validate msg-0 usage + assert_eq!( + messages[0].usage, + Some(Usage { + prompt_tokens: 19, + total_tokens: 28, + completion_tokens: 8, + ..Default::default() + }), + "unexpected usage for msg-0" + ); + + // Validate msg-1 choices + assert_eq!( + messages[1].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n\n1. (503) 272-8192".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-1" + ); + // Validate msg-2 detections + assert_eq!( + messages[1].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 5, + end: 19, + text: "(503) 272-8192".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-1" + ); + // Validate msg-1 usage + assert_eq!( + messages[1].usage, + Some(Usage { + prompt_tokens: 19, + total_tokens: 38, + completion_tokens: 19, + ..Default::default() + }), + "unexpected usage for msg-1" + ); + + // Validate msg-2 choices + assert_eq!( + messages[2].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n2. (617) 985-3519.".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-2" + ); + // Validate msg-2 detections + assert_eq!( + messages[2].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 4, + end: 18, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-2" + ); + // Validate msg-2 usage + assert_eq!( + messages[2].usage, + Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + "unexpected usage for msg-2" + ); + + // Validate finish reason message + assert_eq!( + messages[3].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + // Validate finish reason message usage + assert_eq!( + messages[3].usage, + Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + "unexpected usage for finish reason message" + ); + + // Validate final usage message + assert_eq!( + messages[4].usage, + Some(Usage { + prompt_tokens: 19, + total_tokens: 49, + completion_tokens: 30, + ..Default::default() + }), + "unexpected usage for final usage message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn output_detectors_n2() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + "n": 2, + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { // 0 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 1 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 2 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 3 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 4 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 5 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 6 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" two".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 7 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 8 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 9 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 10 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 11 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 12 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 13 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 14 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 15 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 16 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 278-9123\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 17 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 18 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 854-6279.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 19 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some("**Phone Number 1:** (234) 567-8901\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 20 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + content: Some("**Phone Number 2:** (819) 345-2198".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 21 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { // 22 + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 1, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + // choice 0 mocks + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " two".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 10, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 12, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 14, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 278-9123\n".into(), + input_index_stream: 16, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 854-6279.".into(), + input_index_stream: 18, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 34, + text: "Here are two random phone numbers:".into(), + }], + token_count: 0, + processed_index: 34, + start_index: 0, + input_start_index: 2, + input_end_index: 14, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 34, + end: 53, + text: "\n\n1. (503) 278-9123".into(), + }], + token_count: 0, + processed_index: 53, + start_index: 34, + input_start_index: 16, + input_end_index: 16, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 53, + end: 72, + text: "\n2. (617) 854-6279.".into(), + }], + token_count: 0, + processed_index: 72, + start_index: 53, + input_start_index: 18, + input_end_index: 18, + }, + ]); + }); + // choice 1 mocks + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 11, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 13, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 15, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 17, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "**Phone Number 1:** (234) 567-8901\n\n".into(), + input_index_stream: 19, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "**Phone Number 2:** (819) 345-2198".into(), + input_index_stream: 20, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 3, + input_end_index: 17, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 68, + text: "\n\n**Phone Number 1:** (234) 567-8901".into(), + }], + token_count: 0, + processed_index: 68, + start_index: 32, + input_start_index: 19, + input_end_index: 19, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 68, + end: 104, + text: "\n\n**Phone Number 2:** (819) 345-2198".into(), + }], + token_count: 0, + processed_index: 104, + start_index: 68, + input_start_index: 20, + input_end_index: 20, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + // choice 0 mocks + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are two random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n1. (503) 278-9123".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 5, + "end": 19, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 278-9123", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n2. (617) 854-6279.".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 4, + "end": 18, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 854-6279", + "evidences": [] + } + ]])); + }); + // choice 1 mocks + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n**Phone Number 1:** (234) 567-8901".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 22, + "end": 36, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(234) 567-8901", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n**Phone Number 2:** (819) 345-2198".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 22, + "end": 36, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(819) 345-2198", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "n": 2, + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + let (choice0_messages, choice1_messages): (Vec<_>, Vec<_>) = messages + .into_iter() + .partition(|chunk| chunk.choices[0].index == 0); + + // Validate choice0 messages: + // Validate length + assert_eq!( + choice0_messages.len(), + 4, + "choice0: unexpected number of messages" + ); + // Validate msg-0 choice + assert_eq!( + choice0_messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant,), + content: Some("Here are two random phone numbers:".into()), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "choice0: unexpected msg-0 choice" + ); + // Validate msg-0 detections + assert_eq!( + choice0_messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![], + }], + }), + "choice0: unexpected msg-0 detections" + ); + // Validate finish reason message + assert!( + choice0_messages + .last() + .is_some_and(|msg| msg.choices[0].finish_reason.is_some()), + "choice0: missing finish reason message" + ); + + // Validate choice1 messages: + // Validate length + assert_eq!( + choice1_messages.len(), + 4, + "choice1: unexpected number of messages" + ); + // Validate msg-0 choice + assert_eq!( + choice1_messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 1, + delta: ChatCompletionDelta { + role: Some(Role::Assistant,), + content: Some("Here are 2 random phone numbers:".into()), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "choice1: unexpected msg-0 choice" + ); + // Validate msg-0 detections + assert_eq!( + choice1_messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 1, + results: vec![], + }], + }), + "choice1: unexpected msg-0 detections" + ); + // Validate finish reason message + assert!( + choice1_messages + .last() + .is_some_and(|msg| msg.choices[0].finish_reason.is_some()), + "choice1: missing finish reason message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn whole_doc_output_detectors() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ] + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut pii_detector_whole_doc_server = MockServer::new_http("pii_detector_whole_doc"); + pii_detector_whole_doc_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_WHOLE_DOC) + .json(ContentAnalysisRequest { + contents: vec![ + "Here are 2 random phone numbers:\n\n1. (503) 272-8192\n2. (617) 985-3519." + .into(), + ], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 37, + "end": 53, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192\n2", + "evidences": [] + }, + { + "start": 55, + "end": 69, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .detector_servers([&pii_detector_whole_doc_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_whole_doc": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + // Validate length + assert_eq!(messages.len(), 13, "unexpected number of messages"); + + // Validate finish reason message + assert_eq!( + messages[11].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + + // Validate whole doc detections message + let last = &messages[12]; + assert_eq!( + last.detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ + ContentAnalysisResponse { + start: 37, + end: 53, + text: "(503) 272-8192\n2".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_whole_doc".into()), + score: 0.8, + ..Default::default() + }, + ContentAnalysisResponse { + start: 55, + end: 69, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_whole_doc".into()), + score: 0.8, + ..Default::default() + } + ], + }], + }), + "unexpected whole doc detections message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn output_detectors_and_whole_doc_output_detectors() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ] + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 1, + input_end_index: 8, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 51, + text: "\n\n1. (503) 272-8192".into(), + }], + token_count: 0, + processed_index: 51, + start_index: 32, + input_start_index: 9, + input_end_index: 9, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 51, + end: 70, + text: "\n2. (617) 985-3519.".into(), + }], + token_count: 0, + processed_index: 70, + start_index: 51, + input_start_index: 10, + input_end_index: 10, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([[]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n\n1. (503) 272-8192".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 5, + "end": 19, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192", + "evidences": [] + } + ]])); + }); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["\n2. (617) 985-3519.".into()], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 4, + "end": 18, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let mut pii_detector_whole_doc_server = MockServer::new_http("pii_detector_whole_doc"); + pii_detector_whole_doc_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_WHOLE_DOC) + .json(ContentAnalysisRequest { + contents: vec![ + "Here are 2 random phone numbers:\n\n1. (503) 272-8192\n2. (617) 985-3519." + .into(), + ], + detector_params: DetectorParams::default(), + }); + then.json(json!([ + [ + { + "start": 37, + "end": 53, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(503) 272-8192\n2", + "evidences": [] + }, + { + "start": 55, + "end": 69, + "detection": "PhoneNumber", + "detection_type": "pii", + "score": 0.8, + "text": "(617) 985-3519", + "evidences": [] + } + ]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([ + &pii_detector_sentence_server, + &pii_detector_whole_doc_server, + ]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + "pii_detector_whole_doc": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + debug!("{messages:#?}"); + + // Validate length + assert_eq!(messages.len(), 5, "unexpected number of messages"); + + // Validate msg-0 choices + assert_eq!( + messages[0].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("Here are 2 random phone numbers:".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-0" + ); + // Validate msg-0 detections + assert_eq!( + messages[0].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![], + }], + }), + "unexpected detections for msg-0" + ); + + // Validate msg-1 choices + assert_eq!( + messages[1].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n\n1. (503) 272-8192".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-1" + ); + // Validate msg-2 detections + assert_eq!( + messages[1].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 5, + end: 19, + text: "(503) 272-8192".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-1" + ); + + // Validate msg-2 choices + assert_eq!( + messages[2].choices, + vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some("\n2. (617) 985-3519.".into(),), + refusal: None, + tool_calls: vec![], + }, + ..Default::default() + }], + "unexpected choices for msg-2" + ); + // Validate msg-2 detections + assert_eq!( + messages[2].detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ContentAnalysisResponse { + start: 4, + end: 18, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_sentence".into()), + score: 0.8, + ..Default::default() + }], + }], + }), + "unexpected detections for msg-2" + ); + + // Validate finish reason message + assert_eq!( + messages[3].choices[0].finish_reason, + Some("stop".into()), + "missing finish reason message" + ); + + // Validate whole doc detections message + let last = &messages[4]; + assert_eq!( + last.detections, + Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 0, + results: vec![ + ContentAnalysisResponse { + start: 37, + end: 53, + text: "(503) 272-8192\n2".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_whole_doc".into()), + score: 0.8, + ..Default::default() + }, + ContentAnalysisResponse { + start: 55, + end: 69, + text: "(617) 985-3519".into(), + detection: "PhoneNumber".into(), + detection_type: "pii".into(), + detector_id: Some("pii_detector_whole_doc".into()), + score: 0.8, + ..Default::default() + } + ], + }], + }), + "unexpected whole doc detections message" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn openai_bad_request_error() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ], + "prompt_logprobs": true + }) + ); + then.bad_request().json(OpenAiError { + object: Some("error".into()), + message: r#"[{'type': 'value_error', 'loc': ('body',), 'msg': 'Value error, `prompt_logprobs` are not available when `stream=True`.', 'input': {'model': 'test-0B', 'messages': [{'role': 'user', 'content': 'Hey'}],'n': 1, 'seed': 1337, 'stream': True, 'prompt_logprobs': True}, 'ctx': {'error': ValueError('`prompt_logprobs` are not available when `stream=True`.')}}]"#.into(), + r#type: Some("BadRequestError".into()), + param: None, + code: 400, + }); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": {}, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ], + "prompt_logprobs": true + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; + + // Validate length + assert_eq!(messages.len(), 1, "unexpected number of messages"); + + // Validate error message + assert!( + messages[0] + .as_ref() + .is_err_and(|e| e.code == http::StatusCode::BAD_REQUEST) + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn openai_stream_error() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ] + }) + ); + // Return an error message over the stream + then.text_stream(sse([ + OpenAiErrorMessage { + error: OpenAiError { + object: Some("error".into()), + message: "".into(), + r#type: Some("InternalServerError".into()), + param: None, + code: 500 + } + } + ])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": {}, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hey".into())), ..Default::default()}, + ] + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; + + // Validate length + assert_eq!(messages.len(), 1, "unexpected number of messages"); + + // Validate error message + assert!( + messages[0] + .as_ref() + .is_err_and(|e| e.code == StatusCode::INTERNAL_SERVER_ERROR) + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn chunker_internal_server_error() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ] + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.internal_server_error(); + }); + + let pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; + + // Validate length + assert_eq!(messages.len(), 1, "unexpected number of messages"); + + // Validate error message + assert!( + messages[0] + .as_ref() + .is_err_and(|e| e.code == StatusCode::INTERNAL_SERVER_ERROR) + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn detector_internal_server_error() -> Result<(), anyhow::Error> { + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ] + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Here".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" are".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" ".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" random".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" phone".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(" numbers".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some(":\n\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("1. (503) 272-8192\n".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("2. (617) 985-3519.".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![ + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Here".into(), + input_index_stream: 1, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " are".into(), + input_index_stream: 2, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " ".into(), + input_index_stream: 3, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2".into(), + input_index_stream: 4, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " random".into(), + input_index_stream: 5, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " phone".into(), + input_index_stream: 6, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: " numbers".into(), + input_index_stream: 7, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: ":\n\n".into(), + input_index_stream: 8, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "1. (503) 272-8192\n".into(), + input_index_stream: 9, + }, + BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "2. (617) 985-3519.".into(), + input_index_stream: 10, + }, + ]); + then.pb_stream(vec![ + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 32, + text: "Here are 2 random phone numbers:".into(), + }], + token_count: 0, + processed_index: 32, + start_index: 0, + input_start_index: 1, + input_end_index: 8, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 32, + end: 51, + text: "\n\n1. (503) 272-8192".into(), + }], + token_count: 0, + processed_index: 51, + start_index: 32, + input_start_index: 9, + input_end_index: 9, + }, + ChunkerTokenizationStreamResult { + results: vec![Token { + start: 51, + end: 70, + text: "\n2. (617) 985-3519.".into(), + }], + token_count: 0, + processed_index: 70, + start_index: 51, + input_start_index: 10, + input_end_index: 10, + }, + ]); + }); + + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Here are 2 random phone numbers:".into()], + detector_params: DetectorParams::default(), + }); + then.internal_server_error(); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {}, + }, + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Can you generate 2 random phone numbers?".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; + + // Validate length + assert_eq!(messages.len(), 1, "unexpected number of messages"); + + // Validate error message + assert!( + messages[0] + .as_ref() + .is_err_and(|e| e.code == StatusCode::INTERNAL_SERVER_ERROR) + ); + + Ok(()) +} diff --git a/tests/chat_completions_detection.rs b/tests/chat_completions_unary.rs similarity index 84% rename from tests/chat_completions_detection.rs rename to tests/chat_completions_unary.rs index 42e4e1d0..3ef757a0 100644 --- a/tests/chat_completions_detection.rs +++ b/tests/chat_completions_unary.rs @@ -16,14 +16,14 @@ */ use common::{ - chat_generation::CHAT_COMPLETIONS_ENDPOINT, chunker::CHUNKER_UNARY_ENDPOINT, detectors::{ ANSWER_RELEVANCE_DETECTOR, DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, + openai::CHAT_COMPLETIONS_ENDPOINT, orchestrator::{ ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, TestOrchestratorServer, @@ -34,9 +34,10 @@ use fms_guardrails_orchestr8::{ chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, detector::{ContentAnalysisRequest, ContentAnalysisResponse}, openai::{ - ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, ChatDetections, Content, - ContentPart, ContentType, InputDetectionResult, Message, OrchestratorWarning, - OutputDetectionResult, Role, + ChatCompletion, ChatCompletionChoice, ChatCompletionMessage, + CompletionDetectionWarning, CompletionDetections, CompletionInputDetections, + CompletionOutputDetections, Content, ContentPart, ContentType, Message, Role, + TokenizeResponse, }, }, models::{ @@ -47,6 +48,7 @@ use fms_guardrails_orchestr8::{ caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, caikit_data_model::nlp::{Token, TokenizationResults}, }, + server, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -54,6 +56,8 @@ use serde_json::json; use test_log::test; use tracing::debug; +use crate::common::openai::TOKENIZE_ENDPOINT; + pub mod common; // Constants @@ -123,12 +127,11 @@ async fn no_detectors() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mut mock_chat_completions_server = - MockServer::new("chat_completions").with_mocks(chat_mocks); + let mut mock_openai_server = MockServer::new_http("openai").with_mocks(chat_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) - .chat_generation_server(&mock_chat_completions_server) + .openai_server(&mock_openai_server) .build() .await?; @@ -222,7 +225,7 @@ async fn no_detectors() -> Result<(), anyhow::Error> { ]; // add new mock - mock_chat_completions_server.mock(|when, then| { + mock_openai_server.mock(|when, then| { when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ "model": MODEL_ID, "messages": messages, @@ -302,6 +305,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { stop_reason: None, }, ]; + let chat_completions_response = ChatCompletion { model: MODEL_ID.into(), choices: expected_choices.clone(), @@ -341,13 +345,13 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); - let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(chat_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) - .chat_generation_server(&mock_chat_completions_server) + .openai_server(&mock_openai_server) .build() .await?; @@ -369,7 +373,6 @@ async fn no_detections() -> Result<(), anyhow::Error> { .send() .await?; - // Assertions for no detections assert_eq!(response.status(), StatusCode::OK); let results = response.json::().await?; assert_eq!(results.choices[0], chat_completions_response.choices[0]); @@ -377,6 +380,88 @@ async fn no_detections() -> Result<(), anyhow::Error> { assert_eq!(results.warnings, vec![]); assert!(results.detections.is_none()); + // Scenario: output detectors on empty choices responses + let messages = vec![Message { + content: Some(Content::Text( + "Please provide me an empty message".to_string(), + )), + role: Role::User, + ..Default::default() + }]; + let expected_choices = vec![ + ChatCompletionChoice { + message: ChatCompletionMessage { + role: Role::Assistant, + content: Some("".to_string()), + refusal: None, + tool_calls: vec![], + }, + index: 0, + logprobs: None, + finish_reason: "EOS_TOKEN".to_string(), + stop_reason: None, + }, + ChatCompletionChoice { + message: ChatCompletionMessage { + role: Role::Assistant, + content: None, + refusal: None, + tool_calls: vec![], + }, + index: 1, + logprobs: None, + finish_reason: "EOS_TOKEN".to_string(), + stop_reason: None, + }, + ]; + + let expected_warnings = vec![ + CompletionDetectionWarning::new( + DetectionWarningReason::EmptyOutput, + "Choice of index 0 has no content. Output detection was not executed", + ), + CompletionDetectionWarning::new( + DetectionWarningReason::EmptyOutput, + "Choice of index 1 has no content. Output detection was not executed", + ), + ]; + let chat_completions_response = ChatCompletion { + model: MODEL_ID.into(), + choices: expected_choices.clone(), + detections: None, + ..Default::default() + }; + + mock_openai_server.mocks().mock(|when, then| { + when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "messages": messages, + })); + then.json(&chat_completions_response); + }); + + let response = orchestrator_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "output": { + detector_name: {}, + }, + }, + "messages": messages, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + debug!("{}", serde_json::to_string_pretty(&results)?); + assert_eq!(results.choices[0], chat_completions_response.choices[0]); + assert_eq!(results.choices[1], chat_completions_response.choices[1]); + assert_eq!(results.warnings, expected_warnings); + assert!(results.detections.is_none()); + Ok(()) } @@ -395,7 +480,7 @@ async fn input_detections() -> Result<(), anyhow::Error> { // Add mocksets let mut detector_mocks = MockSet::new(); let mut chunker_mocks = MockSet::new(); - let mut chat_mocks = MockSet::new(); + let mut openai_mocks = MockSet::new(); // Add input detection mock response for input detection let expected_detections = vec![ContentAnalysisResponse { @@ -413,14 +498,14 @@ async fn input_detections() -> Result<(), anyhow::Error> { let chat_completions_response = ChatCompletion { model: MODEL_ID.into(), choices: vec![], - detections: Some(ChatDetections { - input: vec![InputDetectionResult { + detections: Some(CompletionDetections { + input: vec![CompletionInputDetections { message_index: 0, results: expected_detections.clone(), }], output: vec![], }), - warnings: vec![OrchestratorWarning::new( + warnings: vec![CompletionDetectionWarning::new( DetectionWarningReason::UnsuitableInput, UNSUITABLE_INPUT_MESSAGE, )], @@ -455,27 +540,35 @@ async fn input_detections() -> Result<(), anyhow::Error> { then.json([&expected_detections]); }); - // Add chat completions mock - chat_mocks.mock(|when, then| { + // Add openai mocks + openai_mocks.mock(|when, then| { when.post().path(CHAT_COMPLETIONS_ENDPOINT).json(json!({ "model": MODEL_ID, "messages": messages, })); then.json(&chat_completions_response); }); + openai_mocks.mock(|when, then| { + when.post().path(TOKENIZE_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": input_text, + })); + then.json(&TokenizeResponse { + count: 12, + ..Default::default() + }); + }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); - let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(openai_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) .chunker_servers([&mock_chunker_server]) - .chat_generation_server(&mock_chat_completions_server) + .openai_server(&mock_openai_server) .build() .await?; @@ -515,7 +608,10 @@ async fn input_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; // Add 500 expected orchestrator error response - let expected_orchestrator_error = OrchestratorError::internal(); + let expected_orchestrator_error = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; @@ -621,17 +717,15 @@ async fn input_client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); - let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) .chunker_servers([&mock_chunker_server]) - .chat_generation_server(&mock_chat_completions_server) + .openai_server(&mock_openai_server) .build() .await?; @@ -652,7 +746,7 @@ async fn input_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for chunker error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, expected_orchestrator_error); // Make orchestrator call for detector error scenario @@ -672,7 +766,7 @@ async fn input_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for detector error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, expected_orchestrator_error); // Make orchestrator call for chat completions error scenario @@ -692,7 +786,7 @@ async fn input_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for chat completions error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, expected_orchestrator_error); Ok(()) @@ -768,14 +862,14 @@ async fn output_detections() -> Result<(), anyhow::Error> { let chat_completions_response = ChatCompletion { model: MODEL_ID.into(), choices: expected_choices.clone(), - detections: Some(ChatDetections { + detections: Some(CompletionDetections { input: vec![], - output: vec![OutputDetectionResult { + output: vec![CompletionOutputDetections { choice_index: 1, results: expected_detections.clone(), }], }), - warnings: vec![OrchestratorWarning::new( + warnings: vec![CompletionDetectionWarning::new( DetectionWarningReason::UnsuitableOutput, UNSUITABLE_OUTPUT_MESSAGE, )], @@ -848,17 +942,15 @@ async fn output_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); - let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) .chunker_servers([&mock_chunker_server]) - .chat_generation_server(&mock_chat_completions_server) + .openai_server(&mock_openai_server) .build() .await?; @@ -899,7 +991,10 @@ async fn output_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; // Add 500 expected orchestrator mock response - let expected_orchestrator_error = OrchestratorError::internal(); + let expected_orchestrator_error = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; @@ -1023,17 +1118,15 @@ async fn output_client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); - let mock_chat_completions_server = MockServer::new("chat_completions").with_mocks(chat_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(chat_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) .chunker_servers([&mock_chunker_server]) - .chat_generation_server(&mock_chat_completions_server) + .openai_server(&mock_openai_server) .build() .await?; @@ -1054,7 +1147,7 @@ async fn output_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for chunker error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, expected_orchestrator_error); // Make orchestrator call for detector error scenario @@ -1074,7 +1167,7 @@ async fn output_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for detector error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, expected_orchestrator_error); // Make orchestrator call for chat completions error scenario @@ -1094,7 +1187,7 @@ async fn output_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for chat completions error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, expected_orchestrator_error); Ok(()) @@ -1131,11 +1224,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR}` is not supported by this endpoint", + ) + }, "failed on invalid input detector scenario" ); @@ -1155,11 +1253,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing input detector scenario" ); @@ -1179,11 +1280,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR}` is not supported by this endpoint" + ) + }, "failed on invalid output detector scenario" ); @@ -1203,11 +1309,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing input detector scenario" ); @@ -1239,12 +1348,12 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "if input detectors are provided, `content` must not be empty on last message" .into() } @@ -1279,12 +1388,12 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "if input detectors are provided, `content` must not be empty on last message" .into() } @@ -1319,12 +1428,12 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "Detection on array is not supported".into() } ); @@ -1371,12 +1480,12 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "Detection on array is not supported".into() } ); diff --git a/tests/chat_detection.rs b/tests/chat_detection.rs index 95c266fa..c5cd3015 100644 --- a/tests/chat_detection.rs +++ b/tests/chat_detection.rs @@ -21,7 +21,7 @@ use common::{ ANSWER_RELEVANCE_DETECTOR_SENTENCE, CHAT_DETECTOR_ENDPOINT, NON_EXISTING_DETECTOR, PII_DETECTOR, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, orchestrator::{ ORCHESTRATOR_CHAT_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, TestOrchestratorServer, }, @@ -34,6 +34,7 @@ use fms_guardrails_orchestr8::{ models::{ ChatDetectionHttpRequest, ChatDetectionResult, DetectionResult, DetectorParams, Metadata, }, + server, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -92,7 +93,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) @@ -162,7 +163,7 @@ async fn detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -229,7 +230,7 @@ async fn client_errors() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -251,9 +252,15 @@ async fn client_errors() -> Result<(), anyhow::Error> { // assertions assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); - assert_eq!(response, OrchestratorError::internal()); + assert_eq!( + response, + server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into() + } + ); Ok(()) } @@ -291,7 +298,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("unknown field `extra_args`")); @@ -307,7 +314,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `messages`")); @@ -332,7 +339,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `detectors`")); @@ -358,7 +365,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("`detectors` is required")); @@ -382,7 +389,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .await?; assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("Message content cannot be empty")); @@ -415,11 +422,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .await?; assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR_SENTENCE}` is not supported by this endpoint" + ), + }, "failed on detector with invalid type scenario" ); @@ -435,11 +447,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .await?; assert_eq!(response.status(), StatusCode::NOT_FOUND); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing detector scenario" ); diff --git a/tests/classification_with_text_gen.rs b/tests/classification_with_text_gen.rs index c8161a7c..aa25a870 100644 --- a/tests/classification_with_text_gen.rs +++ b/tests/classification_with_text_gen.rs @@ -25,7 +25,7 @@ use common::{ DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, generation::{ GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_TOKENIZATION_ENDPOINT, GENERATION_NLP_UNARY_ENDPOINT, @@ -52,6 +52,7 @@ use fms_guardrails_orchestr8::{ }, caikit_data_model::nlp::{GeneratedTextResult, Token, TokenizationResults}, }, + server, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -93,7 +94,7 @@ async fn no_detectors() -> Result<(), anyhow::Error> { }); // Configure mock servers - let generation_server = MockServer::new("nlp").grpc().with_mocks(mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -271,11 +272,9 @@ async fn no_detections() -> Result<(), anyhow::Error> { // Configure mock servers let mock_detector_server = - MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + MockServer::new_http(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -514,12 +513,10 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { }); // Configure mock servers - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let mock_detector_server = - MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); + MockServer::new_http(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -647,7 +644,10 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; - let orchestrator_error_500 = OrchestratorError::internal(); + let orchestrator_error_500 = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; @@ -715,12 +715,10 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { }); // Configure mock servers - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let mock_detector_server = - MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); + MockServer::new_http(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -753,7 +751,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for generation internal server error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response for detector internal server error scenario @@ -778,7 +776,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for detector internal server error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response @@ -803,7 +801,7 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for chunker internal server error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, orchestrator_error_500); Ok(()) @@ -996,12 +994,10 @@ async fn output_detector_detections() -> Result<(), anyhow::Error> { }); // Configure mock servers - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let mock_detector_server = - MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + MockServer::new_http(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -1122,7 +1118,10 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { message: "Internal detector error.".into(), }; - let orchestrator_error_500 = OrchestratorError::internal(); + let orchestrator_error_500 = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; // Add input for error scenarios let chunker_error_input = "This should return a 500 error on chunker"; @@ -1219,12 +1218,10 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { }); // Configure mock servers - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE) - .grpc() - .with_mocks(chunker_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); let mock_detector_server = - MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); + MockServer::new_http(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -1256,7 +1253,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for generation internal server error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response for detector internal server error scenario @@ -1280,7 +1277,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for detector internal server error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, orchestrator_error_500); // Orchestrator request with unary response @@ -1304,7 +1301,7 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { .await?; // Assertions for chunker internal server error scenario - let results = response.json::().await?; + let results = response.json::().await?; assert_eq!(results, orchestrator_error_500); Ok(()) @@ -1332,7 +1329,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!(results.code, StatusCode::UNPROCESSABLE_ENTITY); assert!( @@ -1362,11 +1359,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR_SENTENCE}` is not supported by this endpoint" + ), + }, "failed on input detector with invalid type scenario" ); @@ -1388,12 +1390,15 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), - "failed on non-existing input detector scenario" + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, + "failed on non-existing detector scenario" ); // Invalid output detector scenario @@ -1416,11 +1421,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR_SENTENCE}` is not supported by this endpoint" + ), + }, "failed on output detector with invalid type scenario" ); @@ -1441,11 +1451,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .send() .await?; - let results = response.json::().await?; + let results = response.json::().await?; debug!("{results:#?}"); assert_eq!( results, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing output detector scenario" ); diff --git a/tests/common/detectors.rs b/tests/common/detectors.rs index 19f67767..a59e1c7b 100644 --- a/tests/common/detectors.rs +++ b/tests/common/detectors.rs @@ -24,6 +24,8 @@ pub const ANSWER_RELEVANCE_DETECTOR_SENTENCE: &str = "answer_relevance_detector_ pub const FACT_CHECKING_DETECTOR: &str = "fact_checking_detector"; pub const FACT_CHECKING_DETECTOR_SENTENCE: &str = "fact_checking_detector_sentence"; pub const PII_DETECTOR: &str = "pii_detector"; +pub const PII_DETECTOR_SENTENCE: &str = "pii_detector_sentence"; +pub const PII_DETECTOR_WHOLE_DOC: &str = "pii_detector_whole_doc"; pub const NON_EXISTING_DETECTOR: &str = "non_existing_detector"; // Detector endpoints diff --git a/tests/common/errors.rs b/tests/common/errors.rs index 3c5e9844..8b6fd623 100644 --- a/tests/common/errors.rs +++ b/tests/common/errors.rs @@ -22,58 +22,3 @@ pub struct DetectorError { pub code: u16, pub message: String, } - -/// Errors returned by orchestrator endpoints. -#[derive(Serialize, Deserialize, Debug, PartialEq)] -pub struct OrchestratorError { - pub code: u16, - pub details: String, -} - -impl OrchestratorError { - /// Helper function that generates an orchestrator internal - /// server error. - pub fn internal() -> OrchestratorError { - OrchestratorError { - code: 500, - details: "unexpected error occurred while processing request".into(), - } - } - /// Helper function that generates an orchestrator non-existing detector error. - pub fn detector_not_found(detector_name: &str) -> Self { - Self { - code: 404, - details: format!("detector `{}` not found", detector_name), - } - } - - /// Helper function that generates an orchestrator invalid detector error. - pub fn detector_not_supported(detector_name: &str) -> Self { - Self { - code: 422, - details: format!( - "detector `{}` is not supported by this endpoint", - detector_name - ), - } - } - - /// Helper function that generates an orchestrator required field error. - pub fn required(field_name: &str) -> Self { - Self { - code: 422, - details: format!("`{}` is required", field_name), - } - } - - /// Helper function that generates an orchestrator invalid chunker error. - pub fn chunker_not_supported(detector_name: &str) -> Self { - Self { - code: 422, - details: format!( - "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", - detector_name - ), - } - } -} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 261f942c..ca6c08ee 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -14,9 +14,19 @@ limitations under the License. */ -pub mod chat_generation; pub mod chunker; pub mod detectors; pub mod errors; pub mod generation; +pub mod openai; pub mod orchestrator; + +/// Converts an iterator of serializable messages into an iterator of SSE data messages. +pub fn sse( + messages: impl IntoIterator, +) -> impl IntoIterator { + messages.into_iter().map(|msg| { + let msg = serde_json::to_string(&msg).unwrap(); + format!("data: {msg}\n\n") + }) +} diff --git a/tests/common/chat_generation.rs b/tests/common/openai.rs similarity index 86% rename from tests/common/chat_generation.rs rename to tests/common/openai.rs index 7ac67d0e..8af159ce 100644 --- a/tests/common/chat_generation.rs +++ b/tests/common/openai.rs @@ -17,3 +17,5 @@ // Chat completions server endpoint pub const CHAT_COMPLETIONS_ENDPOINT: &str = "/v1/chat/completions"; +pub const COMPLETIONS_ENDPOINT: &str = "/v1/completions"; +pub const TOKENIZE_ENDPOINT: &str = "/tokenize"; diff --git a/tests/common/orchestrator.rs b/tests/common/orchestrator.rs index f9fb7e9c..e1124f70 100644 --- a/tests/common/orchestrator.rs +++ b/tests/common/orchestrator.rs @@ -58,6 +58,7 @@ pub const ORCHESTRATOR_CHAT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/c pub const ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT: &str = "/api/v2/chat/completions-detection"; +pub const ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT: &str = "/api/v2/text/completions-detection"; // Messages pub const ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE: &str = "Unsuitable input detected. Please check the detected entities on your input and try again with the unsuitable input removed."; @@ -72,7 +73,7 @@ pub struct TestOrchestratorServerBuilder<'a> { port: Option, health_port: Option, generation_server: Option<&'a MockServer>, - chat_generation_server: Option<&'a MockServer>, + openai_server: Option<&'a MockServer>, detector_servers: Option>, chunker_servers: Option>, } @@ -102,8 +103,8 @@ impl<'a> TestOrchestratorServerBuilder<'a> { self } - pub fn chat_generation_server(mut self, server: &'a MockServer) -> Self { - self.chat_generation_server = Some(server); + pub fn openai_server(mut self, server: &'a MockServer) -> Self { + self.openai_server = Some(server); self } @@ -126,7 +127,7 @@ impl<'a> TestOrchestratorServerBuilder<'a> { // Start & configure mock servers initialize_generation_server(self.generation_server, &mut config).await?; - initialize_chat_generation_server(self.chat_generation_server, &mut config).await?; + initialize_openai_server(self.openai_server, &mut config).await?; initialize_detectors(self.detector_servers.as_deref(), &mut config).await?; initialize_chunkers(self.chunker_servers.as_deref(), &mut config).await?; @@ -169,7 +170,7 @@ impl TestOrchestratorServer { client: reqwest::Client::builder().build().unwrap(), }); } - Err(server::Error::IoError(error)) => { + Err(error) if error.details().starts_with("io error") => { warn!(%error, "failed to start server, trying again with different ports..."); continue; } @@ -214,14 +215,13 @@ async fn initialize_generation_server( } /// Starts and configures chat generation server. -async fn initialize_chat_generation_server( - chat_generation_server: Option<&MockServer>, +async fn initialize_openai_server( + openai_server: Option<&MockServer>, config: &mut OrchestratorConfig, ) -> Result<(), anyhow::Error> { - if let Some(chat_generation_server) = chat_generation_server { - chat_generation_server.start().await?; - config.chat_generation.as_mut().unwrap().service.port = - Some(chat_generation_server.addr().unwrap().port()); + if let Some(openai_server) = openai_server { + openai_server.start().await?; + config.openai.as_mut().unwrap().service.port = Some(openai_server.addr().unwrap().port()); }; Ok(()) } @@ -285,7 +285,7 @@ impl Stream for SseStream<'_, T> where T: DeserializeOwned, { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match Pin::new(&mut self.get_mut().stream).poll_next(cx) { @@ -293,12 +293,30 @@ where if event.data == "[DONE]" { return Poll::Ready(None); } + if event.event == "error" { + let error: server::Error = serde_json::from_str(&event.data).unwrap(); + return Poll::Ready(Some(Err(error))); + } match serde_json::from_str::(&event.data) { Ok(msg) => Poll::Ready(Some(Ok(msg))), - Err(error) => Poll::Ready(Some(Err(error.into()))), + Err(error) => { + let error = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: format!( + "sse_stream error: `event.data` deserialization failed {error}" + ), + }; + Poll::Ready(Some(Err(error))) + } } } - Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error.into()))), + Poll::Ready(Some(Err(error))) => { + let error = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: format!("sse_stream error: error parsing event {error}"), + }; + Poll::Ready(Some(Err(error))) + } Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } diff --git a/tests/completions_detection.rs b/tests/completions_detection.rs new file mode 100644 index 00000000..5370f0c5 --- /dev/null +++ b/tests/completions_detection.rs @@ -0,0 +1,1206 @@ +/* + Copyright FMS Guardrails Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +use common::{ + openai::COMPLETIONS_ENDPOINT, + orchestrator::{ + ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT, ORCHESTRATOR_CONFIG_FILE_PATH, + TestOrchestratorServer, + }, +}; +use fms_guardrails_orchestr8::{ + clients::{ + chunker::MODEL_ID_HEADER_NAME as CHUNKER_MODEL_ID_HEADER_NAME, + detector::{ContentAnalysisRequest, ContentAnalysisResponse}, + openai::{ + Completion, CompletionChoice, CompletionDetectionWarning, CompletionDetections, + CompletionInputDetections, CompletionOutputDetections, TokenizeResponse, Usage, + }, + }, + models::{ + DetectionWarningReason, DetectorParams, Metadata, UNSUITABLE_INPUT_MESSAGE, + UNSUITABLE_OUTPUT_MESSAGE, + }, + orchestrator::common::current_timestamp, + pb::{ + caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, + caikit_data_model::nlp::{Token, TokenizationResults}, + }, + server, +}; +use hyper::StatusCode; +use mocktail::prelude::*; +use serde_json::json; +use test_log::test; +use tracing::debug; +use uuid::Uuid; + +use crate::common::{ + chunker::CHUNKER_UNARY_ENDPOINT, + detectors::{ + ANSWER_RELEVANCE_DETECTOR, DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, + DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, NON_EXISTING_DETECTOR, + TEXT_CONTENTS_DETECTOR_ENDPOINT, + }, + errors::DetectorError, + openai::TOKENIZE_ENDPOINT, +}; + +pub mod common; + +// Constants +const CHUNKER_NAME_SENTENCE: &str = "sentence_chunker"; +const MODEL_ID: &str = "my-super-model-8B"; + +// Validate passthrough scenario +#[test(tokio::test)] +async fn no_detectors() -> Result<(), anyhow::Error> { + let prompt = "Hi there!"; + + // Add mocksets + let mut completion_mocks = MockSet::new(); + + let expected_choices = vec![ + CompletionChoice { + index: 0, + text: "Hi there!".into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + CompletionChoice { + index: 1, + text: "Hello!".into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + ]; + let completions_response = Completion { + id: Uuid::new_v4().simple().to_string(), + object: "text_completion".into(), + created: current_timestamp().as_secs() as i64, + model: MODEL_ID.into(), + choices: expected_choices, + usage: Some(Usage { + prompt_tokens: 4, + total_tokens: 36, + completion_tokens: 32, + ..Default::default() + }), + ..Default::default() + }; + + // Add completions mock + completion_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": prompt, + })); + then.json(&completions_response); + }); + + // Start orchestrator server and its dependencies + let mock_openai_server = MockServer::new_http("openai").with_mocks(completion_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&mock_openai_server) + .build() + .await?; + + // Empty `detectors` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": {}, + "prompt": prompt, + })) + .send() + .await?; + dbg!(&response); + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], completions_response.choices[0]); + assert_eq!(results.choices[1], completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + // Missing `detectors` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "prompt": prompt, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], completions_response.choices[0]); + assert_eq!(results.choices[1], completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + // `detectors` with empty `input` and `output` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "prompt": prompt, + "detectors": { + "input": {}, + "output": {}, + }, + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], completions_response.choices[0]); + assert_eq!(results.choices[1], completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + Ok(()) +} + +// Validate that requests without detectors, input detector and output detector configured +// returns text generated by model +#[test(tokio::test)] +async fn no_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC; + + let prompt = "Hi there!"; + + // Add mocksets + let mut detector_mocks = MockSet::new(); + let mut completion_mocks = MockSet::new(); + + let expected_choices = vec![ + CompletionChoice { + index: 0, + text: "Hi there!".into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + CompletionChoice { + index: 1, + text: "Hello!".into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + ]; + + let completions_response = Completion { + id: Uuid::new_v4().simple().to_string(), + object: "text_completion".into(), + created: current_timestamp().as_secs() as i64, + model: MODEL_ID.into(), + choices: expected_choices, + usage: Some(Usage { + prompt_tokens: 4, + total_tokens: 36, + completion_tokens: 32, + ..Default::default() + }), + ..Default::default() + }; + + // Add detector input mock + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hi there!".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + // Add detector output mock + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec!["Hello!".into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + // Add completions mock + completion_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": prompt, + })); + then.json(&completions_response); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(completion_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .openai_server(&mock_openai_server) + .build() + .await?; + + // Make orchestrator call for input/output no detections + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": { + detector_name: {}, + }, + }, + "prompt": prompt + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.choices[0], completions_response.choices[0]); + assert_eq!(results.choices[1], completions_response.choices[1]); + assert_eq!(results.warnings, vec![]); + assert!(results.detections.is_none()); + + // Scenario: output detectors on empty choices responses + let prompt = "Please provide me an empty message"; + let expected_choices = vec![ + CompletionChoice { + index: 0, + text: "".into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + CompletionChoice { + index: 1, + text: "".into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + ]; + let expected_warnings = vec![ + CompletionDetectionWarning::new( + DetectionWarningReason::EmptyOutput, + "Choice of index 0 has no content. Output detection was not executed", + ), + CompletionDetectionWarning::new( + DetectionWarningReason::EmptyOutput, + "Choice of index 1 has no content. Output detection was not executed", + ), + ]; + let completions_response = Completion { + id: Uuid::new_v4().simple().to_string(), + object: "text_completion".into(), + created: current_timestamp().as_secs() as i64, + model: MODEL_ID.into(), + choices: expected_choices, + usage: Some(Usage { + prompt_tokens: 4, + total_tokens: 36, + completion_tokens: 32, + ..Default::default() + }), + ..Default::default() + }; + + mock_openai_server.mocks().mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": prompt, + })); + then.json(&completions_response); + }); + + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "output": { + detector_name: {}, + }, + }, + "prompt": prompt + })) + .send() + .await?; + + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + debug!("{}", serde_json::to_string_pretty(&results)?); + assert_eq!(results.choices[0], completions_response.choices[0]); + assert_eq!(results.choices[1], completions_response.choices[1]); + assert_eq!(results.warnings, expected_warnings); + assert!(results.detections.is_none()); + + Ok(()) +} + +// Validates that requests with input detector configured returns detections +#[test(tokio::test)] +async fn input_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let prompt = "Hi there! Can you help me with ?"; + + // Add mocksets + let mut detector_mocks = MockSet::new(); + let mut chunker_mocks = MockSet::new(); + let mut tokenize_mocks = MockSet::new(); + + // Add input detection mock response for input detection + let expected_detections = vec![ContentAnalysisResponse { + start: 34, + end: 42, + text: "something".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }]; + + let completions_response = Completion { + id: Uuid::new_v4().simple().to_string(), + object: "text_completion".into(), + created: current_timestamp().as_secs() as i64, + model: MODEL_ID.into(), + choices: vec![], + detections: Some(CompletionDetections { + input: vec![CompletionInputDetections { + message_index: 0, + results: expected_detections.clone(), + }], + output: vec![], + }), + warnings: vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + usage: Some(Usage { + prompt_tokens: 43, + ..Default::default() + }), + ..Default::default() + }; + + // Add chunker tokenization mock for input detection + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: prompt.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: prompt.len() as i64, + text: prompt.into(), + }], + token_count: 0, + }); + }); + + // Add detector input mock + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![prompt.into()], + detector_params: DetectorParams::new(), + }); + then.json([&expected_detections]); + }); + + // Add Tokenize mock + tokenize_mocks.mock(|when, then| { + when.post().path(TOKENIZE_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": prompt, + })); + then.json(&TokenizeResponse { + count: 43, + ..Default::default() + }); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(tokenize_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .openai_server(&mock_openai_server) + .build() + .await?; + + // Make orchestrator call for input/output no detections + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "prompt": prompt + })) + .send() + .await?; + + // Assertions for input detections + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.detections, completions_response.detections); + assert_eq!(results.choices, completions_response.choices); + assert_eq!(results.warnings, completions_response.warnings); + + Ok(()) +} + +// Validates that requests with input detector configured returns propagated errors +#[test(tokio::test)] +async fn input_client_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + // Add 500 expected input detector mock response + let expected_detector_error = DetectorError { + code: 500, + message: "Internal detector error.".into(), + }; + // Add 500 expected orchestrator error response + let expected_orchestrator_error = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; + + // Add input for error scenarios + let chunker_error_input = "This should return a 500 error on chunker"; + let detector_error_input = "This should return a 500 error on detector"; + let completions_error_input = "This should return a 500 error on openai"; + + // Add mocksets + let mut chunker_mocks = MockSet::new(); + let mut detector_mocks = MockSet::new(); + let mut completions_mocks = MockSet::new(); + + // Add chunker tokenization mock for detector internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: detector_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: detector_error_input.len() as i64, + text: detector_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for completions internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: completions_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: completions_error_input.len() as i64, + text: completions_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for chunker internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: chunker_error_input.into(), + }); + then.internal_server_error(); + }); + + // Add detector mock for completions error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![completions_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + // Add detector mock for detector error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![detector_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.internal_server_error().json(&expected_detector_error); + }); + + // Add completions mock for completions error scenario + completions_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": completions_error_input, + })); + then.internal_server_error(); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(completions_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .openai_server(&mock_openai_server) + .build() + .await?; + + // Make orchestrator call for chunker error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "prompt": chunker_error_input, + })) + .send() + .await?; + + // Assertions for chunker error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for detector error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "prompt": detector_error_input + })) + .send() + .await?; + + // Assertions for detector error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for completions error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + detector_name: {}, + }, + "output": {} + }, + "prompt": completions_error_input, + })) + .send() + .await?; + + // Assertions for completions error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + Ok(()) +} + +// Validates that requests with output detector configured returns detections +#[test(tokio::test)] +async fn output_detections() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + let prompt = "Hi there! Can you help me with something?"; + let output_no_detection = "Sure! Let me help you with something, just tell me what you need."; + let output_with_detection = + "Sure! Let me help you with , just tell me what you need."; + + // Add mocksets + let mut detector_mocks = MockSet::new(); + let mut completion_mocks = MockSet::new(); + let mut chunker_mocks = MockSet::new(); + + // Add output detection mock response for output detection + let expected_detections = vec![ContentAnalysisResponse { + start: 28, + end: 37, + text: "something".into(), + detection: "has_angle_brackets".into(), + detection_type: "angle_brackets".into(), + detector_id: Some(detector_name.into()), + score: 1.0, + evidence: None, + metadata: Metadata::new(), + }]; + + // Add completion choices response for output detection + let expected_choices = vec![ + CompletionChoice { + index: 0, + text: output_no_detection.into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + CompletionChoice { + index: 1, + text: output_with_detection.into(), + logprobs: None, + finish_reason: Some("length".into()), + stop_reason: None, + prompt_logprobs: None, + }, + ]; + + let completions_response = Completion { + id: Uuid::new_v4().simple().to_string(), + object: "text_completion".into(), + created: current_timestamp().as_secs() as i64, + model: MODEL_ID.into(), + choices: expected_choices, + detections: Some(CompletionDetections { + input: vec![], + output: vec![CompletionOutputDetections { + choice_index: 1, + results: expected_detections.clone(), + }], + }), + warnings: vec![CompletionDetectionWarning::new( + DetectionWarningReason::UnsuitableOutput, + UNSUITABLE_OUTPUT_MESSAGE, + )], + ..Default::default() + }; + + // Add detector output mock for first message + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![output_no_detection.into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + // Add detector output mock for generated message + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![output_with_detection.into()], + detector_params: DetectorParams::new(), + }); + then.json([&expected_detections]); + }); + + // Add chunker tokenization mock for output detection user input + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: prompt.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: prompt.len() as i64, + text: prompt.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for output detection assistant output + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: output_no_detection.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: output_no_detection.len() as i64, + text: output_no_detection.into(), + }], + token_count: 0, + }); + }); + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: output_with_detection.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: output_with_detection.len() as i64, + text: output_with_detection.into(), + }], + token_count: 0, + }); + }); + + // Add completions mock + completion_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": prompt, + })); + then.json(&completions_response); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(completion_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .openai_server(&mock_openai_server) + .build() + .await?; + + // Make orchestrator call for output detections + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "prompt": prompt, + })) + .send() + .await?; + + // Assertions for output detections + assert_eq!(response.status(), StatusCode::OK); + let results = response.json::().await?; + assert_eq!(results.detections, completions_response.detections); + assert_eq!(results.choices, completions_response.choices); + assert_eq!(results.warnings, completions_response.warnings); + + Ok(()) +} + +// Validates that requests with output detector configured returns propagated errors +// from detector, chunker and completions server when applicable +#[test(tokio::test)] +async fn output_client_error() -> Result<(), anyhow::Error> { + let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; + // Add 500 expected output detector mock response + let expected_detector_error = DetectorError { + code: 500, + message: "Internal detector error.".into(), + }; + // Add 500 expected orchestrator mock response + let expected_orchestrator_error = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; + + // Add input for error scenarios + let chunker_error_input = "This should return a 500 error on chunker"; + let detector_error_input = "This should return a 500 error on detector"; + let completions_error_input = "This should return a 500 error on openai"; + + // Add mocksets + let mut chunker_mocks = MockSet::new(); + let mut detector_mocks = MockSet::new(); + let mut completion_mocks = MockSet::new(); + + // Add chunker tokenization mock for detector internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: detector_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: detector_error_input.len() as i64, + text: detector_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for completions internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: completions_error_input.into(), + }); + then.pb(TokenizationResults { + results: vec![Token { + start: 0, + end: completions_error_input.len() as i64, + text: completions_error_input.into(), + }], + token_count: 0, + }); + }); + + // Add chunker tokenization mock for chunker internal server error scenario + chunker_mocks.mock(|when, then| { + when.path(CHUNKER_UNARY_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, CHUNKER_NAME_SENTENCE) + .pb(ChunkerTokenizationTaskRequest { + text: chunker_error_input.into(), + }); + then.internal_server_error(); + }); + + // Add detector mock for completions error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![completions_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.json([Vec::::new()]); + }); + + // Add detector mock for detector error scenario + detector_mocks.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .json(ContentAnalysisRequest { + contents: vec![detector_error_input.into()], + detector_params: DetectorParams::new(), + }); + then.internal_server_error().json(&expected_detector_error); + }); + + // Add completions mock for chunker error scenario + completion_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": chunker_error_input, + })); + then.internal_server_error(); + }); + + // Add completions mock for detector error scenario + completion_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": detector_error_input, + })); + then.internal_server_error().json(&expected_detector_error); + }); + + // Add completions mock for completions error scenario + completion_mocks.mock(|when, then| { + when.post().path(COMPLETIONS_ENDPOINT).json(json!({ + "model": MODEL_ID, + "prompt": completions_error_input, + })); + then.internal_server_error(); + }); + + // Start orchestrator server and its dependencies + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_openai_server = MockServer::new_http("openai").with_mocks(completion_mocks); + let mock_chunker_server = MockServer::new_grpc(CHUNKER_NAME_SENTENCE).with_mocks(chunker_mocks); + + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .detector_servers([&mock_detector_server]) + .chunker_servers([&mock_chunker_server]) + .openai_server(&mock_openai_server) + .build() + .await?; + + // Make orchestrator call for chunker error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "prompt": chunker_error_input + })) + .send() + .await?; + + // Assertions for chunker error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for detector error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "prompt": detector_error_input, + })) + .send() + .await?; + + // Assertions for detector error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + // Make orchestrator call for completions error scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + detector_name: {}, + }, + }, + "prompt": completions_error_input, + })) + .send() + .await?; + + // Assertions for completions error scenario + let results = response.json::().await?; + assert_eq!(results, expected_orchestrator_error); + + Ok(()) +} + +// Validate that invalid orchestrator requests returns 422 error +#[test(tokio::test)] +async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { + // Start orchestrator server and its dependencies + let orchestrator_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .build() + .await?; + + let prompt = "Hi there!"; + + // Invalid input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + ANSWER_RELEVANCE_DETECTOR: {}, + }, + "output": {} + }, + "prompt": prompt, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR}` is not supported by this endpoint", + ) + }, + "failed on invalid input detector scenario" + ); + + // Non-existing input detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": { + NON_EXISTING_DETECTOR: {}, + }, + "output": {} + }, + "prompt": prompt, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, + "failed on non-existing input detector scenario" + ); + + // Invalid output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + ANSWER_RELEVANCE_DETECTOR: {}, + }, + }, + "prompt": prompt, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR}` is not supported by this endpoint" + ) + }, + "failed on invalid output detector scenario" + ); + + // Non-existing output detector scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": MODEL_ID, + "detectors": { + "input": {}, + "output": { + NON_EXISTING_DETECTOR: {}, + } + }, + "prompt": prompt, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, + "failed on non-existing input detector scenario" + ); + + // Empty `model` scenario + let response = orchestrator_server + .post(ORCHESTRATOR_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "model": "", + "prompt": prompt, + })) + .send() + .await?; + + let results = response.json::().await?; + debug!("{results:#?}"); + assert_eq!( + results, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: "`model` must not be empty".into() + } + ); + + Ok(()) +} diff --git a/tests/context_docs_detection.rs b/tests/context_docs_detection.rs index d24eea4e..b2da1611 100644 --- a/tests/context_docs_detection.rs +++ b/tests/context_docs_detection.rs @@ -21,7 +21,7 @@ use common::{ ANSWER_RELEVANCE_DETECTOR_SENTENCE, CONTEXT_DOC_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR, NON_EXISTING_DETECTOR, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT, TestOrchestratorServer, @@ -32,6 +32,7 @@ use fms_guardrails_orchestr8::{ models::{ ContextDocsHttpRequest, ContextDocsResult, DetectionResult, DetectorParams, Metadata, }, + server, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -71,7 +72,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -133,7 +134,7 @@ async fn detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -193,7 +194,7 @@ async fn client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -216,8 +217,14 @@ async fn client_error() -> Result<(), anyhow::Error> { // assertions assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - let response = response.json::().await?; - assert_eq!(response, OrchestratorError::internal()); + let response = response.json::().await?; + assert_eq!( + response, + server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into() + } + ); Ok(()) } @@ -250,7 +257,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; assert_eq!(response.code, 422); assert!(response.details.contains("unknown field `extra_args`")); @@ -267,7 +274,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; assert_eq!(response.code, 422); assert!(response.details.contains("missing field `content`")); @@ -284,7 +291,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; assert_eq!(response.code, 422); assert!(response.details.contains("missing field `context`")); @@ -301,7 +308,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; assert_eq!(response.code, 422); assert!(response.details.contains("missing field `context_type`")); @@ -319,7 +326,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!( @@ -341,7 +348,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.starts_with("missing field `detectors`")); @@ -359,7 +366,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.starts_with("missing field `detectors`")); @@ -377,7 +384,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.starts_with("missing field `detectors`")); @@ -399,11 +406,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_supported(ANSWER_RELEVANCE_DETECTOR_SENTENCE), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{ANSWER_RELEVANCE_DETECTOR_SENTENCE}` is not supported by this endpoint" + ), + }, "failed on detector with invalid type scenario" ); @@ -421,11 +433,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::NOT_FOUND); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing detector scenario" ); diff --git a/tests/detection_on_generation.rs b/tests/detection_on_generation.rs index b59765f0..a43e1d49 100644 --- a/tests/detection_on_generation.rs +++ b/tests/detection_on_generation.rs @@ -22,7 +22,7 @@ use common::{ ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT, TestOrchestratorServer, @@ -34,6 +34,7 @@ use fms_guardrails_orchestr8::{ DetectionOnGeneratedHttpRequest, DetectionOnGenerationResult, DetectionResult, DetectorParams, Metadata, }, + server, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -72,7 +73,7 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -132,7 +133,7 @@ async fn detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -189,7 +190,7 @@ async fn client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -211,8 +212,14 @@ async fn client_error() -> Result<(), anyhow::Error> { // assertions assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - let response = response.json::().await?; - assert_eq!(response, OrchestratorError::internal()); + let response = response.json::().await?; + assert_eq!( + response, + server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into() + } + ); Ok(()) } @@ -245,7 +252,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("unknown field `extra_args`")); @@ -262,7 +269,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `prompt`")); @@ -279,7 +286,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `generated_text`")); @@ -297,7 +304,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `detectors`")); @@ -315,11 +322,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::required("detectors"), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: "`detectors` is required".into(), + }, "failed on empty `detectors` scenario" ); @@ -339,11 +349,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{FACT_CHECKING_DETECTOR_SENTENCE}` is not supported by this endpoint" + ), + }, "failed on invalid detector scenario" ); @@ -360,11 +375,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::NOT_FOUND); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing detector scenario" ); diff --git a/tests/generation_with_detection.rs b/tests/generation_with_detection.rs index 54d3acf2..40abfd29 100644 --- a/tests/generation_with_detection.rs +++ b/tests/generation_with_detection.rs @@ -21,7 +21,7 @@ use common::{ ANSWER_RELEVANCE_DETECTOR, DETECTION_ON_GENERATION_DETECTOR_ENDPOINT, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, generation::{GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_UNARY_ENDPOINT}, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_GENERATION_WITH_DETECTION_ENDPOINT, @@ -38,6 +38,7 @@ use fms_guardrails_orchestr8::{ caikit::runtime::nlp::TextGenerationTaskRequest, caikit_data_model::nlp::GeneratedTextResult, }, + server, }; use http::StatusCode; use mocktail::{MockSet, server::MockServer}; @@ -93,8 +94,8 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&mock_generation_server) @@ -175,8 +176,8 @@ async fn detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&mock_generation_server) @@ -221,7 +222,10 @@ async fn client_error() -> Result<(), anyhow::Error> { code: 500, message: "Here's your 500 error".into(), }; - let orchestrator_error_500 = OrchestratorError::internal(); + let orchestrator_error_500 = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; // Add generation mock let model_id = "my-super-model-8B"; @@ -262,8 +266,8 @@ async fn client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&mock_generation_server) @@ -286,7 +290,7 @@ async fn client_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( - response.json::().await?, + response.json::().await?, orchestrator_error_500 ); @@ -305,7 +309,7 @@ async fn client_error() -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( - response.json::().await?, + response.json::().await?, orchestrator_error_500 ); @@ -339,7 +343,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("unknown field `extra_args`")); @@ -362,7 +366,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `model_id`")); @@ -379,7 +383,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `prompt`")); @@ -396,7 +400,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("missing field `detectors`")); @@ -414,9 +418,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); - assert_eq!(response, OrchestratorError::required("detectors")); + assert_eq!( + response, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: "`detectors` is required".into(), + }, + "failed on empty `detectors` scenario" + ); // assert request with invalid type detectors let response = orchestrator_server @@ -435,11 +446,16 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{FACT_CHECKING_DETECTOR_SENTENCE}` is not supported by this endpoint" + ), + }, "failed at invalid detector scenario" ); @@ -457,11 +473,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::NOT_FOUND); - let response = response.json::().await?; + let response = response.json::().await?; debug!("{response:#?}"); assert_eq!( response, - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found"), + }, "failed on non-existing detector scenario" ); diff --git a/tests/streaming_classification_with_gen.rs b/tests/streaming_classification_with_gen.rs index b1aee7c0..b758ef2e 100644 --- a/tests/streaming_classification_with_gen.rs +++ b/tests/streaming_classification_with_gen.rs @@ -27,7 +27,7 @@ use common::{ DETECTOR_NAME_PARENTHESIS_SENTENCE, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, generation::{ GENERATION_NLP_MODEL_ID_HEADER_NAME, GENERATION_NLP_STREAMING_ENDPOINT, GENERATION_NLP_TOKENIZATION_ENDPOINT, @@ -37,7 +37,6 @@ use common::{ ORCHESTRATOR_UNSUITABLE_INPUT_MESSAGE, SseStream, TestOrchestratorServer, }, }; -use eventsource_stream::Eventsource; use fms_guardrails_orchestr8::{ clients::detector::{ContentAnalysisRequest, ContentAnalysisResponse}, models::{ @@ -56,6 +55,7 @@ use fms_guardrails_orchestr8::{ ChunkerTokenizationStreamResult, GeneratedTextStreamResult, Token, TokenizationResults, }, }, + server, }; use futures::{StreamExt, TryStreamExt}; use mocktail::prelude::*; @@ -117,7 +117,7 @@ async fn no_detectors() -> Result<(), anyhow::Error> { }); // Configure mock servers - let generation_server = MockServer::new("nlp").grpc().with_mocks(mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(mocks); // Run test orchestrator server let orchestrator_server = TestOrchestratorServer::builder() @@ -281,9 +281,9 @@ async fn input_detector_no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); - let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&generation_server) @@ -427,13 +427,14 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> { }); then.json([vec![&whole_doc_mock_detection_response]]); }); - let mock_whole_doc_detector_server = MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC) - .with_mocks(whole_doc_detection_mocks); + let mock_whole_doc_detector_server = + MockServer::new_http(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC) + .with_mocks(whole_doc_detection_mocks); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); - let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&generation_server) @@ -570,7 +571,10 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { let detector_error_input = "Detector should return an error"; let generation_server_error_input = "Generation should return an error"; - let orchestrator_error_500 = OrchestratorError::internal(); + let orchestrator_error_500 = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; let mut chunker_mocks = MockSet::new(); chunker_mocks.mock(|when, then| { @@ -648,9 +652,9 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detector_mocks); - let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detector_mocks); + let mock_generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .chunker_servers([&mock_chunker_server]) @@ -679,13 +683,17 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!(messages[0], orchestrator_error_500); - + assert!( + messages[0] + .as_ref() + .is_err_and(|error| *error == orchestrator_error_500) + ); // Test error from detector let response = orchestrator_server .post(ORCHESTRATOR_STREAMING_ENDPOINT) @@ -706,12 +714,17 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!(messages[0], orchestrator_error_500); + assert!( + messages[0] + .as_ref() + .is_err_and(|error| *error == orchestrator_error_500) + ); // Test error from generation server let response = orchestrator_server @@ -733,12 +746,17 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> { debug!(?response, "RESPONSE RECEIVED FROM ORCHESTRATOR"); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!(messages[0], orchestrator_error_500); + assert!( + messages[0] + .as_ref() + .is_err_and(|error| *error == orchestrator_error_500) + ); Ok(()) } @@ -771,7 +789,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!(?response); assert_eq!(response.status(), 422); - let response_body = response.json::().await?; + let response_body = response.json::().await?; assert_eq!(response_body.code, 422); assert!( response_body @@ -802,13 +820,19 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!(?response); assert_eq!(response.status(), 200); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), + Err(server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{FACT_CHECKING_DETECTOR_SENTENCE}` is not supported by this endpoint" + ) + }), "failed at invalid input detector scenario" ); @@ -832,13 +856,17 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!(?response); assert_eq!(response.status(), 200); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + Err(server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found") + }), "failed at non-existing input detector scenario" ); @@ -864,13 +892,19 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!(?response); assert_eq!(response.status(), 200); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError::detector_not_supported(FACT_CHECKING_DETECTOR_SENTENCE), + Err(server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{FACT_CHECKING_DETECTOR_SENTENCE}` is not supported by this endpoint" + ) + }), "failed at invalid output detector scenario" ); @@ -896,13 +930,19 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), 200); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError::chunker_not_supported(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC), + Err(server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, + details: format!( + "detector `{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint" + ) + }), "failed on output detector with invalid chunker scenario" ); @@ -925,13 +965,17 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!(?response); assert_eq!(response.status(), 200); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError::detector_not_found(NON_EXISTING_DETECTOR), + Err(server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found") + }), "failed at non-existing output detector scenario" ); @@ -1067,12 +1111,12 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); let mock_angle_brackets_detector_server = - MockServer::new(angle_brackets_detector).with_mocks(detection_mocks.clone()); + MockServer::new_http(angle_brackets_detector).with_mocks(detection_mocks.clone()); let mock_parenthesis_detector_server = - MockServer::new(parenthesis_detector).with_mocks(detection_mocks); - let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + MockServer::new_http(parenthesis_detector).with_mocks(detection_mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) @@ -1108,11 +1152,7 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream - .collect::>() - .await - .into_iter() - .collect::, anyhow::Error>>()?; + let messages = sse_stream.try_collect::>().await?; debug!("{messages:#?}"); assert_eq!(messages.len(), 2); @@ -1156,11 +1196,7 @@ async fn output_detectors_no_detections() -> Result<(), anyhow::Error> { let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream - .collect::>() - .await - .into_iter() - .collect::, anyhow::Error>>()?; + let messages = sse_stream.try_collect::>().await?; debug!("{messages:#?}"); assert_eq!(messages.len(), 2); @@ -1357,12 +1393,12 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); let mock_angle_brackets_detector_server = - MockServer::new(angle_brackets_detector).with_mocks(angle_brackets_mocks); + MockServer::new_http(angle_brackets_detector).with_mocks(angle_brackets_mocks); let mock_parenthesis_detector_server = - MockServer::new(parenthesis_detector).with_mocks(parenthesis_mocks); - let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + MockServer::new_http(parenthesis_detector).with_mocks(parenthesis_mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&generation_server) @@ -1397,11 +1433,7 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream - .collect::>() - .await - .into_iter() - .collect::, anyhow::Error>>()?; + let messages = sse_stream.try_collect::>().await?; debug!("{messages:#?}"); let expected_messages = vec![ @@ -1469,11 +1501,7 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream - .collect::>() - .await - .into_iter() - .collect::, anyhow::Error>>()?; + let messages = sse_stream.try_collect::>().await?; debug!("{messages:#?}"); let expected_messages = vec![ @@ -1535,7 +1563,10 @@ async fn output_detectors_detections() -> Result<(), anyhow::Error> { async fn output_detector_client_error() -> Result<(), anyhow::Error> { let detector_name = DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE; - let orchestrator_error_500 = OrchestratorError::internal(); + let orchestrator_error_500 = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; // Add generation mock let model_id = "my-super-model-8B"; @@ -1722,9 +1753,9 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); - let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); + let generation_server = MockServer::new_grpc("nlp").with_mocks(generation_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .generation_server(&generation_server) @@ -1751,12 +1782,17 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { .await?; debug!("{response:#?}"); - let sse_stream: SseStream = SseStream::new(response.bytes_stream()); - let messages = sse_stream.try_collect::>().await?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; debug!("{messages:#?}"); assert_eq!(messages.len(), 1); - assert_eq!(messages[0], orchestrator_error_500); + assert!( + messages[0] + .as_ref() + .is_err_and(|error| *error == orchestrator_error_500) + ); // assert detector error let response = orchestrator_server @@ -1775,40 +1811,23 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> { .send() .await?; - debug!("{response:#?}"); - - let mut events = Vec::new(); - let mut event_stream = response.bytes_stream().eventsource(); - while let Some(event) = event_stream.next().await { - match event { - Ok(event) => { - if event.data == "[DONE]" { - break; - } - debug!("recv: {event:?}"); - events.push(event.data); - } - Err(_) => { - panic!("received error from event stream"); - } - } - } - debug!("{events:?}"); - - let first_response = - serde_json::from_str::(events[0].as_str())?; - let second_response = serde_json::from_str::(events[1].as_str())?; + let sse_stream: SseStream = + SseStream::new(response.bytes_stream()); + let messages = sse_stream.collect::>().await; + debug!("{messages:#?}"); - assert_eq!(events.len(), 2); - assert_eq!(first_response.generated_text, Some("I am great!".into())); - assert_eq!( - first_response.token_classification_results.output, - Some(vec![]) + assert_eq!(messages.len(), 2); + assert!(messages[0].as_ref().is_ok_and(|msg| { + msg.generated_text == Some("I am great!".into()) + && msg.token_classification_results.output == Some(vec![]) + && msg.start_index == Some(0) + && msg.processed_index == Some(11) + }),); + assert!( + messages[1] + .as_ref() + .is_err_and(|error| *error == orchestrator_error_500) ); - assert_eq!(first_response.start_index, Some(0)); - assert_eq!(first_response.processed_index, Some(11)); - - assert_eq!(second_response, orchestrator_error_500); Ok(()) } diff --git a/tests/streaming_content_detection.rs b/tests/streaming_content_detection.rs index 6100d561..51ff7eba 100644 --- a/tests/streaming_content_detection.rs +++ b/tests/streaming_content_detection.rs @@ -23,7 +23,7 @@ use common::{ DETECTOR_NAME_PARENTHESIS_SENTENCE, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_STREAM_CONTENT_DETECTION_ENDPOINT, TestOrchestratorServer, json_lines_stream, @@ -39,6 +39,7 @@ use fms_guardrails_orchestr8::{ caikit::runtime::chunkers::BidiStreamingChunkerTokenizationTaskRequest, caikit_data_model::nlp::{ChunkerTokenizationStreamResult, Token}, }, + server, }; use futures::StreamExt; use mocktail::{MockSet, server::MockServer}; @@ -132,11 +133,11 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Run test orchestrator server - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); let mock_angle_brackets_detector_server = - MockServer::new(angle_brackets_detector).with_mocks(detection_mocks.clone()); + MockServer::new_http(angle_brackets_detector).with_mocks(detection_mocks.clone()); let mock_parenthesis_detector_server = - MockServer::new(parenthesis_detector).with_mocks(detection_mocks); + MockServer::new_http(parenthesis_detector).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([ @@ -368,11 +369,11 @@ async fn detections() -> Result<(), anyhow::Error> { }); // Run test orchestrator server - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); let mock_angle_brackets_detector_server = - MockServer::new(angle_brackets_detector).with_mocks(angle_brackets_detection_mocks); + MockServer::new_http(angle_brackets_detector).with_mocks(angle_brackets_detection_mocks); let mock_parenthesis_detector_server = - MockServer::new(parenthesis_detector).with_mocks(parenthesis_detection_mocks); + MockServer::new_http(parenthesis_detector).with_mocks(parenthesis_detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([ @@ -505,7 +506,10 @@ async fn client_error() -> Result<(), anyhow::Error> { let chunker_error_payload = "Chunker should return an error."; let detector_error_payload = "Detector should return an error."; - let orchestrator_error_500 = OrchestratorError::internal(); + let orchestrator_error_500 = server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into(), + }; let mut chunker_mocks = MockSet::new(); chunker_mocks.mock(|when, then| { @@ -554,8 +558,8 @@ async fn client_error() -> Result<(), anyhow::Error> { }); // Run test orchestrator server - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -578,7 +582,7 @@ async fn client_error() -> Result<(), anyhow::Error> { ]))) .send() .await?; - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -602,7 +606,7 @@ async fn client_error() -> Result<(), anyhow::Error> { ]))) .send() .await?; - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -635,7 +639,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { })]))) .send() .await?; - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -654,7 +658,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { })]))) .send() .await?; - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -676,14 +680,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { ]))) .send() .await?; - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); messages.push(serde_json::from_slice(&msg[..]).unwrap()); } - let expected_messages = [OrchestratorError { - code: 422, + let expected_messages = [server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "`detectors` is required for the first message".into(), }]; assert_eq!( @@ -703,14 +707,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { ]))) .send() .await?; - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); messages.push(serde_json::from_slice(&msg[..]).unwrap()); } - let expected_messages = [OrchestratorError { - code: 422, + let expected_messages = [server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "`detectors` must not be empty".into(), }]; assert_eq!( @@ -735,7 +739,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .await?; assert_eq!(response.status(), 200); - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -745,11 +749,10 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: format!( - "detector `{}` is not supported by this endpoint", - FACT_CHECKING_DETECTOR_SENTENCE + "detector `{FACT_CHECKING_DETECTOR_SENTENCE}` is not supported by this endpoint" ) }, "failed at invalid input detector scenario" @@ -772,7 +775,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .await?; assert_eq!(response.status(), 200); - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -782,11 +785,10 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: format!( - "detector `{}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint", - DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC + "detector `{DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC}` uses chunker `whole_doc_chunker`, which is not supported by this endpoint" ) }, "failed at detector with invalid chunker scenario" @@ -809,7 +811,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { .await?; assert_eq!(response.status(), 200); - let mut messages = Vec::::with_capacity(1); + let mut messages = Vec::::with_capacity(1); let mut stream = response.bytes_stream(); while let Some(Ok(msg)) = stream.next().await { debug!("recv: {msg:?}"); @@ -819,9 +821,9 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { assert_eq!(messages.len(), 1); assert_eq!( messages[0], - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found") }, "failed at non-existing input detector scenario" ); diff --git a/tests/test_config.yaml b/tests/test_config.yaml index cc691881..84525879 100644 --- a/tests/test_config.yaml +++ b/tests/test_config.yaml @@ -1,4 +1,4 @@ -chat_generation: +openai: service: hostname: localhost port: 3000 @@ -66,3 +66,15 @@ detectors: hostname: localhost chunker_id: whole_doc_chunker default_threshold: 0.5 + pii_detector_sentence: + type: text_contents + service: + hostname: localhost + chunker_id: sentence_chunker + default_threshold: 0.5 + pii_detector_whole_doc: + type: text_contents + service: + hostname: localhost + chunker_id: whole_doc_chunker + default_threshold: 0.5 \ No newline at end of file diff --git a/tests/text_content_detection.rs b/tests/text_content_detection.rs index 1f160398..1227394f 100644 --- a/tests/text_content_detection.rs +++ b/tests/text_content_detection.rs @@ -23,7 +23,7 @@ use common::{ DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE, DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC, FACT_CHECKING_DETECTOR_SENTENCE, NON_EXISTING_DETECTOR, TEXT_CONTENTS_DETECTOR_ENDPOINT, }, - errors::{DetectorError, OrchestratorError}, + errors::DetectorError, orchestrator::{ ORCHESTRATOR_CONFIG_FILE_PATH, ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT, TestOrchestratorServer, @@ -41,6 +41,7 @@ use fms_guardrails_orchestr8::{ caikit::runtime::chunkers::ChunkerTokenizationTaskRequest, caikit_data_model::nlp::{Token, TokenizationResults}, }, + server, }; use hyper::StatusCode; use mocktail::prelude::*; @@ -110,11 +111,11 @@ async fn no_detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); let mock_sentence_detector_server = - MockServer::new(sentence_detector).with_mocks(sentence_detector_mocks); + MockServer::new_http(sentence_detector).with_mocks(sentence_detector_mocks); let mock_whole_doc_detector_server = - MockServer::new(whole_doc_detector).with_mocks(whole_doc_detector_mocks); + MockServer::new_http(whole_doc_detector).with_mocks(whole_doc_detector_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .chunker_servers([&mock_chunker_server]) @@ -252,11 +253,11 @@ async fn detections() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks); + let mock_chunker_server = MockServer::new_grpc(chunker_id).with_mocks(chunker_mocks); let mock_whole_doc_detector_server = - MockServer::new(whole_doc_detector).with_mocks(whole_doc_detector_mocks); + MockServer::new_http(whole_doc_detector).with_mocks(whole_doc_detector_mocks); let mock_sentence_detector_server = - MockServer::new(sentence_detector).with_mocks(sentence_detector_mocks); + MockServer::new_http(sentence_detector).with_mocks(sentence_detector_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .chunker_servers([&mock_chunker_server]) @@ -364,7 +365,7 @@ async fn client_error() -> Result<(), anyhow::Error> { }); // Start orchestrator server and its dependencies - let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks); + let mock_detector_server = MockServer::new_http(detector_name).with_mocks(detection_mocks); let orchestrator_server = TestOrchestratorServer::builder() .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) .detector_servers([&mock_detector_server]) @@ -386,8 +387,14 @@ async fn client_error() -> Result<(), anyhow::Error> { // assertions assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - let response: OrchestratorError = response.json().await?; - assert_eq!(response, OrchestratorError::internal()); + let response: server::Error = response.json().await?; + assert_eq!( + response, + server::Error { + code: http::StatusCode::INTERNAL_SERVER_ERROR, + details: "unexpected error occurred while processing request".into() + } + ); Ok(()) } @@ -416,7 +423,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response: OrchestratorError = response.json().await?; + let response: server::Error = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.contains("unknown field `extra_args`")); @@ -432,7 +439,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response: OrchestratorError = response.json().await?; + let response: server::Error = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.starts_with("missing field `detectors`")); @@ -448,7 +455,7 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response: OrchestratorError = response.json().await?; + let response: server::Error = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); assert_eq!(response.code, 422); assert!(response.details.starts_with("missing field `content`")); @@ -465,12 +472,12 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response: OrchestratorError = response.json().await?; + let response: server::Error = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: "`detectors` is required".into() }, "failed on empty `detectors` scenario" @@ -488,15 +495,14 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); - let response: OrchestratorError = response.json().await?; + let response: server::Error = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 422, + server::Error { + code: http::StatusCode::UNPROCESSABLE_ENTITY, details: format!( - "detector `{}` is not supported by this endpoint", - FACT_CHECKING_DETECTOR_SENTENCE + "detector `{FACT_CHECKING_DETECTOR_SENTENCE}` is not supported by this endpoint" ) }, "failed on invalid detector type scenario" @@ -514,13 +520,13 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> { debug!("{response:#?}"); assert_eq!(response.status(), StatusCode::NOT_FOUND); - let response: OrchestratorError = response.json().await?; + let response: server::Error = response.json().await?; debug!("orchestrator json response body:\n{response:#?}"); assert_eq!( response, - OrchestratorError { - code: 404, - details: format!("detector `{}` not found", NON_EXISTING_DETECTOR) + server::Error { + code: http::StatusCode::NOT_FOUND, + details: format!("detector `{NON_EXISTING_DETECTOR}` not found") }, "failed on non-existing detector scenario" );