Add classification support (#11)
This commit is contained in:
parent
ed2f1152e8
commit
103c34da94
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Publish
|
name: Test
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
@ -0,0 +1,120 @@
|
|||||||
|
package com.theokanning.openai.classification;
|
||||||
|
|
||||||
|
import lombok.*;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A request for OpenAi to classify text based on provided examples
|
||||||
|
* All fields are nullable.
|
||||||
|
*
|
||||||
|
* Documentation taken from
|
||||||
|
* https://beta.openai.com/docs/api-reference/classifications/create
|
||||||
|
*/
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
|
public class ClassificationRequest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ID of the engine to use for completion
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
String model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query to be classified
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
String query;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of examples with labels, in the following format:
|
||||||
|
*
|
||||||
|
* [["The movie is so interesting.", "Positive"], ["It is quite boring.", "Negative"], ...]
|
||||||
|
*
|
||||||
|
* All the label strings will be normalized to be capitalized.
|
||||||
|
*
|
||||||
|
* You should specify either examples or file, but not both.
|
||||||
|
*/
|
||||||
|
List<List<String>> examples;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The ID of the uploaded file that contains training examples.
|
||||||
|
* See upload file for how to upload a file of the desired format and purpose.
|
||||||
|
*
|
||||||
|
* You should specify either examples or file, but not both.
|
||||||
|
*/
|
||||||
|
String file;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The set of categories being classified.
|
||||||
|
* If not specified, candidate labels will be automatically collected from the examples you provide.
|
||||||
|
* All the label strings will be normalized to be capitalized.
|
||||||
|
*/
|
||||||
|
List<String> labels;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ID of the engine to use for Search. You can select one of ada, babbage, curie, or davinci.
|
||||||
|
*/
|
||||||
|
String searchModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* What sampling temperature to use. Higher values means the model will take more risks.
|
||||||
|
* Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
|
||||||
|
*
|
||||||
|
* We generally recommend using this or {@link top_p} but not both.
|
||||||
|
*/
|
||||||
|
Double temperature;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens.
|
||||||
|
* For example, if logprobs is 10, the API will return a list of the 10 most likely tokens.
|
||||||
|
* The API will always return the logprob of the sampled token,
|
||||||
|
* so there may be up to logprobs+1 elements in the response.
|
||||||
|
*/
|
||||||
|
Integer logprobs;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum number of examples to be ranked by Search when using file.
|
||||||
|
* Setting it to a higher value leads to improved accuracy but with increased latency and cost.
|
||||||
|
*/
|
||||||
|
Integer maxExamples;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Modify the likelihood of specified tokens appearing in the completion.
|
||||||
|
*
|
||||||
|
* Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an
|
||||||
|
* associated bias value from -100 to 100.
|
||||||
|
*/
|
||||||
|
Map<String, Double> logitBias;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If set to true, the returned JSON will include a "prompt" field containing the final prompt that was
|
||||||
|
* used to request a completion. This is mainly useful for debugging purposes.
|
||||||
|
*/
|
||||||
|
Boolean returnPrompt;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A special boolean flag for showing metadata.
|
||||||
|
* If set to true, each document entry in the returned JSON will contain a "metadata" field.
|
||||||
|
*
|
||||||
|
* This flag only takes effect when file is set.
|
||||||
|
*/
|
||||||
|
Boolean returnMetadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If an object name is in the list, we provide the full information of the object;
|
||||||
|
* otherwise, we only provide the object ID.
|
||||||
|
*
|
||||||
|
* Currently we support completion and file objects for expansion.
|
||||||
|
*/
|
||||||
|
List<String> expand;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||||
|
*/
|
||||||
|
String user;
|
||||||
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
package com.theokanning.openai.classification;
|
||||||
|
|
||||||
|
import com.theokanning.openai.completion.CompletionChoice;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An object containing a response from the classification api
|
||||||
|
* <p>
|
||||||
|
* https://beta.openai.com/docs/api-reference/classifications/create
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class ClassificationResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A unique id assigned to this completion
|
||||||
|
*/
|
||||||
|
String completion;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The predicted label for the query text.
|
||||||
|
*/
|
||||||
|
String label;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The GPT-3 model used for completion
|
||||||
|
*/
|
||||||
|
String model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The type of object returned, should be "classification"
|
||||||
|
*/
|
||||||
|
String object;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The GPT-3 model used for search
|
||||||
|
*/
|
||||||
|
String searchModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of the most relevant examples for the query text.
|
||||||
|
*/
|
||||||
|
List<Example> selectedExamples;
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
package com.theokanning.openai.classification;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents an example returned by the classification api
|
||||||
|
*
|
||||||
|
* https://beta.openai.com/docs/api-reference/classifications/create
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class Example {
|
||||||
|
/**
|
||||||
|
* The position of this example in the example list
|
||||||
|
*/
|
||||||
|
Integer document;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The label of the example
|
||||||
|
*/
|
||||||
|
String label;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The text of the example
|
||||||
|
*/
|
||||||
|
String text;
|
||||||
|
}
|
@ -1,9 +1,6 @@
|
|||||||
package com.theokanning.openai.finetune;
|
package com.theokanning.openai.finetune;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.*;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@ -22,6 +19,7 @@ public class FineTuneRequest {
|
|||||||
/**
|
/**
|
||||||
* The ID of an uploaded file that contains training data.
|
* The ID of an uploaded file that contains training data.
|
||||||
*/
|
*/
|
||||||
|
@NonNull
|
||||||
String trainingFile;
|
String trainingFile;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package com.theokanning.openai;
|
package com.theokanning.openai;
|
||||||
|
|
||||||
|
import com.theokanning.openai.classification.ClassificationRequest;
|
||||||
|
import com.theokanning.openai.classification.ClassificationResult;
|
||||||
import com.theokanning.openai.engine.Engine;
|
import com.theokanning.openai.engine.Engine;
|
||||||
import com.theokanning.openai.file.File;
|
import com.theokanning.openai.file.File;
|
||||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||||
@ -28,6 +30,9 @@ public interface OpenAiApi {
|
|||||||
@POST("/v1/engines/{engine_id}/search")
|
@POST("/v1/engines/{engine_id}/search")
|
||||||
Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request);
|
Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request);
|
||||||
|
|
||||||
|
@POST("v1/classifications")
|
||||||
|
Single<ClassificationResult> createClassification(@Body ClassificationRequest request);
|
||||||
|
|
||||||
@GET("/v1/files")
|
@GET("/v1/files")
|
||||||
Single<OpenAiResponse<File>> listFiles();
|
Single<OpenAiResponse<File>> listFiles();
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
|||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
|
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
|
||||||
|
import com.theokanning.openai.classification.ClassificationRequest;
|
||||||
|
import com.theokanning.openai.classification.ClassificationResult;
|
||||||
import com.theokanning.openai.file.File;
|
import com.theokanning.openai.file.File;
|
||||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||||
import com.theokanning.openai.finetune.FineTuneEvent;
|
import com.theokanning.openai.finetune.FineTuneEvent;
|
||||||
@ -62,6 +64,10 @@ public class OpenAiService {
|
|||||||
return api.search(engineId, request).blockingGet().data;
|
return api.search(engineId, request).blockingGet().data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ClassificationResult createClassification(ClassificationRequest request) {
|
||||||
|
return api.createClassification(request).blockingGet();
|
||||||
|
}
|
||||||
|
|
||||||
public List<File> listFiles() {
|
public List<File> listFiles() {
|
||||||
return api.listFiles().blockingGet().data;
|
return api.listFiles().blockingGet().data;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
package com.theokanning.openai;
|
||||||
|
|
||||||
|
import com.theokanning.openai.classification.ClassificationRequest;
|
||||||
|
import com.theokanning.openai.classification.ClassificationResult;
|
||||||
|
import com.theokanning.openai.completion.CompletionChoice;
|
||||||
|
import com.theokanning.openai.completion.CompletionRequest;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
|
|
||||||
|
public class ClassificationTest {
|
||||||
|
|
||||||
|
String token = System.getenv("OPENAI_TOKEN");
|
||||||
|
OpenAiService service = new OpenAiService(token);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void createCompletion() {
|
||||||
|
ClassificationRequest classificationRequest = ClassificationRequest.builder()
|
||||||
|
.examples(Arrays.asList(
|
||||||
|
Arrays.asList("A happy moment", "Positive"),
|
||||||
|
Arrays.asList("I am sad.", "Negative"),
|
||||||
|
Arrays.asList("I am feeling awesome", "Positive")
|
||||||
|
))
|
||||||
|
.query("It is a raining day :(")
|
||||||
|
.model("curie")
|
||||||
|
.searchModel("ada")
|
||||||
|
.labels(Arrays.asList("Positive", "Negative", "Neutral"))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ClassificationResult result = service.createClassification(classificationRequest);
|
||||||
|
|
||||||
|
assertNotNull(result.getCompletion());
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@ import com.theokanning.openai.finetune.FineTuneResult;
|
|||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@ -17,10 +18,13 @@ public class FineTuneTest {
|
|||||||
|
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
static void setup() {
|
static void setup() throws Exception {
|
||||||
String token = System.getenv("OPENAI_TOKEN");
|
String token = System.getenv("OPENAI_TOKEN");
|
||||||
service = new OpenAiService(token);
|
service = new OpenAiService(token);
|
||||||
fileId = service.uploadFile("fine-tune", "src/test/resources/fine-tuning-data.jsonl").getId();
|
fileId = service.uploadFile("fine-tune", "src/test/resources/fine-tuning-data.jsonl").getId();
|
||||||
|
|
||||||
|
// wait for file to be processed
|
||||||
|
TimeUnit.SECONDS.sleep(10);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterAll
|
@AfterAll
|
||||||
|
Loading…
x
Reference in New Issue
Block a user