java-tutorials/libraries-llms/src/test/java/com/baeldung/langchain/ChainWithDocumentLiveTest.java
2023-10-25 10:40:32 +05:30

86 lines
3.1 KiB
Java

package com.baeldung.langchain;
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument;
import static java.time.Duration.ofSeconds;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
public class ChainWithDocumentLiveTest {
Logger logger = LoggerFactory.getLogger(ChainWithDocumentLiveTest.class);
@Test
public void givenDocument_whenPrompted_thenValidResponse() {
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
.documentSplitter(DocumentSplitters.recursive(500, 0))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
Document document = loadDocument(toPath("src/test/resources/example-files/simpson's_adventures.txt"));
ingestor.ingest(document);
ChatLanguageModel chatModel = OpenAiChatModel.builder()
.apiKey(Constants.OPEN_AI_KEY)
.timeout(ofSeconds(60))
.build();
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
.chatLanguageModel(chatModel)
.retriever(EmbeddingStoreRetriever.from(embeddingStore, embeddingModel))
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.promptTemplate(PromptTemplate
.from("Answer the following question to the best of your ability: {{question}}\n\nBase your answer on the following information:\n{{information}}"))
.build();
String answer = chain.execute("Who is Simpson?");
logger.info(answer);
Assert.assertNotNull(answer);
}
private static Path toPath(String fileName) {
try {
URL fileUrl = new File(fileName).toURI()
.toURL();
System.out.println(new File(fileName).toURI()
.toURL());
return Paths.get(fileUrl.toURI());
} catch (URISyntaxException | MalformedURLException e) {
throw new RuntimeException(e);
}
}
}