Add file and fine-tune api support (#6)

This commit is contained in:
Theo Kanning 2022-04-19 18:50:44 -05:00 committed by GitHub
parent d410abe1b5
commit 9a05c6ac77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 565 additions and 7 deletions

View File

@ -0,0 +1,24 @@
package com.theokanning.openai;
import lombok.Data;
/**
* A response when deleting an object
*/
@Data
public class DeleteResult {
/**
* The id of the object.
*/
String id;
/**
* The type of object deleted, for example "file" or "model"
*/
String object;
/**
* True if successfully deleted
*/
boolean deleted;
}

View File

@ -20,6 +20,11 @@ import java.util.List;
@Data
public class CompletionRequest {
/**
* The name of the model to use, only used if specifying a fine tuned model.
*/
String model;
/**
* An optional prompt to complete from
*/

View File

@ -22,7 +22,7 @@ public class CompletionResult {
String object;
/**
* The creation time in epoch milliseconds.
* The creation time in epoch seconds.
*/
long created;

View File

@ -0,0 +1,42 @@
package com.theokanning.openai.file;
import lombok.Data;
/**
* A file uploaded to OpenAi
*
* https://beta.openai.com/docs/api-reference/files
*/
@Data
public class File {
/**
* The unique id of this file.
*/
String id;
/**
* The type of object returned, should be "file".
*/
String object;
/**
* File size in bytes.
*/
Long bytes;
/**
* The creation time in epoch seconds.
*/
Long createdAt;
/**
* The name of the file.
*/
String filename;
/**
* Description of the file's purpose.
*/
String purpose;
}

View File

@ -0,0 +1,31 @@
package com.theokanning.openai.finetune;
import lombok.Data;
/**
* An object representing an event in the lifecycle of a fine-tuning job
*
* https://beta.openai.com/docs/api-reference/fine-tunes
*/
@Data
public class FineTuneEvent {
/**
* The type of object returned, should be "fine-tune-event".
*/
String object;
/**
* The creation time in epoch seconds.
*/
Long createdAt;
/**
* The log level of this message.
*/
String level;
/**
* The event message.
*/
String message;
}

View File

@ -0,0 +1,111 @@
package com.theokanning.openai.finetune;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* A request for OpenAi to create a fine-tuned model
* All fields except trainingFile are nullable.
*
* https://beta.openai.com/docs/api-reference/fine-tunes/create
*/
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class FineTuneRequest {
/**
* The ID of an uploaded file that contains training data.
*/
String trainingFile;
/**
* The ID of an uploaded file that contains validation data.
*/
String validationFile;
/**
* The name of the base model to fine-tune. You can select one of "ada", "babbage", "curie", or "davinci".
* To learn more about these models, see the Engines documentation.
*/
String model;
/**
* The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
*/
Integer nEpochs;
/**
* The batch size to use for training.
* The batch size is the number of training examples used to train a single forward and backward pass.
*
* By default, the batch size will be dynamically configured to be ~0.2% of the number of examples in the training
* set, capped at 256 - in general, we've found that larger batch sizes tend to work better for larger datasets.
*/
Integer batchSize;
/**
* The learning rate multiplier to use for training.
* The fine-tuning learning rate is the original learning rate used for pretraining multiplied by this value.
*
* By default, the learning rate multiplier is the 0.05, 0.1, or 0.2 depending on final batch_size
* (larger learning rates tend to perform better with larger batch sizes).
* We recommend experimenting with values in the range 0.02 to 0.2 to see what produces the best results.
*/
Double learningRateMultiplier;
/**
* The weight to use for loss on the prompt tokens.
* This controls how much the model tries to learn to generate the prompt
* (as compared to the completion which always has a weight of 1.0),
* and can add a stabilizing effect to training when completions are short.
*
* If prompts are extremely long (relative to completions), it may make sense to reduce this weight so as to
* avoid over-prioritizing learning the prompt.
*/
Double promptLossWeight;
/**
* If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set
* at the end of every epoch. These metrics can be viewed in the results file.
*
* In order to compute classification metrics, you must provide a validation_file.
* Additionally, you must specify {@link FineTuneRequest#classificationNClasses} for multiclass
* classification or {@link FineTuneRequest#classificationPositiveClass} for binary classification.
*/
Boolean computeClassificationMetrics;
/**
* The number of classes in a classification task.
*
* This parameter is required for multiclass classification.
*/
Integer classificationNClasses; // todo verify snake case
/**
* The positive class in binary classification.
*
* This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification.
*/
String classificationPositiveClass;
/**
* If this is provided, we calculate F-beta scores at the specified beta values.
* The F-beta score is a generalization of F-1 score. This is only used for binary classification.
*
* With a beta of 1 (i.e. the F-1 score), precision and recall are given the same weight.
* A larger beta score puts more weight on recall and less on precision.
* A smaller beta score puts more weight on precision and less on recall.
*/
List<Double> classificationBetas;
/**
* A string of up to 40 characters that will be added to your fine-tuned model name.
*/
String suffix;
}

View File

@ -0,0 +1,80 @@
package com.theokanning.openai.finetune;
import com.theokanning.openai.file.File;
import lombok.Data;
import java.util.List;
/**
* An object describing an fine-tuned model. Returned by multiple fine-tune requests.
*
* https://beta.openai.com/docs/api-reference/fine-tunes
*/
@Data
public class FineTuneResult {
/**
* The ID of the fine-tuning job.
*/
String id;
/**
* The type of object returned, should be "fine-tune".
*/
String object;
/**
* The name of the base model.
*/
String model;
/**
* The creation time in epoch seconds.
*/
Long createdAt;
/**
* List of events in this job's lifecycle. Null when getting a list of fine-tune jobs.
*/
List<FineTuneEvent> events;
/**
* The ID of the fine-tuned model, null if tuning job is not finished.
* This is the id used to call the model.
*/
String fineTunedModel;
/**
* The specified hyper-parameters for the tuning job.
*/
HyperParameters hyperparams;
/**
* The ID of the organization this model belongs to.
*/
String organizationId;
/**
* Result files for this fine-tune job.
*/
List<File> resultFiles;
/**
* The status os the fine-tune job. "pending", "succeeded", or "cancelled"
*/
String status;
/**
* Training files for this fine-tune job.
*/
List<File> trainingFiles;
/**
* The last update time in epoch seconds.
*/
Long updatedAt;
/**
* Validation files for this fine-tune job.
*/
List<File> validationFiles;
}

View File

@ -0,0 +1,32 @@
package com.theokanning.openai.finetune;
import lombok.Data;
/**
* Fine-tuning job hyperparameters
*
* https://beta.openai.com/docs/api-reference/fine-tunes
*/
@Data
public class HyperParameters {
/**
* The batch size to use for training.
*/
String batchSize;
/**
* The learning rate multiplier to use for training.
*/
Double learningRateMultiplier;
/**
* The number of epochs to train the model for.
*/
Integer nEpochs;
/**
* The weight to use for loss on the prompt tokens.
*/
Double promptLossWeight;
}

View File

@ -7,6 +7,16 @@ dependencies {
api 'com.squareup.retrofit2:retrofit:2.9.0'
implementation 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'
testImplementation(platform('org.junit:junit-bom:5.8.2'))
testImplementation('org.junit.jupiter:junit-jupiter')
}
test {
useJUnitPlatform()
testLogging {
events "passed", "skipped", "failed"
}
}
ext {

View File

@ -1,15 +1,18 @@
package com.theokanning.openai;
import com.theokanning.openai.engine.Engine;
import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneResult;
import com.theokanning.openai.search.SearchRequest;
import io.reactivex.Single;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.search.SearchResult;
import retrofit2.http.Body;
import retrofit2.http.GET;
import retrofit2.http.POST;
import retrofit2.http.Path;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import retrofit2.http.*;
public interface OpenAiApi {
@ -24,4 +27,39 @@ public interface OpenAiApi {
@POST("/v1/engines/{engine_id}/search")
Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request);
@GET("/v1/files")
Single<OpenAiResponse<File>> listFiles();
@Multipart
@POST("/v1/files")
Single<File> uploadFile(@Part("purpose") RequestBody purpose, @Part MultipartBody.Part file);
@DELETE("/v1/files/{file_id}")
Single<DeleteResult> deleteFile(@Path("file_id") String fileId);
@GET("/v1/files/{file_id}")
Single<File> retrieveFile(@Path("file_id") String fileId);
@POST("/v1/fine-tunes")
Single<FineTuneResult> createFineTune(@Body FineTuneRequest request);
@POST("/v1/completions")
Single<CompletionResult> createFineTuneCompletion(@Body CompletionRequest request);
@GET("/v1/fine-tunes")
Single<OpenAiResponse<FineTuneResult>> listFineTunes();
@GET("/v1/fine-tunes/{fine_tune_id}")
Single<FineTuneResult> retrieveFineTune(@Path("fine_tune_id") String fineTuneId);
@POST("/v1/fine-tunes/{fine_tune_id}/cancel")
Single<FineTuneResult> cancelFineTune(@Path("fine_tune_id") String fineTuneId);
@GET("/v1/fine-tunes/{fine_tune_id}/events")
Single<OpenAiResponse<FineTuneEvent>> listFineTuneEvents(@Path("fine_tune_id") String fineTuneId);
@DELETE("/v1/models/{fine_tune_id}")
Single<DeleteResult> deleteFineTune(@Path("fine_tune_id") String fineTuneId);
}

View File

@ -4,9 +4,12 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.theokanning.openai.file.File;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneResult;
import com.theokanning.openai.search.SearchRequest;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import okhttp3.*;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.engine.Engine;
@ -58,4 +61,53 @@ public class OpenAiService {
public List<SearchResult> search(String engineId, SearchRequest request) {
return api.search(engineId, request).blockingGet().data;
}
public List<File> listFiles() {
return api.listFiles().blockingGet().data;
}
public File uploadFile(String purpose, String filepath) {
java.io.File file = new java.io.File(filepath);
RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose);
RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file);
MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody);
return api.uploadFile(purposeBody, body).blockingGet();
}
public DeleteResult deleteFile(String fileId) {
return api.deleteFile(fileId).blockingGet();
}
public File retrieveFile(String fileId) {
return api.retrieveFile(fileId).blockingGet();
}
public FineTuneResult createFineTune(FineTuneRequest request) {
return api.createFineTune(request).blockingGet();
}
public CompletionResult createFineTuneCompletion(CompletionRequest request) {
return api.createFineTuneCompletion(request).blockingGet();
}
public List<FineTuneResult> listFineTunes() {
return api.listFineTunes().blockingGet().data;
}
public FineTuneResult retrieveFineTune(String fineTuneId) {
return api.retrieveFineTune(fineTuneId).blockingGet();
}
public FineTuneResult cancelFineTune(String fineTuneId) {
return api.cancelFineTune(fineTuneId).blockingGet();
}
public List<FineTuneEvent> listFineTuneEvents(String fineTuneId) {
return api.listFineTuneEvents(fineTuneId).blockingGet().data;
}
public DeleteResult deleteFineTune(String fineTuneId) {
return api.deleteFineTune(fineTuneId).blockingGet();
}
}

View File

@ -0,0 +1,55 @@
package com.theokanning.openai;
import com.theokanning.openai.file.File;
import org.junit.jupiter.api.*;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
public class FileTest {
static String filePath = "src/test/resources/fine-tuning-data.jsonl";
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);
static String fileId;
@Test
@Order(1)
void uploadFile() throws Exception {
File file = service.uploadFile("fine-tune", filePath);
fileId = file.getId();
assertEquals("fine-tune", file.getPurpose());
assertEquals(filePath, file.getFilename());
// wait for file to be processed
TimeUnit.SECONDS.sleep(10);
}
@Test
@Order(2)
void listFiles() {
List<File> files = service.listFiles();
assertTrue(files.stream().anyMatch(file -> file.getId().equals(fileId)));
}
@Test
@Order(3)
void retrieveFile() {
File file = service.retrieveFile(fileId);
assertEquals(filePath, file.getFilename());
}
@Test
@Order(4)
void deleteFile() {
DeleteResult result = service.deleteFile(fileId);
assertTrue(result.isDeleted());
}
}

View File

@ -0,0 +1,76 @@
package com.theokanning.openai;
import com.theokanning.openai.finetune.FineTuneRequest;
import com.theokanning.openai.finetune.FineTuneEvent;
import com.theokanning.openai.finetune.FineTuneResult;
import org.junit.jupiter.api.*;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
public class FineTuneTest {
static OpenAiService service;
static String fileId;
static String fineTuneId;
@BeforeAll
static void setup() {
String token = System.getenv("OPENAI_TOKEN");
service = new OpenAiService(token);
fileId = service.uploadFile("fine-tune", "src/test/resources/fine-tuning-data.jsonl").getId();
}
@AfterAll
static void teardown() {
service.deleteFile(fileId);
}
@Test
@Order(1)
void createFineTune() {
FineTuneRequest request = FineTuneRequest.builder()
.trainingFile(fileId)
.model("ada")
.build();
FineTuneResult fineTune = service.createFineTune(request);
fineTuneId = fineTune.getId();
assertEquals("pending", fineTune.getStatus());
}
@Test
@Order(2)
void listFineTunes() {
List<FineTuneResult> fineTunes = service.listFineTunes();
assertTrue(fineTunes.stream().anyMatch(fineTune -> fineTune.getId().equals(fineTuneId)));
}
@Test
@Order(3)
void listFineTuneEvents() {
List<FineTuneEvent> events = service.listFineTuneEvents(fineTuneId);
assertFalse(events.isEmpty());
}
@Test
@Order(3)
void retrieveFineTune() {
FineTuneResult fineTune = service.retrieveFineTune(fineTuneId);
assertEquals("ada", fineTune.getModel());
}
@Test
@Order(4)
void cancelFineTune() {
FineTuneResult fineTune = service.cancelFineTune(fineTuneId);
assertEquals("cancelled", fineTune.getStatus());
}
}

View File

@ -0,0 +1,2 @@
{"prompt": "prompt", "completion": "text"}
{"prompt": "prompt", "completion": "text"}