Skip to main content
Version: Next

Hello World Model

In this tutorial, you will create a "Hello World" model. The model will take a string as input and return a string as output. You will also learn how to export a model as TorchScript model that can be loaded with the PlayTorch SDK for on-device inference.

Create PyTorch Model

Let's begin by creating a PyTorch model. Here, we are going to create a simple "Hello World" model using torch.nn.Module to represent a neural network (hence the namespace nn).

The model defines a forward function with one argument name. The function "performs" the computation, e.g., in later tutorials, it will perform inference on an image.

The model constructor has one argument prefix, which will be used in the forward function to prefix the name argument.

More details on PyTorch modules at https://pytorch.org/docs/stable/notes/modules.html

import torch
from torch import nn

class Model(nn.Module):
def __init__(self, prefix: str):
super().__init__()
self.prefix = prefix

def forward(self, name: str) -> str:
return f"{self.prefix} {name}!"

Create an instance of the model

Next, let's create a instance of the model and perform a computation.

model = Model("Hello")
model("Roman")
Output
Hello Roman!

Export Model for Mobile

Now that we have a model, let's export the model to use on mobile. To do that, we need to script the model (i.e., create a TorchScript representation) as follows:

scripted_model = torch.jit.script(model)
scripted_model("Lindsay")
Output
Hello Lindsay!
note

The torch.jit.script is the recommended way to create a TorchScript model because it can capture control flow, but it might fail in some cases. If that happens, we recommend consulting the PyTorch TorchScript documentation for solutions.

PyTorch offers the optimize_for_mobile utility function to run a list of optimizations on the model (e.g., Conv2D + BatchNorm fusion, dropout removal). It's recommended to optimize the model with this utility before exporting it for mobile.

More details on the optimize_for_mobile utility at: https://pytorch.org/docs/stable/mobile_optimizer.html

from torch.utils.mobile_optimizer import optimize_for_mobile

optimized_model = optimize_for_mobile(scripted_model)
optimized_model("Kodo")
Output
Hello Kodo!

Great! Now, let's export the model for mobile. This is done by saving the model for the lite interpreter. The _save_for_lite_interpreter function will create a hello_world.ptl file, which we will be able to load with the PlayTorch SDK.

optimized_model._save_for_lite_interpreter("hello_world.ptl")

More details on the lite interpreter at: https://pytorch.org/tutorials/prototype/lite_interpreter.html

Create Mobile UI and Load Model on Mobile

Next, let's create a PlayTorch Snack by following the link http://snack.playtorch.dev/. Then, drag and drop the hello_world.ptl file onto the just created PlayTorch Snack--this will import the model into the Snack.

Replace the source code in the App.js with the React Native source code below. The source code below will create a user interface with a text input, a button, and a text element. When pressing the button, it will load the hello_world.ptl model and call the model forward function with the text input value as argument. The returned model output will then be displayed below the button.

import * as React from 'react';
import { useState } from 'react';
import {
Button,
SafeAreaView,
StyleSheet,
Text,
TextInput,
View,
} from 'react-native';
import { torch, MobileModel } from 'react-native-pytorch-core';

export default function App() {
const [modelInput, setModelInput] = useState('');
const [modelOutput, setModelOutput] = useState('');

async function handleModelInput() {
const filePath = await MobileModel.download(require('./hello_world.ptl'));
const model = await torch.jit._loadForMobile(filePath);
const output = await model.forward(modelInput);
setModelOutput(output);
}

return (
<SafeAreaView style={StyleSheet.absoluteFill}>
<View style={styles.container}>
<TextInput
value={modelInput}
onChangeText={setModelInput}
placeholder="Write your name"
/>
<Button title="Let's Go" onPress={handleModelInput} />
<Text>{modelOutput}</Text>
</View>
</SafeAreaView>
);
}

const styles = StyleSheet.create({
container: {
flex: 1,
justifyContent: 'center',
backgroundColor: '#fff',
padding: 20,
},
});