]> www.average.org Git - loctrkd.git/blob - gps303/wsgateway.py
WIP on ws gateway, it now works
[loctrkd.git] / gps303 / wsgateway.py
1 """ Websocket Gateway """
2
3 from logging import getLogger
4 from socket import socket, AF_INET6, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR
5 from time import time
6 from wsproto import ConnectionType, WSConnection
7 from wsproto.events import (
8     AcceptConnection,
9     CloseConnection,
10     Message,
11     Ping,
12     Request,
13     TextMessage,
14 )
15 from wsproto.utilities import RemoteProtocolError
16 import zmq
17
18 from . import common
19 from .zmsg import LocEvt
20
21 log = getLogger("gps303/wsgateway")
22 htmldata = None
23
24
25 def try_http(data, fd, e):
26     global htmldata
27     try:
28         lines = data.decode().split("\r\n")
29         request = lines[0]
30         headers = lines[1:]
31         op, resource, proto = request.split(" ")
32         log.debug(
33             "HTTP %s for %s, proto %s from fd %d, headers: %s",
34             op,
35             resource,
36             proto,
37             fd,
38             headers,
39         )
40         if op == "GET":
41             if htmldata is None:
42                 return (
43                     f"{proto} 500 No data configured\r\n"
44                     f"Content-Type: text/plain\r\n\r\n"
45                     f"HTML data not configure on the server\r\n".encode()
46                 )
47             elif resource == "/":
48                 length = len(htmldata.encode("utf-8"))
49                 return (
50                     f"{proto} 200 Ok\r\n"
51                     f"Content-Type: text/html; charset=utf-8\r\n"
52                     f"Content-Length: {length:d}\r\n\r\n" + htmldata
53                 ).encode("utf-8")
54             else:
55                 return (
56                     f"{proto} 404 File not found\r\n"
57                     f"Content-Type: text/plain\r\n\r\n"
58                     f"We can only serve \"/\"\r\n".encode()
59                 )
60         else:
61             return (
62                 f"{proto} 400 Bad request\r\n"
63                 "Content-Type: text/plain\r\n\r\n"
64                 "Bad request\r\n".encode()
65             )
66     except ValueError:
67         log.warning("Unparseable data from fd %d: %s", fd, data)
68         raise e
69
70
71 class Client:
72     """Websocket connection to the client"""
73
74     def __init__(self, sock, addr):
75         self.sock = sock
76         self.addr = addr
77         self.ws = WSConnection(ConnectionType.SERVER)
78         self.ws_data = b""
79
80     def close(self):
81         log.debug("Closing fd %d", self.sock.fileno())
82         self.sock.close()
83
84     def recv(self):
85         try:
86             data = self.sock.recv(4096)
87         except OSError:
88             log.warning(
89                 "Reading from fd %d: %s",
90                 self.sock.fileno(),
91                 e,
92             )
93             self.ws.receive_data(None)
94             return None
95         if not data:  # Client has closed connection
96             log.info(
97                 "EOF reading from fd %d",
98                 self.sock.fileno(),
99             )
100             self.ws.receive_data(None)
101             return None
102         try:
103             self.ws.receive_data(data)
104         except RemoteProtocolError as e:
105             log.debug(
106                 "Websocket error on fd %d, try plain http (%s)",
107                 self.sock.fileno(),
108                 e,
109             )
110             self.ws_data = try_http(data, self.sock.fileno(), e)
111             log.debug("Sending HTTP response to %d", self.sock.fileno())
112             msgs = None
113         else:
114             msgs = []
115             for event in self.ws.events():
116                 if isinstance(event, Request):
117                     log.debug("WebSocket upgrade on fd %d", self.sock.fileno())
118                     # self.ws_data += self.ws.send(event.response())  # Why not?!
119                     self.ws_data += self.ws.send(AcceptConnection())
120                 elif isinstance(event, (CloseConnection, Ping)):
121                     log.debug("%s on fd %d", event, self.sock.fileno())
122                     self.ws_data += self.ws.send(event.response())
123                 elif isinstance(event, TextMessage):
124                     # TODO: save imei "subscription"
125                     log.debug("%s on fd %d", event, self.sock.fileno())
126                     msgs.append(event.data)
127                 else:
128                     log.warning("%s on fd %d", event, self.sock.fileno())
129         if self.ws_data:  # Temp hack
130             self.write()
131         return msgs
132
133     def wants(self, imei):
134         return True  # TODO: check subscriptions
135
136     def send(self, message):
137         # TODO: filter only wanted imei got from the client
138         self.ws_data += self.ws.send(Message(data=message.json))
139
140     def write(self):
141         try:
142             sent = self.sock.send(self.ws_data)
143             self.ws_data = self.ws_data[sent:]
144         except OSError as e:
145             log.error(
146                 "Sending to fd %d: %s",
147                 self.sock.fileno(),
148                 e,
149             )
150             self.ws_data = b""
151
152
153 class Clients:
154     def __init__(self):
155         self.by_fd = {}
156
157     def add(self, clntsock, clntaddr):
158         fd = clntsock.fileno()
159         log.info("Start serving fd %d from %s", fd, clntaddr)
160         self.by_fd[fd] = Client(clntsock, clntaddr)
161         return fd
162
163     def stop(self, fd):
164         clnt = self.by_fd[fd]
165         log.info("Stop serving fd %d", clnt.sock.fileno())
166         clnt.close()
167         del self.by_fd[fd]
168
169     def recv(self, fd):
170         clnt = self.by_fd[fd]
171         msgs = clnt.recv()
172         if msgs is None:
173             return None
174         result = []
175         for msg in msgs:
176             log.debug("Received: %s", msg)
177         return result
178
179     def send(self, msg):
180         for clnt in self.by_fd.values():
181             if clnt.wants(msg.imei):
182                 clnt.send(msg)
183                 clnt.write()
184
185
186 def runserver(conf):
187     global htmldata
188     try:
189         with open(
190             conf.get("wsgateway", "htmlfile"), encoding="utf-8"
191         ) as fl:
192             htmldata = fl.read()
193     except OSError:
194         pass
195     zctx = zmq.Context()
196     zsub = zctx.socket(zmq.SUB)
197     zsub.connect(conf.get("lookaside", "publishurl"))
198     zsub.setsockopt(zmq.SUBSCRIBE, b"")
199     tcpl = socket(AF_INET6, SOCK_STREAM)
200     tcpl.setblocking(False)
201     tcpl.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
202     tcpl.bind(("", conf.getint("wsgateway", "port")))
203     tcpl.listen(5)
204     tcpfd = tcpl.fileno()
205     poller = zmq.Poller()
206     poller.register(zsub, flags=zmq.POLLIN)
207     poller.register(tcpfd, flags=zmq.POLLIN)
208     clients = Clients()
209     try:
210         while True:
211             tosend = []
212             topoll = []
213             tostop = []
214             events = poller.poll(5000)
215             log.debug("got events: %s", events)
216             for sk, fl in events:
217                 if sk is zsub:
218                     while True:
219                         try:
220                             zmsg = LocEvt(zsub.recv(zmq.NOBLOCK))
221                             tosend.append(zmsg)
222                         except zmq.Again:
223                             break
224                 elif sk == tcpfd:
225                     clntsock, clntaddr = tcpl.accept()
226                     topoll.append((clntsock, clntaddr))
227                 elif fl & zmq.POLLIN:
228                     received = clients.recv(sk)
229                     if received is None:
230                         log.debug("Client gone from fd %d", sk)
231                         tostop.append(sk)
232                     else:
233                         for msg in received:
234                             log.debug("Received from %d: %s", sk, msg)
235                 else:
236                     log.debug("Stray event: %s on socket %s", fl, sk)
237             # poll queue consumed, make changes now
238             for fd in tostop:
239                 poller.unregister(fd)
240                 clients.stop(fd)
241             for zmsg in tosend:
242                 log.debug("Sending to the client: %s", zmsg)
243                 clients.send(zmsg)
244             for clntsock, clntaddr in topoll:
245                 fd = clients.add(clntsock, clntaddr)
246                 poller.register(fd, flags=zmq.POLLIN)
247             # TODO: Handle write overruns (register for POLLOUT)
248     except KeyboardInterrupt:
249         pass
250
251
252 if __name__.endswith("__main__"):
253     runserver(common.init(log))