From a0b4bcd1cec5fe927be00376aad8e39ddc935430 Mon Sep 17 00:00:00 2001 From: Alexander Presnyakov Date: Tue, 26 Aug 2025 02:57:09 +0000 Subject: [PATCH] Basic request tracer Tracing requests Custom log factory adds all trace values as one log record parameter (it will be empty if trace values are empty, like in MainThread where there are no incoming requests) --- cmapi/.gitignore | 4 + cmapi/cmapi_server/__main__.py | 11 +- cmapi/cmapi_server/controllers/api_clients.py | 17 +- cmapi/cmapi_server/handlers/cluster.py | 10 +- cmapi/cmapi_server/helpers.py | 30 ++- cmapi/cmapi_server/logging_management.py | 25 +++ cmapi/cmapi_server/node_manipulation.py | 6 +- cmapi/cmapi_server/trace_tool.py | 51 +++++ cmapi/cmapi_server/traced_aiohttp.py | 43 ++++ cmapi/cmapi_server/traced_session.py | 52 +++++ cmapi/cmapi_server/tracer.py | 210 ++++++++++++++++++ 11 files changed, 432 insertions(+), 27 deletions(-) create mode 100644 cmapi/cmapi_server/trace_tool.py create mode 100644 cmapi/cmapi_server/traced_aiohttp.py create mode 100644 cmapi/cmapi_server/traced_session.py create mode 100644 cmapi/cmapi_server/tracer.py diff --git a/cmapi/.gitignore b/cmapi/.gitignore index bdbe67f3f..4b87c115f 100644 --- a/cmapi/.gitignore +++ b/cmapi/.gitignore @@ -87,3 +87,7 @@ result centos8 ubuntu20.04 buildinfo.txt + +# Self-signed certificates +cmapi_server/self-signed.crt +cmapi_server/self-signed.key \ No newline at end of file diff --git a/cmapi/cmapi_server/__main__.py b/cmapi/cmapi_server/__main__.py index d8bf3892b..5a8a2de58 100644 --- a/cmapi/cmapi_server/__main__.py +++ b/cmapi/cmapi_server/__main__.py @@ -18,6 +18,7 @@ from cherrypy.process import plugins from cmapi_server.logging_management import config_cmapi_server_logging from cmapi_server.sentry import maybe_init_sentry, register_sentry_cherrypy_tool config_cmapi_server_logging() +from cmapi_server.trace_tool import register_tracing_tools from cmapi_server import helpers from cmapi_server.constants import DEFAULT_MCS_CONF_PATH, CMAPI_CONF_PATH @@ -141,6 +142,7 @@ if __name__ == '__main__': # TODO: read cmapi config filepath as an argument helpers.cmapi_config_check() + register_tracing_tools() # Init Sentry if DSN is present sentry_active = maybe_init_sentry() if sentry_active: @@ -153,6 +155,9 @@ if __name__ == '__main__': root_config = { "request.dispatch": dispatcher, "error_page.default": jsonify_error, + # Enable tracing tools + 'tools.trace.on': True, + 'tools.trace_end.on': True, } if sentry_active: root_config["tools.sentry.on"] = True @@ -230,10 +235,10 @@ if __name__ == '__main__': 'Something went wrong while trying to detect dbrm protocol.\n' 'Seems "controllernode" process isn\'t started.\n' 'This is just a notification, not a problem.\n' - 'Next detection will started at first node\\cluster ' + 'Next detection will start at first node\\cluster ' 'status check.\n' - f'This can cause extra {SOCK_TIMEOUT} seconds delay while\n' - 'first attempt to get status.', + f'This can cause extra {SOCK_TIMEOUT} seconds delay during\n' + 'this first attempt to get the status.', exc_info=True ) else: diff --git a/cmapi/cmapi_server/controllers/api_clients.py b/cmapi/cmapi_server/controllers/api_clients.py index 7b7e69622..74c8723a9 100644 --- a/cmapi/cmapi_server/controllers/api_clients.py +++ b/cmapi/cmapi_server/controllers/api_clients.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Union import pyotp import requests +from cmapi_server.traced_session import get_traced_session from cmapi_server.controllers.dispatcher import _version from cmapi_server.constants import ( @@ -141,7 +142,7 @@ class ClusterControllerClient: headers['Content-Type'] = 'application/json' data = {'in_transaction': True, **(data or {})} try: - response = requests.request( + response = get_traced_session().request( method, url, headers=headers, json=data, timeout=self.request_timeout, verify=False ) @@ -151,24 +152,26 @@ class ClusterControllerClient: except requests.HTTPError as exc: resp = exc.response error_msg = str(exc) - if resp.status_code == 422: + if resp is not None and resp.status_code == 422: # in this case we think cmapi server returned some value but # had error during running endpoint handler code try: - resp_json = response.json() + resp_json = resp.json() error_msg = resp_json.get('error', resp_json) except requests.exceptions.JSONDecodeError: - error_msg = response.text + error_msg = resp.text message = ( - f'API client got an exception in request to {exc.request.url} ' - f'with code {resp.status_code} and error: {error_msg}' + f'API client got an exception in request to {exc.request.url if exc.request else url} ' + f'with code {resp.status_code if resp is not None else "?"} and error: {error_msg}' ) logging.error(message) raise CMAPIBasicError(message) except requests.exceptions.RequestException as exc: + request_url = getattr(exc.request, 'url', url) + response_status = getattr(getattr(exc, 'response', None), 'status_code', '?') message = ( 'API client got an undefined error in request to ' - f'{exc.request.url} with code {exc.response.status_code} and ' + f'{request_url} with code {response_status} and ' f'error: {str(exc)}' ) logging.error(message) diff --git a/cmapi/cmapi_server/handlers/cluster.py b/cmapi/cmapi_server/handlers/cluster.py index f2d8f892b..0f0d08606 100644 --- a/cmapi/cmapi_server/handlers/cluster.py +++ b/cmapi/cmapi_server/handlers/cluster.py @@ -4,7 +4,7 @@ from datetime import datetime from enum import Enum from typing import Optional -import requests +from cmapi_server.traced_session import get_traced_session from cmapi_server.constants import ( CMAPI_CONF_PATH, DEFAULT_MCS_CONF_PATH, @@ -78,7 +78,7 @@ class ClusterHandler(): for node in active_nodes: url = f'https://{node}:8640/cmapi/{get_version()}/node/status' try: - r = requests.get(url, verify=False, headers=headers) + r = get_traced_session().request('GET', url, verify=False, headers=headers) r.raise_for_status() r_json = r.json() if len(r_json.get('services', 0)) == 0: @@ -277,7 +277,7 @@ class ClusterHandler(): payload['cluster_mode'] = mode try: - r = requests.put(url, headers=headers, json=payload, verify=False) + r = get_traced_session().request('PUT', url, headers=headers, json=payload, verify=False) r.raise_for_status() response['cluster-mode'] = mode except Exception as err: @@ -330,7 +330,7 @@ class ClusterHandler(): logger.debug(f'Setting new api key to "{node}".') url = f'https://{node}:8640/cmapi/{get_version()}/node/apikey-set' try: - resp = requests.put(url, verify=False, json=body) + resp = get_traced_session().request('PUT', url, verify=False, json=body, headers={}) resp.raise_for_status() r_json = resp.json() if active_nodes_count > 0: @@ -383,7 +383,7 @@ class ClusterHandler(): logger.debug(f'Setting new log level to "{node}".') url = f'https://{node}:8640/cmapi/{get_version()}/node/log-level' try: - resp = requests.put(url, verify=False, json=body) + resp = get_traced_session().request('PUT', url, verify=False, json=body, headers={}) resp.raise_for_status() r_json = resp.json() if active_nodes_count > 0: diff --git a/cmapi/cmapi_server/helpers.py b/cmapi/cmapi_server/helpers.py index df58889a1..be38db2f5 100644 --- a/cmapi/cmapi_server/helpers.py +++ b/cmapi/cmapi_server/helpers.py @@ -11,7 +11,6 @@ import os import socket import time from collections import namedtuple -from functools import partial from random import random from shutil import copyfile from typing import Tuple, Optional @@ -20,6 +19,8 @@ from urllib.parse import urlencode, urlunparse import aiohttp import lxml.objectify import requests +from cmapi_server.traced_session import get_traced_session +from cmapi_server.traced_aiohttp import create_traced_async_session from cmapi_server.exceptions import CMAPIBasicError # Bug in pylint https://github.com/PyCQA/pylint/issues/4584 @@ -153,9 +154,9 @@ def start_transaction( body['timeout'] = ( final_time - datetime.datetime.now() ).seconds - r = requests.put( - url, verify=False, headers=headers, json=body, - timeout=10 + r = get_traced_session().request( + 'PUT', url, verify=False, headers=headers, + json=body, timeout=10 ) # a 4xx error from our endpoint; @@ -219,8 +220,9 @@ def rollback_txn_attempt(key, version, txnid, nodes): url = f"https://{node}:8640/cmapi/{version}/node/rollback" for retry in range(5): try: - r = requests.put( - url, verify=False, headers=headers, json=body, timeout=5 + r = get_traced_session().request( + 'PUT', url, verify=False, headers=headers, + json=body, timeout=5 ) r.raise_for_status() except requests.Timeout: @@ -274,7 +276,10 @@ def commit_transaction( url = f"https://{node}:8640/cmapi/{version}/node/commit" for retry in range(5): try: - r = requests.put(url, verify = False, headers = headers, json = body, timeout = 5) + r = get_traced_session().request( + 'PUT', url, verify=False, headers=headers, + json=body, timeout=5 + ) r.raise_for_status() except requests.Timeout as e: logging.warning(f"commit_transaction(): timeout on node {node}") @@ -373,7 +378,7 @@ def broadcast_new_config( url = f'https://{node}:8640/cmapi/{version}/node/config' resp_json: dict = dict() - async with aiohttp.ClientSession() as session: + async with create_traced_async_session() as session: try: async with session.put( url, headers=headers, json=body, ssl=False, timeout=120 @@ -656,7 +661,7 @@ def get_current_config_file( headers = {'x-api-key' : key} url = f'https://{node}:8640/cmapi/{get_version()}/node/config' try: - r = requests.get(url, verify=False, headers=headers, timeout=5) + r = get_traced_session().request('GET', url, verify=False, headers=headers, timeout=5) r.raise_for_status() config = r.json()['config'] except Exception as e: @@ -767,14 +772,17 @@ def if_primary_restart( success = False while not success and datetime.datetime.now() < endtime: try: - response = requests.put(url, verify = False, headers = headers, json = body, timeout = 60) + response = get_traced_session().request( + 'PUT', url, verify=False, headers=headers, + json=body, timeout=60 + ) response.raise_for_status() success = True except Exception as e: logging.warning(f"if_primary_restart(): failed to start the cluster, got {str(e)}") time.sleep(10) if not success: - logging.error(f"if_primary_restart(): failed to start the cluster. Manual intervention is required.") + logging.error("if_primary_restart(): failed to start the cluster. Manual intervention is required.") def get_cej_info(config_root): diff --git a/cmapi/cmapi_server/logging_management.py b/cmapi/cmapi_server/logging_management.py index cffcae122..837b4267f 100644 --- a/cmapi/cmapi_server/logging_management.py +++ b/cmapi/cmapi_server/logging_management.py @@ -7,6 +7,7 @@ import cherrypy from cherrypy import _cperror from cmapi_server.constants import CMAPI_LOG_CONF_PATH +from cmapi_server.tracer import get_tracer class AddIpFilter(logging.Filter): @@ -16,6 +17,28 @@ class AddIpFilter(logging.Filter): return True +def install_trace_record_factory() -> None: + """Install a LogRecord factory that adds 'trace_params' field. + 'trace_params' will be an empty string if there is no active trace/span + (like in MainThread, where there is no incoming requests). + Otherwise it will contain trace parameters. + """ + current_factory = logging.getLogRecordFactory() + + def factory(*args, **kwargs): # type: ignore[no-untyped-def] + record = current_factory(*args, **kwargs) + try: + trace_id, span_id, parent_span_id = get_tracer().current_trace_ids() + record.trace_params = ( + f" rid={trace_id} sid={span_id} psid={parent_span_id}" + ) + except Exception: + record.trace_params = " rid= sid= psid=" + return record + + logging.setLogRecordFactory(factory) + + def custom_cherrypy_error( self, msg='', context='', severity=logging.INFO, traceback=False ): @@ -119,6 +142,8 @@ def config_cmapi_server_logging(): cherrypy._cplogging.LogManager.access_log_format = ( '{h} ACCESS "{r}" code {s}, bytes {b}, user-agent "{a}"' ) + # Ensure trace_params is available on every record + install_trace_record_factory() dict_config(CMAPI_LOG_CONF_PATH) diff --git a/cmapi/cmapi_server/node_manipulation.py b/cmapi/cmapi_server/node_manipulation.py index bccdb142a..ff0d5259c 100644 --- a/cmapi/cmapi_server/node_manipulation.py +++ b/cmapi/cmapi_server/node_manipulation.py @@ -20,7 +20,9 @@ from cmapi_server.constants import ( CMAPI_CONF_PATH, CMAPI_SINGLE_NODE_XML, DEFAULT_MCS_CONF_PATH, LOCALHOSTS, MCS_DATA_PATH, ) +from cmapi_server.traced_session import get_traced_session from cmapi_server.managers.network import NetworkManager +from cmapi_server.tracer import get_tracer from mcs_node_control.models.node_config import NodeConfig @@ -617,7 +619,9 @@ def _rebalance_dbroots(root, test_mode=False): headers = {'x-api-key': key} url = f"https://{node_ip}:8640/cmapi/{version}/node/new_primary" try: - r = requests.get(url, verify = False, headers = headers, timeout = 10) + r = get_traced_session().request( + 'GET', url, verify=False, headers=headers, timeout=10 + ) r.raise_for_status() r = r.json() is_primary = r['is_primary'] diff --git a/cmapi/cmapi_server/trace_tool.py b/cmapi/cmapi_server/trace_tool.py new file mode 100644 index 000000000..7164373d6 --- /dev/null +++ b/cmapi/cmapi_server/trace_tool.py @@ -0,0 +1,51 @@ +""" +CherryPy tool that uses the tracer to start a span for each request. + +If traceparent header is present in the request, the tool will continue this trace chain. +Otherwise, it will start a new trace (with a new trace_id). +""" +from typing import Dict + +import cherrypy + +from cmapi_server.tracer import get_tracer + + +def _on_request_start() -> None: + """CherryPy tool hook: extract incoming context and start a SERVER span.""" + req = cherrypy.request + tracer = get_tracer() + + headers: Dict[str, str] = dict(req.headers or {}) + trace_id, parent_span_id = tracer.extract_traceparent(headers) + tracer.set_incoming_context(trace_id, parent_span_id) + + span_name = f"{getattr(req, 'method', 'HTTP')} {getattr(req, 'path_info', '/')}" + + ctx = tracer.start_as_current_span(span_name, kind="SERVER") + span = ctx.__enter__() + setattr(req, "_trace_span_ctx", ctx) + setattr(req, "_trace_span", span) + + # Echo current traceparent to the client + tracer.inject_traceparent(cherrypy.response.headers) # type: ignore[arg-type] + + +def _on_request_end() -> None: + """CherryPy tool hook: end the SERVER span started at request start.""" + req = cherrypy.request + ctx = getattr(req, "_trace_span_ctx", None) + if ctx is not None: + try: + ctx.__exit__(None, None, None) + finally: + setattr(req, "_trace_span_ctx", None) + setattr(req, "_trace_span", None) + + +def register_tracing_tools() -> None: + """Register CherryPy tools for request tracing.""" + cherrypy.tools.trace = cherrypy.Tool("on_start_resource", _on_request_start, priority=10) + cherrypy.tools.trace_end = cherrypy.Tool("on_end_resource", _on_request_end, priority=80) + + diff --git a/cmapi/cmapi_server/traced_aiohttp.py b/cmapi/cmapi_server/traced_aiohttp.py new file mode 100644 index 000000000..4643e217d --- /dev/null +++ b/cmapi/cmapi_server/traced_aiohttp.py @@ -0,0 +1,43 @@ +"""Async sibling of TracedSession""" + +from typing import Any + +import aiohttp + +from cmapi_server.tracer import get_tracer + + +class TracedAsyncSession(aiohttp.ClientSession): + async def _request( + self, method: str, str_or_url: Any, *args: Any, **kwargs: Any + ) -> aiohttp.ClientResponse: + tracer = get_tracer() + + headers = kwargs.get("headers") or {} + if headers is None: + headers = {} + kwargs["headers"] = headers + + url_text = str(str_or_url) + span_name = f"HTTP {method} {url_text}" + with tracer.start_as_current_span(span_name, kind="CLIENT") as span: + span.set_attribute("http.method", method) + span.set_attribute("http.url", url_text) + try: + tracer.inject_traceparent(headers) + except Exception: + pass + try: + response = await super()._request(method, str_or_url, *args, **kwargs) + except Exception as exc: + span.set_status("ERROR", str(exc)) + raise + else: + span.set_attribute("http.status_code", response.status) + return response + + +def create_traced_async_session(**kwargs: Any) -> TracedAsyncSession: + return TracedAsyncSession(**kwargs) + + diff --git a/cmapi/cmapi_server/traced_session.py b/cmapi/cmapi_server/traced_session.py new file mode 100644 index 000000000..120e2379e --- /dev/null +++ b/cmapi/cmapi_server/traced_session.py @@ -0,0 +1,52 @@ +"""Our own customized requests.Session that automatically traces outbound HTTP calls. + +Creates a CLIENT span per outbound HTTP request, injects traceparent, +records method/url/status, and closes the span when the request finishes. +""" + +from typing import Any, Optional + +import requests + +from cmapi_server.tracer import get_tracer + + +class TracedSession(requests.Session): + """requests.Session that automatically traces outbound HTTP calls.""" + + def request(self, method: str, url: str, *args: Any, **kwargs: Any) -> requests.Response: + tracer = get_tracer() + + headers = kwargs.get("headers") or {} + if headers is None: + headers = {} + kwargs["headers"] = headers + + span_name = f"HTTP {method} {url}" + with tracer.start_as_current_span(span_name, kind="CLIENT") as span: + span.set_attribute("http.method", method) + span.set_attribute("http.url", url) + + tracer.inject_traceparent(headers) + try: + response = super().request(method, url, *args, **kwargs) + except Exception as exc: + span.set_status("ERROR", str(exc)) + raise + else: + # Record status code + span.set_attribute("http.status_code", response.status_code) + return response + + +_default_session: Optional[TracedSession] = None + + +def get_traced_session() -> TracedSession: + """Return a process-wide TracedSession singleton.""" + global _default_session + if _default_session is None: + _default_session = TracedSession() + return _default_session + + diff --git a/cmapi/cmapi_server/tracer.py b/cmapi/cmapi_server/tracer.py new file mode 100644 index 000000000..1b85b2e4d --- /dev/null +++ b/cmapi/cmapi_server/tracer.py @@ -0,0 +1,210 @@ +"""Support distributed request tracing via W3C Trace Context. +See https://www.w3.org/TR/trace-context/ for the official spec. + +There are 3 and a half main components: +- trace_id: a unique identifier for a trace. + It is a 32-hex string, passed in the outbound HTTP requests headers, so that we can + trace the request chain through the system. +- span_id: a unique identifier for a span (the current operation within a trace chain). + It is a 16-hex string. For example, when we we receive a request to add a host, the addition + of the host is a separate span within the request chain. +- parent_span_id: a unique identifier for the parent span of the current span. + Continuing the example above, when we add a host, first we start a transaction, + then we add the host. If we are already adding a host, then creation of the transaction + is the parent span of the current span. +- traceparent: a header that combines trace_id and span_id in one value. + It has the format "00---". + +A system that calls CMAPI can pass the traceparent header in the request, so that we can pass + the trace_id through the system, changing span_id as we enter new sub-operations. + +We can reconstruct the trace tree from the set of the logged traceparent attributes, + representing how the request was processed, which nodes were involved, + how much time did each op take, etc. + +This module implements a tracer class that creates spans, injects/extracts traceparent headers. +It uses contextvars to store the trace/span/parent_span ids and start time for each context. +""" + +from __future__ import annotations + +import contextvars +import logging +import os +import time +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + +# Contextvars containing trace/span/parent_span ids and start time for this thread +# (contextvars are something like TLS, they contain variables that are local to the context) +_current_trace_id = contextvars.ContextVar[str]("trace_id", default="") +_current_span_id = contextvars.ContextVar[str]("span_id", default="") +_current_parent_span_id = contextvars.ContextVar[str]("parent_span_id", default="") +_current_span_start_ns = contextvars.ContextVar[int]("span_start_ns", default=0) + + +def _rand_16_hex() -> str: + # 16 hex bytes (32 hex chars) + return os.urandom(16).hex() + +def _rand_8_hex() -> str: + # 8 hex bytes (16 hex chars) + return os.urandom(8).hex() + +def format_traceparent(trace_id: str, span_id: str, flags: str = "01") -> str: + """W3C traceparent: version 00""" + # version(2)-trace_id(32)-span_id(16)-flags(2) + return f"00-{trace_id}-{span_id}-{flags}" + +def parse_traceparent(header: str) -> Optional[tuple[str, str, str]]: + """Return (trace_id, span_id, flags) or None if invalid.""" + try: + parts = header.strip().split("-") + if len(parts) != 4 or parts[0] != "00": + logger.error(f"Invalid traceparent: {header}") + return None + trace_id, span_id, flags = parts[1], parts[2], parts[3] + if len(trace_id) != 32 or len(span_id) != 16 or len(flags) != 2: + return None + # W3C: all zero trace_id/span_id are invalid + if set(trace_id) == {"0"} or set(span_id) == {"0"}: + return None + return trace_id, span_id, flags + except Exception: + logger.error(f"Failed to parse traceparent: {header}") + return None + + +@dataclass +class TraceSpan: + """Lightweight span handle; keeps attributes in memory (for logging only).""" + name: str + kind: str # "SERVER" | "CLIENT" | "INTERNAL" + start_ns: int + trace_id: str + span_id: str + parent_span_id: str + attributes: Dict[str, Any] + + def set_attribute(self, key: str, value: Any) -> None: + self.attributes[key] = value + + def add_event(self, name: str, **attrs: Any) -> None: + # For simplicity we just log events immediately + logger.info( + "event name=%s trace_id=%s span_id=%s attrs=%s", + name, self.trace_id, self.span_id, attrs + ) + + def set_status(self, code: str, description: str = "") -> None: + self.attributes["status.code"] = code + if description: + self.attributes["status.description"] = description + + def record_exception(self, exc: BaseException) -> None: + self.add_event("exception", type=type(exc).__name__, msg=str(exc)) + + +class Tracer: + """Encapsulates everything related to tracing (span creation, logging, etc)""" + @contextmanager + def start_as_current_span(self, name: str, kind: str = "INTERNAL") -> Iterator[TraceSpan]: + trace_id = _current_trace_id.get() or _rand_16_hex() + parent_span = _current_span_id.get() + new_span_id = _rand_8_hex() + + # Push new context + tok_tid = _current_trace_id.set(trace_id) + tok_sid = _current_span_id.set(new_span_id) + tok_pid = _current_parent_span_id.set(parent_span) + tok_start = _current_span_start_ns.set(time.time_ns()) + + span = TraceSpan( + name=name, + kind=kind, + start_ns=_current_span_start_ns.get(), + trace_id=trace_id, + span_id=new_span_id, + parent_span_id=parent_span, + attributes={"span.kind": kind, "span.name": name}, + ) + + try: + logger.info( + "span_begin name=%s kind=%s trace_id=%s span_id=%s parent_span_id=%s attrs=%s", + span.name, span.kind, span.trace_id, span.span_id, span.parent_span_id, span.attributes + ) + yield span + except BaseException as exc: + span.record_exception(exc) + span.set_status("ERROR", str(exc)) + raise + finally: + # Pop the span from the context (restore parent context) + duration_ms = (time.time_ns() - span.start_ns) / 1_000_000 + logger.info( + "span_end name=%s kind=%s trace_id=%s span_id=%s parent_span_id=%s duration_ms=%.3f attrs=%s", + span.name, span.kind, span.trace_id, span.span_id, span.parent_span_id, duration_ms, span.attributes + ) + # Restore previous context + _current_span_start_ns.reset(tok_start) + _current_parent_span_id.reset(tok_pid) + _current_span_id.reset(tok_sid) + _current_trace_id.reset(tok_tid) + + def set_incoming_context( + self, + trace_id: Optional[str] = None, + parent_span_id: Optional[str] = None, + ) -> None: + """Seed current context with incoming W3C traceparent values. + + Only non-empty values are applied. + """ + if trace_id: + _current_trace_id.set(trace_id) + if parent_span_id: + _current_parent_span_id.set(parent_span_id) + + def current_trace_ids(self) -> tuple[str, str, str]: + return _current_trace_id.get(), _current_span_id.get(), _current_parent_span_id.get() + + def inject_traceparent(self, headers: Dict[str, str]) -> None: + """Inject W3C traceparent into outbound headers.""" + trace_id, span_id, _ = self.current_trace_ids() + if not trace_id or not span_id: + # If called outside of a span, create a short-lived span just to carry IDs + trace_id = trace_id or _rand_16_hex() + span_id = span_id or _rand_8_hex() + headers["traceparent"] = format_traceparent(trace_id, span_id) + + def extract_traceparent(self, headers: Dict[str, str]) -> tuple[str, str]: + """Extract parent trace/span; returns (trace_id, parent_span_id).""" + raw_traceparent = (headers.get("traceparent") + or headers.get("Traceparent") + or headers.get("TRACEPARENT")) + if not raw_traceparent: + return "", "" + parsed = parse_traceparent(raw_traceparent) + if not parsed: + return "", "" + return parsed[0], parsed[1] + # No incoming context + return "", "" + +# Tracer singleton for the process (not thread) +_tracer = Tracer() + +def get_tracer() -> Tracer: + return _tracer + + +class TraceLogFilter(logging.Filter): + """Inject trace/span ids into LogRecord for formatting.""" + def filter(self, record: logging.LogRecord) -> bool: + record.traceID, record.spanID, record.parentSpanID = get_tracer().current_trace_ids() + return True