Skip to content

Commit 74eece1

Browse files
author
David Motsonashvili
committed
introduced new auto function declaration type and auto function calling
1 parent 8ffa303 commit 74eece1

File tree

6 files changed

+225
-70
lines changed

6 files changed

+225
-70
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@ package com.google.firebase.ai
1818

1919
import android.graphics.Bitmap
2020
import com.google.firebase.ai.type.Content
21+
import com.google.firebase.ai.type.FunctionCallPart
22+
import com.google.firebase.ai.type.FunctionResponsePart
2123
import com.google.firebase.ai.type.GenerateContentResponse
22-
import com.google.firebase.ai.type.ImagePart
23-
import com.google.firebase.ai.type.InlineDataPart
2424
import com.google.firebase.ai.type.InvalidStateException
2525
import com.google.firebase.ai.type.TextPart
2626
import com.google.firebase.ai.type.content
2727
import java.util.LinkedList
2828
import java.util.concurrent.Semaphore
2929
import kotlinx.coroutines.flow.Flow
30+
import kotlinx.coroutines.flow.FlowCollector
3031
import kotlinx.coroutines.flow.onCompletion
31-
import kotlinx.coroutines.flow.onEach
32+
import kotlinx.coroutines.flow.transform
3233

3334
/**
3435
* Representation of a multi-turn interaction with a model.
@@ -51,25 +52,36 @@ public class Chat(
5152
private var lock = Semaphore(1)
5253

5354
/**
54-
* Sends a message using the provided [prompt]; automatically providing the existing [history] as
55-
* context.
55+
* Sends a message using the provided [inputPrompt]; automatically providing the existing
56+
* [history] as context.
5657
*
5758
* If successful, the message and response will be added to the [history]. If unsuccessful,
5859
* [history] will remain unchanged.
5960
*
60-
* @param prompt The input that, together with the history, will be given to the model as the
61+
* @param inputPrompt The input that, together with the history, will be given to the model as the
6162
* prompt.
62-
* @throws InvalidStateException if [prompt] is not coming from the 'user' role.
63+
* @throws InvalidStateException if [inputPrompt] is not coming from the 'user' role.
6364
* @throws InvalidStateException if the [Chat] instance has an active request.
6465
*/
65-
public suspend fun sendMessage(prompt: Content): GenerateContentResponse {
66-
prompt.assertComesFromUser()
66+
public suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse {
67+
inputPrompt.assertComesFromUser()
6768
attemptLock()
69+
var response: GenerateContentResponse
70+
var prompt = inputPrompt
6871
try {
69-
val fullPrompt = history + prompt
70-
val response = model.generateContent(fullPrompt.first(), *fullPrompt.drop(1).toTypedArray())
71-
history.add(prompt)
72-
history.add(response.candidates.first().content)
72+
while (true) {
73+
response = model.generateContent(listOf(*history.toTypedArray(), prompt))
74+
val responsePart = response.candidates.first().content.parts.first()
75+
76+
history.add(prompt)
77+
history.add(response.candidates.first().content)
78+
if (responsePart is FunctionCallPart) {
79+
val output = model.executeFunction(responsePart)
80+
prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output)))
81+
} else {
82+
break
83+
}
84+
}
7385
return response
7486
} finally {
7587
lock.release()
@@ -130,43 +142,20 @@ public class Chat(
130142

131143
val fullPrompt = history + prompt
132144
val flow = model.generateContentStream(fullPrompt.first(), *fullPrompt.drop(1).toTypedArray())
133-
val bitmaps = LinkedList<Bitmap>()
134-
val inlineDataParts = LinkedList<InlineDataPart>()
135-
val text = StringBuilder()
145+
val tempHistory = LinkedList<Content>()
146+
tempHistory.add(prompt)
136147

137148
/**
138149
* TODO: revisit when images and inline data are returned. This will cause issues with how
139150
* things are structured in the response. eg; a text/image/text response will be (incorrectly)
140151
* represented as image/text
141152
*/
142153
return flow
143-
.onEach {
144-
for (part in it.candidates.first().content.parts) {
145-
when (part) {
146-
is TextPart -> text.append(part.text)
147-
is ImagePart -> bitmaps.add(part.image)
148-
is InlineDataPart -> inlineDataParts.add(part)
149-
}
150-
}
151-
}
154+
.transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) }
152155
.onCompletion {
153156
lock.release()
154157
if (it == null) {
155-
val content =
156-
content("model") {
157-
for (bitmap in bitmaps) {
158-
image(bitmap)
159-
}
160-
for (inlineDataPart in inlineDataParts) {
161-
inlineData(inlineDataPart.inlineData, inlineDataPart.mimeType)
162-
}
163-
if (text.isNotBlank()) {
164-
text(text.toString())
165-
}
166-
}
167-
168-
history.add(prompt)
169-
history.add(content)
158+
history.addAll(tempHistory)
170159
}
171160
}
172161
}
@@ -209,6 +198,62 @@ public class Chat(
209198
return sendMessageStream(content)
210199
}
211200

201+
private suspend fun automaticFunctionExecutingTransform(
202+
transformer: FlowCollector<GenerateContentResponse>,
203+
tempHistory: LinkedList<Content>,
204+
response: GenerateContentResponse
205+
) {
206+
for (part in response.candidates.first().content.parts) {
207+
when (part) {
208+
is TextPart -> {
209+
transformer.emit(response)
210+
addTextToHistory(tempHistory, part)
211+
}
212+
is FunctionCallPart -> {
213+
val functionCall =
214+
response.candidates.first().content.parts.first { it is FunctionCallPart }
215+
as FunctionCallPart
216+
val output = model.executeFunction(functionCall)
217+
val functionResponse =
218+
Content("function", listOf(FunctionResponsePart(functionCall.name, output)))
219+
tempHistory.add(response.candidates.first().content)
220+
tempHistory.add(functionResponse)
221+
model
222+
.generateContentStream(listOf(*history.toTypedArray(), *tempHistory.toTypedArray()))
223+
.collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
224+
}
225+
else -> {
226+
transformer.emit(response)
227+
tempHistory.add(Content("model", listOf(part)))
228+
}
229+
}
230+
}
231+
}
232+
233+
private fun addTextToHistory(tempHistory: LinkedList<Content>, textPart: TextPart) {
234+
val lastContent = tempHistory.lastOrNull()
235+
if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) {
236+
tempHistory.removeLast()
237+
val editedContent =
238+
Content(
239+
"model",
240+
lastContent.parts.map {
241+
when (it) {
242+
is TextPart -> {
243+
TextPart(it.text + textPart.text)
244+
}
245+
else -> {
246+
it
247+
}
248+
}
249+
}
250+
)
251+
tempHistory.add(editedContent)
252+
return
253+
}
254+
tempHistory.add(Content("model", listOf(textPart)))
255+
}
256+
212257
private fun Content.assertComesFromUser() {
213258
if (role !in listOf("user", "function")) {
214259
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")

firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ import com.google.firebase.ai.common.APIController
2222
import com.google.firebase.ai.common.AppCheckHeaderProvider
2323
import com.google.firebase.ai.common.CountTokensRequest
2424
import com.google.firebase.ai.common.GenerateContentRequest
25+
import com.google.firebase.ai.type.AutoFunctionDeclaration
2526
import com.google.firebase.ai.type.Content
2627
import com.google.firebase.ai.type.CountTokensResponse
2728
import com.google.firebase.ai.type.FinishReason
2829
import com.google.firebase.ai.type.FirebaseAIException
30+
import com.google.firebase.ai.type.FunctionCallPart
2931
import com.google.firebase.ai.type.GenerateContentResponse
3032
import com.google.firebase.ai.type.GenerationConfig
3133
import com.google.firebase.ai.type.GenerativeBackend
@@ -45,6 +47,11 @@ import kotlinx.coroutines.flow.Flow
4547
import kotlinx.coroutines.flow.catch
4648
import kotlinx.coroutines.flow.map
4749
import kotlinx.serialization.ExperimentalSerializationApi
50+
import kotlinx.serialization.InternalSerializationApi
51+
import kotlinx.serialization.json.Json
52+
import kotlinx.serialization.json.JsonObject
53+
import kotlinx.serialization.json.jsonObject
54+
import kotlinx.serialization.serializerOrNull
4855

4956
/**
5057
* Represents a multimodal model (like Gemini), capable of generating content based on various input
@@ -266,6 +273,43 @@ internal constructor(
266273
return countTokens(content { image(prompt) })
267274
}
268275

276+
@OptIn(InternalSerializationApi::class)
277+
internal suspend fun executeFunction(call: FunctionCallPart): JsonObject {
278+
if (tools == null) {
279+
throw RuntimeException("No registered tools")
280+
}
281+
val tool = tools.flatMap { it.autoFunctionDeclarations?.filterNotNull() ?: emptyList() }
282+
val declaration =
283+
tool.firstOrNull() { it.name == call.name }
284+
?: throw RuntimeException("No registered function named ${call.name}")
285+
return executeFunction<Any, Any>(
286+
declaration as AutoFunctionDeclaration<Any, Any>,
287+
call.args["param"].toString()
288+
)
289+
}
290+
291+
@OptIn(InternalSerializationApi::class)
292+
internal suspend fun <I : Any, O : Any> executeFunction(
293+
functionDeclaration: AutoFunctionDeclaration<I, O>,
294+
parameter: String
295+
): JsonObject {
296+
val inputDeserializer =
297+
functionDeclaration.inputSchema.clazz.serializerOrNull()
298+
?: throw RuntimeException(
299+
"Function input type ${functionDeclaration.inputSchema.clazz.qualifiedName} is not @Serializable"
300+
)
301+
val input = Json.decodeFromString(inputDeserializer, parameter)
302+
val functionReference =
303+
functionDeclaration.functionReference
304+
?: throw RuntimeException("Function reference for ${functionDeclaration.name} is missing")
305+
val output = functionReference.invoke(input)
306+
val outputSerializer = functionDeclaration.outputSchema?.clazz?.serializerOrNull()
307+
if (outputSerializer != null) {
308+
return Json.encodeToJsonElement(outputSerializer, output).jsonObject
309+
}
310+
return output as JsonObject
311+
}
312+
269313
@OptIn(ExperimentalSerializationApi::class)
270314
private fun constructRequest(vararg prompt: Content) =
271315
GenerateContentRequest(
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.google.firebase.ai.type
2+
3+
import kotlinx.serialization.json.JsonObject
4+
5+
public class AutoFunctionDeclaration<I : Any, O : Any>
6+
internal constructor(
7+
public val name: String,
8+
public val description: String,
9+
public val inputSchema: JsonSchema<I>,
10+
public val outputSchema: JsonSchema<O>? = null,
11+
public val functionReference: (suspend (I) -> O)? = null
12+
) {
13+
public companion object {
14+
public fun <I : Any, O : Any> create(
15+
functionName: String,
16+
description: String,
17+
inputSchema: JsonSchema<I>,
18+
outputSchema: JsonSchema<O>,
19+
functionReference: ((I) -> O)? = null
20+
): AutoFunctionDeclaration<I, O> {
21+
return AutoFunctionDeclaration<I, O>(
22+
functionName,
23+
description,
24+
inputSchema,
25+
outputSchema,
26+
functionReference
27+
)
28+
}
29+
30+
public fun <I : Any> create(
31+
functionName: String,
32+
inputSchema: JsonSchema<I>,
33+
description: String,
34+
functionReference: ((I) -> JsonObject)? = null
35+
): AutoFunctionDeclaration<I, JsonObject> {
36+
return AutoFunctionDeclaration<I, JsonObject>(
37+
functionName,
38+
description,
39+
inputSchema,
40+
null,
41+
functionReference
42+
)
43+
}
44+
}
45+
46+
internal fun toInternal(): FunctionDeclaration.Internal {
47+
return FunctionDeclaration.Internal(
48+
name,
49+
description,
50+
null,
51+
JsonSchema.obj(mapOf("param" to inputSchema)).toInternalJson(),
52+
outputSchema?.toInternalJson()
53+
)
54+
}
55+
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,14 @@ public class FunctionDeclaration(
6161
internal val schema: Schema =
6262
Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false)
6363

64-
internal fun toInternal() = Internal(name, description, schema.toInternalOpenApi())
64+
internal fun toInternal() = Internal(name, description, schema.toInternalOpenApi(), null, null)
6565

6666
@Serializable
6767
internal data class Internal(
6868
val name: String,
6969
val description: String,
70-
val parameters: Schema.InternalOpenAPI
70+
val parameters: Schema.InternalOpenAPI?,
71+
val parametersJsonSchema: Schema.InternalJson?,
72+
val responseJsonSchema: Schema.InternalJson?,
7173
)
7274
}

0 commit comments

Comments
 (0)