|
8 | 8 | from pathlib import Path |
9 | 9 | from typing import TypedDict |
10 | 10 |
|
11 | | -# Disable LiteLLM's async logging to avoid event loop issues with joblib |
12 | | -import litellm |
| 11 | +from litellm_multiprocess_fix import patch_litellm_for_multiprocessing |
| 12 | + |
| 13 | +patch_litellm_for_multiprocessing() |
| 14 | + |
13 | 15 | from codegen import ClaudeAppBuilder |
14 | 16 | from codegen import GenerationMetrics as ClaudeGenerationMetrics |
15 | 17 | from codegen_multi import LiteLLMAppBuilder |
16 | 18 | from dotenv import load_dotenv |
17 | 19 | from joblib import Parallel, delayed |
18 | 20 | from prompts_databricks import PROMPTS as DATABRICKS_PROMPTS |
19 | 21 |
|
20 | | -litellm.turn_off_message_logging = True |
21 | | -litellm.drop_params = True # silently drop unsupported params instead of warning |
22 | | - |
23 | 22 | # Unified type for metrics from both backends |
24 | 23 | GenerationMetrics = ClaudeGenerationMetrics |
25 | 24 |
|
@@ -48,12 +47,9 @@ def run_single_generation( |
48 | 47 | suppress_logs: bool = True, |
49 | 48 | mcp_binary: str | None = None, |
50 | 49 | ) -> RunResult: |
51 | | - # Ensure LiteLLM is configured fresh in each worker process |
| 50 | + # re-apply litellm patch in worker process (joblib uses spawn/fork) |
52 | 51 | if backend == "litellm": |
53 | | - import litellm |
54 | | - |
55 | | - litellm.turn_off_message_logging = True |
56 | | - litellm.drop_params = True |
| 52 | + patch_litellm_for_multiprocessing() |
57 | 53 |
|
58 | 54 | def timeout_handler(signum, frame): |
59 | 55 | raise TimeoutError("Generation timed out after 1200 seconds") |
|
0 commit comments