Java programmers learn deep learning djl get started 9 using style transfer learning in cifar-10 dataset

Programming circle 2022-02-13 08:11:51 阅读数:555

java programmers learn deep learning

Java Programmers learn deep learning DJL Get started 9 stay CIFAR-10 Data sets use style transfer learning

One 、 explain

 Insert picture description here

This paper will use the style transfer learning model to train an image classification model . As mentioned earlier , Style transfer learning is to train a model for a certain problem , Then apply the model to the second problem . Compared with directly training models for specific problems , Style transfer learning can reduce the number of learning features , Produce more flexible models in less time .

This article USES the CIFAR-10 Data sets train our own models , The dataset contains 6 m 32*32 Color classification graphics .

The pre training model in this paper uses ResNet50v1 , It's using ImageNet Training includes 50 Deep learning model of layer , Over use 120 Ten thousand pictures 、 Have 1000 A classification . This article is revised ImageNet, And from CIFAR-10 Data set classification 10 Classes .

The experiment in this paper has not been successful , Failed to load predefined model

 Insert picture description here
CIFAR-10 Data sets

Two 、 The operation process

1. Load pre trained ResNet50V1 Model

ResNet50V1 Can be in ModelZoo Find . This model is in ImageNet Training on the dataset , Have 1000 Output categories . Because we are going to be in CIFAR10 Readjust to 10 A classification , So we're going to delete the last layer , And add have 10 A new linear layer of output channels . After modifying the block , Put the block back into the model and use .

// load model and change last layer
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optProgress(new ProgressBar())
.optArtifactId("resnet")
.optFilter("layers", "50")
.optFilter("flavor", "v1").build();
Model model = criteria.loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);

2. Prepare the dataset

When building a dataset , You can set up training 、 The size of the data set tested 、 Batch size , Set the pretreatment pipeline .
Pipes are used to preprocess data , For example, you can shape (32、32、3) And value from 0 To 256 Color image of NDArray And shape (3、32、32) And value from 0 convert to 1.
In addition, the input data can be normalized according to the mean and standard deviation of the input data .

int batchSize = 32;
int limit = Integer.MAX_VALUE; // change this to a small value for a dry run
// int limit = 160; // limit 160 records in the dataset for a dry run
Pipeline pipeline = new Pipeline(
new ToTensor(),
new Normalize(new float[] {
0.4914f, 0.4822f, 0.4465f}, new float[] {
0.2023f, 0.1994f, 0.2010f}));
Cifar10 trainDataset =
Cifar10.builder()
.setSampling(batchSize, true)
.optUsage(Dataset.Usage.TRAIN)
.optLimit(limit)
.optPipeline(pipeline)
.build();
trainDataset.prepare(new ProgressBar());
```
## 3. Set training parameters
We use a pre trained model , Just do it 10 Sub iteration .
```java
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
.addTrainingListeners(TrainingListener.Defaults.logging());
// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
```
## 4. Training models
```java
int epoch = 10;
Shape inputShape = new Shape(1, 3, 32, 32);
trainer.initialize(inputShape);
```
```java
for (int i = 0; i < epoch; ++i) {

int index = 0;
for (Batch batch : trainer.iterateDataset(trainDataset)) {

EasyTrain.trainBatch(trainer, batch);
trainer.step();
batch.close();
}
// reset training and validation evaluators at end of epoch
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
}
```
## 5. Save the model
```java
Path modelDir = Paths.get("build/resnet");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "resnet");
```
# Source code
```java
package com.xundh;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;
import java.io.IOException;
import java.nio.file.Paths;
public class PyTorchLearn {

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {

// According to the deep learning engine , Select the model to download 
// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
}
Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(ndList -> {

NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
inputs.add(data.getManager().arange(maxLength)
.toType(DataType.INT64, false)
.broadcast(data.getShape()));
} else {

inputs.add(data);
inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
}
// run embedding
try {

return embedder.predict(inputs);
} catch (TranslateException e) {

throw new IllegalArgumentException("embedding error", e);
}
})
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Activation::relu)
.add(Dropout.builder().optRate(0.2f).build())
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);
// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(embedding.getArtifact("vocab.txt"))
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
// int limit = Integer.MAX_VALUE;
int limit = 512; // uncomment for quick testing
BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(trainer -> {

TrainingResult result = trainer.getTrainingResult();
Model model1 = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model1.setProperty("Accuracy", String.format("%.5f", accuracy));
model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(new Device[]{
Device.cpu()}) // train using single GPU
.addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
.addTrainingListeners(listener);
int epoch = 2;
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());
model.save(Paths.get("build/model"), "amazon-review.param");
String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
System.out.println(predictor.predict(review));
}
/** * Download and create dataset objects */
static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {

String amazonReview = "https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
return CsvDataset.builder()
.optCsvUrl(amazonReview) // load from Url
.setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
.setSampling(batchSize, true) // make sample size and random access
.optLimit(limit)
.addFeature(new CsvDataset.Feature("review_body", new BertFeaturizer(tokenizer, maxLength)))
.addLabel(new CsvDataset.Feature("star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.optDataBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(false)
.addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
.build()) // define how to pad dataset to a fix length
.build();
}
}
```
pom.xml
```xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.12.0</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${
djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
</dependency>
</dependencies>
</project>
```
copyright:author[Programming circle],Please bring the original link to reprint, thank you. https://en.javamana.com/2022/02/202202130811493398.html