1
0
mirror of https://github.com/quay/quay.git synced 2025-04-18 10:44:06 +03:00
quay/test/helpers.py
Kenny Lee Sin Cheong 5f63b3a7bb
chore: drop deprecated tables and remove unused code (PROJQUAY-522) (#2089)
* chore: drop deprecated tables and remove unused code

* isort imports

* migration: check for table existence before drop
2023-08-25 12:17:24 -04:00

145 lines
4.2 KiB
Python

import logging
import multiprocessing
import re
import socket
import time
from contextlib import contextmanager
from playhouse.test_utils import _QueryLogHandler
from data.database import LogEntry3, LogEntryKind
class assert_action_logged(object):
"""
Specialized assertion for ensuring that a log entry of a particular kind was added under the
context of this call.
"""
def __init__(self, log_kind):
self.log_kind = log_kind
self.existing_count = 0
def _get_log_count(self):
return (
LogEntry3.select().where(LogEntry3.kind == LogEntryKind.get(name=self.log_kind)).count()
)
def __enter__(self):
self.existing_count = self._get_log_count()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is None:
updated_count = self._get_log_count()
error_msg = "Missing new log entry of kind %s" % self.log_kind
assert self.existing_count == (updated_count - 1), error_msg
class log_queries(object):
"""Logs all queries that occur under the context."""
def __init__(self, query_filters=None):
self.filters = query_filters
def get_queries(self):
queries = [q.msg[0] for q in self._handler.queries]
if not self.filters:
return queries
filtered_queries = []
for query_filter in self.filters:
filtered_queries.extend([q for q in queries if re.match(query_filter, q)])
return filtered_queries
def __enter__(self):
logger = logging.getLogger("peewee")
self._handler = _QueryLogHandler()
logger.setLevel(logging.DEBUG)
logger.addHandler(self._handler)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logger = logging.getLogger("peewee")
logger.removeHandler(self._handler)
class check_transitive_modifications(log_queries):
"""Checks for Peewee-generated transition deletion queries and fails if any are found.
These kinds of queries (which use subqueries) can lock massively on MySQL, so we detect
them and fail.
"""
def __init__(self):
filters = [r"^DELETE.+IN \(SELECT.+$", r"^UPDATE.+IN \(SELECT.+$"]
super(check_transitive_modifications, self).__init__(query_filters=filters)
def __exit__(self, exc_type, exc_val, exc_tb):
super(check_transitive_modifications, self).__exit__(exc_type, exc_val, exc_tb)
queries = self.get_queries()
if queries:
raise Exception("Detected transitive deletion or update in queries: %s" % queries)
_LIVESERVER_TIMEOUT = 5
@contextmanager
def liveserver_app(flask_app, port):
"""
Based on https://github.com/jarus/flask-testing/blob/master/flask_testing/utils.py.
Runs the given Flask app as a live web server locally, on the given port, starting it
when called and terminating after the yield.
Usage:
with liveserver_app(flask_app, port):
# Code that makes use of the app.
"""
shared = {}
def _can_ping_server():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect(("localhost", port))
except socket.error:
success = False
else:
success = True
finally:
sock.close()
return success
def _spawn_live_server():
worker = lambda app, port: app.run(port=port, use_reloader=False)
ctx = multiprocessing.get_context("fork")
shared["process"] = ctx.Process(target=worker, args=(flask_app, port))
shared["process"].start()
start_time = time.time()
while True:
elapsed_time = time.time() - start_time
if elapsed_time > _LIVESERVER_TIMEOUT:
_terminate_live_server()
raise RuntimeError(
"Failed to start the server after %d seconds. " % _LIVESERVER_TIMEOUT
)
if _can_ping_server():
break
def _terminate_live_server():
if shared.get("process"):
shared.get("process").terminate()
shared.pop("process")
try:
_spawn_live_server()
yield
finally:
_terminate_live_server()