1
0
mirror of https://github.com/mariadb-corporation/mariadb-columnstore-engine.git synced 2026-01-06 08:21:10 +03:00

Merge pull request #3765 from mariadb-corporation/MCOL-6159-fix-active-nodes-membership-check

MCOL-6159: Add forward/reverse DNS validation when adding nodes by hostname
This commit is contained in:
Alexander Presniakov
2025-09-17 22:43:17 -03:00
committed by GitHub
7 changed files with 184 additions and 65 deletions

View File

@@ -12,7 +12,7 @@ import cherrypy
import pyotp
import requests
from cmapi_server.exceptions import CMAPIBasicError
from cmapi_server.exceptions import CMAPIBasicError, cmapi_error_to_422
from cmapi_server.constants import (
DEFAULT_MCS_CONF_PATH, DEFAULT_SM_CONF_PATH, EM_PATH_SUFFIX,
MCS_BRM_CURRENT_PATH, MCS_EM_PATH, S3_BRM_CURRENT_PATH, SECRET_KEY,
@@ -924,14 +924,12 @@ class ClusterController:
if node is None:
raise_422_error(module_logger, func_name, 'missing node argument')
try:
with cmapi_error_to_422(module_logger, func_name):
if not in_transaction:
with TransactionManager(extra_nodes=[node]):
response = ClusterHandler.add_node(node, config)
else:
response = ClusterHandler.add_node(node, config)
except CMAPIBasicError as err:
raise_422_error(module_logger, func_name, err.message)
module_logger.debug(f'{func_name} returns {str(response)}')
return response
@@ -953,14 +951,12 @@ class ClusterController:
if node is None:
raise_422_error(module_logger, func_name, 'missing node argument')
try:
with cmapi_error_to_422(module_logger, func_name):
if not in_transaction:
with TransactionManager(remove_nodes=[node]):
response = ClusterHandler.remove_node(node, config)
else:
response = ClusterHandler.remove_node(node, config)
except CMAPIBasicError as err:
raise_422_error(module_logger, func_name, err.message)
module_logger.debug(f'{func_name} returns {str(response)}')
return response
@@ -1079,10 +1075,8 @@ class ClusterController:
module_logger, func_name, 'Wrong verification key.'
)
try:
with cmapi_error_to_422(module_logger, func_name):
response = ClusterHandler.set_api_key(new_api_key, totp_key)
except CMAPIBasicError as err:
raise_422_error(module_logger, func_name, err.message)
module_logger.debug(f'{func_name} returns {str(response)}')
return response

View File

@@ -1,5 +1,11 @@
"""Module contains custom exceptions."""
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Optional
from cmapi_server.controllers.error import APIError
class CMAPIBasicError(Exception):
"""Basic exception raised for CMAPI related processes.
@@ -20,3 +26,35 @@ class CEJError(CMAPIBasicError):
Attributes:
message -- explanation of the error
"""
@contextmanager
def exc_to_cmapi_error(prefix: Optional[str] = None) -> Iterator[None]:
"""Context manager to standardize error wrapping into CMAPIBasicError.
Re-raises existing CMAPIBasicError untouched (to preserve detailed
messages). Any other exception type is wrapped into CMAPIBasicError with an
optional prefix and the original exception string appended as details.
:param prefix: Optional message prefix for wrapped errors
:raises CMAPIBasicError: for any wrapped non-CMAPIBasicError exceptions
"""
try:
yield
except CMAPIBasicError:
# Preserve detailed messages from deeper layers (e.g., validation)
raise
except Exception as err:
msg = f"{prefix}. Details: {err}" if prefix else str(err)
raise CMAPIBasicError(msg) from err
@contextmanager
def cmapi_error_to_422(logger, func_name: str) -> Iterator[None]:
"""Convert CMAPIBasicError to HTTP 422 APIError."""
try:
yield
except CMAPIBasicError as err:
# mirror raise_422_error behavior locally to avoid circular imports
logger.error(f'{func_name} {err.message}', exc_info=False)
raise APIError(422, err.message) from err

View File

@@ -6,12 +6,13 @@ from typing import Optional
from mcs_node_control.models.misc import get_dbrm_master
from mcs_node_control.models.node_config import NodeConfig
from tracing.traced_session import get_traced_session
from cmapi_server.constants import (
CMAPI_CONF_PATH,
DEFAULT_MCS_CONF_PATH,
)
from cmapi_server.exceptions import CMAPIBasicError
from cmapi_server.exceptions import CMAPIBasicError, exc_to_cmapi_error
from cmapi_server.helpers import (
broadcast_new_config,
get_active_nodes,
@@ -27,7 +28,6 @@ from cmapi_server.node_manipulation import (
remove_node,
switch_node_maintenance,
)
from tracing.traced_session import get_traced_session
class ClusterAction(Enum):
@@ -171,7 +171,7 @@ class ClusterHandler:
response = {'timestamp': str(datetime.now())}
try:
with exc_to_cmapi_error(prefix='Error while adding node'):
add_node(
node, input_config_filename=config,
output_config_filename=config
@@ -181,8 +181,6 @@ class ClusterHandler:
host=node, input_config_filename=config,
output_config_filename=config
)
except Exception as err:
raise CMAPIBasicError('Error while adding node.') from err
response['node_id'] = node
update_revision_and_manager(
@@ -218,13 +216,11 @@ class ClusterHandler:
)
response = {'timestamp': str(datetime.now())}
try:
with exc_to_cmapi_error(prefix='Error while removing node'):
remove_node(
node, input_config_filename=config,
output_config_filename=config
)
except Exception as err:
raise CMAPIBasicError('Error while removing node.') from err
response['node_id'] = node
active_nodes = get_active_nodes(config)

View File

@@ -210,9 +210,28 @@ class NetworkManager:
hostnames = socket.gethostbyaddr(ip_addr)
return hostnames[0]
except socket.herror:
logging.error(f'No hostname found for address: {ip_addr!r}')
logging.error('No hostname found for address: %s', ip_addr)
return None
@classmethod
def get_hostnames_by_ip(cls, ip_addr: str) -> list[str]:
"""Get all hostnames for a given IP address.
:return: List of hostnames (may be empty if reverse lookup fails)
"""
try:
primary, aliases, _ = socket.gethostbyaddr(ip_addr)
seen = set()
names = []
for n in [primary, *aliases]:
if n not in seen:
seen.add(n)
names.append(n)
return names
except socket.herror:
logging.error('No hostname found for address: %s', ip_addr)
return []
@classmethod
def is_only_loopback_hostname(cls, hostname: str) -> bool:
"""Check if all IPs resolved from the hostname are loopback.
@@ -256,3 +275,44 @@ class NetworkManager:
raise CMAPIBasicError(f'No IPs found for {hostname!r}')
ip = ip_list[0]
return ip, hostname
@classmethod
def validate_hostname_fwd_rev(cls, hostname: str) -> None:
"""Validate forward and reverse DNS for a hostname.
Checks that hostname resolves to one or more usable IPs and that at
least one of those IPs reverse-resolves back to the provided hostname
(either an exact match or an FQDN starting with the hostname are accepted).
:raises CMAPIBasicError: if validation fails
"""
exclude_loopback = not cls.is_only_loopback_hostname(hostname)
ips = cls.resolve_hostname_to_ip(
hostname,
only_ipv4=True,
exclude_loopback=exclude_loopback,
)
if not ips:
raise CMAPIBasicError(
f"Hostname {hostname!r} did not resolve to any usable IPs. "
"Please fix DNS or add the host by IP."
)
wanted = hostname.rstrip('.').lower()
for ip in ips:
rev_names = cls.get_hostnames_by_ip(ip)
for rev in rev_names:
rev_norm = rev.rstrip('.').lower()
# Accept exact match ("db1" == "db1") or FQDN starting with the short hostname
# e.g. user provided "db1" and PTR returns "db1.example.com"
if rev_norm == wanted or rev_norm.startswith(wanted + '.'):
return
raise CMAPIBasicError(
'Forward/reverse DNS check failed: '
f"hostname {hostname!r} resolved to {ips}, but none of these IPs "
f"reverse-resolve back to {hostname!r}. Consider adding the host by IP, "
'or fix DNS so that at least one IP has a PTR/record mapping back to '
'the provided hostname.'
)

View File

@@ -15,6 +15,7 @@ from typing import Optional
import requests
from lxml import etree
from mcs_node_control.models.node_config import NodeConfig
from tracing.traced_session import get_traced_session
from cmapi_server import helpers
from cmapi_server.constants import (
@@ -25,7 +26,6 @@ from cmapi_server.constants import (
MCS_DATA_PATH,
)
from cmapi_server.managers.network import NetworkManager
from tracing.traced_session import get_traced_session
PMS_NODE_PORT = '8620'
EXEMGR_NODE_PORT = '8601'
@@ -61,7 +61,6 @@ def switch_node_maintenance(
node_config.write_config(config_root, filename=output_config_filename)
# TODO: probably move publishing to cherrypy.engine failover channel here?
def add_node(
node: str, input_config_filename: str = DEFAULT_MCS_CONF_PATH,
output_config_filename: Optional[str] = None,
@@ -96,6 +95,11 @@ def add_node(
node_config = NodeConfig()
c_root = node_config.get_current_config_root(input_config_filename)
# If a hostname (not IP) is provided, ensure fwd/rev DNS consistency.
# Skip validation for localhost aliases to preserve legacy single-node flows.
if not NetworkManager.is_ip(node) and not NetworkManager.is_only_loopback_hostname(node):
NetworkManager.validate_hostname_fwd_rev(node)
try:
if not _replace_localhost(c_root, node):
pm_num = _add_node_to_PMS(c_root, node)
@@ -636,7 +640,7 @@ def _rebalance_dbroots(root, test_mode=False):
# timed out
# possible node is not ready, leave retry as-is
pass
except Exception as e:
except Exception:
retry = False
if not found_master:

View File

@@ -1,13 +1,13 @@
import logging
import socket
from unittest.mock import patch
from mcs_node_control.models.node_config import NodeConfig
from cmapi_server.failover_agent import FailoverAgent
from cmapi_server.managers.network import NetworkManager
from cmapi_server.node_manipulation import add_node, remove_node
from mcs_node_control.models.node_config import NodeConfig
from cmapi_server.test.unittest_global import (
tmp_mcs_config_filename, BaseNodeManipTestCase
)
from cmapi_server.test.unittest_global import BaseNodeManipTestCase, tmp_mcs_config_filename
logging.basicConfig(level='DEBUG')
@@ -18,10 +18,12 @@ class TestFailoverAgent(BaseNodeManipTestCase):
self.tmp_files = ('./activate0.xml', './activate1.xml')
hostaddr = socket.gethostbyname(socket.gethostname())
fa = FailoverAgent()
fa.activateNodes(
[self.NEW_NODE_NAME], tmp_mcs_config_filename, self.tmp_files[0],
test_mode=True
)
# Bypass DNS validation for hostname-based addition in tests
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
fa.activateNodes(
[self.NEW_NODE_NAME], tmp_mcs_config_filename, self.tmp_files[0],
test_mode=True
)
add_node(
hostaddr, self.tmp_files[0], self.tmp_files[1]
)
@@ -50,10 +52,11 @@ class TestFailoverAgent(BaseNodeManipTestCase):
add_node(
hostaddr, tmp_mcs_config_filename, self.tmp_files[0]
)
fa.activateNodes(
[self.NEW_NODE_NAME], self.tmp_files[0], self.tmp_files[1],
test_mode=True
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
fa.activateNodes(
[self.NEW_NODE_NAME], self.tmp_files[0], self.tmp_files[1],
test_mode=True
)
fa.deactivateNodes(
[self.NEW_NODE_NAME], self.tmp_files[1], self.tmp_files[2],
test_mode=True
@@ -89,10 +92,11 @@ class TestFailoverAgent(BaseNodeManipTestCase):
)
fa = FailoverAgent()
hostaddr = socket.gethostbyname(socket.gethostname())
fa.activateNodes(
[self.NEW_NODE_NAME], tmp_mcs_config_filename, self.tmp_files[0],
test_mode=True
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
fa.activateNodes(
[self.NEW_NODE_NAME], tmp_mcs_config_filename, self.tmp_files[0],
test_mode=True
)
add_node(
hostaddr, self.tmp_files[0], self.tmp_files[1]
)

View File

@@ -1,15 +1,15 @@
import logging
import socket
from unittest.mock import patch
from lxml import etree
from mcs_node_control.models.node_config import NodeConfig
from cmapi_server import node_manipulation
from cmapi_server.constants import MCS_DATA_PATH
from cmapi_server.test.unittest_global import (
tmp_mcs_config_filename, BaseNodeManipTestCase
)
from mcs_node_control.models.node_config import NodeConfig
from cmapi_server.exceptions import CMAPIBasicError
from cmapi_server.managers.network import NetworkManager
from cmapi_server.test.unittest_global import BaseNodeManipTestCase, tmp_mcs_config_filename
logging.basicConfig(level='DEBUG')
@@ -21,9 +21,11 @@ class NodeManipTester(BaseNodeManipTestCase):
'./test-output0.xml','./test-output1.xml','./test-output2.xml'
)
hostaddr = socket.gethostbyname(socket.gethostname())
node_manipulation.add_node(
self.NEW_NODE_NAME, tmp_mcs_config_filename, self.tmp_files[0]
)
# Bypass DNS validation to avoid dependence on external DNS for tests
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
node_manipulation.add_node(
self.NEW_NODE_NAME, tmp_mcs_config_filename, self.tmp_files[0]
)
node_manipulation.add_node(
hostaddr, self.tmp_files[0], self.tmp_files[1]
)
@@ -69,9 +71,10 @@ class NodeManipTester(BaseNodeManipTestCase):
etree.SubElement(sysconf_node, 'DBRoot10').text = '/dummy_path/data10'
nc.write_config(root, self.tmp_files[0])
node_manipulation.add_node(
self.NEW_NODE_NAME, self.tmp_files[0], self.tmp_files[1]
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
node_manipulation.add_node(
self.NEW_NODE_NAME, self.tmp_files[0], self.tmp_files[1]
)
# get a NodeConfig, read test.xml
# look for some of the expected changes.
@@ -113,12 +116,13 @@ class NodeManipTester(BaseNodeManipTestCase):
# add a node, verify we can add a dbroot to each of them
hostname = socket.gethostname()
node_manipulation.add_node(
hostname, tmp_mcs_config_filename, self.tmp_files[1]
)
node_manipulation.add_node(
self.NEW_NODE_NAME, self.tmp_files[1], self.tmp_files[2]
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
node_manipulation.add_node(
hostname, tmp_mcs_config_filename, self.tmp_files[1]
)
node_manipulation.add_node(
self.NEW_NODE_NAME, self.tmp_files[1], self.tmp_files[2]
)
id1 = node_manipulation.add_dbroot(
self.tmp_files[2], self.tmp_files[3], host=self.NEW_NODE_NAME
)
@@ -152,9 +156,10 @@ class NodeManipTester(BaseNodeManipTestCase):
def test_change_primary_node(self):
# add a node, make it the primary, verify expected result
self.tmp_files = ('./primary-node0.xml', './primary-node1.xml')
node_manipulation.add_node(
self.NEW_NODE_NAME, tmp_mcs_config_filename, self.tmp_files[0]
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
node_manipulation.add_node(
self.NEW_NODE_NAME, tmp_mcs_config_filename, self.tmp_files[0]
)
node_manipulation.move_primary_node(
self.tmp_files[0], self.tmp_files[1]
)
@@ -179,17 +184,19 @@ class NodeManipTester(BaseNodeManipTestCase):
self.tmp_files = (
'./tud-0.xml', './tud-1.xml', './tud-2.xml', './tud-3.xml',
)
node_manipulation.add_node(
self.NEW_NODE_NAME, tmp_mcs_config_filename, self.tmp_files[0]
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
node_manipulation.add_node(
self.NEW_NODE_NAME, tmp_mcs_config_filename, self.tmp_files[0]
)
root = NodeConfig().get_current_config_root(self.tmp_files[0])
(name, addr) = node_manipulation.find_dbroot1(root)
self.assertEqual(name, self.NEW_NODE_NAME)
# add a second node and more dbroots to make the test slightly more robust
node_manipulation.add_node(
socket.gethostname(), self.tmp_files[0], self.tmp_files[1]
)
with patch.object(NetworkManager, 'validate_hostname_fwd_rev', return_value=None):
node_manipulation.add_node(
socket.gethostname(), self.tmp_files[0], self.tmp_files[1]
)
node_manipulation.add_dbroot(
self.tmp_files[1], self.tmp_files[2], socket.gethostname()
)
@@ -209,3 +216,19 @@ class NodeManipTester(BaseNodeManipTestCase):
caught_it = True
self.assertTrue(caught_it)
def test_add_node_hostname_reverse_mismatch(self):
"""Adding a node by hostname should fail if reverse DNS doesn't map
back to the provided hostname (neither exact nor FQDN starting with it).
"""
self.tmp_files = ('./rev-mismatch-0.xml',)
bad_hostname = 'badhost'
with patch.object(NetworkManager, 'is_ip', return_value=False), \
patch.object(NetworkManager, 'is_only_loopback_hostname', return_value=False), \
patch.object(NetworkManager, 'resolve_hostname_to_ip', return_value=['10.0.0.5']), \
patch.object(NetworkManager, 'get_hostnames_by_ip', return_value=['other.example.com', 'alias.other.example.com']):
with self.assertRaises(CMAPIBasicError):
node_manipulation.add_node(
bad_hostname, tmp_mcs_config_filename, self.tmp_files[0]
)