Java programmers learn deep learning djl get started 7 using pytoch engine

Programming circle 2022-02-13 08:12:03 阅读数:447

java programmers learn deep learning

This article mainly explains how to use the DJL call Pytorch Engine and use Pytorch The object of . because DJL Only support ScriptTorch Format , So one's own PyTorch The model needs format conversion . The previous section of this article talked about the way of transformation , The following demonstration loads the converted from the network ScriptTorch Format model .

One 、DJL The project in maven quote Pytorch engine

1. quote pytorch-engin

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.13.0-SNAPSHOT</version>
<scope>runtime</scope>
</dependency>

2. quote pytorch-native-auto library

at present pytoch-engin Each version of can only match one pytorch edition , The correspondence is as follows :

PyTorch engine version PyTorch native library version
pytorch-engine:0.13.0 pytorch-native-auto:1.9.0
pytorch-engine:0.12.0 pytorch-native-auto:1.8.1
pytorch-engine:0.11.0 pytorch-native-auto:1.8.1
pytorch-engine:0.10.0 pytorch-native-auto:1.7.1
pytorch-engine:0.9.0 pytorch-native-auto:1.7.0
pytorch-engine:0.8.0 pytorch-native-auto:1.6.0
pytorch-engine:0.7.0 pytorch-native-auto:1.6.0
pytorch-engine:0.6.0 pytorch-native-auto:1.5.0
pytorch-engine:0.5.0 pytorch-native-auto:1.4.0
pytorch-engine:0.4.0 pytorch-native-auto:1.4.0

Examples of use :

<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
<scope>runtime</scope>
</dependency>

This step is similar to CPU、 Operation architecture 、GPU It also matters , but pytorch-native-auto Will automatically match the corresponding version .
If there is a problem with adaptation , You can go to http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html Query the library required by the corresponding schema and modify it manually .

Two 、PyTorch Of Model Zoo Pre training model library


<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.13.0-SNAPSHOT</version>
</dependency>

ModelZoo The pre training model is mainly machine vision model , Include :

  • Image classification
  • object detection
  • Style transfer
  • Image generation
    etc. .

3、 ... and 、PyTorch Model format conversion

Need to put PyTorch The model is changed to TorchScript Format , There are two main ways of conversion : track (Tracing) And scripts (Scripting).
Tracing Example script for :

import torch
import torchvision
# Point to your own model 
model = torchvision.models.resnet18(pretrained=True)
# Switch to test mode 
model.eval()
# Provide a sample data to the forward processing of the model (forward) Method 
example = torch.rand(1, 3, 224, 224)
# perform Trace
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# preservation TorchScript Model 
traced_script_module.save("traced_resnet_model.pt")

Four 、 load PyTorch Model

1. Prepare the model

The following example assumes that you are ready TorchScript Format model , Here we use pre trained resnet18 Model ,
DownloadUtils Function is used to download the model on the network , The destination folder is build/pytorch_models.

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| resnet18.pt

coordination resnet18 The model also has a label file , Also use DownloadUtils download .

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| synset.txt

2. Create Converter (Translator)

First create a pipe ( The preprocessing of each image ):

preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Then create the converter :

Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {
0.485f, 0.456f, 0.406f},
new float[] {
0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();

3. Load your own model

Some parameters are required when loading the model , Such as optModelPath Tell the location of the model .

Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = criteria.loadModel();
Loading: 100% |████████████████████████████████████████|

4. Load classifier

var img = ImageFactory.getInstance().fromUrl("https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg");
img.getWrappedImage()

5. To carry out reasoning

Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);

Print the results :

classifications
[
class: "n02111889 Samoyed, Samoyede", probability: 0.94256
class: "n02114548 white wolf, Arctic wolf, Canis lupus tundrarum", probability: 0.02820
class: "n02111500 Great Pyrenees", probability: 0.01032
class: "n02120079 Arctic fox, white fox, Alopex lagopus", probability: 0.00412
class: "n02109961 Eskimo dog, husky", probability: 0.00279
]

5、 ... and 、 Source code

1. pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.13.0-SNAPSHOT</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.9.0</version>
</dependency>
</dependencies>
</project>

2. java

package com.xundh;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.nio.file.Paths;
public class PyTorchLearn {

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.addTransform(new Resize(256))
.addTransform(new CenterCrop(224, 224))
.addTransform(new ToTensor())
.addTransform(new Normalize(
new float[] {
0.485f, 0.456f, 0.406f},
new float[] {
0.229f, 0.224f, 0.225f}))
.optApplySoftmax(true)
.build();
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("build/pytorch_models/resnet18"))
.optTranslator(translator)
.optProgress(new ProgressBar()).build();
ZooModel model = criteria.loadModel();
Image img = ImageFactory.getInstance().fromUrl("https://img-blog.csdnimg.cn/4c1c40b41c6a49afa69f7ccf96e24ddf.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA57yW56iL5ZyI5a2Q,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}
}

6、 ... and 、 Load local model

package com.xundh;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
public class PyTorchLearn {

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {

DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
Path modelDir = Paths.get("build/pytorch_models/resnet18");
Model model = Model.newInstance("resnet");
model.load(modelDir, "resnet18");
Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optSynsetArtifactName("synset.txt")
.optApplySoftmax(true)
.build();
Image img = ImageFactory.getInstance().fromUrl("https://img-blog.csdnimg.cn/4c1c40b41c6a49afa69f7ccf96e24ddf.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA57yW56iL5ZyI5a2Q,size_20,color_FFFFFF,t_70,g_se,x_16#pic_center");
img.getWrappedImage();
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);
System.out.println(classifications);
}
}

7、 ... and 、 Model optimization suggestions

See link address :
https://github.com/deepjavalibrary/djl/blob/master/docs/pytorch/how_to_optimize_inference_performance.md

copyright:author[Programming circle],Please bring the original link to reprint, thank you. https://en.javamana.com/2022/02/202202130812007803.html