Add Edit request functionality (#14)

This commit is contained in:
Theo Kanning 2022-04-28 16:48:33 -05:00 committed by GitHub
parent 900e13bbda
commit 7f39df6e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 144 additions and 14 deletions

View File

@ -5,7 +5,7 @@ import lombok.Data;
/** /**
* A completion generated by GPT-3 * A completion generated by GPT-3
* *
* https://beta.openai.com/docs/api-reference/create-completion * https://beta.openai.com/docs/api-reference/completions/create
*/ */
@Data @Data
public class CompletionChoice { public class CompletionChoice {

View File

@ -11,8 +11,7 @@ import java.util.List;
* A request for OpenAi to generate a predicted completion for a prompt. * A request for OpenAi to generate a predicted completion for a prompt.
* All fields are nullable. * All fields are nullable.
* *
* Documentation taken from * https://beta.openai.com/docs/api-reference/completions/create
* https://beta.openai.com/docs/api-reference/create-completion
*/ */
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@ -41,7 +40,7 @@ public class CompletionRequest {
* What sampling temperature to use. Higher values means the model will take more risks. * What sampling temperature to use. Higher values means the model will take more risks.
* Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. * Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
* *
* We generally recommend using this or {@link top_p} but not both. * We generally recommend using this or {@link CompletionRequest#topP} but not both.
*/ */
Double temperature; Double temperature;
@ -50,7 +49,7 @@ public class CompletionRequest {
* the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are * the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are
* considered. * considered.
* *
* We generally recommend using this or {@link temperature} but not both. * We generally recommend using this or {@link CompletionRequest#temperature} but not both.
*/ */
Double topP; Double topP;
@ -58,7 +57,7 @@ public class CompletionRequest {
* How many completions to generate for each prompt. * How many completions to generate for each prompt.
* *
* Because this parameter generates many completions, it can quickly consume your token quota. * Because this parameter generates many completions, it can quickly consume your token quota.
* Use carefully and ensure that you have reasonable settings for {@link max_tokens} and {@link stop}. * Use carefully and ensure that you have reasonable settings for {@link CompletionRequest#maxTokens} and {@link CompletionRequest#stop}.
*/ */
Integer n; Integer n;
@ -105,7 +104,7 @@ public class CompletionRequest {
* (the one with the lowest log probability per token). * (the one with the lowest log probability per token).
* Results cannot be streamed. * Results cannot be streamed.
* *
* When used with {@link n}, best_of controls the number of candidate completions and n specifies how many to return, * When used with {@link CompletionRequest#n}, best_of controls the number of candidate completions and n specifies how many to return,
* best_of must be greater than n. * best_of must be greater than n.
*/ */
Integer bestOf; Integer bestOf;

View File

@ -7,12 +7,12 @@ import java.util.List;
/** /**
* An object containing a response from the completion api * An object containing a response from the completion api
* *
* https://beta.openai.com/docs/api-reference/create-completion * https://beta.openai.com/docs/api-reference/completions/create
*/ */
@Data @Data
public class CompletionResult { public class CompletionResult {
/** /**
* A unique id assigned to this completion * A unique id assigned to this completion.
*/ */
String id; String id;
@ -27,12 +27,12 @@ public class CompletionResult {
long created; long created;
/** /**
* The GPT-3 model used * The GPT-3 model used.
*/ */
String model; String model;
/** /**
* A list of generated completions * A list of generated completions.
*/ */
List<CompletionChoice> choices; List<CompletionChoice> choices;
} }

View File

@ -0,0 +1,22 @@
package com.theokanning.openai.edit;
import lombok.Data;
/**
* An edit generated by GPT-3
*
* https://beta.openai.com/docs/api-reference/edits/create
*/
@Data
public class EditChoice {
/**
* The edited text.
*/
String text;
/**
* This index of this completion in the returned list.
*/
Integer index;
}

View File

@ -0,0 +1,43 @@
package com.theokanning.openai.edit;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* Given a prompt and an instruction, OpenAi will return an edited version of the prompt
*
* https://beta.openai.com/docs/api-reference/edits/create
*/
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class EditRequest {
/**
* The input text to use as a starting point for the edit.
*/
String input;
/**
* The instruction that tells the model how to edit the prompt.
* For example, "Fix the spelling mistakes"
*/
String instruction;
/**
* What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
*
* We generally recommend altering this or {@link EditRequest#topP} but not both.
*/
Double temperature;
/**
* An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
*
* We generally recommend altering this or {@link EditRequest#temperature} but not both.
*/
Double topP;
}

View File

@ -0,0 +1,29 @@
package com.theokanning.openai.edit;
import lombok.Data;
import java.util.List;
/**
* A list of edits generated by GPT-3
*
* https://beta.openai.com/docs/api-reference/edits/create
*/
@Data
public class EditResult {
/**
* The type of object returned, should be "edit"
*/
String object;
/**
* The creation time in epoch milliseconds.
*/
long created;
/**
* A list of generated edits.
*/
List<EditChoice> choices;
}

View File

@ -12,7 +12,7 @@ import java.util.List;
* GPT-3 will perform a semantic search over the documents and score them based on how related they are to the query. * GPT-3 will perform a semantic search over the documents and score them based on how related they are to the query.
* Higher scores indicate a stronger relation. * Higher scores indicate a stronger relation.
* *
* https://beta.openai.com/docs/api-reference/search * https://beta.openai.com/docs/api-reference/searches
*/ */
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor

View File

@ -5,7 +5,7 @@ import lombok.Data;
/** /**
* A search result for a single document. * A search result for a single document.
* *
* https://beta.openai.com/docs/api-reference/search * https://beta.openai.com/docs/api-reference/searches
*/ */
@Data @Data
public class SearchResult { public class SearchResult {

View File

@ -6,6 +6,8 @@ import com.theokanning.openai.classification.ClassificationRequest;
import com.theokanning.openai.classification.ClassificationResult; import com.theokanning.openai.classification.ClassificationResult;
import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult; import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.edit.EditRequest;
import com.theokanning.openai.edit.EditResult;
import com.theokanning.openai.embedding.EmbeddingRequest; import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult; import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.engine.Engine; import com.theokanning.openai.engine.Engine;
@ -31,6 +33,9 @@ public interface OpenAiApi {
@POST("/v1/engines/{engine_id}/completions") @POST("/v1/engines/{engine_id}/completions")
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request); Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
@POST("/v1/engines/{engine_id}/edits")
Single<EditResult> createEdit(@Path("engine_id") String engineId, @Body EditRequest request);
@POST("/v1/engines/{engine_id}/search") @POST("/v1/engines/{engine_id}/search")
Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request); Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request);

View File

@ -10,6 +10,8 @@ import com.theokanning.openai.classification.ClassificationRequest;
import com.theokanning.openai.classification.ClassificationResult; import com.theokanning.openai.classification.ClassificationResult;
import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult; import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.edit.EditRequest;
import com.theokanning.openai.edit.EditResult;
import com.theokanning.openai.embedding.EmbeddingRequest; import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult; import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.engine.Engine; import com.theokanning.openai.engine.Engine;
@ -64,6 +66,10 @@ public class OpenAiService {
return api.createCompletion(engineId, request).blockingGet(); return api.createCompletion(engineId, request).blockingGet();
} }
public EditResult createEdit(String engineId, EditRequest request) {
return api.createEdit(engineId, request).blockingGet();
}
public List<SearchResult> search(String engineId, SearchRequest request) { public List<SearchResult> search(String engineId, SearchRequest request) {
return api.search(engineId, request).blockingGet().data; return api.search(engineId, request).blockingGet().data;
} }

View File

@ -0,0 +1,27 @@
package com.theokanning.openai;
import com.theokanning.openai.edit.EditRequest;
import com.theokanning.openai.edit.EditResult;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Disabled // disabled until edit example CURL works
public class EditTest {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);
@Test
void edit() {
EditRequest request = EditRequest.builder()
.input("What day of the wek is it?")
.instruction("Fix the spelling mistakes")
.build();
EditResult result = service.createEdit("text-ada-001", request);
assertEquals("What day of the week is it?", result.getChoices().get(0).getText());
}
}

View File

@ -20,7 +20,6 @@ class OpenAiApiExample {
System.out.println(ada); System.out.println(ada);
System.out.println("\nCreating completion..."); System.out.println("\nCreating completion...");
CompletionRequest completionRequest = CompletionRequest.builder() CompletionRequest completionRequest = CompletionRequest.builder()
.prompt("Somebody once told me the world is gonna roll me") .prompt("Somebody once told me the world is gonna roll me")
.echo(true) .echo(true)