2023-10-23 14:33:05 +05:30
package com.baeldung.langchain ;
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument ;
import static java.time.Duration.ofSeconds ;
2023-10-26 23:15:14 +05:30
import static org.junit.Assert.assertNotNull ;
2023-10-23 14:33:05 +05:30
import java.nio.file.Paths ;
import org.junit.Test ;
2023-10-25 10:40:32 +05:30
import org.slf4j.Logger ;
import org.slf4j.LoggerFactory ;
2023-10-23 14:33:05 +05:30
2023-10-25 10:40:32 +05:30
import dev.langchain4j.chain.ConversationalRetrievalChain ;
2023-10-23 14:33:05 +05:30
import dev.langchain4j.data.document.Document ;
import dev.langchain4j.data.document.splitter.DocumentSplitters ;
import dev.langchain4j.data.segment.TextSegment ;
2023-10-25 10:40:32 +05:30
import dev.langchain4j.memory.chat.MessageWindowChatMemory ;
2023-10-23 14:33:05 +05:30
import dev.langchain4j.model.chat.ChatLanguageModel ;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel ;
import dev.langchain4j.model.embedding.EmbeddingModel ;
import dev.langchain4j.model.input.PromptTemplate ;
import dev.langchain4j.model.openai.OpenAiChatModel ;
2023-10-25 10:40:32 +05:30
import dev.langchain4j.retriever.EmbeddingStoreRetriever ;
2023-10-23 14:33:05 +05:30
import dev.langchain4j.store.embedding.EmbeddingStore ;
2023-10-25 10:40:32 +05:30
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor ;
2023-10-23 14:33:05 +05:30
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore ;
2023-10-23 16:05:03 +05:30
public class ChainWithDocumentLiveTest {
2023-10-26 23:15:14 +05:30
private static final Logger logger = LoggerFactory . getLogger ( ChainWithDocumentLiveTest . class ) ;
2023-10-23 14:33:05 +05:30
@Test
2023-10-26 23:15:14 +05:30
public void givenChainWithDocument_whenPrompted_thenValidResponse ( ) {
2023-10-23 14:33:05 +05:30
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel ( ) ;
2023-10-25 10:40:32 +05:30
EmbeddingStore < TextSegment > embeddingStore = new InMemoryEmbeddingStore < > ( ) ;
2023-10-23 14:33:05 +05:30
2023-10-25 10:40:32 +05:30
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor . builder ( )
. documentSplitter ( DocumentSplitters . recursive ( 500 , 0 ) )
. embeddingModel ( embeddingModel )
. embeddingStore ( embeddingStore )
. build ( ) ;
2023-10-23 14:33:05 +05:30
2023-10-26 23:15:14 +05:30
Document document = loadDocument ( Paths . get ( " src/test/resources/example-files/simpson's_adventures.txt " ) ) ;
2023-10-25 10:40:32 +05:30
ingestor . ingest ( document ) ;
2023-10-23 14:33:05 +05:30
ChatLanguageModel chatModel = OpenAiChatModel . builder ( )
2023-10-25 10:40:32 +05:30
. apiKey ( Constants . OPEN_AI_KEY )
2023-10-23 14:33:05 +05:30
. timeout ( ofSeconds ( 60 ) )
. build ( ) ;
2023-10-25 10:40:32 +05:30
ConversationalRetrievalChain chain = ConversationalRetrievalChain . builder ( )
. chatLanguageModel ( chatModel )
. retriever ( EmbeddingStoreRetriever . from ( embeddingStore , embeddingModel ) )
. chatMemory ( MessageWindowChatMemory . withMaxMessages ( 10 ) )
2023-10-26 23:15:14 +05:30
. promptTemplate ( PromptTemplate . from ( " Answer the following question to the best of your ability: {{question}} \ n \ nBase your answer on the following information: \ n{{information}} " ) )
2023-10-25 10:40:32 +05:30
. build ( ) ;
String answer = chain . execute ( " Who is Simpson? " ) ;
logger . info ( answer ) ;
2023-10-26 23:15:14 +05:30
assertNotNull ( answer ) ;
2023-10-23 14:33:05 +05:30
}
}