Moved to LangChain4J for ChatGPT and Gemini modules
This commit is contained in:
parent
ebc3da70e4
commit
c60531f617
14 changed files with 102 additions and 163 deletions
|
@ -77,20 +77,20 @@ public class MobibotBuild extends Project {
|
|||
new Repository("https://jitpack.io"),
|
||||
SONATYPE_SNAPSHOTS_LEGACY);
|
||||
|
||||
var log4j = version(2, 23, 1);
|
||||
var kotlin = version(2, 0, 10);
|
||||
var log4j = version(2, 24, 0);
|
||||
var kotlin = version(2, 0, 20);
|
||||
var langchain = version(0, 34, 0);
|
||||
scope(compile)
|
||||
// PircBotX
|
||||
.include(dependency("com.github.pircbotx", "pircbotx", "2.3.1"))
|
||||
// Commons (mostly for PircBotX)
|
||||
.include(dependency("org.apache.commons", "commons-lang3", "3.16.0"))
|
||||
.include(dependency("org.apache.commons", "commons-lang3", "3.17.0"))
|
||||
.include(dependency("org.apache.commons", "commons-text", "1.12.0"))
|
||||
.include(dependency("commons-codec", "commons-codec", "1.17.1"))
|
||||
.include(dependency("commons-net", "commons-net", "3.11.1"))
|
||||
// Google
|
||||
.include(dependency("com.google.code.gson", "gson", "2.11.0"))
|
||||
.include(dependency("com.google.guava", "guava", "33.2.1-jre"))
|
||||
.include(dependency("com.google.cloud", "google-cloud-vertexai", "1.7.0"))
|
||||
// Kotlin
|
||||
.include(dependency("org.jetbrains.kotlin", "kotlin-stdlib", kotlin))
|
||||
.include(dependency("org.jetbrains.kotlin", "kotlin-stdlib-common", kotlin))
|
||||
|
@ -99,10 +99,15 @@ public class MobibotBuild extends Project {
|
|||
.include(dependency("org.jetbrains.kotlinx", "kotlinx-coroutines-core", "1.8.1"))
|
||||
.include(dependency("org.jetbrains.kotlinx", "kotlinx-cli-jvm", "0.3.6"))
|
||||
// Logging
|
||||
.include(dependency("org.slf4j", "slf4j-api", "2.0.15"))
|
||||
.include(dependency("org.slf4j", "slf4j-api", "2.0.16"))
|
||||
.include(dependency("org.apache.logging.log4j", "log4j-api", log4j))
|
||||
.include(dependency("org.apache.logging.log4j", "log4j-core", log4j))
|
||||
.include(dependency("org.apache.logging.log4j", "log4j-slf4j2-impl", log4j))
|
||||
// LangChain4J
|
||||
.include(dependency("dev.langchain4j", "langchain4j-open-ai", langchain))
|
||||
.include(dependency("dev.langchain4j", "langchain4j-google-ai-gemini", langchain))
|
||||
.include(dependency("dev.langchain4j", "langchain4j-core", langchain))
|
||||
.include(dependency("dev.langchain4j", "langchain4j", langchain))
|
||||
// Misc.
|
||||
.include(dependency("com.rometools", "rome", "2.1.0"))
|
||||
.include(dependency("com.squareup.okhttp3", "okhttp", "4.12.0"))
|
||||
|
@ -118,8 +123,8 @@ public class MobibotBuild extends Project {
|
|||
scope(test)
|
||||
.include(dependency("com.willowtreeapps.assertk", "assertk-jvm", version(0, 28, 1)))
|
||||
.include(dependency("org.jetbrains.kotlin", "kotlin-test-junit5", kotlin))
|
||||
.include(dependency("org.junit.jupiter", "junit-jupiter", version(5, 10, 3)))
|
||||
.include(dependency("org.junit.platform", "junit-platform-console-standalone", version(1, 10, 3)));
|
||||
.include(dependency("org.junit.jupiter", "junit-jupiter", version(5, 11, 0)))
|
||||
.include(dependency("org.junit.platform", "junit-platform-console-standalone", version(1, 11, 0)));
|
||||
|
||||
List<String> jars = new ArrayList<>();
|
||||
runtimeClasspathJars().forEach(f -> jars.add("./lib/" + f.getName()));
|
||||
|
|
|
@ -396,11 +396,11 @@ class Mobibot(nickname: String, val channel: String, logsDirPath: String, p: Pro
|
|||
|
||||
// Load the modules
|
||||
addons.add(Calc())
|
||||
addons.add(ChatGpt())
|
||||
addons.add(ChatGpt2())
|
||||
addons.add(CryptoPrices())
|
||||
addons.add(CurrencyConverter())
|
||||
addons.add(Dice())
|
||||
addons.add(Gemini())
|
||||
addons.add(Gemini2())
|
||||
addons.add(GoogleSearch())
|
||||
addons.add(Info(tell, seen))
|
||||
addons.add(Joke())
|
||||
|
|
|
@ -14,12 +14,12 @@ import java.time.ZoneId
|
|||
*/
|
||||
object ReleaseInfo {
|
||||
const val PROJECT = "mobibot"
|
||||
const val VERSION = "0.8.0-rc+20240712110931"
|
||||
const val VERSION = "0.8.0-rc+20240908190240"
|
||||
|
||||
@JvmField
|
||||
@Suppress("MagicNumber")
|
||||
val BUILD_DATE: LocalDateTime = LocalDateTime.ofInstant(
|
||||
Instant.ofEpochMilli(1720807771484L), ZoneId.systemDefault()
|
||||
Instant.ofEpochMilli(1725847361020L), ZoneId.systemDefault()
|
||||
)
|
||||
|
||||
const val WEBSITE = "https://mobitopia.org/mobibot/"
|
||||
|
|
|
@ -31,23 +31,16 @@
|
|||
|
||||
package net.thauvin.erik.mobibot.modules
|
||||
|
||||
import net.thauvin.erik.mobibot.Constants
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel
|
||||
import dev.langchain4j.model.openai.OpenAiChatModelName
|
||||
import net.thauvin.erik.mobibot.Utils
|
||||
import net.thauvin.erik.mobibot.Utils.sendMessage
|
||||
import org.json.JSONArray
|
||||
import org.json.JSONException
|
||||
import org.json.JSONObject
|
||||
import org.pircbotx.hooks.types.GenericMessageEvent
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.IOException
|
||||
import java.net.URI
|
||||
import java.net.http.HttpClient
|
||||
import java.net.http.HttpRequest
|
||||
import java.net.http.HttpResponse
|
||||
|
||||
class ChatGpt : AbstractModule() {
|
||||
val logger: Logger = LoggerFactory.getLogger(ChatGpt::class.java)
|
||||
class ChatGpt2 : AbstractModule() {
|
||||
val logger: Logger = LoggerFactory.getLogger(ChatGpt2::class.java)
|
||||
|
||||
override val name = CHATGPT_NAME
|
||||
|
||||
|
@ -93,9 +86,6 @@ class ChatGpt : AbstractModule() {
|
|||
*/
|
||||
const val MAX_TOKENS_PROP = "chatgpt-max-tokens"
|
||||
|
||||
// ChatGPT API URL
|
||||
private const val API_URL = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
// ChatGPT command
|
||||
private const val CHATGPT_CMD = "chatgpt"
|
||||
|
||||
|
@ -103,48 +93,15 @@ class ChatGpt : AbstractModule() {
|
|||
@Throws(ModuleException::class)
|
||||
fun chat(query: String, apiKey: String?, maxTokens: Int): String {
|
||||
if (!apiKey.isNullOrEmpty()) {
|
||||
val jsonObject = JSONObject()
|
||||
jsonObject.put("model", "gpt-3.5-turbo-1106")
|
||||
jsonObject.put("max_tokens", maxTokens)
|
||||
val message = JSONObject()
|
||||
message.put("role", "user")
|
||||
message.put("content", query)
|
||||
val messages = JSONArray()
|
||||
messages.put(message)
|
||||
jsonObject.put("messages", messages)
|
||||
|
||||
val request = HttpRequest.newBuilder()
|
||||
.uri(URI.create(API_URL))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", "Bearer $apiKey")
|
||||
.header("User-Agent", Constants.USER_AGENT)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(jsonObject.toString()))
|
||||
.build()
|
||||
try {
|
||||
val response = HttpClient.newHttpClient().send(request, HttpResponse.BodyHandlers.ofString())
|
||||
if (response.statusCode() == 200) {
|
||||
try {
|
||||
val jsonResponse = JSONObject(response.body())
|
||||
val choices = jsonResponse.getJSONArray("choices")
|
||||
return choices.getJSONObject(0).getJSONObject("message").getString("content").trim()
|
||||
} catch (e: JSONException) {
|
||||
throw ModuleException(
|
||||
"$CHATGPT_CMD($query): JSON",
|
||||
"A JSON error has occurred while conversing with $CHATGPT_NAME.",
|
||||
e
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if (response.statusCode() == 429) {
|
||||
throw ModuleException(
|
||||
"$CHATGPT_CMD($query): Rate limit reached",
|
||||
"Rate limit reached. Please try again later."
|
||||
)
|
||||
} else {
|
||||
throw IOException("HTTP Status Code: " + response.statusCode())
|
||||
}
|
||||
}
|
||||
} catch (e: IOException) {
|
||||
val model = OpenAiChatModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(OpenAiChatModelName.GPT_4_O)
|
||||
.maxTokens(maxTokens)
|
||||
.build()
|
||||
|
||||
return model.generate(query)
|
||||
} catch (e: Exception) {
|
||||
throw ModuleException(
|
||||
"$CHATGPT_CMD($query): IO",
|
||||
"An IO error has occurred while conversing with $CHATGPT_NAME.",
|
|
@ -217,6 +217,6 @@ class CurrencyConverter : AbstractModule() {
|
|||
init {
|
||||
commands.add(CURRENCY_CMD)
|
||||
initProperties(API_KEY_PROP)
|
||||
loadSymbols(properties[ChatGpt.API_KEY_PROP])
|
||||
loadSymbols(properties[ChatGpt2.API_KEY_PROP])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,12 +31,7 @@
|
|||
|
||||
package net.thauvin.erik.mobibot.modules
|
||||
|
||||
import com.google.cloud.vertexai.VertexAI
|
||||
import com.google.cloud.vertexai.api.GenerationConfig
|
||||
import com.google.cloud.vertexai.api.HarmCategory
|
||||
import com.google.cloud.vertexai.api.SafetySetting
|
||||
import com.google.cloud.vertexai.generativeai.GenerativeModel
|
||||
import com.google.cloud.vertexai.generativeai.ResponseHandler
|
||||
import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel
|
||||
import net.thauvin.erik.mobibot.Utils
|
||||
import net.thauvin.erik.mobibot.Utils.sendMessage
|
||||
import org.pircbotx.hooks.types.GenericMessageEvent
|
||||
|
@ -45,8 +40,8 @@ import org.slf4j.LoggerFactory
|
|||
import java.util.*
|
||||
|
||||
|
||||
class Gemini : AbstractModule() {
|
||||
private val logger: Logger = LoggerFactory.getLogger(Gemini::class.java)
|
||||
class Gemini2 : AbstractModule() {
|
||||
private val logger: Logger = LoggerFactory.getLogger(Gemini2::class.java)
|
||||
|
||||
override val name = GEMINI_NAME
|
||||
|
||||
|
@ -55,8 +50,7 @@ class Gemini : AbstractModule() {
|
|||
try {
|
||||
val answer = chat(
|
||||
args.trim(),
|
||||
properties[PROJECT_ID_PROP],
|
||||
properties[LOCATION_PROP],
|
||||
properties[GEMINI_API_KEY],
|
||||
properties.getOrDefault(MAX_TOKENS_PROP, "1024").toInt()
|
||||
)
|
||||
if (!answer.isNullOrEmpty()) {
|
||||
|
@ -82,17 +76,12 @@ class Gemini : AbstractModule() {
|
|||
const val GEMINI_NAME = "Gemini"
|
||||
|
||||
/**
|
||||
* The Google cloud project ID property.
|
||||
* The API key
|
||||
*/
|
||||
const val PROJECT_ID_PROP = "gemini-project-id"
|
||||
const val GEMINI_API_KEY = "gemini-api-key"
|
||||
|
||||
/**
|
||||
* The Vertex AI location property.
|
||||
*/
|
||||
const val LOCATION_PROP = "gemini-location"
|
||||
|
||||
/**
|
||||
* The max number of tokens property.
|
||||
* The max number of output tokens property.
|
||||
*/
|
||||
const val MAX_TOKENS_PROP = "gemini-max-tokens"
|
||||
|
||||
|
@ -103,40 +92,18 @@ class Gemini : AbstractModule() {
|
|||
@Throws(ModuleException::class)
|
||||
fun chat(
|
||||
query: String,
|
||||
projectId: String?,
|
||||
location: String?,
|
||||
maxToken: Int
|
||||
apiKey: String?,
|
||||
maxTokens: Int
|
||||
): String? {
|
||||
if (!projectId.isNullOrEmpty() && !location.isNullOrEmpty()) {
|
||||
if (!apiKey.isNullOrEmpty()) {
|
||||
try {
|
||||
VertexAI(projectId, location).use { vertexAI ->
|
||||
val generationConfig = GenerationConfig.newBuilder().setMaxOutputTokens(maxToken).build()
|
||||
val safetySettings = listOf(
|
||||
SafetySetting.newBuilder()
|
||||
.setCategory(HarmCategory.HARM_CATEGORY_HATE_SPEECH)
|
||||
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
|
||||
.build(),
|
||||
SafetySetting.newBuilder()
|
||||
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
|
||||
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
|
||||
.build(),
|
||||
SafetySetting.newBuilder()
|
||||
.setCategory(HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
|
||||
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
|
||||
.build(),
|
||||
SafetySetting.newBuilder()
|
||||
.setCategory(HarmCategory.HARM_CATEGORY_HARASSMENT)
|
||||
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
|
||||
.build()
|
||||
)
|
||||
val model = GenerativeModel.Builder().setModelName("gemini-1.5-flash-001")
|
||||
.setGenerationConfig(generationConfig)
|
||||
.setVertexAi(vertexAI).build()
|
||||
.withSafetySettings(safetySettings)
|
||||
val gemini = GoogleAiGeminiChatModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName("gemini-1.5-flash")
|
||||
.maxOutputTokens(maxTokens)
|
||||
.build()
|
||||
|
||||
val response = model.generateContent(query)
|
||||
return ResponseHandler.getText(response)
|
||||
}
|
||||
return gemini.generate(query)
|
||||
} catch (e: Exception) {
|
||||
throw ModuleException(
|
||||
"$GEMINI_CMD($query): IO",
|
||||
|
@ -159,7 +126,6 @@ class Gemini : AbstractModule() {
|
|||
add(Utils.helpFormat("%c $GEMINI_CMD explain quantum computing in simple terms"))
|
||||
add(Utils.helpFormat("%c $GEMINI_CMD how do I make an HTTP request in Javascript?"))
|
||||
}
|
||||
initProperties(PROJECT_ID_PROP, LOCATION_PROP, MAX_TOKENS_PROP)
|
||||
initProperties(GEMINI_API_KEY, MAX_TOKENS_PROP)
|
||||
}
|
||||
|
||||
}
|
|
@ -39,10 +39,10 @@ import net.thauvin.erik.mobibot.DisableOnCi
|
|||
import net.thauvin.erik.mobibot.LocalProperties
|
||||
import kotlin.test.Test
|
||||
|
||||
class ChatGptTest : LocalProperties() {
|
||||
class ChatGpt2Test : LocalProperties() {
|
||||
@Test
|
||||
fun testApiKey() {
|
||||
assertFailure { ChatGpt.chat("1 gallon to liter", "", 0) }
|
||||
assertFailure { ChatGpt2.chat("1 gallon to liter", "", 0) }
|
||||
.isInstanceOf(ModuleException::class.java)
|
||||
.hasNoCause()
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ class ChatGptTest : LocalProperties() {
|
|||
fun testChatOnCoverage() {
|
||||
if (System.getenv("CI") == null || System.getenv("COVERAGE_JDK") != null) {
|
||||
assertThat(
|
||||
ChatGpt.chat("how do I encode a URL in java?", getProperty(ChatGpt.API_KEY_PROP), 60)
|
||||
ChatGpt2.chat("how do I encode a URL in java?", getProperty(ChatGpt2.API_KEY_PROP), 60)
|
||||
).contains("URLEncoder")
|
||||
}
|
||||
}
|
||||
|
@ -59,12 +59,12 @@ class ChatGptTest : LocalProperties() {
|
|||
@Test
|
||||
@DisableOnCi
|
||||
fun testChat() {
|
||||
val apiKey = getProperty(ChatGpt.API_KEY_PROP)
|
||||
val apiKey = getProperty(ChatGpt2.API_KEY_PROP)
|
||||
assertThat(
|
||||
ChatGpt.chat("how do I make an HTTP request in Javascript?", apiKey, 100)
|
||||
ChatGpt2.chat("how do I make an HTTP request in Javascript?", apiKey, 200)
|
||||
).contains("XMLHttpRequest")
|
||||
|
||||
assertFailure { ChatGpt.chat("1 liter to gallon", apiKey, -1) }
|
||||
assertFailure { ChatGpt2.chat("1 liter to gallon", apiKey, -1) }
|
||||
.isInstanceOf(ModuleException::class.java)
|
||||
}
|
||||
}
|
|
@ -37,10 +37,10 @@ import net.thauvin.erik.mobibot.DisableOnCi
|
|||
import net.thauvin.erik.mobibot.LocalProperties
|
||||
import kotlin.test.Test
|
||||
|
||||
class GeminiTest : LocalProperties() {
|
||||
class Gemini2Test : LocalProperties() {
|
||||
@Test
|
||||
fun testApiKey() {
|
||||
assertFailure { Gemini.chat("1 gallon to liter", "", "", 1024) }
|
||||
assertFailure { Gemini2.chat("1 gallon to liter", "", 0) }
|
||||
.isInstanceOf(ModuleException::class.java)
|
||||
.hasNoCause()
|
||||
}
|
||||
|
@ -48,19 +48,18 @@ class GeminiTest : LocalProperties() {
|
|||
@Test
|
||||
@DisableOnCi
|
||||
fun chatPrompt() {
|
||||
val projectId = getProperty(Gemini.PROJECT_ID_PROP)
|
||||
val location = getProperty(Gemini.LOCATION_PROP)
|
||||
val maxTokens = getProperty(Gemini.MAX_TOKENS_PROP).toInt()
|
||||
val apiKey = getProperty(Gemini2.GEMINI_API_KEY)
|
||||
val maxTokens = getProperty(Gemini2.MAX_TOKENS_PROP).toInt()
|
||||
|
||||
assertThat(
|
||||
Gemini.chat("how do I make an HTTP request in Javascript?", projectId, location, maxTokens)
|
||||
Gemini2.chat("how do I make an HTTP request in Javascript?", apiKey, maxTokens)
|
||||
).isNotNull().contains("XMLHttpRequest")
|
||||
|
||||
assertThat(
|
||||
Gemini.chat("how do I encode a URL in java?", projectId, location, 60)
|
||||
Gemini2.chat("how do I encode a URL in java?", apiKey, 60)
|
||||
).isNotNull().contains("URLEncoder")
|
||||
|
||||
assertFailure { Gemini.chat("1 liter to gallon", projectId, "blah", 40) }
|
||||
assertFailure { Gemini2.chat("1 liter to gallon", "foo", 40) }
|
||||
.isInstanceOf(ModuleException::class.java)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue