]> www.average.org Git - loctrkd.git/blobdiff - loctrkd/collector.py
Update changelog for 2.00 release
[loctrkd.git] / loctrkd / collector.py
index 98345007b708ad446c80c39d04d42fe7c52fd6e1..216134e57850d96d762286f23db43cafd9ea78c7 100644 (file)
@@ -14,10 +14,11 @@ from socket import (
 )
 from struct import pack
 from time import time
-from typing import Any, cast, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
 import zmq
 
 from . import common
+from .protomodule import ProtoModule
 from .zmsg import Bcast, Resp
 
 log = getLogger("loctrkd/collector")
@@ -25,54 +26,10 @@ log = getLogger("loctrkd/collector")
 MAXBUFFER: int = 4096
 
 
-class ProtoModule:
-    class Stream:
-        def recv(self, segment: bytes) -> List[Union[bytes, str]]:
-            ...
-
-        def close(self) -> bytes:
-            ...
-
-    @staticmethod
-    def enframe(buffer: bytes, imei: Optional[str] = None) -> bytes:
-        ...
-
-    @staticmethod
-    def probe_buffer(buffer: bytes) -> bool:
-        ...
-
-    @staticmethod
-    def parse_message(packet: bytes, is_incoming: bool = True) -> Any:
-        ...
-
-    @staticmethod
-    def inline_response(packet: bytes) -> Optional[bytes]:
-        ...
-
-    @staticmethod
-    def is_goodbye_packet(packet: bytes) -> bool:
-        ...
-
-    @staticmethod
-    def imei_from_packet(packet: bytes) -> Optional[str]:
-        ...
-
-    @staticmethod
-    def proto_of_message(packet: bytes) -> str:
-        ...
-
-    @staticmethod
-    def proto_by_name(name: str) -> int:
-        ...
-
-
-pmods: List[ProtoModule] = []
-
-
 class Client:
     """Connected socket to the terminal plus buffer and metadata"""
 
-    def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
+    def __init__(self, sock: socket, addr: Any) -> None:
         self.sock = sock
         self.addr = addr
         self.pmod: Optional[ProtoModule] = None
@@ -87,11 +44,11 @@ class Client:
         else:
             rest = b""
         if rest:
-            log.warning(
+            log.info(
                 "%d bytes in buffer on close: %s", len(rest), rest[:64].hex()
             )
 
-    def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]:
+    def recv(self) -> Optional[List[Tuple[float, Any, bytes]]]:
         """Read from the socket and parse complete messages"""
         try:
             segment = self.sock.recv(MAXBUFFER)
@@ -111,11 +68,9 @@ class Client:
             )
             return None
         if self.stream is None:
-            for pmod in pmods:
-                if pmod.probe_buffer(segment):
-                    self.pmod = pmod
-                    self.stream = pmod.Stream()
-                    break
+            self.pmod = common.probe_pmod(segment)
+            if self.pmod is not None:
+                self.stream = self.pmod.Stream()
         if self.stream is None:
             log.info(
                 "unrecognizable %d bytes of data %s from fd %d",
@@ -130,7 +85,7 @@ class Client:
             if isinstance(elem, bytes):
                 msgs.append((when, self.addr, elem))
             else:
-                log.warning(
+                log.info(
                     "%s from fd %d (IMEI %s)",
                     elem,
                     self.sock.fileno(),
@@ -156,25 +111,32 @@ class Clients:
         self.by_fd: Dict[int, Client] = {}
         self.by_imei: Dict[str, Client] = {}
 
-    def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
+    def fds(self) -> Set[int]:
+        return set(self.by_fd.keys())
+
+    def add(self, clntsock: socket, clntaddr: Any) -> int:
         fd = clntsock.fileno()
         log.info("Start serving fd %d from %s", fd, clntaddr)
         self.by_fd[fd] = Client(clntsock, clntaddr)
         return fd
 
     def stop(self, fd: int) -> None:
+        if fd not in self.by_fd:
+            log.debug("Fd %d is not served, ingore stop", fd)
+            return
         clnt = self.by_fd[fd]
         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
         clnt.close()
-        if clnt.imei:
+        if clnt.imei and self.by_imei[clnt.imei] == clnt:  # could be replaced
             del self.by_imei[clnt.imei]
         del self.by_fd[fd]
 
     def recv(
         self, fd: int
-    ) -> Optional[
-        List[Tuple[ProtoModule, Optional[str], float, Tuple[str, int], bytes]]
-    ]:
+    ) -> Optional[List[Tuple[ProtoModule, Optional[str], float, Any, bytes]]]:
+        if fd not in self.by_fd:
+            log.debug("Client at fd %d gone, ingore event", fd)
+            return None
         clnt = self.by_fd[fd]
         msgs = clnt.recv()
         if msgs is None:
@@ -189,18 +151,11 @@ class Clients:
                     clnt.imei = imei
                     oldclnt = self.by_imei.get(clnt.imei)
                     if oldclnt is not None:
-                        log.info(
-                            "Orphaning fd %d with the same IMEI",
-                            oldclnt.sock.fileno(),
-                        )
+                        oldfd = oldclnt.sock.fileno()
+                        log.info("Removing stale connection on fd %d", oldfd)
                         oldclnt.imei = None
+                        self.stop(oldfd)
                     self.by_imei[clnt.imei] = clnt
-                else:
-                    log.warning(
-                        "Login message from %s: %s, but client imei unfilled",
-                        peeraddr,
-                        packet,
-                    )
             result.append((clnt.pmod, clnt.imei, when, peeraddr, packet))
             log.debug(
                 "Received from %s (IMEI %s): %s",
@@ -221,11 +176,6 @@ class Clients:
 
 
 def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
-    global pmods
-    pmods = [
-        cast(ProtoModule, import_module("." + modnm, __package__))
-        for modnm in conf.get("collector", "protocols").split(",")
-    ]
     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
     zctx = zmq.Context()  # type: ignore
     zpub = zctx.socket(zmq.PUB)  # type: ignore
@@ -243,11 +193,11 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
     poller.register(zpull, flags=zmq.POLLIN)
     poller.register(tcpfd, flags=zmq.POLLIN)
     clients = Clients()
+    pollingfds: Set[int] = set()
     try:
         while True:
-            tosend = []
-            topoll = []
-            tostop = []
+            tosend: List[Resp] = []
+            toadd: List[Tuple[socket, Any]] = []
             events = poller.poll(1000)
             for sk, fl in events:
                 if sk is zpull:
@@ -261,18 +211,19 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                 elif sk == tcpfd:
                     clntsock, clntaddr = tcpl.accept()
                     clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
-                    topoll.append((clntsock, clntaddr))
+                    toadd.append((clntsock, clntaddr))
                 elif fl & zmq.POLLIN:
                     received = clients.recv(sk)
                     if received is None:
                         log.debug("Terminal gone from fd %d", sk)
-                        tostop.append(sk)
+                        clients.stop(sk)
                     else:
                         for pmod, imei, when, peeraddr, packet in received:
                             proto = pmod.proto_of_message(packet)
                             zpub.send(
                                 Bcast(
                                     proto=proto,
+                                    pmod=pmod.PMODNAME,
                                     imei=imei,
                                     when=when,
                                     peeraddr=peeraddr,
@@ -288,7 +239,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                                     sk,
                                     imei,
                                 )
-                                tostop.append(sk)
+                                clients.stop(sk)
                             respmsg = pmod.inline_response(packet)
                             if respmsg is not None:
                                 tosend.append(
@@ -305,17 +256,19 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                         Bcast(
                             is_incoming=False,
                             proto=rpmod.proto_of_message(zmsg.packet),
-                            when=zmsg.when,
+                            pmod=pmod.PMODNAME,
                             imei=zmsg.imei,
+                            when=zmsg.when,
                             packet=zmsg.packet,
                         ).packed
                     )
-            for fd in tostop:
+            for fd in pollingfds - clients.fds():
                 poller.unregister(fd)  # type: ignore
-                clients.stop(fd)
-            for clntsock, clntaddr in topoll:
+            for clntsock, clntaddr in toadd:
                 fd = clients.add(clntsock, clntaddr)
+            for fd in clients.fds() - pollingfds:
                 poller.register(fd, flags=zmq.POLLIN)
+            pollingfds = clients.fds()
     except KeyboardInterrupt:
         zpub.close()
         zpull.close()