How to call TensorFlow models in Scala
9 minute read
In this article I provide a short example of how to serve a TensorFlow model from Scala1. I show how this is possible by serialising an example model to the Onnx format which then allows it to be served using the Java Onnx runtime2 from Scala. Along the way I highlight some of the decisions I made and discuss why you might want to make different ones in your own work.
ML models need to be served in a variety of different ways depending on the functionality they enable. One implication of this is that the development and production environments are seperate. This could be to manage security concerns, different infrastructure (i.e. deploying on a phone or IOT device) or to manage language or dependency availability. If the production environment doesn’t use the same language or dependencies that were used to develop the model then we need to explore options for serving that model in other ways.
One option is to serialize the model into a format for storing ML models and serve the serialized model using a seperate runtime. Formats used to store ML Models are known as “intermediary representation” formats. Onnx is one such IR format 3. The diagram below shows the logic of this process:
Exporting to Onnx
I used this timeseries anomaly detection model as an example model to start with. Preferably you should bring your own, but starting with something that’s already working to understand the interface is not a bad idea either. Clone this repo to follow along.
To export a TensorFlow model to Onnx you will need to install the package tf2onnx. The model is exported by running a script from the command line. For my project this looked like:
1 python -m tf2onnx.convert --saved-model ./model --opset 10 --output tfScala/src/main/resources/model.onnx
This will produce a single onnx file. Most of the arguments above are fairly self explanatory apart from
--opset. The Opset number is how Onnx versions it’s API. The API is guaranteeing that a specific collection of mathematical operations commonly used in machine learning are supported. Now that the model has been serialized into Onnx it’s time to call it in Scala!
Serving the Onnx model in Scala
There are runtimes written in Java and Scala that serve Onnx files. The Scala implementation builds on the Java library and is worth checking out as it’s typesafe and built specifically for providing an accessible abstraction for serving serialized models in Scala. However, I use the Java library in this post as it’s a little more fundamental and makes less assumptions on how I’d like to interact with the models I interface with.
Here’s a quick diagram 4 of what we need to understand about how the Onnx runtime integrates within my program to serve the sample model I serializer earlier:
There are some considerations that I’ll call out now:
- Onnx doesn’t currently support operators that manipulate text. This means you’ll have to pre-process (tokenize) your textual input prior to being sent to the model.
- In functional programming terms the creation of the environment, Onnx Tensors and output all occur as side effects. You now have a choice about how you’d like to handle that. Functional error handling can be extremely powerful, but often capturing it at the level of “the ML function didn’t work” is just as useful as “my cats IO failed here” as the default errors from Onnx are usually explicit enough. It depends, but it’s worth considering what level of functional abstraction is useful given the time you have available to solve this problem.
- The output of the Onnx session needs to be cast to a typed val when being retrieved.
All the logic for calling the serialized model through Onnx are contained in the Interface.scala file. Let’s review each important expression to understand what it means.
We start with loading the model file and instantiating the Onnx environment and session. In this case I’ve stored the model file in the resources of the jar. This might be fine if the model file is small and unlikely to be updated after the application has been deployed, otherwise you would store it externally to the jar.
There can be at most one Onnx Environment created per Java runtime. It’s possible to instantiate as many session, hence models, as you want, within reason. I’ve found that creating models from bytes to be the most reliable of the other methods as it allows me to specify how exactly the model file is to be accessed rather than leaving that to the default implementation in
createSession. You can also manage hardware and threads through the Onnx session but that’s outside the scope of this article.
1 2 3 val modelBytes = Files.readAllBytes(Paths.get(getClass.getResource("model.onnx").toURI)) val env = OrtEnvironment.getEnvironment() val session = env.createSession(modelBytes)
After the environment has been created it’s possible to convert the input into Onnx Tensors. This specific model accepts an array with the dimension \(n \times 288\), where n in the number of inputs requiring evaluation. As this is example is about the plumbing rather than the contents I’ll stick to creating a dummy input of the correct dimension containing zeros.
Note that for this implementation each value needs to be wrapped in it’s own Array, and it does matter that it’s an Array not a Seq. This is because of how the Java Onnx runtime converts values to the fundamentals required for the runtime - see this file.
1 2 3 4 val input = Array(Seq.fill(288)(0).map(_.toFloat).map(Array(_)).toArray) val inputTensor = OnnxTensor.createTensor(env, input) val inputName = "input_1" val modelInput = Map(inputName -> inputTensor).asJava
It’s time to pass the processed input into the session to run against the model. You can unpack the result returned from an Onnx model in a few different ways, but I prefer to get the output by name specifically as this makes the logic more explicit in the code. The result is then cast into a type that makes it easier to work with.
1 2 3 4 val result = session.run(modelInput).get("conv1d_transpose_2") .get() .getValue .asInstanceOf[Array[Array[Array[Float]]]]
Finally the result is unpacked from it’s cast type for further use in the application. For this specific model these are the probabilities that an anomaly has been detected, but if you were working with an LLM this would be where the model outputs might be decoded into text. Either way, it’s worth inserting error handling and writing unit tests to convert this behaviour.
1 2 val parsedResult = result.head.flatMap(_.toSeq).toSeq println(parsedResult)
In the sections above I take you through a workflow for exporting an ML model developed with Python into a format, Onnx, that can then be served with Scala. Now you’ve seen the simplest possible workflow for doing this there are several different problems or actions that can be taken:
- Write unit tests to test the translations from input to Onnx Tensors and from the Onnx Result to output. After these tests have been implemented consider writing an integration test based on a dummy model with an identical API to the real models to confirm the end to end functionality of your feature.
- Functional programming dominates Scala programming and rightly so as a lot of it’s ideas write easy to maintain code. However, the depth to which functional abstractions should be used is often dependent on the business problem and personal preferences.
- Onnx has significantly higher throughput and lower latency when being called in batch mode.
- If the functionality needs to be called in Spark make heavy use of the
@transient lazy valpattern to pass around the model file.
If you’re only going to be serving TensorFlow models in Java or Scala then check out TensorFlow Java. I’m tackling the general case where I want to serve any model that can be serialized into Onnx and I’m just using a TensorFlow model as an example. TensorFlow Java has a significantly smaller dependency bundle than Onnx. Both, however, are significantly smaller than having to serve the model development environment. ↩