1
0
mirror of https://github.com/quay/quay.git synced 2025-07-30 07:43:13 +03:00

chore: Run mypy as CI job (#1363)

* Run mypy as CI job

* Fix peewee.pyi and configure pyright
This commit is contained in:
Oleg Bulatov
2022-06-13 11:01:17 +02:00
committed by GitHub
parent 46cd48dd9f
commit 5eaf0584db
71 changed files with 278 additions and 237 deletions

View File

@ -64,6 +64,27 @@ jobs:
- name: tox - name: tox
run: tox -e py38-unit run: tox -e py38-unit
types:
name: Types Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install libgpgme-dev libldap2-dev libsasl2-dev swig
python -m pip install --upgrade pip
pip install -r ./requirements-dev.txt
- name: Check Types
run: make types-test
e2e: e2e:
name: E2E Tests name: E2E Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -7,7 +7,9 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
# Stub out this call so that we can run the external_libraries script # Stub out this call so that we can run the external_libraries script
# without needing the entire codebase. # without needing the entire codebase.
def get_config_provider(*args, **kwargs): def get_config_provider(
config_volume, yaml_filename, py_filename, testing=False, kubernetes=False
):
return None return None

View File

@ -15,16 +15,16 @@ from data import model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ResourceNeed = namedtuple("resource", ["type", "namespace", "name", "role"]) _ResourceNeed = namedtuple("_ResourceNeed", ["type", "namespace", "name", "role"])
_RepositoryNeed = partial(_ResourceNeed, "repository") _RepositoryNeed = partial(_ResourceNeed, "repository")
_NamespaceWideNeed = namedtuple("namespacewide", ["type", "namespace", "role"]) _NamespaceWideNeed = namedtuple("_NamespaceWideNeed", ["type", "namespace", "role"])
_OrganizationNeed = partial(_NamespaceWideNeed, "organization") _OrganizationNeed = partial(_NamespaceWideNeed, "organization")
_OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo") _OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo")
_TeamTypeNeed = namedtuple("teamwideneed", ["type", "orgname", "teamname", "role"]) _TeamTypeNeed = namedtuple("_TeamTypeNeed", ["type", "orgname", "teamname", "role"])
_TeamNeed = partial(_TeamTypeNeed, "orgteam") _TeamNeed = partial(_TeamTypeNeed, "orgteam")
_UserTypeNeed = namedtuple("userspecificneed", ["type", "username", "role"]) _UserTypeNeed = namedtuple("_UserTypeNeed", ["type", "username", "role"])
_UserNeed = partial(_UserTypeNeed, "user") _UserNeed = partial(_UserTypeNeed, "user")
_SuperUserNeed = partial(namedtuple("superuserneed", ["type"]), "superuser") _SuperUserNeed = partial(namedtuple("_SuperUserNeed", ["type"]), "superuser")
REPO_ROLES = [None, "read", "write", "admin"] REPO_ROLES = [None, "read", "write", "admin"]

View File

@ -2,7 +2,7 @@ from collections import namedtuple
import features import features
import re import re
Scope = namedtuple("scope", ["scope", "icon", "dangerous", "title", "description"]) Scope = namedtuple("Scope", ["scope", "icon", "dangerous", "title", "description"])
READ_REPO = Scope( READ_REPO = Scope(

View File

@ -54,7 +54,7 @@ def test_valid_oauth(app):
assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token) assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token)
def test_invalid_user(app): def test_invalid_password(app):
result, kind = validate_credentials("devtable", "somepassword") result, kind = validate_credentials("devtable", "somepassword")
assert kind == CredentialKind.user assert kind == CredentialKind.user
assert result == ValidateResult( assert result == ValidateResult(

View File

@ -4,30 +4,11 @@ Provides helper methods and templates for generating cloud config for running co
Originally from https://github.com/DevTable/container-cloud-config Originally from https://github.com/DevTable/container-cloud-config
""" """
from functools import partial
import base64
import json import json
import os import os
import requests
import logging import logging
try: from urllib.parse import quote as urlquote
# Python 3
from urllib.request import HTTPRedirectHandler, build_opener, install_opener, urlopen, Request
from urllib.error import HTTPError
from urllib.parse import quote as urlquote
except ImportError:
# Python 2
from urllib2 import (
HTTPRedirectHandler,
build_opener,
install_opener,
urlopen,
Request,
HTTPError,
)
from urllib import quote as urlquote
from jinja2 import FileSystemLoader, Environment, StrictUndefined from jinja2 import FileSystemLoader, Environment, StrictUndefined

View File

@ -186,7 +186,7 @@ class Orchestrator(metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def shutdown(): def shutdown(self):
""" """
This function should shutdown any final resources allocated by the Orchestrator. This function should shutdown any final resources allocated by the Orchestrator.
""" """

View File

@ -461,14 +461,15 @@ class DefaultConfig(ImmutableConfig):
AVATAR_KIND = "local" AVATAR_KIND = "local"
# Custom branding # Custom branding
BRANDING: Dict[str, Optional[str]]
if os.environ.get("RED_HAT_QUAY", False): if os.environ.get("RED_HAT_QUAY", False):
BRANDING: Dict[str, Optional[str]] = { BRANDING = {
"logo": "/static/img/RH_Logo_Quay_Black_UX-horizontal.svg", "logo": "/static/img/RH_Logo_Quay_Black_UX-horizontal.svg",
"footer_img": "/static/img/RedHat.svg", "footer_img": "/static/img/RedHat.svg",
"footer_url": "https://access.redhat.com/documentation/en-us/red_hat_quay/3/", "footer_url": "https://access.redhat.com/documentation/en-us/red_hat_quay/3/",
} }
else: else:
BRANDING: Dict[str, Optional[str]] = { BRANDING = {
"logo": "/static/img/quay-horizontal-color.svg", "logo": "/static/img/quay-horizontal-color.svg",
"footer_img": None, "footer_img": None,
"footer_url": None, "footer_url": None,
@ -485,7 +486,7 @@ class DefaultConfig(ImmutableConfig):
FEATURE_SECURITY_NOTIFICATIONS = False FEATURE_SECURITY_NOTIFICATIONS = False
# The endpoint for the (deprecated) V2 security scanner. # The endpoint for the (deprecated) V2 security scanner.
SECURITY_SCANNER_ENDPOINT = None SECURITY_SCANNER_ENDPOINT: Optional[str] = None
# The endpoint for the V4 security scanner. # The endpoint for the V4 security scanner.
SECURITY_SCANNER_V4_ENDPOINT: Optional[str] = None SECURITY_SCANNER_V4_ENDPOINT: Optional[str] = None

View File

@ -1,9 +1,6 @@
import logging import logging
from util.registry.gzipinputstream import GzipInputStream from data.userfiles import DelegateUserfiles
from flask import send_file, abort
from data.userfiles import DelegateUserfiles, UserfilesHandlers
JSON_MIMETYPE = "application/json" JSON_MIMETYPE = "application/json"

View File

@ -1,4 +1,4 @@
from typing import Dict from typing import Any, Dict
import stripe import stripe
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -384,7 +384,7 @@ class FakeStripe(object):
} }
) )
ACTIVE_CUSTOMERS = {} ACTIVE_CUSTOMERS: Dict[str, Any] = {}
@property @property
def card(self): def card(self):

View File

@ -988,6 +988,7 @@ class RepositorySearchScore(BaseModel):
class RepositorySize(BaseModel): class RepositorySize(BaseModel):
repository = ForeignKeyField(Repository, unique=True) repository = ForeignKeyField(Repository, unique=True)
repository_id: int
size_bytes = BigIntegerField() size_bytes = BigIntegerField()
@ -1519,6 +1520,7 @@ class UploadedBlob(BaseModel):
class BlobUpload(BaseModel): class BlobUpload(BaseModel):
repository = ForeignKeyField(Repository) repository = ForeignKeyField(Repository)
repository_id: int
uuid = CharField(index=True, unique=True) uuid = CharField(index=True, unique=True)
byte_count = BigIntegerField(default=0) byte_count = BigIntegerField(default=0)
# TODO(kleesc): Verify that this is backward compatible with resumablehashlib # TODO(kleesc): Verify that this is backward compatible with resumablehashlib
@ -1798,6 +1800,7 @@ class Manifest(BaseModel):
""" """
repository = ForeignKeyField(Repository) repository = ForeignKeyField(Repository)
repository_id: int
digest = CharField(index=True) digest = CharField(index=True)
media_type = EnumField(MediaType) media_type = EnumField(MediaType)
manifest_bytes = TextField() manifest_bytes = TextField()
@ -1830,6 +1833,7 @@ class Tag(BaseModel):
name = CharField() name = CharField()
repository = ForeignKeyField(Repository) repository = ForeignKeyField(Repository)
repository_id: int
manifest = ForeignKeyField(Manifest, null=True) manifest = ForeignKeyField(Manifest, null=True)
lifetime_start_ms = BigIntegerField(default=get_epoch_timestamp_ms) lifetime_start_ms = BigIntegerField(default=get_epoch_timestamp_ms)
lifetime_end_ms = BigIntegerField(null=True, index=True) lifetime_end_ms = BigIntegerField(null=True, index=True)

View File

@ -6,7 +6,6 @@ import json
from random import SystemRandom from random import SystemRandom
import bcrypt import bcrypt
import rehash
from peewee import TextField, CharField, SmallIntegerField from peewee import TextField, CharField, SmallIntegerField
from data.text import prefix_search from data.text import prefix_search

View File

@ -1,7 +1,9 @@
import pytest
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from freezegun import freeze_time from freezegun import freeze_time
from data import model
from data.logs_model.inmemory_model import InMemoryModel from data.logs_model.inmemory_model import InMemoryModel
from data.logs_model.combined_model import CombinedLogsModel from data.logs_model.combined_model import CombinedLogsModel

View File

@ -1,4 +1,7 @@
import os
import pytest
from datetime import datetime, timedelta, date from datetime import datetime, timedelta, date
from unittest.mock import patch
from data.logs_model.datatypes import AggregatedLogCount from data.logs_model.datatypes import AggregatedLogCount
from data.logs_model.table_logs_model import TableLogsModel from data.logs_model.table_logs_model import TableLogsModel
from data.logs_model.combined_model import CombinedLogsModel from data.logs_model.combined_model import CombinedLogsModel

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
from typing import overload, Optional, Literal
from collections import namedtuple from collections import namedtuple
@ -130,11 +131,29 @@ def _lookup_manifest(repository_id, manifest_digest, allow_dead=False):
return None return None
@overload
def create_manifest(
repository_id: int,
manifest: ManifestInterface | ManifestListInterface,
raise_on_error: Literal[True] = ...,
) -> Manifest:
...
@overload
def create_manifest(
repository_id: int,
manifest: ManifestInterface | ManifestListInterface,
raise_on_error: Literal[False],
) -> Optional[Manifest]:
...
def create_manifest( def create_manifest(
repository_id: int, repository_id: int,
manifest: ManifestInterface | ManifestListInterface, manifest: ManifestInterface | ManifestListInterface,
raise_on_error: bool = True, raise_on_error: bool = True,
) -> Manifest: ) -> Optional[Manifest]:
""" """
Creates a manifest in the database. Creates a manifest in the database.
Does not handle sub manifests in a manifest list/index. Does not handle sub manifests in a manifest list/index.

View File

@ -1,4 +1,5 @@
import json import json
import pytest
from playhouse.test_utils import assert_query_count from playhouse.test_utils import assert_query_count

View File

@ -1,15 +1,12 @@
import pytest
from calendar import timegm from calendar import timegm
from datetime import timedelta, datetime from datetime import timedelta, datetime
from playhouse.test_utils import assert_query_count from playhouse.test_utils import assert_query_count
from data import model
from data.database import ( from data.database import (
Tag, Tag,
ManifestLegacyImage,
TagToRepositoryTag,
TagManifestToManifest,
TagManifest,
Manifest,
Repository, Repository,
) )
from data.model.oci.test.test_oci_manifest import create_manifest_for_testing from data.model.oci.test.test_oci_manifest import create_manifest_for_testing

View File

@ -17,7 +17,7 @@ from data.database import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
search_bucket = namedtuple("SearchBucket", ["delta", "days", "weight"]) search_bucket = namedtuple("search_bucket", ["delta", "days", "weight"])
# Defines the various buckets for search scoring. Each bucket is computed using the given time # Defines the various buckets for search scoring. Each bucket is computed using the given time
# delta from today *minus the previous bucket's time period*. Once all the actions over the # delta from today *minus the previous bucket's time period*. Once all the actions over the

View File

@ -31,7 +31,7 @@ from util.metrics.prometheus import gc_table_rows_deleted, gc_storage_blobs_dele
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_Location = namedtuple("location", ["id", "name"]) _Location = namedtuple("_Location", ["id", "name"])
EMPTY_LAYER_BLOB_DIGEST = "sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4" EMPTY_LAYER_BLOB_DIGEST = "sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"
SPECIAL_BLOB_DIGESTS = set([EMPTY_LAYER_BLOB_DIGEST]) SPECIAL_BLOB_DIGESTS = set([EMPTY_LAYER_BLOB_DIGEST])

View File

@ -2,6 +2,7 @@ import pytest
from mock import patch from mock import patch
from data import model
from data.database import BUILD_PHASE, RepositoryBuildTrigger, RepositoryBuild from data.database import BUILD_PHASE, RepositoryBuildTrigger, RepositoryBuild
from data.model.build import ( from data.model.build import (
update_trigger_disable_status, update_trigger_disable_status,

View File

@ -1,9 +1,11 @@
import pytest
from playhouse.test_utils import assert_query_count from playhouse.test_utils import assert_query_count
from data.database import DEFAULT_PROXY_CACHE_EXPIRATION
from data.model import InvalidOrganizationException from data.model import InvalidOrganizationException
from data.model.proxy_cache import *
from data.model.organization import create_organization from data.model.organization import create_organization
from data.database import ProxyCacheConfig, DEFAULT_PROXY_CACHE_EXPIRATION from data.model.proxy_cache import *
from data.model.user import create_user_noverify
from test.fixtures import * from test.fixtures import *

View File

@ -1,5 +1,6 @@
from data.model import namespacequota from data.model import namespacequota
from data.model.organization import create_organization from data.model.organization import create_organization
from data.model.user import create_user_noverify
from test.fixtures import * from test.fixtures import *

View File

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import pytest
from datetime import datetime, timedelta
from jsonschema import ValidationError from jsonschema import ValidationError
from data.database import RepoMirrorConfig, RepoMirrorStatus, User from data.database import RepoMirrorConfig, RepoMirrorStatus, User
@ -8,9 +10,13 @@ from data.model.repo_mirror import (
create_mirroring_rule, create_mirroring_rule,
get_eligible_mirrors, get_eligible_mirrors,
update_sync_status_to_cancel, update_sync_status_to_cancel,
MAX_SYNC_RETRIES,
release_mirror, release_mirror,
) )
from data.model.user import (
create_robot,
create_user_noverify,
lookup_robot,
)
from test.fixtures import * from test.fixtures import *

View File

@ -1,3 +1,4 @@
import os
from datetime import timedelta from datetime import timedelta
import pytest import pytest

View File

@ -1,7 +1,8 @@
from datetime import date, timedelta from datetime import date, datetime, timedelta
import pytest import pytest
from data import model
from data.database import RepositoryActionCount, RepositorySearchScore from data.database import RepositoryActionCount, RepositorySearchScore
from data.model.repository import create_repository, Repository from data.model.repository import create_repository, Repository
from data.model.repositoryactioncount import update_repository_score, SEARCH_BUCKETS from data.model.repositoryactioncount import update_repository_score, SEARCH_BUCKETS

View File

@ -4,6 +4,7 @@ import pytest
from mock import patch from mock import patch
from data import model
from data.database import EmailConfirmation, User, DeletedNamespace, FederatedLogin from data.database import EmailConfirmation, User, DeletedNamespace, FederatedLogin
from data.model.organization import get_organization from data.model.organization import get_organization
from data.model.notification import create_notification from data.model.notification import create_notification

View File

@ -15,7 +15,7 @@ from image.docker.schema1 import DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE
from util.bytes import Bytes from util.bytes import Bytes
class RepositoryReference(datatype("Repository", [])): class RepositoryReference(datatype("Repository", [])): # type: ignore[misc]
""" """
RepositoryReference is a reference to a repository, passed to registry interface methods. RepositoryReference is a reference to a repository, passed to registry interface methods.
""" """
@ -62,7 +62,7 @@ class RepositoryReference(datatype("Repository", [])):
@property # type: ignore @property # type: ignore
@optionalinput("kind") @optionalinput("kind")
def kind(self, kind): def kind(self, kind): # type: ignore[misc]
""" """
Returns the kind of the repository. Returns the kind of the repository.
""" """
@ -70,7 +70,7 @@ class RepositoryReference(datatype("Repository", [])):
@property # type: ignore @property # type: ignore
@optionalinput("is_public") @optionalinput("is_public")
def is_public(self, is_public): def is_public(self, is_public): # type: ignore[misc]
""" """
Returns whether the repository is public. Returns whether the repository is public.
""" """
@ -99,7 +99,7 @@ class RepositoryReference(datatype("Repository", [])):
@property # type: ignore @property # type: ignore
@optionalinput("namespace_name") @optionalinput("namespace_name")
def namespace_name(self, namespace_name=None): def namespace_name(self, namespace_name=None): # type: ignore[misc]
""" """
Returns the namespace name of this repository. Returns the namespace name of this repository.
""" """
@ -114,7 +114,7 @@ class RepositoryReference(datatype("Repository", [])):
@property # type: ignore @property # type: ignore
@optionalinput("is_free_namespace") @optionalinput("is_free_namespace")
def is_free_namespace(self, is_free_namespace=None): def is_free_namespace(self, is_free_namespace=None): # type: ignore[misc]
""" """
Returns whether the namespace of the repository is on a free plan. Returns whether the namespace of the repository is on a free plan.
""" """
@ -129,7 +129,7 @@ class RepositoryReference(datatype("Repository", [])):
@property # type: ignore @property # type: ignore
@optionalinput("repo_name") @optionalinput("repo_name")
def name(self, repo_name=None): def name(self, repo_name=None): # type: ignore[misc]
""" """
Returns the name of this repository. Returns the name of this repository.
""" """
@ -144,7 +144,7 @@ class RepositoryReference(datatype("Repository", [])):
@property # type: ignore @property # type: ignore
@optionalinput("state") @optionalinput("state")
def state(self, state=None): def state(self, state=None): # type: ignore[misc]
""" """
Return the state of the Repository. Return the state of the Repository.
""" """
@ -158,7 +158,7 @@ class RepositoryReference(datatype("Repository", [])):
return repository.state return repository.state
class Label(datatype("Label", ["key", "value", "uuid", "source_type_name", "media_type_name"])): class Label(datatype("Label", ["key", "value", "uuid", "source_type_name", "media_type_name"])): # type: ignore[misc]
""" """
Label represents a label on a manifest. Label represents a label on a manifest.
""" """
@ -178,7 +178,7 @@ class Label(datatype("Label", ["key", "value", "uuid", "source_type_name", "medi
) )
class ShallowTag(datatype("ShallowTag", ["name"])): class ShallowTag(datatype("ShallowTag", ["name"])): # type: ignore[misc]
""" """
ShallowTag represents a tag in a repository, but only contains basic information. ShallowTag represents a tag in a repository, but only contains basic information.
""" """
@ -199,7 +199,7 @@ class ShallowTag(datatype("ShallowTag", ["name"])):
class Tag( class Tag(
datatype( datatype( # type: ignore[misc]
"Tag", "Tag",
[ [
"name", "name",
@ -243,19 +243,19 @@ class Tag(
now_ms = get_epoch_timestamp_ms() now_ms = get_epoch_timestamp_ms()
return self.lifetime_end_ms is not None and self.lifetime_end_ms <= now_ms return self.lifetime_end_ms is not None and self.lifetime_end_ms <= now_ms
@property @property # type: ignore[misc]
@requiresinput("manifest_row") @requiresinput("manifest_row")
def _manifest_row(self, manifest_row): def _manifest_row(self, manifest_row): # type: ignore[misc]
""" """
Returns the database Manifest object for this tag. Returns the database Manifest object for this tag.
""" """
return manifest_row return manifest_row
@property @property # type: ignore[misc]
@requiresinput("manifest_row") @requiresinput("manifest_row")
@requiresinput("legacy_id_handler") @requiresinput("legacy_id_handler")
@optionalinput("legacy_image_row") @optionalinput("legacy_image_row")
def manifest(self, manifest_row, legacy_id_handler, legacy_image_row): def manifest(self, manifest_row, legacy_id_handler, legacy_image_row): # type: ignore[misc]
""" """
Returns the manifest for this tag. Returns the manifest for this tag.
""" """
@ -265,7 +265,7 @@ class Tag(
@property # type: ignore @property # type: ignore
@requiresinput("repository") @requiresinput("repository")
def repository(self, repository): def repository(self, repository): # type: ignore[misc]
""" """
Returns the repository under which this tag lives. Returns the repository under which this tag lives.
""" """
@ -287,7 +287,7 @@ class Tag(
class Manifest( class Manifest(
datatype( datatype( # type: ignore[misc]
"Manifest", "Manifest",
[ [
"digest", "digest",
@ -352,15 +352,15 @@ class Manifest(
@property # type: ignore @property # type: ignore
@requiresinput("repository") @requiresinput("repository")
def repository(self, repository): def repository(self, repository): # type: ignore[misc]
""" """
Returns the repository under which this manifest lives. Returns the repository under which this manifest lives.
""" """
return repository return repository
@property @property # type: ignore[misc]
@optionalinput("legacy_image_row") @optionalinput("legacy_image_row")
def _legacy_image_row(self, legacy_image_row): def _legacy_image_row(self, legacy_image_row): # type: ignore[misc]
return legacy_image_row return legacy_image_row
@property @property
@ -381,9 +381,9 @@ class Manifest(
# Otherwise, return None. # Otherwise, return None.
return None return None
@property @property # type: ignore[misc]
@requiresinput("legacy_id_handler") @requiresinput("legacy_id_handler")
def legacy_image_root_id(self, legacy_id_handler): def legacy_image_root_id(self, legacy_id_handler): # type: ignore[misc]
""" """
Returns the legacy Docker V1-style image ID for this manifest. Note that an ID will Returns the legacy Docker V1-style image ID for this manifest. Note that an ID will
be returned even if the manifest does not support a legacy image. be returned even if the manifest does not support a legacy image.
@ -394,9 +394,9 @@ class Manifest(
"""Returns the manifest or legacy image as a manifest.""" """Returns the manifest or legacy image as a manifest."""
return self return self
@property @property # type: ignore[misc]
@requiresinput("legacy_id_handler") @requiresinput("legacy_id_handler")
def _legacy_id_handler(self, legacy_id_handler): def _legacy_id_handler(self, legacy_id_handler): # type: ignore[misc]
return legacy_id_handler return legacy_id_handler
def lookup_legacy_image(self, layer_index, retriever): def lookup_legacy_image(self, layer_index, retriever):
@ -555,7 +555,7 @@ class ManifestLayer(namedtuple("ManifestLayer", ["layer_info", "blob"])):
class Blob( class Blob(
datatype("Blob", ["uuid", "digest", "compressed_size", "uncompressed_size", "uploading"]) datatype("Blob", ["uuid", "digest", "compressed_size", "uncompressed_size", "uploading"]) # type: ignore[misc]
): ):
""" """
Blob represents a content-addressable piece of storage. Blob represents a content-addressable piece of storage.
@ -578,7 +578,7 @@ class Blob(
@property # type: ignore @property # type: ignore
@requiresinput("storage_path") @requiresinput("storage_path")
def storage_path(self, storage_path): def storage_path(self, storage_path): # type: ignore[misc]
""" """
Returns the path of this blob in storage. Returns the path of this blob in storage.
""" """
@ -586,7 +586,7 @@ class Blob(
@property # type: ignore @property # type: ignore
@requiresinput("placements") @requiresinput("placements")
def placements(self, placements): def placements(self, placements): # type: ignore[misc]
""" """
Returns all the storage placements at which the Blob can be found. Returns all the storage placements at which the Blob can be found.
""" """
@ -594,7 +594,7 @@ class Blob(
class BlobUpload( class BlobUpload(
datatype( datatype( # type: ignore[misc]
"BlobUpload", "BlobUpload",
[ [
"upload_id", "upload_id",
@ -629,7 +629,7 @@ class BlobUpload(
) )
class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"])): class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"])): # type: ignore[misc]
""" """
LikelyVulnerableTag represents a tag in a repository that is likely vulnerable to a notified LikelyVulnerableTag represents a tag in a repository that is likely vulnerable to a notified
vulnerability. vulnerability.
@ -643,7 +643,7 @@ class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"]))
db_id=tag.id, name=tag.name, layer_id=layer_id, inputs=dict(repository=repository) db_id=tag.id, name=tag.name, layer_id=layer_id, inputs=dict(repository=repository)
) )
@property @property # type: ignore[misc]
@requiresinput("repository") @requiresinput("repository")
def repository(self, repository): def repository(self, repository): # type: ignore[misc]
return RepositoryReference.for_repo_obj(repository) return RepositoryReference.for_repo_obj(repository)

View File

@ -498,7 +498,7 @@ class ProxyModel(OCIModel):
return super().get_repo_blob_by_digest(repository_ref, blob_digest, include_placements) return super().get_repo_blob_by_digest(repository_ref, blob_digest, include_placements)
def _download_blob(self, repo_ref: RepositoryReference, digest: str) -> int: def _download_blob(self, repo_ref: RepositoryReference, digest: str) -> None:
""" """
Download blob from upstream registry and perform a monolitic upload to Download blob from upstream registry and perform a monolitic upload to
Quay's own storage. Quay's own storage.

View File

@ -2,6 +2,7 @@
import hashlib import hashlib
import json import json
import os
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta

View File

@ -50,7 +50,7 @@ logger = logging.getLogger(__name__)
DEFAULT_SECURITY_SCANNER_V4_REINDEX_THRESHOLD = 86400 # 1 day DEFAULT_SECURITY_SCANNER_V4_REINDEX_THRESHOLD = 86400 # 1 day
IndexReportState = namedtuple("IndexReportState", ["Index_Finished", "Index_Error"])( IndexReportState = namedtuple("IndexReportState", ["Index_Finished", "Index_Error"])( # type: ignore[call-arg]
"IndexFinished", "IndexError" "IndexFinished", "IndexError"
) )

View File

@ -15,29 +15,25 @@ from data.secscan_model.datatypes import (
) )
from data.database import ( from data.database import (
Manifest, Manifest,
Repository,
ManifestSecurityStatus, ManifestSecurityStatus,
IndexStatus, IndexStatus,
IndexerVersion, IndexerVersion,
User,
ManifestBlob, ManifestBlob,
db_transaction, db_transaction,
MediaType, MediaType,
) )
from data.registry_model.datatypes import Manifest as ManifestDataType
from data.registry_model import registry_model from data.registry_model import registry_model
from util.secscan.v4.api import APIRequestFailure from util.secscan.v4.api import APIRequestFailure
from util.canonicaljson import canonicalize
from image.docker.schema2 import DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE from image.docker.schema2 import DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE
from test.fixtures import * from test.fixtures import *
from app import app, instance_keys, storage from app import app as application, instance_keys, storage
@pytest.fixture() @pytest.fixture()
def set_secscan_config(): def set_secscan_config():
app.config["SECURITY_SCANNER_V4_ENDPOINT"] = "http://clairv4:6060" application.config["SECURITY_SCANNER_V4_ENDPOINT"] = "http://clairv4:6060"
def test_load_security_information_queued(initialized_db, set_secscan_config): def test_load_security_information_queued(initialized_db, set_secscan_config):
@ -45,7 +41,7 @@ def test_load_security_information_queued(initialized_db, set_secscan_config):
tag = registry_model.get_repo_tag(repository_ref, "latest") tag = registry_model.get_repo_tag(repository_ref, "latest")
manifest = registry_model.get_manifest_for_tag(tag) manifest = registry_model.get_manifest_for_tag(tag)
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
assert secscan.load_security_information(manifest).status == ScanLookupStatus.NOT_YET_INDEXED assert secscan.load_security_information(manifest).status == ScanLookupStatus.NOT_YET_INDEXED
@ -64,7 +60,7 @@ def test_load_security_information_failed_to_index(initialized_db, set_secscan_c
metadata_json={}, metadata_json={},
) )
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
assert secscan.load_security_information(manifest).status == ScanLookupStatus.FAILED_TO_INDEX assert secscan.load_security_information(manifest).status == ScanLookupStatus.FAILED_TO_INDEX
@ -83,7 +79,7 @@ def test_load_security_information_api_returns_none(initialized_db, set_secscan_
metadata_json={}, metadata_json={},
) )
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.vulnerability_report.return_value = None secscan._secscan_api.vulnerability_report.return_value = None
@ -105,7 +101,7 @@ def test_load_security_information_api_request_failure(initialized_db, set_secsc
metadata_json={}, metadata_json={},
) )
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.vulnerability_report.side_effect = APIRequestFailure() secscan._secscan_api.vulnerability_report.side_effect = APIRequestFailure()
@ -128,7 +124,7 @@ def test_load_security_information_success(initialized_db, set_secscan_config):
metadata_json={}, metadata_json={},
) )
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.vulnerability_report.return_value = { secscan._secscan_api.vulnerability_report.return_value = {
"manifest_hash": manifest.digest, "manifest_hash": manifest.digest,
@ -149,7 +145,7 @@ def test_load_security_information_success(initialized_db, set_secscan_config):
def test_perform_indexing_whitelist(initialized_db, set_secscan_config): def test_perform_indexing_whitelist(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.state.return_value = {"state": "abc"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -169,7 +165,7 @@ def test_perform_indexing_whitelist(initialized_db, set_secscan_config):
def test_perform_indexing_failed(initialized_db, set_secscan_config): def test_perform_indexing_failed(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.state.return_value = {"state": "abc"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -186,7 +182,7 @@ def test_perform_indexing_failed(initialized_db, set_secscan_config):
indexer_hash="abc", indexer_hash="abc",
indexer_version=IndexerVersion.V4, indexer_version=IndexerVersion.V4,
last_indexed=datetime.utcnow() last_indexed=datetime.utcnow()
- timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60),
metadata_json={}, metadata_json={},
) )
@ -199,9 +195,9 @@ def test_perform_indexing_failed(initialized_db, set_secscan_config):
def test_perform_indexing_failed_within_reindex_threshold(initialized_db, set_secscan_config): def test_perform_indexing_failed_within_reindex_threshold(initialized_db, set_secscan_config):
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300 application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.state.return_value = {"state": "abc"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -229,7 +225,7 @@ def test_perform_indexing_failed_within_reindex_threshold(initialized_db, set_se
def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config): def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.state.return_value = {"state": "xyz"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -246,7 +242,7 @@ def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config):
indexer_hash="abc", indexer_hash="abc",
indexer_version=IndexerVersion.V4, indexer_version=IndexerVersion.V4,
last_indexed=datetime.utcnow() last_indexed=datetime.utcnow()
- timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60),
metadata_json={}, metadata_json={},
) )
@ -259,7 +255,7 @@ def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config):
def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_secscan_config): def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "new hash"} secscan._secscan_api.state.return_value = {"state": "new hash"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -276,7 +272,7 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_
indexer_hash="old hash", indexer_hash="old hash",
indexer_version=IndexerVersion.V4, indexer_version=IndexerVersion.V4,
last_indexed=datetime.utcnow() last_indexed=datetime.utcnow()
- timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60),
metadata_json={}, metadata_json={},
) )
@ -296,7 +292,7 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_
( (
IndexStatus.MANIFEST_UNSUPPORTED, IndexStatus.MANIFEST_UNSUPPORTED,
{"status": "old hash"}, {"status": "old hash"},
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60,
True, True,
), ),
# Old hash and recent scan, don't rescan # Old hash and recent scan, don't rescan
@ -305,14 +301,14 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_
( (
IndexStatus.COMPLETED, IndexStatus.COMPLETED,
{"status": "old hash"}, {"status": "old hash"},
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60,
False, False,
), ),
# New hash and old scan, don't rescan # New hash and old scan, don't rescan
( (
IndexStatus.COMPLETED, IndexStatus.COMPLETED,
{"status": "new hash"}, {"status": "new hash"},
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60,
False, False,
), ),
# New hash and recent scan, don't rescan # New hash and recent scan, don't rescan
@ -321,14 +317,14 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_
( (
IndexStatus.FAILED, IndexStatus.FAILED,
{"status": "old hash"}, {"status": "old hash"},
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60,
False, False,
), ),
# New hash and old scan, rescan # New hash and old scan, rescan
( (
IndexStatus.FAILED, IndexStatus.FAILED,
{"status": "new hash"}, {"status": "new hash"},
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60,
False, False,
), ),
], ],
@ -336,7 +332,7 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_
def test_manifest_iterator( def test_manifest_iterator(
initialized_db, set_secscan_config, index_status, indexer_state, seconds, expect_zero initialized_db, set_secscan_config, index_status, indexer_state, seconds, expect_zero
): ):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
for manifest in Manifest.select(): for manifest in Manifest.select():
with db_transaction(): with db_transaction():
@ -360,7 +356,7 @@ def test_manifest_iterator(
Manifest.select(fn.Min(Manifest.id)).scalar(), Manifest.select(fn.Min(Manifest.id)).scalar(),
Manifest.select(fn.Max(Manifest.id)).scalar(), Manifest.select(fn.Max(Manifest.id)).scalar(),
reindex_threshold=datetime.utcnow() reindex_threshold=datetime.utcnow()
- timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"]), - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"]),
) )
count = 0 count = 0
@ -376,9 +372,9 @@ def test_manifest_iterator(
def test_perform_indexing_needs_reindexing_within_reindex_threshold( def test_perform_indexing_needs_reindexing_within_reindex_threshold(
initialized_db, set_secscan_config initialized_db, set_secscan_config
): ):
app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300 application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.state.return_value = {"state": "xyz"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -406,7 +402,7 @@ def test_perform_indexing_needs_reindexing_within_reindex_threshold(
def test_perform_indexing_api_request_failure_state(initialized_db, set_secscan_config): def test_perform_indexing_api_request_failure_state(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.side_effect = APIRequestFailure() secscan._secscan_api.state.side_effect = APIRequestFailure()
@ -418,7 +414,7 @@ def test_perform_indexing_api_request_failure_state(initialized_db, set_secscan_
def test_perform_indexing_api_request_index_error_response(initialized_db, set_secscan_config): def test_perform_indexing_api_request_index_error_response(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.state.return_value = {"state": "xyz"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -435,7 +431,7 @@ def test_perform_indexing_api_request_index_error_response(initialized_db, set_s
def test_perform_indexing_api_request_non_finished_state(initialized_db, set_secscan_config): def test_perform_indexing_api_request_non_finished_state(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.state.return_value = {"state": "xyz"}
secscan._secscan_api.index.return_value = ( secscan._secscan_api.index.return_value = (
@ -450,7 +446,7 @@ def test_perform_indexing_api_request_non_finished_state(initialized_db, set_sec
def test_perform_indexing_api_request_failure_index(initialized_db, set_secscan_config): def test_perform_indexing_api_request_failure_index(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.state.return_value = {"state": "abc"}
secscan._secscan_api.index.side_effect = APIRequestFailure() secscan._secscan_api.index.side_effect = APIRequestFailure()
@ -510,7 +506,7 @@ def test_features_for():
def test_perform_indexing_invalid_manifest(initialized_db, set_secscan_config): def test_perform_indexing_invalid_manifest(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
# Delete all ManifestBlob rows to cause the manifests to be invalid. # Delete all ManifestBlob rows to cause the manifests to be invalid.
@ -525,7 +521,7 @@ def test_perform_indexing_invalid_manifest(initialized_db, set_secscan_config):
def test_lookup_notification_page_invalid(initialized_db, set_secscan_config): def test_lookup_notification_page_invalid(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.retrieve_notification_page.return_value = None secscan._secscan_api.retrieve_notification_page.return_value = None
@ -535,7 +531,7 @@ def test_lookup_notification_page_invalid(initialized_db, set_secscan_config):
def test_lookup_notification_page_valid(initialized_db, set_secscan_config): def test_lookup_notification_page_valid(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.retrieve_notification_page.return_value = { secscan._secscan_api.retrieve_notification_page.return_value = {
"notifications": [ "notifications": [
@ -560,7 +556,7 @@ def test_lookup_notification_page_valid(initialized_db, set_secscan_config):
def test_mark_notification_handled(initialized_db, set_secscan_config): def test_mark_notification_handled(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan._secscan_api.delete_notification.return_value = True secscan._secscan_api.delete_notification.return_value = True
@ -568,7 +564,7 @@ def test_mark_notification_handled(initialized_db, set_secscan_config):
def test_process_notification_page(initialized_db, set_secscan_config): def test_process_notification_page(initialized_db, set_secscan_config):
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
results = list( results = list(
secscan.process_notification_page( secscan.process_notification_page(
@ -614,7 +610,7 @@ def test_perform_indexing_manifest_list(initialized_db, set_secscan_config):
media_type=MediaType.get(name=DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE) media_type=MediaType.get(name=DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE)
).execute() ).execute()
secscan = V4SecurityScanner(app, instance_keys, storage) secscan = V4SecurityScanner(application, instance_keys, storage)
secscan._secscan_api = mock.Mock() secscan._secscan_api = mock.Mock()
secscan.perform_indexing_recent_manifests() secscan.perform_indexing_recent_manifests()

View File

@ -1,4 +1,4 @@
from peewee import NodeList, SQL, fn, TextField, Field from peewee import NodeList, SQL, fn, Field
def _escape_wildcard(search_query): def _escape_wildcard(search_query):

View File

@ -87,7 +87,7 @@ class LDAPConnection(object):
class LDAPUsers(FederatedUsers): class LDAPUsers(FederatedUsers):
_LDAPResult = namedtuple("LDAPResult", ["dn", "attrs"]) _LDAPResult = namedtuple("_LDAPResult", ["dn", "attrs"])
def __init__( def __init__(
self, self,

View File

@ -3,7 +3,6 @@ List, create and manage repositories.
""" """
import logging import logging
import datetime
import features import features
from collections import defaultdict from collections import defaultdict

View File

@ -1,4 +1,5 @@
import os import os
import pytest
import time import time
from mock import patch from mock import patch

View File

@ -1,3 +1,4 @@
import os
import pytest import pytest
from playhouse.test_utils import assert_query_count from playhouse.test_utils import assert_query_count

View File

@ -17,7 +17,7 @@ USER_FIELDS: List[str] = [
] ]
class User(namedtuple("User", USER_FIELDS)): class User(namedtuple("User", USER_FIELDS)): # type: ignore[misc]
""" """
User represents a user. User represents a user.
""" """

View File

@ -33,7 +33,7 @@ oauthlogin_csrf_protect = csrf_protect(
OAuthResult = namedtuple( OAuthResult = namedtuple(
"oauthresult", "OAuthResult",
["user_obj", "service_name", "error_message", "register_redirect", "requires_verification"], ["user_obj", "service_name", "error_message", "register_redirect", "requires_verification"],
) )

View File

@ -1,3 +1,4 @@
import pytest
from data import model from data import model
from endpoints.api import api from endpoints.api import api
from endpoints.api.repository import Repository from endpoints.api.repository import Repository

View File

@ -68,7 +68,7 @@ def handle_quota_error(error):
return _format_error_response(QuotaExceeded()) return _format_error_response(QuotaExceeded())
def _format_error_response(error: Exception) -> Response: def _format_error_response(error: V2RegistryException) -> Response:
response = jsonify({"errors": [error.as_dict()]}) response = jsonify({"errors": [error.as_dict()]})
response.status_code = error.http_status_code response.status_code = error.http_status_code
logger.debug("sending response: %s", response.get_data()) logger.debug("sending response: %s", response.get_data())

View File

@ -1,9 +1,10 @@
import base64 import pytest
from flask import url_for from flask import url_for
from app import instance_keys, app as original_app from app import instance_keys, app as original_app
from data.model.user import regenerate_robot_token, get_robot_and_metadata, get_user from data import model
from data.model.user import get_robot_and_metadata, get_user
from endpoints.test.shared import conduct_call, gen_basic_auth from endpoints.test.shared import conduct_call, gen_basic_auth
from util.security.registry_jwt import decode_bearer_token, CLAIM_TUF_ROOTS from util.security.registry_jwt import decode_bearer_token, CLAIM_TUF_ROOTS

View File

@ -1,3 +1,5 @@
from typing import Dict
class FeatureNameValue(object): class FeatureNameValue(object):
def __init__(self, name: str, value: bool): ... def __init__(self, name: str, value: bool): ...
def __str__(self) -> str: ... def __str__(self) -> str: ...
@ -184,3 +186,7 @@ USER_INITIALIZE: FeatureNameValue
EXTENDED_REPOSITORY_NAMES: FeatureNameValue EXTENDED_REPOSITORY_NAMES: FeatureNameValue
QUOTA_MANAGEMENT: FeatureNameValue QUOTA_MANAGEMENT: FeatureNameValue
HELM_OCI_SUPPORT: FeatureNameValue
PROXY_CACHE: FeatureNameValue

View File

@ -267,27 +267,6 @@ def test_get_schema1_manifest_incorrect_history():
manifest.get_schema1_manifest("somenamespace", "somename", "sometag", retriever) manifest.get_schema1_manifest("somenamespace", "somename", "sometag", retriever)
def test_validate_manifest_invalid_config_type():
manifest_bytes = """{
"schemaVersion": 2,
"config": {
"mediaType": "application/some.other.thing",
"digest": "sha256:6bd578ec7d1e7381f63184dfe5fbe7f2f15805ecc4bfd485e286b76b1e796524",
"size": 145
},
"layers": [
{
"mediaType": "application/tar+gzip",
"digest": "sha256:ce879e86a8f71031c0f1ab149a26b000b3b5b8810d8d047f240ef69a6b2516ee",
"size": 2807
}
]
}"""
with pytest.raises(MalformedOCIManifest):
OCIManifest(Bytes.for_string_or_unicode(manifest_bytes))
def test_validate_helm_oci_manifest(): def test_validate_helm_oci_manifest():
manifest_bytes = """{ manifest_bytes = """{
"schemaVersion":2, "schemaVersion":2,

View File

@ -1,8 +1,10 @@
[mypy] [mypy]
python_version = 3.8 python_version = 3.8
mypy_path = mypy_stubs mypy_path = mypy_stubs
# local-dev is excluded until we decide what to do with __init__.py in those packages. exclude = (?x)(
exclude = local-dev/ ^buildman/test/test_buildman\.py$ |
^buildman/buildman_pb/buildman_pb2\.py$ |
^endpoints/api/test/test_security\.py$ )
# Necessary because most current dependencies do not have typing # Necessary because most current dependencies do not have typing
ignore_missing_imports = True ignore_missing_imports = True
@ -12,9 +14,3 @@ ignore_missing_imports = True
warn_redundant_casts = True warn_redundant_casts = True
warn_unused_ignores = True warn_unused_ignores = True
warn_unreachable = True warn_unreachable = True
# The `features` module uses some magic to declare attributes at the module level at runtime.
# mypy cant introspect this so we need to disable type checking on features.X
[mypy-features]
ignore_errors = True

View File

@ -2,7 +2,6 @@ import itertools
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from pysqlite2 import dbapi2 as pysq3
from typing import ( from typing import (
Any, Any,
AnyStr, AnyStr,
@ -18,6 +17,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Text, Text,
Tuple as TupleT,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -29,9 +29,6 @@ _T = TypeVar("_T")
# Manual Adjustments # Manual Adjustments
SENTINEL = object() SENTINEL = object()
sqlite3 = pysq3
sqlite3 = pysq3
class NullHandler(logging.Handler): class NullHandler(logging.Handler):
def emit(self, record: Any) -> None: ... def emit(self, record: Any) -> None: ...
@ -149,7 +146,7 @@ class Source(Node):
def __init__(self, alias: Optional[Any] = ...) -> None: ... def __init__(self, alias: Optional[Any] = ...) -> None: ...
def alias(self, name: Any) -> None: ... def alias(self, name: Any) -> None: ...
def select(self, *columns: Any): ... def select(self, *columns: Any): ...
def join(self, dest: Any, join_type: Any = ..., on: Optional[Any] = ...): ... def join(self: _T, dest: Any, join_type: Any = ..., on: Optional[Any] = ...) -> _T: ...
def left_outer_join(self, dest: Any, on: Optional[Any] = ...): ... def left_outer_join(self, dest: Any, on: Optional[Any] = ...): ...
def cte( def cte(
self, self,
@ -493,7 +490,7 @@ class DQ(ColumnBase):
def __invert__(self) -> None: ... def __invert__(self) -> None: ...
def clone(self): ... def clone(self): ...
Tuple: Any Tuple: Callable[..., NodeList]
class QualifiedNames(WrappedNode): class QualifiedNames(WrappedNode):
def __sql__(self, ctx: Any): ... def __sql__(self, ctx: Any): ...
@ -529,7 +526,7 @@ class BaseQuery(Node):
def objects(self, constructor: Optional[Any] = ...): ... def objects(self, constructor: Optional[Any] = ...): ...
def __sql__(self, ctx: Context) -> Context: ... def __sql__(self, ctx: Context) -> Context: ...
def sql(self): ... def sql(self): ...
def execute(self, database: Any): ... def execute(self, database: Optional[Any] = ...): ...
def iterator(self, database: Optional[Any] = ...): ... def iterator(self, database: Optional[Any] = ...): ...
def __iter__(self) -> Any: ... def __iter__(self) -> Any: ...
def __getitem__(self, value: Any): ... def __getitem__(self, value: Any): ...
@ -551,8 +548,8 @@ class Query(BaseQuery):
**kwargs: Any, **kwargs: Any,
) -> None: ... ) -> None: ...
def with_cte(self, *cte_list: Any) -> None: ... def with_cte(self, *cte_list: Any) -> None: ...
def where(self, *expressions: Any) -> None: ... def where(self: _T, *expressions: Any) -> _T: ...
def orwhere(self, *expressions: Any) -> None: ... def orwhere(self: _T, *expressions: Any) -> _T: ...
def order_by(self, *values: Any) -> None: ... def order_by(self, *values: Any) -> None: ...
def order_by_extend(self, *values: Any) -> None: ... def order_by_extend(self, *values: Any) -> None: ...
def limit(self, value: Optional[Any] = ...) -> None: ... def limit(self, value: Optional[Any] = ...) -> None: ...
@ -576,19 +573,19 @@ class SelectQuery(Query):
def select_from(self, *columns: Any): ... def select_from(self, *columns: Any): ...
class SelectBase(_HashableSource, Source, SelectQuery): class SelectBase(_HashableSource, Source, SelectQuery):
def peek(self, database: Any, n: int = ...): ... def peek(self, database: Optional[Any] = ..., n: int = ...): ...
def first(self, database: Any, n: int = ...): ... def first(self, database: Optional[Any] = ..., n: int = ...): ...
def scalar(self, database: Any, as_tuple: bool = ...): ... def scalar(self, database: Optional[Any] = ..., as_tuple: bool = ...): ...
def count(self, database: Any, clear_limit: bool = ...): ... def count(self, database: Optional[Any] = ..., clear_limit: bool = ...): ...
def exists(self, database: Any): ... def exists(self, database: Optional[Any] = ...): ...
def get(self, database: Any): ... def get(self, database: Optional[Any] = ...): ...
class CompoundSelectQuery(SelectBase): class CompoundSelectQuery(SelectBase):
lhs: Any = ... lhs: Any = ...
op: Any = ... op: Any = ...
rhs: Any = ... rhs: Any = ...
def __init__(self, lhs: Any, op: Any, rhs: Any) -> None: ... def __init__(self, lhs: Any, op: Any, rhs: Any) -> None: ...
def exists(self, database: Any): ... def exists(self, database: Optional[Any] = ...): ...
def __sql__(self, ctx: Any): ... def __sql__(self, ctx: Any): ...
class Select(SelectBase): class Select(SelectBase):
@ -609,16 +606,16 @@ class Select(SelectBase):
def columns(self, *columns: Any, **kwargs: Any) -> None: ... def columns(self, *columns: Any, **kwargs: Any) -> None: ...
select: Any = ... select: Any = ...
def select_extend(self, *columns: Any) -> None: ... def select_extend(self, *columns: Any) -> None: ...
def from_(self, *sources: Any) -> None: ... def from_(self: _T, *sources: Any) -> _T: ...
def join(self, dest: Any, join_type: Any = ..., on: Optional[Any] = ...) -> None: ... def join(self: _T, dest: Any, join_type: Any = ..., on: Optional[Any] = ...) -> _T: ...
def group_by(self, *columns: Any) -> None: ... def group_by(self: _T, *columns: Any) -> _T: ...
def group_by_extend(self, *values: Any): ... def group_by_extend(self, *values: Any): ...
def having(self, *expressions: Any) -> None: ... def having(self: _T, *expressions: Any) -> _T: ...
def distinct(self, *columns: Any) -> None: ... def distinct(self: _T, *columns: Any) -> _T: ...
def window(self, *windows: Any) -> None: ... def window(self: _T, *windows: Any) -> _T: ...
def for_update( def for_update(
self, for_update: bool = ..., of: Optional[Any] = ..., nowait: Optional[Any] = ... self: _T, for_update: bool = ..., of: Optional[Any] = ..., nowait: Optional[Any] = ...
) -> None: ... ) -> _T: ...
def __sql_selection__(self, ctx: Any, is_subquery: bool = ...): ... def __sql_selection__(self, ctx: Any, is_subquery: bool = ...): ...
def __sql__(self, ctx: Any): ... def __sql__(self, ctx: Any): ...
@ -651,7 +648,7 @@ class Insert(_WriteQuery):
on_conflict: Optional[Any] = ..., on_conflict: Optional[Any] = ...,
**kwargs: Any, **kwargs: Any,
) -> None: ... ) -> None: ...
def where(self, *expressions: Any) -> None: ... def where(self: _T, *expressions: Any) -> _T: ...
def on_conflict_ignore(self, ignore: bool = ...) -> None: ... def on_conflict_ignore(self, ignore: bool = ...) -> None: ...
def on_conflict_replace(self, replace: bool = ...) -> None: ... def on_conflict_replace(self, replace: bool = ...) -> None: ...
def on_conflict(self, *args: Any, **kwargs: Any) -> None: ... def on_conflict(self, *args: Any, **kwargs: Any) -> None: ...
@ -1100,7 +1097,7 @@ class Field(ColumnBase):
sequence: Optional[str] = ..., sequence: Optional[str] = ...,
collation: Optional[str] = ..., collation: Optional[str] = ...,
unindexed: Optional[bool] = ..., unindexed: Optional[bool] = ...,
choices: Optional[Iterable[Tuple[Any, str]]] = ..., choices: Optional[Iterable[TupleT[Any, str]]] = ...,
help_text: Optional[str] = ..., help_text: Optional[str] = ...,
verbose_name: Optional[str] = ..., verbose_name: Optional[str] = ...,
index_type: Optional[str] = ..., index_type: Optional[str] = ...,
@ -1116,7 +1113,7 @@ class Field(ColumnBase):
def db_value(self, value: _T) -> _T: ... def db_value(self, value: _T) -> _T: ...
def python_value(self, value: _T) -> _T: ... def python_value(self, value: _T) -> _T: ...
def to_value(self, value: Any) -> Value: ... def to_value(self, value: Any) -> Value: ...
def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ... def get_sort_key(self, ctx: Context) -> TupleT[int, int]: ...
def __sql__(self, ctx: Context) -> Context: ... def __sql__(self, ctx: Context) -> Context: ...
def get_modifiers(self) -> Any: ... def get_modifiers(self) -> Any: ...
def ddl_datatype(self, ctx: Context) -> SQL: ... def ddl_datatype(self, ctx: Context) -> SQL: ...
@ -1155,7 +1152,7 @@ class _StringField(Field):
def adapt(self, value: AnyStr) -> str: ... def adapt(self, value: AnyStr) -> str: ...
@overload @overload
def adapt(self, value: _T) -> _T: ... def adapt(self, value: _T) -> _T: ...
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: ... def split(self, sep: Optional[str] = ..., maxsplit: int = ...) -> List[str]: ...
def __add__(self, other: Any) -> StringExpression: ... def __add__(self, other: Any) -> StringExpression: ...
def __radd__(self, other: Any) -> StringExpression: ... def __radd__(self, other: Any) -> StringExpression: ...
@ -1545,7 +1542,9 @@ class Model(Node, metaclass=ModelBase):
) -> ModelInsert: ... ) -> ModelInsert: ...
@overload @overload
@classmethod @classmethod
def insert_many(cls, rows: Iterable[Mapping[str, object]], fields: None) -> ModelInsert: ... def insert_many(
cls, rows: Iterable[Mapping[str, object]], fields: None = ...
) -> ModelInsert: ...
@overload @overload
@classmethod @classmethod
def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ... def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ...
@ -1593,7 +1592,7 @@ class Model(Node, metaclass=ModelBase):
@classmethod @classmethod
def get_or_create( def get_or_create(
cls, *, defaults: Mapping[str, object] = ..., **kwargs: object cls, *, defaults: Mapping[str, object] = ..., **kwargs: object
) -> Tuple[Any, bool]: ... ) -> TupleT[Any, bool]: ...
@classmethod @classmethod
def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ... def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ...
def get_id(self) -> Any: ... def get_id(self) -> Any: ...
@ -1605,7 +1604,7 @@ class Model(Node, metaclass=ModelBase):
def dirty_fields(self) -> List[Field]: ... def dirty_fields(self) -> List[Field]: ...
def dependencies( def dependencies(
self, search_nullable: bool = ... self, search_nullable: bool = ...
) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ... ) -> Iterator[TupleT[Union[bool, Node], ForeignKeyField]]: ...
def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ... def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ...
def __hash__(self) -> int: ... def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool: ... def __eq__(self, other: object) -> bool: ...
@ -1685,7 +1684,7 @@ class BaseModelSelect(_ModelQueryHelper):
def __iter__(self) -> Any: ... def __iter__(self) -> Any: ...
def prefetch(self, *subqueries: Any): ... def prefetch(self, *subqueries: Any): ...
def get(self, database: Optional[Any] = ...): ... def get(self, database: Optional[Any] = ...): ...
def group_by(self, *columns: Any) -> None: ... def group_by(self, *columns: Any): ...
class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery):
model: Any = ... model: Any = ...
@ -1695,24 +1694,24 @@ class ModelSelect(BaseModelSelect, Select):
model: Any = ... model: Any = ...
def __init__(self, model: Any, fields_or_models: Any, is_default: bool = ...) -> None: ... def __init__(self, model: Any, fields_or_models: Any, is_default: bool = ...) -> None: ...
def clone(self): ... def clone(self): ...
def select(self, *fields_or_models: Any): ... def select(self: _T, *fields_or_models: Any) -> _T: ...
def switch(self, ctx: Optional[Any] = ...): ... def switch(self, ctx: Optional[Any] = ...): ...
def join( def join(
self, self: _T,
dest: Any, dest: Any,
join_type: Any = ..., join_type: Any = ...,
on: Optional[Any] = ..., on: Optional[Any] = ...,
src: Optional[Any] = ..., src: Optional[Any] = ...,
attr: Optional[Any] = ..., attr: Optional[Any] = ...,
) -> None: ... ) -> _T: ...
def join_from( def join_from(
self, self: _T,
src: Any, src: Any,
dest: Any, dest: Any,
join_type: Any = ..., join_type: Any = ...,
on: Optional[Any] = ..., on: Optional[Any] = ...,
attr: Optional[Any] = ..., attr: Optional[Any] = ...,
): ... ) -> _T: ...
def ensure_join(self, lm: Any, rm: Any, on: Optional[Any] = ..., **join_kwargs: Any): ... def ensure_join(self, lm: Any, rm: Any, on: Optional[Any] = ..., **join_kwargs: Any): ...
def convert_dict_to_node(self, qdict: Any): ... def convert_dict_to_node(self, qdict: Any): ...
def filter(self, *args: Any, **kwargs: Any): ... def filter(self, *args: Any, **kwargs: Any): ...

View File

@ -1,6 +1,6 @@
[tool.black] [tool.black]
line-length = 100 line-length = 100
target-version = ['py27'] target-version = ['py38']
[tool.pylint.messages_control] [tool.pylint.messages_control]
disable = "missing-docstring,invalid-name,too-many-locals,too-few-public-methods,too-many-lines" disable = "missing-docstring,invalid-name,too-many-locals,too-few-public-methods,too-many-lines"
@ -17,3 +17,6 @@ branch = true
[tool.coverage.report] [tool.coverage.report]
omit = ['test/**', 'venv/**', '**/test/**'] omit = ['test/**', 'venv/**', '**/test/**']
[tool.pyright]
stubPath = 'mypy_stubs'

View File

@ -11,7 +11,7 @@ httmock==1.3.0
ipdb ipdb
ipython ipython
mock==3.0.5 mock==3.0.5
mypy==0.910 mypy==0.950
moto==2.0.1 moto==2.0.1
parameterized==0.8.1 parameterized==0.8.1
pytest pytest

View File

@ -4,7 +4,7 @@ APScheduler==3.6.3
attrs==19.3.0 attrs==19.3.0
Authlib==1.0.0a1 Authlib==1.0.0a1
aws-sam-translator==1.20.1 aws-sam-translator==1.20.1
azure-core==1.8.0 azure-core==1.23.1
azure-storage-blob==12.4.0 azure-storage-blob==12.4.0
Babel==2.9.1 Babel==2.9.1
bcrypt==3.1.7 bcrypt==3.1.7
@ -40,7 +40,7 @@ futures==3.1.1
geoip2==3.0.0 geoip2==3.0.0
gevent==21.8.0 gevent==21.8.0
greenlet==1.1.2 greenlet==1.1.2
grpcio==1.30.0 grpcio==1.46.3
gunicorn==20.1.0 gunicorn==20.1.0
hashids==1.2.0 hashids==1.2.0
html5lib==1.0.1 html5lib==1.0.1

View File

@ -1,7 +1,7 @@
import logging import logging
import base64 import base64
import urllib.request, urllib.parse, urllib.error
from flask import Flask
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import abort, request from flask import abort, request
from jsonschema import validate, ValidationError from jsonschema import validate, ValidationError
@ -61,7 +61,7 @@ class DownloadProxy(object):
NGINX. NGINX.
""" """
def __init__(self, app, instance_keys): def __init__(self, app: Flask, instance_keys):
self.app = app self.app = app
self.instance_keys = instance_keys self.instance_keys = instance_keys

View File

@ -1,3 +1,4 @@
import os
import pytest import pytest
from contextlib import contextmanager from contextlib import contextmanager

View File

@ -1,11 +1,7 @@
import os import os
import pytest import pytest
import requests
from flask import Flask
from flask_testing import LiveServerTestCase
from data.database import close_db_filter, configure
from storage import Storage from storage import Storage
from util.security.instancekeys import InstanceKeys from util.security.instancekeys import InstanceKeys

View File

@ -1,8 +1,6 @@
import os import os
from cachetools.func import lru_cache
from collections import namedtuple from collections import namedtuple
from datetime import datetime, timedelta
import pytest import pytest
import shutil import shutil
@ -10,7 +8,7 @@ import inspect
from flask import Flask, jsonify from flask import Flask, jsonify
from flask_login import LoginManager from flask_login import LoginManager
from flask_principal import identity_loaded, Permission, Identity, identity_changed, Principal from flask_principal import identity_loaded, Principal
from flask_mail import Mail from flask_mail import Mail
from peewee import SqliteDatabase, InternalError from peewee import SqliteDatabase, InternalError
from mock import patch from mock import patch
@ -20,7 +18,7 @@ from app import app as application
from auth.permissions import on_identity_loaded from auth.permissions import on_identity_loaded
from data import model from data import model
from data.database import close_db_filter, db, configure from data.database import close_db_filter, db, configure
from data.model.user import LoginWrappedDBUser, create_robot, lookup_robot, create_user_noverify from data.model.user import LoginWrappedDBUser
from data.userfiles import Userfiles from data.userfiles import Userfiles
from endpoints.api import api_bp from endpoints.api import api_bp
from endpoints.appr import appr_bp from endpoints.appr import appr_bp
@ -43,6 +41,9 @@ from test.testconfig import FakeTransaction
INIT_DB_PATH = 0 INIT_DB_PATH = 0
__all__ = ["init_db_path", "database_uri", "sqlitedb_file", "appconfig", "initialized_db", "app"]
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def init_db_path(tmpdir_factory): def init_db_path(tmpdir_factory):
""" """

View File

@ -1,4 +1,5 @@
import json import json
from typing import Dict
from io import BytesIO from io import BytesIO
from enum import Enum, unique from enum import Enum, unique
@ -29,7 +30,7 @@ class V1ProtocolSteps(Enum):
class V1Protocol(RegistryProtocol): class V1Protocol(RegistryProtocol):
FAILURE_CODES = { FAILURE_CODES: Dict[Enum, Dict[Failures, int]] = {
V1ProtocolSteps.PUT_IMAGES: { V1ProtocolSteps.PUT_IMAGES: {
Failures.INVALID_AUTHENTICATION: 403, Failures.INVALID_AUTHENTICATION: 403,
Failures.UNAUTHENTICATED: 401, Failures.UNAUTHENTICATED: 401,

View File

@ -1,5 +1,6 @@
import hashlib import hashlib
import json import json
from typing import Dict
from enum import Enum, unique from enum import Enum, unique
@ -49,7 +50,7 @@ class V2ProtocolSteps(Enum):
class V2Protocol(RegistryProtocol): class V2Protocol(RegistryProtocol):
FAILURE_CODES = { FAILURE_CODES: Dict[Enum, Dict[Failures, int]] = {
V2ProtocolSteps.AUTH: { V2ProtocolSteps.AUTH: {
Failures.UNAUTHENTICATED: 401, Failures.UNAUTHENTICATED: 401,
Failures.INVALID_AUTHENTICATION: 401, Failures.INVALID_AUTHENTICATION: 401,

View File

@ -1,5 +1,6 @@
import json import json
import tarfile import tarfile
from typing import Dict
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
@ -104,7 +105,7 @@ class RegistryProtocol(object):
Interface for protocols. Interface for protocols.
""" """
FAILURE_CODES = {} FAILURE_CODES: Dict[Enum, Dict[Failures, int]] = {}
@abstractmethod @abstractmethod
def login(self, session, username, password, scopes, expect_success): def login(self, session, username, password, scopes, expect_success):

View File

@ -3943,7 +3943,7 @@ class TestLogs(ApiTestCase):
json = self.getJsonResponse(UserAggregateLogs) json = self.getJsonResponse(UserAggregateLogs)
assert "aggregated" in json assert "aggregated" in json
def test_org_logs(self): def test_org_aggregate_logs(self):
self.login(ADMIN_ACCESS_USER) self.login(ADMIN_ACCESS_USER)
json = self.getJsonResponse(OrgAggregateLogs, params=dict(orgname=ORGANIZATION)) json = self.getJsonResponse(OrgAggregateLogs, params=dict(orgname=ORGANIZATION))

View File

@ -1,5 +1,6 @@
import base64 import base64
import unittest import unittest
from typing import Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@ -191,7 +192,7 @@ class JWTAuthTestMixin:
Mixin defining all the JWT auth tests. Mixin defining all the JWT auth tests.
""" """
maxDiff = None maxDiff: Optional[int] = None
@property @property
def emails(self): def emails(self):

View File

@ -1,6 +1,7 @@
import json import json
import os import os
import unittest import unittest
from typing import Optional
import requests import requests
@ -289,7 +290,7 @@ def _create_app(requires_email=True):
class KeystoneAuthTestsMixin: class KeystoneAuthTestsMixin:
maxDiff = None maxDiff: Optional[int] = None
@property @property
def emails(self): def emails(self):

View File

@ -34,17 +34,22 @@ def make_custom_sort(orders):
return process return process
SCHEMA_HTML_FILE = "schema.html" def main():
SCHEMA_HTML_FILE = "schema.html"
schema = json.dumps(CONFIG_SCHEMA, sort_keys=True) schema = json.dumps(CONFIG_SCHEMA, sort_keys=True)
schema = json.loads(schema, object_pairs_hook=OrderedDict) schema = json.loads(schema, object_pairs_hook=OrderedDict)
req = sorted(schema["required"]) req = sorted(schema["required"])
custom_sort = make_custom_sort([req]) custom_sort = make_custom_sort([req])
schema = custom_sort(schema) schema = custom_sort(schema)
parsed_items = docsmodel.DocsModel().parse(schema)[1:] parsed_items = docsmodel.DocsModel().parse(schema)[1:]
output = html_output.HtmlOutput().generate_output(parsed_items) output = html_output.HtmlOutput().generate_output(parsed_items)
with open(SCHEMA_HTML_FILE, "wt") as f: with open(SCHEMA_HTML_FILE, "wt") as f:
f.write(output) f.write(output)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,4 @@
import os
import pytest import pytest
from util.config.provider import FileConfigProvider from util.config.provider import FileConfigProvider

View File

@ -34,6 +34,8 @@ def valid_date(s):
if __name__ == "__main__": if __name__ == "__main__":
from cryptography.hazmat.primitives import serialization
parser = argparse.ArgumentParser(description="Generates a preshared key") parser = argparse.ArgumentParser(description="Generates a preshared key")
parser.add_argument("service", help="The service name for which the key is being generated") parser.add_argument("service", help="The service name for which the key is being generated")
parser.add_argument("name", help="The friendly name for the key") parser.add_argument("name", help="The friendly name for the key")

View File

@ -13,7 +13,7 @@ SKOPEO_TIMEOUT_SECONDS = 300
# tags: list of tags or empty list # tags: list of tags or empty list
# stdout: stdout from skopeo subprocess # stdout: stdout from skopeo subprocess
# stderr: stderr from skopeo subprocess # stderr: stderr from skopeo subprocess
SkopeoResults = namedtuple("SkopeoCopyResults", "success tags stdout stderr") SkopeoResults = namedtuple("SkopeoResults", "success tags stdout stderr")
class SkopeoMirror(object): class SkopeoMirror(object):

View File

@ -6,6 +6,7 @@ import os
import jwt import jwt
import base64 import base64
import time import time
from typing import Dict, Callable
from collections import namedtuple from collections import namedtuple
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -124,7 +125,7 @@ class SecurityScannerAPIInterface(object):
Action = namedtuple("Action", ["name", "payload"]) Action = namedtuple("Action", ["name", "payload"])
actions = { actions: Dict[str, Callable[..., Action]] = {
"IndexState": lambda: Action("IndexState", ("GET", "/indexer/api/v1/index_state", None)), "IndexState": lambda: Action("IndexState", ("GET", "/indexer/api/v1/index_state", None)),
"Index": lambda manifest: Action("Index", ("POST", "/indexer/api/v1/index_report", manifest)), "Index": lambda manifest: Action("Index", ("POST", "/indexer/api/v1/index_report", manifest)),
"GetIndexReport": lambda manifest_hash: Action( "GetIndexReport": lambda manifest_hash: Action(

View File

@ -1,5 +1,6 @@
import pytest import pytest
import requests import requests
from typing import Dict, Any
from mock import mock, patch from mock import mock, patch
from flask import Flask from flask import Flask
@ -9,7 +10,7 @@ from test.fixtures import init_db_path
from util.tufmetadata import api from util.tufmetadata import api
valid_response = { valid_response: Dict[str, Any] = {
"signed": { "signed": {
"type": "Targets", "type": "Targets",
"delegations": { "delegations": {
@ -95,7 +96,7 @@ valid_targets_with_delegation = {
} }
valid_delegation = { valid_delegation: Dict[str, Any] = {
"signed": { "signed": {
"_type": "Targets", "_type": "Targets",
"delegations": {"keys": {}, "roles": []}, "delegations": {"keys": {}, "roles": []},

View File

@ -52,10 +52,11 @@ def create_gunicorn_worker():
if __name__ == "__main__": if __name__ == "__main__":
if os.getenv("PYDEV_DEBUG", None): pydev_debug = os.getenv("PYDEV_DEBUG", None)
if pydev_debug:
import pydevd_pycharm import pydevd_pycharm
host, port = os.getenv("PYDEV_DEBUG").split(":") host, port = pydev_debug.split(":")
pydevd_pycharm.settrace( pydevd_pycharm.settrace(
host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False
) )

View File

@ -2,6 +2,7 @@ import pytest
import mock import mock
import json import json
from functools import wraps from functools import wraps
from unittest.mock import patch
from app import storage from app import storage
from data.registry_model.blobuploader import upload_blob, BlobUploadSettings from data.registry_model.blobuploader import upload_blob, BlobUploadSettings

View File

@ -67,10 +67,11 @@ def create_gunicorn_worker():
if __name__ == "__main__": if __name__ == "__main__":
if os.getenv("PYDEV_DEBUG", None): pydev_debug = os.getenv("PYDEV_DEBUG", None)
if pydev_debug:
import pydevd_pycharm import pydevd_pycharm
host, port = os.getenv("PYDEV_DEBUG").split(":") host, port = pydev_debug.split(":")
pydevd_pycharm.settrace( pydevd_pycharm.settrace(
host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False
) )

View File

@ -1,7 +1,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mock import patch from mock import patch
from data import model
from workers.servicekeyworker.servicekeyworker import ServiceKeyWorker from workers.servicekeyworker.servicekeyworker import ServiceKeyWorker
from util.morecollections import AttrDict from util.morecollections import AttrDict

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import pytest
from datetime import datetime, timedelta from datetime import datetime, timedelta

View File

@ -1,5 +1,6 @@
import os.path import os.path
import pytest
from unittest.mock import patch
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app import storage from app import storage

View File

@ -1,4 +1,5 @@
import json import json
import os
import pytest import pytest
from urllib.parse import urlparse from urllib.parse import urlparse