diff --git a/data/model/_basequery.py b/data/model/_basequery.py index 2ebe536e8..c407f7278 100644 --- a/data/model/_basequery.py +++ b/data/model/_basequery.py @@ -56,7 +56,7 @@ def reduce_as_tree(queries_to_reduce): def get_existing_repository(namespace_name, repository_name, for_update=False, kind_filter=None): query = ( - Repository.select(Repository, Namespace) + Repository.select(Repository, Namespace, can_use_read_replica=True) .join(Namespace, on=(Repository.namespace_user == Namespace.id)) .where(Namespace.username == namespace_name, Repository.name == repository_name) .where(Repository.state != RepositoryState.MARKED_FOR_DELETION) diff --git a/data/model/oci/tag.py b/data/model/oci/tag.py index edf00154a..24318fdf2 100644 --- a/data/model/oci/tag.py +++ b/data/model/oci/tag.py @@ -73,7 +73,7 @@ def get_tag(repository_id, tag_name): The tag is returned joined with its manifest. """ query = ( - Tag.select(Tag, Manifest) + Tag.select(Tag, Manifest, can_use_read_replica=True) .join(Manifest) .where(Tag.repository == repository_id) .where(Tag.name == tag_name) diff --git a/data/model/permission.py b/data/model/permission.py index e47e4e316..a38e78e84 100644 --- a/data/model/permission.py +++ b/data/model/permission.py @@ -76,7 +76,12 @@ def _get_user_repo_permissions( UserThroughTeam = User.alias() base_query = ( - RepositoryPermission.select(RepositoryPermission, Role, Repository, Namespace) + RepositoryPermission.select( + RepositoryPermission, + Role, + Repository, + Namespace, + ) .join(Role) .switch(RepositoryPermission) .join(Repository) diff --git a/data/readreplica.py b/data/readreplica.py index 4a0f69d44..3e48cc0de 100644 --- a/data/readreplica.py +++ b/data/readreplica.py @@ -1,14 +1,17 @@ from __future__ import annotations +import logging import random from collections import namedtuple from contextlib import contextmanager -from typing import Type, TypeVar +from typing import Type, TypeVar, Any from peewee import SENTINEL, Model, ModelSelect, OperationalError, Proxy from data.decorators import is_deprecated_model +logger = logging.getLogger(__name__) + TReadReplicaSupportedModel = TypeVar( "TReadReplicaSupportedModel", bound="ReadReplicaSupportedModel" ) @@ -108,7 +111,7 @@ class ReadReplicaSupportedModel(Model): return cls._read_only_config().is_readonly @classmethod - def _select_database(cls): + def _select_database(cls, can_use_read_replica=False): """ Selects a read replica database if we're configured to support read replicas. @@ -116,6 +119,7 @@ class ReadReplicaSupportedModel(Model): """ # Select the master DB if read replica support is not enabled. read_only_config = cls._read_only_config() + if not read_only_config.read_replicas: return cls._meta.database @@ -127,6 +131,9 @@ class ReadReplicaSupportedModel(Model): if getattr(cls._meta.database._state, _FORCE_MASTER_COUNTER_ATTRIBUTE, 0) > 0: return cls._meta.database + if not can_use_read_replica: + return cls._meta.database + # Otherwise, return a read replica database with auto-retry onto the main database. replicas = read_only_config.read_replicas selected_read_replica = replicas[random.randrange(len(replicas))] @@ -134,12 +141,38 @@ class ReadReplicaSupportedModel(Model): @classmethod def select( - cls: Type[TReadReplicaSupportedModel], *args, **kwargs + cls: Type[TReadReplicaSupportedModel], *args, **kwargs: Any ) -> ModelSelect[TReadReplicaSupportedModel]: + + can_use_read_replica = False + if "can_use_read_replica" in kwargs: + can_use_read_replica = kwargs.get("can_use_read_replica") + del kwargs["can_use_read_replica"] + query = super(ReadReplicaSupportedModel, cls).select(*args, **kwargs) - query._database = cls._select_database() + query._database = cls._select_database(can_use_read_replica) return query + @classmethod + def get( + cls: Type[TReadReplicaSupportedModel], *args, **kwargs: Any + ) -> TReadReplicaSupportedModel: + can_use_read_replica = False + if "can_use_read_replica" in kwargs: + can_use_read_replica = kwargs.get("can_use_read_replica") + del kwargs["can_use_read_replica"] + + sq = cls.select(can_use_read_replica=can_use_read_replica) + if args: + # Handle simple lookup using just the primary key. + if len(args) == 1 and isinstance(args[0], int): + sq = sq.where(cls._meta.primary_key == args[0]) + else: + sq = sq.where(*args) + if kwargs: + sq = sq.filter(**kwargs) + return sq.get() + @classmethod def insert(cls, *args, **kwargs): if is_deprecated_model(cls): diff --git a/data/test/test_readreplica.py b/data/test/test_readreplica.py index 1230c0da4..2d7b85118 100644 --- a/data/test/test_readreplica.py +++ b/data/test/test_readreplica.py @@ -1,13 +1,13 @@ import os import shutil -from test.fixtures import * -from test.testconfig import FakeTransaction import pytest from peewee import OperationalError from data.database import User, configure, db_disallow_replica_use, read_only_config from data.readreplica import ReadOnlyModeException +from test.fixtures import * +from test.testconfig import FakeTransaction @pytest.mark.skipif(bool(os.environ.get("TEST_DATABASE_URI")), reason="Testing requires SQLite") @@ -49,7 +49,7 @@ def test_readreplica(init_db_path, tmpdir_factory): assert not read_only_config.obj.is_readonly assert read_only_config.obj.read_replicas - devtable_user = User.get(username="devtable") + devtable_user = User.get(username="devtable", can_use_read_replica=True) assert devtable_user.username == "devtable" # Force us to hit the master and ensure it doesn't work. @@ -57,8 +57,16 @@ def test_readreplica(init_db_path, tmpdir_factory): with pytest.raises(OperationalError): User.get(username="devtable") + # Explicitly disallow replica use and ensure it doesn't work. + with pytest.raises(OperationalError): + User.get(username="devtable", can_use_read_replica=False) + + # Default to hitting the master and ensure it doesn't work. + with pytest.raises(OperationalError): + User.get(username="devtable") + # Test read replica again. - devtable_user = User.get(username="devtable") + devtable_user = User.get(username="devtable", can_use_read_replica=True) assert devtable_user.username == "devtable" # Try to change some data. This should fail because the primary is broken.