]> www.average.org Git - loctrkd.git/blob - test/common.py
284ca4d10c8509606dd6879f026c07f211104f83
[loctrkd.git] / test / common.py
1 """ Common housekeeping for tests that rely on daemons """
2
3 from configparser import ConfigParser, SectionProxy
4 from contextlib import closing, ExitStack
5 from http.server import HTTPServer, SimpleHTTPRequestHandler
6 from importlib import import_module
7 from logging import DEBUG, StreamHandler
8 from multiprocessing import Process
9 from os import kill, unlink
10 from signal import SIGINT
11 from socket import (
12     AF_INET6,
13     MSG_DONTWAIT,
14     SOCK_DGRAM,
15     SOL_SOCKET,
16     SO_REUSEADDR,
17     socket,
18     SocketType,
19 )
20 from sys import exit, stderr
21 from tempfile import mkstemp
22 from time import sleep
23 from typing import Optional
24 from unittest import TestCase
25
26 NUMPORTS = 3
27
28
29 class TestWithServers(TestCase):
30     def setUp(
31         self, *args: str, httpd: bool = False, verbose: bool = False
32     ) -> None:
33         freeports = []
34         with ExitStack() as stack:
35             for _ in range(NUMPORTS):
36                 sk = stack.enter_context(closing(socket(AF_INET6, SOCK_DGRAM)))
37                 sk.bind(("", 0))
38                 sk.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
39                 freeports.append(sk.getsockname()[1])
40         _, self.tmpfilebase = mkstemp()
41         self.conf = ConfigParser()
42         self.conf["common"] = {
43             "protocols": "zx303proto",
44         }
45         self.conf["collector"] = {
46             "port": str(freeports[0]),
47             "publishurl": "ipc://" + self.tmpfilebase + ".pub",
48             "listenurl": "ipc://" + self.tmpfilebase + ".pul",
49         }
50         self.conf["storage"] = {
51             "dbfn": self.tmpfilebase + ".storage.sqlite",
52         }
53         self.conf["opencellid"] = {
54             "dbfn": self.tmpfilebase + ".opencellid.sqlite",
55             "downloadurl": f"http://localhost:{freeports[2]}/test/262.csv.gz",
56         }
57         self.conf["rectifier"] = {
58             "lookaside": "opencellid",
59             "publishurl": "ipc://" + self.tmpfilebase + ".rect.pub",
60         }
61         self.conf["wsgateway"] = {
62             "port": str(freeports[1]),
63         }
64         self.children = []
65         for srvname in args:
66             if srvname == "collector":
67                 kwargs = {"handle_hibernate": False}
68             else:
69                 kwargs = {}
70             cls = import_module("loctrkd." + srvname, package=".")
71             if verbose:
72                 cls.log.addHandler(StreamHandler(stderr))
73                 cls.log.setLevel(DEBUG)
74             p = Process(target=cls.runserver, args=(self.conf,), kwargs=kwargs)
75             p.start()
76             self.children.append((srvname, p))
77         if httpd:
78             server = HTTPServer(("", freeports[2]), SimpleHTTPRequestHandler)
79
80             def run(server: HTTPServer) -> None:
81                 try:
82                     server.serve_forever()
83                 except KeyboardInterrupt:
84                     # TODO: this still leaves unclosed socket in the server
85                     server.shutdown()
86
87             p = Process(target=run, args=(server,))
88             p.start()
89             self.children.append(("httpd", p))
90         sleep(1)
91
92     def tearDown(self) -> None:
93         for srvname, p in self.children:
94             if p.pid is not None:
95                 kill(p.pid, SIGINT)
96             p.join()
97             self.assertEqual(
98                 p.exitcode,
99                 0,
100                 f"{srvname} terminated with return code {p.exitcode}",
101             )
102         for sfx in (
103             "",
104             ".pub",
105             ".rect.pub",
106             ".pul",
107             ".storage.sqlite",
108             ".opencellid.sqlite",
109         ):
110             try:
111                 unlink(self.tmpfilebase + sfx)
112             except OSError:
113                 pass
114
115
116 def send_and_drain(sock: SocketType, buf: Optional[bytes]) -> None:
117     if buf is not None:
118         sock.send(buf)
119     try:
120         sock.recv(4096, MSG_DONTWAIT)
121     except BlockingIOError:
122         pass