99 lines
3.9 KiB
Java
99 lines
3.9 KiB
Java
|
package com.baeldung.langchain;
|
||
|
|
||
|
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument;
|
||
|
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||
|
import static java.time.Duration.ofSeconds;
|
||
|
import static java.util.stream.Collectors.joining;
|
||
|
|
||
|
import java.io.File;
|
||
|
import java.net.MalformedURLException;
|
||
|
import java.net.URISyntaxException;
|
||
|
import java.net.URL;
|
||
|
import java.nio.file.Path;
|
||
|
import java.nio.file.Paths;
|
||
|
import java.util.HashMap;
|
||
|
import java.util.List;
|
||
|
import java.util.Map;
|
||
|
import java.util.logging.Logger;
|
||
|
|
||
|
import org.junit.Assert;
|
||
|
import org.junit.Test;
|
||
|
|
||
|
import dev.langchain4j.data.document.Document;
|
||
|
import dev.langchain4j.data.document.DocumentSplitter;
|
||
|
import dev.langchain4j.data.document.splitter.DocumentSplitters;
|
||
|
import dev.langchain4j.data.embedding.Embedding;
|
||
|
import dev.langchain4j.data.message.AiMessage;
|
||
|
import dev.langchain4j.data.segment.TextSegment;
|
||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||
|
import dev.langchain4j.model.input.Prompt;
|
||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||
|
import dev.langchain4j.model.openai.OpenAiTokenizer;
|
||
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||
|
|
||
|
public class ChainWithDocumentLiveTests {
|
||
|
|
||
|
@Test
|
||
|
public void givenChainWithDocument_whenPrompted_thenValidResponse() {
|
||
|
|
||
|
Document document = loadDocument(toPath("src/test/resources/example-files/simpson's_adventures.txt"));
|
||
|
DocumentSplitter splitter = DocumentSplitters.recursive(100, 0, new OpenAiTokenizer(GPT_3_5_TURBO));
|
||
|
List<TextSegment> segments = splitter.split(document);
|
||
|
|
||
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
||
|
List<Embedding> embeddings = embeddingModel.embedAll(segments)
|
||
|
.content();
|
||
|
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
|
||
|
embeddingStore.addAll(embeddings, segments);
|
||
|
|
||
|
String question = "Who is Simpson?";
|
||
|
Embedding questionEmbedding = embeddingModel.embed(question)
|
||
|
.content();
|
||
|
int maxResults = 3;
|
||
|
double minScore = 0.7;
|
||
|
List<EmbeddingMatch<TextSegment>> relevantEmbeddings = embeddingStore.findRelevant(questionEmbedding, maxResults, minScore);
|
||
|
|
||
|
PromptTemplate promptTemplate = PromptTemplate.from("Answer the following question to the best of your ability:\n" + "\n" + "Question:\n" + "{{question}}\n" + "\n" + "Base your answer on the following information:\n" + "{{information}}");
|
||
|
|
||
|
String information = relevantEmbeddings.stream()
|
||
|
.map(match -> match.embedded()
|
||
|
.text())
|
||
|
.collect(joining("\n\n"));
|
||
|
|
||
|
Map<String, Object> variables = new HashMap<>();
|
||
|
variables.put("question", question);
|
||
|
variables.put("information", information);
|
||
|
|
||
|
Prompt prompt = promptTemplate.apply(variables);
|
||
|
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
||
|
.apiKey(Constants.OPEN_API_KEY)
|
||
|
.timeout(ofSeconds(60))
|
||
|
.build();
|
||
|
AiMessage aiMessage = chatModel.generate(prompt.toUserMessage())
|
||
|
.content();
|
||
|
|
||
|
Logger.getGlobal()
|
||
|
.info(aiMessage.text());
|
||
|
Assert.assertNotNull(aiMessage.text());
|
||
|
|
||
|
}
|
||
|
|
||
|
private static Path toPath(String fileName) {
|
||
|
try {
|
||
|
URL fileUrl = new File(fileName).toURI()
|
||
|
.toURL();
|
||
|
System.out.println(new File(fileName).toURI()
|
||
|
.toURL());
|
||
|
return Paths.get(fileUrl.toURI());
|
||
|
} catch (URISyntaxException | MalformedURLException e) {
|
||
|
throw new RuntimeException(e);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|