Инструкция как запустить PyTorch модель в браузере.

Демо версия , Код на гитхабе .

ONNX — это библиотека, которая позволяет запускать модели с одного языка на другой. Она позволит запустить нейронную сеть, сделанную на Python, в веб JavaScript. ONNX входит в состав PyTorch.

Для web js, скачайте библиотеки и поместите их в папку с остальными js файлами:

_x000D_wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort-wasm-simd.wasm_x000D_wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js.map_x000D_wget https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js

Сохранить модель в ONNX формат можно следующим образом:

_x000D_tensor_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')_x000D__x000D_onnx_model_path = "web/model.onnx"_x000D_input_shape = [1, 32, 32]_x000D__x000D_data_input = torch.zeros(input_shape).to(torch.float32)_x000D__x000D_model = model.to(tensor_device)_x000D_data_input = data_input.to(tensor_device)_x000D__x000D_torch.onnx.export(_x000D_ model,_x000D_ data_input,_x000D_ onnx_model_path,_x000D_ opset_version = 10,_x000D_ input_names = ['input'],_x000D_ output_names = ['output'],_x000D_ verbose=False_x000D_)

Очень важные детали:

  • opset_version = 10 Генерирует модель с опкодами 10й версии. Если у вас модель не запускается в браузере и выдает ошибку при загрузке, то попробуйте добавить эту строчку.

Подключите скрипт ./ort.min.js через тэг script

JS файл:

_x000D_let input_shape = [1, 32, 32];_x000D__x000D_async function load_model()_x000D_{_x000D_ model = await ort.InferenceSession.create('./mnist3.onnx', {_x000D_  "executionProviders": ["webgl"]_x000D_ });_x000D_ return model;_x000D_}_x000D__x000D_async function predict(model, input)_x000D_{_x000D_ input = Float32Array.from(input);_x000D_ input = new ort.Tensor('float32', input, input_shape);_x000D_ let res = await model.run({ 'input': input });_x000D_ let output = res['output'].data;_x000D_ return output;_x000D_}_x000D__x000D_async function run()_x000D_{_x000D_ let input = [...Array(32 * 32).keys()];_x000D_ let model = await load_model();_x000D_ let output = await predict(model, input);_x000D_ console.log();_x000D_}_x000D__x000D_run();

Обратите внимание на строчку input_shape = [1, 32, 32]. Размерность должна быть такой же как и в python скрипте. И массив input в функцию predict должен передаваться такой же размерностью, только в одну линию из 1024 элементов.

Если у вас размерность вектора 32×32 и вам нужно добавить еще одно измерение, то нужно выполнить следующую комманду:

_x000D_data_input = data_input[None,:]

Эта команда превращает тензор 32×32 в 1x32x32

Чтобы проверить результат, можно  перейти в папку с html файлом и запустить локальный веб сервер, через команду:

_x000D_php -S 127.0.0.1:8080

И открыть браузер по адресу http://127.0.0.1:8080/