]> www.average.org Git - loctrkd.git/blobdiff - gps303/wsgateway.py
typeckecking: annotate wsgateway.py
[loctrkd.git] / gps303 / wsgateway.py
index d3c0541158941dedd017aba669e40042fd11f928..c0a47768254bff754a2425ea1cd40ae597002a06 100644 (file)
@@ -1,14 +1,17 @@
 """ Websocket Gateway """
 
+from configparser import ConfigParser
 from datetime import datetime, timezone
 from json import dumps, loads
 from logging import getLogger
 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
 from time import time
+from typing import Any, cast, Dict, List, Optional, Set, Tuple
 from wsproto import ConnectionType, WSConnection
 from wsproto.events import (
     AcceptConnection,
     CloseConnection,
+    Event,
     Message,
     Ping,
     Request,
@@ -21,6 +24,7 @@ from . import common
 from .evstore import initdb, fetch
 from .gps303proto import (
     GPS_POSITIONING,
+    STATUS,
     WIFI_POSITIONING,
     parse_message,
 )
@@ -30,16 +34,17 @@ log = getLogger("gps303/wsgateway")
 htmlfile = None
 
 
-def backlog(imei, numback):
+def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
     result = []
     for is_incoming, timestamp, packet in fetch(
         imei,
-        ((True, GPS_POSITIONING.PROTO), (False, WIFI_POSITIONING.PROTO)),
+        [(True, GPS_POSITIONING.PROTO), (False, WIFI_POSITIONING.PROTO)],
         numback,
     ):
         msg = parse_message(packet, is_incoming=is_incoming)
         result.append(
             {
+                "type": "location",
                 "imei": imei,
                 "timestamp": str(
                     datetime.fromtimestamp(timestamp).astimezone(
@@ -48,12 +53,15 @@ def backlog(imei, numback):
                 ),
                 "longitude": msg.longitude,
                 "latitude": msg.latitude,
+                "accuracy": "gps"
+                if isinstance(msg, GPS_POSITIONING)
+                else "approximate",
             }
         )
     return result
 
 
-def try_http(data, fd, e):
+def try_http(data: bytes, fd: int, e: Exception) -> bytes:
     global htmlfile
     try:
         lines = data.decode().split("\r\n")
@@ -68,19 +76,14 @@ def try_http(data, fd, e):
             fd,
             headers,
         )
-        try:
-            pos = resource.index("?")
-            resource = resource[:pos]
-        except ValueError:
-            pass
         if op == "GET":
             if htmlfile is None:
                 return (
                     f"{proto} 500 No data configured\r\n"
                     f"Content-Type: text/plain\r\n\r\n"
-                    f"HTML data not configure on the server\r\n".encode()
+                    f"HTML data not configured on the server\r\n".encode()
                 )
-            elif resource == "/":
+            else:
                 try:
                     with open(htmlfile, "rb") as fl:
                         htmldata = fl.read()
@@ -96,12 +99,6 @@ def try_http(data, fd, e):
                         f"Content-Type: text/plain\r\n\r\n"
                         f"HTML file could not be opened\r\n".encode()
                     )
-            else:
-                return (
-                    f"{proto} 404 File not found\r\n"
-                    f"Content-Type: text/plain\r\n\r\n"
-                    f'We can only serve "/"\r\n'.encode()
-                )
         else:
             return (
                 f"{proto} 400 Bad request\r\n"
@@ -116,22 +113,22 @@ def try_http(data, fd, e):
 class Client:
     """Websocket connection to the client"""
 
-    def __init__(self, sock, addr):
+    def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
         self.sock = sock
         self.addr = addr
         self.ws = WSConnection(ConnectionType.SERVER)
         self.ws_data = b""
         self.ready = False
-        self.imeis = set()
+        self.imeis: Set[str] = set()
 
-    def close(self):
+    def close(self) -> None:
         log.debug("Closing fd %d", self.sock.fileno())
         self.sock.close()
 
-    def recv(self):
+    def recv(self) -> Optional[List[Dict[str, Any]]]:
         try:
             data = self.sock.recv(4096)
-        except OSError:
+        except OSError as e:
             log.warning(
                 "Reading from fd %d: %s",
                 self.sock.fileno(),
@@ -184,7 +181,7 @@ class Client:
                     log.warning("%s on fd %d", event, self.sock.fileno())
         return msgs
 
-    def wants(self, imei):
+    def wants(self, imei: str) -> bool:
         log.debug(
             "wants %s? set is %s on fd %d",
             imei,
@@ -193,11 +190,11 @@ class Client:
         )
         return imei in self.imeis
 
-    def send(self, message):
+    def send(self, message: Dict[str, Any]) -> None:
         if self.ready and message["imei"] in self.imeis:
             self.ws_data += self.ws.send(Message(data=dumps(message)))
 
-    def write(self):
+    def write(self) -> bool:
         if self.ws_data:
             try:
                 sent = self.sock.send(self.ws_data)
@@ -213,26 +210,26 @@ class Client:
 
 
 class Clients:
-    def __init__(self):
-        self.by_fd = {}
+    def __init__(self) -> None:
+        self.by_fd: Dict[int, Client] = {}
 
-    def add(self, clntsock, clntaddr):
+    def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
         fd = clntsock.fileno()
         log.info("Start serving fd %d from %s", fd, clntaddr)
         self.by_fd[fd] = Client(clntsock, clntaddr)
         return fd
 
-    def stop(self, fd):
+    def stop(self, fd: int) -> None:
         clnt = self.by_fd[fd]
         log.info("Stop serving fd %d", clnt.sock.fileno())
         clnt.close()
         del self.by_fd[fd]
 
-    def recv(self, fd):
+    def recv(self, fd: int) -> Optional[List[Dict[str, Any]]]:
         clnt = self.by_fd[fd]
         return clnt.recv()
 
-    def send(self, msg):
+    def send(self, msg: Dict[str, Any]) -> Set[int]:
         towrite = set()
         for fd, clnt in self.by_fd.items():
             if clnt.wants(msg["imei"]):
@@ -240,27 +237,28 @@ class Clients:
                 towrite.add(fd)
         return towrite
 
-    def write(self, towrite):
+    def write(self, towrite: Set[int]) -> Set[int]:
         waiting = set()
         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
-            if clnt.write():
+            if clnt and clnt.write():
                 waiting.add(fd)
         return waiting
 
-    def subs(self):
+    def subs(self) -> Set[str]:
         result = set()
         for clnt in self.by_fd.values():
             result |= clnt.imeis
         return result
 
 
-def runserver(conf):
+def runserver(conf: ConfigParser) -> None:
     global htmlfile
 
     initdb(conf.get("storage", "dbfn"))
     htmlfile = conf.get("wsgateway", "htmlfile")
-    zctx = zmq.Context()
-    zsub = zctx.socket(zmq.SUB)
+    # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
+    zctx = zmq.Context()  # type: ignore
+    zsub = zctx.socket(zmq.SUB)  # type: ignore
     zsub.connect(conf.get("collector", "publishurl"))
     tcpl = socket(AF_INET6, SOCK_STREAM)
     tcpl.setblocking(False)
@@ -268,13 +266,13 @@ def runserver(conf):
     tcpl.bind(("", conf.getint("wsgateway", "port")))
     tcpl.listen(5)
     tcpfd = tcpl.fileno()
-    poller = zmq.Poller()
+    poller = zmq.Poller()  # type: ignore
     poller.register(zsub, flags=zmq.POLLIN)
     poller.register(tcpfd, flags=zmq.POLLIN)
     clients = Clients()
-    activesubs = set()
+    activesubs: Set[str] = set()
     try:
-        towait = set()
+        towait: Set[int] = set()
         while True:
             neededsubs = clients.subs()
             for imei in neededsubs - activesubs:
@@ -286,6 +284,10 @@ def runserver(conf):
                     zmq.SUBSCRIBE,
                     topic(WIFI_POSITIONING.PROTO, False, imei),
                 )
+                zsub.setsockopt(
+                    zmq.SUBSCRIBE,
+                    topic(STATUS.PROTO, True, imei),
+                )
             for imei in activesubs - neededsubs:
                 zsub.setsockopt(
                     zmq.UNSUBSCRIBE,
@@ -295,6 +297,10 @@ def runserver(conf):
                     zmq.UNSUBSCRIBE,
                     topic(WIFI_POSITIONING.PROTO, False, imei),
                 )
+                zsub.setsockopt(
+                    zmq.UNSUBSCRIBE,
+                    topic(STATUS.PROTO, True, imei),
+                )
             activesubs = neededsubs
             log.debug("Subscribed to: %s", activesubs)
             tosend = []
@@ -309,18 +315,36 @@ def runserver(conf):
                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
                             msg = parse_message(zmsg.packet, zmsg.is_incoming)
                             log.debug("Got %s with %s", zmsg, msg)
-                            tosend.append(
-                                {
-                                    "imei": zmsg.imei,
-                                    "timestamp": str(
-                                        datetime.fromtimestamp(
-                                            zmsg.when
-                                        ).astimezone(tz=timezone.utc)
-                                    ),
-                                    "longitude": msg.longitude,
-                                    "latitude": msg.latitude,
-                                }
-                            )
+                            if isinstance(msg, STATUS):
+                                tosend.append(
+                                    {
+                                        "type": "status",
+                                        "imei": zmsg.imei,
+                                        "timestamp": str(
+                                            datetime.fromtimestamp(
+                                                zmsg.when
+                                            ).astimezone(tz=timezone.utc)
+                                        ),
+                                        "battery": msg.batt,
+                                    }
+                                )
+                            else:
+                                tosend.append(
+                                    {
+                                        "type": "location",
+                                        "imei": zmsg.imei,
+                                        "timestamp": str(
+                                            datetime.fromtimestamp(
+                                                zmsg.when
+                                            ).astimezone(tz=timezone.utc)
+                                        ),
+                                        "longitude": msg.longitude,
+                                        "latitude": msg.latitude,
+                                        "accuracy": "gps"
+                                        if zmsg.is_incoming
+                                        else "approximate",
+                                    }
+                                )
                         except zmq.Again:
                             break
                 elif sk == tcpfd:
@@ -331,13 +355,14 @@ def runserver(conf):
                     if received is None:
                         log.debug("Client gone from fd %d", sk)
                         tostop.append(sk)
-                        towait.discard(fd)
+                        towait.discard(sk)
                     else:
-                        for msg in received:
-                            log.debug("Received from %d: %s", sk, msg)
-                            if msg.get("type", None) == "subscribe":
-                                imeis = msg.get("imei")
-                                numback = msg.get("backlog", 5)
+                        for wsmsg in received:
+                            log.debug("Received from %d: %s", sk, wsmsg)
+                            if wsmsg.get("type", None) == "subscribe":
+                                # Have to live w/o typeckeding from json
+                                imeis = cast(List[str], wsmsg.get("imei"))
+                                numback: int = wsmsg.get("backlog", 5)
                                 for imei in imeis:
                                     tosend.extend(backlog(imei, numback))
                         towrite.add(sk)
@@ -349,11 +374,11 @@ def runserver(conf):
                     log.debug("Stray event: %s on socket %s", fl, sk)
             # poll queue consumed, make changes now
             for fd in tostop:
-                poller.unregister(fd)
+                poller.unregister(fd)  # type: ignore
                 clients.stop(fd)
-            for zmsg in tosend:
-                log.debug("Sending to the clients: %s", zmsg)
-                towrite |= clients.send(zmsg)
+            for wsmsg in tosend:
+                log.debug("Sending to the clients: %s", wsmsg)
+                towrite |= clients.send(wsmsg)
             for clntsock, clntaddr in topoll:
                 fd = clients.add(clntsock, clntaddr)
                 poller.register(fd, flags=zmq.POLLIN)
@@ -367,9 +392,9 @@ def runserver(conf):
                 morewait,
             )
             for fd in morewait - trywrite:  # new fds waiting for write
-                poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
+                poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
             for fd in trywrite - morewait:  # no longer waiting for write
-                poller.modify(fd, flags=zmq.POLLIN)
+                poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
             towait &= trywrite
             towait |= morewait
     except KeyboardInterrupt: