]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
reimplement backlog query again
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from datetime import datetime, timezone
4 from json import dumps, loads
5 from logging import getLogger
6 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
7 from time import time
8 from wsproto import ConnectionType, WSConnection
9 from wsproto.events import (
10     AcceptConnection,
11     CloseConnection,
12     Message,
13     Ping,
14     Request,
15     TextMessage,
16 )
17 from wsproto.utilities import RemoteProtocolError
18 import zmq
19
20 from . import common
21 from .evstore import initdb, fetch
22 from .gps303proto import (
23     GPS_POSITIONING,
24     WIFI_POSITIONING,
25     parse_message,
26 )
27 from .zmsg import Bcast, topic
28
29 log = getLogger("gps303/wsgateway")
30 htmlfile = None
31
32
33 def backlog(imei, numback):
34     result = []
35     for is_incoming, timestamp, packet in fetch(
36         imei,
37         ((True, GPS_POSITIONING.PROTO), (False, WIFI_POSITIONING.PROTO)),
38         numback,
39     ):
40         msg = parse_message(packet, is_incoming=is_incoming)
41         result.append(
42             {
43                 "imei": imei,
44                 "timestamp": str(
45                     datetime.fromtimestamp(timestamp).astimezone(
46                         tz=timezone.utc
47                     )
48                 ),
49                 "longitude": msg.longitude,
50                 "latitude": msg.latitude,
51             }
52         )
53     return result
54
55
56 def try_http(data, fd, e):
57     global htmlfile
58     try:
59         lines = data.decode().split("\r\n")
60         request = lines[0]
61         headers = lines[1:]
62         op, resource, proto = request.split(" ")
63         log.debug(
64             "HTTP %s for %s, proto %s from fd %d, headers: %s",
65             op,
66             resource,
67             proto,
68             fd,
69             headers,
70         )
71         try:
72             pos = resource.index("?")
73             resource = resource[:pos]
74         except ValueError:
75             pass
76         if op == "GET":
77             if htmlfile is None:
78                 return (
79                     f"{proto} 500 No data configured\r\n"
80                     f"Content-Type: text/plain\r\n\r\n"
81                     f"HTML data not configure on the server\r\n".encode()
82                 )
83             elif resource == "/":
84                 try:
85                     with open(htmlfile, "rb") as fl:
86                         htmldata = fl.read()
87                     length = len(htmldata)
88                     return (
89                         f"{proto} 200 Ok\r\n"
90                         f"Content-Type: text/html; charset=utf-8\r\n"
91                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
92                     ).encode("utf-8") + htmldata
93                 except OSError:
94                     return (
95                         f"{proto} 500 File not found\r\n"
96                         f"Content-Type: text/plain\r\n\r\n"
97                         f"HTML file could not be opened\r\n".encode()
98                     )
99             else:
100                 return (
101                     f"{proto} 404 File not found\r\n"
102                     f"Content-Type: text/plain\r\n\r\n"
103                     f'We can only serve "/"\r\n'.encode()
104                 )
105         else:
106             return (
107                 f"{proto} 400 Bad request\r\n"
108                 "Content-Type: text/plain\r\n\r\n"
109                 "Bad request\r\n".encode()
110             )
111     except ValueError:
112         log.warning("Unparseable data from fd %d: %s", fd, data)
113         raise e
114
115
116 class Client:
117     """Websocket connection to the client"""
118
119     def __init__(self, sock, addr):
120         self.sock = sock
121         self.addr = addr
122         self.ws = WSConnection(ConnectionType.SERVER)
123         self.ws_data = b""
124         self.ready = False
125         self.imeis = set()
126
127     def close(self):
128         log.debug("Closing fd %d", self.sock.fileno())
129         self.sock.close()
130
131     def recv(self):
132         try:
133             data = self.sock.recv(4096)
134         except OSError:
135             log.warning(
136                 "Reading from fd %d: %s",
137                 self.sock.fileno(),
138                 e,
139             )
140             self.ws.receive_data(None)
141             return None
142         if not data:  # Client has closed connection
143             log.info(
144                 "EOF reading from fd %d",
145                 self.sock.fileno(),
146             )
147             self.ws.receive_data(None)
148             return None
149         try:
150             self.ws.receive_data(data)
151         except RemoteProtocolError as e:
152             log.debug(
153                 "Websocket error on fd %d, try plain http (%s)",
154                 self.sock.fileno(),
155                 e,
156             )
157             self.ws_data = try_http(data, self.sock.fileno(), e)
158             self.write()  # TODO this is a hack
159             log.debug("Sending HTTP response to %d", self.sock.fileno())
160             msgs = None
161         else:
162             msgs = []
163             for event in self.ws.events():
164                 if isinstance(event, Request):
165                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
166                     # self.ws_data += self.ws.send(event.response())  # Why not?!
167                     self.ws_data += self.ws.send(AcceptConnection())
168                     self.ready = True
169                 elif isinstance(event, (CloseConnection, Ping)):
170                     log.debug("%s on fd %d", event, self.sock.fileno())
171                     self.ws_data += self.ws.send(event.response())
172                 elif isinstance(event, TextMessage):
173                     # TODO: save imei "subscription"
174                     log.debug("%s on fd %d", event, self.sock.fileno())
175                     msg = loads(event.data)
176                     msgs.append(msg)
177                     if msg.get("type", None) == "subscribe":
178                         self.imeis = set(msg.get("imei", []))
179                         log.debug(
180                             "subs list on fd %s is %s",
181                             self.sock.fileno(),
182                             self.imeis,
183                         )
184                 else:
185                     log.warning("%s on fd %d", event, self.sock.fileno())
186         return msgs
187
188     def wants(self, imei):
189         log.debug(
190             "wants %s? set is %s on fd %d",
191             imei,
192             self.imeis,
193             self.sock.fileno(),
194         )
195         return True  # TODO: check subscriptions
196
197     def send(self, message):
198         if self.ready and message["imei"] in self.imeis:
199             self.ws_data += self.ws.send(Message(data=dumps(message)))
200
201     def write(self):
202         if self.ws_data:
203             try:
204                 sent = self.sock.send(self.ws_data)
205                 self.ws_data = self.ws_data[sent:]
206             except OSError as e:
207                 log.error(
208                     "Sending to fd %d: %s",
209                     self.sock.fileno(),
210                     e,
211                 )
212                 self.ws_data = b""
213         return bool(self.ws_data)
214
215
216 class Clients:
217     def __init__(self):
218         self.by_fd = {}
219
220     def add(self, clntsock, clntaddr):
221         fd = clntsock.fileno()
222         log.info("Start serving fd %d from %s", fd, clntaddr)
223         self.by_fd[fd] = Client(clntsock, clntaddr)
224         return fd
225
226     def stop(self, fd):
227         clnt = self.by_fd[fd]
228         log.info("Stop serving fd %d", clnt.sock.fileno())
229         clnt.close()
230         del self.by_fd[fd]
231
232     def recv(self, fd):
233         clnt = self.by_fd[fd]
234         return clnt.recv()
235
236     def send(self, msg):
237         towrite = set()
238         for fd, clnt in self.by_fd.items():
239             if clnt.wants(msg["imei"]):
240                 clnt.send(msg)
241                 towrite.add(fd)
242         return towrite
243
244     def write(self, towrite):
245         waiting = set()
246         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
247             if clnt.write():
248                 waiting.add(fd)
249         return waiting
250
251     def subs(self):
252         result = set()
253         for clnt in self.by_fd.values():
254             result |= clnt.imeis
255         return result
256
257
258 def runserver(conf):
259     global htmlfile
260
261     initdb(conf.get("storage", "dbfn"))
262     htmlfile = conf.get("wsgateway", "htmlfile")
263     zctx = zmq.Context()
264     zsub = zctx.socket(zmq.SUB)
265     zsub.connect(conf.get("collector", "publishurl"))
266     tcpl = socket(AF_INET6, SOCK_STREAM)
267     tcpl.setblocking(False)
268     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
269     tcpl.bind(("", conf.getint("wsgateway", "port")))
270     tcpl.listen(5)
271     tcpfd = tcpl.fileno()
272     poller = zmq.Poller()
273     poller.register(zsub, flags=zmq.POLLIN)
274     poller.register(tcpfd, flags=zmq.POLLIN)
275     clients = Clients()
276     activesubs = set()
277     try:
278         towait = set()
279         while True:
280             neededsubs = clients.subs()
281             for imei in neededsubs - activesubs:
282                 zsub.setsockopt(
283                     zmq.SUBSCRIBE,
284                     topic(GPS_POSITIONING.PROTO, True, imei),
285                 )
286                 zsub.setsockopt(
287                     zmq.SUBSCRIBE,
288                     topic(WIFI_POSITIONING.PROTO, False, imei),
289                 )
290             for imei in activesubs - neededsubs:
291                 zsub.setsockopt(
292                     zmq.UNSUBSCRIBE,
293                     topic(GPS_POSITIONING.PROTO, True, imei),
294                 )
295                 zsub.setsockopt(
296                     zmq.UNSUBSCRIBE,
297                     topic(WIFI_POSITIONING.PROTO, False, imei),
298                 )
299             activesubs = neededsubs
300             log.debug("Subscribed to: %s", activesubs)
301             tosend = []
302             topoll = []
303             tostop = []
304             towrite = set()
305             events = poller.poll()
306             for sk, fl in events:
307                 if sk is zsub:
308                     while True:
309                         try:
310                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
311                             msg = parse_message(zmsg.packet, zmsg.is_incoming)
312                             log.debug("Got %s with %s", zmsg, msg)
313                             tosend.append(
314                                 {
315                                     "imei": zmsg.imei,
316                                     "timestamp": str(
317                                         datetime.fromtimestamp(
318                                             zmsg.when
319                                         ).astimezone(tz=timezone.utc)
320                                     ),
321                                     "longitude": msg.longitude,
322                                     "latitude": msg.latitude,
323                                 }
324                             )
325                         except zmq.Again:
326                             break
327                 elif sk == tcpfd:
328                     clntsock, clntaddr = tcpl.accept()
329                     topoll.append((clntsock, clntaddr))
330                 elif fl & zmq.POLLIN:
331                     received = clients.recv(sk)
332                     if received is None:
333                         log.debug("Client gone from fd %d", sk)
334                         tostop.append(sk)
335                         towait.discard(fd)
336                     else:
337                         for msg in received:
338                             log.debug("Received from %d: %s", sk, msg)
339                             if msg.get("type", None) == "subscribe":
340                                 imeis = msg.get("imei")
341                                 numback = msg.get("backlog", 5)
342                                 for imei in imeis:
343                                     tosend.extend(backlog(imei, numback))
344                         towrite.add(sk)
345                 elif fl & zmq.POLLOUT:
346                     log.debug("Write now open for fd %d", sk)
347                     towrite.add(sk)
348                     towait.discard(sk)
349                 else:
350                     log.debug("Stray event: %s on socket %s", fl, sk)
351             # poll queue consumed, make changes now
352             for fd in tostop:
353                 poller.unregister(fd)
354                 clients.stop(fd)
355             for zmsg in tosend:
356                 log.debug("Sending to the clients: %s", zmsg)
357                 towrite |= clients.send(zmsg)
358             for clntsock, clntaddr in topoll:
359                 fd = clients.add(clntsock, clntaddr)
360                 poller.register(fd, flags=zmq.POLLIN)
361             # Deal with actually writing the data out
362             trywrite = towrite - towait
363             morewait = clients.write(trywrite)
364             log.debug(
365                 "towait %s, tried %s, still busy %s",
366                 towait,
367                 trywrite,
368                 morewait,
369             )
370             for fd in morewait - trywrite:  # new fds waiting for write
371                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
372             for fd in trywrite - morewait:  # no longer waiting for write
373                 poller.modify(fd, flags=zmq.POLLIN)
374             towait &= trywrite
375             towait |= morewait
376     except KeyboardInterrupt:
377         pass
378
379
380 if __name__.endswith("__main__"):
381     runserver(common.init(log))