Deploy the deep learning model on Apache spark with Scala in 10 minutes with a few easy steps


Deploy the deep learning model on Apache spark with Scala in 10 minutes with a few easy steps

Summary of the article

This paper takes an image classification model as an example to guide you to deploy tensorflow, pytorch and mxnet models in big data production environment by using djl on Apache spark.


Deep learning is widely used in the field of big data, but there are few deployment schemes on Java / Scala. Amazon’s open source project team has found another way to help users deploy deep learning applications on spark with djl. In 10 minutes, you can easily deploy tensorflow, pytorch, and mxnet models in a big data production environment.

Apache spark is an excellent big data processing tool. In the field of machine learning, spark can be used to classify data, predict demand and make personalized recommendation. Although spark supports multiple languages, most spark tasks are set and deployed through Scala. However, Scala does not have a good platform to support deep learning. Most of the deep learning applications are deployed on Python and related frameworks, which causes a headache for Scala developers: whether to write the whole spark architecture in Python or wrap Python code in pipeline with Scala. Both options increase workload and maintenance costs. Moreover, at present, pyspark’s performance is not as good as Scala’s multithreading in deep learning multi process support, resulting in many deep learning application speed stuck here.

Today, we’ll show you a new solution that uses Scala calls directlyDeep Java Library(djl) to implement deep learning application deployment. Djl will fully release Spark’s powerful multi-threaded processing performance and easily speed up 2-5 times * existing reasoning tasks. Djl is a Java deep learning library tailored for spark. It is not limited to the engine, users can easily deploy pytorch, tensorflow and mxnet models on spark. In this blog, we use djl to complete the deployment of a picture classification model. You can also use thehereRefer to the complete code.

Image classification: djl + spark

We will useResnet50To deploy a reasoning task, a pre training image classification model is proposed. In order to simplify the configuration process, we will only set up a single cluster and multiple virtual worker nodes locally for reasoning. This is the general workflow:

Deploy the deep learning model on Apache spark with Scala in 10 minutes with a few easy steps

Spark generates multiple executors to start each JVM process, and each processing task is sent to the executor for execution. Each extractor has an independently allocated kernel and memory. The specific task execution will use multithreading completely. In big data processing, this architecture can help each worker allocate a reasonable amount of data.

The first step is to build a spark project

through the use ofsbtWe can easily build Scala projects. For more information on SBT, please refer tohere。 It can be easily set through the following template:

name := "sparkExample"

version := "0.1"

//Djl requires JVM 1.8 and above
scalaVersion := "2.11.12"
scalacOptions += "-target:jvm-1.8"

resolvers += Resolver.mavenLocal

libraryDependencies += "org.apache.spark" %% "spark-core" % "2.3.0"

libraryDependencies += "ai.djl" % "api" % "0.5.0"
libraryDependencies += "ai.djl" % "repository" % "0.5.0"
//Using the mxnet engine
libraryDependencies += "ai.djl.mxnet" % "mxnet-model-zoo" % "0.5.0"
libraryDependencies += "ai.djl.mxnet" % "mxnet-native-auto" % "1.6.0"

The project uses mxnet as the default engine. You can change the use of pytorch by modifying the following two lines:

//Using the pytorch engine
libraryDependencies += "ai.djl.pytorch" % "pytorch-model-zoo" % "0.5.0"
libraryDependencies += "ai.djl.pytorch" % "pytorch-native-auto" % "1.5.0"

The second step is to configure spark

We use the following configuration to run spark locally:

//Spark settings
val conf = new SparkConf()
  . setappname ("picture classification task")
  .setExecutorEnv("MXNET_ENGINE_TYPE", "NaiveEngine")
val sc = new SparkContext(conf)

Mxnet multithreading requires additional naiveengine environment variables to be set. If pytorch or tensorflow is used, this line can be deleted:

 .setExecutorEnv("MXNET_ENGINE_TYPE", "NaiveEngine")

The third step is to set the input data

The input data is a folder containing multiple images. Spark will read these images into different partitions. Each partition is distributed to different executors. Let’s configure the image distribution process:

val partitions = sc.binaryFiles("images/*")

Step 4: set up spark job

In this step, we will create a spark calculation diagram for model reading and reasoning. Since the reasoning of each image will be completed in multi thread, we need to set the executor before reasoning:

//Start distributing tasks to the worker node
val result = partitions.mapPartitions( partition => {
   //Preparing the deep learning model: building a filter
    val criteria = Criteria.builder
        //Image classification model
        .setTypes(classOf[BufferedImage], classOf[Classifications])
        .optFilter("dataset", "imagenet")
        //Resnet50 settings
        .optFilter("layers", "50")
        .optProgress(new ProgressBar)
   val model = ModelZoo.loadModel(criteria)
   //Establish predictor
    val predictor = model.newPredictor()
   //Multithreading reasoning => {
        val img =

Djl introduced aModelZooTo set the read model through criteria. Then create the predictor in the partition. In the process of image classification, we read images from RDD and infer them. The resnet50 model used this time isImagenet datasetPre training model.

Step 5 set the output

When we have completed the process of map data, we need to let the master node collect data:

//Print the results of reasoning
//Store in output folder

Running the above two lines of code will drive the spark open task, and the output file will be saved in the output folder. Please refer toScala exampleTo run the complete code.

If you run the sample code, this is the output:

    class: "n02085936 Maltese dog, Maltese terrier, Maltese", probability: 0.81445
    class: "n02096437 Dandie Dinmont, Dandie Dinmont terrier", probability: 0.08678
    class: "n02098286 West Highland white terrier", probability: 0.03561
    class: "n02113624 toy poodle", probability: 0.01261
    class: "n02113712 miniature poodle", probability: 0.01200
    class: "n02123045 tabby, tabby cat", probability: 0.52391
    class: "n02123394 Persian cat", probability: 0.24143
    class: "n02123159 tiger cat", probability: 0.05892
    class: "n02124075 Egyptian cat", probability: 0.04563
    class: "n03942813 ping-pong ball", probability: 0.01164
    class: "n03770679 minivan", probability: 0.95839
    class: "n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", probability: 0.01674
    class: "n03769881 minibus", probability: 0.00610
    class: "n03594945 jeep, landrover", probability: 0.00448
    class: "n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", probability: 0.00278

Suggestions on configuration of production environment

In this example, we use RDD for task allocation, which is just for display. If performance factors are considered, it is recommended thatUse dataframe as the carrier of data。 Starting with spark 3.0, Apache spark providesBinary file reading function。 In the future, image reading and storage will be easy.

Application of djl in spark in industrial environment

Amazon retail system (ARS) runs millions of large-scale data flow reasoning tasks on spark by using djl. The results of these inferences are used to infer the user’s preference for different operations, such as whether they will buy the product or whether they will add the product to the shopping cart, etc. Thousands of user orientation categories can help Amazon better push related ads to users’ clients and home pages. Ars deep learning model uses thousands of features applied to hundreds of millions of users, and the total amount of input data reaches 100 billion. Under the huge data set, because of the use of spark processing platform based on Scala, they have been troubled by the lack of good solutions. After using djl, their deep learning tasks are easily integrated into spark. Reasoning time has gone from days to hours. We will publish another article later to analyze the deep learning model used by ars and the application of djl in it.

About djl

Djl is a customized deep learning framework for Java developers launched by Amazon cloud services at re: invent conference in 2019, and has been running in millions of reasoning tasks of Amazon. If we want to summarize the main features of djl, they are as follows:

  • Djl is not limited to the back-end engine: users can easily use mxnet, pytorch, tensorflow and fasttext to do model training and reasoning on Java.
  • Djl’s operator design is infinitely close to numpy: its experience is basically seamless with numpy, and switching engine will not change the result.
  • Djl excellent memory management and efficiency mechanism: djl has its own resource recovery mechanism, 100 hours of continuous reasoning will not overflow memory.

To learn more, see the following links:

Welcome to join djlSlack Forum

Deploy the deep learning model on Apache spark with Scala in 10 minutes with a few easy steps

*2-5 times the performance test results based on pyspark on pytorch Python CPU and spark on djl pytorch Scala CPU.

Deploy the deep learning model on Apache spark with Scala in 10 minutes with a few easy steps