]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
do not respond to hibernation; minor cleanup
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from datetime import datetime, timezone
4 from json import dumps, loads
5 from logging import getLogger
6 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
7 from time import time
8 from wsproto import ConnectionType, WSConnection
9 from wsproto.events import (
10     AcceptConnection,
11     CloseConnection,
12     Message,
13     Ping,
14     Request,
15     TextMessage,
16 )
17 from wsproto.utilities import RemoteProtocolError
18 import zmq
19
20 from . import common
21 from .backlog import blinit, backlog
22 from .gps303proto import (
23     GPS_POSITIONING,
24     WIFI_POSITIONING,
25     parse_message,
26 )
27 from .zmsg import Bcast, topic
28
29 log = getLogger("gps303/wsgateway")
30 htmlfile = None
31
32
33 def try_http(data, fd, e):
34     global htmlfile
35     try:
36         lines = data.decode().split("\r\n")
37         request = lines[0]
38         headers = lines[1:]
39         op, resource, proto = request.split(" ")
40         log.debug(
41             "HTTP %s for %s, proto %s from fd %d, headers: %s",
42             op,
43             resource,
44             proto,
45             fd,
46             headers,
47         )
48         try:
49             pos = resource.index("?")
50             resource = resource[:pos]
51         except ValueError:
52             pass
53         if op == "GET":
54             if htmlfile is None:
55                 return (
56                     f"{proto} 500 No data configured\r\n"
57                     f"Content-Type: text/plain\r\n\r\n"
58                     f"HTML data not configure on the server\r\n".encode()
59                 )
60             elif resource == "/":
61                 try:
62                     with open(htmlfile, "rb") as fl:
63                         htmldata = fl.read()
64                     length = len(htmldata)
65                     return (
66                         f"{proto} 200 Ok\r\n"
67                         f"Content-Type: text/html; charset=utf-8\r\n"
68                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
69                     ).encode("utf-8") + htmldata
70                 except OSError:
71                     return (
72                         f"{proto} 500 File not found\r\n"
73                         f"Content-Type: text/plain\r\n\r\n"
74                         f"HTML file could not be opened\r\n".encode()
75                     )
76             else:
77                 return (
78                     f"{proto} 404 File not found\r\n"
79                     f"Content-Type: text/plain\r\n\r\n"
80                     f'We can only serve "/"\r\n'.encode()
81                 )
82         else:
83             return (
84                 f"{proto} 400 Bad request\r\n"
85                 "Content-Type: text/plain\r\n\r\n"
86                 "Bad request\r\n".encode()
87             )
88     except ValueError:
89         log.warning("Unparseable data from fd %d: %s", fd, data)
90         raise e
91
92
93 class Client:
94     """Websocket connection to the client"""
95
96     def __init__(self, sock, addr):
97         self.sock = sock
98         self.addr = addr
99         self.ws = WSConnection(ConnectionType.SERVER)
100         self.ws_data = b""
101         self.ready = False
102         self.imeis = set()
103
104     def close(self):
105         log.debug("Closing fd %d", self.sock.fileno())
106         self.sock.close()
107
108     def recv(self):
109         try:
110             data = self.sock.recv(4096)
111         except OSError:
112             log.warning(
113                 "Reading from fd %d: %s",
114                 self.sock.fileno(),
115                 e,
116             )
117             self.ws.receive_data(None)
118             return None
119         if not data:  # Client has closed connection
120             log.info(
121                 "EOF reading from fd %d",
122                 self.sock.fileno(),
123             )
124             self.ws.receive_data(None)
125             return None
126         try:
127             self.ws.receive_data(data)
128         except RemoteProtocolError as e:
129             log.debug(
130                 "Websocket error on fd %d, try plain http (%s)",
131                 self.sock.fileno(),
132                 e,
133             )
134             self.ws_data = try_http(data, self.sock.fileno(), e)
135             self.write()  # TODO this is a hack
136             log.debug("Sending HTTP response to %d", self.sock.fileno())
137             msgs = None
138         else:
139             msgs = []
140             for event in self.ws.events():
141                 if isinstance(event, Request):
142                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
143                     # self.ws_data += self.ws.send(event.response())  # Why not?!
144                     self.ws_data += self.ws.send(AcceptConnection())
145                     self.ready = True
146                 elif isinstance(event, (CloseConnection, Ping)):
147                     log.debug("%s on fd %d", event, self.sock.fileno())
148                     self.ws_data += self.ws.send(event.response())
149                 elif isinstance(event, TextMessage):
150                     # TODO: save imei "subscription"
151                     log.debug("%s on fd %d", event, self.sock.fileno())
152                     msg = loads(event.data)
153                     msgs.append(msg)
154                     if msg.get("type", None) == "subscribe":
155                         self.imeis = set(msg.get("imei", []))
156                         log.debug(
157                             "subs list on fd %s is %s",
158                             self.sock.fileno(),
159                             self.imeis,
160                         )
161                 else:
162                     log.warning("%s on fd %d", event, self.sock.fileno())
163         return msgs
164
165     def wants(self, imei):
166         log.debug(
167             "wants %s? set is %s on fd %d",
168             imei,
169             self.imeis,
170             self.sock.fileno(),
171         )
172         return True  # TODO: check subscriptions
173
174     def send(self, message):
175         if self.ready and message["imei"] in self.imeis:
176             self.ws_data += self.ws.send(Message(data=dumps(message)))
177
178     def write(self):
179         if self.ws_data:
180             try:
181                 sent = self.sock.send(self.ws_data)
182                 self.ws_data = self.ws_data[sent:]
183             except OSError as e:
184                 log.error(
185                     "Sending to fd %d: %s",
186                     self.sock.fileno(),
187                     e,
188                 )
189                 self.ws_data = b""
190         return bool(self.ws_data)
191
192
193 class Clients:
194     def __init__(self):
195         self.by_fd = {}
196
197     def add(self, clntsock, clntaddr):
198         fd = clntsock.fileno()
199         log.info("Start serving fd %d from %s", fd, clntaddr)
200         self.by_fd[fd] = Client(clntsock, clntaddr)
201         return fd
202
203     def stop(self, fd):
204         clnt = self.by_fd[fd]
205         log.info("Stop serving fd %d", clnt.sock.fileno())
206         clnt.close()
207         del self.by_fd[fd]
208
209     def recv(self, fd):
210         clnt = self.by_fd[fd]
211         return clnt.recv()
212
213     def send(self, msg):
214         towrite = set()
215         for fd, clnt in self.by_fd.items():
216             if clnt.wants(msg["imei"]):
217                 clnt.send(msg)
218                 towrite.add(fd)
219         return towrite
220
221     def write(self, towrite):
222         waiting = set()
223         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
224             if clnt.write():
225                 waiting.add(fd)
226         return waiting
227
228     def subs(self):
229         result = set()
230         for clnt in self.by_fd.values():
231             result |= clnt.imeis
232         return result
233
234
235 def runserver(conf):
236     global htmlfile
237
238     blinit(conf.get("storage", "dbfn"))
239     htmlfile = conf.get("wsgateway", "htmlfile")
240     zctx = zmq.Context()
241     zsub = zctx.socket(zmq.SUB)
242     zsub.connect(conf.get("collector", "publishurl"))
243     tcpl = socket(AF_INET6, SOCK_STREAM)
244     tcpl.setblocking(False)
245     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
246     tcpl.bind(("", conf.getint("wsgateway", "port")))
247     tcpl.listen(5)
248     tcpfd = tcpl.fileno()
249     poller = zmq.Poller()
250     poller.register(zsub, flags=zmq.POLLIN)
251     poller.register(tcpfd, flags=zmq.POLLIN)
252     clients = Clients()
253     activesubs = set()
254     try:
255         towait = set()
256         while True:
257             neededsubs = clients.subs()
258             for imei in neededsubs - activesubs:
259                 zsub.setsockopt(
260                     zmq.SUBSCRIBE,
261                     topic(GPS_POSITIONING.PROTO, True, imei),
262                 )
263                 zsub.setsockopt(
264                     zmq.SUBSCRIBE,
265                     topic(WIFI_POSITIONING.PROTO, False, imei),
266                 )
267             for imei in activesubs - neededsubs:
268                 zsub.setsockopt(
269                     zmq.UNSUBSCRIBE,
270                     topic(GPS_POSITIONING.PROTO, True, imei),
271                 )
272                 zsub.setsockopt(
273                     zmq.UNSUBSCRIBE,
274                     topic(WIFI_POSITIONING.PROTO, False, imei),
275                 )
276             activesubs = neededsubs
277             log.debug("Subscribed to: %s", activesubs)
278             tosend = []
279             topoll = []
280             tostop = []
281             towrite = set()
282             events = poller.poll()
283             for sk, fl in events:
284                 if sk is zsub:
285                     while True:
286                         try:
287                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
288                             msg = parse_message(zmsg.packet, zmsg.is_incoming)
289                             log.debug("Got %s with %s", zmsg, msg)
290                             tosend.append(
291                                 {
292                                     "imei": zmsg.imei,
293                                     "timestamp": str(
294                                         datetime.fromtimestamp(
295                                             zmsg.when
296                                         ).astimezone(tz=timezone.utc)
297                                     ),
298                                     "longitude": msg.longitude,
299                                     "latitude": msg.latitude,
300                                 }
301                             )
302                         except zmq.Again:
303                             break
304                 elif sk == tcpfd:
305                     clntsock, clntaddr = tcpl.accept()
306                     topoll.append((clntsock, clntaddr))
307                 elif fl & zmq.POLLIN:
308                     received = clients.recv(sk)
309                     if received is None:
310                         log.debug("Client gone from fd %d", sk)
311                         tostop.append(sk)
312                         towait.discard(fd)
313                     else:
314                         for msg in received:
315                             log.debug("Received from %d: %s", sk, msg)
316                             if msg.get("type", None) == "subscribe":
317                                 imei = msg.get("imei")
318                                 numback = msg.get("backlog", 5)
319                                 for elem in imei:
320                                     tosend.extend(backlog(elem, numback))
321                         towrite.add(sk)
322                 elif fl & zmq.POLLOUT:
323                     log.debug("Write now open for fd %d", sk)
324                     towrite.add(sk)
325                     towait.discard(sk)
326                 else:
327                     log.debug("Stray event: %s on socket %s", fl, sk)
328             # poll queue consumed, make changes now
329             for fd in tostop:
330                 poller.unregister(fd)
331                 clients.stop(fd)
332             for zmsg in tosend:
333                 log.debug("Sending to the clients: %s", zmsg)
334                 towrite |= clients.send(zmsg)
335             for clntsock, clntaddr in topoll:
336                 fd = clients.add(clntsock, clntaddr)
337                 poller.register(fd, flags=zmq.POLLIN)
338             # Deal with actually writing the data out
339             trywrite = towrite - towait
340             morewait = clients.write(trywrite)
341             log.debug(
342                 "towait %s, tried %s, still busy %s",
343                 towait,
344                 trywrite,
345                 morewait,
346             )
347             for fd in morewait - trywrite:  # new fds waiting for write
348                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
349             for fd in trywrite - morewait:  # no longer waiting for write
350                 poller.modify(fd, flags=zmq.POLLIN)
351             towait &= trywrite
352             towait |= morewait
353     except KeyboardInterrupt:
354         pass
355
356
357 if __name__.endswith("__main__"):
358     runserver(common.init(log))