]> www.average.org Git - loctrkd.git/blobdiff - gps303/zmsg.py
fix zmq subscription topics
[loctrkd.git] / gps303 / zmsg.py
index 6e591ab65c2fb9a8c4ec7ff7a9bcab8d69e7de08..1fe18123c441739a9891b1433ec8bc276b5c876a 100644 (file)
@@ -3,31 +3,30 @@
 import ipaddress as ip
 from struct import pack, unpack
 
-__all__ = "Bcast", "Resp"
+__all__ = "Bcast", "Resp", "topic"
 
 
 def pack_peer(peeraddr):
-    saddr, port, _x, _y = peeraddr
-    addr6 = ip.ip_address(saddr)
-    addr = addr6.ipv4_mapped
-    if addr is None:
-        addr = addr6
-    return (
-        pack("B", addr.version)
-        + (addr.packed + b"\0\0\0\0\0\0\0\0\0\0\0\0")[:16]
-        + pack("!H", port)
-    )
+    try:
+        if peeraddr is None:
+            saddr = "::"
+            port = 0
+        else:
+            saddr, port, _x, _y = peeraddr
+        addr = ip.ip_address(saddr)
+    except ValueError:
+        saddr, port = peeraddr
+        a4 = ip.ip_address(saddr)
+        addr = ip.IPv6Address(b"\0\0\0\0\0\0\0\0\0\0\xff\xff" + a4.packed)
+    return addr.packed + pack("!H", port)
 
 
 def unpack_peer(buffer):
-    version = buffer[0]
-    if version not in (4, 6):
-        return None
-    if version == 4:
-        addr = ip.IPv4Address(buffer[1:5])
-    else:
-        addr = ip.IPv6Address(buffer[1:17])
-    port = unpack("!H", buffer[17:19])[0]
+    a6 = ip.IPv6Address(buffer[:16])
+    port = unpack("!H", buffer[16:])[0]
+    addr = a6.ipv4_mapped
+    if addr is None:
+        addr = a6
     return (addr, port)
 
 
@@ -63,22 +62,34 @@ class _Zmsg:
             ),
         )
 
+    def __eq__(self, other):
+        return all(
+            [getattr(self, k) == getattr(other, k) for k, _ in self.KWARGS]
+        )
+
     def decode(self, buffer):
-        raise RuntimeError(
-            self.__class__.__name__ + "must implement `encode()` method"
+        raise NotImplementedError(
+            self.__class__.__name__ + "must implement `decode()` method"
         )
 
     @property
     def packed(self):
-        raise RuntimeError(
-            self.__class__.__name__ + "must implement `encode()` method"
+        raise NotImplementedError(
+            self.__class__.__name__ + "must implement `packed()` property"
         )
 
 
+def topic(proto, is_incoming=True, imei=None):
+    return pack("BB", is_incoming, proto) + (
+        b"" if imei is None else pack("16s", imei.encode())
+    )
+
+
 class Bcast(_Zmsg):
     """Zmq message to broadcast what was received from the terminal"""
 
     KWARGS = (
+        ("is_incoming", True),
         ("proto", 256),
         ("imei", None),
         ("when", None),
@@ -89,8 +100,14 @@ class Bcast(_Zmsg):
     @property
     def packed(self):
         return (
-            pack("B", self.proto)
-            + ("0000000000000000" if self.imei is None else self.imei).encode()
+            pack(
+                "BB16s",
+                int(self.is_incoming),
+                self.proto,
+                "0000000000000000"
+                if self.imei is None
+                else self.imei.encode(),
+            )
             + (
                 b"\0\0\0\0\0\0\0\0"
                 if self.when is None
@@ -101,26 +118,39 @@ class Bcast(_Zmsg):
         )
 
     def decode(self, buffer):
-        self.proto = buffer[0]
-        self.imei = buffer[1:17].decode()
+        self.is_incoming = bool(buffer[0])
+        self.proto = buffer[1]
+        self.imei = buffer[2:18].decode()
         if self.imei == "0000000000000000":
             self.imei = None
-        self.when = unpack("!d", buffer[17:25])[0]
-        self.peeraddr = unpack_peer(buffer[25:44])
+        self.when = unpack("!d", buffer[18:26])[0]
+        self.peeraddr = unpack_peer(buffer[26:44])
         self.packet = buffer[44:]
 
 
 class Resp(_Zmsg):
     """Zmq message received from a third party to send to the terminal"""
 
-    KWARGS = (("imei", None), ("packet", b""))
+    KWARGS = (("imei", None), ("when", None), ("packet", b""))
 
     @property
     def packed(self):
         return (
-            "0000000000000000" if self.imei is None else self.imei.encode()
-        ) + self.packet
+            pack(
+                "16s",
+                "0000000000000000"
+                if self.imei is None
+                else self.imei.encode(),
+            )
+            + (
+                b"\0\0\0\0\0\0\0\0"
+                if self.when is None
+                else pack("!d", self.when)
+            )
+            + self.packet
+        )
 
     def decode(self, buffer):
         self.imei = buffer[:16].decode()
-        self.packet = buffer[16:]
+        self.when = unpack("!d", buffer[16:24])[0]
+        self.packet = buffer[24:]