You've already forked mariadb-connector-c
mirror of
https://github.com/mariadb-corporation/mariadb-connector-c.git
synced 2025-08-04 04:42:18 +03:00
166 lines
5.8 KiB
Python
Executable File
166 lines
5.8 KiB
Python
Executable File
import socket
|
|
import ssl
|
|
import argparse
|
|
from ast import literal_eval
|
|
from OpenSSL import crypto, SSL
|
|
import os
|
|
|
|
class TlsServer():
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.host= kwargs.pop("host", "127.0.0.1")
|
|
self.port= kwargs.pop("port", 50000)
|
|
self.server= None
|
|
self.end= False
|
|
|
|
try:
|
|
self.server= socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.server.bind((self.host, self.port))
|
|
print("# tls dummy_server started: ", self.host, self.port)
|
|
self.server.listen()
|
|
except Exception as e:
|
|
print("Couldn't start tls_server")
|
|
print(e)
|
|
|
|
def check_server(self):
|
|
if not self.server:
|
|
raise Exception("Server not started")
|
|
|
|
|
|
def send_server_hello(self, conn):
|
|
self.check_server()
|
|
try:
|
|
conn.sendall(server_hello)
|
|
except Exception as e:
|
|
print("Couldn't send server_hello")
|
|
print(e)
|
|
return 0
|
|
return 1
|
|
|
|
def generate_cert(self,
|
|
create_new=False,
|
|
create_crl=False,
|
|
emailAddress="emailAddress",
|
|
commonName="commonName",
|
|
SAN=None,
|
|
countryName="NT",
|
|
localityName="localityName",
|
|
stateOrProvinceName="stateOrProvinceName",
|
|
organizationName="organizationName",
|
|
organizationUnitName="organizationUnitName",
|
|
serialNumber=123,
|
|
validityStartInSeconds=0,
|
|
validityEndInSeconds=10*365*24*60*60,
|
|
KEY_FILE = "privkey.pem",
|
|
CRL_FILE = "selfsigned.crl",
|
|
CERT_FILE="selfsigned.pem"):
|
|
|
|
|
|
self.key_file= KEY_FILE
|
|
self.cert_file= CERT_FILE
|
|
self.crl_file = CRL_FILE
|
|
|
|
if create_new:
|
|
try:
|
|
k = crypto.PKey()
|
|
k.generate_key(crypto.TYPE_RSA, 4096)
|
|
# create a self-signed cert
|
|
cert = crypto.X509()
|
|
cert.get_subject().C = countryName
|
|
cert.get_subject().ST = stateOrProvinceName
|
|
cert.get_subject().L = localityName
|
|
cert.get_subject().O = organizationName
|
|
cert.get_subject().OU = organizationUnitName
|
|
cert.get_subject().CN = commonName
|
|
cert.get_subject().emailAddress = emailAddress
|
|
cert.set_serial_number(serialNumber)
|
|
cert.gmtime_adj_notBefore(validityStartInSeconds)
|
|
cert.gmtime_adj_notAfter(validityEndInSeconds)
|
|
cert.set_issuer(cert.get_subject())
|
|
if SAN:
|
|
print(SAN)
|
|
san_list= [SAN,]
|
|
cert.add_extensions([
|
|
crypto.X509Extension(
|
|
b"subjectAltName", False, "," . join(san_list).encode()
|
|
)])
|
|
cert.set_pubkey(k)
|
|
cert.sign(k, 'sha512')
|
|
with open(CERT_FILE, "wt") as f:
|
|
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
|
|
with open(KEY_FILE, "wt") as f:
|
|
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
|
|
return 1
|
|
except Exception as e:
|
|
return 0
|
|
return 1
|
|
|
|
|
|
def set_tls_context(self, reply):
|
|
kwargs= {}
|
|
if len(reply) > 0:
|
|
cmds= reply.decode()
|
|
kwargs= dict((k, literal_eval(v)) for k, v in (pair.split('=') for pair in cmds.split()))
|
|
print("# command: ", kwargs)
|
|
if self.generate_cert(**kwargs):
|
|
print("# loading certs", self.cert_file, self.key_file)
|
|
self.context= ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
self.context.load_cert_chain(self.cert_file, self.key_file)
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def accept(self):
|
|
self.check_server()
|
|
conn, addr= self.server.accept()
|
|
return (conn, addr)
|
|
|
|
def run(self):
|
|
while not self.end:
|
|
connection, address= self.accept()
|
|
print("# new connection")
|
|
self.send_server_hello(connection)
|
|
reply= connection.recv(4096)
|
|
if reply[:4] == b'CMD:':
|
|
if self.set_tls_context(reply[4:]):
|
|
connection.sendall(b'OK')
|
|
elif reply[:4] == b'QUIT':
|
|
print("# exiting tls_dummy_server")
|
|
try:
|
|
connection.close()
|
|
except:
|
|
pass
|
|
return
|
|
else:
|
|
try:
|
|
tls_sock= self.context.wrap_socket(connection, server_side=True)
|
|
except Exception as e:
|
|
print("error occured")
|
|
print(e)
|
|
connection.close()
|
|
connection.close()
|
|
|
|
# Hardcoded server hello packet (captured from MariaDB Server 11.4.2)
|
|
server_hello = b'R\x00\x00\x00\n11.4.2-MariaDB\x00\xff\x01\x00\x00Nv\
|
|
*hQ;qK\x00\xfe\xff\x08\x02\x00\xff\x81\x15\x00\x00\x00\
|
|
\x00\x00\x00\x1d\x00\x00\x00`$-VIJyC!x[?\x00mysql_native_password\x00'
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser= argparse.ArgumentParser(
|
|
prog='tls_server',
|
|
description='Simple TLS dummy test server')
|
|
parser.add_argument('--host', help='Hostaddress of TLS test server (Default 127.0.0.1)')
|
|
parser.add_argument('--port', help='Port of TLS test server. (Default 50000)')
|
|
|
|
args= parser.parse_args()
|
|
|
|
if not (port := args.port):
|
|
port= 50000;
|
|
if not (host := args.host):
|
|
host= "127.0.0.1"
|
|
server= TlsServer(host=host, port=int(port))
|
|
print("# Starting tls_dummy_server")
|
|
server.run()
|