Add file and fine-tune api support (#6)
This commit is contained in:
parent
d410abe1b5
commit
9a05c6ac77
|
@ -0,0 +1,24 @@
|
|||
package com.theokanning.openai;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* A response when deleting an object
|
||||
*/
|
||||
@Data
|
||||
public class DeleteResult {
|
||||
/**
|
||||
* The id of the object.
|
||||
*/
|
||||
String id;
|
||||
|
||||
/**
|
||||
* The type of object deleted, for example "file" or "model"
|
||||
*/
|
||||
String object;
|
||||
|
||||
/**
|
||||
* True if successfully deleted
|
||||
*/
|
||||
boolean deleted;
|
||||
}
|
|
@ -20,6 +20,11 @@ import java.util.List;
|
|||
@Data
|
||||
public class CompletionRequest {
|
||||
|
||||
/**
|
||||
* The name of the model to use, only used if specifying a fine tuned model.
|
||||
*/
|
||||
String model;
|
||||
|
||||
/**
|
||||
* An optional prompt to complete from
|
||||
*/
|
||||
|
|
|
@ -22,7 +22,7 @@ public class CompletionResult {
|
|||
String object;
|
||||
|
||||
/**
|
||||
* The creation time in epoch milliseconds.
|
||||
* The creation time in epoch seconds.
|
||||
*/
|
||||
long created;
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
package com.theokanning.openai.file;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* A file uploaded to OpenAi
|
||||
*
|
||||
* https://beta.openai.com/docs/api-reference/files
|
||||
*/
|
||||
@Data
|
||||
public class File {
|
||||
|
||||
/**
|
||||
* The unique id of this file.
|
||||
*/
|
||||
String id;
|
||||
|
||||
/**
|
||||
* The type of object returned, should be "file".
|
||||
*/
|
||||
String object;
|
||||
|
||||
/**
|
||||
* File size in bytes.
|
||||
*/
|
||||
Long bytes;
|
||||
|
||||
/**
|
||||
* The creation time in epoch seconds.
|
||||
*/
|
||||
Long createdAt;
|
||||
|
||||
/**
|
||||
* The name of the file.
|
||||
*/
|
||||
String filename;
|
||||
|
||||
/**
|
||||
* Description of the file's purpose.
|
||||
*/
|
||||
String purpose;
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package com.theokanning.openai.finetune;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* An object representing an event in the lifecycle of a fine-tuning job
|
||||
*
|
||||
* https://beta.openai.com/docs/api-reference/fine-tunes
|
||||
*/
|
||||
@Data
|
||||
public class FineTuneEvent {
|
||||
/**
|
||||
* The type of object returned, should be "fine-tune-event".
|
||||
*/
|
||||
String object;
|
||||
|
||||
/**
|
||||
* The creation time in epoch seconds.
|
||||
*/
|
||||
Long createdAt;
|
||||
|
||||
/**
|
||||
* The log level of this message.
|
||||
*/
|
||||
String level;
|
||||
|
||||
/**
|
||||
* The event message.
|
||||
*/
|
||||
String message;
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package com.theokanning.openai.finetune;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A request for OpenAi to create a fine-tuned model
|
||||
* All fields except trainingFile are nullable.
|
||||
*
|
||||
* https://beta.openai.com/docs/api-reference/fine-tunes/create
|
||||
*/
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public class FineTuneRequest {
|
||||
|
||||
/**
|
||||
* The ID of an uploaded file that contains training data.
|
||||
*/
|
||||
String trainingFile;
|
||||
|
||||
/**
|
||||
* The ID of an uploaded file that contains validation data.
|
||||
*/
|
||||
String validationFile;
|
||||
|
||||
/**
|
||||
* The name of the base model to fine-tune. You can select one of "ada", "babbage", "curie", or "davinci".
|
||||
* To learn more about these models, see the Engines documentation.
|
||||
*/
|
||||
String model;
|
||||
|
||||
/**
|
||||
* The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
|
||||
*/
|
||||
Integer nEpochs;
|
||||
|
||||
/**
|
||||
* The batch size to use for training.
|
||||
* The batch size is the number of training examples used to train a single forward and backward pass.
|
||||
*
|
||||
* By default, the batch size will be dynamically configured to be ~0.2% of the number of examples in the training
|
||||
* set, capped at 256 - in general, we've found that larger batch sizes tend to work better for larger datasets.
|
||||
*/
|
||||
Integer batchSize;
|
||||
|
||||
/**
|
||||
* The learning rate multiplier to use for training.
|
||||
* The fine-tuning learning rate is the original learning rate used for pretraining multiplied by this value.
|
||||
*
|
||||
* By default, the learning rate multiplier is the 0.05, 0.1, or 0.2 depending on final batch_size
|
||||
* (larger learning rates tend to perform better with larger batch sizes).
|
||||
* We recommend experimenting with values in the range 0.02 to 0.2 to see what produces the best results.
|
||||
*/
|
||||
Double learningRateMultiplier;
|
||||
|
||||
/**
|
||||
* The weight to use for loss on the prompt tokens.
|
||||
* This controls how much the model tries to learn to generate the prompt
|
||||
* (as compared to the completion which always has a weight of 1.0),
|
||||
* and can add a stabilizing effect to training when completions are short.
|
||||
*
|
||||
* If prompts are extremely long (relative to completions), it may make sense to reduce this weight so as to
|
||||
* avoid over-prioritizing learning the prompt.
|
||||
*/
|
||||
Double promptLossWeight;
|
||||
|
||||
/**
|
||||
* If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set
|
||||
* at the end of every epoch. These metrics can be viewed in the results file.
|
||||
*
|
||||
* In order to compute classification metrics, you must provide a validation_file.
|
||||
* Additionally, you must specify {@link FineTuneRequest#classificationNClasses} for multiclass
|
||||
* classification or {@link FineTuneRequest#classificationPositiveClass} for binary classification.
|
||||
*/
|
||||
Boolean computeClassificationMetrics;
|
||||
|
||||
/**
|
||||
* The number of classes in a classification task.
|
||||
*
|
||||
* This parameter is required for multiclass classification.
|
||||
*/
|
||||
Integer classificationNClasses; // todo verify snake case
|
||||
|
||||
/**
|
||||
* The positive class in binary classification.
|
||||
*
|
||||
* This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification.
|
||||
*/
|
||||
String classificationPositiveClass;
|
||||
|
||||
/**
|
||||
* If this is provided, we calculate F-beta scores at the specified beta values.
|
||||
* The F-beta score is a generalization of F-1 score. This is only used for binary classification.
|
||||
*
|
||||
* With a beta of 1 (i.e. the F-1 score), precision and recall are given the same weight.
|
||||
* A larger beta score puts more weight on recall and less on precision.
|
||||
* A smaller beta score puts more weight on precision and less on recall.
|
||||
*/
|
||||
List<Double> classificationBetas;
|
||||
|
||||
/**
|
||||
* A string of up to 40 characters that will be added to your fine-tuned model name.
|
||||
*/
|
||||
String suffix;
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package com.theokanning.openai.finetune;
|
||||
|
||||
import com.theokanning.openai.file.File;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* An object describing an fine-tuned model. Returned by multiple fine-tune requests.
|
||||
*
|
||||
* https://beta.openai.com/docs/api-reference/fine-tunes
|
||||
*/
|
||||
@Data
|
||||
public class FineTuneResult {
|
||||
/**
|
||||
* The ID of the fine-tuning job.
|
||||
*/
|
||||
String id;
|
||||
|
||||
/**
|
||||
* The type of object returned, should be "fine-tune".
|
||||
*/
|
||||
String object;
|
||||
|
||||
/**
|
||||
* The name of the base model.
|
||||
*/
|
||||
String model;
|
||||
|
||||
/**
|
||||
* The creation time in epoch seconds.
|
||||
*/
|
||||
Long createdAt;
|
||||
|
||||
/**
|
||||
* List of events in this job's lifecycle. Null when getting a list of fine-tune jobs.
|
||||
*/
|
||||
List<FineTuneEvent> events;
|
||||
|
||||
/**
|
||||
* The ID of the fine-tuned model, null if tuning job is not finished.
|
||||
* This is the id used to call the model.
|
||||
*/
|
||||
String fineTunedModel;
|
||||
|
||||
/**
|
||||
* The specified hyper-parameters for the tuning job.
|
||||
*/
|
||||
HyperParameters hyperparams;
|
||||
|
||||
/**
|
||||
* The ID of the organization this model belongs to.
|
||||
*/
|
||||
String organizationId;
|
||||
|
||||
/**
|
||||
* Result files for this fine-tune job.
|
||||
*/
|
||||
List<File> resultFiles;
|
||||
|
||||
/**
|
||||
* The status os the fine-tune job. "pending", "succeeded", or "cancelled"
|
||||
*/
|
||||
String status;
|
||||
|
||||
/**
|
||||
* Training files for this fine-tune job.
|
||||
*/
|
||||
List<File> trainingFiles;
|
||||
|
||||
/**
|
||||
* The last update time in epoch seconds.
|
||||
*/
|
||||
Long updatedAt;
|
||||
|
||||
/**
|
||||
* Validation files for this fine-tune job.
|
||||
*/
|
||||
List<File> validationFiles;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package com.theokanning.openai.finetune;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* Fine-tuning job hyperparameters
|
||||
*
|
||||
* https://beta.openai.com/docs/api-reference/fine-tunes
|
||||
*/
|
||||
@Data
|
||||
public class HyperParameters {
|
||||
|
||||
/**
|
||||
* The batch size to use for training.
|
||||
*/
|
||||
String batchSize;
|
||||
|
||||
/**
|
||||
* The learning rate multiplier to use for training.
|
||||
*/
|
||||
Double learningRateMultiplier;
|
||||
|
||||
/**
|
||||
* The number of epochs to train the model for.
|
||||
*/
|
||||
Integer nEpochs;
|
||||
|
||||
/**
|
||||
* The weight to use for loss on the prompt tokens.
|
||||
*/
|
||||
Double promptLossWeight;
|
||||
}
|
|
@ -7,6 +7,16 @@ dependencies {
|
|||
api 'com.squareup.retrofit2:retrofit:2.9.0'
|
||||
implementation 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
|
||||
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'
|
||||
|
||||
testImplementation(platform('org.junit:junit-bom:5.8.2'))
|
||||
testImplementation('org.junit.jupiter:junit-jupiter')
|
||||
}
|
||||
|
||||
test {
|
||||
useJUnitPlatform()
|
||||
testLogging {
|
||||
events "passed", "skipped", "failed"
|
||||
}
|
||||
}
|
||||
|
||||
ext {
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
package com.theokanning.openai;
|
||||
|
||||
import com.theokanning.openai.engine.Engine;
|
||||
import com.theokanning.openai.file.File;
|
||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneEvent;
|
||||
import com.theokanning.openai.finetune.FineTuneResult;
|
||||
import com.theokanning.openai.search.SearchRequest;
|
||||
import io.reactivex.Single;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.search.SearchResult;
|
||||
import retrofit2.http.Body;
|
||||
import retrofit2.http.GET;
|
||||
import retrofit2.http.POST;
|
||||
import retrofit2.http.Path;
|
||||
import okhttp3.MultipartBody;
|
||||
import okhttp3.RequestBody;
|
||||
import retrofit2.http.*;
|
||||
|
||||
public interface OpenAiApi {
|
||||
|
||||
|
@ -24,4 +27,39 @@ public interface OpenAiApi {
|
|||
|
||||
@POST("/v1/engines/{engine_id}/search")
|
||||
Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request);
|
||||
|
||||
@GET("/v1/files")
|
||||
Single<OpenAiResponse<File>> listFiles();
|
||||
|
||||
@Multipart
|
||||
@POST("/v1/files")
|
||||
Single<File> uploadFile(@Part("purpose") RequestBody purpose, @Part MultipartBody.Part file);
|
||||
|
||||
@DELETE("/v1/files/{file_id}")
|
||||
Single<DeleteResult> deleteFile(@Path("file_id") String fileId);
|
||||
|
||||
@GET("/v1/files/{file_id}")
|
||||
Single<File> retrieveFile(@Path("file_id") String fileId);
|
||||
|
||||
@POST("/v1/fine-tunes")
|
||||
Single<FineTuneResult> createFineTune(@Body FineTuneRequest request);
|
||||
|
||||
@POST("/v1/completions")
|
||||
Single<CompletionResult> createFineTuneCompletion(@Body CompletionRequest request);
|
||||
|
||||
@GET("/v1/fine-tunes")
|
||||
Single<OpenAiResponse<FineTuneResult>> listFineTunes();
|
||||
|
||||
@GET("/v1/fine-tunes/{fine_tune_id}")
|
||||
Single<FineTuneResult> retrieveFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
|
||||
@POST("/v1/fine-tunes/{fine_tune_id}/cancel")
|
||||
Single<FineTuneResult> cancelFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
|
||||
@GET("/v1/fine-tunes/{fine_tune_id}/events")
|
||||
Single<OpenAiResponse<FineTuneEvent>> listFineTuneEvents(@Path("fine_tune_id") String fineTuneId);
|
||||
|
||||
@DELETE("/v1/models/{fine_tune_id}")
|
||||
Single<DeleteResult> deleteFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
|
||||
}
|
||||
|
|
|
@ -4,9 +4,12 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
|||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
|
||||
import com.theokanning.openai.file.File;
|
||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneEvent;
|
||||
import com.theokanning.openai.finetune.FineTuneResult;
|
||||
import com.theokanning.openai.search.SearchRequest;
|
||||
import okhttp3.ConnectionPool;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.*;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.engine.Engine;
|
||||
|
@ -58,4 +61,53 @@ public class OpenAiService {
|
|||
public List<SearchResult> search(String engineId, SearchRequest request) {
|
||||
return api.search(engineId, request).blockingGet().data;
|
||||
}
|
||||
|
||||
public List<File> listFiles() {
|
||||
return api.listFiles().blockingGet().data;
|
||||
}
|
||||
|
||||
public File uploadFile(String purpose, String filepath) {
|
||||
java.io.File file = new java.io.File(filepath);
|
||||
RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose);
|
||||
RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file);
|
||||
MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody);
|
||||
|
||||
return api.uploadFile(purposeBody, body).blockingGet();
|
||||
}
|
||||
|
||||
public DeleteResult deleteFile(String fileId) {
|
||||
return api.deleteFile(fileId).blockingGet();
|
||||
}
|
||||
|
||||
public File retrieveFile(String fileId) {
|
||||
return api.retrieveFile(fileId).blockingGet();
|
||||
}
|
||||
|
||||
public FineTuneResult createFineTune(FineTuneRequest request) {
|
||||
return api.createFineTune(request).blockingGet();
|
||||
}
|
||||
|
||||
public CompletionResult createFineTuneCompletion(CompletionRequest request) {
|
||||
return api.createFineTuneCompletion(request).blockingGet();
|
||||
}
|
||||
|
||||
public List<FineTuneResult> listFineTunes() {
|
||||
return api.listFineTunes().blockingGet().data;
|
||||
}
|
||||
|
||||
public FineTuneResult retrieveFineTune(String fineTuneId) {
|
||||
return api.retrieveFineTune(fineTuneId).blockingGet();
|
||||
}
|
||||
|
||||
public FineTuneResult cancelFineTune(String fineTuneId) {
|
||||
return api.cancelFineTune(fineTuneId).blockingGet();
|
||||
}
|
||||
|
||||
public List<FineTuneEvent> listFineTuneEvents(String fineTuneId) {
|
||||
return api.listFineTuneEvents(fineTuneId).blockingGet().data;
|
||||
}
|
||||
|
||||
public DeleteResult deleteFineTune(String fineTuneId) {
|
||||
return api.deleteFineTune(fineTuneId).blockingGet();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package com.theokanning.openai;
|
||||
|
||||
import com.theokanning.openai.file.File;
|
||||
import org.junit.jupiter.api.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
public class FileTest {
|
||||
static String filePath = "src/test/resources/fine-tuning-data.jsonl";
|
||||
|
||||
String token = System.getenv("OPENAI_TOKEN");
|
||||
OpenAiService service = new OpenAiService(token);
|
||||
static String fileId;
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void uploadFile() throws Exception {
|
||||
File file = service.uploadFile("fine-tune", filePath);
|
||||
fileId = file.getId();
|
||||
|
||||
assertEquals("fine-tune", file.getPurpose());
|
||||
assertEquals(filePath, file.getFilename());
|
||||
|
||||
// wait for file to be processed
|
||||
TimeUnit.SECONDS.sleep(10);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void listFiles() {
|
||||
List<File> files = service.listFiles();
|
||||
|
||||
assertTrue(files.stream().anyMatch(file -> file.getId().equals(fileId)));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void retrieveFile() {
|
||||
File file = service.retrieveFile(fileId);
|
||||
|
||||
assertEquals(filePath, file.getFilename());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(4)
|
||||
void deleteFile() {
|
||||
DeleteResult result = service.deleteFile(fileId);
|
||||
assertTrue(result.isDeleted());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package com.theokanning.openai;
|
||||
|
||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneEvent;
|
||||
import com.theokanning.openai.finetune.FineTuneResult;
|
||||
import org.junit.jupiter.api.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
public class FineTuneTest {
|
||||
static OpenAiService service;
|
||||
static String fileId;
|
||||
static String fineTuneId;
|
||||
|
||||
|
||||
@BeforeAll
|
||||
static void setup() {
|
||||
String token = System.getenv("OPENAI_TOKEN");
|
||||
service = new OpenAiService(token);
|
||||
fileId = service.uploadFile("fine-tune", "src/test/resources/fine-tuning-data.jsonl").getId();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void teardown() {
|
||||
service.deleteFile(fileId);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
void createFineTune() {
|
||||
FineTuneRequest request = FineTuneRequest.builder()
|
||||
.trainingFile(fileId)
|
||||
.model("ada")
|
||||
.build();
|
||||
|
||||
FineTuneResult fineTune = service.createFineTune(request);
|
||||
fineTuneId = fineTune.getId();
|
||||
|
||||
assertEquals("pending", fineTune.getStatus());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
void listFineTunes() {
|
||||
List<FineTuneResult> fineTunes = service.listFineTunes();
|
||||
|
||||
assertTrue(fineTunes.stream().anyMatch(fineTune -> fineTune.getId().equals(fineTuneId)));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void listFineTuneEvents() {
|
||||
List<FineTuneEvent> events = service.listFineTuneEvents(fineTuneId);
|
||||
|
||||
assertFalse(events.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
void retrieveFineTune() {
|
||||
FineTuneResult fineTune = service.retrieveFineTune(fineTuneId);
|
||||
|
||||
assertEquals("ada", fineTune.getModel());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(4)
|
||||
void cancelFineTune() {
|
||||
FineTuneResult fineTune = service.cancelFineTune(fineTuneId);
|
||||
|
||||
assertEquals("cancelled", fineTune.getStatus());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
{"prompt": "prompt", "completion": "text"}
|
||||
{"prompt": "prompt", "completion": "text"}
|
Loading…
Reference in New Issue