]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
wsgateway properly handle write-busy websockets
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from logging import getLogger
4 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
5 from time import time
6 from wsproto import ConnectionType, WSConnection
7 from wsproto.events import (
8     AcceptConnection,
9     CloseConnection,
10     Message,
11     Ping,
12     Request,
13     TextMessage,
14 )
15 from wsproto.utilities import RemoteProtocolError
16 import zmq
17
18 from . import common
19 from .zmsg import LocEvt
20
21 log = getLogger("gps303/wsgateway")
22 htmldata = None
23
24
25 def try_http(data, fd, e):
26     global htmldata
27     try:
28         lines = data.decode().split("\r\n")
29         request = lines[0]
30         headers = lines[1:]
31         op, resource, proto = request.split(" ")
32         log.debug(
33             "HTTP %s for %s, proto %s from fd %d, headers: %s",
34             op,
35             resource,
36             proto,
37             fd,
38             headers,
39         )
40         try:
41             pos = resource.index("?")
42             resource = resource[:pos]
43         except ValueError:
44             pass
45         if op == "GET":
46             if htmldata is None:
47                 return (
48                     f"{proto} 500 No data configured\r\n"
49                     f"Content-Type: text/plain\r\n\r\n"
50                     f"HTML data not configure on the server\r\n".encode()
51                 )
52             elif resource == "/":
53                 length = len(htmldata.encode("utf-8"))
54                 return (
55                     f"{proto} 200 Ok\r\n"
56                     f"Content-Type: text/html; charset=utf-8\r\n"
57                     f"Content-Length: {length:d}\r\n\r\n" + htmldata
58                 ).encode("utf-8")
59             else:
60                 return (
61                     f"{proto} 404 File not found\r\n"
62                     f"Content-Type: text/plain\r\n\r\n"
63                     f'We can only serve "/"\r\n'.encode()
64                 )
65         else:
66             return (
67                 f"{proto} 400 Bad request\r\n"
68                 "Content-Type: text/plain\r\n\r\n"
69                 "Bad request\r\n".encode()
70             )
71     except ValueError:
72         log.warning("Unparseable data from fd %d: %s", fd, data)
73         raise e
74
75
76 class Client:
77     """Websocket connection to the client"""
78
79     def __init__(self, sock, addr):
80         self.sock = sock
81         self.addr = addr
82         self.ws = WSConnection(ConnectionType.SERVER)
83         self.ws_data = b""
84
85     def close(self):
86         log.debug("Closing fd %d", self.sock.fileno())
87         self.sock.close()
88
89     def recv(self):
90         try:
91             data = self.sock.recv(4096)
92         except OSError:
93             log.warning(
94                 "Reading from fd %d: %s",
95                 self.sock.fileno(),
96                 e,
97             )
98             self.ws.receive_data(None)
99             return None
100         if not data:  # Client has closed connection
101             log.info(
102                 "EOF reading from fd %d",
103                 self.sock.fileno(),
104             )
105             self.ws.receive_data(None)
106             return None
107         try:
108             self.ws.receive_data(data)
109         except RemoteProtocolError as e:
110             log.debug(
111                 "Websocket error on fd %d, try plain http (%s)",
112                 self.sock.fileno(),
113                 e,
114             )
115             self.ws_data = try_http(data, self.sock.fileno(), e)
116             log.debug("Sending HTTP response to %d", self.sock.fileno())
117             msgs = None
118         else:
119             msgs = []
120             for event in self.ws.events():
121                 if isinstance(event, Request):
122                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
123                     # self.ws_data += self.ws.send(event.response())  # Why not?!
124                     self.ws_data += self.ws.send(AcceptConnection())
125                 elif isinstance(event, (CloseConnection, Ping)):
126                     log.debug("%s on fd %d", event, self.sock.fileno())
127                     self.ws_data += self.ws.send(event.response())
128                 elif isinstance(event, TextMessage):
129                     # TODO: save imei "subscription"
130                     log.debug("%s on fd %d", event, self.sock.fileno())
131                     msgs.append(event.data)
132                 else:
133                     log.warning("%s on fd %d", event, self.sock.fileno())
134         if self.ws_data:  # Temp hack
135             self.write()
136         return msgs
137
138     def wants(self, imei):
139         return True  # TODO: check subscriptions
140
141     def send(self, message):
142         # TODO: filter only wanted imei got from the client
143         self.ws_data += self.ws.send(Message(data=message.json))
144
145     def write(self):
146         if self.ws_data:
147             try:
148                 sent = self.sock.send(self.ws_data)
149                 self.ws_data = self.ws_data[sent:]
150             except OSError as e:
151                 log.error(
152                     "Sending to fd %d: %s",
153                     self.sock.fileno(),
154                     e,
155                 )
156                 self.ws_data = b""
157         return bool(self.ws_data)
158
159
160 class Clients:
161     def __init__(self):
162         self.by_fd = {}
163
164     def add(self, clntsock, clntaddr):
165         fd = clntsock.fileno()
166         log.info("Start serving fd %d from %s", fd, clntaddr)
167         self.by_fd[fd] = Client(clntsock, clntaddr)
168         return fd
169
170     def stop(self, fd):
171         clnt = self.by_fd[fd]
172         log.info("Stop serving fd %d", clnt.sock.fileno())
173         clnt.close()
174         del self.by_fd[fd]
175
176     def recv(self, fd):
177         clnt = self.by_fd[fd]
178         msgs = clnt.recv()
179         if msgs is None:
180             return None
181         result = []
182         for msg in msgs:
183             log.debug("Received: %s", msg)
184         return result
185
186     def send(self, msg):
187         towrite = set()
188         for fd, clnt in self.by_fd.items():
189             if clnt.wants(msg.imei):
190                 clnt.send(msg)
191                 towrite.add(fd)
192         return towrite
193
194     def write(self, towrite):
195         waiting = set()
196         for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
197             if clnt.write():
198                 waiting.add(fd)
199         return waiting
200
201
202 def runserver(conf):
203     global htmldata
204     try:
205         with open(conf.get("wsgateway", "htmlfile"), encoding="utf-8") as fl:
206             htmldata = fl.read()
207     except OSError:
208         pass
209     zctx = zmq.Context()
210     zsub = zctx.socket(zmq.SUB)
211     zsub.connect(conf.get("lookaside", "publishurl"))
212     zsub.setsockopt(zmq.SUBSCRIBE, b"")
213     tcpl = socket(AF_INET6, SOCK_STREAM)
214     tcpl.setblocking(False)
215     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
216     tcpl.bind(("", conf.getint("wsgateway", "port")))
217     tcpl.listen(5)
218     tcpfd = tcpl.fileno()
219     poller = zmq.Poller()
220     poller.register(zsub, flags=zmq.POLLIN)
221     poller.register(tcpfd, flags=zmq.POLLIN)
222     clients = Clients()
223     try:
224         towait = set()
225         while True:
226             tosend = []
227             topoll = []
228             tostop = []
229             towrite = set()
230             events = poller.poll(5000)
231             for sk, fl in events:
232                 if sk is zsub:
233                     while True:
234                         try:
235                             zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
236                             tosend.append(zmsg)
237                         except zmq.Again:
238                             break
239                 elif sk == tcpfd:
240                     clntsock, clntaddr = tcpl.accept()
241                     topoll.append((clntsock, clntaddr))
242                 elif fl & zmq.POLLIN:
243                     received = clients.recv(sk)
244                     if received is None:
245                         log.debug("Client gone from fd %d", sk)
246                         tostop.append(sk)
247                     else:
248                         for msg in received:
249                             log.debug("Received from %d: %s", sk, msg)
250                 elif fl & zmq.POLLOUT:
251                     log.debug("Write now open for fd %d", sk)
252                     towrite.add(sk)
253                     towait.discard(sk)
254                 else:
255                     log.debug("Stray event: %s on socket %s", fl, sk)
256             # poll queue consumed, make changes now
257             for fd in tostop:
258                 poller.unregister(fd)
259                 clients.stop(fd)
260             for zmsg in tosend:
261                 log.debug("Sending to the clients: %s", zmsg)
262                 towrite |= clients.send(zmsg)
263             for clntsock, clntaddr in topoll:
264                 fd = clients.add(clntsock, clntaddr)
265                 poller.register(fd, flags=zmq.POLLIN)
266             # Deal with actually writing the data out
267             trywrite = towrite - towait
268             morewait = clients.write(trywrite)
269             log.debug(
270                 "towait %s, tried %s, still busy %s",
271                 towait,
272                 trywrite,
273                 morewait,
274             )
275             for fd in morewait - trywrite:  # new fds waiting for write
276                 poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
277             for fd in trywrite - morewait:  # no longer waiting for write
278                 poller.modify(fd, flags=zmq.POLLIN)
279             towait &= trywrite
280             towait |= morewait
281     except KeyboardInterrupt:
282         pass
283
284
285 if __name__.endswith("__main__"):
286     runserver(common.init(log))