1
0
mirror of https://github.com/mariadb-corporation/mariadb-columnstore-engine.git synced 2025-11-03 17:13:17 +03:00

Basic request tracer

Tracing requests

Custom log factory adds all trace values as one log record parameter (it will be empty if trace values are empty, like in MainThread where there are no incoming requests)
This commit is contained in:
Alexander Presnyakov
2025-08-26 02:57:09 +00:00
committed by Leonid Fedorov
parent 0fc41e0387
commit a0b4bcd1ce
11 changed files with 432 additions and 27 deletions

4
cmapi/.gitignore vendored
View File

@@ -87,3 +87,7 @@ result
centos8 centos8
ubuntu20.04 ubuntu20.04
buildinfo.txt buildinfo.txt
# Self-signed certificates
cmapi_server/self-signed.crt
cmapi_server/self-signed.key

View File

@@ -18,6 +18,7 @@ from cherrypy.process import plugins
from cmapi_server.logging_management import config_cmapi_server_logging from cmapi_server.logging_management import config_cmapi_server_logging
from cmapi_server.sentry import maybe_init_sentry, register_sentry_cherrypy_tool from cmapi_server.sentry import maybe_init_sentry, register_sentry_cherrypy_tool
config_cmapi_server_logging() config_cmapi_server_logging()
from cmapi_server.trace_tool import register_tracing_tools
from cmapi_server import helpers from cmapi_server import helpers
from cmapi_server.constants import DEFAULT_MCS_CONF_PATH, CMAPI_CONF_PATH from cmapi_server.constants import DEFAULT_MCS_CONF_PATH, CMAPI_CONF_PATH
@@ -141,6 +142,7 @@ if __name__ == '__main__':
# TODO: read cmapi config filepath as an argument # TODO: read cmapi config filepath as an argument
helpers.cmapi_config_check() helpers.cmapi_config_check()
register_tracing_tools()
# Init Sentry if DSN is present # Init Sentry if DSN is present
sentry_active = maybe_init_sentry() sentry_active = maybe_init_sentry()
if sentry_active: if sentry_active:
@@ -153,6 +155,9 @@ if __name__ == '__main__':
root_config = { root_config = {
"request.dispatch": dispatcher, "request.dispatch": dispatcher,
"error_page.default": jsonify_error, "error_page.default": jsonify_error,
# Enable tracing tools
'tools.trace.on': True,
'tools.trace_end.on': True,
} }
if sentry_active: if sentry_active:
root_config["tools.sentry.on"] = True root_config["tools.sentry.on"] = True
@@ -230,10 +235,10 @@ if __name__ == '__main__':
'Something went wrong while trying to detect dbrm protocol.\n' 'Something went wrong while trying to detect dbrm protocol.\n'
'Seems "controllernode" process isn\'t started.\n' 'Seems "controllernode" process isn\'t started.\n'
'This is just a notification, not a problem.\n' 'This is just a notification, not a problem.\n'
'Next detection will started at first node\\cluster ' 'Next detection will start at first node\\cluster '
'status check.\n' 'status check.\n'
f'This can cause extra {SOCK_TIMEOUT} seconds delay while\n' f'This can cause extra {SOCK_TIMEOUT} seconds delay during\n'
'first attempt to get status.', 'this first attempt to get the status.',
exc_info=True exc_info=True
) )
else: else:

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Union
import pyotp import pyotp
import requests import requests
from cmapi_server.traced_session import get_traced_session
from cmapi_server.controllers.dispatcher import _version from cmapi_server.controllers.dispatcher import _version
from cmapi_server.constants import ( from cmapi_server.constants import (
@@ -141,7 +142,7 @@ class ClusterControllerClient:
headers['Content-Type'] = 'application/json' headers['Content-Type'] = 'application/json'
data = {'in_transaction': True, **(data or {})} data = {'in_transaction': True, **(data or {})}
try: try:
response = requests.request( response = get_traced_session().request(
method, url, headers=headers, json=data, method, url, headers=headers, json=data,
timeout=self.request_timeout, verify=False timeout=self.request_timeout, verify=False
) )
@@ -151,24 +152,26 @@ class ClusterControllerClient:
except requests.HTTPError as exc: except requests.HTTPError as exc:
resp = exc.response resp = exc.response
error_msg = str(exc) error_msg = str(exc)
if resp.status_code == 422: if resp is not None and resp.status_code == 422:
# in this case we think cmapi server returned some value but # in this case we think cmapi server returned some value but
# had error during running endpoint handler code # had error during running endpoint handler code
try: try:
resp_json = response.json() resp_json = resp.json()
error_msg = resp_json.get('error', resp_json) error_msg = resp_json.get('error', resp_json)
except requests.exceptions.JSONDecodeError: except requests.exceptions.JSONDecodeError:
error_msg = response.text error_msg = resp.text
message = ( message = (
f'API client got an exception in request to {exc.request.url} ' f'API client got an exception in request to {exc.request.url if exc.request else url} '
f'with code {resp.status_code} and error: {error_msg}' f'with code {resp.status_code if resp is not None else "?"} and error: {error_msg}'
) )
logging.error(message) logging.error(message)
raise CMAPIBasicError(message) raise CMAPIBasicError(message)
except requests.exceptions.RequestException as exc: except requests.exceptions.RequestException as exc:
request_url = getattr(exc.request, 'url', url)
response_status = getattr(getattr(exc, 'response', None), 'status_code', '?')
message = ( message = (
'API client got an undefined error in request to ' 'API client got an undefined error in request to '
f'{exc.request.url} with code {exc.response.status_code} and ' f'{request_url} with code {response_status} and '
f'error: {str(exc)}' f'error: {str(exc)}'
) )
logging.error(message) logging.error(message)

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import requests from cmapi_server.traced_session import get_traced_session
from cmapi_server.constants import ( from cmapi_server.constants import (
CMAPI_CONF_PATH, DEFAULT_MCS_CONF_PATH, CMAPI_CONF_PATH, DEFAULT_MCS_CONF_PATH,
@@ -78,7 +78,7 @@ class ClusterHandler():
for node in active_nodes: for node in active_nodes:
url = f'https://{node}:8640/cmapi/{get_version()}/node/status' url = f'https://{node}:8640/cmapi/{get_version()}/node/status'
try: try:
r = requests.get(url, verify=False, headers=headers) r = get_traced_session().request('GET', url, verify=False, headers=headers)
r.raise_for_status() r.raise_for_status()
r_json = r.json() r_json = r.json()
if len(r_json.get('services', 0)) == 0: if len(r_json.get('services', 0)) == 0:
@@ -277,7 +277,7 @@ class ClusterHandler():
payload['cluster_mode'] = mode payload['cluster_mode'] = mode
try: try:
r = requests.put(url, headers=headers, json=payload, verify=False) r = get_traced_session().request('PUT', url, headers=headers, json=payload, verify=False)
r.raise_for_status() r.raise_for_status()
response['cluster-mode'] = mode response['cluster-mode'] = mode
except Exception as err: except Exception as err:
@@ -330,7 +330,7 @@ class ClusterHandler():
logger.debug(f'Setting new api key to "{node}".') logger.debug(f'Setting new api key to "{node}".')
url = f'https://{node}:8640/cmapi/{get_version()}/node/apikey-set' url = f'https://{node}:8640/cmapi/{get_version()}/node/apikey-set'
try: try:
resp = requests.put(url, verify=False, json=body) resp = get_traced_session().request('PUT', url, verify=False, json=body, headers={})
resp.raise_for_status() resp.raise_for_status()
r_json = resp.json() r_json = resp.json()
if active_nodes_count > 0: if active_nodes_count > 0:
@@ -383,7 +383,7 @@ class ClusterHandler():
logger.debug(f'Setting new log level to "{node}".') logger.debug(f'Setting new log level to "{node}".')
url = f'https://{node}:8640/cmapi/{get_version()}/node/log-level' url = f'https://{node}:8640/cmapi/{get_version()}/node/log-level'
try: try:
resp = requests.put(url, verify=False, json=body) resp = get_traced_session().request('PUT', url, verify=False, json=body, headers={})
resp.raise_for_status() resp.raise_for_status()
r_json = resp.json() r_json = resp.json()
if active_nodes_count > 0: if active_nodes_count > 0:

View File

@@ -11,7 +11,6 @@ import os
import socket import socket
import time import time
from collections import namedtuple from collections import namedtuple
from functools import partial
from random import random from random import random
from shutil import copyfile from shutil import copyfile
from typing import Tuple, Optional from typing import Tuple, Optional
@@ -20,6 +19,8 @@ from urllib.parse import urlencode, urlunparse
import aiohttp import aiohttp
import lxml.objectify import lxml.objectify
import requests import requests
from cmapi_server.traced_session import get_traced_session
from cmapi_server.traced_aiohttp import create_traced_async_session
from cmapi_server.exceptions import CMAPIBasicError from cmapi_server.exceptions import CMAPIBasicError
# Bug in pylint https://github.com/PyCQA/pylint/issues/4584 # Bug in pylint https://github.com/PyCQA/pylint/issues/4584
@@ -153,9 +154,9 @@ def start_transaction(
body['timeout'] = ( body['timeout'] = (
final_time - datetime.datetime.now() final_time - datetime.datetime.now()
).seconds ).seconds
r = requests.put( r = get_traced_session().request(
url, verify=False, headers=headers, json=body, 'PUT', url, verify=False, headers=headers,
timeout=10 json=body, timeout=10
) )
# a 4xx error from our endpoint; # a 4xx error from our endpoint;
@@ -219,8 +220,9 @@ def rollback_txn_attempt(key, version, txnid, nodes):
url = f"https://{node}:8640/cmapi/{version}/node/rollback" url = f"https://{node}:8640/cmapi/{version}/node/rollback"
for retry in range(5): for retry in range(5):
try: try:
r = requests.put( r = get_traced_session().request(
url, verify=False, headers=headers, json=body, timeout=5 'PUT', url, verify=False, headers=headers,
json=body, timeout=5
) )
r.raise_for_status() r.raise_for_status()
except requests.Timeout: except requests.Timeout:
@@ -274,7 +276,10 @@ def commit_transaction(
url = f"https://{node}:8640/cmapi/{version}/node/commit" url = f"https://{node}:8640/cmapi/{version}/node/commit"
for retry in range(5): for retry in range(5):
try: try:
r = requests.put(url, verify = False, headers = headers, json = body, timeout = 5) r = get_traced_session().request(
'PUT', url, verify=False, headers=headers,
json=body, timeout=5
)
r.raise_for_status() r.raise_for_status()
except requests.Timeout as e: except requests.Timeout as e:
logging.warning(f"commit_transaction(): timeout on node {node}") logging.warning(f"commit_transaction(): timeout on node {node}")
@@ -373,7 +378,7 @@ def broadcast_new_config(
url = f'https://{node}:8640/cmapi/{version}/node/config' url = f'https://{node}:8640/cmapi/{version}/node/config'
resp_json: dict = dict() resp_json: dict = dict()
async with aiohttp.ClientSession() as session: async with create_traced_async_session() as session:
try: try:
async with session.put( async with session.put(
url, headers=headers, json=body, ssl=False, timeout=120 url, headers=headers, json=body, ssl=False, timeout=120
@@ -656,7 +661,7 @@ def get_current_config_file(
headers = {'x-api-key' : key} headers = {'x-api-key' : key}
url = f'https://{node}:8640/cmapi/{get_version()}/node/config' url = f'https://{node}:8640/cmapi/{get_version()}/node/config'
try: try:
r = requests.get(url, verify=False, headers=headers, timeout=5) r = get_traced_session().request('GET', url, verify=False, headers=headers, timeout=5)
r.raise_for_status() r.raise_for_status()
config = r.json()['config'] config = r.json()['config']
except Exception as e: except Exception as e:
@@ -767,14 +772,17 @@ def if_primary_restart(
success = False success = False
while not success and datetime.datetime.now() < endtime: while not success and datetime.datetime.now() < endtime:
try: try:
response = requests.put(url, verify = False, headers = headers, json = body, timeout = 60) response = get_traced_session().request(
'PUT', url, verify=False, headers=headers,
json=body, timeout=60
)
response.raise_for_status() response.raise_for_status()
success = True success = True
except Exception as e: except Exception as e:
logging.warning(f"if_primary_restart(): failed to start the cluster, got {str(e)}") logging.warning(f"if_primary_restart(): failed to start the cluster, got {str(e)}")
time.sleep(10) time.sleep(10)
if not success: if not success:
logging.error(f"if_primary_restart(): failed to start the cluster. Manual intervention is required.") logging.error("if_primary_restart(): failed to start the cluster. Manual intervention is required.")
def get_cej_info(config_root): def get_cej_info(config_root):

View File

@@ -7,6 +7,7 @@ import cherrypy
from cherrypy import _cperror from cherrypy import _cperror
from cmapi_server.constants import CMAPI_LOG_CONF_PATH from cmapi_server.constants import CMAPI_LOG_CONF_PATH
from cmapi_server.tracer import get_tracer
class AddIpFilter(logging.Filter): class AddIpFilter(logging.Filter):
@@ -16,6 +17,28 @@ class AddIpFilter(logging.Filter):
return True return True
def install_trace_record_factory() -> None:
"""Install a LogRecord factory that adds 'trace_params' field.
'trace_params' will be an empty string if there is no active trace/span
(like in MainThread, where there is no incoming requests).
Otherwise it will contain trace parameters.
"""
current_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs): # type: ignore[no-untyped-def]
record = current_factory(*args, **kwargs)
try:
trace_id, span_id, parent_span_id = get_tracer().current_trace_ids()
record.trace_params = (
f" rid={trace_id} sid={span_id} psid={parent_span_id}"
)
except Exception:
record.trace_params = " rid= sid= psid="
return record
logging.setLogRecordFactory(factory)
def custom_cherrypy_error( def custom_cherrypy_error(
self, msg='', context='', severity=logging.INFO, traceback=False self, msg='', context='', severity=logging.INFO, traceback=False
): ):
@@ -119,6 +142,8 @@ def config_cmapi_server_logging():
cherrypy._cplogging.LogManager.access_log_format = ( cherrypy._cplogging.LogManager.access_log_format = (
'{h} ACCESS "{r}" code {s}, bytes {b}, user-agent "{a}"' '{h} ACCESS "{r}" code {s}, bytes {b}, user-agent "{a}"'
) )
# Ensure trace_params is available on every record
install_trace_record_factory()
dict_config(CMAPI_LOG_CONF_PATH) dict_config(CMAPI_LOG_CONF_PATH)

View File

@@ -20,7 +20,9 @@ from cmapi_server.constants import (
CMAPI_CONF_PATH, CMAPI_SINGLE_NODE_XML, DEFAULT_MCS_CONF_PATH, LOCALHOSTS, CMAPI_CONF_PATH, CMAPI_SINGLE_NODE_XML, DEFAULT_MCS_CONF_PATH, LOCALHOSTS,
MCS_DATA_PATH, MCS_DATA_PATH,
) )
from cmapi_server.traced_session import get_traced_session
from cmapi_server.managers.network import NetworkManager from cmapi_server.managers.network import NetworkManager
from cmapi_server.tracer import get_tracer
from mcs_node_control.models.node_config import NodeConfig from mcs_node_control.models.node_config import NodeConfig
@@ -617,7 +619,9 @@ def _rebalance_dbroots(root, test_mode=False):
headers = {'x-api-key': key} headers = {'x-api-key': key}
url = f"https://{node_ip}:8640/cmapi/{version}/node/new_primary" url = f"https://{node_ip}:8640/cmapi/{version}/node/new_primary"
try: try:
r = requests.get(url, verify = False, headers = headers, timeout = 10) r = get_traced_session().request(
'GET', url, verify=False, headers=headers, timeout=10
)
r.raise_for_status() r.raise_for_status()
r = r.json() r = r.json()
is_primary = r['is_primary'] is_primary = r['is_primary']

View File

@@ -0,0 +1,51 @@
"""
CherryPy tool that uses the tracer to start a span for each request.
If traceparent header is present in the request, the tool will continue this trace chain.
Otherwise, it will start a new trace (with a new trace_id).
"""
from typing import Dict
import cherrypy
from cmapi_server.tracer import get_tracer
def _on_request_start() -> None:
"""CherryPy tool hook: extract incoming context and start a SERVER span."""
req = cherrypy.request
tracer = get_tracer()
headers: Dict[str, str] = dict(req.headers or {})
trace_id, parent_span_id = tracer.extract_traceparent(headers)
tracer.set_incoming_context(trace_id, parent_span_id)
span_name = f"{getattr(req, 'method', 'HTTP')} {getattr(req, 'path_info', '/')}"
ctx = tracer.start_as_current_span(span_name, kind="SERVER")
span = ctx.__enter__()
setattr(req, "_trace_span_ctx", ctx)
setattr(req, "_trace_span", span)
# Echo current traceparent to the client
tracer.inject_traceparent(cherrypy.response.headers) # type: ignore[arg-type]
def _on_request_end() -> None:
"""CherryPy tool hook: end the SERVER span started at request start."""
req = cherrypy.request
ctx = getattr(req, "_trace_span_ctx", None)
if ctx is not None:
try:
ctx.__exit__(None, None, None)
finally:
setattr(req, "_trace_span_ctx", None)
setattr(req, "_trace_span", None)
def register_tracing_tools() -> None:
"""Register CherryPy tools for request tracing."""
cherrypy.tools.trace = cherrypy.Tool("on_start_resource", _on_request_start, priority=10)
cherrypy.tools.trace_end = cherrypy.Tool("on_end_resource", _on_request_end, priority=80)

View File

@@ -0,0 +1,43 @@
"""Async sibling of TracedSession"""
from typing import Any
import aiohttp
from cmapi_server.tracer import get_tracer
class TracedAsyncSession(aiohttp.ClientSession):
async def _request(
self, method: str, str_or_url: Any, *args: Any, **kwargs: Any
) -> aiohttp.ClientResponse:
tracer = get_tracer()
headers = kwargs.get("headers") or {}
if headers is None:
headers = {}
kwargs["headers"] = headers
url_text = str(str_or_url)
span_name = f"HTTP {method} {url_text}"
with tracer.start_as_current_span(span_name, kind="CLIENT") as span:
span.set_attribute("http.method", method)
span.set_attribute("http.url", url_text)
try:
tracer.inject_traceparent(headers)
except Exception:
pass
try:
response = await super()._request(method, str_or_url, *args, **kwargs)
except Exception as exc:
span.set_status("ERROR", str(exc))
raise
else:
span.set_attribute("http.status_code", response.status)
return response
def create_traced_async_session(**kwargs: Any) -> TracedAsyncSession:
return TracedAsyncSession(**kwargs)

View File

@@ -0,0 +1,52 @@
"""Our own customized requests.Session that automatically traces outbound HTTP calls.
Creates a CLIENT span per outbound HTTP request, injects traceparent,
records method/url/status, and closes the span when the request finishes.
"""
from typing import Any, Optional
import requests
from cmapi_server.tracer import get_tracer
class TracedSession(requests.Session):
"""requests.Session that automatically traces outbound HTTP calls."""
def request(self, method: str, url: str, *args: Any, **kwargs: Any) -> requests.Response:
tracer = get_tracer()
headers = kwargs.get("headers") or {}
if headers is None:
headers = {}
kwargs["headers"] = headers
span_name = f"HTTP {method} {url}"
with tracer.start_as_current_span(span_name, kind="CLIENT") as span:
span.set_attribute("http.method", method)
span.set_attribute("http.url", url)
tracer.inject_traceparent(headers)
try:
response = super().request(method, url, *args, **kwargs)
except Exception as exc:
span.set_status("ERROR", str(exc))
raise
else:
# Record status code
span.set_attribute("http.status_code", response.status_code)
return response
_default_session: Optional[TracedSession] = None
def get_traced_session() -> TracedSession:
"""Return a process-wide TracedSession singleton."""
global _default_session
if _default_session is None:
_default_session = TracedSession()
return _default_session

View File

@@ -0,0 +1,210 @@
"""Support distributed request tracing via W3C Trace Context.
See https://www.w3.org/TR/trace-context/ for the official spec.
There are 3 and a half main components:
- trace_id: a unique identifier for a trace.
It is a 32-hex string, passed in the outbound HTTP requests headers, so that we can
trace the request chain through the system.
- span_id: a unique identifier for a span (the current operation within a trace chain).
It is a 16-hex string. For example, when we we receive a request to add a host, the addition
of the host is a separate span within the request chain.
- parent_span_id: a unique identifier for the parent span of the current span.
Continuing the example above, when we add a host, first we start a transaction,
then we add the host. If we are already adding a host, then creation of the transaction
is the parent span of the current span.
- traceparent: a header that combines trace_id and span_id in one value.
It has the format "00-<trace_id>-<span_id>-<flags>".
A system that calls CMAPI can pass the traceparent header in the request, so that we can pass
the trace_id through the system, changing span_id as we enter new sub-operations.
We can reconstruct the trace tree from the set of the logged traceparent attributes,
representing how the request was processed, which nodes were involved,
how much time did each op take, etc.
This module implements a tracer class that creates spans, injects/extracts traceparent headers.
It uses contextvars to store the trace/span/parent_span ids and start time for each context.
"""
from __future__ import annotations
import contextvars
import logging
import os
import time
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
# Contextvars containing trace/span/parent_span ids and start time for this thread
# (contextvars are something like TLS, they contain variables that are local to the context)
_current_trace_id = contextvars.ContextVar[str]("trace_id", default="")
_current_span_id = contextvars.ContextVar[str]("span_id", default="")
_current_parent_span_id = contextvars.ContextVar[str]("parent_span_id", default="")
_current_span_start_ns = contextvars.ContextVar[int]("span_start_ns", default=0)
def _rand_16_hex() -> str:
# 16 hex bytes (32 hex chars)
return os.urandom(16).hex()
def _rand_8_hex() -> str:
# 8 hex bytes (16 hex chars)
return os.urandom(8).hex()
def format_traceparent(trace_id: str, span_id: str, flags: str = "01") -> str:
"""W3C traceparent: version 00"""
# version(2)-trace_id(32)-span_id(16)-flags(2)
return f"00-{trace_id}-{span_id}-{flags}"
def parse_traceparent(header: str) -> Optional[tuple[str, str, str]]:
"""Return (trace_id, span_id, flags) or None if invalid."""
try:
parts = header.strip().split("-")
if len(parts) != 4 or parts[0] != "00":
logger.error(f"Invalid traceparent: {header}")
return None
trace_id, span_id, flags = parts[1], parts[2], parts[3]
if len(trace_id) != 32 or len(span_id) != 16 or len(flags) != 2:
return None
# W3C: all zero trace_id/span_id are invalid
if set(trace_id) == {"0"} or set(span_id) == {"0"}:
return None
return trace_id, span_id, flags
except Exception:
logger.error(f"Failed to parse traceparent: {header}")
return None
@dataclass
class TraceSpan:
"""Lightweight span handle; keeps attributes in memory (for logging only)."""
name: str
kind: str # "SERVER" | "CLIENT" | "INTERNAL"
start_ns: int
trace_id: str
span_id: str
parent_span_id: str
attributes: Dict[str, Any]
def set_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def add_event(self, name: str, **attrs: Any) -> None:
# For simplicity we just log events immediately
logger.info(
"event name=%s trace_id=%s span_id=%s attrs=%s",
name, self.trace_id, self.span_id, attrs
)
def set_status(self, code: str, description: str = "") -> None:
self.attributes["status.code"] = code
if description:
self.attributes["status.description"] = description
def record_exception(self, exc: BaseException) -> None:
self.add_event("exception", type=type(exc).__name__, msg=str(exc))
class Tracer:
"""Encapsulates everything related to tracing (span creation, logging, etc)"""
@contextmanager
def start_as_current_span(self, name: str, kind: str = "INTERNAL") -> Iterator[TraceSpan]:
trace_id = _current_trace_id.get() or _rand_16_hex()
parent_span = _current_span_id.get()
new_span_id = _rand_8_hex()
# Push new context
tok_tid = _current_trace_id.set(trace_id)
tok_sid = _current_span_id.set(new_span_id)
tok_pid = _current_parent_span_id.set(parent_span)
tok_start = _current_span_start_ns.set(time.time_ns())
span = TraceSpan(
name=name,
kind=kind,
start_ns=_current_span_start_ns.get(),
trace_id=trace_id,
span_id=new_span_id,
parent_span_id=parent_span,
attributes={"span.kind": kind, "span.name": name},
)
try:
logger.info(
"span_begin name=%s kind=%s trace_id=%s span_id=%s parent_span_id=%s attrs=%s",
span.name, span.kind, span.trace_id, span.span_id, span.parent_span_id, span.attributes
)
yield span
except BaseException as exc:
span.record_exception(exc)
span.set_status("ERROR", str(exc))
raise
finally:
# Pop the span from the context (restore parent context)
duration_ms = (time.time_ns() - span.start_ns) / 1_000_000
logger.info(
"span_end name=%s kind=%s trace_id=%s span_id=%s parent_span_id=%s duration_ms=%.3f attrs=%s",
span.name, span.kind, span.trace_id, span.span_id, span.parent_span_id, duration_ms, span.attributes
)
# Restore previous context
_current_span_start_ns.reset(tok_start)
_current_parent_span_id.reset(tok_pid)
_current_span_id.reset(tok_sid)
_current_trace_id.reset(tok_tid)
def set_incoming_context(
self,
trace_id: Optional[str] = None,
parent_span_id: Optional[str] = None,
) -> None:
"""Seed current context with incoming W3C traceparent values.
Only non-empty values are applied.
"""
if trace_id:
_current_trace_id.set(trace_id)
if parent_span_id:
_current_parent_span_id.set(parent_span_id)
def current_trace_ids(self) -> tuple[str, str, str]:
return _current_trace_id.get(), _current_span_id.get(), _current_parent_span_id.get()
def inject_traceparent(self, headers: Dict[str, str]) -> None:
"""Inject W3C traceparent into outbound headers."""
trace_id, span_id, _ = self.current_trace_ids()
if not trace_id or not span_id:
# If called outside of a span, create a short-lived span just to carry IDs
trace_id = trace_id or _rand_16_hex()
span_id = span_id or _rand_8_hex()
headers["traceparent"] = format_traceparent(trace_id, span_id)
def extract_traceparent(self, headers: Dict[str, str]) -> tuple[str, str]:
"""Extract parent trace/span; returns (trace_id, parent_span_id)."""
raw_traceparent = (headers.get("traceparent")
or headers.get("Traceparent")
or headers.get("TRACEPARENT"))
if not raw_traceparent:
return "", ""
parsed = parse_traceparent(raw_traceparent)
if not parsed:
return "", ""
return parsed[0], parsed[1]
# No incoming context
return "", ""
# Tracer singleton for the process (not thread)
_tracer = Tracer()
def get_tracer() -> Tracer:
return _tracer
class TraceLogFilter(logging.Filter):
"""Inject trace/span ids into LogRecord for formatting."""
def filter(self, record: logging.LogRecord) -> bool:
record.traceID, record.spanID, record.parentSpanID = get_tracer().current_trace_ids()
return True