home-assistant-configuration/config/custom_components/localtuya/pytuya/__init__.py

683 lines
23 KiB
Python

# PyTuya Module
# -*- coding: utf-8 -*-
"""
Python module to interface with Tuya WiFi smart devices.
Mostly derived from Shenzhen Xenon ESP8266MOD WiFi smart devices
E.g. https://wikidevi.com/wiki/Xenon_SM-PW701U
Author: clach04
Maintained by: postlund
For more information see https://github.com/clach04/python-tuya
Classes
TuyaInterface(dev_id, address, local_key=None)
dev_id (str): Device ID e.g. 01234567891234567890
address (str): Device Network IP Address e.g. 10.0.1.99
local_key (str, optional): The encryption key. Defaults to None.
Functions
json = status() # returns json payload
set_version(version) # 3.1 [default] or 3.3
detect_available_dps() # returns a list of available dps provided by the device
update_dps(dps) # sends update dps command
add_dps_to_request(dp_index) # adds dp_index to the list of dps used by the
# device (to be queried in the payload)
set_dp(on, dp_index) # Set value of any dps index.
Credits
* TuyaAPI https://github.com/codetheweb/tuyapi by codetheweb and blackrozes
For protocol reverse engineering
* PyTuya https://github.com/clach04/python-tuya by clach04
The origin of this python module (now abandoned)
* LocalTuya https://github.com/rospogrigio/localtuya-homeassistant by rospogrigio
Updated pytuya to support devices with Device IDs of 22 characters
"""
import asyncio
import base64
import binascii
import json
import logging
import struct
import time
import weakref
from abc import ABC, abstractmethod
from collections import namedtuple
from hashlib import md5
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
version_tuple = (9, 0, 0)
version = version_string = __version__ = "%d.%d.%d" % version_tuple
__author__ = "postlund"
_LOGGER = logging.getLogger(__name__)
TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc")
SET = "set"
STATUS = "status"
HEARTBEAT = "heartbeat"
UPDATEDPS = "updatedps" # Request refresh of DPS
PROTOCOL_VERSION_BYTES_31 = b"3.1"
PROTOCOL_VERSION_BYTES_33 = b"3.3"
PROTOCOL_33_HEADER = PROTOCOL_VERSION_BYTES_33 + 12 * b"\x00"
MESSAGE_HEADER_FMT = ">4I" # 4*uint32: prefix, seqno, cmd, length
MESSAGE_RECV_HEADER_FMT = ">5I" # 4*uint32: prefix, seqno, cmd, length, retcode
MESSAGE_END_FMT = ">2I" # 2*uint32: crc, suffix
PREFIX_VALUE = 0x000055AA
SUFFIX_VALUE = 0x0000AA55
HEARTBEAT_INTERVAL = 10
# DPS that are known to be safe to use with update_dps (0x12) command
UPDATE_DPS_WHITELIST = [18, 19, 20] # Socket (Wi-Fi)
# This is intended to match requests.json payload at
# https://github.com/codetheweb/tuyapi :
# type_0a devices require the 0a command as the status request
# type_0d devices require the 0d command as the status request, and the list of
# dps used set to null in the request payload (see generate_payload method)
# prefix: # Next byte is command byte ("hexByte") some zero padding, then length
# of remaining payload, i.e. command + suffix (unclear if multiple bytes used for
# length, zero padding implies could be more than one byte)
PAYLOAD_DICT = {
"type_0a": {
STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": 0x09, "command": {}},
UPDATEDPS: {"hexByte": 0x12, "command": {"dpId": [18, 19, 20]}},
},
"type_0d": {
STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": 0x09, "command": {}},
UPDATEDPS: {"hexByte": 0x12, "command": {"dpId": [18, 19, 20]}},
},
}
class TuyaLoggingAdapter(logging.LoggerAdapter):
"""Adapter that adds device id to all log points."""
def process(self, msg, kwargs):
"""Process log point and return output."""
dev_id = self.extra["device_id"]
return f"[{dev_id[0:3]}...{dev_id[-3:]}] {msg}", kwargs
class ContextualLogger:
"""Contextual logger adding device id to log points."""
def __init__(self):
"""Initialize a new ContextualLogger."""
self._logger = None
def set_logger(self, logger, device_id):
"""Set base logger to use."""
self._logger = TuyaLoggingAdapter(logger, {"device_id": device_id})
def debug(self, msg, *args):
"""Debug level log."""
return self._logger.log(logging.DEBUG, msg, *args)
def info(self, msg, *args):
"""Info level log."""
return self._logger.log(logging.INFO, msg, *args)
def warning(self, msg, *args):
"""Warning method log."""
return self._logger.log(logging.WARNING, msg, *args)
def error(self, msg, *args):
"""Error level log."""
return self._logger.log(logging.ERROR, msg, *args)
def exception(self, msg, *args):
"""Exception level log."""
return self._logger.exception(msg, *args)
def pack_message(msg):
"""Pack a TuyaMessage into bytes."""
# Create full message excluding CRC and suffix
buffer = (
struct.pack(
MESSAGE_HEADER_FMT,
PREFIX_VALUE,
msg.seqno,
msg.cmd,
len(msg.payload) + struct.calcsize(MESSAGE_END_FMT),
)
+ msg.payload
)
# Calculate CRC, add it together with suffix
buffer += struct.pack(MESSAGE_END_FMT, binascii.crc32(buffer), SUFFIX_VALUE)
return buffer
def unpack_message(data):
"""Unpack bytes into a TuyaMessage."""
header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
end_len = struct.calcsize(MESSAGE_END_FMT)
_, seqno, cmd, _, retcode = struct.unpack(
MESSAGE_RECV_HEADER_FMT, data[:header_len]
)
payload = data[header_len:-end_len]
crc, _ = struct.unpack(MESSAGE_END_FMT, data[-end_len:])
return TuyaMessage(seqno, cmd, retcode, payload, crc)
class AESCipher:
"""Cipher module for Tuya communication."""
def __init__(self, key):
"""Initialize a new AESCipher."""
self.block_size = 16
self.cipher = Cipher(algorithms.AES(key), modes.ECB(), default_backend())
def encrypt(self, raw, use_base64=True):
"""Encrypt data to be sent to device."""
encryptor = self.cipher.encryptor()
crypted_text = encryptor.update(self._pad(raw)) + encryptor.finalize()
return base64.b64encode(crypted_text) if use_base64 else crypted_text
def decrypt(self, enc, use_base64=True):
"""Decrypt data from device."""
if use_base64:
enc = base64.b64decode(enc)
decryptor = self.cipher.decryptor()
return self._unpad(decryptor.update(enc) + decryptor.finalize()).decode()
def _pad(self, data):
padnum = self.block_size - len(data) % self.block_size
return data + padnum * chr(padnum).encode()
@staticmethod
def _unpad(data):
return data[: -ord(data[len(data) - 1 :])]
class MessageDispatcher(ContextualLogger):
"""Buffer and dispatcher for Tuya messages."""
# Heartbeats always respond with sequence number 0, so they can't be waited for like
# other messages. This is a hack to allow waiting for heartbeats.
HEARTBEAT_SEQNO = -100
def __init__(self, dev_id, listener):
"""Initialize a new MessageBuffer."""
super().__init__()
self.buffer = b""
self.listeners = {}
self.listener = listener
self.set_logger(_LOGGER, dev_id)
def abort(self):
"""Abort all waiting clients."""
for key in self.listeners:
sem = self.listeners[key]
self.listeners[key] = None
# TODO: Received data and semahore should be stored separately
if isinstance(sem, asyncio.Semaphore):
sem.release()
async def wait_for(self, seqno, timeout=5):
"""Wait for response to a sequence number to be received and return it."""
if seqno in self.listeners:
raise Exception(f"listener exists for {seqno}")
self.debug("Waiting for sequence number %d", seqno)
self.listeners[seqno] = asyncio.Semaphore(0)
try:
await asyncio.wait_for(self.listeners[seqno].acquire(), timeout=timeout)
except asyncio.TimeoutError:
del self.listeners[seqno]
raise
return self.listeners.pop(seqno)
def add_data(self, data):
"""Add new data to the buffer and try to parse messages."""
self.buffer += data
header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
while self.buffer:
# Check if enough data for measage header
if len(self.buffer) < header_len:
break
# Parse header and check if enough data according to length in header
_, seqno, cmd, length, retcode = struct.unpack_from(
MESSAGE_RECV_HEADER_FMT, self.buffer
)
if len(self.buffer[header_len - 4 :]) < length:
break
# length includes payload length, retcode, crc and suffix
if (retcode & 0xFFFFFF00) != 0:
payload_start = header_len - 4
payload_length = length - struct.calcsize(MESSAGE_END_FMT)
else:
payload_start = header_len
payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT)
payload = self.buffer[payload_start : payload_start + payload_length]
crc, _ = struct.unpack_from(
MESSAGE_END_FMT,
self.buffer[payload_start + payload_length : payload_start + length],
)
self.buffer = self.buffer[header_len - 4 + length :]
self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc))
def _dispatch(self, msg):
"""Dispatch a message to someone that is listening."""
self.debug("Dispatching message %s", msg)
if msg.seqno in self.listeners:
self.debug("Dispatching sequence number %d", msg.seqno)
sem = self.listeners[msg.seqno]
self.listeners[msg.seqno] = msg
sem.release()
elif msg.cmd == 0x09:
self.debug("Got heartbeat response")
if self.HEARTBEAT_SEQNO in self.listeners:
sem = self.listeners[self.HEARTBEAT_SEQNO]
self.listeners[self.HEARTBEAT_SEQNO] = msg
sem.release()
elif msg.cmd == 0x12:
self.debug("Got normal updatedps response")
elif msg.cmd == 0x08:
self.debug("Got status update")
self.listener(msg)
else:
self.debug(
"Got message type %d for unknown listener %d: %s",
msg.cmd,
msg.seqno,
msg,
)
class TuyaListener(ABC):
"""Listener interface for Tuya device changes."""
@abstractmethod
def status_updated(self, status):
"""Device updated status."""
@abstractmethod
def disconnected(self):
"""Device disconnected."""
class EmptyListener(TuyaListener):
"""Listener doing nothing."""
def status_updated(self, status):
"""Device updated status."""
def disconnected(self):
"""Device disconnected."""
class TuyaProtocol(asyncio.Protocol, ContextualLogger):
"""Implementation of the Tuya protocol."""
def __init__(self, dev_id, local_key, protocol_version, on_connected, listener):
"""
Initialize a new TuyaInterface.
Args:
dev_id (str): The device id.
address (str): The network address.
local_key (str, optional): The encryption key. Defaults to None.
Attributes:
port (int): The port to connect to.
"""
super().__init__()
self.loop = asyncio.get_running_loop()
self.set_logger(_LOGGER, dev_id)
self.id = dev_id
self.local_key = local_key.encode("latin1")
self.version = protocol_version
self.dev_type = "type_0a"
self.dps_to_request = {}
self.cipher = AESCipher(self.local_key)
self.seqno = 0
self.transport = None
self.listener = weakref.ref(listener)
self.dispatcher = self._setup_dispatcher()
self.on_connected = on_connected
self.heartbeater = None
self.dps_cache = {}
def _setup_dispatcher(self):
def _status_update(msg):
decoded_message = self._decode_payload(msg.payload)
if "dps" in decoded_message:
self.dps_cache.update(decoded_message["dps"])
listener = self.listener and self.listener()
if listener is not None:
listener.status_updated(self.dps_cache)
return MessageDispatcher(self.id, _status_update)
def connection_made(self, transport):
"""Did connect to the device."""
async def heartbeat_loop():
"""Continuously send heart beat updates."""
self.debug("Started heartbeat loop")
while True:
try:
await self.heartbeat()
await asyncio.sleep(HEARTBEAT_INTERVAL)
except asyncio.CancelledError:
self.debug("Stopped heartbeat loop")
raise
except asyncio.TimeoutError:
self.debug("Heartbeat failed due to timeout, disconnecting")
break
except Exception as ex: # pylint: disable=broad-except
self.exception("Heartbeat failed (%s), disconnecting", ex)
break
transport = self.transport
self.transport = None
transport.close()
self.transport = transport
self.on_connected.set_result(True)
self.heartbeater = self.loop.create_task(heartbeat_loop())
def data_received(self, data):
"""Received data from device."""
self.dispatcher.add_data(data)
def connection_lost(self, exc):
"""Disconnected from device."""
self.debug("Connection lost: %s", exc)
try:
listener = self.listener and self.listener()
if listener is not None:
listener.disconnected()
except Exception: # pylint: disable=broad-except
self.exception("Failed to call disconnected callback")
async def close(self):
"""Close connection and abort all outstanding listeners."""
self.debug("Closing connection")
if self.heartbeater is not None:
self.heartbeater.cancel()
try:
await self.heartbeater
except asyncio.CancelledError:
pass
self.heartbeater = None
if self.dispatcher is not None:
self.dispatcher.abort()
self.dispatcher = None
if self.transport is not None:
transport = self.transport
self.transport = None
transport.close()
async def exchange(self, command, dps=None):
"""Send and receive a message, returning response from device."""
self.debug(
"Sending command %s (device type: %s)",
command,
self.dev_type,
)
payload = self._generate_payload(command, dps)
dev_type = self.dev_type
# Wait for special sequence number if heartbeat
seqno = (
MessageDispatcher.HEARTBEAT_SEQNO
if command == HEARTBEAT
else (self.seqno - 1)
)
self.transport.write(payload)
msg = await self.dispatcher.wait_for(seqno)
if msg is None:
self.debug("Wait was aborted for seqno %d", seqno)
return None
# TODO: Verify stuff, e.g. CRC sequence number?
payload = self._decode_payload(msg.payload)
# Perform a new exchange (once) if we switched device type
if dev_type != self.dev_type:
self.debug(
"Re-send %s due to device type change (%s -> %s)",
command,
dev_type,
self.dev_type,
)
return await self.exchange(command, dps)
return payload
async def status(self):
"""Return device status."""
status = await self.exchange(STATUS)
if status and "dps" in status:
self.dps_cache.update(status["dps"])
return self.dps_cache
async def heartbeat(self):
"""Send a heartbeat message."""
return await self.exchange(HEARTBEAT)
async def update_dps(self, dps=None):
"""
Request device to update index.
Args:
dps([int]): list of dps to update, default=detected&whitelisted
"""
if self.version == 3.3:
if dps is None:
if not self.dps_cache:
await self.detect_available_dps()
if self.dps_cache:
dps = [int(dp) for dp in self.dps_cache]
# filter non whitelisted dps
dps = list(set(dps).intersection(set(UPDATE_DPS_WHITELIST)))
self.debug("updatedps() entry (dps %s, dps_cache %s)", dps, self.dps_cache)
payload = self._generate_payload(UPDATEDPS, dps)
self.transport.write(payload)
return True
async def set_dp(self, value, dp_index):
"""
Set value (may be any type: bool, int or string) of any dps index.
Args:
dp_index(int): dps index to set
value: new value for the dps index
"""
return await self.exchange(SET, {str(dp_index): value})
async def set_dps(self, dps):
"""Set values for a set of datapoints."""
return await self.exchange(SET, dps)
async def detect_available_dps(self):
"""Return which datapoints are supported by the device."""
# type_0d devices need a sort of bruteforce querying in order to detect the
# list of available dps experience shows that the dps available are usually
# in the ranges [1-25] and [100-110] need to split the bruteforcing in
# different steps due to request payload limitation (max. length = 255)
self.dps_cache = {}
ranges = [(2, 11), (11, 21), (21, 31), (100, 111)]
for dps_range in ranges:
# dps 1 must always be sent, otherwise it might fail in case no dps is found
# in the requested range
self.dps_to_request = {"1": None}
self.add_dps_to_request(range(*dps_range))
try:
data = await self.status()
except Exception as ex:
self.exception("Failed to get status: %s", ex)
raise
if "dps" in data:
self.dps_cache.update(data["dps"])
if self.dev_type == "type_0a":
return self.dps_cache
self.debug("Detected dps: %s", self.dps_cache)
return self.dps_cache
def add_dps_to_request(self, dp_indicies):
"""Add a datapoint (DP) to be included in requests."""
if isinstance(dp_indicies, int):
self.dps_to_request[str(dp_indicies)] = None
else:
self.dps_to_request.update({str(index): None for index in dp_indicies})
def _decode_payload(self, payload):
if not payload:
payload = "{}"
elif payload.startswith(b"{"):
pass
elif payload.startswith(PROTOCOL_VERSION_BYTES_31):
payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header
# remove (what I'm guessing, but not confirmed is) 16-bytes of MD5
# hexdigest of payload
payload = self.cipher.decrypt(payload[16:])
elif self.version == 3.3:
if self.dev_type != "type_0a" or payload.startswith(
PROTOCOL_VERSION_BYTES_33
):
payload = payload[len(PROTOCOL_33_HEADER) :]
payload = self.cipher.decrypt(payload, False)
if "data unvalid" in payload:
self.dev_type = "type_0d"
self.debug(
"switching to dev_type %s",
self.dev_type,
)
return None
else:
raise Exception(f"Unexpected payload={payload}")
if not isinstance(payload, str):
payload = payload.decode()
self.debug("Decrypted payload: %s", payload)
return json.loads(payload)
def _generate_payload(self, command, data=None):
"""
Generate the payload to send.
Args:
command(str): The type of command.
This is one of the entries from payload_dict
data(dict, optional): The data to be send.
This is what will be passed via the 'dps' entry
"""
cmd_data = PAYLOAD_DICT[self.dev_type][command]
json_data = cmd_data["command"]
command_hb = cmd_data["hexByte"]
if "gwId" in json_data:
json_data["gwId"] = self.id
if "devId" in json_data:
json_data["devId"] = self.id
if "uid" in json_data:
json_data["uid"] = self.id # still use id, no separate uid
if "t" in json_data:
json_data["t"] = str(int(time.time()))
if data is not None:
if "dpId" in json_data:
json_data["dpId"] = data
else:
json_data["dps"] = data
elif command_hb == 0x0D:
json_data["dps"] = self.dps_to_request
payload = json.dumps(json_data).replace(" ", "").encode("utf-8")
self.debug("Send payload: %s", payload)
if self.version == 3.3:
payload = self.cipher.encrypt(payload, False)
if command_hb not in [0x0A, 0x12]:
# add the 3.3 header
payload = PROTOCOL_33_HEADER + payload
elif command == SET:
payload = self.cipher.encrypt(payload)
to_hash = (
b"data="
+ payload
+ b"||lpv="
+ PROTOCOL_VERSION_BYTES_31
+ b"||"
+ self.local_key
)
hasher = md5()
hasher.update(to_hash)
hexdigest = hasher.hexdigest()
payload = (
PROTOCOL_VERSION_BYTES_31
+ hexdigest[8:][:16].encode("latin1")
+ payload
)
msg = TuyaMessage(self.seqno, command_hb, 0, payload, 0)
self.seqno += 1
return pack_message(msg)
def __repr__(self):
"""Return internal string representation of object."""
return self.id
async def connect(
address,
device_id,
local_key,
protocol_version,
listener=None,
port=6668,
timeout=5,
):
"""Connect to a device."""
loop = asyncio.get_running_loop()
on_connected = loop.create_future()
_, protocol = await loop.create_connection(
lambda: TuyaProtocol(
device_id,
local_key,
protocol_version,
on_connected,
listener or EmptyListener(),
),
address,
port,
)
await asyncio.wait_for(on_connected, timeout=timeout)
return protocol