Skip to content

Commit 09ee4d1

Browse files
de-codeMrRiahi
andauthored
added tflite runtime support (#167)
* Release the tflite inference from tensorflow * Update README * Remove the bodypix_tflite from develop branch * Add tflite inference to the tflite_inference branch * added initial build_tflite workflow job * added --use-feature=in-tree-build * don't install tflite by default * moved build_tflite up * added tflite extra * using dev-install-tflite * make dev-install-tflite install build and dev depenencies * run pytest for tflite * using tflite extra when installing tflite * added make dev-pytest-tflite * linting: addressed markdown linting * made tensorflow import optional * added test_should_be_able_to_use_existing_tflite_model * import tflite_runtime.interpreter * extracted load_image * load image using pillow * adapted pad_and_resize_to using _pad_image_like_tensorflow * implemented _resize_image_to_using_pillow * added make dev-watch-tflite * fallback to np expand_dims without tf * extracted _get_mobilenet_preprocessed_image with np fallback * reuse resize_image_to for scale_and_crop_to_input_tensor_shape * automatically reduce dimension if needed * added support for single channel in _resize_image_to_using_pillow * extracted get_sigmoid and implemented np version * reuse resize_image_to * fixed failing test * removed trailing space from requirements.txt * cli: automatically select tflite model if full tf is not available * don't fail with missing tf when adding alpha mask * added tflite support to draw mask cli * added support for remote tflite models; defined model tflite paths * use model path constants for cli * added TensorFlow Lite Runtime support section to readme * ignore tflite models * removed obsolete bodypix_tflite diectory * fixed draw mask * replaced pillow resize with numpy handling floats * retain original dtype when padding * use float32 for imagenet preprocessing * debug logging of input image * added list-tflite-models sub command * fixed resnet tflite support * added more tflite models Co-authored-by: MrRiahi <mohammad.r.riahi@gmail.com>
1 parent 8222006 commit 09ee4d1

17 files changed

Lines changed: 698 additions & 97 deletions

File tree

.github/workflows/ci.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,30 @@ jobs:
2121
env:
2222
TEST_PYPI_PASSWORD: ${{ secrets.test_pypi_password }}
2323
24+
build_tflite:
25+
needs: []
26+
runs-on: ${{ matrix.os }}
27+
strategy:
28+
matrix:
29+
os: [ubuntu-latest]
30+
python-version: [3.8]
31+
include:
32+
- python-version: 3.8
33+
34+
steps:
35+
- uses: actions/checkout@v2
36+
- name: Set up Python ${{ matrix.python-version }}
37+
uses: actions/setup-python@v2
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
- name: Install dependencies
41+
run: |
42+
make venv-create SYSTEM_PYTHON=python
43+
make dev-install-tflite
44+
- name: Test with pytest
45+
run: |
46+
make dev-pytest-tflite
47+
2448
build:
2549
needs: ["check_secrets"]
2650
runs-on: ${{ matrix.os }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ build
66
*.egg-info
77

88
*.pyc
9+
*.tflite

Makefile

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,25 @@ venv-create:
4545
$(SYSTEM_PYTHON) -m venv $(VENV)
4646

4747

48-
dev-install:
48+
dev-install-build-dependencies:
4949
$(PIP) install -r requirements.build.txt
50+
51+
52+
dev-install: dev-install-build-dependencies
5053
$(PIP) install \
5154
-r requirements.dev.txt \
5255
-r requirements.txt
5356

5457

58+
dev-install-tflite: dev-install-build-dependencies
59+
$(PIP) install -r requirements.dev.txt
60+
$(PIP) install --use-feature=in-tree-build .[tflite,image]
61+
62+
63+
dev-run-pip:
64+
$(PIP) $(ARGS)
65+
66+
5567
dev-venv: venv-create dev-install
5668

5769

@@ -75,10 +87,20 @@ dev-pytest:
7587
$(PYTHON) -m pytest -p no:cacheprovider $(ARGS)
7688

7789

90+
dev-pytest-tflite:
91+
$(MAKE) dev-pytest \
92+
ARGS='tests/cli_test.py -k test_should_be_able_to_use_existing_tflite_model'
93+
94+
7895
dev-watch:
7996
$(PYTHON) -m pytest_watch -- -p no:cacheprovider -p no:warnings $(ARGS)
8097

8198

99+
dev-watch-tflite:
100+
$(MAKE) dev-watch \
101+
ARGS='tests/cli_test.py -k test_should_be_able_to_use_existing_tflite_model'
102+
103+
82104
dev-test: dev-lint dev-pytest
83105

84106

@@ -114,6 +136,11 @@ list-models:
114136
list-models
115137

116138

139+
list-tflite-models:
140+
$(PYTHON) -m tf_bodypix \
141+
list-tflite-models
142+
143+
117144
convert-example-draw-mask:
118145
$(PYTHON) -m tf_bodypix \
119146
draw-mask \
@@ -240,6 +267,66 @@ webcam-v4l2-replace-background:
240267
$(ARGS)
241268

242269

270+
convert-tfjs-models-to-tflite:
271+
mkdir -p "./data/tflite-models"
272+
$(PYTHON) -m tf_bodypix \
273+
convert-to-tflite \
274+
--model-path \
275+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/050/model-stride8.json" \
276+
--optimize \
277+
--quantization-type=float16 \
278+
--output-model-file "./data/tflite-models/mobilenet-float-multiplier-050-stride8-float16.tflite"
279+
$(PYTHON) -m tf_bodypix \
280+
convert-to-tflite \
281+
--model-path \
282+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/050/model-stride16.json" \
283+
--optimize \
284+
--quantization-type=float16 \
285+
--output-model-file "./data/tflite-models/mobilenet-float-multiplier-050-stride16-float16.tflite"
286+
$(PYTHON) -m tf_bodypix \
287+
convert-to-tflite \
288+
--model-path \
289+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride8.json" \
290+
--optimize \
291+
--quantization-type=float16 \
292+
--output-model-file "./data/tflite-models/mobilenet-float-multiplier-075-stride8-float16.tflite"
293+
$(PYTHON) -m tf_bodypix \
294+
convert-to-tflite \
295+
--model-path \
296+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \
297+
--optimize \
298+
--quantization-type=float16 \
299+
--output-model-file "./data/tflite-models/mobilenet-float-multiplier-075-stride16-float16.tflite"
300+
$(PYTHON) -m tf_bodypix \
301+
convert-to-tflite \
302+
--model-path \
303+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/100/model-stride8.json" \
304+
--optimize \
305+
--quantization-type=float16 \
306+
--output-model-file "./data/tflite-models/mobilenet-float-multiplier-100-stride8-float16.tflite"
307+
$(PYTHON) -m tf_bodypix \
308+
convert-to-tflite \
309+
--model-path \
310+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/100/model-stride16.json" \
311+
--optimize \
312+
--quantization-type=float16 \
313+
--output-model-file "./data/tflite-models/mobilenet-float-multiplier-100-stride16-float16.tflite"
314+
$(PYTHON) -m tf_bodypix \
315+
convert-to-tflite \
316+
--model-path \
317+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/resnet50/float/model-stride16.json" \
318+
--optimize \
319+
--quantization-type=float16 \
320+
--output-model-file "./data/tflite-models/resnet50-float-stride16-float16.tflite"
321+
$(PYTHON) -m tf_bodypix \
322+
convert-to-tflite \
323+
--model-path \
324+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/resnet50/float/model-stride32.json" \
325+
--optimize \
326+
--quantization-type=float16 \
327+
--output-model-file "./data/tflite-models/resnet50-float-stride32-float16.tflite"
328+
329+
243330
docker-build:
244331
docker build . -t $(IMAGE_NAME):$(IMAGE_TAG)
245332

README.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ when using this project as a library:
3131
| ---------- | -----------
3232
| tf | [TensorFlow](https://pypi.org/project/tensorflow/) (required). But you may use your own build.
3333
| tfjs | TensorFlow JS Model support, using [tfjs-graph-converter](https://pypi.org/project/tfjs-graph-converter/)
34+
| tflite | [tflite-runtime](https://pypi.org/project/tflite-runtime/)
3435
| image | Image loading via [Pillow](https://pypi.org/project/Pillow/), required by the CLI.
3536
| video | Video support via [OpenCV](https://pypi.org/project/opencv-python/)
3637
| webcam | Webcam support via [OpenCV](https://pypi.org/project/opencv-python/) and [pyfakewebcam](https://pypi.org/project/pyfakewebcam/)
37-
| all | All of the libraries
38+
| all | All of the libraries (except `tflite-runtime`)
3839

3940
## Python API
4041

@@ -117,6 +118,12 @@ Those URLs can be passed as the `--model-path` arguments below, or to the `downl
117118

118119
The CLI will download and cache the model from the provided path. If no `--model-path` is provided, it will use a default model (mobilenet).
119120

121+
To list TensorFlow Lite models instead:
122+
123+
```bash
124+
python -m tf_bodypix list-tflite-models
125+
```
126+
120127
### Inputs and Outputs
121128

122129
Most commands will work with inputs (source) and outputs.
@@ -317,7 +324,7 @@ python -m tf_bodypix \
317324

318325
Background: [Brown Landscape Under Grey Sky](https://www.pexels.com/photo/brown-landscape-under-grey-sky-3244513/)
319326

320-
## TensorFlow Lite support (experimental)
327+
## TensorFlow Lite Model support (experimental)
321328

322329
The model path may also point to a TensorFlow Lite model (`.tflite` extension). Whether that actually improves performance may depend on the platform and available hardware.
323330

@@ -330,7 +337,7 @@ python -m tf_bodypix \
330337
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \
331338
--optimize \
332339
--quantization-type=float16 \
333-
--output-model-file "./mobilenet-float16-stride16.tflite"
340+
--output-model-file "./mobilenet-float-multiplier-075-stride16-float16.tflite"
334341
```
335342

336343
The above command is provided for convenience.
@@ -342,6 +349,12 @@ Relevant links:
342349
* [TF Lite post_training_quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
343350
* [TF GitHub #40183](https://github.com/tensorflow/tensorflow/issues/40183).
344351

352+
## TensorFlow Lite Runtime support (experimental)
353+
354+
This project can also be used with [tflite-runtime](https://pypi.org/project/tflite-runtime/) instead of full TensorFlow (e.g. by using the `tflite` extra).
355+
However, [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/) would require full TensorFlow.
356+
In order to avoid it, one needs to use a TensorFlow Lite model (see previous section).
357+
345358
## Docker Usage
346359

347360
You could also use the Docker image if you prefer.

requirements.tflite.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tflite-runtime==2.7.0

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
opencv-python==4.5.5.62
1+
opencv-python==4.5.5.62
22
Pillow==8.4.0; python_version < "3.7"
33
Pillow==9.0.1; python_version >= "3.7"
44
pyfakewebcam==0.1.0

setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
REQUIRED_PACKAGES = f.readlines()
1313

1414

15+
with open('requirements.tflite.txt', 'r', encoding='utf-8') as f:
16+
TFLITE_REQUIRED_PACKAGES = f.readlines()
17+
18+
1519
with open('README.md', 'r', encoding='utf-8') as f:
1620
LONG_DESCRIPTION = '\n'.join([
1721
line.rstrip()
@@ -30,6 +34,10 @@ def local_scheme(version):
3034
get_requirements_with_groups(REQUIRED_PACKAGES)
3135
)
3236

37+
ALL_EXTRAS = {
38+
**EXTRAS,
39+
'tflite': TFLITE_REQUIRED_PACKAGES
40+
}
3341

3442
packages = find_packages(exclude=["tests", "tests.*"])
3543

@@ -42,7 +50,7 @@ def local_scheme(version):
4250
author="Daniel Ecer",
4351
url="https://github.com/de-code/python-tf-bodypix",
4452
install_requires=DEFAULT_REQUIRED_PACKAGES,
45-
extras_require=EXTRAS,
53+
extras_require=ALL_EXTRAS,
4654
packages=packages,
4755
include_package_data=True,
4856
description='Python implemention of the TensorFlow BodyPix model.',

tests/cli_test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22
from pathlib import Path
33

4-
from tf_bodypix.download import BodyPixModelPaths
4+
from tf_bodypix.download import ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS, BodyPixModelPaths
55
from tf_bodypix.model import ModelArchitectureNames
6-
from tf_bodypix.cli import main
6+
from tf_bodypix.cli import DEFAULT_MODEL_TFLITE_PATH, main
77

88

99
LOGGER = logging.getLogger(__name__)
@@ -97,6 +97,15 @@ def test_should_list_all_default_model_urls(self, capsys):
9797
missing_urls = set(expected_urls) - set(output_urls)
9898
assert not missing_urls
9999

100+
def test_should_list_all_default_tflite_models(self, capsys):
101+
expected_urls = ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS
102+
main(['list-tflite-models'])
103+
captured = capsys.readouterr()
104+
output_urls = captured.out.splitlines()
105+
LOGGER.debug('output_urls: %s', output_urls)
106+
missing_urls = set(expected_urls) - set(output_urls)
107+
assert not missing_urls
108+
100109
def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path):
101110
output_model_file = temp_dir / 'model.tflite'
102111
main([
@@ -115,3 +124,14 @@ def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path)
115124
'--source=%s' % EXAMPLE_IMAGE_URL,
116125
'--output=%s' % output_image_path
117126
])
127+
128+
def test_should_be_able_to_use_existing_tflite_model(self, temp_dir: Path):
129+
output_image_path = temp_dir / 'mask.jpg'
130+
main([
131+
'draw-mask',
132+
'--model-path=%s' % DEFAULT_MODEL_TFLITE_PATH,
133+
'--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1,
134+
'--output-stride=16',
135+
'--source=%s' % EXAMPLE_IMAGE_URL,
136+
'--output=%s' % output_image_path
137+
])

tf_bodypix/bodypix_js_utils/decode_part_map.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
# based on:
22
# https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/decode_part_map.ts
33

4-
import tensorflow as tf
4+
try:
5+
import tensorflow as tf
6+
except ImportError:
7+
tf = None
58

69
import numpy as np
710

811

12+
DEFAULT_DTYPE = (
13+
tf.int32 if tf is not None else np.int32
14+
)
15+
16+
917
def to_mask_tensor(
1018
segment_scores: np.ndarray,
1119
threshold: float,
12-
dtype: type = tf.int32
20+
dtype: type = DEFAULT_DTYPE
1321
) -> np.ndarray:
22+
if tf is None:
23+
return (segment_scores > threshold).astype(dtype)
1424
return tf.cast(
1525
tf.greater(segment_scores, threshold),
1626
dtype

0 commit comments

Comments
 (0)