Calling Tensorflow models in Scala

Thomas Dickson

9 minute read



In this article I provide a short example of how to serve a TensorFlow model from Scala1. I show how this is possible by serialising an example model to the Onnx format which then allows it to be served using the Java Onnx runtime2 from Scala. Along the way I highlight some of the decisions I made and discuss why you might want to make different ones in your own work.

Context

ML models need to be served in a variety of different ways depending on the functionality they enable. One implication of this is that the development and production environments are seperate. This could be to manage security concerns, different infrastructure (i.e. deploying on a phone or IOT device) or to manage language or dependency availability. If the production environment doesn’t use the same language or dependencies that were used to develop the model then we need to explore options for serving that model in other ways.

One option is to serialize the model into a format for storing ML models and serve the serialized model using a seperate runtime. Formats used to store ML Models are known as “intermediary representation” formats. Onnx is one such IR format 3. The diagram below shows the logic of this process:

graph LR A(Development environment) --> B(Model serialization) --> |Model artefacts in IR format| D(Production environment)

Exporting to Onnx

I used this timeseries anomaly detection model as an example model to start with. Preferably you should bring your own, but starting with something that’s already working to understand the interface is not a bad idea either. Clone this repo to follow along.

To export a TensorFlow model to Onnx you will need to install the package tf2onnx. The model is exported by running a script from the command line. For my project this looked like:

1
python -m tf2onnx.convert --saved-model ./model --opset 10 --output tfScala/src/main/resources/model.onnx

This will produce a single onnx file. Most of the arguments above are fairly self explanatory apart from --opset. The Opset number is how Onnx versions it’s API. The API is guaranteeing that a specific collection of mathematical operations commonly used in machine learning are supported. Now that the model has been serialized into Onnx it’s time to call it in Scala!

Serving the Onnx model in Scala

There are runtimes written in Java and Scala that serve Onnx files. The Scala implementation builds on the Java library and is worth checking out as it’s typesafe and built specifically for providing an accessible abstraction for serving serialized models in Scala. However, I use the Java library in this post as it’s a little more fundamental and makes less assumptions on how I’d like to interact with the models I interface with.

Here’s a quick diagram 4 of what we need to understand about how the Onnx runtime integrates within my program to serve the sample model I serializer earlier:

The different stages for serving models through an Onnx runtime.
The different stages for serving models through an Onnx runtime.

There are some considerations that I’ll call out now:

All the logic for calling the serialized model through Onnx are contained in the Interface.scala file. Let’s review each important expression to understand what it means.

We start with loading the model file and instantiating the Onnx environment and session. In this case I’ve stored the model file in the resources of the jar. This might be fine if the model file is small and unlikely to be updated after the application has been deployed, otherwise you would store it externally to the jar.

There can be at most one Onnx Environment created per Java runtime. It’s possible to instantiate as many session, hence models, as you want, within reason. I’ve found that creating models from bytes to be the most reliable of the other methods as it allows me to specify how exactly the model file is to be accessed rather than leaving that to the default implementation in createSession. You can also manage hardware and threads through the Onnx session but that’s outside the scope of this article.

1
2
3
val modelBytes = Files.readAllBytes(Paths.get(getClass.getResource("model.onnx").toURI))
val env = OrtEnvironment.getEnvironment()
val session = env.createSession(modelBytes)

After the environment has been created it’s possible to convert the input into Onnx Tensors. This specific model accepts an array with the dimension \(n \times 288\), where n in the number of inputs requiring evaluation. As this is example is about the plumbing rather than the contents I’ll stick to creating a dummy input of the correct dimension containing zeros.

Note that for this implementation each value needs to be wrapped in it’s own Array, and it does matter that it’s an Array not a Seq. This is because of how the Java Onnx runtime converts values to the fundamentals required for the runtime - see this file.

1
2
3
4
val input = Array(Seq.fill(288)(0).map(_.toFloat).map(Array(_)).toArray)
val inputTensor = OnnxTensor.createTensor(env, input)
val inputName = "input_1"
val modelInput = Map(inputName -> inputTensor).asJava

It’s time to pass the processed input into the session to run against the model. You can unpack the result returned from an Onnx model in a few different ways, but I prefer to get the output by name specifically as this makes the logic more explicit in the code. The result is then cast into a type that makes it easier to work with.

1
2
3
4
val result = session.run(modelInput).get("conv1d_transpose_2")
    .get()
    .getValue
    .asInstanceOf[Array[Array[Array[Float]]]]

Finally the result is unpacked from it’s cast type for further use in the application. For this specific model these are the probabilities that an anomaly has been detected, but if you were working with an LLM this would be where the model outputs might be decoded into text. Either way, it’s worth inserting error handling and writing unit tests to convert this behaviour.

1
2
val parsedResult = result.head.flatMap(_.toSeq).toSeq
println(parsedResult)

Next steps

In the sections above I take you through a workflow for exporting an ML model developed with Python into a format, Onnx, that can then be served with Scala. Now you’ve seen the simplest possible workflow for doing this there are several different problems or actions that can be taken:

  1. Write unit tests to test the translations from input to Onnx Tensors and from the Onnx Result to output. After these tests have been implemented consider writing an integration test based on a dummy model with an identical API to the real models to confirm the end to end functionality of your feature.
  2. Functional programming dominates Scala programming and rightly so as a lot of it’s ideas write easy to maintain code. However, the depth to which functional abstractions should be used is often dependent on the business problem and personal preferences.
  3. Onnx has significantly higher throughput and lower latency when being called in batch mode.
  4. If the functionality needs to be called in Spark make heavy use of the @transient lazy val pattern to pass around the model file.
  1. This repo contains the example code I’ve written to demonstrate how it’s possible to call models serialized to Onnx in Scala. 

  2. If you’re only going to be serving TensorFlow models in Java or Scala then check out TensorFlow Java. I’m tackling the general case where I want to serve any model that can be serialized into Onnx and I’m just using a TensorFlow model as an example. TensorFlow Java has a significantly smaller dependency bundle than Onnx. Both, however, are significantly smaller than having to serve the model development environment. 

  3. Safetensors is another format, but I haven’t found it’s support in the languages I’m interested in. Hopefully that changes soon! 

  4. I used Excalidraw to draw this diagram. As mentioned here I also use Mermaid tp create diagrams.