]> www.average.org Git - loctrkd.git/blob - loctrkd/wsgateway.py
WIP converting wsgateway to multiprotocols
[loctrkd.git] / loctrkd / wsgateway.py
1 """ Websocket Gateway """
2
3 from configparser import ConfigParser
4 from datetime import datetime, timezone
5 from importlib import import_module
6 from json import dumps, loads
7 from logging import getLogger
8 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
9 from time import time
10 from typing import Any, cast, Dict, List, Optional, Set, Tuple
11 from wsproto import ConnectionType, WSConnection
12 from wsproto.events import (
13     AcceptConnection,
14     CloseConnection,
15     Event,
16     Message,
17     Ping,
18     Request,
19     TextMessage,
20 )
21 from wsproto.utilities import RemoteProtocolError
22 import zmq
23
24 from . import common
25 from .evstore import initdb, fetch
26 from .zmsg import Bcast, topic
27
28 log = getLogger("loctrkd/wsgateway")
29
30
31 class ProtoModule:
32     @staticmethod
33     def parse_message(packet: bytes, is_incoming: bool = True) -> Any:
34         ...
35
36     @staticmethod
37     def exposed_protos() -> List[Tuple[str, bool]]:
38         ...
39
40     @staticmethod
41     def proto_handled(proto: str) -> bool:
42         ...
43
44
45 htmlfile = None
46 pmods: List[ProtoModule] = []
47 selector: List[Tuple[bool, str]] = []
48
49
50 def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
51     result = []
52     for is_incoming, timestamp, proto, packet in fetch(
53         imei,
54         selector,
55         numback,
56     ):
57         for pmod in pmods:
58             if pmod.proto_handled(proto):
59                 msg = pmod.parse_message(packet, is_incoming=is_incoming)
60         result.append(
61             {
62                 "type": "location",
63                 "imei": imei,
64                 "timestamp": str(
65                     datetime.fromtimestamp(timestamp).astimezone(
66                         tz=timezone.utc
67                     )
68                 ),
69                 "longitude": msg.longitude,
70                 "latitude": msg.latitude,
71                 "accuracy": "gps"
72                 if True  # TODO isinstance(msg, GPS_POSITIONING)
73                 else "approximate",
74             }
75         )
76     return result
77
78
79 def try_http(data: bytes, fd: int, e: Exception) -> bytes:
80     global htmlfile
81     try:
82         lines = data.decode().split("\r\n")
83         request = lines[0]
84         headers = lines[1:]
85         op, resource, proto = request.split(" ")
86         log.debug(
87             "HTTP %s for %s, proto %s from fd %d, headers: %s",
88             op,
89             resource,
90             proto,
91             fd,
92             headers,
93         )
94         if op == "GET":
95             if htmlfile is None:
96                 return (
97                     f"{proto} 500 No data configured\r\n"
98                     f"Content-Type: text/plain\r\n\r\n"
99                     f"HTML data not configured on the server\r\n".encode()
100                 )
101             else:
102                 try:
103                     with open(htmlfile, "rb") as fl:
104                         htmldata = fl.read()
105                     length = len(htmldata)
106                     return (
107                         f"{proto} 200 Ok\r\n"
108                         f"Content-Type: text/html; charset=utf-8\r\n"
109                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
110                     ).encode("utf-8") + htmldata
111                 except OSError:
112                     return (
113                         f"{proto} 500 File not found\r\n"
114                         f"Content-Type: text/plain\r\n\r\n"
115                         f"HTML file could not be opened\r\n".encode()
116                     )
117         else:
118             return (
119                 f"{proto} 400 Bad request\r\n"
120                 "Content-Type: text/plain\r\n\r\n"
121                 "Bad request\r\n".encode()
122             )
123     except ValueError:
124         log.warning("Unparseable data from fd %d: %s", fd, data)
125         raise e
126
127
128 class Client:
129     """Websocket connection to the client"""
130
131     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
132         self.sock = sock
133         self.addr = addr
134         self.ws = WSConnection(ConnectionType.SERVER)
135         self.ws_data = b""
136         self.ready = False
137         self.imeis: Set[str] = set()
138
139     def close(self) -> None:
140         log.debug("Closing fd %d", self.sock.fileno())
141         self.sock.close()
142
143     def recv(self) -> Optional[List[Dict[str, Any]]]:
144         try:
145             data = self.sock.recv(4096)
146         except OSError as e:
147             log.warning(
148                 "Reading from fd %d: %s",
149                 self.sock.fileno(),
150                 e,
151             )
152             self.ws.receive_data(None)
153             return None
154         if not data:  # Client has closed connection
155             log.info(
156                 "EOF reading from fd %d",
157                 self.sock.fileno(),
158             )
159             self.ws.receive_data(None)
160             return None
161         try:
162             self.ws.receive_data(data)
163         except RemoteProtocolError as e:
164             log.debug(
165                 "Websocket error on fd %d, try plain http (%s)",
166                 self.sock.fileno(),
167                 e,
168             )
169             self.ws_data = try_http(data, self.sock.fileno(), e)
170             # this `write` is a hack - writing _ought_ to be done at the
171             # stage when all other writes are performed. But I could not
172             # arrange it so in a logical way. Let it stay this way. The
173             # whole http server affair is a hack anyway.
174             self.write()
175             log.debug("Sending HTTP response to %d", self.sock.fileno())
176             msgs = None
177         else:
178             msgs = []
179             for event in self.ws.events():
180                 if isinstance(event, Request):
181                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
182                     # self.ws_data += self.ws.send(event.response())  # Why not?!
183                     self.ws_data += self.ws.send(AcceptConnection())
184                     self.ready = True
185                 elif isinstance(event, (CloseConnection, Ping)):
186                     log.debug("%s on fd %d", event, self.sock.fileno())
187                     self.ws_data += self.ws.send(event.response())
188                 elif isinstance(event, TextMessage):
189                     log.debug("%s on fd %d", event, self.sock.fileno())
190                     msg = loads(event.data)
191                     msgs.append(msg)
192                     if msg.get("type", None) == "subscribe":
193                         self.imeis = set(msg.get("imei", []))
194                         log.debug(
195                             "subs list on fd %s is %s",
196                             self.sock.fileno(),
197                             self.imeis,
198                         )
199                 else:
200                     log.warning("%s on fd %d", event, self.sock.fileno())
201         return msgs
202
203     def wants(self, imei: str) -> bool:
204         log.debug(
205             "wants %s? set is %s on fd %d",
206             imei,
207             self.imeis,
208             self.sock.fileno(),
209         )
210         return imei in self.imeis
211
212     def send(self, message: Dict[str, Any]) -> None:
213         if self.ready and message["imei"] in self.imeis:
214             self.ws_data += self.ws.send(Message(data=dumps(message)))
215
216     def write(self) -> bool:
217         if self.ws_data:
218             try:
219                 sent = self.sock.send(self.ws_data)
220                 self.ws_data = self.ws_data[sent:]
221             except OSError as e:
222                 log.error(
223                     "Sending to fd %d: %s",
224                     self.sock.fileno(),
225                     e,
226                 )
227                 self.ws_data = b""
228         return bool(self.ws_data)
229
230
231 class Clients:
232     def __init__(self) -> None:
233         self.by_fd: Dict[int, Client] = {}
234
235     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
236         fd = clntsock.fileno()
237         log.info("Start serving fd %d from %s", fd, clntaddr)
238         self.by_fd[fd] = Client(clntsock, clntaddr)
239         return fd
240
241     def stop(self, fd: int) -> None:
242         clnt = self.by_fd[fd]
243         log.info("Stop serving fd %d", clnt.sock.fileno())
244         clnt.close()
245         del self.by_fd[fd]
246
247     def recv(self, fd: int) -> Optional[List[Dict[str, Any]]]:
248         clnt = self.by_fd[fd]
249         return clnt.recv()
250
251     def send(self, msg: Dict[str, Any]) -> Set[int]:
252         towrite = set()
253         for fd, clnt in self.by_fd.items():
254             if clnt.wants(msg["imei"]):
255                 clnt.send(msg)
256                 towrite.add(fd)
257         return towrite
258
259     def write(self, towrite: Set[int]) -> Set[int]:
260         waiting = set()
261         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
262             if clnt and clnt.write():
263                 waiting.add(fd)
264         return waiting
265
266     def subs(self) -> Set[str]:
267         result = set()
268         for clnt in self.by_fd.values():
269             result |= clnt.imeis
270         return result
271
272
273 def runserver(conf: ConfigParser) -> None:
274     global htmlfile, pmods, selector
275     pmods = [
276         cast(ProtoModule, import_module("." + modnm, __package__))
277         for modnm in conf.get("collector", "protocols").split(",")
278     ]
279     for pmod in pmods:
280         for proto, is_incoming in pmod.exposed_protos():
281             if proto != "ZX:STATUS":  # TODO make it better
282                 selector.append((is_incoming, proto))
283     initdb(conf.get("storage", "dbfn"))
284     htmlfile = conf.get("wsgateway", "htmlfile", fallback=None)
285     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
286     zctx = zmq.Context()  # type: ignore
287     zsub = zctx.socket(zmq.SUB)  # type: ignore
288     zsub.connect(conf.get("collector", "publishurl"))
289     tcpl = socket(AF_INET6, SOCK_STREAM)
290     tcpl.setblocking(False)
291     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
292     tcpl.bind(("", conf.getint("wsgateway", "port")))
293     tcpl.listen(5)
294     tcpfd = tcpl.fileno()
295     poller = zmq.Poller()  # type: ignore
296     poller.register(zsub, flags=zmq.POLLIN)
297     poller.register(tcpfd, flags=zmq.POLLIN)
298     clients = Clients()
299     activesubs: Set[str] = set()
300     try:
301         towait: Set[int] = set()
302         while True:
303             neededsubs = clients.subs()
304             for pmod in pmods:
305                 for proto, is_incoming in pmod.exposed_protos():
306                     for imei in neededsubs - activesubs:
307                         zsub.setsockopt(
308                             zmq.SUBSCRIBE,
309                             topic(proto, is_incoming, imei),
310                         )
311                     for imei in activesubs - neededsubs:
312                         zsub.setsockopt(
313                             zmq.UNSUBSCRIBE,
314                             topic(proto, is_incoming, imei),
315                         )
316             activesubs = neededsubs
317             log.debug("Subscribed to: %s", activesubs)
318             tosend = []
319             topoll = []
320             tostop = []
321             towrite = set()
322             events = poller.poll()
323             for sk, fl in events:
324                 if sk is zsub:
325                     while True:
326                         try:
327                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
328                             for pmod in pmods:
329                                 if pmod.proto_handled(zmsg.proto):
330                                     msg = pmod.parse_message(
331                                         zmsg.packet, zmsg.is_incoming
332                                     )
333                             log.debug("Got %s with %s", zmsg, msg)
334                             if zmsg.proto == "ZX:STATUS":
335                                 tosend.append(
336                                     {
337                                         "type": "status",
338                                         "imei": zmsg.imei,
339                                         "timestamp": str(
340                                             datetime.fromtimestamp(
341                                                 zmsg.when
342                                             ).astimezone(tz=timezone.utc)
343                                         ),
344                                         "battery": msg.batt,
345                                     }
346                                 )
347                             else:
348                                 tosend.append(
349                                     {
350                                         "type": "location",
351                                         "imei": zmsg.imei,
352                                         "timestamp": str(
353                                             datetime.fromtimestamp(
354                                                 zmsg.when
355                                             ).astimezone(tz=timezone.utc)
356                                         ),
357                                         "longitude": msg.longitude,
358                                         "latitude": msg.latitude,
359                                         "accuracy": "gps"
360                                         if zmsg.is_incoming
361                                         else "approximate",
362                                     }
363                                 )
364                         except zmq.Again:
365                             break
366                 elif sk == tcpfd:
367                     clntsock, clntaddr = tcpl.accept()
368                     topoll.append((clntsock, clntaddr))
369                 elif fl & zmq.POLLIN:
370                     received = clients.recv(sk)
371                     if received is None:
372                         log.debug("Client gone from fd %d", sk)
373                         tostop.append(sk)
374                         towait.discard(sk)
375                     else:
376                         for wsmsg in received:
377                             log.debug("Received from %d: %s", sk, wsmsg)
378                             if wsmsg.get("type", None) == "subscribe":
379                                 # Have to live w/o typeckeding from json
380                                 imeis = cast(List[str], wsmsg.get("imei"))
381                                 numback: int = wsmsg.get("backlog", 5)
382                                 for imei in imeis:
383                                     tosend.extend(backlog(imei, numback))
384                         towrite.add(sk)
385                 elif fl & zmq.POLLOUT:
386                     log.debug("Write now open for fd %d", sk)
387                     towrite.add(sk)
388                     towait.discard(sk)
389                 else:
390                     log.debug("Stray event: %s on socket %s", fl, sk)
391             # poll queue consumed, make changes now
392             for fd in tostop:
393                 poller.unregister(fd)  # type: ignore
394                 clients.stop(fd)
395             for wsmsg in tosend:
396                 log.debug("Sending to the clients: %s", wsmsg)
397                 towrite |= clients.send(wsmsg)
398             for clntsock, clntaddr in topoll:
399                 fd = clients.add(clntsock, clntaddr)
400                 poller.register(fd, flags=zmq.POLLIN)
401             # Deal with actually writing the data out
402             trywrite = towrite - towait
403             morewait = clients.write(trywrite)
404             log.debug(
405                 "towait %s, tried %s, still busy %s",
406                 towait,
407                 trywrite,
408                 morewait,
409             )
410             for fd in morewait - trywrite:  # new fds waiting for write
411                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
412             for fd in trywrite - morewait:  # no longer waiting for write
413                 poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
414             towait &= trywrite
415             towait |= morewait
416     except KeyboardInterrupt:
417         zsub.close()
418         zctx.destroy()  # type: ignore
419         tcpl.close()
420
421
422 if __name__.endswith("__main__"):
423     runserver(common.init(log))