1
0
mirror of https://github.com/mariadb-corporation/mariadb-connector-c.git synced 2025-08-04 04:42:18 +03:00
Files
mariadb-connector-c/unittest/libmariadb/tls_server.py
2024-12-09 13:13:22 +01:00

166 lines
5.8 KiB
Python
Executable File

import socket
import ssl
import argparse
from ast import literal_eval
from OpenSSL import crypto, SSL
import os
class TlsServer():
def __init__(self, *args, **kwargs):
self.host= kwargs.pop("host", "127.0.0.1")
self.port= kwargs.pop("port", 50000)
self.server= None
self.end= False
try:
self.server= socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.bind((self.host, self.port))
print("# tls dummy_server started: ", self.host, self.port)
self.server.listen()
except Exception as e:
print("Couldn't start tls_server")
print(e)
def check_server(self):
if not self.server:
raise Exception("Server not started")
def send_server_hello(self, conn):
self.check_server()
try:
conn.sendall(server_hello)
except Exception as e:
print("Couldn't send server_hello")
print(e)
return 0
return 1
def generate_cert(self,
create_new=False,
create_crl=False,
emailAddress="emailAddress",
commonName="commonName",
SAN=None,
countryName="NT",
localityName="localityName",
stateOrProvinceName="stateOrProvinceName",
organizationName="organizationName",
organizationUnitName="organizationUnitName",
serialNumber=123,
validityStartInSeconds=0,
validityEndInSeconds=10*365*24*60*60,
KEY_FILE = "privkey.pem",
CRL_FILE = "selfsigned.crl",
CERT_FILE="selfsigned.pem"):
self.key_file= KEY_FILE
self.cert_file= CERT_FILE
self.crl_file = CRL_FILE
if create_new:
try:
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 4096)
# create a self-signed cert
cert = crypto.X509()
cert.get_subject().C = countryName
cert.get_subject().ST = stateOrProvinceName
cert.get_subject().L = localityName
cert.get_subject().O = organizationName
cert.get_subject().OU = organizationUnitName
cert.get_subject().CN = commonName
cert.get_subject().emailAddress = emailAddress
cert.set_serial_number(serialNumber)
cert.gmtime_adj_notBefore(validityStartInSeconds)
cert.gmtime_adj_notAfter(validityEndInSeconds)
cert.set_issuer(cert.get_subject())
if SAN:
print(SAN)
san_list= [SAN,]
cert.add_extensions([
crypto.X509Extension(
b"subjectAltName", False, "," . join(san_list).encode()
)])
cert.set_pubkey(k)
cert.sign(k, 'sha512')
with open(CERT_FILE, "wt") as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
with open(KEY_FILE, "wt") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
return 1
except Exception as e:
return 0
return 1
def set_tls_context(self, reply):
kwargs= {}
if len(reply) > 0:
cmds= reply.decode()
kwargs= dict((k, literal_eval(v)) for k, v in (pair.split('=') for pair in cmds.split()))
print("# command: ", kwargs)
if self.generate_cert(**kwargs):
print("# loading certs", self.cert_file, self.key_file)
self.context= ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.context.load_cert_chain(self.cert_file, self.key_file)
return 1
return 0
def accept(self):
self.check_server()
conn, addr= self.server.accept()
return (conn, addr)
def run(self):
while not self.end:
connection, address= self.accept()
print("# new connection")
self.send_server_hello(connection)
reply= connection.recv(4096)
if reply[:4] == b'CMD:':
if self.set_tls_context(reply[4:]):
connection.sendall(b'OK')
elif reply[:4] == b'QUIT':
print("# exiting tls_dummy_server")
try:
connection.close()
except:
pass
return
else:
try:
tls_sock= self.context.wrap_socket(connection, server_side=True)
except Exception as e:
print("error occured")
print(e)
connection.close()
connection.close()
# Hardcoded server hello packet (captured from MariaDB Server 11.4.2)
server_hello = b'R\x00\x00\x00\n11.4.2-MariaDB\x00\xff\x01\x00\x00Nv\
*hQ;qK\x00\xfe\xff\x08\x02\x00\xff\x81\x15\x00\x00\x00\
\x00\x00\x00\x1d\x00\x00\x00`$-VIJyC!x[?\x00mysql_native_password\x00'
if __name__ == '__main__':
parser= argparse.ArgumentParser(
prog='tls_server',
description='Simple TLS dummy test server')
parser.add_argument('--host', help='Hostaddress of TLS test server (Default 127.0.0.1)')
parser.add_argument('--port', help='Port of TLS test server. (Default 50000)')
args= parser.parse_args()
if not (port := args.port):
port= 50000;
if not (host := args.host):
host= "127.0.0.1"
server= TlsServer(host=host, port=int(port))
print("# Starting tls_dummy_server")
server.run()