Skip to content

Commit fa42884

Browse files
committed
feat: add provider health checks and improve rig adapters
- load per-provider credentials from YAML profiles and expose a health subcommand that primes env overrides before probing - enhance rig adapter parsing (code fences, fallback verdicts, Gemini JSON MIME) so health checks survive empty or non-textual responses - document the workflow, ship an example llm_providers.yaml, and add integration/unit tests for the new paths
1 parent eb79c6f commit fa42884

File tree

9 files changed

+1149
-105
lines changed

9 files changed

+1149
-105
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
1616
serde = { version = "1", features = ["derive"] }
1717
serde_json = "1"
1818
serde_yaml = "0.9"
19+
json5 = "0.4"
1920
aho-corasick = "1"
2021
regex = "1"
2122
clap = { version = "4", features = ["derive"] }

README.md

Lines changed: 157 additions & 64 deletions
Large diffs are not rendered by default.

crates/llm-guard-cli/src/main.rs

Lines changed: 173 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
use std::collections::HashMap;
2+
use std::env;
23
use std::fs as stdfs;
34
use std::path::{Path, PathBuf};
45
use std::process;
56
use std::sync::Arc;
67
use std::time::Duration;
78

8-
use anyhow::{anyhow, Context, Result};
9+
use anyhow::{anyhow, bail, Context, Result};
910
use clap::{Parser, Subcommand};
1011
use config::Config;
1112
use llm_guard_core::{
1213
build_client, render_report, DefaultScanner, FileRuleRepository, LlmClient, LlmSettings,
13-
OutputFormat, RiskBand, RuleKind, RuleRepository, Scanner,
14+
OutputFormat, RiskBand, RiskThresholds, RuleKind, RuleRepository, ScanReport, Scanner,
15+
ScoreBreakdown,
1416
};
1517
use serde::Deserialize;
1618
use serde_yaml;
@@ -52,6 +54,10 @@ struct Cli {
5254
)]
5355
providers_config: PathBuf,
5456

57+
/// Enable verbose diagnostics (including raw provider payloads on errors).
58+
#[arg(long, global = true)]
59+
debug: bool,
60+
5561
#[command(subcommand)]
5662
command: Option<Commands>,
5763
}
@@ -97,6 +103,15 @@ enum Commands {
97103
#[arg(long)]
98104
workspace: Option<String>,
99105
},
106+
/// Execute health checks against configured LLM providers.
107+
Health {
108+
/// Limit the health check to a single provider name.
109+
#[arg(long)]
110+
provider: Option<String>,
111+
/// Skip the live LLM call; only validate configuration/build steps.
112+
#[arg(long)]
113+
dry_run: bool,
114+
},
100115
}
101116

102117
#[derive(Debug, Deserialize, Clone)]
@@ -165,8 +180,8 @@ impl ProviderProfiles {
165180
}
166181

167182
fn prime_env(&self, provider: &str) {
168-
if let Some(profile) = self.entries.get(&provider.to_ascii_lowercase()) {
169-
maybe_set_env("LLM_GUARD_PROVIDER", Some(provider.to_string()));
183+
if let Some(profile) = self.get(provider) {
184+
maybe_set_env("LLM_GUARD_PROVIDER", Some(profile.name.clone()));
170185
maybe_set_env("LLM_GUARD_API_KEY", profile.api_key.clone());
171186
maybe_set_env("LLM_GUARD_ENDPOINT", profile.endpoint.clone());
172187
maybe_set_env("LLM_GUARD_MODEL", profile.model.clone());
@@ -186,7 +201,7 @@ impl ProviderProfiles {
186201
}
187202

188203
fn apply_defaults(&self, provider: &str, settings: &mut LlmSettings) {
189-
if let Some(profile) = self.entries.get(&provider.to_ascii_lowercase()) {
204+
if let Some(profile) = self.get(provider) {
190205
if settings.model.is_none() {
191206
settings.model = profile.model.clone();
192207
}
@@ -212,6 +227,58 @@ impl ProviderProfiles {
212227
}
213228
}
214229
}
230+
231+
fn get(&self, provider: &str) -> Option<&ProviderProfile> {
232+
self.entries.get(&provider.to_ascii_lowercase())
233+
}
234+
235+
fn names(&self) -> Vec<String> {
236+
self.entries
237+
.values()
238+
.map(|profile| profile.name.clone())
239+
.collect()
240+
}
241+
242+
fn is_empty(&self) -> bool {
243+
self.entries.is_empty()
244+
}
245+
}
246+
247+
struct EnvGuard {
248+
snapshot: Vec<(String, Option<String>)>,
249+
}
250+
251+
impl EnvGuard {
252+
fn new() -> Self {
253+
Self {
254+
snapshot: Vec::new(),
255+
}
256+
}
257+
258+
fn set(&mut self, key: &str, value: &str) {
259+
if !self.snapshot.iter().any(|(k, _)| k == key) {
260+
self.snapshot.push((key.to_string(), env::var(key).ok()));
261+
}
262+
env::set_var(key, value);
263+
}
264+
265+
fn maybe_set(&mut self, key: &str, value: Option<&str>) {
266+
if let Some(val) = value {
267+
self.set(key, val);
268+
}
269+
}
270+
}
271+
272+
impl Drop for EnvGuard {
273+
fn drop(&mut self) {
274+
for (key, previous) in self.snapshot.drain(..).rev() {
275+
if let Some(value) = previous {
276+
env::set_var(&key, value);
277+
} else {
278+
env::remove_var(&key);
279+
}
280+
}
281+
}
215282
}
216283

217284
#[cfg(test)]
@@ -338,6 +405,11 @@ async fn main() {
338405
async fn run() -> Result<i32> {
339406
init_tracing();
340407
let cli = Cli::parse();
408+
if cli.debug {
409+
env::set_var("LLM_GUARD_DEBUG", "1");
410+
} else {
411+
env::remove_var("LLM_GUARD_DEBUG");
412+
}
341413
let provider_profiles = ProviderProfiles::load(&cli.providers_config)?;
342414
match cli.command.unwrap_or(Commands::ListRules { json: false }) {
343415
Commands::ListRules { json } => {
@@ -373,6 +445,9 @@ async fn run() -> Result<i32> {
373445
)
374446
.await
375447
}
448+
Commands::Health { provider, dry_run } => {
449+
run_health(&provider_profiles, provider.as_deref(), !dry_run).await
450+
}
376451
}
377452
}
378453

@@ -642,6 +717,99 @@ fn exit_code_for_band(band: RiskBand) -> i32 {
642717
}
643718
}
644719

720+
async fn run_health(
721+
profiles: &ProviderProfiles,
722+
provider_filter: Option<&str>,
723+
perform_call: bool,
724+
) -> Result<i32> {
725+
let mut targets = if let Some(filter) = provider_filter {
726+
if let Some(profile) = profiles.get(filter) {
727+
vec![profile.name.clone()]
728+
} else {
729+
vec![filter.to_string()]
730+
}
731+
} else if !profiles.is_empty() {
732+
profiles.names()
733+
} else if let Ok(env_provider) = env::var("LLM_GUARD_PROVIDER") {
734+
vec![env_provider]
735+
} else {
736+
bail!("no providers configured; supply --provider or create llm_providers.yaml");
737+
};
738+
739+
targets.sort();
740+
targets.dedup();
741+
742+
let mut failed = false;
743+
for provider in targets {
744+
println!("Checking provider {provider}...");
745+
match check_provider(profiles, &provider, perform_call).await {
746+
Ok(()) => println!(" ok"),
747+
Err(err) => {
748+
failed = true;
749+
eprintln!(" failed: {err:#}");
750+
}
751+
}
752+
}
753+
754+
Ok(if failed { 1 } else { 0 })
755+
}
756+
757+
async fn check_provider(
758+
profiles: &ProviderProfiles,
759+
provider: &str,
760+
perform_call: bool,
761+
) -> Result<()> {
762+
let profile_snapshot = profiles.get(provider).cloned();
763+
let canonical_provider = profile_snapshot
764+
.as_ref()
765+
.map(|p| p.name.clone())
766+
.unwrap_or_else(|| provider.to_string());
767+
768+
let mut guard = EnvGuard::new();
769+
guard.set("LLM_GUARD_PROVIDER", &canonical_provider);
770+
if let Some(profile) = profile_snapshot.as_ref() {
771+
guard.maybe_set("LLM_GUARD_API_KEY", profile.api_key.as_deref());
772+
guard.maybe_set("LLM_GUARD_ENDPOINT", profile.endpoint.as_deref());
773+
guard.maybe_set("LLM_GUARD_MODEL", profile.model.as_deref());
774+
guard.maybe_set("LLM_GUARD_DEPLOYMENT", profile.deployment.as_deref());
775+
guard.maybe_set("LLM_GUARD_PROJECT", profile.project.as_deref());
776+
guard.maybe_set("LLM_GUARD_WORKSPACE", profile.workspace.as_deref());
777+
if let Some(timeout) = profile.timeout_secs {
778+
guard.set("LLM_GUARD_TIMEOUT_SECS", &timeout.to_string());
779+
}
780+
if let Some(retries) = profile.max_retries {
781+
guard.set("LLM_GUARD_MAX_RETRIES", &retries.to_string());
782+
}
783+
guard.maybe_set("LLM_GUARD_API_VERSION", profile.api_version.as_deref());
784+
}
785+
786+
let mut settings = LlmSettings::from_env()?;
787+
let provider_for_defaults = settings.provider.clone();
788+
profiles.apply_defaults(&provider_for_defaults, &mut settings);
789+
drop(guard);
790+
791+
let client = build_client(&settings)?;
792+
if perform_call {
793+
let report = dummy_report();
794+
let _ = client
795+
.enrich("Health check probe", &report)
796+
.await
797+
.context("LLM enrich call failed")?;
798+
}
799+
800+
Ok(())
801+
}
802+
803+
fn dummy_report() -> ScanReport {
804+
ScanReport::from_breakdown(
805+
Vec::new(),
806+
0,
807+
None,
808+
ScoreBreakdown::default(),
809+
&RiskThresholds::default(),
810+
)
811+
}
812+
645813
fn init_tracing() {
646814
let env_filter =
647815
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info,tokio=warn"));
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
use assert_cmd::Command;
2+
use once_cell::sync::Lazy;
3+
use predicates::str::contains;
4+
use std::env;
5+
use std::fs::write;
6+
use std::sync::Mutex;
7+
use tempfile;
8+
9+
static ENV_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
10+
11+
fn reset_env() {
12+
env::remove_var("LLM_GUARD_PROVIDER");
13+
env::remove_var("LLM_GUARD_API_KEY");
14+
env::remove_var("LLM_GUARD_ENDPOINT");
15+
env::remove_var("LLM_GUARD_MODEL");
16+
env::remove_var("LLM_GUARD_DEPLOYMENT");
17+
env::remove_var("LLM_GUARD_PROJECT");
18+
env::remove_var("LLM_GUARD_WORKSPACE");
19+
env::remove_var("LLM_GUARD_TIMEOUT_SECS");
20+
env::remove_var("LLM_GUARD_MAX_RETRIES");
21+
env::remove_var("LLM_GUARD_API_VERSION");
22+
env::remove_var("LLM_GUARD_DEBUG");
23+
}
24+
25+
#[test]
26+
fn health_check_with_noop_profile() {
27+
let _guard = ENV_LOCK.lock().unwrap();
28+
reset_env();
29+
30+
let file = tempfile::Builder::new().suffix(".yaml").tempfile().unwrap();
31+
32+
write(file.path(), "providers:\n - name: \"noop\"\n").unwrap();
33+
34+
let mut cmd = Command::cargo_bin("llm-guard-cli").unwrap();
35+
cmd.args([
36+
"--providers-config",
37+
file.path().to_str().unwrap(),
38+
"health",
39+
])
40+
.assert()
41+
.success()
42+
.stdout(contains("Checking provider noop"))
43+
.stdout(contains("ok"));
44+
}

crates/llm-guard-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tracing.workspace = true
1818
reqwest.workspace = true
1919
tokio.workspace = true
2020
rig-core = "0.22.0"
21+
json5.workspace = true
2122

2223
[dev-dependencies]
2324
tempfile = "3"

0 commit comments

Comments
 (0)