Инструкция как запустить PyTorch модель в браузере.
Демо версия , Код на гитхабе .
ONNX — это библиотека, которая позволяет запускать модели с одного языка на другой. Она позволит запустить нейронную сеть, сделанную на Python, в веб JavaScript. ONNX входит в состав PyTorch.
Для web js, скачайте библиотеки и поместите их в папку с остальными js файлами:
_x000D_wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort-wasm-simd.wasm_x000D_wget https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js.map_x000D_wget https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.jsСохранить модель в ONNX формат можно следующим образом:
_x000D_tensor_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')_x000D__x000D_onnx_model_path = "web/model.onnx"_x000D_input_shape = [1, 32, 32]_x000D__x000D_data_input = torch.zeros(input_shape).to(torch.float32)_x000D__x000D_model = model.to(tensor_device)_x000D_data_input = data_input.to(tensor_device)_x000D__x000D_torch.onnx.export(_x000D_ model,_x000D_ data_input,_x000D_ onnx_model_path,_x000D_ opset_version = 10,_x000D_ input_names = ['input'],_x000D_ output_names = ['output'],_x000D_ verbose=False_x000D_)Очень важные детали:
- opset_version = 10 Генерирует модель с опкодами 10й версии. Если у вас модель не запускается в браузере и выдает ошибку при загрузке, то попробуйте добавить эту строчку.
Подключите скрипт ./ort.min.js через тэг script
JS файл:
_x000D_let input_shape = [1, 32, 32];_x000D__x000D_async function load_model()_x000D_{_x000D_ model = await ort.InferenceSession.create('./mnist3.onnx', {_x000D_ "executionProviders": ["webgl"]_x000D_ });_x000D_ return model;_x000D_}_x000D__x000D_async function predict(model, input)_x000D_{_x000D_ input = Float32Array.from(input);_x000D_ input = new ort.Tensor('float32', input, input_shape);_x000D_ let res = await model.run({ 'input': input });_x000D_ let output = res['output'].data;_x000D_ return output;_x000D_}_x000D__x000D_async function run()_x000D_{_x000D_ let input = [...Array(32 * 32).keys()];_x000D_ let model = await load_model();_x000D_ let output = await predict(model, input);_x000D_ console.log();_x000D_}_x000D__x000D_run();Обратите внимание на строчку input_shape = [1, 32, 32]. Размерность должна быть такой же как и в python скрипте. И массив input в функцию predict должен передаваться такой же размерностью, только в одну линию из 1024 элементов.
Если у вас размерность вектора 32×32 и вам нужно добавить еще одно измерение, то нужно выполнить следующую комманду:
_x000D_data_input = data_input[None,:]Эта команда превращает тензор 32×32 в 1x32x32
Чтобы проверить результат, можно перейти в папку с html файлом и запустить локальный веб сервер, через команду:
_x000D_php -S 127.0.0.1:8080И открыть браузер по адресу http://127.0.0.1:8080/