33
44from langchain_core .language_models .llms import Generation , LLMResult
55from langchain_core .prompt_values import PromptValue
6- from llama_stack .apis .inference import SamplingParams , TopPSamplingStrategy
6+ from llama_stack .apis .inference import (
7+ OpenAICompletionRequestWithExtraBody ,
8+ OpenAIEmbeddingsRequestWithExtraBody ,
9+ SamplingParams ,
10+ TopPSamplingStrategy ,
11+ )
712from ragas .embeddings .base import BaseRagasEmbeddings
813from ragas .llms .base import BaseRagasLLM
914from ragas .run_config import RunConfig
@@ -39,10 +44,11 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
3944 async def aembed_documents (self , texts : list [str ]) -> list [list [float ]]:
4045 """Embed documents using Llama Stack inference API."""
4146 try :
42- response = await self . inference_api . openai_embeddings (
47+ request = OpenAIEmbeddingsRequestWithExtraBody (
4348 model = self .embedding_model_id ,
4449 input = texts ,
4550 )
51+ response = await self .inference_api .openai_embeddings (request )
4652 return [data .embedding for data in response .data ]
4753 except Exception as e :
4854 logger .error (f"Document embedding failed: { str (e )} " )
@@ -51,10 +57,11 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
5157 async def aembed_query (self , text : str ) -> list [float ]:
5258 """Embed query using Llama Stack inference API."""
5359 try :
54- response = await self . inference_api . openai_embeddings (
60+ request = OpenAIEmbeddingsRequestWithExtraBody (
5561 model = self .embedding_model_id ,
5662 input = text ,
5763 )
64+ response = await self .inference_api .openai_embeddings (request )
5865 return response .data [0 ].embedding # type: ignore
5966 except Exception as e :
6067 logger .error (f"Query embedding failed: { str (e )} " )
@@ -109,7 +116,7 @@ async def agenerate_text(
109116 # sampling params for this generation should be set via the benchmark config
110117 # we will ignore the temperature and stop params passed in here
111118 for _ in range (n ):
112- response = await self . inference_api . openai_completion (
119+ request = OpenAICompletionRequestWithExtraBody (
113120 model = self .model_id ,
114121 prompt = prompt .to_string (),
115122 max_tokens = self .sampling_params .max_tokens
@@ -125,6 +132,7 @@ async def agenerate_text(
125132 else None ,
126133 stop = self .sampling_params .stop if self .sampling_params else None ,
127134 )
135+ response = await self .inference_api .openai_completion (request )
128136
129137 if not response .choices :
130138 logger .warning ("Completion response returned no choices" )
0 commit comments