@@ -18,17 +18,18 @@ package com.google.firebase.ai
1818
1919import android.graphics.Bitmap
2020import com.google.firebase.ai.type.Content
21+ import com.google.firebase.ai.type.FunctionCallPart
22+ import com.google.firebase.ai.type.FunctionResponsePart
2123import com.google.firebase.ai.type.GenerateContentResponse
22- import com.google.firebase.ai.type.ImagePart
23- import com.google.firebase.ai.type.InlineDataPart
2424import com.google.firebase.ai.type.InvalidStateException
2525import com.google.firebase.ai.type.TextPart
2626import com.google.firebase.ai.type.content
2727import java.util.LinkedList
2828import java.util.concurrent.Semaphore
2929import kotlinx.coroutines.flow.Flow
30+ import kotlinx.coroutines.flow.FlowCollector
3031import kotlinx.coroutines.flow.onCompletion
31- import kotlinx.coroutines.flow.onEach
32+ import kotlinx.coroutines.flow.transform
3233
3334/* *
3435 * Representation of a multi-turn interaction with a model.
@@ -51,25 +52,36 @@ public class Chat(
5152 private var lock = Semaphore (1 )
5253
5354 /* *
54- * Sends a message using the provided [prompt ]; automatically providing the existing [history] as
55- * context.
55+ * Sends a message using the provided [inputPrompt ]; automatically providing the existing
56+ * [history] as context.
5657 *
5758 * If successful, the message and response will be added to the [history]. If unsuccessful,
5859 * [history] will remain unchanged.
5960 *
60- * @param prompt The input that, together with the history, will be given to the model as the
61+ * @param inputPrompt The input that, together with the history, will be given to the model as the
6162 * prompt.
62- * @throws InvalidStateException if [prompt ] is not coming from the 'user' role.
63+ * @throws InvalidStateException if [inputPrompt ] is not coming from the 'user' role.
6364 * @throws InvalidStateException if the [Chat] instance has an active request.
6465 */
65- public suspend fun sendMessage (prompt : Content ): GenerateContentResponse {
66- prompt .assertComesFromUser()
66+ public suspend fun sendMessage (inputPrompt : Content ): GenerateContentResponse {
67+ inputPrompt .assertComesFromUser()
6768 attemptLock()
69+ var response: GenerateContentResponse
70+ var prompt = inputPrompt
6871 try {
69- val fullPrompt = history + prompt
70- val response = model.generateContent(fullPrompt.first(), * fullPrompt.drop(1 ).toTypedArray())
71- history.add(prompt)
72- history.add(response.candidates.first().content)
72+ while (true ) {
73+ response = model.generateContent(listOf (* history.toTypedArray(), prompt))
74+ val responsePart = response.candidates.first().content.parts.first()
75+
76+ history.add(prompt)
77+ history.add(response.candidates.first().content)
78+ if (responsePart is FunctionCallPart ) {
79+ val output = model.executeFunction(responsePart)
80+ prompt = Content (" function" , listOf (FunctionResponsePart (responsePart.name, output)))
81+ } else {
82+ break
83+ }
84+ }
7385 return response
7486 } finally {
7587 lock.release()
@@ -130,43 +142,20 @@ public class Chat(
130142
131143 val fullPrompt = history + prompt
132144 val flow = model.generateContentStream(fullPrompt.first(), * fullPrompt.drop(1 ).toTypedArray())
133- val bitmaps = LinkedList <Bitmap >()
134- val inlineDataParts = LinkedList <InlineDataPart >()
135- val text = StringBuilder ()
145+ val tempHistory = LinkedList <Content >()
146+ tempHistory.add(prompt)
136147
137148 /* *
138149 * TODO: revisit when images and inline data are returned. This will cause issues with how
139150 * things are structured in the response. eg; a text/image/text response will be (incorrectly)
140151 * represented as image/text
141152 */
142153 return flow
143- .onEach {
144- for (part in it.candidates.first().content.parts) {
145- when (part) {
146- is TextPart -> text.append(part.text)
147- is ImagePart -> bitmaps.add(part.image)
148- is InlineDataPart -> inlineDataParts.add(part)
149- }
150- }
151- }
154+ .transform { response -> automaticFunctionExecutingTransform(this , tempHistory, response) }
152155 .onCompletion {
153156 lock.release()
154157 if (it == null ) {
155- val content =
156- content(" model" ) {
157- for (bitmap in bitmaps) {
158- image(bitmap)
159- }
160- for (inlineDataPart in inlineDataParts) {
161- inlineData(inlineDataPart.inlineData, inlineDataPart.mimeType)
162- }
163- if (text.isNotBlank()) {
164- text(text.toString())
165- }
166- }
167-
168- history.add(prompt)
169- history.add(content)
158+ history.addAll(tempHistory)
170159 }
171160 }
172161 }
@@ -209,6 +198,62 @@ public class Chat(
209198 return sendMessageStream(content)
210199 }
211200
201+ private suspend fun automaticFunctionExecutingTransform (
202+ transformer : FlowCollector <GenerateContentResponse >,
203+ tempHistory : LinkedList <Content >,
204+ response : GenerateContentResponse
205+ ) {
206+ for (part in response.candidates.first().content.parts) {
207+ when (part) {
208+ is TextPart -> {
209+ transformer.emit(response)
210+ addTextToHistory(tempHistory, part)
211+ }
212+ is FunctionCallPart -> {
213+ val functionCall =
214+ response.candidates.first().content.parts.first { it is FunctionCallPart }
215+ as FunctionCallPart
216+ val output = model.executeFunction(functionCall)
217+ val functionResponse =
218+ Content (" function" , listOf (FunctionResponsePart (functionCall.name, output)))
219+ tempHistory.add(response.candidates.first().content)
220+ tempHistory.add(functionResponse)
221+ model
222+ .generateContentStream(listOf (* history.toTypedArray(), * tempHistory.toTypedArray()))
223+ .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
224+ }
225+ else -> {
226+ transformer.emit(response)
227+ tempHistory.add(Content (" model" , listOf (part)))
228+ }
229+ }
230+ }
231+ }
232+
233+ private fun addTextToHistory (tempHistory : LinkedList <Content >, textPart : TextPart ) {
234+ val lastContent = tempHistory.lastOrNull()
235+ if (lastContent?.role == " model" && lastContent.parts.any { it is TextPart }) {
236+ tempHistory.removeLast()
237+ val editedContent =
238+ Content (
239+ " model" ,
240+ lastContent.parts.map {
241+ when (it) {
242+ is TextPart -> {
243+ TextPart (it.text + textPart.text)
244+ }
245+ else -> {
246+ it
247+ }
248+ }
249+ }
250+ )
251+ tempHistory.add(editedContent)
252+ return
253+ }
254+ tempHistory.add(Content (" model" , listOf (textPart)))
255+ }
256+
212257 private fun Content.assertComesFromUser () {
213258 if (role !in listOf (" user" , " function" )) {
214259 throw InvalidStateException (" Chat prompts should come from the 'user' or 'function' role." )
0 commit comments