1
0
mirror of https://github.com/mariadb-corporation/mariadb-columnstore-engine.git synced 2025-08-10 01:22:48 +03:00
Files
mariadb-columnstore-engine/cmapi/integration_tests/ssh.py
Alexander Presnyakov 4df61622cc Review fixes
2025-07-30 00:53:01 +00:00

123 lines
4.0 KiB
Python

import concurrent.futures
import json
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable
from fabric import Connection
from dataclasses import dataclass, field
@dataclass
class RemoteHost:
"""Keeps configuration for a host in the cluster"""
name: str
private_ip: str
public_fqdn: str
key_file_path: Path = field(repr=False)
ssh_user: str = field(repr=False)
@contextmanager
def ssh_connection(self):
"""Connect to the host using SSH (via Fabric)"""
conn = Connection(
host=self.public_fqdn,
user=self.ssh_user,
connect_kwargs={
"key_filename": str(self.key_file_path),
}
)
yield conn
conn.close()
def is_reachable(self) -> bool:
"""Check if the host is reachable via SSH"""
try:
with self.ssh_connection() as conn:
conn.run("echo 'SSH connection successful'", hide=True)
return True
except Exception as e:
print(f"SSH connection failed for {self.name}: {e}")
return False
def exec_mcs(self, params: str, drop_timestamp: bool = True) -> dict:
"""Execute an mcs subcommand and return its parsed output"""
with self.ssh_connection() as conn:
print(f"{self.name}: Executing 'mcs {params}'")
result = conn.sudo(f"mcs {params}", hide=True)
if result.return_code != 0:
raise RuntimeError(f"Failed to execute mcs command on {self.name}: {result.stderr}")
output = json.loads(result.stdout)
return output
class ClusterConfig:
"""Keeps configuration of the cluster"""
def __init__(self, mcs_hosts: list[RemoteHost], maxscale_hosts: list[RemoteHost]) -> None:
self.mcs_hosts = mcs_hosts
self.maxscale_hosts = maxscale_hosts
def __repr__(self) -> str:
return f"ClusterConfig(mcs_hosts={self.mcs_hosts}"
@property
def primary(self) -> RemoteHost:
return self.mcs_hosts[0]
@property
def replicas(self) -> list[RemoteHost]:
return self.mcs_hosts[1:]
def run_on_all_hosts_parallel(
hosts: list[RemoteHost],
func: Callable[[RemoteHost], Any],
timeout: float | None = None,
max_workers: int | None = None,
) -> dict[str, Any]:
"""
Run a function on all hosts in parallel.
:param hosts: List of RemoteHost objects
:param func: Function that takes a RemoteHost as its only argument
:param timeout: Timeout for the function execution
:param max_workers: Maximum number of worker threads (None = default based on system)
:returns: Dictionary mapping hosts to function results
"""
results = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Start the function execution on each host
future_to_host = {executor.submit(func, host): host for host in hosts}
# Process results as they complete
for future in concurrent.futures.as_completed(future_to_host):
host = future_to_host[future]
try:
result = future.result(timeout=timeout)
results[host.name] = result
except Exception as exc:
print(f"Error on host {host.name}: {exc}")
results[host.name] = exc
raise
return results
def block_port(host: RemoteHost, port: int) -> None:
"""Block a port on the host using iptables"""
with host.ssh_connection() as conn:
print(f"{host.name}: Blocking port {port}")
conn.sudo(f"iptables -A INPUT -p tcp --dport {port} -j DROP")
print(f"{host.name}: Port {port} blocked")
def unblock_port(host: RemoteHost, port: int) -> None:
"""Unblock a port on the host using iptables"""
with host.ssh_connection() as conn:
print(f"{host.name}: Unblocking port {port}")
conn.sudo(f"iptables -D INPUT -p tcp --dport {port} -j DROP")
print(f"{host.name}: Port {port} unblocked")