]> www.average.org Git - loctrkd.git/blob - test/common.py
evstore: do not use unixepoch() - for older sqlite
[loctrkd.git] / test / common.py
1 """ Common housekeeping for tests that rely on daemons """
2
3 from configparser import ConfigParser, SectionProxy
4 from contextlib import closing, ExitStack
5 from http.server import HTTPServer, SimpleHTTPRequestHandler
6 from importlib import import_module
7 from logging import DEBUG, StreamHandler, WARNING
8 from multiprocessing import Process
9 from os import kill, unlink
10 from signal import SIGINT
11 from socket import (
12     AF_INET,
13     AF_INET6,
14     getaddrinfo,
15     MSG_DONTWAIT,
16     SOCK_DGRAM,
17     SOCK_STREAM,
18     SOL_SOCKET,
19     SO_REUSEADDR,
20     socket,
21     SocketType,
22 )
23 from sys import exit, stderr
24 from random import Random
25 from tempfile import mkstemp
26 from time import sleep
27 from typing import Any, Optional
28 from unittest import TestCase
29
30 from loctrkd.common import init_protocols
31
32 NUMPORTS = 3
33
34
35 class TestWithServers(TestCase):
36     def setUp(
37         self, *args: str, httpd: bool = False, verbose: bool = False
38     ) -> None:
39         freeports = []
40         with ExitStack() as stack:
41             for _ in range(NUMPORTS):
42                 sk = stack.enter_context(closing(socket(AF_INET6, SOCK_DGRAM)))
43                 sk.bind(("", 0))
44                 sk.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
45                 freeports.append(sk.getsockname()[1])
46         _, self.tmpfilebase = mkstemp()
47         self.conf = ConfigParser()
48         self.conf["common"] = {
49             "protocols": "zx303proto,beesure",
50         }
51         self.conf["collector"] = {
52             "port": str(freeports[0]),
53             "publishurl": "ipc://" + self.tmpfilebase + ".pub",
54             "listenurl": "ipc://" + self.tmpfilebase + ".pul",
55         }
56         self.conf["storage"] = {
57             "dbfn": self.tmpfilebase + ".storage.sqlite",
58             "events": "yes",
59         }
60         self.conf["opencellid"] = {
61             "dbfn": self.tmpfilebase + ".opencellid.sqlite",
62             "downloadurl": f"http://localhost:{freeports[2]}/test/262.csv.gz",
63         }
64         self.conf["rectifier"] = {
65             "lookaside": "opencellid",
66             "publishurl": "ipc://" + self.tmpfilebase + ".rect.pub",
67         }
68         self.conf["wsgateway"] = {
69             "port": str(freeports[1]),
70         }
71         init_protocols(self.conf)
72         self.children = []
73         for srvname in args:
74             if srvname == "collector":
75                 kwargs = {"handle_hibernate": False}
76             else:
77                 kwargs = {}
78             cls = import_module("loctrkd." + srvname, package=".")
79             if verbose:
80                 cls.log.addHandler(StreamHandler(stderr))
81                 cls.log.setLevel(DEBUG)
82             else:
83                 cls.log.setLevel(WARNING)
84             p = Process(target=cls.runserver, args=(self.conf,), kwargs=kwargs)
85             p.start()
86             self.children.append((srvname, p))
87         if httpd:
88             server = HTTPServer(("", freeports[2]), SimpleHTTPRequestHandler)
89
90             def run(server: HTTPServer) -> None:
91                 try:
92                     server.serve_forever()
93                 except KeyboardInterrupt:
94                     # TODO: this still leaves unclosed socket in the server
95                     server.shutdown()
96
97             p = Process(target=run, args=(server,))
98             p.start()
99             self.children.append(("httpd", p))
100         sleep(1)
101         for srvname, p in self.children:
102             self.assertTrue(
103                 p.is_alive(),
104                 f"{srvname} did not start, exit code {p.exitcode}",
105             )
106
107     def tearDown(self) -> None:
108         for srvname, p in self.children:
109             if p.pid is not None:
110                 kill(p.pid, SIGINT)
111             p.join()
112         for sfx in (
113             "",
114             ".pub",
115             ".rect.pub",
116             ".pul",
117             ".storage.sqlite",
118             ".opencellid.sqlite",
119         ):
120             try:
121                 unlink(self.tmpfilebase + sfx)
122             except OSError:
123                 pass
124         for srvname, p in self.children:
125             self.assertEqual(
126                 p.exitcode,
127                 0,
128                 f"{srvname} terminated with exit code {p.exitcode}",
129             )
130
131
132 class Fuzz(TestWithServers):
133     def setUp(self, *args: str, **kwargs: Any) -> None:
134         super().setUp("collector")
135         self.rnd = Random()
136         for fam, typ, pro, cnm, skadr in getaddrinfo(
137             "127.0.0.1",
138             self.conf.getint("collector", "port"),
139             family=AF_INET,
140             type=SOCK_STREAM,
141         ):
142             break  # Just take the first element
143         self.sock = socket(AF_INET, SOCK_STREAM)
144         self.sock.connect(skadr)
145
146     def tearDown(self) -> None:
147         sleep(1)  # give collector some time
148         send_and_drain(self.sock, None)
149         self.sock.close()
150         sleep(1)  # Let the server close their side
151         super().tearDown()
152
153
154 def send_and_drain(sock: SocketType, buf: Optional[bytes]) -> None:
155     if buf is not None:
156         sock.send(buf)
157     try:
158         sock.recv(4096, MSG_DONTWAIT)
159     except BlockingIOError:
160         pass