Java programmers learn deep learning djl get started 8 using style transfer learning

Programming circle 2022-02-13 08:11:57 阅读数:117

java programmers learn deep learning

One 、 Introduction to style transfer learning

 Insert picture description here

1. Style transfer learning

Style transfer , English name :Transfer learning, It is a kind of machine learning , It is on the premise that there is some additional data and there is an existing model , To generate target data , Typical applications, such as generating new paintings ,2015 Year by year Gatys Et al. Published an article 《A Neural Algorithm of Artistic Style》, For the first time, use deep learning to learn art painting style .

2. BERT

BERT The full name is Bidirectional Encoder Representation from Transformers, It is a pre trained language representation model , It is mainly used as a participle of natural language .

3. DistilBERT

BERT The number of parameters is huge , Huge space is needed in operation 、 Consume a lot of resources , and DistilBERT On the other hand Bert Keep fit .

Two 、 Implementation process

1. explain

Here we use Amazon's comment data set , The commodity category is digital software , contain 10.2 Million valid comments . Selected pre training model DistilBERT It's a lightweight BERT Model , Has been trained using Wikipedia's more than 100 percent text corpus .DistilBERT As a basic layer, it is added to the classification model to output the results of the review , The star range is 1-5.
Comment data will be passed in as data , And the score is used as a label .
Amazon comment example :
 Insert picture description here

2. Prepare the dataset

The first is to prepare the data set , The raw data is TSV Format , Use here CSVDataset As a data container , Use Featurizer Interface to the rows of raw data / Preprocess Columns , To achieve feature extraction .

final class BertFeaturizer implements CsvDataset.Featurizer {

private final BertFullTokenizer tokenizer;
private final int maxLength; // the cut-off length
public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {

this.tokenizer = tokenizer;
this.maxLength = maxLength;
}
/** {@inheritDoc} */
@Override
public void featurize(DynamicBuffer buf, String input) {

SimpleVocabulary vocab = tokenizer.getVocabulary();
// convert sentence to tokens (toLowerCase for uncased model)
List<String> tokens = tokenizer.tokenize(input.toLowerCase());
// beyond maxLength Interception of 
tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;
// BERT embedding convention "[CLS] Your Sentence [SEP]"
buf.put(vocab.getIndex("[CLS]"));
tokens.forEach(token -> buf.put(vocab.getIndex(token)));
buf.put(vocab.getIndex("[SEP]"));
}
}

about BERT Model , We construct a BertFeaturizer object , Realization CsvDataset.Featurizer Method for feature extraction . In this example, the data is simply cleaned up .

3. hold BertFeaturizer Apply to datasets

CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {

String amazonReview =
"https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
return CsvDataset.builder()
.optCsvUrl(amazonReview) // load from Url
.setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
.setSampling(batchSize, true) // make sample size and random access
.optLimit(limit)
.addFeature(new CsvDataset.Feature("review_body", new BertFeaturizer(tokenizer, maxLength)))
.addLabel(new CsvDataset.Feature("star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.optDataBatchifier(
PaddingStackBatchifier.builder()
.optIncludeValidLengths(false)
.addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
.build()) // define how to pad dataset to a fix length
.build();
}

Apply the... Defined above to the column BertFeaturizer, Score as tag set . In addition, when the extracted word is shorter than our definition , The data filling method is also defined .

4. Build a model

To download DistiledBERT Model , Then download the weight of pre training . The downloaded model does not contain a classification layer , We also need to add a classification layer at the end of the construction model, and then train . After modifying the block , Use .criteria loadModel setBlock Put the model .

2.4.1 Load model

// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
}
Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();

2.4.2 Create a classification layer

Here you create a simple MLP Layer is used to classify comment levels , The last full connection layer output 5 A numerical , Used to correspond to the evaluation 5 A level .
The embedded text will also be processed at the front of the layer .
Then load the block into the model .

Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(
ndList -> {

NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
inputs.add(data.getManager().arange(maxLength)
.toType(DataType.INT64, false)
.broadcast(data.getShape()));
} else {

inputs.add(data);
inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
}
// run embedding
try {

return embedder.predict(inputs);
} catch (TranslateException e) {

throw new IllegalArgumentException("embedding error", e);
}
})
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Activation::relu) // Activation function 
.add(Dropout.builder().optRate(0.2f).build())
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);

5. Start training

2.5.1 Create training sets and test sets

First create a word list , Turn words to numbers . Then feed the alphabet to tokenizer Feature extractor .
Last , The data set should be divided into training set and test set in proportion .

tokens The maximum length is set to 64, This means that there are only 64 A feature participle will be used .

// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(embedding.getArtifact("vocab.txt"))
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing
BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];

2.5.2 Create a training listener to track the training process

Here we should pay attention to the accuracy of the setting 、 Loss function . The training log will be saved to build/model1 in .

SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(
trainer -> {

TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(Device.getDevices(1)) // train using single GPU
.addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
.addTrainingListeners(listener);

2.5.3 Training

int epoch = 2;
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());

2.5.4 Save the model

model.save(Paths.get("build/model"), "amazon-review.param");

2.5.5 Validate the model

Create a predictor from the model , Then use your own data for training , To verify the effect of the model .


class MyTranslator implements Translator<String, Classifications> {

private BertFullTokenizer tokenizer;
private SimpleVocabulary vocab;
private List<String> ranks;
public MyTranslator(BertFullTokenizer tokenizer) {

this.tokenizer = tokenizer;
vocab = tokenizer.getVocabulary();
ranks = Arrays.asList("1", "2", "3", "4", "5");
}
@Override
public Batchifier getBatchifier() {
 return new StackBatchifier(); }
@Override
public NDList processInput(TranslatorContext ctx, String input) {

List<String> tokens = tokenizer.tokenize(input);
float[] indices = new float[tokens.size() + 2];
indices[0] = vocab.getIndex("[CLS]");
for (int i = 0; i < tokens.size(); i++) {

indices[i+1] = vocab.getIndex(tokens.get(i));
}
indices[indices.length - 1] = vocab.getIndex("[SEP]");
return new NDList(ctx.getNDManager().create(indices));
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {

return new Classifications(ranks, list.singletonOrThrow().softmax(0));
}
}

Create a predictor :

String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
System.out.println(predictor.predict(review));

3、 ... and 、 Source program

 Insert picture description here

PyTorchLearn

package com.xundh;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;
import java.io.IOException;
import java.nio.file.Paths;
public class PyTorchLearn {

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {

// According to the deep learning engine , Select the model to download 
// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
}
Criteria<NDList, NDList> criteria = Criteria.builder()
.optApplication(Application.NLP.WORD_EMBEDDING)
.setTypes(NDList.class, NDList.class)
.optModelUrls(modelUrls)
.optProgress(new ProgressBar())
.build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
// text embedding layer
.add(
ndList -> {

NDArray data = ndList.singletonOrThrow();
NDList inputs = new NDList();
long batchSize = data.getShape().get(0);
float maxLength = data.getShape().get(1);
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {

inputs.add(data.toType(DataType.INT64, false));
inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
inputs.add(data.getManager().arange(maxLength)
.toType(DataType.INT64, false)
.broadcast(data.getShape()));
} else {

inputs.add(data);
inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
}
// run embedding
try {

return embedder.predict(inputs);
} catch (TranslateException e) {

throw new IllegalArgumentException("embedding error", e);
}
})
// classification layer
.add(Linear.builder().setUnits(768).build()) // pre classifier
.add(Activation::relu)
.add(Dropout.builder().optRate(0.2f).build())
.add(Linear.builder().setUnits(5).build()) // 5 star rating
.addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);
// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
.optMinFrequency(1)
.addFromTextFile(embedding.getArtifact("vocab.txt"))
.optUnknownToken("[UNK]")
.build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing
BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];
SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(
trainer -> {

TrainingResult result = trainer.getTrainingResult();
Model model1 = trainer.getModel();
// track for accuracy and loss
float accuracy = result.getValidateEvaluation("Accuracy");
model1.setProperty("Accuracy", String.format("%.5f", accuracy));
model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
.addEvaluator(new Accuracy())
.optDevices(new Device[]{
Device.cpu()}) // train using single GPU
.addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
.addTrainingListeners(listener);
int epoch = 2;
Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());
model.save(Paths.get("build/model"), "amazon-review.param");
}
/** * Download and create dataset objects */
static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {

String amazonReview =
"https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
return CsvDataset.builder()
.optCsvUrl(amazonReview) // load from Url
.setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
.setSampling(batchSize, true) // make sample size and random access
.optLimit(limit)
.addFeature(
new CsvDataset.Feature(
"review_body", new BertFeaturizer(tokenizer, maxLength)))
.addLabel(
new CsvDataset.Feature(
"star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
.optDataBatchifier(
PaddingStackBatchifier.builder()
.optIncludeValidLengths(false)
.addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
.build()) // define how to pad dataset to a fix length
.build();
}
}

BertFeaturizer

package com.xundh;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.basicdataset.utils.DynamicBuffer;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import java.util.List;
final class BertFeaturizer implements CsvDataset.Featurizer {

private final BertFullTokenizer tokenizer;
private final int maxLength; // the cut-off length
public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {

this.tokenizer = tokenizer;
this.maxLength = maxLength;
}
/** {@inheritDoc} */
@Override
public void featurize(DynamicBuffer buf, String input) {

SimpleVocabulary vocab = tokenizer.getVocabulary();
// convert sentence to tokens (toLowerCase for uncased model)
List<String> tokens = tokenizer.tokenize(input.toLowerCase());
// trim the tokens to maxLength
tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;
// BERT embedding convention "[CLS] Your Sentence [SEP]"
buf.put(vocab.getIndex("[CLS]"));
tokens.forEach(token -> buf.put(vocab.getIndex(token)));
buf.put(vocab.getIndex("[SEP]"));
}
}

Running effect :
 Insert picture description here

copyright:author[Programming circle],Please bring the original link to reprint, thank you. https://en.javamana.com/2022/02/202202130811537950.html