diff --git a/src/main/kotlin/net/thauvin/erik/mobibot/modules/Gemini.kt b/src/main/kotlin/net/thauvin/erik/mobibot/modules/Gemini.kt index c0faefa..2e4ed91 100644 --- a/src/main/kotlin/net/thauvin/erik/mobibot/modules/Gemini.kt +++ b/src/main/kotlin/net/thauvin/erik/mobibot/modules/Gemini.kt @@ -35,7 +35,6 @@ import com.google.cloud.vertexai.VertexAI import com.google.cloud.vertexai.api.GenerationConfig import com.google.cloud.vertexai.api.HarmCategory import com.google.cloud.vertexai.api.SafetySetting -import com.google.cloud.vertexai.generativeai.ChatSession import com.google.cloud.vertexai.generativeai.GenerativeModel import com.google.cloud.vertexai.generativeai.ResponseHandler import net.thauvin.erik.mobibot.Utils @@ -57,7 +56,7 @@ class Gemini : AbstractModule() { val answer = chat( args.trim(), properties[PROJECT_ID_PROP], - properties[LOCATION_PROPR], + properties[LOCATION_PROP], properties.getOrDefault(MAX_TOKENS_PROP, "1024").toInt() ) if (!answer.isNullOrEmpty()) { @@ -83,17 +82,17 @@ class Gemini : AbstractModule() { const val GEMINI_NAME = "Gemini" /** - * The Google cloud project ID. + * The Google cloud project ID property. */ const val PROJECT_ID_PROP = "gemini-project-id" /** - * The Vertex AI location. + * The Vertex AI location property. */ - const val LOCATION_PROPR = "gemini-location" + const val LOCATION_PROP = "gemini-location" /** - * The max tokens property. + * The max number of tokens property. */ const val MAX_TOKENS_PROP = "gemini-max-tokens" @@ -112,31 +111,30 @@ class Gemini : AbstractModule() { try { VertexAI(projectId, location).use { vertexAI -> val generationConfig = GenerationConfig.newBuilder().setMaxOutputTokens(maxToken).build() - val safetySettings = Arrays.asList( + val safetySettings = listOf( SafetySetting.newBuilder() .setCategory(HarmCategory.HARM_CATEGORY_HATE_SPEECH) - .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE) + .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) .build(), SafetySetting.newBuilder() .setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) - .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE) + .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) .build(), SafetySetting.newBuilder() .setCategory(HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT) - .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE) + .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) .build(), SafetySetting.newBuilder() .setCategory(HarmCategory.HARM_CATEGORY_HARASSMENT) - .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE) + .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) .build() ) - val model = GenerativeModel.Builder().setModelName("gemini-pro-vision") + val model = GenerativeModel.Builder().setModelName("gemini-1.5-flash-001") .setGenerationConfig(generationConfig) .setVertexAi(vertexAI).build() .withSafetySettings(safetySettings) - val session = ChatSession(model) - val response = session.sendMessage(query) + val response = model.generateContent(query) return ResponseHandler.getText(response) } } catch (e: Exception) { @@ -161,7 +159,7 @@ class Gemini : AbstractModule() { add(Utils.helpFormat("%c $GEMINI_CMD explain quantum computing in simple terms")) add(Utils.helpFormat("%c $GEMINI_CMD how do I make an HTTP request in Javascript?")) } - initProperties(PROJECT_ID_PROP, LOCATION_PROPR, MAX_TOKENS_PROP) + initProperties(PROJECT_ID_PROP, LOCATION_PROP, MAX_TOKENS_PROP) } } diff --git a/src/test/kotlin/net/thauvin/erik/mobibot/modules/GeminiTest.kt b/src/test/kotlin/net/thauvin/erik/mobibot/modules/GeminiTest.kt index db69fe7..1f0202f 100644 --- a/src/test/kotlin/net/thauvin/erik/mobibot/modules/GeminiTest.kt +++ b/src/test/kotlin/net/thauvin/erik/mobibot/modules/GeminiTest.kt @@ -49,7 +49,7 @@ class GeminiTest : LocalProperties() { @DisableOnCi fun chatPrompt() { val projectId = getProperty(Gemini.PROJECT_ID_PROP) - val location = getProperty(Gemini.LOCATION_PROPR) + val location = getProperty(Gemini.LOCATION_PROP) val maxTokens = getProperty(Gemini.MAX_TOKENS_PROP).toInt() assertThat(