2023-10-23 14:33:05 +05:30
package com.baeldung.langchain ;
import static dev.langchain4j.data.document.FileSystemDocumentLoader.loadDocument ;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO ;
import static java.time.Duration.ofSeconds ;
import static java.util.stream.Collectors.joining ;
import java.io.File ;
import java.net.MalformedURLException ;
import java.net.URISyntaxException ;
import java.net.URL ;
import java.nio.file.Path ;
import java.nio.file.Paths ;
import java.util.HashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.logging.Logger ;
import org.junit.Assert ;
import org.junit.Test ;
import dev.langchain4j.data.document.Document ;
import dev.langchain4j.data.document.DocumentSplitter ;
import dev.langchain4j.data.document.splitter.DocumentSplitters ;
import dev.langchain4j.data.embedding.Embedding ;
import dev.langchain4j.data.message.AiMessage ;
import dev.langchain4j.data.segment.TextSegment ;
import dev.langchain4j.model.chat.ChatLanguageModel ;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel ;
import dev.langchain4j.model.embedding.EmbeddingModel ;
import dev.langchain4j.model.input.Prompt ;
import dev.langchain4j.model.input.PromptTemplate ;
import dev.langchain4j.model.openai.OpenAiChatModel ;
import dev.langchain4j.model.openai.OpenAiTokenizer ;
import dev.langchain4j.store.embedding.EmbeddingMatch ;
import dev.langchain4j.store.embedding.EmbeddingStore ;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore ;
2023-10-23 16:05:03 +05:30
public class ChainWithDocumentLiveTest {
2023-10-23 14:33:05 +05:30
@Test
public void givenChainWithDocument_whenPrompted_thenValidResponse ( ) {
Document document = loadDocument ( toPath ( " src/test/resources/example-files/simpson's_adventures.txt " ) ) ;
DocumentSplitter splitter = DocumentSplitters . recursive ( 100 , 0 , new OpenAiTokenizer ( GPT_3_5_TURBO ) ) ;
List < TextSegment > segments = splitter . split ( document ) ;
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel ( ) ;
List < Embedding > embeddings = embeddingModel . embedAll ( segments )
. content ( ) ;
EmbeddingStore < TextSegment > embeddingStore = new InMemoryEmbeddingStore < > ( ) ;
embeddingStore . addAll ( embeddings , segments ) ;
String question = " Who is Simpson? " ;
Embedding questionEmbedding = embeddingModel . embed ( question )
. content ( ) ;
int maxResults = 3 ;
double minScore = 0 . 7 ;
List < EmbeddingMatch < TextSegment > > relevantEmbeddings = embeddingStore . findRelevant ( questionEmbedding , maxResults , minScore ) ;
PromptTemplate promptTemplate = PromptTemplate . from ( " Answer the following question to the best of your ability: \ n " + " \ n " + " Question: \ n " + " {{question}} \ n " + " \ n " + " Base your answer on the following information: \ n " + " {{information}} " ) ;
String information = relevantEmbeddings . stream ( )
. map ( match - > match . embedded ( )
. text ( ) )
. collect ( joining ( " \ n \ n " ) ) ;
Map < String , Object > variables = new HashMap < > ( ) ;
variables . put ( " question " , question ) ;
variables . put ( " information " , information ) ;
Prompt prompt = promptTemplate . apply ( variables ) ;
ChatLanguageModel chatModel = OpenAiChatModel . builder ( )
. apiKey ( Constants . OPEN_API_KEY )
. timeout ( ofSeconds ( 60 ) )
. build ( ) ;
AiMessage aiMessage = chatModel . generate ( prompt . toUserMessage ( ) )
. content ( ) ;
Logger . getGlobal ( )
. info ( aiMessage . text ( ) ) ;
Assert . assertNotNull ( aiMessage . text ( ) ) ;
}
private static Path toPath ( String fileName ) {
try {
URL fileUrl = new File ( fileName ) . toURI ( )
. toURL ( ) ;
System . out . println ( new File ( fileName ) . toURI ( )
. toURL ( ) ) ;
return Paths . get ( fileUrl . toURI ( ) ) ;
} catch ( URISyntaxException | MalformedURLException e ) {
throw new RuntimeException ( e ) ;
}
}
}