]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
WIP retoure messaging
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from json import loads
4 from logging import getLogger
5 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
6 from time import time
7 from wsproto import ConnectionType, WSConnection
8 from wsproto.events import (
9     AcceptConnection,
10     CloseConnection,
11     Message,
12     Ping,
13     Request,
14     TextMessage,
15 )
16 from wsproto.utilities import RemoteProtocolError
17 import zmq
18
19 from . import common
20 from .backlog import blinit, backlog
21 from .gps303proto import (
22     GPS_POSITIONING,
23     WIFI_POSITIONING,
24     parse_message,
25 )
26 from .zmsg import Bcast, topic
27
28 log = getLogger("gps303/wsgateway")
29 htmlfile = None
30
31
32 def try_http(data, fd, e):
33     global htmlfile
34     try:
35         lines = data.decode().split("\r\n")
36         request = lines[0]
37         headers = lines[1:]
38         op, resource, proto = request.split(" ")
39         log.debug(
40             "HTTP %s for %s, proto %s from fd %d, headers: %s",
41             op,
42             resource,
43             proto,
44             fd,
45             headers,
46         )
47         try:
48             pos = resource.index("?")
49             resource = resource[:pos]
50         except ValueError:
51             pass
52         if op == "GET":
53             if htmlfile is None:
54                 return (
55                     f"{proto} 500 No data configured\r\n"
56                     f"Content-Type: text/plain\r\n\r\n"
57                     f"HTML data not configure on the server\r\n".encode()
58                 )
59             elif resource == "/":
60                 try:
61                     with open(htmlfile, "rb") as fl:
62                         htmldata = fl.read()
63                     length = len(htmldata)
64                     return (
65                         f"{proto} 200 Ok\r\n"
66                         f"Content-Type: text/html; charset=utf-8\r\n"
67                         f"Content-Length: {len(htmldata):d}\r\n\r\n"
68                     ).encode("utf-8") + htmldata
69                 except OSError:
70                     return (
71                         f"{proto} 500 File not found\r\n"
72                         f"Content-Type: text/plain\r\n\r\n"
73                         f"HTML file could not be opened\r\n".encode()
74                     )
75             else:
76                 return (
77                     f"{proto} 404 File not found\r\n"
78                     f"Content-Type: text/plain\r\n\r\n"
79                     f'We can only serve "/"\r\n'.encode()
80                 )
81         else:
82             return (
83                 f"{proto} 400 Bad request\r\n"
84                 "Content-Type: text/plain\r\n\r\n"
85                 "Bad request\r\n".encode()
86             )
87     except ValueError:
88         log.warning("Unparseable data from fd %d: %s", fd, data)
89         raise e
90
91
92 class Client:
93     """Websocket connection to the client"""
94
95     def __init__(self, sock, addr):
96         self.sock = sock
97         self.addr = addr
98         self.ws = WSConnection(ConnectionType.SERVER)
99         self.ws_data = b""
100         self.ready = False
101         self.imeis = set()
102
103     def close(self):
104         log.debug("Closing fd %d", self.sock.fileno())
105         self.sock.close()
106
107     def recv(self):
108         try:
109             data = self.sock.recv(4096)
110         except OSError:
111             log.warning(
112                 "Reading from fd %d: %s",
113                 self.sock.fileno(),
114                 e,
115             )
116             self.ws.receive_data(None)
117             return None
118         if not data:  # Client has closed connection
119             log.info(
120                 "EOF reading from fd %d",
121                 self.sock.fileno(),
122             )
123             self.ws.receive_data(None)
124             return None
125         try:
126             self.ws.receive_data(data)
127         except RemoteProtocolError as e:
128             log.debug(
129                 "Websocket error on fd %d, try plain http (%s)",
130                 self.sock.fileno(),
131                 e,
132             )
133             self.ws_data = try_http(data, self.sock.fileno(), e)
134             self.write()  # TODO this is a hack
135             log.debug("Sending HTTP response to %d", self.sock.fileno())
136             msgs = None
137         else:
138             msgs = []
139             for event in self.ws.events():
140                 if isinstance(event, Request):
141                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
142                     # self.ws_data += self.ws.send(event.response())  # Why not?!
143                     self.ws_data += self.ws.send(AcceptConnection())
144                     self.ready = True
145                 elif isinstance(event, (CloseConnection, Ping)):
146                     log.debug("%s on fd %d", event, self.sock.fileno())
147                     self.ws_data += self.ws.send(event.response())
148                 elif isinstance(event, TextMessage):
149                     # TODO: save imei "subscription"
150                     log.debug("%s on fd %d", event, self.sock.fileno())
151                     msg = loads(event.data)
152                     msgs.append(msg)
153                     if msg.get("type", None) == "subscribe":
154                         self.imeis = set(msg.get("imei", []))
155                         log.debug(
156                             "subs list on fd %s is %s",
157                             self.sock.fileno(),
158                             self.imeis,
159                         )
160                 else:
161                     log.warning("%s on fd %d", event, self.sock.fileno())
162         return msgs
163
164     def wants(self, imei):
165         log.debug(
166             "wants %s? set is %s on fd %d",
167             imei,
168             self.imeis,
169             self.sock.fileno(),
170         )
171         return True  # TODO: check subscriptions
172
173     def send(self, message):
174         if self.ready and message.imei in self.imeis:
175             self.ws_data += self.ws.send(Message(data=message.json))
176
177     def write(self):
178         if self.ws_data:
179             try:
180                 sent = self.sock.send(self.ws_data)
181                 self.ws_data = self.ws_data[sent:]
182             except OSError as e:
183                 log.error(
184                     "Sending to fd %d: %s",
185                     self.sock.fileno(),
186                     e,
187                 )
188                 self.ws_data = b""
189         return bool(self.ws_data)
190
191
192 class Clients:
193     def __init__(self):
194         self.by_fd = {}
195
196     def add(self, clntsock, clntaddr):
197         fd = clntsock.fileno()
198         log.info("Start serving fd %d from %s", fd, clntaddr)
199         self.by_fd[fd] = Client(clntsock, clntaddr)
200         return fd
201
202     def stop(self, fd):
203         clnt = self.by_fd[fd]
204         log.info("Stop serving fd %d", clnt.sock.fileno())
205         clnt.close()
206         del self.by_fd[fd]
207
208     def recv(self, fd):
209         clnt = self.by_fd[fd]
210         return clnt.recv()
211
212     def send(self, msg):
213         towrite = set()
214         for fd, clnt in self.by_fd.items():
215             if clnt.wants(msg.imei):
216                 clnt.send(msg)
217                 towrite.add(fd)
218         return towrite
219
220     def write(self, towrite):
221         waiting = set()
222         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
223             if clnt.write():
224                 waiting.add(fd)
225         return waiting
226
227     def subs(self):
228         result = set()
229         for clnt in self.by_fd.values():
230             result |= clnt.imeis
231         return result
232
233
234 def runserver(conf):
235     global htmlfile
236
237     blinit(conf.get("storage", "dbfn"))
238     htmlfile = conf.get("wsgateway", "htmlfile")
239     zctx = zmq.Context()
240     zsub = zctx.socket(zmq.SUB)
241     zsub.connect(conf.get("collector", "publishurl"))
242     tcpl = socket(AF_INET6, SOCK_STREAM)
243     tcpl.setblocking(False)
244     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
245     tcpl.bind(("", conf.getint("wsgateway", "port")))
246     tcpl.listen(5)
247     tcpfd = tcpl.fileno()
248     poller = zmq.Poller()
249     poller.register(zsub, flags=zmq.POLLIN)
250     poller.register(tcpfd, flags=zmq.POLLIN)
251     clients = Clients()
252     activesubs = set()
253     try:
254         towait = set()
255         while True:
256             neededsubs = clients.subs()
257             for imei in neededsubs - activesubs:
258                 zsub.setsockopt(
259                     zmq.SUBSCRIBE,
260                     topic(GPS_POSITIONING.PROTO, True),
261                 )
262                 zsub.setsockopt(
263                     zmq.SUBSCRIBE,
264                     topic(WIFI_POSITIONING.PROTO, False),
265                 )
266             for imei in activesubs - neededsubs:
267                 zsub.setsockopt(
268                     zmq.UNSUBSCRIBE,
269                     topic(GPS_POSITIONING.PROTO, True),
270                 )
271                 zsub.setsockopt(
272                     zmq.UNSUBSCRIBE,
273                     topic(WIFI_POSITIONING.PROTO, False),
274                 )
275             activesubs = neededsubs
276             log.debug("Subscribed to: %s", activesubs)
277             tosend = []
278             topoll = []
279             tostop = []
280             towrite = set()
281             events = poller.poll()
282             for sk, fl in events:
283                 if sk is zsub:
284                     while True:
285                         try:
286                             zmsg = Bcast(zsub.recv(zmq.NOBLOCK))
287                             msg = parse_message(zmsg.packet)
288                             tosend.append(zmsg)
289                             log.debug("Got %s", zmsg)
290                         except zmq.Again:
291                             break
292                 elif sk == tcpfd:
293                     clntsock, clntaddr = tcpl.accept()
294                     topoll.append((clntsock, clntaddr))
295                 elif fl & zmq.POLLIN:
296                     received = clients.recv(sk)
297                     if received is None:
298                         log.debug("Client gone from fd %d", sk)
299                         tostop.append(sk)
300                         towait.discard(fd)
301                     else:
302                         for msg in received:
303                             log.debug("Received from %d: %s", sk, msg)
304                             if msg.get("type", None) == "subscribe":
305                                 imei = msg.get("imei")
306                                 numback = msg.get("backlog", 5)
307                                 for elem in imei:
308                                     tosend.extend(backlog(elem, numback))
309                         towrite.add(sk)
310                 elif fl & zmq.POLLOUT:
311                     log.debug("Write now open for fd %d", sk)
312                     towrite.add(sk)
313                     towait.discard(sk)
314                 else:
315                     log.debug("Stray event: %s on socket %s", fl, sk)
316             # poll queue consumed, make changes now
317             for fd in tostop:
318                 poller.unregister(fd)
319                 clients.stop(fd)
320             for zmsg in tosend:
321                 log.debug("Sending to the clients: %s", zmsg)
322                 towrite |= clients.send(zmsg)
323             for clntsock, clntaddr in topoll:
324                 fd = clients.add(clntsock, clntaddr)
325                 poller.register(fd, flags=zmq.POLLIN)
326             # Deal with actually writing the data out
327             trywrite = towrite - towait
328             morewait = clients.write(trywrite)
329             log.debug(
330                 "towait %s, tried %s, still busy %s",
331                 towait,
332                 trywrite,
333                 morewait,
334             )
335             for fd in morewait - trywrite:  # new fds waiting for write
336                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
337             for fd in trywrite - morewait:  # no longer waiting for write
338                 poller.modify(fd, flags=zmq.POLLIN)
339             towait &= trywrite
340             towait |= morewait
341     except KeyboardInterrupt:
342         pass
343
344
345 if __name__.endswith("__main__"):
346     runserver(common.init(log))