]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
a4a0b6d441d31d8a42800adc3e3d77eebd196547
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from logging import getLogger
4 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
5 from time import time
6 from wsproto import ConnectionType, WSConnection
7 from wsproto.events import (
8     AcceptConnection,
9     CloseConnection,
10     Message,
11     Ping,
12     Request,
13     TextMessage,
14 )
15 from wsproto.utilities import RemoteProtocolError
16 import zmq
17
18 from . import common
19 from .zmsg import LocEvt
20
21 log = getLogger("gps303/wsgateway")
22 htmldata = None
23
24
25 def try_http(data, fd, e):
26     global htmldata
27     try:
28         lines = data.decode().split("\r\n")
29         request = lines[0]
30         headers = lines[1:]
31         op, resource, proto = request.split(" ")
32         log.debug(
33             "HTTP %s for %s, proto %s from fd %d, headers: %s",
34             op,
35             resource,
36             proto,
37             fd,
38             headers,
39         )
40         try:
41             pos = resource.index("?")
42             resource = resource[:pos]
43         except ValueError:
44             pass
45         if op == "GET":
46             if htmldata is None:
47                 return (
48                     f"{proto} 500 No data configured\r\n"
49                     f"Content-Type: text/plain\r\n\r\n"
50                     f"HTML data not configure on the server\r\n".encode()
51                 )
52             elif resource == "/":
53                 length = len(htmldata.encode("utf-8"))
54                 return (
55                     f"{proto} 200 Ok\r\n"
56                     f"Content-Type: text/html; charset=utf-8\r\n"
57                     f"Content-Length: {length:d}\r\n\r\n" + htmldata
58                 ).encode("utf-8")
59             else:
60                 return (
61                     f"{proto} 404 File not found\r\n"
62                     f"Content-Type: text/plain\r\n\r\n"
63                     f"We can only serve \"/\"\r\n".encode()
64                 )
65         else:
66             return (
67                 f"{proto} 400 Bad request\r\n"
68                 "Content-Type: text/plain\r\n\r\n"
69                 "Bad request\r\n".encode()
70             )
71     except ValueError:
72         log.warning("Unparseable data from fd %d: %s", fd, data)
73         raise e
74
75
76 class Client:
77     """Websocket connection to the client"""
78
79     def __init__(self, sock, addr):
80         self.sock = sock
81         self.addr = addr
82         self.ws = WSConnection(ConnectionType.SERVER)
83         self.ws_data = b""
84
85     def close(self):
86         log.debug("Closing fd %d", self.sock.fileno())
87         self.sock.close()
88
89     def recv(self):
90         try:
91             data = self.sock.recv(4096)
92         except OSError:
93             log.warning(
94                 "Reading from fd %d: %s",
95                 self.sock.fileno(),
96                 e,
97             )
98             self.ws.receive_data(None)
99             return None
100         if not data:  # Client has closed connection
101             log.info(
102                 "EOF reading from fd %d",
103                 self.sock.fileno(),
104             )
105             self.ws.receive_data(None)
106             return None
107         try:
108             self.ws.receive_data(data)
109         except RemoteProtocolError as e:
110             log.debug(
111                 "Websocket error on fd %d, try plain http (%s)",
112                 self.sock.fileno(),
113                 e,
114             )
115             self.ws_data = try_http(data, self.sock.fileno(), e)
116             log.debug("Sending HTTP response to %d", self.sock.fileno())
117             msgs = None
118         else:
119             msgs = []
120             for event in self.ws.events():
121                 if isinstance(event, Request):
122                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
123                     # self.ws_data += self.ws.send(event.response())  # Why not?!
124                     self.ws_data += self.ws.send(AcceptConnection())
125                 elif isinstance(event, (CloseConnection, Ping)):
126                     log.debug("%s on fd %d", event, self.sock.fileno())
127                     self.ws_data += self.ws.send(event.response())
128                 elif isinstance(event, TextMessage):
129                     # TODO: save imei "subscription"
130                     log.debug("%s on fd %d", event, self.sock.fileno())
131                     msgs.append(event.data)
132                 else:
133                     log.warning("%s on fd %d", event, self.sock.fileno())
134         if self.ws_data:  # Temp hack
135             self.write()
136         return msgs
137
138     def wants(self, imei):
139         return True  # TODO: check subscriptions
140
141     def send(self, message):
142         # TODO: filter only wanted imei got from the client
143         self.ws_data += self.ws.send(Message(data=message.json))
144
145     def write(self):
146         try:
147             sent = self.sock.send(self.ws_data)
148             self.ws_data = self.ws_data[sent:]
149         except OSError as e:
150             log.error(
151                 "Sending to fd %d: %s",
152                 self.sock.fileno(),
153                 e,
154             )
155             self.ws_data = b""
156
157
158 class Clients:
159     def __init__(self):
160         self.by_fd = {}
161
162     def add(self, clntsock, clntaddr):
163         fd = clntsock.fileno()
164         log.info("Start serving fd %d from %s", fd, clntaddr)
165         self.by_fd[fd] = Client(clntsock, clntaddr)
166         return fd
167
168     def stop(self, fd):
169         clnt = self.by_fd[fd]
170         log.info("Stop serving fd %d", clnt.sock.fileno())
171         clnt.close()
172         del self.by_fd[fd]
173
174     def recv(self, fd):
175         clnt = self.by_fd[fd]
176         msgs = clnt.recv()
177         if msgs is None:
178             return None
179         result = []
180         for msg in msgs:
181             log.debug("Received: %s", msg)
182         return result
183
184     def send(self, msg):
185         for clnt in self.by_fd.values():
186             if clnt.wants(msg.imei):
187                 clnt.send(msg)
188                 clnt.write()
189
190
191 def runserver(conf):
192     global htmldata
193     try:
194         with open(
195             conf.get("wsgateway", "htmlfile"), encoding="utf-8"
196         ) as fl:
197             htmldata = fl.read()
198     except OSError:
199         pass
200     zctx = zmq.Context()
201     zsub = zctx.socket(zmq.SUB)
202     zsub.connect(conf.get("lookaside", "publishurl"))
203     zsub.setsockopt(zmq.SUBSCRIBE, b"")
204     tcpl = socket(AF_INET6, SOCK_STREAM)
205     tcpl.setblocking(False)
206     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
207     tcpl.bind(("", conf.getint("wsgateway", "port")))
208     tcpl.listen(5)
209     tcpfd = tcpl.fileno()
210     poller = zmq.Poller()
211     poller.register(zsub, flags=zmq.POLLIN)
212     poller.register(tcpfd, flags=zmq.POLLIN)
213     clients = Clients()
214     try:
215         while True:
216             tosend = []
217             topoll = []
218             tostop = []
219             events = poller.poll(5000)
220             for sk, fl in events:
221                 if sk is zsub:
222                     while True:
223                         try:
224                             zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
225                             tosend.append(zmsg)
226                         except zmq.Again:
227                             break
228                 elif sk == tcpfd:
229                     clntsock, clntaddr = tcpl.accept()
230                     topoll.append((clntsock, clntaddr))
231                 elif fl & zmq.POLLIN:
232                     received = clients.recv(sk)
233                     if received is None:
234                         log.debug("Client gone from fd %d", sk)
235                         tostop.append(sk)
236                     else:
237                         for msg in received:
238                             log.debug("Received from %d: %s", sk, msg)
239                 else:
240                     log.debug("Stray event: %s on socket %s", fl, sk)
241             # poll queue consumed, make changes now
242             for fd in tostop:
243                 poller.unregister(fd)
244                 clients.stop(fd)
245             for zmsg in tosend:
246                 log.debug("Sending to the client: %s", zmsg)
247                 clients.send(zmsg)
248             for clntsock, clntaddr in topoll:
249                 fd = clients.add(clntsock, clntaddr)
250                 poller.register(fd, flags=zmq.POLLIN)
251             # TODO: Handle write overruns (register for POLLOUT)
252     except KeyboardInterrupt:
253         pass
254
255
256 if __name__.endswith("__main__"):
257     runserver(common.init(log))