]> www.average.org Git - loctrkd.git/blob - test/common.py
tests: separate fuzz tests in two modules
[loctrkd.git] / test / common.py
1 """ Common housekeeping for tests that rely on daemons """
2
3 from configparser import ConfigParser, SectionProxy
4 from contextlib import closing, ExitStack
5 from http.server import HTTPServer, SimpleHTTPRequestHandler
6 from importlib import import_module
7 from logging import DEBUG, StreamHandler
8 from multiprocessing import Process
9 from os import kill, unlink
10 from signal import SIGINT
11 from socket import (
12     AF_INET,
13     AF_INET6,
14     getaddrinfo,
15     MSG_DONTWAIT,
16     SOCK_DGRAM,
17     SOCK_STREAM,
18     SOL_SOCKET,
19     SO_REUSEADDR,
20     socket,
21     SocketType,
22 )
23 from sys import exit, stderr
24 from random import Random
25 from tempfile import mkstemp
26 from time import sleep
27 from typing import Any, Optional
28 from unittest import TestCase
29
30 from loctrkd.common import init_protocols
31
32 NUMPORTS = 3
33
34
35 class TestWithServers(TestCase):
36     def setUp(
37         self, *args: str, httpd: bool = False, verbose: bool = False
38     ) -> None:
39         freeports = []
40         with ExitStack() as stack:
41             for _ in range(NUMPORTS):
42                 sk = stack.enter_context(closing(socket(AF_INET6, SOCK_DGRAM)))
43                 sk.bind(("", 0))
44                 sk.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
45                 freeports.append(sk.getsockname()[1])
46         _, self.tmpfilebase = mkstemp()
47         self.conf = ConfigParser()
48         self.conf["common"] = {
49             "protocols": "zx303proto",
50         }
51         self.conf["collector"] = {
52             "port": str(freeports[0]),
53             "publishurl": "ipc://" + self.tmpfilebase + ".pub",
54             "listenurl": "ipc://" + self.tmpfilebase + ".pul",
55         }
56         self.conf["storage"] = {
57             "dbfn": self.tmpfilebase + ".storage.sqlite",
58             "events": "yes",
59         }
60         self.conf["opencellid"] = {
61             "dbfn": self.tmpfilebase + ".opencellid.sqlite",
62             "downloadurl": f"http://localhost:{freeports[2]}/test/262.csv.gz",
63         }
64         self.conf["rectifier"] = {
65             "lookaside": "opencellid",
66             "publishurl": "ipc://" + self.tmpfilebase + ".rect.pub",
67         }
68         self.conf["wsgateway"] = {
69             "port": str(freeports[1]),
70         }
71         init_protocols(self.conf)
72         self.children = []
73         for srvname in args:
74             if srvname == "collector":
75                 kwargs = {"handle_hibernate": False}
76             else:
77                 kwargs = {}
78             cls = import_module("loctrkd." + srvname, package=".")
79             if verbose:
80                 cls.log.addHandler(StreamHandler(stderr))
81                 cls.log.setLevel(DEBUG)
82             p = Process(target=cls.runserver, args=(self.conf,), kwargs=kwargs)
83             p.start()
84             self.children.append((srvname, p))
85         if httpd:
86             server = HTTPServer(("", freeports[2]), SimpleHTTPRequestHandler)
87
88             def run(server: HTTPServer) -> None:
89                 try:
90                     server.serve_forever()
91                 except KeyboardInterrupt:
92                     # TODO: this still leaves unclosed socket in the server
93                     server.shutdown()
94
95             p = Process(target=run, args=(server,))
96             p.start()
97             self.children.append(("httpd", p))
98         sleep(1)
99
100     def tearDown(self) -> None:
101         for srvname, p in self.children:
102             if p.pid is not None:
103                 kill(p.pid, SIGINT)
104             p.join()
105             self.assertEqual(
106                 p.exitcode,
107                 0,
108                 f"{srvname} terminated with return code {p.exitcode}",
109             )
110         for sfx in (
111             "",
112             ".pub",
113             ".rect.pub",
114             ".pul",
115             ".storage.sqlite",
116             ".opencellid.sqlite",
117         ):
118             try:
119                 unlink(self.tmpfilebase + sfx)
120             except OSError:
121                 pass
122
123
124 class Fuzz(TestWithServers):
125     def setUp(self, *args: str, **kwargs: Any) -> None:
126         super().setUp("collector")
127         self.rnd = Random()
128         for fam, typ, pro, cnm, skadr in getaddrinfo(
129             "127.0.0.1",
130             self.conf.getint("collector", "port"),
131             family=AF_INET,
132             type=SOCK_STREAM,
133         ):
134             break  # Just take the first element
135         self.sock = socket(AF_INET, SOCK_STREAM)
136         self.sock.connect(skadr)
137
138     def tearDown(self) -> None:
139         sleep(1)  # give collector some time
140         send_and_drain(self.sock, None)
141         self.sock.close()
142         sleep(1)  # Let the server close their side
143         super().tearDown()
144
145
146 def send_and_drain(sock: SocketType, buf: Optional[bytes]) -> None:
147     if buf is not None:
148         sock.send(buf)
149     try:
150         sock.recv(4096, MSG_DONTWAIT)
151     except BlockingIOError:
152         pass