Add embeddings support (#13)

This commit is contained in:
Theo Kanning 2022-04-28 16:42:13 -05:00 committed by GitHub
parent 9f5b64b151
commit 900e13bbda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 134 additions and 1 deletions

View File

@ -6,7 +6,7 @@ import java.util.List;
/**
* An object containing a response from the answer api
* <p>
*
* https://beta.openai.com/docs/api-reference/answers/create
*/
@Data

View File

@ -0,0 +1,29 @@
package com.theokanning.openai.embedding;
import lombok.Data;
import java.util.List;
/**
* Represents an embedding returned by the embedding api
*
* https://beta.openai.com/docs/api-reference/classifications/create
*/
@Data
public class Embedding {
/**
* The type of object returned, should be "embedding"
*/
String object;
/**
* The embedding vector
*/
List<Double> embedding;
/**
* The position of this embedding in the list
*/
Integer index;
}

View File

@ -0,0 +1,34 @@
package com.theokanning.openai.embedding;
import lombok.*;
import java.util.List;
/**
* Creates an embedding vector representing the input text.
*
* Documentation taken from
* https://beta.openai.com/docs/api-reference/embeddings/create
*/
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class EmbeddingRequest {
/**
* Input text to get embeddings for, encoded as a string or array of tokens.
* To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays.
* Each input must not exceed 2048 tokens in length.
* <p>
* Unless your are embedding code, we suggest replacing newlines (\n) in your input with a single space,
* as we have observed inferior results when newlines are present.
*/
@NonNull
List<String> input;
/**
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
*/
String user;
}

View File

@ -0,0 +1,29 @@
package com.theokanning.openai.embedding;
import lombok.Data;
import java.util.List;
/**
* An object containing a response from the answer api
*
* https://beta.openai.com/docs/api-reference/embeddings/create
*/
@Data
public class EmbeddingResult {
/**
* The GPT-3 model used for generating embeddings
*/
String model;
/**
* The type of object returned, should be "list"
*/
String object;
/**
* A list of the calculated embeddings
*/
List<Embedding> data;
}

View File

@ -6,6 +6,8 @@ import com.theokanning.openai.classification.ClassificationRequest;
import com.theokanning.openai.classification.ClassificationResult;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.engine.Engine;
import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneEvent;
@ -72,4 +74,8 @@ public interface OpenAiApi {
@DELETE("/v1/models/{fine_tune_id}")
Single<DeleteResult> deleteFineTune(@Path("fine_tune_id") String fineTuneId);
@POST("/v1/engines/{engine_id}/embeddings")
Single<EmbeddingResult> createEmbeddings(@Path("engine_id") String engineId, @Body EmbeddingRequest request);
}

View File

@ -10,6 +10,8 @@ import com.theokanning.openai.classification.ClassificationRequest;
import com.theokanning.openai.classification.ClassificationResult;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.engine.Engine;
import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneEvent;
@ -122,4 +124,8 @@ public class OpenAiService {
public DeleteResult deleteFineTune(String fineTuneId) {
return api.deleteFineTune(fineTuneId).blockingGet();
}
public EmbeddingResult createEmbeddings(String engineId, EmbeddingRequest request) {
return api.createEmbeddings(engineId, request).blockingGet();
}
}

View File

@ -0,0 +1,29 @@
package com.theokanning.openai;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class EmbeddingTest {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);
@Test
void createEmbeddings() {
EmbeddingRequest embeddingRequest = EmbeddingRequest.builder()
.input(Collections.singletonList("The food was delicious and the waiter..."))
.build();
List<Embedding> embeddings = service.createEmbeddings("text-similarity-babbage-001", embeddingRequest).getData();
assertFalse(embeddings.isEmpty());
assertFalse(embeddings.get(0).getEmbedding().isEmpty());
}
}