347 lines
13 KiB
Cython
347 lines
13 KiB
Cython
|
# """
|
||
|
# Component that performs TensorFlow classification on images.
|
||
|
|
||
|
# For a quick start, pick a pre-trained COCO model from:
|
||
|
# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
|
||
|
|
||
|
# For more details about this platform, please refer to the documentation at
|
||
|
# https://home-assistant.io/components/image_processing.tensorflow/
|
||
|
# """
|
||
|
# import logging
|
||
|
# import sys
|
||
|
# import os
|
||
|
|
||
|
# import voluptuous as vol
|
||
|
|
||
|
# from homeassistant.components.image_processing import (
|
||
|
# CONF_CONFIDENCE, CONF_ENTITY_ID, CONF_NAME, CONF_SOURCE, PLATFORM_SCHEMA,
|
||
|
# ImageProcessingEntity)
|
||
|
# from homeassistant.core import split_entity_id
|
||
|
# from homeassistant.helpers import template
|
||
|
# import homeassistant.helpers.config_validation as cv
|
||
|
|
||
|
# REQUIREMENTS = ['numpy==1.15.3', 'pillow==5.2.0', 'protobuf==3.6.1']
|
||
|
|
||
|
# _LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
# ATTR_MATCHES = 'matches'
|
||
|
# ATTR_SUMMARY = 'summary'
|
||
|
# ATTR_TOTAL_MATCHES = 'total_matches'
|
||
|
|
||
|
# CONF_FILE_OUT = 'file_out'
|
||
|
# CONF_MODEL = 'model'
|
||
|
# CONF_GRAPH = 'graph'
|
||
|
# CONF_LABELS = 'labels'
|
||
|
# CONF_MODEL_DIR = 'model_dir'
|
||
|
# CONF_CATEGORIES = 'categories'
|
||
|
# CONF_CATEGORY = 'category'
|
||
|
# CONF_AREA = 'area'
|
||
|
# CONF_TOP = 'top'
|
||
|
# CONF_LEFT = 'left'
|
||
|
# CONF_BOTTOM = 'bottom'
|
||
|
# CONF_RIGHT = 'right'
|
||
|
|
||
|
# AREA_SCHEMA = vol.Schema({
|
||
|
# vol.Optional(CONF_TOP, default=0): cv.small_float,
|
||
|
# vol.Optional(CONF_LEFT, default=0): cv.small_float,
|
||
|
# vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
|
||
|
# vol.Optional(CONF_RIGHT, default=1): cv.small_float
|
||
|
# })
|
||
|
|
||
|
# CATEGORY_SCHEMA = vol.Schema({
|
||
|
# vol.Required(CONF_CATEGORY): cv.string,
|
||
|
# vol.Optional(CONF_AREA): AREA_SCHEMA
|
||
|
# })
|
||
|
|
||
|
# PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||
|
# vol.Optional(CONF_FILE_OUT, default=[]):
|
||
|
# vol.All(cv.ensure_list, [cv.template]),
|
||
|
# vol.Required(CONF_MODEL): vol.Schema({
|
||
|
# vol.Required(CONF_GRAPH): cv.isfile,
|
||
|
# vol.Optional(CONF_LABELS): cv.isfile,
|
||
|
# vol.Optional(CONF_MODEL_DIR): cv.isdir,
|
||
|
# vol.Optional(CONF_AREA): AREA_SCHEMA,
|
||
|
# vol.Optional(CONF_CATEGORIES, default=[]):
|
||
|
# vol.All(cv.ensure_list, [vol.Any(
|
||
|
# cv.string,
|
||
|
# CATEGORY_SCHEMA
|
||
|
# )])
|
||
|
# })
|
||
|
# })
|
||
|
|
||
|
|
||
|
# def draw_box(draw, box, img_width,
|
||
|
# img_height, text='', color=(255, 255, 0)):
|
||
|
# """Draw bounding box on image."""
|
||
|
# ymin, xmin, ymax, xmax = box
|
||
|
# (left, right, top, bottom) = (xmin * img_width, xmax * img_width,
|
||
|
# ymin * img_height, ymax * img_height)
|
||
|
# draw.line([(left, top), (left, bottom), (right, bottom),
|
||
|
# (right, top), (left, top)], width=5, fill=color)
|
||
|
# if text:
|
||
|
# draw.text((left, abs(top-15)), text, fill=color)
|
||
|
|
||
|
|
||
|
# def setup_platform(hass, config, add_entities, discovery_info=None):
|
||
|
# """Set up the TensorFlow image processing platform."""
|
||
|
# model_config = config.get(CONF_MODEL)
|
||
|
# model_dir = model_config.get(CONF_MODEL_DIR) \
|
||
|
# or hass.config.path('tensorflow')
|
||
|
# labels = model_config.get(CONF_LABELS) \
|
||
|
# or hass.config.path('tensorflow', 'object_detection',
|
||
|
# 'data', 'mscoco_label_map.pbtxt')
|
||
|
|
||
|
# # Make sure locations exist
|
||
|
# if not os.path.isdir(model_dir) or not os.path.exists(labels):
|
||
|
# _LOGGER.error("Unable to locate tensorflow models or label map.")
|
||
|
# return
|
||
|
|
||
|
# # append custom model path to sys.path
|
||
|
# sys.path.append(model_dir)
|
||
|
|
||
|
# try:
|
||
|
# # Verify that the TensorFlow Object Detection API is pre-installed
|
||
|
# # pylint: disable=unused-import,unused-variable
|
||
|
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||
|
# import tensorflow as tf # noqa
|
||
|
# from object_detection.utils import label_map_util # noqa
|
||
|
# except ImportError:
|
||
|
# # pylint: disable=line-too-long
|
||
|
# _LOGGER.error(
|
||
|
# "No TensorFlow Object Detection library found! Install or compile "
|
||
|
# "for your system following instructions here: "
|
||
|
# "https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md") # noqa
|
||
|
# return
|
||
|
|
||
|
# try:
|
||
|
# # Display warning that PIL will be used if no OpenCV is found.
|
||
|
# # pylint: disable=unused-import,unused-variable
|
||
|
# import cv2 # noqa
|
||
|
# except ImportError:
|
||
|
# _LOGGER.warning("No OpenCV library found. "
|
||
|
# "TensorFlow will process image with "
|
||
|
# "PIL at reduced resolution.")
|
||
|
|
||
|
# # setup tensorflow graph, session, and label map to pass to processor
|
||
|
# # pylint: disable=no-member
|
||
|
# detection_graph = tf.Graph()
|
||
|
# with detection_graph.as_default():
|
||
|
# od_graph_def = tf.GraphDef()
|
||
|
# with tf.gfile.GFile(model_config.get(CONF_GRAPH), 'rb') as fid:
|
||
|
# serialized_graph = fid.read()
|
||
|
# od_graph_def.ParseFromString(serialized_graph)
|
||
|
# tf.import_graph_def(od_graph_def, name='')
|
||
|
|
||
|
# session = tf.Session(graph=detection_graph)
|
||
|
# label_map = label_map_util.load_labelmap(labels)
|
||
|
# categories = label_map_util.convert_label_map_to_categories(
|
||
|
# label_map, max_num_classes=90, use_display_name=True)
|
||
|
# category_index = label_map_util.create_category_index(categories)
|
||
|
|
||
|
# entities = []
|
||
|
|
||
|
# for camera in config[CONF_SOURCE]:
|
||
|
# entities.append(TensorFlowImageProcessor(
|
||
|
# hass, camera[CONF_ENTITY_ID], camera.get(CONF_NAME),
|
||
|
# session, detection_graph, category_index, config))
|
||
|
|
||
|
# add_entities(entities)
|
||
|
|
||
|
|
||
|
# class TensorFlowImageProcessor(ImageProcessingEntity):
|
||
|
# """Representation of an TensorFlow image processor."""
|
||
|
|
||
|
# def __init__(self, hass, camera_entity, name, session, detection_graph,
|
||
|
# category_index, config):
|
||
|
# """Initialize the TensorFlow entity."""
|
||
|
# model_config = config.get(CONF_MODEL)
|
||
|
# self.hass = hass
|
||
|
# self._camera_entity = camera_entity
|
||
|
# if name:
|
||
|
# self._name = name
|
||
|
# else:
|
||
|
# self._name = "TensorFlow {0}".format(
|
||
|
# split_entity_id(camera_entity)[1])
|
||
|
# self._session = session
|
||
|
# self._graph = detection_graph
|
||
|
# self._category_index = category_index
|
||
|
# self._min_confidence = config.get(CONF_CONFIDENCE)
|
||
|
# self._file_out = config.get(CONF_FILE_OUT)
|
||
|
|
||
|
# # handle categories and specific detection areas
|
||
|
# categories = model_config.get(CONF_CATEGORIES)
|
||
|
# self._include_categories = []
|
||
|
# self._category_areas = {}
|
||
|
# for category in categories:
|
||
|
# if isinstance(category, dict):
|
||
|
# category_name = category.get(CONF_CATEGORY)
|
||
|
# category_area = category.get(CONF_AREA)
|
||
|
# self._include_categories.append(category_name)
|
||
|
# self._category_areas[category_name] = [0, 0, 1, 1]
|
||
|
# if category_area:
|
||
|
# self._category_areas[category_name] = [
|
||
|
# category_area.get(CONF_TOP),
|
||
|
# category_area.get(CONF_LEFT),
|
||
|
# category_area.get(CONF_BOTTOM),
|
||
|
# category_area.get(CONF_RIGHT)
|
||
|
# ]
|
||
|
# else:
|
||
|
# self._include_categories.append(category)
|
||
|
# self._category_areas[category] = [0, 0, 1, 1]
|
||
|
|
||
|
# # Handle global detection area
|
||
|
# self._area = [0, 0, 1, 1]
|
||
|
# area_config = model_config.get(CONF_AREA)
|
||
|
# if area_config:
|
||
|
# self._area = [
|
||
|
# area_config.get(CONF_TOP),
|
||
|
# area_config.get(CONF_LEFT),
|
||
|
# area_config.get(CONF_BOTTOM),
|
||
|
# area_config.get(CONF_RIGHT)
|
||
|
# ]
|
||
|
|
||
|
# template.attach(hass, self._file_out)
|
||
|
|
||
|
# self._matches = {}
|
||
|
# self._total_matches = 0
|
||
|
# self._last_image = None
|
||
|
|
||
|
# @property
|
||
|
# def camera_entity(self):
|
||
|
# """Return camera entity id from process pictures."""
|
||
|
# return self._camera_entity
|
||
|
|
||
|
# @property
|
||
|
# def name(self):
|
||
|
# """Return the name of the image processor."""
|
||
|
# return self._name
|
||
|
|
||
|
# @property
|
||
|
# def state(self):
|
||
|
# """Return the state of the entity."""
|
||
|
# return self._total_matches
|
||
|
|
||
|
# @property
|
||
|
# def device_state_attributes(self):
|
||
|
# """Return device specific state attributes."""
|
||
|
# return {
|
||
|
# ATTR_MATCHES: self._matches,
|
||
|
# ATTR_SUMMARY: {category: len(values)
|
||
|
# for category, values in self._matches.items()},
|
||
|
# ATTR_TOTAL_MATCHES: self._total_matches
|
||
|
# }
|
||
|
|
||
|
# def _save_image(self, image, matches, paths):
|
||
|
# from PIL import Image, ImageDraw
|
||
|
# import io
|
||
|
# img = Image.open(io.BytesIO(bytearray(image))).convert('RGB')
|
||
|
# img_width, img_height = img.size
|
||
|
# draw = ImageDraw.Draw(img)
|
||
|
|
||
|
# # Draw custom global region/area
|
||
|
# if self._area != [0, 0, 1, 1]:
|
||
|
# draw_box(draw, self._area,
|
||
|
# img_width, img_height,
|
||
|
# "Detection Area", (0, 255, 255))
|
||
|
|
||
|
# for category, values in matches.items():
|
||
|
# # Draw custom category regions/areas
|
||
|
# if (category in self._category_areas
|
||
|
# and self._category_areas[category] != [0, 0, 1, 1]):
|
||
|
# label = "{} Detection Area".format(category.capitalize())
|
||
|
# draw_box(draw, self._category_areas[category], img_width,
|
||
|
# img_height, label, (0, 255, 0))
|
||
|
|
||
|
# # Draw detected objects
|
||
|
# for instance in values:
|
||
|
# label = "{0} {1:.1f}%".format(category, instance['score'])
|
||
|
# draw_box(draw, instance['box'],
|
||
|
# img_width, img_height,
|
||
|
# label, (255, 255, 0))
|
||
|
|
||
|
# for path in paths:
|
||
|
# _LOGGER.info("Saving results image to %s", path)
|
||
|
# img.save(path)
|
||
|
|
||
|
# def process_image(self, image):
|
||
|
# """Process the image."""
|
||
|
# import numpy as np
|
||
|
|
||
|
# try:
|
||
|
# import cv2 # pylint: disable=import-error
|
||
|
# img = cv2.imdecode(
|
||
|
# np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
||
|
# inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
||
|
# inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
||
|
# except ImportError:
|
||
|
# from PIL import Image
|
||
|
# import io
|
||
|
# img = Image.open(io.BytesIO(bytearray(image))).convert('RGB')
|
||
|
# img.thumbnail((460, 460), Image.ANTIALIAS)
|
||
|
# img_width, img_height = img.size
|
||
|
# inp = np.array(img.getdata()).reshape(
|
||
|
# (img_height, img_width, 3)).astype(np.uint8)
|
||
|
# inp_expanded = np.expand_dims(inp, axis=0)
|
||
|
|
||
|
# image_tensor = self._graph.get_tensor_by_name('image_tensor:0')
|
||
|
# boxes = self._graph.get_tensor_by_name('detection_boxes:0')
|
||
|
# scores = self._graph.get_tensor_by_name('detection_scores:0')
|
||
|
# classes = self._graph.get_tensor_by_name('detection_classes:0')
|
||
|
# boxes, scores, classes = self._session.run(
|
||
|
# [boxes, scores, classes],
|
||
|
# feed_dict={image_tensor: inp_expanded})
|
||
|
# boxes, scores, classes = map(np.squeeze, [boxes, scores, classes])
|
||
|
# classes = classes.astype(int)
|
||
|
|
||
|
# matches = {}
|
||
|
# total_matches = 0
|
||
|
# for box, score, obj_class in zip(boxes, scores, classes):
|
||
|
# score = score * 100
|
||
|
# boxes = box.tolist()
|
||
|
|
||
|
# # Exclude matches below min confidence value
|
||
|
# if score < self._min_confidence:
|
||
|
# continue
|
||
|
|
||
|
# # Exclude matches outside global area definition
|
||
|
# if (boxes[0] < self._area[0] or boxes[1] < self._area[1]
|
||
|
# or boxes[2] > self._area[2] or boxes[3] > self._area[3]):
|
||
|
# continue
|
||
|
|
||
|
# category = self._category_index[obj_class]['name']
|
||
|
|
||
|
# # Exclude unlisted categories
|
||
|
# if (self._include_categories
|
||
|
# and category not in self._include_categories):
|
||
|
# continue
|
||
|
|
||
|
# # Exclude matches outside category specific area definition
|
||
|
# if (self._category_areas
|
||
|
# and (boxes[0] < self._category_areas[category][0]
|
||
|
# or boxes[1] < self._category_areas[category][1]
|
||
|
# or boxes[2] > self._category_areas[category][2]
|
||
|
# or boxes[3] > self._category_areas[category][3])):
|
||
|
# continue
|
||
|
|
||
|
# # If we got here, we should include it
|
||
|
# if category not in matches.keys():
|
||
|
# matches[category] = []
|
||
|
# matches[category].append({
|
||
|
# 'score': float(score),
|
||
|
# 'box': boxes
|
||
|
# })
|
||
|
# total_matches += 1
|
||
|
|
||
|
# # Save Images
|
||
|
# if total_matches and self._file_out:
|
||
|
# paths = []
|
||
|
# for path_template in self._file_out:
|
||
|
# if isinstance(path_template, template.Template):
|
||
|
# paths.append(path_template.render(
|
||
|
# camera_entity=self._camera_entity))
|
||
|
# else:
|
||
|
# paths.append(path_template)
|
||
|
# self._save_image(image, matches, paths)
|
||
|
|
||
|
# self._matches = matches
|
||
|
# self._total_matches = total_matches
|