]> www.average.org Git - loctrkd.git/blobdiff - gps303/collector.py
collector: get rid of more protocol specifics
[loctrkd.git] / gps303 / collector.py
index c2efd79b82182e13319526a644abccb0e85071b9..df1a474f057e0fb68c9af477fc43c714b083de01 100644 (file)
@@ -3,7 +3,14 @@
 from configparser import ConfigParser
 from logging import getLogger
 from os import umask
-from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
+from socket import (
+    socket,
+    AF_INET6,
+    SOCK_STREAM,
+    SOL_SOCKET,
+    SO_KEEPALIVE,
+    SO_REUSEADDR,
+)
 from struct import pack
 from time import time
 from typing import Dict, List, Optional, Tuple
@@ -11,8 +18,9 @@ import zmq
 
 from . import common
 from .gps303proto import (
-    HIBERNATION,
-    LOGIN,
+    GPS303Conn,
+    is_goodbye_packet,
+    imei_from_packet,
     inline_response,
     parse_message,
     proto_of_message,
@@ -30,13 +38,15 @@ class Client:
     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
         self.sock = sock
         self.addr = addr
-        self.buffer = b""
+        self.stream = GPS303Conn()
         self.imei: Optional[str] = None
 
     def close(self) -> None:
         log.debug("Closing fd %d (IMEI %s)", self.sock.fileno(), self.imei)
         self.sock.close()
-        self.buffer = b""
+        rest = self.stream.close()
+        if rest:
+            log.warning("%d bytes in buffer on close: %s", len(rest), rest)
 
     def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]:
         """Read from the socket and parse complete messages"""
@@ -58,61 +68,26 @@ class Client:
             )
             return None
         when = time()
-        self.buffer += segment
-        if len(self.buffer) > MAXBUFFER:
-            # We are receiving junk. Let's drop it or we run out of memory.
-            log.warning("More than %d unparseable data, dropping", MAXBUFFER)
-            self.buffer = b""
         msgs = []
-        while True:
-            framestart = self.buffer.find(b"xx")
-            if framestart == -1:  # No frames, return whatever we have
-                break
-            if framestart > 0:  # Should not happen, report
+        for elem in self.stream.recv(segment):
+            if isinstance(elem, bytes):
+                msgs.append((when, self.addr, elem))
+            else:
                 log.warning(
-                    'Undecodable data (%d) "%s" from fd %d (IMEI %s)',
-                    framestart,
-                    self.buffer[:framestart][:64].hex(),
+                    "%s from fd %d (IMEI %s)",
+                    elem,
                     self.sock.fileno(),
                     self.imei,
                 )
-                self.buffer = self.buffer[framestart:]
-            # At this point, buffer starts with a packet
-            if len(self.buffer) < 6:  # no len and proto - cannot proceed
-                break
-            exp_end = self.buffer[2] + 3  # Expect '\r\n' here
-            frameend = 0
-            # Length field can legitimeely be much less than the
-            # length of the packet (e.g. WiFi positioning), but
-            # it _should not_ be greater. Still sometimes it is.
-            # Luckily, not by too much: by maybe two or three bytes?
-            # Do this embarrassing hack to avoid accidental match
-            # of some binary data in the packet against '\r\n'.
-            while True:
-                frameend = self.buffer.find(b"\r\n", frameend + 1)
-                if frameend == -1 or frameend >= (
-                    exp_end - 3
-                ):  # Found realistic match or none
-                    break
-            if frameend == -1:  # Incomplete frame, return what we have
-                break
-            packet = self.buffer[2:frameend]
-            self.buffer = self.buffer[frameend + 2 :]
-            if proto_of_message(packet) == LOGIN.PROTO:
-                self.imei = parse_message(packet).imei
-                log.info(
-                    "LOGIN from fd %d (IMEI %s)", self.sock.fileno(), self.imei
-                )
-            msgs.append((when, self.addr, packet))
         return msgs
 
     def send(self, buffer: bytes) -> None:
         try:
-            self.sock.send(b"xx" + buffer + b"\r\n")
+            self.sock.send(self.stream.enframe(buffer))
         except OSError as e:
             log.error(
                 "Sending to fd %d (IMEI %s): %s",
-                self.sock.fileno,
+                self.sock.fileno(),
                 self.imei,
                 e,
             )
@@ -146,8 +121,18 @@ class Clients:
             return None
         result = []
         for when, peeraddr, packet in msgs:
-            if proto_of_message(packet) == LOGIN.PROTO:  # Could do blindly...
-                if clnt.imei:
+            if clnt.imei is None:
+                imei = imei_from_packet(packet)
+                if imei is not None:
+                    log.info("LOGIN from fd %d (IMEI %s)", fd, imei)
+                    clnt.imei = imei
+                    oldclnt = self.by_imei.get(clnt.imei)
+                    if oldclnt is not None:
+                        log.info(
+                            "Orphaning fd %d with the same IMEI",
+                            oldclnt.sock.fileno(),
+                        )
+                        oldclnt.imei = None
                     self.by_imei[clnt.imei] = clnt
                 else:
                     log.warning(
@@ -206,6 +191,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                             break
                 elif sk == tcpfd:
                     clntsock, clntaddr = tcpl.accept()
+                    clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1)
                     topoll.append((clntsock, clntaddr))
                 elif fl & zmq.POLLIN:
                     received = clients.recv(sk)
@@ -224,9 +210,9 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                                     packet=packet,
                                 ).packed
                             )
-                            if proto == HIBERNATION.PROTO and handle_hibernate:
+                            if is_goodbye_packet(packet) and handle_hibernate:
                                 log.debug(
-                                    "HIBERNATION from fd %d (IMEI %s)",
+                                    "Goodbye from fd %d (IMEI %s)",
                                     sk,
                                     imei,
                                 )
@@ -239,9 +225,6 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                 else:
                     log.debug("Stray event: %s on socket %s", fl, sk)
             # poll queue consumed, make changes now
-            for fd in tostop:
-                poller.unregister(fd)  # type: ignore
-                clients.stop(fd)
             for zmsg in tosend:
                 zpub.send(
                     Bcast(
@@ -254,6 +237,9 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                 )
                 log.debug("Sending to the client: %s", zmsg)
                 clients.response(zmsg)
+            for fd in tostop:
+                poller.unregister(fd)  # type: ignore
+                clients.stop(fd)
             for clntsock, clntaddr in topoll:
                 fd = clients.add(clntsock, clntaddr)
                 poller.register(fd, flags=zmq.POLLIN)