Add Model support and new v1 endpoints (#27)

The old endpoints are marked as deprecated
Also marked response fields as public for easier access
This commit is contained in:
Theo Kanning 2022-08-22 13:55:20 -05:00 committed by GitHub
parent 252db27577
commit 553e22fea2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 265 additions and 37 deletions

View File

@ -20,7 +20,8 @@ import java.util.List;
public class CompletionRequest {
/**
* The name of the model to use, only used if specifying a fine tuned model.
* The name of the model to use.
* Required if specifying a fine tuned model or if using the new v1/completions endpoint.
*/
String model;

View File

@ -14,6 +14,12 @@ import lombok.*;
@Data
public class EditRequest {
/**
* The name of the model to use.
* Required if using the new v1/edits endpoint.
*/
String model;
/**
* The input text to use as a starting point for the edit.
*/
@ -32,14 +38,17 @@ public class EditRequest {
Integer n;
/**
* What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
* What sampling temperature to use. Higher values means the model will take more risks.
* Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
*
* We generally recommend altering this or {@link EditRequest#topP} but not both.
*/
Double temperature;
/**
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of
* the tokens with top_p probability mass.So 0.1 means only the tokens comprising the top 10% probability mass are
* considered.
*
* We generally recommend altering this or {@link EditRequest#temperature} but not both.
*/

View File

@ -15,20 +15,20 @@ public class EditResult {
/**
* The type of object returned, should be "edit"
*/
String object;
public String object;
/**
* The creation time in epoch milliseconds.
*/
long created;
public long created;
/**
* A list of generated edits.
*/
List<EditChoice> choices;
public List<EditChoice> choices;
/**
* The API usage for this request
*/
EditUsage usage;
public EditUsage usage;
}

View File

@ -7,7 +7,6 @@ import java.util.List;
/**
* Creates an embedding vector representing the input text.
*
* Documentation taken from
* https://beta.openai.com/docs/api-reference/embeddings/create
*/
@Builder
@ -16,6 +15,12 @@ import java.util.List;
@Data
public class EmbeddingRequest {
/**
* The name of the model to use.
* Required if using the new v1/embeddings endpoint.
*/
String model;
/**
* Input text to get embeddings for, encoded as a string or array of tokens.
* To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays.

View File

@ -7,6 +7,7 @@ import lombok.Data;
*
* https://beta.openai.com/docs/api-reference/retrieve-engine
*/
@Deprecated
@Data
public class Engine {
/**

View File

@ -1,5 +1,6 @@
package com.theokanning.openai.finetune;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;
import java.util.List;
@ -83,7 +84,8 @@ public class FineTuneRequest {
*
* This parameter is required for multiclass classification.
*/
Integer classificationNClasses; // todo verify snake case
@JsonProperty("classification_n_classes")
Integer classificationNClasses;
/**
* The positive class in binary classification.

View File

@ -0,0 +1,43 @@
package com.theokanning.openai.model;
import lombok.Data;
import java.util.List;
/**
* GPT-3 model details
*
* https://beta.openai.com/docs/api-reference/models
*/
@Data
public class Model {
/**
* An identifier for this model, used to specify the model when making completions, etc
*/
public String id;
/**
* The type of object returned, should be "model"
*/
public String object;
/**
* The owner of the GPT-3 model, typically "openai"
*/
public String ownedBy;
/**
* List of permissions for this model
*/
public List<Permission> permission;
/**
* The root model that this and its parent (if applicable) are based on
*/
public String root;
/**
* The parent model that this is based on
*/
public String parent;
}

View File

@ -0,0 +1,47 @@
package com.theokanning.openai.model;
import lombok.Data;
/**
* GPT-3 model permissions
* I couldn't find documentation for the specific permissions, and I've elected to leave them undocumented rather than
* write something incorrect.
*
* https://beta.openai.com/docs/api-reference/models
*/
@Data
public class Permission {
/**
* An identifier for this model permission
*/
public String id;
/**
* The type of object returned, should be "model_permission"
*/
public String object;
/**
* The creation time in epoch seconds.
*/
public long created;
public boolean allowCreateEngine;
public boolean allowSampling;
public boolean allowLogProbs;
public boolean allowSearchIndices;
public boolean allowView;
public boolean allowFineTuning;
public String organization;
public String group;
public boolean isBlocking;
}

View File

@ -12,13 +12,13 @@ public class Moderation {
/**
* Set to true if the model classifies the content as violating OpenAI's content policy, false otherwise
*/
boolean flagged;
public boolean flagged;
/**
* Object containing per-category binary content policy violation flags.
* For each category, the value is true if the model flags the corresponding category as violated, false otherwise.
*/
ModerationCategories categories;
public ModerationCategories categories;
/**
* Object containing per-category raw scores output by the model, denoting the model's confidence that the
@ -26,5 +26,5 @@ public class Moderation {
* The value is between 0 and 1, where higher values denote higher confidence.
* The scores should not be interpreted as probabilities.
*/
ModerationCategoryScores categoryScores;
public ModerationCategoryScores categoryScores;
}

View File

@ -14,21 +14,21 @@ import java.util.List;
@Data
public class ModerationCategories {
boolean hate;
public boolean hate;
@JsonProperty("hate/threatening")
boolean hateThreatening;
public boolean hateThreatening;
@JsonProperty("self-harm")
boolean selfHarm;
public boolean selfHarm;
boolean sexual;
public boolean sexual;
@JsonProperty("sexual/minors")
boolean sexualMinors;
public boolean sexualMinors;
boolean violence;
public boolean violence;
@JsonProperty("violence/graphic")
boolean violenceGraphic;
public boolean violenceGraphic;
}

View File

@ -11,21 +11,21 @@ import lombok.Data;
@Data
public class ModerationCategoryScores {
double hate;
public double hate;
@JsonProperty("hate/threatening")
double hateThreatening;
public double hateThreatening;
@JsonProperty("self-harm")
double selfHarm;
public double selfHarm;
double sexual;
public double sexual;
@JsonProperty("sexual/minors")
double sexualMinors;
public double sexualMinors;
double violence;
public double violence;
@JsonProperty("violence/graphic")
double violenceGraphic;
public double violenceGraphic;
}

View File

@ -14,15 +14,15 @@ public class ModerationResult {
/**
* A unique id assigned to this moderation.
*/
String id;
public String id;
/**
* The GPT-3 model used.
*/
String model;
public String model;
/**
* A list of moderation scores.
*/
List<Moderation> results;
public List<Moderation> results;
}

View File

@ -15,6 +15,7 @@ import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneResult;
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;
import com.theokanning.openai.search.SearchRequest;
@ -26,18 +27,30 @@ import retrofit2.http.*;
public interface OpenAiApi {
@GET("v1/engines")
Single<OpenAiResponse<Engine>> getEngines();
@GET("v1/models")
Single<OpenAiResponse<Model>> listModels();
@GET("/v1/engines/{engine_id}")
Single<Engine> getEngine(@Path("engine_id") String engineId);
@GET("/v1/models/{model_id}")
Single<Model> getModel(@Path("model_id") String modelId);
@POST("/v1/completions")
Single<CompletionResult> createCompletion(@Body CompletionRequest request);
@Deprecated
@POST("/v1/engines/{engine_id}/completions")
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
@POST("/v1/edits")
Single<EditResult> createEdit(@Body EditRequest request);
@Deprecated
@POST("/v1/engines/{engine_id}/edits")
Single<EditResult> createEdit(@Path("engine_id") String engineId, @Body EditRequest request);
@POST("/v1/embeddings")
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request);
@Deprecated
@POST("/v1/engines/{engine_id}/embeddings")
Single<EmbeddingResult> createEmbeddings(@Path("engine_id") String engineId, @Body EmbeddingRequest request);
@ -78,6 +91,14 @@ public interface OpenAiApi {
@POST("/v1/moderations")
Single<ModerationResult> createModeration(@Body ModerationRequest request);
@Deprecated
@GET("v1/engines")
Single<OpenAiResponse<Engine>> getEngines();
@Deprecated
@GET("/v1/engines/{engine_id}")
Single<Engine> getEngine(@Path("engine_id") String engineId);
@Deprecated
@POST("v1/answers")
Single<AnswerResult> createAnswer(@Body AnswerRequest request);

View File

@ -19,6 +19,7 @@ import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneResult;
import com.theokanning.openai.model.Model;
import com.theokanning.openai.moderation.ModerationRequest;
import com.theokanning.openai.moderation.ModerationResult;
import com.theokanning.openai.search.SearchRequest;
@ -78,22 +79,40 @@ public class OpenAiService {
this.api = api;
}
public List<Engine> getEngines() {
return api.getEngines().blockingGet().data;
public List<Model> listModels() {
return api.listModels().blockingGet().data;
}
public Engine getEngine(String engineId) {
return api.getEngine(engineId).blockingGet();
public Model getModel(String modelId) {
return api.getModel(modelId).blockingGet();
}
public CompletionResult createCompletion(CompletionRequest request) {
return api.createCompletion(request).blockingGet();
}
/** Use {@link OpenAiService#createCompletion(CompletionRequest)} and {@link CompletionRequest#model}instead */
@Deprecated
public CompletionResult createCompletion(String engineId, CompletionRequest request) {
return api.createCompletion(engineId, request).blockingGet();
}
public EditResult createEdit(EditRequest request) {
return api.createEdit(request).blockingGet();
}
/** Use {@link OpenAiService#createEdit(EditRequest)} and {@link EditRequest#model}instead */
@Deprecated
public EditResult createEdit(String engineId, EditRequest request) {
return api.createEdit(engineId, request).blockingGet();
}
public EmbeddingResult createEmbeddings(EmbeddingRequest request) {
return api.createEmbeddings(request).blockingGet();
}
/** Use {@link OpenAiService#createEmbeddings(EmbeddingRequest)} and {@link EmbeddingRequest#model}instead */
@Deprecated
public EmbeddingResult createEmbeddings(String engineId, EmbeddingRequest request) {
return api.createEmbeddings(engineId, request).blockingGet();
}
@ -151,6 +170,16 @@ public class OpenAiService {
return api.createModeration(request).blockingGet();
}
@Deprecated
public List<Engine> getEngines() {
return api.getEngines().blockingGet().data;
}
@Deprecated
public Engine getEngine(String engineId) {
return api.getEngine(engineId).blockingGet();
}
@Deprecated
public AnswerResult createAnswer(AnswerRequest request) {
return api.createAnswer(request).blockingGet();

View File

@ -16,6 +16,19 @@ public class CompletionTest {
@Test
void createCompletion() {
CompletionRequest completionRequest = CompletionRequest.builder()
.model("ada")
.prompt("Somebody once told me the world is gonna roll me")
.echo(true)
.user("testing")
.build();
List<CompletionChoice> choices = service.createCompletion(completionRequest).getChoices();
assertFalse(choices.isEmpty());
}
@Test
void createCompletionDeprecated() {
CompletionRequest completionRequest = CompletionRequest.builder()
.prompt("Somebody once told me the world is gonna roll me")
.echo(true)

View File

@ -13,6 +13,19 @@ public class EditTest {
@Test
void edit() {
EditRequest request = EditRequest.builder()
.model("text-davinci-edit-001")
.input("What day of the wek is it?")
.instruction("Fix the spelling mistakes")
.build();
EditResult result = service.createEdit( request);
assertNotNull(result.getChoices().get(0).getText());
}
@Test
void editDeprecated() {
EditRequest request = EditRequest.builder()
.input("What day of the wek is it?")
.instruction("Fix the spelling mistakes")

View File

@ -17,6 +17,19 @@ public class EmbeddingTest {
@Test
void createEmbeddings() {
EmbeddingRequest embeddingRequest = EmbeddingRequest.builder()
.model("text-similarity-babbage-001")
.input(Collections.singletonList("The food was delicious and the waiter..."))
.build();
List<Embedding> embeddings = service.createEmbeddings(embeddingRequest).getData();
assertFalse(embeddings.isEmpty());
assertFalse(embeddings.get(0).getEmbedding().isEmpty());
}
@Test
void createEmbeddingsDeprecated() {
EmbeddingRequest embeddingRequest = EmbeddingRequest.builder()
.input(Collections.singletonList("The food was delicious and the waiter..."))
.build();

View File

@ -0,0 +1,31 @@
package com.theokanning.openai;
import com.theokanning.openai.model.Model;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class ModelTest {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);
@Test
void listModels() {
List<Model> models = service.listModels();
assertFalse(models.isEmpty());
}
@Test
void getModel() {
Model ada = service.getModel("ada");
assertEquals("ada", ada.id);
assertEquals("openai", ada.ownedBy);
assertFalse(ada.permission.isEmpty());
}
}