]> www.average.org Git - loctrkd.git/blob - loctrkd/wsgateway.py
Adjust config to changing messaging topology
[loctrkd.git] / loctrkd / wsgateway.py
1 """ Websocket Gateway """
2
3 from configparser import ConfigParser
4 from datetime import datetime, timezone
5 from importlib import import_module
6 from json import dumps, loads
7 from logging import getLogger
8 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
9 from time import time
10 from typing import Any, cast, Dict, List, Optional, Set, Tuple
11 from wsproto import ConnectionType, WSConnection
12 from wsproto.events import (
13     AcceptConnection,
14     CloseConnection,
15     Event,
16     Message,
17     Ping,
18     Request,
19     TextMessage,
20 )
21 from wsproto.utilities import RemoteProtocolError
22 import zmq
23
24 from . import common
25 from .evstore import initdb, fetch
26 from .protomodule import ProtoModule
27 from .zmsg import Bcast, topic
28
29 log = getLogger("loctrkd/wsgateway")
30
31
32 htmlfile = None
33 pmods: List[ProtoModule] = []
34 selector: List[Tuple[bool, str]] = []
35
36
37 def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
38     result = []
39     for is_incoming, timestamp, proto, packet in fetch(
40         imei,
41         selector,
42         numback,
43     ):
44         for pmod in pmods:
45             if pmod.proto_handled(proto):
46                 msg = pmod.parse_message(packet, is_incoming=is_incoming)
47         result.append(
48             {
49                 "type": "location",
50                 "imei": imei,
51                 "timestamp": str(
52                     datetime.fromtimestamp(timestamp).astimezone(
53                         tz=timezone.utc
54                     )
55                 ),
56                 "longitude": msg.longitude,
57                 "latitude": msg.latitude,
58                 "accuracy": "gps"
59                 if True  # TODO isinstance(msg, GPS_POSITIONING)
60                 else "approximate",
61             }
62         )
63     return result
64
65
66 def try_http(data: bytes, fd: int, e: Exception) -> bytes:
67     global htmlfile
68     try:
69         lines = data.decode().split("\r\n")
70         request = lines[0]
71         headers = lines[1:]
72         op, resource, proto = request.split(" ")
73         log.debug(
74             "HTTP %s for %s, proto %s from fd %d, headers: %s",
75             op,
76             resource,
77             proto,
78             fd,
79             headers,
80         )
81         if op == "GET":
82             if htmlfile is None:
83                 return (
84                     f"{proto} 500 No data configured\r\n"
85                     f"Content-Type: text/plain\r\n\r\n"
86                     f"HTML data not configured on the server\r\n".encode()
87                 )
88             else:
89                 try:
90                     with open(htmlfile, "rb") as fl:
91                         htmldata = fl.read()
92                     length = len(htmldata)
93                     return (
94                         f"{proto} 200 Ok\r\n"
95                         f"Content-Type: text/html; charset=utf-8\r\n"
96                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
97                     ).encode("utf-8") + htmldata
98                 except OSError:
99                     return (
100                         f"{proto} 500 File not found\r\n"
101                         f"Content-Type: text/plain\r\n\r\n"
102                         f"HTML file could not be opened\r\n".encode()
103                     )
104         else:
105             return (
106                 f"{proto} 400 Bad request\r\n"
107                 "Content-Type: text/plain\r\n\r\n"
108                 "Bad request\r\n".encode()
109             )
110     except ValueError:
111         log.warning("Unparseable data from fd %d: %s", fd, data)
112         raise e
113
114
115 class Client:
116     """Websocket connection to the client"""
117
118     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
119         self.sock = sock
120         self.addr = addr
121         self.ws = WSConnection(ConnectionType.SERVER)
122         self.ws_data = b""
123         self.ready = False
124         self.imeis: Set[str] = set()
125
126     def close(self) -> None:
127         log.debug("Closing fd %d", self.sock.fileno())
128         self.sock.close()
129
130     def recv(self) -> Optional[List[Dict[str, Any]]]:
131         try:
132             data = self.sock.recv(4096)
133         except OSError as e:
134             log.warning(
135                 "Reading from fd %d: %s",
136                 self.sock.fileno(),
137                 e,
138             )
139             self.ws.receive_data(None)
140             return None
141         if not data:  # Client has closed connection
142             log.info(
143                 "EOF reading from fd %d",
144                 self.sock.fileno(),
145             )
146             self.ws.receive_data(None)
147             return None
148         try:
149             self.ws.receive_data(data)
150         except RemoteProtocolError as e:
151             log.debug(
152                 "Websocket error on fd %d, try plain http (%s)",
153                 self.sock.fileno(),
154                 e,
155             )
156             self.ws_data = try_http(data, self.sock.fileno(), e)
157             # this `write` is a hack - writing _ought_ to be done at the
158             # stage when all other writes are performed. But I could not
159             # arrange it so in a logical way. Let it stay this way. The
160             # whole http server affair is a hack anyway.
161             self.write()
162             log.debug("Sending HTTP response to %d", self.sock.fileno())
163             msgs = None
164         else:
165             msgs = []
166             for event in self.ws.events():
167                 if isinstance(event, Request):
168                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
169                     # self.ws_data += self.ws.send(event.response())  # Why not?!
170                     self.ws_data += self.ws.send(AcceptConnection())
171                     self.ready = True
172                 elif isinstance(event, (CloseConnection, Ping)):
173                     log.debug("%s on fd %d", event, self.sock.fileno())
174                     self.ws_data += self.ws.send(event.response())
175                 elif isinstance(event, TextMessage):
176                     log.debug("%s on fd %d", event, self.sock.fileno())
177                     msg = loads(event.data)
178                     msgs.append(msg)
179                     if msg.get("type", None) == "subscribe":
180                         self.imeis = set(msg.get("imei", []))
181                         log.debug(
182                             "subs list on fd %s is %s",
183                             self.sock.fileno(),
184                             self.imeis,
185                         )
186                 else:
187                     log.warning("%s on fd %d", event, self.sock.fileno())
188         return msgs
189
190     def wants(self, imei: str) -> bool:
191         log.debug(
192             "wants %s? set is %s on fd %d",
193             imei,
194             self.imeis,
195             self.sock.fileno(),
196         )
197         return imei in self.imeis
198
199     def send(self, message: Dict[str, Any]) -> None:
200         if self.ready and message["imei"] in self.imeis:
201             self.ws_data += self.ws.send(Message(data=dumps(message)))
202
203     def write(self) -> bool:
204         if self.ws_data:
205             try:
206                 sent = self.sock.send(self.ws_data)
207                 self.ws_data = self.ws_data[sent:]
208             except OSError as e:
209                 log.error(
210                     "Sending to fd %d: %s",
211                     self.sock.fileno(),
212                     e,
213                 )
214                 self.ws_data = b""
215         return bool(self.ws_data)
216
217
218 class Clients:
219     def __init__(self) -> None:
220         self.by_fd: Dict[int, Client] = {}
221
222     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
223         fd = clntsock.fileno()
224         log.info("Start serving fd %d from %s", fd, clntaddr)
225         self.by_fd[fd] = Client(clntsock, clntaddr)
226         return fd
227
228     def stop(self, fd: int) -> None:
229         clnt = self.by_fd[fd]
230         log.info("Stop serving fd %d", clnt.sock.fileno())
231         clnt.close()
232         del self.by_fd[fd]
233
234     def recv(self, fd: int) -> Optional[List[Dict[str, Any]]]:
235         clnt = self.by_fd[fd]
236         return clnt.recv()
237
238     def send(self, msg: Dict[str, Any]) -> Set[int]:
239         towrite = set()
240         for fd, clnt in self.by_fd.items():
241             if clnt.wants(msg["imei"]):
242                 clnt.send(msg)
243                 towrite.add(fd)
244         return towrite
245
246     def write(self, towrite: Set[int]) -> Set[int]:
247         waiting = set()
248         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
249             if clnt and clnt.write():
250                 waiting.add(fd)
251         return waiting
252
253     def subs(self) -> Set[str]:
254         result = set()
255         for clnt in self.by_fd.values():
256             result |= clnt.imeis
257         return result
258
259
260 def runserver(conf: ConfigParser) -> None:
261     global htmlfile, pmods, selector
262     pmods = [
263         cast(ProtoModule, import_module("." + modnm, __package__))
264         for modnm in conf.get("common", "protocols").split(",")
265     ]
266     for pmod in pmods:
267         for proto, is_incoming in pmod.exposed_protos():
268             if proto != "ZX:STATUS":  # TODO make it better
269                 selector.append((is_incoming, proto))
270     initdb(conf.get("storage", "dbfn"))
271     htmlfile = conf.get("wsgateway", "htmlfile", fallback=None)
272     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
273     zctx = zmq.Context()  # type: ignore
274     zsub = zctx.socket(zmq.SUB)  # type: ignore
275     zsub.connect(conf.get("collector", "publishurl"))
276     tcpl = socket(AF_INET6, SOCK_STREAM)
277     tcpl.setblocking(False)
278     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
279     tcpl.bind(("", conf.getint("wsgateway", "port")))
280     tcpl.listen(5)
281     tcpfd = tcpl.fileno()
282     poller = zmq.Poller()  # type: ignore
283     poller.register(zsub, flags=zmq.POLLIN)
284     poller.register(tcpfd, flags=zmq.POLLIN)
285     clients = Clients()
286     activesubs: Set[str] = set()
287     try:
288         towait: Set[int] = set()
289         while True:
290             neededsubs = clients.subs()
291             for pmod in pmods:
292                 for proto, is_incoming in pmod.exposed_protos():
293                     for imei in neededsubs - activesubs:
294                         zsub.setsockopt(
295                             zmq.SUBSCRIBE,
296                             topic(proto, is_incoming, imei),
297                         )
298                     for imei in activesubs - neededsubs:
299                         zsub.setsockopt(
300                             zmq.UNSUBSCRIBE,
301                             topic(proto, is_incoming, imei),
302                         )
303             activesubs = neededsubs
304             log.debug("Subscribed to: %s", activesubs)
305             tosend = []
306             topoll = []
307             tostop = []
308             towrite = set()
309             events = poller.poll()
310             for sk, fl in events:
311                 if sk is zsub:
312                     while True:
313                         try:
314                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
315                             for pmod in pmods:
316                                 if pmod.proto_handled(zmsg.proto):
317                                     msg = pmod.parse_message(
318                                         zmsg.packet, zmsg.is_incoming
319                                     )
320                             log.debug("Got %s with %s", zmsg, msg)
321                             if zmsg.proto == "ZX:STATUS":
322                                 tosend.append(
323                                     {
324                                         "type": "status",
325                                         "imei": zmsg.imei,
326                                         "timestamp": str(
327                                             datetime.fromtimestamp(
328                                                 zmsg.when
329                                             ).astimezone(tz=timezone.utc)
330                                         ),
331                                         "battery": msg.batt,
332                                     }
333                                 )
334                             else:
335                                 tosend.append(
336                                     {
337                                         "type": "location",
338                                         "imei": zmsg.imei,
339                                         "timestamp": str(
340                                             datetime.fromtimestamp(
341                                                 zmsg.when
342                                             ).astimezone(tz=timezone.utc)
343                                         ),
344                                         "longitude": msg.longitude,
345                                         "latitude": msg.latitude,
346                                         "accuracy": "gps"
347                                         if zmsg.is_incoming
348                                         else "approximate",
349                                     }
350                                 )
351                         except zmq.Again:
352                             break
353                 elif sk == tcpfd:
354                     clntsock, clntaddr = tcpl.accept()
355                     topoll.append((clntsock, clntaddr))
356                 elif fl & zmq.POLLIN:
357                     received = clients.recv(sk)
358                     if received is None:
359                         log.debug("Client gone from fd %d", sk)
360                         tostop.append(sk)
361                         towait.discard(sk)
362                     else:
363                         for wsmsg in received:
364                             log.debug("Received from %d: %s", sk, wsmsg)
365                             if wsmsg.get("type", None) == "subscribe":
366                                 # Have to live w/o typeckeding from json
367                                 imeis = cast(List[str], wsmsg.get("imei"))
368                                 numback: int = wsmsg.get("backlog", 5)
369                                 for imei in imeis:
370                                     tosend.extend(backlog(imei, numback))
371                         towrite.add(sk)
372                 elif fl & zmq.POLLOUT:
373                     log.debug("Write now open for fd %d", sk)
374                     towrite.add(sk)
375                     towait.discard(sk)
376                 else:
377                     log.debug("Stray event: %s on socket %s", fl, sk)
378             # poll queue consumed, make changes now
379             for fd in tostop:
380                 poller.unregister(fd)  # type: ignore
381                 clients.stop(fd)
382             for wsmsg in tosend:
383                 log.debug("Sending to the clients: %s", wsmsg)
384                 towrite |= clients.send(wsmsg)
385             for clntsock, clntaddr in topoll:
386                 fd = clients.add(clntsock, clntaddr)
387                 poller.register(fd, flags=zmq.POLLIN)
388             # Deal with actually writing the data out
389             trywrite = towrite - towait
390             morewait = clients.write(trywrite)
391             log.debug(
392                 "towait %s, tried %s, still busy %s",
393                 towait,
394                 trywrite,
395                 morewait,
396             )
397             for fd in morewait - trywrite:  # new fds waiting for write
398                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
399             for fd in trywrite - morewait:  # no longer waiting for write
400                 poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
401             towait &= trywrite
402             towait |= morewait
403     except KeyboardInterrupt:
404         zsub.close()
405         zctx.destroy()  # type: ignore
406         tcpl.close()
407
408
409 if __name__.endswith("__main__"):
410     runserver(common.init(log))