]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
Multiprotocol support in zmq messages and storage
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from configparser import ConfigParser
4 from datetime import datetime, timezone
5 from json import dumps, loads
6 from logging import getLogger
7 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
8 from time import time
9 from typing import Any, cast, Dict, List, Optional, Set, Tuple
10 from wsproto import ConnectionType, WSConnection
11 from wsproto.events import (
12     AcceptConnection,
13     CloseConnection,
14     Event,
15     Message,
16     Ping,
17     Request,
18     TextMessage,
19 )
20 from wsproto.utilities import RemoteProtocolError
21 import zmq
22
23 from . import common
24 from .evstore import initdb, fetch
25 from .gps303proto import (
26     GPS_POSITIONING,
27     STATUS,
28     WIFI_POSITIONING,
29     parse_message,
30     proto_name,
31 )
32 from .zmsg import Bcast, topic
33
34 log = getLogger("gps303/wsgateway")
35 htmlfile = None
36
37
38 def backlog(imei: str, numback: int) -> List[Dict[str, Any]]:
39     result = []
40     for is_incoming, timestamp, packet in fetch(
41         imei,
42         [
43             (True, proto_name(GPS_POSITIONING)),
44             (False, proto_name(WIFI_POSITIONING)),
45         ],
46         numback,
47     ):
48         msg = parse_message(packet, is_incoming=is_incoming)
49         result.append(
50             {
51                 "type": "location",
52                 "imei": imei,
53                 "timestamp": str(
54                     datetime.fromtimestamp(timestamp).astimezone(
55                         tz=timezone.utc
56                     )
57                 ),
58                 "longitude": msg.longitude,
59                 "latitude": msg.latitude,
60                 "accuracy": "gps"
61                 if isinstance(msg, GPS_POSITIONING)
62                 else "approximate",
63             }
64         )
65     return result
66
67
68 def try_http(data: bytes, fd: int, e: Exception) -> bytes:
69     global htmlfile
70     try:
71         lines = data.decode().split("\r\n")
72         request = lines[0]
73         headers = lines[1:]
74         op, resource, proto = request.split(" ")
75         log.debug(
76             "HTTP %s for %s, proto %s from fd %d, headers: %s",
77             op,
78             resource,
79             proto,
80             fd,
81             headers,
82         )
83         if op == "GET":
84             if htmlfile is None:
85                 return (
86                     f"{proto} 500 No data configured\r\n"
87                     f"Content-Type: text/plain\r\n\r\n"
88                     f"HTML data not configured on the server\r\n".encode()
89                 )
90             else:
91                 try:
92                     with open(htmlfile, "rb") as fl:
93                         htmldata = fl.read()
94                     length = len(htmldata)
95                     return (
96                         f"{proto} 200 Ok\r\n"
97                         f"Content-Type: text/html; charset=utf-8\r\n"
98                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
99                     ).encode("utf-8") + htmldata
100                 except OSError:
101                     return (
102                         f"{proto} 500 File not found\r\n"
103                         f"Content-Type: text/plain\r\n\r\n"
104                         f"HTML file could not be opened\r\n".encode()
105                     )
106         else:
107             return (
108                 f"{proto} 400 Bad request\r\n"
109                 "Content-Type: text/plain\r\n\r\n"
110                 "Bad request\r\n".encode()
111             )
112     except ValueError:
113         log.warning("Unparseable data from fd %d: %s", fd, data)
114         raise e
115
116
117 class Client:
118     """Websocket connection to the client"""
119
120     def __init__(self, sock: socket, addr: Tuple[str, int]) -> None:
121         self.sock = sock
122         self.addr = addr
123         self.ws = WSConnection(ConnectionType.SERVER)
124         self.ws_data = b""
125         self.ready = False
126         self.imeis: Set[str] = set()
127
128     def close(self) -> None:
129         log.debug("Closing fd %d", self.sock.fileno())
130         self.sock.close()
131
132     def recv(self) -> Optional[List[Dict[str, Any]]]:
133         try:
134             data = self.sock.recv(4096)
135         except OSError as e:
136             log.warning(
137                 "Reading from fd %d: %s",
138                 self.sock.fileno(),
139                 e,
140             )
141             self.ws.receive_data(None)
142             return None
143         if not data:  # Client has closed connection
144             log.info(
145                 "EOF reading from fd %d",
146                 self.sock.fileno(),
147             )
148             self.ws.receive_data(None)
149             return None
150         try:
151             self.ws.receive_data(data)
152         except RemoteProtocolError as e:
153             log.debug(
154                 "Websocket error on fd %d, try plain http (%s)",
155                 self.sock.fileno(),
156                 e,
157             )
158             self.ws_data = try_http(data, self.sock.fileno(), e)
159             # this `write` is a hack - writing _ought_ to be done at the
160             # stage when all other writes are performed. But I could not
161             # arrange it so in a logical way. Let it stay this way. The
162             # whole http server affair is a hack anyway.
163             self.write()
164             log.debug("Sending HTTP response to %d", self.sock.fileno())
165             msgs = None
166         else:
167             msgs = []
168             for event in self.ws.events():
169                 if isinstance(event, Request):
170                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
171                     # self.ws_data += self.ws.send(event.response())  # Why not?!
172                     self.ws_data += self.ws.send(AcceptConnection())
173                     self.ready = True
174                 elif isinstance(event, (CloseConnection, Ping)):
175                     log.debug("%s on fd %d", event, self.sock.fileno())
176                     self.ws_data += self.ws.send(event.response())
177                 elif isinstance(event, TextMessage):
178                     log.debug("%s on fd %d", event, self.sock.fileno())
179                     msg = loads(event.data)
180                     msgs.append(msg)
181                     if msg.get("type", None) == "subscribe":
182                         self.imeis = set(msg.get("imei", []))
183                         log.debug(
184                             "subs list on fd %s is %s",
185                             self.sock.fileno(),
186                             self.imeis,
187                         )
188                 else:
189                     log.warning("%s on fd %d", event, self.sock.fileno())
190         return msgs
191
192     def wants(self, imei: str) -> bool:
193         log.debug(
194             "wants %s? set is %s on fd %d",
195             imei,
196             self.imeis,
197             self.sock.fileno(),
198         )
199         return imei in self.imeis
200
201     def send(self, message: Dict[str, Any]) -> None:
202         if self.ready and message["imei"] in self.imeis:
203             self.ws_data += self.ws.send(Message(data=dumps(message)))
204
205     def write(self) -> bool:
206         if self.ws_data:
207             try:
208                 sent = self.sock.send(self.ws_data)
209                 self.ws_data = self.ws_data[sent:]
210             except OSError as e:
211                 log.error(
212                     "Sending to fd %d: %s",
213                     self.sock.fileno(),
214                     e,
215                 )
216                 self.ws_data = b""
217         return bool(self.ws_data)
218
219
220 class Clients:
221     def __init__(self) -> None:
222         self.by_fd: Dict[int, Client] = {}
223
224     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
225         fd = clntsock.fileno()
226         log.info("Start serving fd %d from %s", fd, clntaddr)
227         self.by_fd[fd] = Client(clntsock, clntaddr)
228         return fd
229
230     def stop(self, fd: int) -> None:
231         clnt = self.by_fd[fd]
232         log.info("Stop serving fd %d", clnt.sock.fileno())
233         clnt.close()
234         del self.by_fd[fd]
235
236     def recv(self, fd: int) -> Optional[List[Dict[str, Any]]]:
237         clnt = self.by_fd[fd]
238         return clnt.recv()
239
240     def send(self, msg: Dict[str, Any]) -> Set[int]:
241         towrite = set()
242         for fd, clnt in self.by_fd.items():
243             if clnt.wants(msg["imei"]):
244                 clnt.send(msg)
245                 towrite.add(fd)
246         return towrite
247
248     def write(self, towrite: Set[int]) -> Set[int]:
249         waiting = set()
250         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
251             if clnt and clnt.write():
252                 waiting.add(fd)
253         return waiting
254
255     def subs(self) -> Set[str]:
256         result = set()
257         for clnt in self.by_fd.values():
258             result |= clnt.imeis
259         return result
260
261
262 def runserver(conf: ConfigParser) -> None:
263     global htmlfile
264
265     initdb(conf.get("storage", "dbfn"))
266     htmlfile = conf.get("wsgateway", "htmlfile", fallback=None)
267     # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?!
268     zctx = zmq.Context()  # type: ignore
269     zsub = zctx.socket(zmq.SUB)  # type: ignore
270     zsub.connect(conf.get("collector", "publishurl"))
271     tcpl = socket(AF_INET6, SOCK_STREAM)
272     tcpl.setblocking(False)
273     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
274     tcpl.bind(("", conf.getint("wsgateway", "port")))
275     tcpl.listen(5)
276     tcpfd = tcpl.fileno()
277     poller = zmq.Poller()  # type: ignore
278     poller.register(zsub, flags=zmq.POLLIN)
279     poller.register(tcpfd, flags=zmq.POLLIN)
280     clients = Clients()
281     activesubs: Set[str] = set()
282     try:
283         towait: Set[int] = set()
284         while True:
285             neededsubs = clients.subs()
286             for imei in neededsubs - activesubs:
287                 zsub.setsockopt(
288                     zmq.SUBSCRIBE,
289                     topic(proto_name(GPS_POSITIONING), True, imei),
290                 )
291                 zsub.setsockopt(
292                     zmq.SUBSCRIBE,
293                     topic(proto_name(WIFI_POSITIONING), False, imei),
294                 )
295                 zsub.setsockopt(
296                     zmq.SUBSCRIBE,
297                     topic(proto_name(STATUS), True, imei),
298                 )
299             for imei in activesubs - neededsubs:
300                 zsub.setsockopt(
301                     zmq.UNSUBSCRIBE,
302                     topic(proto_name(GPS_POSITIONING), True, imei),
303                 )
304                 zsub.setsockopt(
305                     zmq.UNSUBSCRIBE,
306                     topic(proto_name(WIFI_POSITIONING), False, imei),
307                 )
308                 zsub.setsockopt(
309                     zmq.UNSUBSCRIBE,
310                     topic(proto_name(STATUS), True, imei),
311                 )
312             activesubs = neededsubs
313             log.debug("Subscribed to: %s", activesubs)
314             tosend = []
315             topoll = []
316             tostop = []
317             towrite = set()
318             events = poller.poll()
319             for sk, fl in events:
320                 if sk is zsub:
321                     while True:
322                         try:
323                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
324                             msg = parse_message(zmsg.packet, zmsg.is_incoming)
325                             log.debug("Got %s with %s", zmsg, msg)
326                             if isinstance(msg, STATUS):
327                                 tosend.append(
328                                     {
329                                         "type": "status",
330                                         "imei": zmsg.imei,
331                                         "timestamp": str(
332                                             datetime.fromtimestamp(
333                                                 zmsg.when
334                                             ).astimezone(tz=timezone.utc)
335                                         ),
336                                         "battery": msg.batt,
337                                     }
338                                 )
339                             else:
340                                 tosend.append(
341                                     {
342                                         "type": "location",
343                                         "imei": zmsg.imei,
344                                         "timestamp": str(
345                                             datetime.fromtimestamp(
346                                                 zmsg.when
347                                             ).astimezone(tz=timezone.utc)
348                                         ),
349                                         "longitude": msg.longitude,
350                                         "latitude": msg.latitude,
351                                         "accuracy": "gps"
352                                         if zmsg.is_incoming
353                                         else "approximate",
354                                     }
355                                 )
356                         except zmq.Again:
357                             break
358                 elif sk == tcpfd:
359                     clntsock, clntaddr = tcpl.accept()
360                     topoll.append((clntsock, clntaddr))
361                 elif fl & zmq.POLLIN:
362                     received = clients.recv(sk)
363                     if received is None:
364                         log.debug("Client gone from fd %d", sk)
365                         tostop.append(sk)
366                         towait.discard(sk)
367                     else:
368                         for wsmsg in received:
369                             log.debug("Received from %d: %s", sk, wsmsg)
370                             if wsmsg.get("type", None) == "subscribe":
371                                 # Have to live w/o typeckeding from json
372                                 imeis = cast(List[str], wsmsg.get("imei"))
373                                 numback: int = wsmsg.get("backlog", 5)
374                                 for imei in imeis:
375                                     tosend.extend(backlog(imei, numback))
376                         towrite.add(sk)
377                 elif fl & zmq.POLLOUT:
378                     log.debug("Write now open for fd %d", sk)
379                     towrite.add(sk)
380                     towait.discard(sk)
381                 else:
382                     log.debug("Stray event: %s on socket %s", fl, sk)
383             # poll queue consumed, make changes now
384             for fd in tostop:
385                 poller.unregister(fd)  # type: ignore
386                 clients.stop(fd)
387             for wsmsg in tosend:
388                 log.debug("Sending to the clients: %s", wsmsg)
389                 towrite |= clients.send(wsmsg)
390             for clntsock, clntaddr in topoll:
391                 fd = clients.add(clntsock, clntaddr)
392                 poller.register(fd, flags=zmq.POLLIN)
393             # Deal with actually writing the data out
394             trywrite = towrite - towait
395             morewait = clients.write(trywrite)
396             log.debug(
397                 "towait %s, tried %s, still busy %s",
398                 towait,
399                 trywrite,
400                 morewait,
401             )
402             for fd in morewait - trywrite:  # new fds waiting for write
403                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)  # type: ignore
404             for fd in trywrite - morewait:  # no longer waiting for write
405                 poller.modify(fd, flags=zmq.POLLIN)  # type: ignore
406             towait &= trywrite
407             towait |= morewait
408     except KeyboardInterrupt:
409         zsub.close()
410         zctx.destroy()  # type: ignore
411         tcpl.close()
412
413
414 if __name__.endswith("__main__"):
415     runserver(common.init(log))