{ "cells": [ { "cell_type": "markdown", "id": "d6f83f21-0122-4d3d-897b-eaad49f2647c", "metadata": {}, "source": [ "# Convert sklearn model to ONNX\n", "\n", "For reproducability of the planetsca models, you might be interested in converting the sklearn model objects into the [ONNX](https://onnx.ai/) format. This allows for a model trained with one version of scikit-learn to be used with a different version of scikit-learn (and any other ML framework that uses the ONNX format). [You can read more about model persistence recommendations for scikit-learn models here](https://scikit-learn.org/stable/model_persistence.html).\n", "\n", "To do this conversion, we can use the [skl2onnx](https://onnx.ai/sklearn-onnx/index.html) package, following the steps below." ] }, { "cell_type": "code", "execution_count": 1, "id": "a9e9d95d-c582-423a-a0d1-769baa247c8b", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jovyan/envs/planetenv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from skl2onnx import convert_sklearn\n", "from skl2onnx.common.data_types import FloatTensorType\n", "\n", "import planetsca as ps" ] }, { "cell_type": "markdown", "id": "39ee5433-e20b-4c68-9460-965309ec07b8", "metadata": {}, "source": [ "For this example, we are just going to retrieve the planetsca model from the joblib file on Hugging Face.\n", "\n", "However, `model` below can also be a new custom planetsca model that you've trained yourself. See the *Train and Predict* notebook for an example of this." ] }, { "cell_type": "code", "execution_count": 2, "id": "9c8c10e9-3c3b-480e-9e6e-140f1ef075fd", "metadata": {}, "outputs": [], "source": [ "# retrieve planetsca model from Hugging Face\n", "model = ps.download.retrieve_model()" ] }, { "cell_type": "markdown", "id": "4c9a5666-30eb-4178-b412-93d886779bb4", "metadata": {}, "source": [ "Specify the initial types for the model. Here we need to provide the data type, and shape of the input data for our model. We can call this input variable `surface_reflectance`, and specify that it is a `FloatTensorType` (for floating point numbers) of shape `[None, 4]` where `None` is interpreted as meaning that we will have a flexible number of samples, and `4` corresponds to the four Planet Scope bands (blue, green, red, NIR) for each sample." ] }, { "cell_type": "code", "execution_count": 3, "id": "9f00b8a6-58c9-43bb-a0a0-18ac259496a0", "metadata": {}, "outputs": [], "source": [ "initial_types = [(\"surface_reflectance\", FloatTensorType([None, 4]))]" ] }, { "cell_type": "markdown", "id": "6f65cb85-e23c-4bc1-86a3-b762914c6b63", "metadata": {}, "source": [ "Now, convert the model:" ] }, { "cell_type": "code", "execution_count": 4, "id": "11007eff-9beb-4d60-a981-2d4f67726e72", "metadata": {}, "outputs": [], "source": [ "onnx_model = convert_sklearn(model, initial_types=initial_types)" ] }, { "cell_type": "markdown", "id": "d98daddc-bbf2-43f3-9ed1-6057f91de393", "metadata": {}, "source": [ "And save the ONNX model to a new file." ] }, { "cell_type": "code", "execution_count": 5, "id": "e5b3f2e2-c4df-49f5-85f1-b2c3c60132d3", "metadata": {}, "outputs": [], "source": [ "with open(\"random_forest_20240116_binary_174K.onnx\", \"wb\") as f:\n", " f.write(onnx_model.SerializeToString())" ] }, { "cell_type": "code", "execution_count": null, "id": "10fd9958-a2cd-48ef-a2d8-f125bb3541f5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "planetenv", "language": "python", "name": "planetenv" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }