Java programmers learn deep learning djl get started 9 using style transfer learning in cifar-10 dataset

Programming circle 2022-02-13 08:11:51 阅读数:555

java programmers learn deep learning

Java Programmers learn deep learning DJL Get started 9 stay CIFAR-10 Data sets use style transfer learning

One 、 explain

 Insert picture description here

This paper will use the style transfer learning model to train an image classification model . As mentioned earlier , Style transfer learning is to train a model for a certain problem , Then apply the model to the second problem . Compared with directly training models for specific problems , Style transfer learning can reduce the number of learning features , Produce more flexible models in less time .

This article USES the CIFAR-10 Data sets train our own models , The dataset contains 6 m 32*32 Color classification graphics .

The pre training model in this paper uses ResNet50v1 , It's using ImageNet Training includes 50 Deep learning model of layer , Over use 120 Ten thousand pictures 、 Have 1000 A classification . This article is revised ImageNet, And from CIFAR-10 Data set classification 10 Classes .

The experiment in this paper has not been successful , Failed to load predefined model

 Insert picture description here
CIFAR-10 Data sets

Two 、 The operation process

1. Load pre trained ResNet50V1 Model

ResNet50V1 Can be in ModelZoo Find . This model is in ImageNet Training on the dataset , Have 1000 Output categories . Because we are going to be in CIFAR10 Readjust to 10 A classification , So we're going to delete the last layer , And add have 10 A new linear layer of output channels . After modifying the block , Put the block back into the model and use .

// load model and change last layer
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optProgress(new ProgressBar())
.optFilter("layers", "50")
.optFilter("flavor", "v1").build();
Model model = criteria.loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();

2. Prepare the dataset

When building a dataset , You can set up training 、 The size of the data set tested 、 Batch size , Set the pretreatment pipeline .
Pipes are used to preprocess data , For example, you can shape (32、32、3) And value from 0 To 256 Color image of NDArray And shape (3、32、32) And value from 0 convert to 1.
In addition, the input data can be normalized according to the mean and standard deviation of the input data .

int batchSize = 32;
int limit = Integer.MAX_VALUE; // change this to a small value for a dry run
// int limit = 160; // limit 160 records in the dataset for a dry run
Pipeline pipeline = new Pipeline(
new ToTensor(),
new Normalize(new float[] {
0.4914f, 0.4822f, 0.4465f}, new float[] {
0.2023f, 0.1994f, 0.2010f}));
Cifar10 trainDataset =
.setSampling(batchSize, true)
trainDataset.prepare(new ProgressBar());
## 3. Set training parameters
We use a pre trained model , Just do it 10 Sub iteration .
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);
## 4. Training models
int epoch = 10;
Shape inputShape = new Shape(1, 3, 32, 32);
for (int i = 0; i < epoch; ++i) {

int index = 0;
for (Batch batch : trainer.iterateDataset(trainDataset)) {

EasyTrain.trainBatch(trainer, batch);
// reset training and validation evaluators at end of epoch
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
## 5. Save the model
Path modelDir = Paths.get("build/resnet");
model.setProperty("Epoch", String.valueOf(epoch));, "resnet");
# Source code
package com.xundh;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;
import java.nio.file.Paths;
public class PyTorchLearn {

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {

// According to the deep learning engine , Select the model to download 
// MXNet base model
String modelUrls = "";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

modelUrls = "";
Criteria<NDList, NDList> criteria = Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optProgress(new ProgressBar())
ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(ndList -> {

NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
.toType(DataType.INT64, false)
} else {

inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
// run embedding
try {

return embedder.predict(inputs);
} catch (TranslateException e) {

throw new IllegalArgumentException("embedding error", e);
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
// int limit = Integer.MAX_VALUE;
int limit = 512; // uncomment for quick testing
BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(trainer -> {

TrainingResult result = trainer.getTrainingResult();
Model model1 = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model1.setProperty("Accuracy", String.format("%.5f", accuracy));
model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(new Device[]{
Device.cpu()}) // train using single GPU
int epoch = 2;
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());"build/model"), "amazon-review.param");
String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
/** * Download and create dataset objects */
static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {

String amazonReview = "";
float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
return CsvDataset.builder()
.optCsvUrl(amazonReview) // load from Url
.setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
.setSampling(batchSize, true) // make sample size and random access
.addFeature(new CsvDataset.Feature("review_body", new BertFeaturizer(tokenizer, maxLength)))
.addLabel(new CsvDataset.Feature("star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
.build()) // define how to pad dataset to a fix length
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="" xmlns:xsi=""
<!-- Pytorch -->
copyright:author[Programming circle],Please bring the original link to reprint, thank you.