]> www.average.org Git - loctrkd.git/commitdiff
abstract protocol selection in `common`
authorEugene Crosser <crosser@average.org>
Wed, 27 Jul 2022 22:41:54 +0000 (00:41 +0200)
committerEugene Crosser <crosser@average.org>
Wed, 27 Jul 2022 22:41:54 +0000 (00:41 +0200)
loctrkd/__main__.py
loctrkd/collector.py
loctrkd/common.py
loctrkd/mkgpx.py
loctrkd/termconfig.py
loctrkd/watch.py
test/common.py

index bba38ff1f1321f3ac76a62e69eefebe7db8c0bc9..8808c8a31c092fdc78dc2a674f5932740622a865 100644 (file)
@@ -17,17 +17,9 @@ from .zmsg import Bcast, Resp
 log = getLogger("loctrkd")
 
 
-pmods: List[ProtoModule] = []
-
-
 def main(
     conf: ConfigParser, opts: List[Tuple[str, str]], args: List[str]
 ) -> None:
-    global pmods
-    pmods = [
-        cast(ProtoModule, import_module("." + modnm, __package__))
-        for modnm in conf.get("common", "protocols").split(",")
-    ]
     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
     zctx = zmq.Context()  # type: ignore
     zpush = zctx.socket(zmq.PUSH)  # type: ignore
@@ -40,12 +32,8 @@ def main(
     imei = args[0]
     cmd = args[1]
     args = args[2:]
-    handled = False
-    for pmod in pmods:
-        if pmod.proto_handled(cmd):
-            handled = True
-            break
-    if not handled:
+    pmod = common.pmod_for_proto(cmd)
+    if pmod is None:
         raise NotImplementedError(f"No protocol can handle {cmd}")
     cls = pmod.class_by_prefix(cmd)
     if isinstance(cls, list):
index 17c98d5d98326af76dad80c2dc760bb3fb19c4ce..788cb11ca1ee8b186838ea1ef882a5b6faf4ea83 100644 (file)
@@ -14,7 +14,7 @@ from socket import (
 )
 from struct import pack
 from time import time
-from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
 import zmq
 
 from . import common
@@ -26,9 +26,6 @@ log = getLogger("loctrkd/collector")
 MAXBUFFER: int = 4096
 
 
-pmods: List[ProtoModule] = []
-
-
 class Client:
     """Connected socket to the terminal plus buffer and metadata"""
 
@@ -71,11 +68,9 @@ class Client:
             )
             return None
         if self.stream is None:
-            for pmod in pmods:
-                if pmod.probe_buffer(segment):
-                    self.pmod = pmod
-                    self.stream = pmod.Stream()
-                    break
+            self.pmod = common.probe_pmod(segment)
+            if self.pmod is not None:
+                self.stream = self.pmod.Stream()
         if self.stream is None:
             log.info(
                 "unrecognizable %d bytes of data %s from fd %d",
@@ -181,11 +176,6 @@ class Clients:
 
 
 def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
-    global pmods
-    pmods = [
-        cast(ProtoModule, import_module("." + modnm, __package__))
-        for modnm in conf.get("common", "protocols").split(",")
-    ]
     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
     zctx = zmq.Context()  # type: ignore
     zpub = zctx.socket(zmq.PUB)  # type: ignore
index 227611c62fa4b8cf04f6bb22f2f8e4754cd744f2..02b520f620d7d82582df949361b572327de283ac 100644 (file)
@@ -1,16 +1,18 @@
 """ Common housekeeping for all daemons """
 
-from configparser import ConfigParser, SectionProxy
+from configparser import ConfigParser
+from importlib import import_module
 from getopt import getopt
 from logging import Formatter, getLogger, Logger, StreamHandler, DEBUG, INFO
 from logging.handlers import SysLogHandler
 from pkg_resources import get_distribution, DistributionNotFound
 from sys import argv, stderr, stdout
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
+
+from .protomodule import ProtoModule
 
 CONF = "/etc/loctrkd.conf"
-PORT = 4303
-DBFN = "/var/lib/loctrkd/loctrkd.sqlite"
+pmods: List[ProtoModule] = []
 
 try:
     version = get_distribution("loctrkd").version
@@ -18,13 +20,22 @@ except DistributionNotFound:
     version = "<local>"
 
 
+def init_protocols(conf: ConfigParser) -> None:
+    global pmods
+    pmods = [
+        cast(ProtoModule, import_module("." + modnm, __package__))
+        for modnm in conf.get("common", "protocols").split(",")
+    ]
+
+
 def init(
     log: Logger, opts: Optional[List[Tuple[str, str]]] = None
 ) -> ConfigParser:
     if opts is None:
         opts, _ = getopt(argv[1:], "c:d")
     dopts = dict(opts)
-    conf = readconfig(dopts["-c"] if "-c" in dopts else CONF)
+    conf = ConfigParser()
+    conf.read(dopts["-c"] if "-c" in dopts else CONF)
     log.setLevel(DEBUG if "-d" in dopts else INFO)
     if stdout.isatty():
         fhdl = StreamHandler(stderr)
@@ -40,59 +51,19 @@ def init(
         )
         log.addHandler(lhdl)
         log.info("%s starting with options: %s", version, dopts)
+    init_protocols(conf)
     return conf
 
 
-def readconfig(fname: str) -> ConfigParser:
-    config = ConfigParser()
-    config["collector"] = {
-        "port": str(PORT),
-    }
-    config["storage"] = {
-        "dbfn": DBFN,
-    }
-    config["termconfig"] = {}
-    config.read(fname)
-    return config
-
-
-def normconf(section: SectionProxy) -> Dict[str, Any]:
-    result: Dict[str, Any] = {}
-    for key, val in section.items():
-        vals = val.split("\n")
-        if len(vals) > 1 and vals[0] == "":
-            vals = vals[1:]
-        lst: List[Union[str, int]] = []
-        for el in vals:
-            try:
-                lst.append(int(el, 0))
-            except ValueError:
-                if el[0] == '"' and el[-1] == '"':
-                    el = el.strip('"').rstrip('"')
-                lst.append(el)
-        if not (
-            all([isinstance(x, int) for x in lst])
-            or all([isinstance(x, str) for x in lst])
-        ):
-            raise ValueError(
-                "Values of %s - %s are of different type", key, vals
-            )
-        if len(lst) == 1:
-            result[key] = lst[0]
-        else:
-            result[key] = lst
-    return result
-
-
-if __name__ == "__main__":
-    from sys import argv
+def probe_pmod(segment: bytes) -> Optional[ProtoModule]:
+    for pmod in pmods:
+        if pmod.probe_buffer(segment):
+            return pmod
+    return None
 
-    def _print_config(conf: ConfigParser) -> None:
-        for section in conf.sections():
-            print("section", section)
-            for option in conf.options(section):
-                print("    ", option, conf[section][option])
 
-    conf = readconfig(argv[1])
-    _print_config(conf)
-    print(normconf(conf["termconfig"]))
+def pmod_for_proto(proto: str) -> Optional[ProtoModule]:
+    for pmod in pmods:
+        if pmod.proto_handled(proto):
+            return pmod
+    return None
index 35ff77e8b2a23dfbfc8682376f7a1074133003ab..6d1ee27ab07024a471ca45b98069709c354b6ab5 100644 (file)
@@ -19,17 +19,9 @@ from .protomodule import ProtoModule
 log = getLogger("loctrkd/mkgpx")
 
 
-pmods: List[ProtoModule] = []
-
-
 def main(
     conf: ConfigParser, opts: List[Tuple[str, str]], args: List[str]
 ) -> None:
-    global pmods
-    pmods = [
-        cast(ProtoModule, import_module("." + modnm, __package__))
-        for modnm in conf.get("common", "protocols").split(",")
-    ]
     db = connect(conf.get("storage", "dbfn"))
     c = db.cursor()
     c.execute(
@@ -52,9 +44,9 @@ def main(
     )
 
     for tstamp, is_incoming, proto, packet in c:
-        for pmod in pmods:
-            if pmod.proto_handled(proto):
-                msg = pmod.parse_message(packet, is_incoming=is_incoming)
+        pmod = common.pmod_for_proto(proto)
+        if pmod is not None:
+            msg = pmod.parse_message(packet, is_incoming=is_incoming)
         lat, lon = msg.latitude, msg.longitude
         isotime = (
             datetime.fromtimestamp(tstamp)
index b1ea80ae691bb1b93174c50b5d58d8150d9065fd..4968e28ee86cee6e6a46cd5a4fa8a2068d29457d 100644 (file)
@@ -1,9 +1,10 @@
 """ For when responding to the terminal is not trivial """
 
-from configparser import ConfigParser
+from configparser import ConfigParser, SectionProxy
 from datetime import datetime, timezone
 from logging import getLogger
 from struct import pack
+from typing import Any, Dict, List, Union
 import zmq
 
 from . import common
@@ -14,6 +15,34 @@ from .zmsg import Bcast, Resp, topic
 log = getLogger("loctrkd/termconfig")
 
 
+def normconf(section: SectionProxy) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for key, val in section.items():
+        vals = val.split("\n")
+        if len(vals) > 1 and vals[0] == "":
+            vals = vals[1:]
+        lst: List[Union[str, int]] = []
+        for el in vals:
+            try:
+                lst.append(int(el, 0))
+            except ValueError:
+                if el[0] == '"' and el[-1] == '"':
+                    el = el.strip('"').rstrip('"')
+                lst.append(el)
+        if not (
+            all([isinstance(x, int) for x in lst])
+            or all([isinstance(x, str) for x in lst])
+        ):
+            raise ValueError(
+                "Values of %s - %s are of different type", key, vals
+            )
+        if len(lst) == 1:
+            result[key] = lst[0]
+        else:
+            result[key] = lst
+    return result
+
+
 def runserver(conf: ConfigParser) -> None:
     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
     zctx = zmq.Context()  # type: ignore
@@ -44,9 +73,9 @@ def runserver(conf: ConfigParser) -> None:
                     "%s does not expect externally provided response", msg
                 )
             if zmsg.imei is not None and conf.has_section(zmsg.imei):
-                termconfig = common.normconf(conf[zmsg.imei])
+                termconfig = normconf(conf[zmsg.imei])
             elif conf.has_section("termconfig"):
-                termconfig = common.normconf(conf["termconfig"])
+                termconfig = normconf(conf["termconfig"])
             else:
                 termconfig = {}
             kwargs = {}
index 6d3dcd9fa8ed7eb1f4eba1cdacd266549e9b2519..bda952c259da36299fd5f8859b3c33363983720b 100644 (file)
@@ -14,15 +14,7 @@ from .zmsg import Bcast
 log = getLogger("loctrkd/watch")
 
 
-pmods: List[ProtoModule] = []
-
-
 def runserver(conf: ConfigParser) -> None:
-    global pmods
-    pmods = [
-        cast(ProtoModule, import_module("." + modnm, __package__))
-        for modnm in conf.get("common", "protocols").split(",")
-    ]
     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
     zctx = zmq.Context()  # type: ignore
     zsub = zctx.socket(zmq.SUB)  # type: ignore
@@ -33,12 +25,12 @@ def runserver(conf: ConfigParser) -> None:
         while True:
             zmsg = Bcast(zsub.recv())
             print("I" if zmsg.is_incoming else "O", zmsg.proto, zmsg.imei)
-            for pmod in pmods:
-                if pmod.proto_handled(zmsg.proto):
-                    msg = pmod.parse_message(zmsg.packet, zmsg.is_incoming)
-                    print(msg)
-                    if zmsg.is_incoming and hasattr(msg, "rectified"):
-                        print(msg.rectified())
+            pmod = common.pmod_for_proto(zmsg.proto)
+            if pmod is not None:
+                msg = pmod.parse_message(zmsg.packet, zmsg.is_incoming)
+                print(msg)
+                if zmsg.is_incoming and hasattr(msg, "rectified"):
+                    print(msg.rectified())
     except KeyboardInterrupt:
         pass
 
index 284ca4d10c8509606dd6879f026c07f211104f83..58954a2b217acbd0fa5a930476cee1882dea8ff1 100644 (file)
@@ -23,6 +23,8 @@ from time import sleep
 from typing import Optional
 from unittest import TestCase
 
+from loctrkd.common import init_protocols
+
 NUMPORTS = 3
 
 
@@ -61,6 +63,7 @@ class TestWithServers(TestCase):
         self.conf["wsgateway"] = {
             "port": str(freeports[1]),
         }
+        init_protocols(self.conf)
         self.children = []
         for srvname in args:
             if srvname == "collector":