]> www.average.org Git - loctrkd.git/blob - test/common.py
58954a2b217acbd0fa5a930476cee1882dea8ff1
[loctrkd.git] / test / common.py
1 """ Common housekeeping for tests that rely on daemons """
2
3 from configparser import ConfigParser, SectionProxy
4 from contextlib import closing, ExitStack
5 from http.server import HTTPServer, SimpleHTTPRequestHandler
6 from importlib import import_module
7 from logging import DEBUG, StreamHandler
8 from multiprocessing import Process
9 from os import kill, unlink
10 from signal import SIGINT
11 from socket import (
12     AF_INET6,
13     MSG_DONTWAIT,
14     SOCK_DGRAM,
15     SOL_SOCKET,
16     SO_REUSEADDR,
17     socket,
18     SocketType,
19 )
20 from sys import exit, stderr
21 from tempfile import mkstemp
22 from time import sleep
23 from typing import Optional
24 from unittest import TestCase
25
26 from loctrkd.common import init_protocols
27
28 NUMPORTS = 3
29
30
31 class TestWithServers(TestCase):
32     def setUp(
33         self, *args: str, httpd: bool = False, verbose: bool = False
34     ) -> None:
35         freeports = []
36         with ExitStack() as stack:
37             for _ in range(NUMPORTS):
38                 sk = stack.enter_context(closing(socket(AF_INET6, SOCK_DGRAM)))
39                 sk.bind(("", 0))
40                 sk.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
41                 freeports.append(sk.getsockname()[1])
42         _, self.tmpfilebase = mkstemp()
43         self.conf = ConfigParser()
44         self.conf["common"] = {
45             "protocols": "zx303proto",
46         }
47         self.conf["collector"] = {
48             "port": str(freeports[0]),
49             "publishurl": "ipc://" + self.tmpfilebase + ".pub",
50             "listenurl": "ipc://" + self.tmpfilebase + ".pul",
51         }
52         self.conf["storage"] = {
53             "dbfn": self.tmpfilebase + ".storage.sqlite",
54         }
55         self.conf["opencellid"] = {
56             "dbfn": self.tmpfilebase + ".opencellid.sqlite",
57             "downloadurl": f"http://localhost:{freeports[2]}/test/262.csv.gz",
58         }
59         self.conf["rectifier"] = {
60             "lookaside": "opencellid",
61             "publishurl": "ipc://" + self.tmpfilebase + ".rect.pub",
62         }
63         self.conf["wsgateway"] = {
64             "port": str(freeports[1]),
65         }
66         init_protocols(self.conf)
67         self.children = []
68         for srvname in args:
69             if srvname == "collector":
70                 kwargs = {"handle_hibernate": False}
71             else:
72                 kwargs = {}
73             cls = import_module("loctrkd." + srvname, package=".")
74             if verbose:
75                 cls.log.addHandler(StreamHandler(stderr))
76                 cls.log.setLevel(DEBUG)
77             p = Process(target=cls.runserver, args=(self.conf,), kwargs=kwargs)
78             p.start()
79             self.children.append((srvname, p))
80         if httpd:
81             server = HTTPServer(("", freeports[2]), SimpleHTTPRequestHandler)
82
83             def run(server: HTTPServer) -> None:
84                 try:
85                     server.serve_forever()
86                 except KeyboardInterrupt:
87                     # TODO: this still leaves unclosed socket in the server
88                     server.shutdown()
89
90             p = Process(target=run, args=(server,))
91             p.start()
92             self.children.append(("httpd", p))
93         sleep(1)
94
95     def tearDown(self) -> None:
96         for srvname, p in self.children:
97             if p.pid is not None:
98                 kill(p.pid, SIGINT)
99             p.join()
100             self.assertEqual(
101                 p.exitcode,
102                 0,
103                 f"{srvname} terminated with return code {p.exitcode}",
104             )
105         for sfx in (
106             "",
107             ".pub",
108             ".rect.pub",
109             ".pul",
110             ".storage.sqlite",
111             ".opencellid.sqlite",
112         ):
113             try:
114                 unlink(self.tmpfilebase + sfx)
115             except OSError:
116                 pass
117
118
119 def send_and_drain(sock: SocketType, buf: Optional[bytes]) -> None:
120     if buf is not None:
121         sock.send(buf)
122     try:
123         sock.recv(4096, MSG_DONTWAIT)
124     except BlockingIOError:
125         pass