Incorporated some review comments on the code.
This commit is contained in:
parent
c36a6cc5a4
commit
03c9dce8cf
@ -1,17 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<configuration>
|
|
||||||
<appender name="STDOUT"
|
|
||||||
class="ch.qos.logback.core.ConsoleAppender">
|
|
||||||
<encoder>
|
|
||||||
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -
|
|
||||||
%msg%n
|
|
||||||
</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
<logger name="io.ebean.DDL" level="TRACE" />
|
|
||||||
<logger name="io.ebean.SQL" level="TRACE" />
|
|
||||||
<logger name="io.ebean.TXN" level="TRACE" />
|
|
||||||
<root level="WARN">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
</root>
|
|
||||||
</configuration>
|
|
@ -1,9 +1,7 @@
|
|||||||
package com.baeldung.langchain;
|
package com.baeldung.langchain;
|
||||||
|
|
||||||
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument;
|
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
|
||||||
import static java.time.Duration.ofSeconds;
|
import static java.time.Duration.ofSeconds;
|
||||||
import static java.util.stream.Collectors.joining;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.MalformedURLException;
|
import java.net.MalformedURLException;
|
||||||
@ -11,75 +9,64 @@ import java.net.URISyntaxException;
|
|||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import dev.langchain4j.chain.ConversationalRetrievalChain;
|
||||||
import dev.langchain4j.data.document.Document;
|
import dev.langchain4j.data.document.Document;
|
||||||
import dev.langchain4j.data.document.DocumentSplitter;
|
|
||||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||||
|
|
||||||
public class ChainWithDocumentLiveTest {
|
public class ChainWithDocumentLiveTest {
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(ChainWithDocumentLiveTest.class);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenChainWithDocument_whenPrompted_thenValidResponse() {
|
public void givenDocument_whenPrompted_thenValidResponse() {
|
||||||
|
|
||||||
Document document = loadDocument(toPath("src/test/resources/example-files/simpson's_adventures.txt"));
|
|
||||||
DocumentSplitter splitter = DocumentSplitters.recursive(100, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
|
|
||||||
List<TextSegment> segments = splitter.split(document);
|
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
||||||
List<Embedding> embeddings = embeddingModel.embedAll(segments)
|
|
||||||
.content();
|
|
||||||
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
||||||
embeddingStore.addAll(embeddings, segments);
|
|
||||||
|
|
||||||
String question = "Who is Simpson?";
|
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
|
||||||
Embedding questionEmbedding = embeddingModel.embed(question)
|
.documentSplitter(DocumentSplitters.recursive(500, 0))
|
||||||
.content();
|
.embeddingModel(embeddingModel)
|
||||||
int maxResults = 3;
|
.embeddingStore(embeddingStore)
|
||||||
double minScore = 0.7;
|
.build();
|
||||||
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = embeddingStore.findRelevant(questionEmbedding, maxResults, minScore);
|
|
||||||
|
|
||||||
PromptTemplate promptTemplate = PromptTemplate.from("Answer the following question to the best of your ability:\n" + "\n" + "Question:\n" + "{{question}}\n" + "\n" + "Base your answer on the following information:\n" + "{{information}}");
|
Document document = loadDocument(toPath("src/test/resources/example-files/simpson's_adventures.txt"));
|
||||||
|
ingestor.ingest(document);
|
||||||
|
|
||||||
String information = relevantEmbeddings.stream()
|
|
||||||
.map(match -> match.embedded()
|
|
||||||
.text())
|
|
||||||
.collect(joining("\n\n"));
|
|
||||||
|
|
||||||
Map<String, Object> variables = new HashMap<>();
|
|
||||||
variables.put("question", question);
|
|
||||||
variables.put("information", information);
|
|
||||||
|
|
||||||
Prompt prompt = promptTemplate.apply(variables);
|
|
||||||
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
||||||
.apiKey(Constants.OPEN_API_KEY)
|
.apiKey(Constants.OPEN_AI_KEY)
|
||||||
.timeout(ofSeconds(60))
|
.timeout(ofSeconds(60))
|
||||||
.build();
|
.build();
|
||||||
AiMessage aiMessage = chatModel.generate(prompt.toUserMessage())
|
|
||||||
.content();
|
|
||||||
|
|
||||||
Logger.getGlobal()
|
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
|
||||||
.info(aiMessage.text());
|
.chatLanguageModel(chatModel)
|
||||||
Assert.assertNotNull(aiMessage.text());
|
.retriever(EmbeddingStoreRetriever.from(embeddingStore, embeddingModel))
|
||||||
|
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||||
|
.promptTemplate(PromptTemplate
|
||||||
|
.from("Answer the following question to the best of your ability: {{question}}\n\nBase your answer on the following information:\n{{information}}"))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String answer = chain.execute("Who is Simpson?");
|
||||||
|
|
||||||
|
logger.info(answer);
|
||||||
|
Assert.assertNotNull(answer);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package com.baeldung.langchain;
|
package com.baeldung.langchain;
|
||||||
|
|
||||||
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument;
|
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument;
|
||||||
|
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||||
import static java.time.Duration.ofSeconds;
|
import static java.time.Duration.ofSeconds;
|
||||||
|
import static java.util.stream.Collectors.joining;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.MalformedURLException;
|
import java.net.MalformedURLException;
|
||||||
@ -9,59 +11,77 @@ import java.net.URISyntaxException;
|
|||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.logging.Logger;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import dev.langchain4j.chain.ConversationalRetrievalChain;
|
|
||||||
import dev.langchain4j.data.document.Document;
|
import dev.langchain4j.data.document.Document;
|
||||||
|
import dev.langchain4j.data.document.DocumentSplitter;
|
||||||
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
|
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||||
|
|
||||||
public class ChatWithDocumentLiveTest {
|
public class ChatWithDocumentLiveTest {
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(ChatWithDocumentLiveTest.class);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenDocument_whenPrompted_thenValidResponse() {
|
public void givenChainWithDocument_whenPrompted_thenValidResponse() {
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
|
||||||
|
|
||||||
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
|
||||||
|
|
||||||
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
|
|
||||||
.documentSplitter(DocumentSplitters.recursive(500, 0))
|
|
||||||
.embeddingModel(embeddingModel)
|
|
||||||
.embeddingStore(embeddingStore)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Document document = loadDocument(toPath("src/test/resources/example-files/simpson's_adventures.txt"));
|
Document document = loadDocument(toPath("src/test/resources/example-files/simpson's_adventures.txt"));
|
||||||
ingestor.ingest(document);
|
DocumentSplitter splitter = DocumentSplitters.recursive(100, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
|
||||||
|
List<TextSegment> segments = splitter.split(document);
|
||||||
|
|
||||||
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
||||||
|
List<Embedding> embeddings = embeddingModel.embedAll(segments)
|
||||||
|
.content();
|
||||||
|
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
||||||
|
embeddingStore.addAll(embeddings, segments);
|
||||||
|
|
||||||
|
String question = "Who is Simpson?";
|
||||||
|
Embedding questionEmbedding = embeddingModel.embed(question)
|
||||||
|
.content();
|
||||||
|
int maxResults = 3;
|
||||||
|
double minScore = 0.7;
|
||||||
|
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = embeddingStore.findRelevant(questionEmbedding, maxResults, minScore);
|
||||||
|
|
||||||
|
PromptTemplate promptTemplate = PromptTemplate.from("Answer the following question to the best of your ability:\n" + "\n" + "Question:\n" + "{{question}}\n" + "\n" + "Base your answer on the following information:\n" + "{{information}}");
|
||||||
|
|
||||||
|
String information = relevantEmbeddings.stream()
|
||||||
|
.map(match -> match.embedded()
|
||||||
|
.text())
|
||||||
|
.collect(joining("\n\n"));
|
||||||
|
|
||||||
|
Map<String, Object> variables = new HashMap<>();
|
||||||
|
variables.put("question", question);
|
||||||
|
variables.put("information", information);
|
||||||
|
|
||||||
|
Prompt prompt = promptTemplate.apply(variables);
|
||||||
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
||||||
.apiKey(Constants.OPEN_API_KEY)
|
.apiKey(Constants.OPEN_AI_KEY)
|
||||||
.timeout(ofSeconds(60))
|
.timeout(ofSeconds(60))
|
||||||
.build();
|
.build();
|
||||||
|
AiMessage aiMessage = chatModel.generate(prompt.toUserMessage())
|
||||||
|
.content();
|
||||||
|
|
||||||
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
|
logger.info(aiMessage.text());
|
||||||
.chatLanguageModel(chatModel)
|
Assert.assertNotNull(aiMessage.text());
|
||||||
.retriever(EmbeddingStoreRetriever.from(embeddingStore, embeddingModel))
|
|
||||||
// .chatMemory() // you can override default chat memory
|
|
||||||
// .promptTemplate() // you can override default prompt template
|
|
||||||
.build();
|
|
||||||
|
|
||||||
String answer = chain.execute("Who is Simpson?");
|
|
||||||
|
|
||||||
Logger.getGlobal()
|
|
||||||
.info(answer);
|
|
||||||
Assert.assertNotNull(answer);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,35 +9,35 @@ import dev.langchain4j.model.openai.OpenAiTokenizer;
|
|||||||
|
|
||||||
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
import static dev.langchain4j.data.message.UserMessage.userMessage;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
public class ChatWithMemoryLiveTest {
|
public class ChatWithMemoryLiveTest {
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(ChatWithMemoryLiveTest.class);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenMemory_whenPrompted_thenValidResponse() {
|
public void givenMemory_whenPrompted_thenValidResponse() {
|
||||||
|
|
||||||
ChatLanguageModel model = OpenAiChatModel.withApiKey(Constants.OPEN_API_KEY);
|
ChatLanguageModel model = OpenAiChatModel.withApiKey(Constants.OPEN_AI_KEY);
|
||||||
ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(300, new OpenAiTokenizer(GPT_3_5_TURBO));
|
ChatMemory chatMemory = TokenWindowChatMemory.withMaxTokens(300, new OpenAiTokenizer(GPT_3_5_TURBO));
|
||||||
|
|
||||||
chatMemory.add(userMessage("Hello, my name is Kumar"));
|
chatMemory.add(userMessage("Hello, my name is Kumar"));
|
||||||
AiMessage answer = model.generate(chatMemory.messages())
|
AiMessage answer = model.generate(chatMemory.messages())
|
||||||
.content();
|
.content();
|
||||||
Logger.getGlobal()
|
logger.info(answer.text());
|
||||||
.info(answer.text());
|
|
||||||
Assert.assertNotNull(answer.text());
|
Assert.assertNotNull(answer.text());
|
||||||
chatMemory.add(answer);
|
chatMemory.add(answer);
|
||||||
|
|
||||||
chatMemory.add(userMessage("What is my name?"));
|
chatMemory.add(userMessage("What is my name?"));
|
||||||
AiMessage answerWithName = model.generate(chatMemory.messages())
|
AiMessage answerWithName = model.generate(chatMemory.messages())
|
||||||
.content();
|
.content();
|
||||||
Logger.getGlobal()
|
logger.info(answerWithName.text());
|
||||||
.info(answerWithName.text());
|
assertThat(answerWithName.text().contains("Kumar"));
|
||||||
Assert.assertTrue(answerWithName.text()
|
|
||||||
.contains("Kumar"));
|
|
||||||
chatMemory.add(answerWithName);
|
chatMemory.add(answerWithName);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,11 @@ package com.baeldung.langchain;
|
|||||||
|
|
||||||
public class Constants {
|
public class Constants {
|
||||||
|
|
||||||
public static String OPEN_API_KEY = "demo";
|
/**
|
||||||
|
* A limited access key for access to OpenAI language models can be generated by first
|
||||||
|
* registering for free at (https://platform.openai.com/signup) and then by navigating
|
||||||
|
* to "Create new secret key" page at (https://platform.openai.com/account/api-keys).
|
||||||
|
*/
|
||||||
|
public static String OPEN_AI_KEY = "<OPEN_AI_KEY>";
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -4,9 +4,10 @@ import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
|||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
@ -16,6 +17,8 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
|
|||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
|
||||||
public class PromptTemplatesLiveTest {
|
public class PromptTemplatesLiveTest {
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(PromptTemplatesLiveTest.class);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void givenPromptTemplate_whenSuppliedInput_thenValidResponse() {
|
public void givenPromptTemplate_whenSuppliedInput_thenValidResponse() {
|
||||||
@ -27,14 +30,13 @@ public class PromptTemplatesLiveTest {
|
|||||||
Prompt prompt = promptTemplate.apply(variables);
|
Prompt prompt = promptTemplate.apply(variables);
|
||||||
|
|
||||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
ChatLanguageModel model = OpenAiChatModel.builder()
|
||||||
.apiKey(Constants.OPEN_API_KEY)
|
.apiKey(Constants.OPEN_AI_KEY)
|
||||||
.modelName(GPT_3_5_TURBO)
|
.modelName(GPT_3_5_TURBO)
|
||||||
.temperature(0.3)
|
.temperature(0.3)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
String response = model.generate(prompt.text());
|
String response = model.generate(prompt.text());
|
||||||
Logger.getGlobal()
|
logger.info(response);
|
||||||
.info(response);
|
|
||||||
Assert.assertNotNull(response);
|
Assert.assertNotNull(response);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package com.baeldung.langchain;
|
package com.baeldung.langchain;
|
||||||
|
|
||||||
import java.util.logging.Logger;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import dev.langchain4j.agent.tool.Tool;
|
import dev.langchain4j.agent.tool.Tool;
|
||||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||||
@ -11,6 +13,8 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
|
|||||||
import dev.langchain4j.service.AiServices;
|
import dev.langchain4j.service.AiServices;
|
||||||
|
|
||||||
public class ServiceWithToolsLiveTest {
|
public class ServiceWithToolsLiveTest {
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(ServiceWithToolsLiveTest.class);
|
||||||
|
|
||||||
static class Calculator {
|
static class Calculator {
|
||||||
|
|
||||||
@ -35,7 +39,7 @@ public class ServiceWithToolsLiveTest {
|
|||||||
public void givenServiceWithTools_whenPrompted_thenValidResponse() {
|
public void givenServiceWithTools_whenPrompted_thenValidResponse() {
|
||||||
|
|
||||||
Assistant assistant = AiServices.builder(Assistant.class)
|
Assistant assistant = AiServices.builder(Assistant.class)
|
||||||
.chatLanguageModel(OpenAiChatModel.withApiKey(Constants.OPEN_API_KEY))
|
.chatLanguageModel(OpenAiChatModel.withApiKey(Constants.OPEN_AI_KEY))
|
||||||
.tools(new Calculator())
|
.tools(new Calculator())
|
||||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||||
.build();
|
.build();
|
||||||
@ -43,9 +47,8 @@ public class ServiceWithToolsLiveTest {
|
|||||||
String question = "What is the sum of the numbers of letters in the words \"language\" and \"model\"?";
|
String question = "What is the sum of the numbers of letters in the words \"language\" and \"model\"?";
|
||||||
String answer = assistant.chat(question);
|
String answer = assistant.chat(question);
|
||||||
|
|
||||||
Logger.getGlobal()
|
logger.info(answer);
|
||||||
.info(answer);
|
assertThat(answer).contains("13");
|
||||||
Assert.assertNotNull(answer);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
4
pom.xml
4
pom.xml
@ -936,7 +936,7 @@
|
|||||||
<module>spring-di-4</module>
|
<module>spring-di-4</module>
|
||||||
<module>spring-kafka-2</module>
|
<module>spring-kafka-2</module>
|
||||||
<!--<module>java-panama</module> Java-19 module-->
|
<!--<module>java-panama</module> Java-19 module-->
|
||||||
<module>libraries-llms</module>
|
<module>libraries-llms</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
@ -1224,7 +1224,7 @@
|
|||||||
<module>spring-di-4</module>
|
<module>spring-di-4</module>
|
||||||
<module>spring-kafka-2</module>
|
<module>spring-kafka-2</module>
|
||||||
<!--<module>java-panama</module> Java-19 module-->
|
<!--<module>java-panama</module> Java-19 module-->
|
||||||
<module>libraries-llms</module>
|
<module>libraries-llms</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user