diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 6ad9a6084..0ed21dee5 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -64,6 +64,27 @@ jobs: - name: tox run: tox -e py38-unit + types: + name: Types Test + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install libgpgme-dev libldap2-dev libsasl2-dev swig + python -m pip install --upgrade pip + pip install -r ./requirements-dev.txt + + - name: Check Types + run: make types-test + e2e: name: E2E Tests runs-on: ubuntu-latest diff --git a/_init.py b/_init.py index f61ad16df..76838511e 100644 --- a/_init.py +++ b/_init.py @@ -7,7 +7,9 @@ try: except ModuleNotFoundError: # Stub out this call so that we can run the external_libraries script # without needing the entire codebase. - def get_config_provider(*args, **kwargs): + def get_config_provider( + config_volume, yaml_filename, py_filename, testing=False, kubernetes=False + ): return None diff --git a/auth/permissions.py b/auth/permissions.py index d0c89c38f..3a278d6ca 100644 --- a/auth/permissions.py +++ b/auth/permissions.py @@ -15,16 +15,16 @@ from data import model logger = logging.getLogger(__name__) -_ResourceNeed = namedtuple("resource", ["type", "namespace", "name", "role"]) +_ResourceNeed = namedtuple("_ResourceNeed", ["type", "namespace", "name", "role"]) _RepositoryNeed = partial(_ResourceNeed, "repository") -_NamespaceWideNeed = namedtuple("namespacewide", ["type", "namespace", "role"]) +_NamespaceWideNeed = namedtuple("_NamespaceWideNeed", ["type", "namespace", "role"]) _OrganizationNeed = partial(_NamespaceWideNeed, "organization") _OrganizationRepoNeed = partial(_NamespaceWideNeed, "organizationrepo") -_TeamTypeNeed = namedtuple("teamwideneed", ["type", "orgname", "teamname", "role"]) +_TeamTypeNeed = namedtuple("_TeamTypeNeed", ["type", "orgname", "teamname", "role"]) _TeamNeed = partial(_TeamTypeNeed, "orgteam") -_UserTypeNeed = namedtuple("userspecificneed", ["type", "username", "role"]) +_UserTypeNeed = namedtuple("_UserTypeNeed", ["type", "username", "role"]) _UserNeed = partial(_UserTypeNeed, "user") -_SuperUserNeed = partial(namedtuple("superuserneed", ["type"]), "superuser") +_SuperUserNeed = partial(namedtuple("_SuperUserNeed", ["type"]), "superuser") REPO_ROLES = [None, "read", "write", "admin"] diff --git a/auth/scopes.py b/auth/scopes.py index c631e3e07..f5bbe2b93 100644 --- a/auth/scopes.py +++ b/auth/scopes.py @@ -2,7 +2,7 @@ from collections import namedtuple import features import re -Scope = namedtuple("scope", ["scope", "icon", "dangerous", "title", "description"]) +Scope = namedtuple("Scope", ["scope", "icon", "dangerous", "title", "description"]) READ_REPO = Scope( diff --git a/auth/test/test_credentials.py b/auth/test/test_credentials.py index a22cfdc6d..55c269914 100644 --- a/auth/test/test_credentials.py +++ b/auth/test/test_credentials.py @@ -54,7 +54,7 @@ def test_valid_oauth(app): assert result == ValidateResult(AuthKind.oauth, oauthtoken=oauth_token) -def test_invalid_user(app): +def test_invalid_password(app): result, kind = validate_credentials("devtable", "somepassword") assert kind == CredentialKind.user assert result == ValidateResult( diff --git a/buildman/container_cloud_config.py b/buildman/container_cloud_config.py index 0df114b46..2f882fe55 100644 --- a/buildman/container_cloud_config.py +++ b/buildman/container_cloud_config.py @@ -4,30 +4,11 @@ Provides helper methods and templates for generating cloud config for running co Originally from https://github.com/DevTable/container-cloud-config """ -from functools import partial - -import base64 import json import os -import requests import logging -try: - # Python 3 - from urllib.request import HTTPRedirectHandler, build_opener, install_opener, urlopen, Request - from urllib.error import HTTPError - from urllib.parse import quote as urlquote -except ImportError: - # Python 2 - from urllib2 import ( - HTTPRedirectHandler, - build_opener, - install_opener, - urlopen, - Request, - HTTPError, - ) - from urllib import quote as urlquote +from urllib.parse import quote as urlquote from jinja2 import FileSystemLoader, Environment, StrictUndefined diff --git a/buildman/orchestrator.py b/buildman/orchestrator.py index dd7884003..2be7f7743 100644 --- a/buildman/orchestrator.py +++ b/buildman/orchestrator.py @@ -186,7 +186,7 @@ class Orchestrator(metaclass=ABCMeta): pass @abstractmethod - def shutdown(): + def shutdown(self): """ This function should shutdown any final resources allocated by the Orchestrator. """ diff --git a/config.py b/config.py index 82a132256..dce31bdae 100644 --- a/config.py +++ b/config.py @@ -461,14 +461,15 @@ class DefaultConfig(ImmutableConfig): AVATAR_KIND = "local" # Custom branding + BRANDING: Dict[str, Optional[str]] if os.environ.get("RED_HAT_QUAY", False): - BRANDING: Dict[str, Optional[str]] = { + BRANDING = { "logo": "/static/img/RH_Logo_Quay_Black_UX-horizontal.svg", "footer_img": "/static/img/RedHat.svg", "footer_url": "https://access.redhat.com/documentation/en-us/red_hat_quay/3/", } else: - BRANDING: Dict[str, Optional[str]] = { + BRANDING = { "logo": "/static/img/quay-horizontal-color.svg", "footer_img": None, "footer_url": None, @@ -485,7 +486,7 @@ class DefaultConfig(ImmutableConfig): FEATURE_SECURITY_NOTIFICATIONS = False # The endpoint for the (deprecated) V2 security scanner. - SECURITY_SCANNER_ENDPOINT = None + SECURITY_SCANNER_ENDPOINT: Optional[str] = None # The endpoint for the V4 security scanner. SECURITY_SCANNER_V4_ENDPOINT: Optional[str] = None diff --git a/data/archivedlogs.py b/data/archivedlogs.py index d2b3fbabb..d17d521e8 100644 --- a/data/archivedlogs.py +++ b/data/archivedlogs.py @@ -1,9 +1,6 @@ import logging -from util.registry.gzipinputstream import GzipInputStream -from flask import send_file, abort - -from data.userfiles import DelegateUserfiles, UserfilesHandlers +from data.userfiles import DelegateUserfiles JSON_MIMETYPE = "application/json" diff --git a/data/billing.py b/data/billing.py index 0a42bcc4c..01e0a4240 100644 --- a/data/billing.py +++ b/data/billing.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict import stripe from datetime import datetime, timedelta @@ -384,7 +384,7 @@ class FakeStripe(object): } ) - ACTIVE_CUSTOMERS = {} + ACTIVE_CUSTOMERS: Dict[str, Any] = {} @property def card(self): diff --git a/data/database.py b/data/database.py index 7910a9760..549f0efda 100644 --- a/data/database.py +++ b/data/database.py @@ -988,6 +988,7 @@ class RepositorySearchScore(BaseModel): class RepositorySize(BaseModel): repository = ForeignKeyField(Repository, unique=True) + repository_id: int size_bytes = BigIntegerField() @@ -1519,6 +1520,7 @@ class UploadedBlob(BaseModel): class BlobUpload(BaseModel): repository = ForeignKeyField(Repository) + repository_id: int uuid = CharField(index=True, unique=True) byte_count = BigIntegerField(default=0) # TODO(kleesc): Verify that this is backward compatible with resumablehashlib @@ -1798,6 +1800,7 @@ class Manifest(BaseModel): """ repository = ForeignKeyField(Repository) + repository_id: int digest = CharField(index=True) media_type = EnumField(MediaType) manifest_bytes = TextField() @@ -1830,6 +1833,7 @@ class Tag(BaseModel): name = CharField() repository = ForeignKeyField(Repository) + repository_id: int manifest = ForeignKeyField(Manifest, null=True) lifetime_start_ms = BigIntegerField(default=get_epoch_timestamp_ms) lifetime_end_ms = BigIntegerField(null=True, index=True) diff --git a/data/fields.py b/data/fields.py index b867c2857..16d585706 100644 --- a/data/fields.py +++ b/data/fields.py @@ -6,7 +6,6 @@ import json from random import SystemRandom import bcrypt -import rehash from peewee import TextField, CharField, SmallIntegerField from data.text import prefix_search diff --git a/data/logs_model/test/test_combined_model.py b/data/logs_model/test/test_combined_model.py index 4dda7b490..aaa3bc351 100644 --- a/data/logs_model/test/test_combined_model.py +++ b/data/logs_model/test/test_combined_model.py @@ -1,7 +1,9 @@ +import pytest from datetime import date, datetime, timedelta from freezegun import freeze_time +from data import model from data.logs_model.inmemory_model import InMemoryModel from data.logs_model.combined_model import CombinedLogsModel diff --git a/data/logs_model/test/test_logs_interface.py b/data/logs_model/test/test_logs_interface.py index 442779177..0c92de197 100644 --- a/data/logs_model/test/test_logs_interface.py +++ b/data/logs_model/test/test_logs_interface.py @@ -1,4 +1,7 @@ +import os +import pytest from datetime import datetime, timedelta, date +from unittest.mock import patch from data.logs_model.datatypes import AggregatedLogCount from data.logs_model.table_logs_model import TableLogsModel from data.logs_model.combined_model import CombinedLogsModel diff --git a/data/model/oci/manifest.py b/data/model/oci/manifest.py index 1f789ee72..beedacaf8 100644 --- a/data/model/oci/manifest.py +++ b/data/model/oci/manifest.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging import os +from typing import overload, Optional, Literal from collections import namedtuple @@ -130,11 +131,29 @@ def _lookup_manifest(repository_id, manifest_digest, allow_dead=False): return None +@overload +def create_manifest( + repository_id: int, + manifest: ManifestInterface | ManifestListInterface, + raise_on_error: Literal[True] = ..., +) -> Manifest: + ... + + +@overload +def create_manifest( + repository_id: int, + manifest: ManifestInterface | ManifestListInterface, + raise_on_error: Literal[False], +) -> Optional[Manifest]: + ... + + def create_manifest( repository_id: int, manifest: ManifestInterface | ManifestListInterface, raise_on_error: bool = True, -) -> Manifest: +) -> Optional[Manifest]: """ Creates a manifest in the database. Does not handle sub manifests in a manifest list/index. diff --git a/data/model/oci/test/test_oci_manifest.py b/data/model/oci/test/test_oci_manifest.py index 0aa6d6c1c..2d3682ca6 100644 --- a/data/model/oci/test/test_oci_manifest.py +++ b/data/model/oci/test/test_oci_manifest.py @@ -1,4 +1,5 @@ import json +import pytest from playhouse.test_utils import assert_query_count diff --git a/data/model/oci/test/test_oci_tag.py b/data/model/oci/test/test_oci_tag.py index 84f50b500..1741fc49f 100644 --- a/data/model/oci/test/test_oci_tag.py +++ b/data/model/oci/test/test_oci_tag.py @@ -1,15 +1,12 @@ +import pytest from calendar import timegm from datetime import timedelta, datetime from playhouse.test_utils import assert_query_count +from data import model from data.database import ( Tag, - ManifestLegacyImage, - TagToRepositoryTag, - TagManifestToManifest, - TagManifest, - Manifest, Repository, ) from data.model.oci.test.test_oci_manifest import create_manifest_for_testing diff --git a/data/model/repositoryactioncount.py b/data/model/repositoryactioncount.py index d08a2005f..315847a4d 100644 --- a/data/model/repositoryactioncount.py +++ b/data/model/repositoryactioncount.py @@ -17,7 +17,7 @@ from data.database import ( logger = logging.getLogger(__name__) -search_bucket = namedtuple("SearchBucket", ["delta", "days", "weight"]) +search_bucket = namedtuple("search_bucket", ["delta", "days", "weight"]) # Defines the various buckets for search scoring. Each bucket is computed using the given time # delta from today *minus the previous bucket's time period*. Once all the actions over the diff --git a/data/model/storage.py b/data/model/storage.py index c3f7cffd3..0c0d80448 100644 --- a/data/model/storage.py +++ b/data/model/storage.py @@ -31,7 +31,7 @@ from util.metrics.prometheus import gc_table_rows_deleted, gc_storage_blobs_dele logger = logging.getLogger(__name__) -_Location = namedtuple("location", ["id", "name"]) +_Location = namedtuple("_Location", ["id", "name"]) EMPTY_LAYER_BLOB_DIGEST = "sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4" SPECIAL_BLOB_DIGESTS = set([EMPTY_LAYER_BLOB_DIGEST]) diff --git a/data/model/test/test_build.py b/data/model/test/test_build.py index a5612f5b7..ffe7e0baf 100644 --- a/data/model/test/test_build.py +++ b/data/model/test/test_build.py @@ -2,6 +2,7 @@ import pytest from mock import patch +from data import model from data.database import BUILD_PHASE, RepositoryBuildTrigger, RepositoryBuild from data.model.build import ( update_trigger_disable_status, diff --git a/data/model/test/test_proxy_cache_config.py b/data/model/test/test_proxy_cache_config.py index 59ade541b..91d88acb9 100644 --- a/data/model/test/test_proxy_cache_config.py +++ b/data/model/test/test_proxy_cache_config.py @@ -1,9 +1,11 @@ +import pytest from playhouse.test_utils import assert_query_count +from data.database import DEFAULT_PROXY_CACHE_EXPIRATION from data.model import InvalidOrganizationException -from data.model.proxy_cache import * from data.model.organization import create_organization -from data.database import ProxyCacheConfig, DEFAULT_PROXY_CACHE_EXPIRATION +from data.model.proxy_cache import * +from data.model.user import create_user_noverify from test.fixtures import * diff --git a/data/model/test/test_quota_model_config.py b/data/model/test/test_quota_model_config.py index 4ca97d1cd..2189b8c4a 100644 --- a/data/model/test/test_quota_model_config.py +++ b/data/model/test/test_quota_model_config.py @@ -1,5 +1,6 @@ from data.model import namespacequota from data.model.organization import create_organization +from data.model.user import create_user_noverify from test.fixtures import * diff --git a/data/model/test/test_repo_mirroring.py b/data/model/test/test_repo_mirroring.py index 237289066..561b28dd0 100644 --- a/data/model/test/test_repo_mirroring.py +++ b/data/model/test/test_repo_mirroring.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import pytest +from datetime import datetime, timedelta from jsonschema import ValidationError from data.database import RepoMirrorConfig, RepoMirrorStatus, User @@ -8,9 +10,13 @@ from data.model.repo_mirror import ( create_mirroring_rule, get_eligible_mirrors, update_sync_status_to_cancel, - MAX_SYNC_RETRIES, release_mirror, ) +from data.model.user import ( + create_robot, + create_user_noverify, + lookup_robot, +) from test.fixtures import * diff --git a/data/model/test/test_repository.py b/data/model/test/test_repository.py index 106e00d1a..77e506362 100644 --- a/data/model/test/test_repository.py +++ b/data/model/test/test_repository.py @@ -1,3 +1,4 @@ +import os from datetime import timedelta import pytest diff --git a/data/model/test/test_repositoryactioncount.py b/data/model/test/test_repositoryactioncount.py index 35544c13e..1c6301584 100644 --- a/data/model/test/test_repositoryactioncount.py +++ b/data/model/test/test_repositoryactioncount.py @@ -1,7 +1,8 @@ -from datetime import date, timedelta +from datetime import date, datetime, timedelta import pytest +from data import model from data.database import RepositoryActionCount, RepositorySearchScore from data.model.repository import create_repository, Repository from data.model.repositoryactioncount import update_repository_score, SEARCH_BUCKETS diff --git a/data/model/test/test_user.py b/data/model/test/test_user.py index 1ea54b1a9..e6d585604 100644 --- a/data/model/test/test_user.py +++ b/data/model/test/test_user.py @@ -4,6 +4,7 @@ import pytest from mock import patch +from data import model from data.database import EmailConfirmation, User, DeletedNamespace, FederatedLogin from data.model.organization import get_organization from data.model.notification import create_notification diff --git a/data/registry_model/datatypes.py b/data/registry_model/datatypes.py index 1dfe2a7ec..8ecc52fae 100644 --- a/data/registry_model/datatypes.py +++ b/data/registry_model/datatypes.py @@ -15,7 +15,7 @@ from image.docker.schema1 import DOCKER_SCHEMA1_SIGNED_MANIFEST_CONTENT_TYPE from util.bytes import Bytes -class RepositoryReference(datatype("Repository", [])): +class RepositoryReference(datatype("Repository", [])): # type: ignore[misc] """ RepositoryReference is a reference to a repository, passed to registry interface methods. """ @@ -62,7 +62,7 @@ class RepositoryReference(datatype("Repository", [])): @property # type: ignore @optionalinput("kind") - def kind(self, kind): + def kind(self, kind): # type: ignore[misc] """ Returns the kind of the repository. """ @@ -70,7 +70,7 @@ class RepositoryReference(datatype("Repository", [])): @property # type: ignore @optionalinput("is_public") - def is_public(self, is_public): + def is_public(self, is_public): # type: ignore[misc] """ Returns whether the repository is public. """ @@ -99,7 +99,7 @@ class RepositoryReference(datatype("Repository", [])): @property # type: ignore @optionalinput("namespace_name") - def namespace_name(self, namespace_name=None): + def namespace_name(self, namespace_name=None): # type: ignore[misc] """ Returns the namespace name of this repository. """ @@ -114,7 +114,7 @@ class RepositoryReference(datatype("Repository", [])): @property # type: ignore @optionalinput("is_free_namespace") - def is_free_namespace(self, is_free_namespace=None): + def is_free_namespace(self, is_free_namespace=None): # type: ignore[misc] """ Returns whether the namespace of the repository is on a free plan. """ @@ -129,7 +129,7 @@ class RepositoryReference(datatype("Repository", [])): @property # type: ignore @optionalinput("repo_name") - def name(self, repo_name=None): + def name(self, repo_name=None): # type: ignore[misc] """ Returns the name of this repository. """ @@ -144,7 +144,7 @@ class RepositoryReference(datatype("Repository", [])): @property # type: ignore @optionalinput("state") - def state(self, state=None): + def state(self, state=None): # type: ignore[misc] """ Return the state of the Repository. """ @@ -158,7 +158,7 @@ class RepositoryReference(datatype("Repository", [])): return repository.state -class Label(datatype("Label", ["key", "value", "uuid", "source_type_name", "media_type_name"])): +class Label(datatype("Label", ["key", "value", "uuid", "source_type_name", "media_type_name"])): # type: ignore[misc] """ Label represents a label on a manifest. """ @@ -178,7 +178,7 @@ class Label(datatype("Label", ["key", "value", "uuid", "source_type_name", "medi ) -class ShallowTag(datatype("ShallowTag", ["name"])): +class ShallowTag(datatype("ShallowTag", ["name"])): # type: ignore[misc] """ ShallowTag represents a tag in a repository, but only contains basic information. """ @@ -199,7 +199,7 @@ class ShallowTag(datatype("ShallowTag", ["name"])): class Tag( - datatype( + datatype( # type: ignore[misc] "Tag", [ "name", @@ -243,19 +243,19 @@ class Tag( now_ms = get_epoch_timestamp_ms() return self.lifetime_end_ms is not None and self.lifetime_end_ms <= now_ms - @property + @property # type: ignore[misc] @requiresinput("manifest_row") - def _manifest_row(self, manifest_row): + def _manifest_row(self, manifest_row): # type: ignore[misc] """ Returns the database Manifest object for this tag. """ return manifest_row - @property + @property # type: ignore[misc] @requiresinput("manifest_row") @requiresinput("legacy_id_handler") @optionalinput("legacy_image_row") - def manifest(self, manifest_row, legacy_id_handler, legacy_image_row): + def manifest(self, manifest_row, legacy_id_handler, legacy_image_row): # type: ignore[misc] """ Returns the manifest for this tag. """ @@ -265,7 +265,7 @@ class Tag( @property # type: ignore @requiresinput("repository") - def repository(self, repository): + def repository(self, repository): # type: ignore[misc] """ Returns the repository under which this tag lives. """ @@ -287,7 +287,7 @@ class Tag( class Manifest( - datatype( + datatype( # type: ignore[misc] "Manifest", [ "digest", @@ -352,15 +352,15 @@ class Manifest( @property # type: ignore @requiresinput("repository") - def repository(self, repository): + def repository(self, repository): # type: ignore[misc] """ Returns the repository under which this manifest lives. """ return repository - @property + @property # type: ignore[misc] @optionalinput("legacy_image_row") - def _legacy_image_row(self, legacy_image_row): + def _legacy_image_row(self, legacy_image_row): # type: ignore[misc] return legacy_image_row @property @@ -381,9 +381,9 @@ class Manifest( # Otherwise, return None. return None - @property + @property # type: ignore[misc] @requiresinput("legacy_id_handler") - def legacy_image_root_id(self, legacy_id_handler): + def legacy_image_root_id(self, legacy_id_handler): # type: ignore[misc] """ Returns the legacy Docker V1-style image ID for this manifest. Note that an ID will be returned even if the manifest does not support a legacy image. @@ -394,9 +394,9 @@ class Manifest( """Returns the manifest or legacy image as a manifest.""" return self - @property + @property # type: ignore[misc] @requiresinput("legacy_id_handler") - def _legacy_id_handler(self, legacy_id_handler): + def _legacy_id_handler(self, legacy_id_handler): # type: ignore[misc] return legacy_id_handler def lookup_legacy_image(self, layer_index, retriever): @@ -555,7 +555,7 @@ class ManifestLayer(namedtuple("ManifestLayer", ["layer_info", "blob"])): class Blob( - datatype("Blob", ["uuid", "digest", "compressed_size", "uncompressed_size", "uploading"]) + datatype("Blob", ["uuid", "digest", "compressed_size", "uncompressed_size", "uploading"]) # type: ignore[misc] ): """ Blob represents a content-addressable piece of storage. @@ -578,7 +578,7 @@ class Blob( @property # type: ignore @requiresinput("storage_path") - def storage_path(self, storage_path): + def storage_path(self, storage_path): # type: ignore[misc] """ Returns the path of this blob in storage. """ @@ -586,7 +586,7 @@ class Blob( @property # type: ignore @requiresinput("placements") - def placements(self, placements): + def placements(self, placements): # type: ignore[misc] """ Returns all the storage placements at which the Blob can be found. """ @@ -594,7 +594,7 @@ class Blob( class BlobUpload( - datatype( + datatype( # type: ignore[misc] "BlobUpload", [ "upload_id", @@ -629,7 +629,7 @@ class BlobUpload( ) -class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"])): +class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"])): # type: ignore[misc] """ LikelyVulnerableTag represents a tag in a repository that is likely vulnerable to a notified vulnerability. @@ -643,7 +643,7 @@ class LikelyVulnerableTag(datatype("LikelyVulnerableTag", ["layer_id", "name"])) db_id=tag.id, name=tag.name, layer_id=layer_id, inputs=dict(repository=repository) ) - @property + @property # type: ignore[misc] @requiresinput("repository") - def repository(self, repository): + def repository(self, repository): # type: ignore[misc] return RepositoryReference.for_repo_obj(repository) diff --git a/data/registry_model/registry_proxy_model.py b/data/registry_model/registry_proxy_model.py index 35478aad7..a9b43ee6d 100644 --- a/data/registry_model/registry_proxy_model.py +++ b/data/registry_model/registry_proxy_model.py @@ -498,7 +498,7 @@ class ProxyModel(OCIModel): return super().get_repo_blob_by_digest(repository_ref, blob_digest, include_placements) - def _download_blob(self, repo_ref: RepositoryReference, digest: str) -> int: + def _download_blob(self, repo_ref: RepositoryReference, digest: str) -> None: """ Download blob from upstream registry and perform a monolitic upload to Quay's own storage. diff --git a/data/registry_model/test/test_interface.py b/data/registry_model/test/test_interface.py index 4bc3ae73a..bc11b4b0c 100644 --- a/data/registry_model/test/test_interface.py +++ b/data/registry_model/test/test_interface.py @@ -2,6 +2,7 @@ import hashlib import json +import os import uuid from datetime import datetime, timedelta diff --git a/data/secscan_model/secscan_v4_model.py b/data/secscan_model/secscan_v4_model.py index 83aa1af12..0f4747c3e 100644 --- a/data/secscan_model/secscan_v4_model.py +++ b/data/secscan_model/secscan_v4_model.py @@ -50,7 +50,7 @@ logger = logging.getLogger(__name__) DEFAULT_SECURITY_SCANNER_V4_REINDEX_THRESHOLD = 86400 # 1 day -IndexReportState = namedtuple("IndexReportState", ["Index_Finished", "Index_Error"])( +IndexReportState = namedtuple("IndexReportState", ["Index_Finished", "Index_Error"])( # type: ignore[call-arg] "IndexFinished", "IndexError" ) diff --git a/data/secscan_model/test/test_secscan_v4_model.py b/data/secscan_model/test/test_secscan_v4_model.py index c3a863b24..efb26f400 100644 --- a/data/secscan_model/test/test_secscan_v4_model.py +++ b/data/secscan_model/test/test_secscan_v4_model.py @@ -15,29 +15,25 @@ from data.secscan_model.datatypes import ( ) from data.database import ( Manifest, - Repository, ManifestSecurityStatus, IndexStatus, IndexerVersion, - User, ManifestBlob, db_transaction, MediaType, ) -from data.registry_model.datatypes import Manifest as ManifestDataType from data.registry_model import registry_model from util.secscan.v4.api import APIRequestFailure -from util.canonicaljson import canonicalize from image.docker.schema2 import DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE from test.fixtures import * -from app import app, instance_keys, storage +from app import app as application, instance_keys, storage @pytest.fixture() def set_secscan_config(): - app.config["SECURITY_SCANNER_V4_ENDPOINT"] = "http://clairv4:6060" + application.config["SECURITY_SCANNER_V4_ENDPOINT"] = "http://clairv4:6060" def test_load_security_information_queued(initialized_db, set_secscan_config): @@ -45,7 +41,7 @@ def test_load_security_information_queued(initialized_db, set_secscan_config): tag = registry_model.get_repo_tag(repository_ref, "latest") manifest = registry_model.get_manifest_for_tag(tag) - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) assert secscan.load_security_information(manifest).status == ScanLookupStatus.NOT_YET_INDEXED @@ -64,7 +60,7 @@ def test_load_security_information_failed_to_index(initialized_db, set_secscan_c metadata_json={}, ) - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) assert secscan.load_security_information(manifest).status == ScanLookupStatus.FAILED_TO_INDEX @@ -83,7 +79,7 @@ def test_load_security_information_api_returns_none(initialized_db, set_secscan_ metadata_json={}, ) - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.vulnerability_report.return_value = None @@ -105,7 +101,7 @@ def test_load_security_information_api_request_failure(initialized_db, set_secsc metadata_json={}, ) - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.vulnerability_report.side_effect = APIRequestFailure() @@ -128,7 +124,7 @@ def test_load_security_information_success(initialized_db, set_secscan_config): metadata_json={}, ) - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.vulnerability_report.return_value = { "manifest_hash": manifest.digest, @@ -149,7 +145,7 @@ def test_load_security_information_success(initialized_db, set_secscan_config): def test_perform_indexing_whitelist(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.index.return_value = ( @@ -169,7 +165,7 @@ def test_perform_indexing_whitelist(initialized_db, set_secscan_config): def test_perform_indexing_failed(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.index.return_value = ( @@ -186,7 +182,7 @@ def test_perform_indexing_failed(initialized_db, set_secscan_config): indexer_hash="abc", indexer_version=IndexerVersion.V4, last_indexed=datetime.utcnow() - - timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), + - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), metadata_json={}, ) @@ -199,9 +195,9 @@ def test_perform_indexing_failed(initialized_db, set_secscan_config): def test_perform_indexing_failed_within_reindex_threshold(initialized_db, set_secscan_config): - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300 + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300 - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.index.return_value = ( @@ -229,7 +225,7 @@ def test_perform_indexing_failed_within_reindex_threshold(initialized_db, set_se def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.index.return_value = ( @@ -246,7 +242,7 @@ def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config): indexer_hash="abc", indexer_version=IndexerVersion.V4, last_indexed=datetime.utcnow() - - timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), + - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), metadata_json={}, ) @@ -259,7 +255,7 @@ def test_perform_indexing_needs_reindexing(initialized_db, set_secscan_config): def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "new hash"} secscan._secscan_api.index.return_value = ( @@ -276,7 +272,7 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_ indexer_hash="old hash", indexer_version=IndexerVersion.V4, last_indexed=datetime.utcnow() - - timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), + - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60), metadata_json={}, ) @@ -296,7 +292,7 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_ ( IndexStatus.MANIFEST_UNSUPPORTED, {"status": "old hash"}, - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, True, ), # Old hash and recent scan, don't rescan @@ -305,14 +301,14 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_ ( IndexStatus.COMPLETED, {"status": "old hash"}, - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, False, ), # New hash and old scan, don't rescan ( IndexStatus.COMPLETED, {"status": "new hash"}, - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, False, ), # New hash and recent scan, don't rescan @@ -321,14 +317,14 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_ ( IndexStatus.FAILED, {"status": "old hash"}, - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, False, ), # New hash and old scan, rescan ( IndexStatus.FAILED, {"status": "new hash"}, - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] + 60, False, ), ], @@ -336,7 +332,7 @@ def test_perform_indexing_needs_reindexing_skip_unsupported(initialized_db, set_ def test_manifest_iterator( initialized_db, set_secscan_config, index_status, indexer_state, seconds, expect_zero ): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) for manifest in Manifest.select(): with db_transaction(): @@ -360,7 +356,7 @@ def test_manifest_iterator( Manifest.select(fn.Min(Manifest.id)).scalar(), Manifest.select(fn.Max(Manifest.id)).scalar(), reindex_threshold=datetime.utcnow() - - timedelta(seconds=app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"]), + - timedelta(seconds=application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"]), ) count = 0 @@ -376,9 +372,9 @@ def test_manifest_iterator( def test_perform_indexing_needs_reindexing_within_reindex_threshold( initialized_db, set_secscan_config ): - app.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300 + application.config["SECURITY_SCANNER_V4_REINDEX_THRESHOLD"] = 300 - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.index.return_value = ( @@ -406,7 +402,7 @@ def test_perform_indexing_needs_reindexing_within_reindex_threshold( def test_perform_indexing_api_request_failure_state(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.side_effect = APIRequestFailure() @@ -418,7 +414,7 @@ def test_perform_indexing_api_request_failure_state(initialized_db, set_secscan_ def test_perform_indexing_api_request_index_error_response(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.index.return_value = ( @@ -435,7 +431,7 @@ def test_perform_indexing_api_request_index_error_response(initialized_db, set_s def test_perform_indexing_api_request_non_finished_state(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "xyz"} secscan._secscan_api.index.return_value = ( @@ -450,7 +446,7 @@ def test_perform_indexing_api_request_non_finished_state(initialized_db, set_sec def test_perform_indexing_api_request_failure_index(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.state.return_value = {"state": "abc"} secscan._secscan_api.index.side_effect = APIRequestFailure() @@ -510,7 +506,7 @@ def test_features_for(): def test_perform_indexing_invalid_manifest(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() # Delete all ManifestBlob rows to cause the manifests to be invalid. @@ -525,7 +521,7 @@ def test_perform_indexing_invalid_manifest(initialized_db, set_secscan_config): def test_lookup_notification_page_invalid(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.retrieve_notification_page.return_value = None @@ -535,7 +531,7 @@ def test_lookup_notification_page_invalid(initialized_db, set_secscan_config): def test_lookup_notification_page_valid(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.retrieve_notification_page.return_value = { "notifications": [ @@ -560,7 +556,7 @@ def test_lookup_notification_page_valid(initialized_db, set_secscan_config): def test_mark_notification_handled(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan._secscan_api.delete_notification.return_value = True @@ -568,7 +564,7 @@ def test_mark_notification_handled(initialized_db, set_secscan_config): def test_process_notification_page(initialized_db, set_secscan_config): - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) results = list( secscan.process_notification_page( @@ -614,7 +610,7 @@ def test_perform_indexing_manifest_list(initialized_db, set_secscan_config): media_type=MediaType.get(name=DOCKER_SCHEMA2_MANIFESTLIST_CONTENT_TYPE) ).execute() - secscan = V4SecurityScanner(app, instance_keys, storage) + secscan = V4SecurityScanner(application, instance_keys, storage) secscan._secscan_api = mock.Mock() secscan.perform_indexing_recent_manifests() diff --git a/data/text.py b/data/text.py index 4578c2742..0d4b570a6 100644 --- a/data/text.py +++ b/data/text.py @@ -1,4 +1,4 @@ -from peewee import NodeList, SQL, fn, TextField, Field +from peewee import NodeList, SQL, fn, Field def _escape_wildcard(search_query): diff --git a/data/users/externalldap.py b/data/users/externalldap.py index 5f5160781..88facc45d 100644 --- a/data/users/externalldap.py +++ b/data/users/externalldap.py @@ -87,7 +87,7 @@ class LDAPConnection(object): class LDAPUsers(FederatedUsers): - _LDAPResult = namedtuple("LDAPResult", ["dn", "attrs"]) + _LDAPResult = namedtuple("_LDAPResult", ["dn", "attrs"]) def __init__( self, diff --git a/endpoints/api/repository.py b/endpoints/api/repository.py index 608ed40ae..327e9287f 100644 --- a/endpoints/api/repository.py +++ b/endpoints/api/repository.py @@ -3,7 +3,6 @@ List, create and manage repositories. """ import logging -import datetime import features from collections import defaultdict diff --git a/endpoints/api/test/test_logs.py b/endpoints/api/test/test_logs.py index 3e0a2aaf3..5f872327f 100644 --- a/endpoints/api/test/test_logs.py +++ b/endpoints/api/test/test_logs.py @@ -1,4 +1,5 @@ import os +import pytest import time from mock import patch diff --git a/endpoints/api/test/test_search.py b/endpoints/api/test/test_search.py index a43a83e47..3da9b15d1 100644 --- a/endpoints/api/test/test_search.py +++ b/endpoints/api/test/test_search.py @@ -1,3 +1,4 @@ +import os import pytest from playhouse.test_utils import assert_query_count diff --git a/endpoints/common_models_interface.py b/endpoints/common_models_interface.py index 842fe1381..63580d52f 100644 --- a/endpoints/common_models_interface.py +++ b/endpoints/common_models_interface.py @@ -17,7 +17,7 @@ USER_FIELDS: List[str] = [ ] -class User(namedtuple("User", USER_FIELDS)): +class User(namedtuple("User", USER_FIELDS)): # type: ignore[misc] """ User represents a user. """ diff --git a/endpoints/oauth/login.py b/endpoints/oauth/login.py index e74e6cfce..94c8af114 100644 --- a/endpoints/oauth/login.py +++ b/endpoints/oauth/login.py @@ -33,7 +33,7 @@ oauthlogin_csrf_protect = csrf_protect( OAuthResult = namedtuple( - "oauthresult", + "OAuthResult", ["user_obj", "service_name", "error_message", "register_redirect", "requires_verification"], ) diff --git a/endpoints/test/test_decorators.py b/endpoints/test/test_decorators.py index a36dafe48..11f7626cf 100644 --- a/endpoints/test/test_decorators.py +++ b/endpoints/test/test_decorators.py @@ -1,3 +1,4 @@ +import pytest from data import model from endpoints.api import api from endpoints.api.repository import Repository diff --git a/endpoints/v2/__init__.py b/endpoints/v2/__init__.py index e6ed16742..afe229184 100644 --- a/endpoints/v2/__init__.py +++ b/endpoints/v2/__init__.py @@ -68,7 +68,7 @@ def handle_quota_error(error): return _format_error_response(QuotaExceeded()) -def _format_error_response(error: Exception) -> Response: +def _format_error_response(error: V2RegistryException) -> Response: response = jsonify({"errors": [error.as_dict()]}) response.status_code = error.http_status_code logger.debug("sending response: %s", response.get_data()) diff --git a/endpoints/v2/test/test_v2auth.py b/endpoints/v2/test/test_v2auth.py index 71c1727d4..1e3998f50 100644 --- a/endpoints/v2/test/test_v2auth.py +++ b/endpoints/v2/test/test_v2auth.py @@ -1,9 +1,10 @@ -import base64 +import pytest from flask import url_for from app import instance_keys, app as original_app -from data.model.user import regenerate_robot_token, get_robot_and_metadata, get_user +from data import model +from data.model.user import get_robot_and_metadata, get_user from endpoints.test.shared import conduct_call, gen_basic_auth from util.security.registry_jwt import decode_bearer_token, CLAIM_TUF_ROOTS diff --git a/features/__init__.pyi b/features/__init__.pyi index bb32c8a4a..24296b938 100644 --- a/features/__init__.pyi +++ b/features/__init__.pyi @@ -1,3 +1,5 @@ +from typing import Dict + class FeatureNameValue(object): def __init__(self, name: str, value: bool): ... def __str__(self) -> str: ... @@ -184,3 +186,7 @@ USER_INITIALIZE: FeatureNameValue EXTENDED_REPOSITORY_NAMES: FeatureNameValue QUOTA_MANAGEMENT: FeatureNameValue + +HELM_OCI_SUPPORT: FeatureNameValue + +PROXY_CACHE: FeatureNameValue diff --git a/image/oci/test/test_oci_manifest.py b/image/oci/test/test_oci_manifest.py index c214f9966..b38fd98ba 100644 --- a/image/oci/test/test_oci_manifest.py +++ b/image/oci/test/test_oci_manifest.py @@ -267,27 +267,6 @@ def test_get_schema1_manifest_incorrect_history(): manifest.get_schema1_manifest("somenamespace", "somename", "sometag", retriever) -def test_validate_manifest_invalid_config_type(): - manifest_bytes = """{ - "schemaVersion": 2, - "config": { - "mediaType": "application/some.other.thing", - "digest": "sha256:6bd578ec7d1e7381f63184dfe5fbe7f2f15805ecc4bfd485e286b76b1e796524", - "size": 145 - }, - "layers": [ - { - "mediaType": "application/tar+gzip", - "digest": "sha256:ce879e86a8f71031c0f1ab149a26b000b3b5b8810d8d047f240ef69a6b2516ee", - "size": 2807 - } - ] - }""" - - with pytest.raises(MalformedOCIManifest): - OCIManifest(Bytes.for_string_or_unicode(manifest_bytes)) - - def test_validate_helm_oci_manifest(): manifest_bytes = """{ "schemaVersion":2, diff --git a/mypy.ini b/mypy.ini index 842416df3..ac935b95a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,8 +1,10 @@ [mypy] python_version = 3.8 mypy_path = mypy_stubs -# local-dev is excluded until we decide what to do with __init__.py in those packages. -exclude = local-dev/ +exclude = (?x)( + ^buildman/test/test_buildman\.py$ | + ^buildman/buildman_pb/buildman_pb2\.py$ | + ^endpoints/api/test/test_security\.py$ ) # Necessary because most current dependencies do not have typing ignore_missing_imports = True @@ -12,9 +14,3 @@ ignore_missing_imports = True warn_redundant_casts = True warn_unused_ignores = True warn_unreachable = True - -# The `features` module uses some magic to declare attributes at the module level at runtime. -# mypy cant introspect this so we need to disable type checking on features.X -[mypy-features] -ignore_errors = True - diff --git a/mypy_stubs/peewee.pyi b/mypy_stubs/peewee.pyi index 9bd5780bf..c178b2eef 100644 --- a/mypy_stubs/peewee.pyi +++ b/mypy_stubs/peewee.pyi @@ -2,7 +2,6 @@ import itertools import logging import threading from collections import namedtuple -from pysqlite2 import dbapi2 as pysq3 from typing import ( Any, AnyStr, @@ -18,6 +17,7 @@ from typing import ( Optional, Sequence, Text, + Tuple as TupleT, Type, TypeVar, Union, @@ -29,9 +29,6 @@ _T = TypeVar("_T") # Manual Adjustments SENTINEL = object() -sqlite3 = pysq3 -sqlite3 = pysq3 - class NullHandler(logging.Handler): def emit(self, record: Any) -> None: ... @@ -149,7 +146,7 @@ class Source(Node): def __init__(self, alias: Optional[Any] = ...) -> None: ... def alias(self, name: Any) -> None: ... def select(self, *columns: Any): ... - def join(self, dest: Any, join_type: Any = ..., on: Optional[Any] = ...): ... + def join(self: _T, dest: Any, join_type: Any = ..., on: Optional[Any] = ...) -> _T: ... def left_outer_join(self, dest: Any, on: Optional[Any] = ...): ... def cte( self, @@ -493,7 +490,7 @@ class DQ(ColumnBase): def __invert__(self) -> None: ... def clone(self): ... -Tuple: Any +Tuple: Callable[..., NodeList] class QualifiedNames(WrappedNode): def __sql__(self, ctx: Any): ... @@ -529,7 +526,7 @@ class BaseQuery(Node): def objects(self, constructor: Optional[Any] = ...): ... def __sql__(self, ctx: Context) -> Context: ... def sql(self): ... - def execute(self, database: Any): ... + def execute(self, database: Optional[Any] = ...): ... def iterator(self, database: Optional[Any] = ...): ... def __iter__(self) -> Any: ... def __getitem__(self, value: Any): ... @@ -551,8 +548,8 @@ class Query(BaseQuery): **kwargs: Any, ) -> None: ... def with_cte(self, *cte_list: Any) -> None: ... - def where(self, *expressions: Any) -> None: ... - def orwhere(self, *expressions: Any) -> None: ... + def where(self: _T, *expressions: Any) -> _T: ... + def orwhere(self: _T, *expressions: Any) -> _T: ... def order_by(self, *values: Any) -> None: ... def order_by_extend(self, *values: Any) -> None: ... def limit(self, value: Optional[Any] = ...) -> None: ... @@ -576,19 +573,19 @@ class SelectQuery(Query): def select_from(self, *columns: Any): ... class SelectBase(_HashableSource, Source, SelectQuery): - def peek(self, database: Any, n: int = ...): ... - def first(self, database: Any, n: int = ...): ... - def scalar(self, database: Any, as_tuple: bool = ...): ... - def count(self, database: Any, clear_limit: bool = ...): ... - def exists(self, database: Any): ... - def get(self, database: Any): ... + def peek(self, database: Optional[Any] = ..., n: int = ...): ... + def first(self, database: Optional[Any] = ..., n: int = ...): ... + def scalar(self, database: Optional[Any] = ..., as_tuple: bool = ...): ... + def count(self, database: Optional[Any] = ..., clear_limit: bool = ...): ... + def exists(self, database: Optional[Any] = ...): ... + def get(self, database: Optional[Any] = ...): ... class CompoundSelectQuery(SelectBase): lhs: Any = ... op: Any = ... rhs: Any = ... def __init__(self, lhs: Any, op: Any, rhs: Any) -> None: ... - def exists(self, database: Any): ... + def exists(self, database: Optional[Any] = ...): ... def __sql__(self, ctx: Any): ... class Select(SelectBase): @@ -609,16 +606,16 @@ class Select(SelectBase): def columns(self, *columns: Any, **kwargs: Any) -> None: ... select: Any = ... def select_extend(self, *columns: Any) -> None: ... - def from_(self, *sources: Any) -> None: ... - def join(self, dest: Any, join_type: Any = ..., on: Optional[Any] = ...) -> None: ... - def group_by(self, *columns: Any) -> None: ... + def from_(self: _T, *sources: Any) -> _T: ... + def join(self: _T, dest: Any, join_type: Any = ..., on: Optional[Any] = ...) -> _T: ... + def group_by(self: _T, *columns: Any) -> _T: ... def group_by_extend(self, *values: Any): ... - def having(self, *expressions: Any) -> None: ... - def distinct(self, *columns: Any) -> None: ... - def window(self, *windows: Any) -> None: ... + def having(self: _T, *expressions: Any) -> _T: ... + def distinct(self: _T, *columns: Any) -> _T: ... + def window(self: _T, *windows: Any) -> _T: ... def for_update( - self, for_update: bool = ..., of: Optional[Any] = ..., nowait: Optional[Any] = ... - ) -> None: ... + self: _T, for_update: bool = ..., of: Optional[Any] = ..., nowait: Optional[Any] = ... + ) -> _T: ... def __sql_selection__(self, ctx: Any, is_subquery: bool = ...): ... def __sql__(self, ctx: Any): ... @@ -651,7 +648,7 @@ class Insert(_WriteQuery): on_conflict: Optional[Any] = ..., **kwargs: Any, ) -> None: ... - def where(self, *expressions: Any) -> None: ... + def where(self: _T, *expressions: Any) -> _T: ... def on_conflict_ignore(self, ignore: bool = ...) -> None: ... def on_conflict_replace(self, replace: bool = ...) -> None: ... def on_conflict(self, *args: Any, **kwargs: Any) -> None: ... @@ -1100,7 +1097,7 @@ class Field(ColumnBase): sequence: Optional[str] = ..., collation: Optional[str] = ..., unindexed: Optional[bool] = ..., - choices: Optional[Iterable[Tuple[Any, str]]] = ..., + choices: Optional[Iterable[TupleT[Any, str]]] = ..., help_text: Optional[str] = ..., verbose_name: Optional[str] = ..., index_type: Optional[str] = ..., @@ -1116,7 +1113,7 @@ class Field(ColumnBase): def db_value(self, value: _T) -> _T: ... def python_value(self, value: _T) -> _T: ... def to_value(self, value: Any) -> Value: ... - def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ... + def get_sort_key(self, ctx: Context) -> TupleT[int, int]: ... def __sql__(self, ctx: Context) -> Context: ... def get_modifiers(self) -> Any: ... def ddl_datatype(self, ctx: Context) -> SQL: ... @@ -1155,7 +1152,7 @@ class _StringField(Field): def adapt(self, value: AnyStr) -> str: ... @overload def adapt(self, value: _T) -> _T: ... - def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: ... + def split(self, sep: Optional[str] = ..., maxsplit: int = ...) -> List[str]: ... def __add__(self, other: Any) -> StringExpression: ... def __radd__(self, other: Any) -> StringExpression: ... @@ -1545,7 +1542,9 @@ class Model(Node, metaclass=ModelBase): ) -> ModelInsert: ... @overload @classmethod - def insert_many(cls, rows: Iterable[Mapping[str, object]], fields: None) -> ModelInsert: ... + def insert_many( + cls, rows: Iterable[Mapping[str, object]], fields: None = ... + ) -> ModelInsert: ... @overload @classmethod def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ... @@ -1593,7 +1592,7 @@ class Model(Node, metaclass=ModelBase): @classmethod def get_or_create( cls, *, defaults: Mapping[str, object] = ..., **kwargs: object - ) -> Tuple[Any, bool]: ... + ) -> TupleT[Any, bool]: ... @classmethod def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ... def get_id(self) -> Any: ... @@ -1605,7 +1604,7 @@ class Model(Node, metaclass=ModelBase): def dirty_fields(self) -> List[Field]: ... def dependencies( self, search_nullable: bool = ... - ) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ... + ) -> Iterator[TupleT[Union[bool, Node], ForeignKeyField]]: ... def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ... def __hash__(self) -> int: ... def __eq__(self, other: object) -> bool: ... @@ -1685,7 +1684,7 @@ class BaseModelSelect(_ModelQueryHelper): def __iter__(self) -> Any: ... def prefetch(self, *subqueries: Any): ... def get(self, database: Optional[Any] = ...): ... - def group_by(self, *columns: Any) -> None: ... + def group_by(self, *columns: Any): ... class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): model: Any = ... @@ -1695,24 +1694,24 @@ class ModelSelect(BaseModelSelect, Select): model: Any = ... def __init__(self, model: Any, fields_or_models: Any, is_default: bool = ...) -> None: ... def clone(self): ... - def select(self, *fields_or_models: Any): ... + def select(self: _T, *fields_or_models: Any) -> _T: ... def switch(self, ctx: Optional[Any] = ...): ... def join( - self, + self: _T, dest: Any, join_type: Any = ..., on: Optional[Any] = ..., src: Optional[Any] = ..., attr: Optional[Any] = ..., - ) -> None: ... + ) -> _T: ... def join_from( - self, + self: _T, src: Any, dest: Any, join_type: Any = ..., on: Optional[Any] = ..., attr: Optional[Any] = ..., - ): ... + ) -> _T: ... def ensure_join(self, lm: Any, rm: Any, on: Optional[Any] = ..., **join_kwargs: Any): ... def convert_dict_to_node(self, qdict: Any): ... def filter(self, *args: Any, **kwargs: Any): ... diff --git a/pyproject.toml b/pyproject.toml index bd5d70b4a..4a8ba6b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 100 -target-version = ['py27'] +target-version = ['py38'] [tool.pylint.messages_control] disable = "missing-docstring,invalid-name,too-many-locals,too-few-public-methods,too-many-lines" @@ -17,3 +17,6 @@ branch = true [tool.coverage.report] omit = ['test/**', 'venv/**', '**/test/**'] + +[tool.pyright] +stubPath = 'mypy_stubs' diff --git a/requirements-dev.txt b/requirements-dev.txt index 2219d3cbb..9cf146832 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,7 +11,7 @@ httmock==1.3.0 ipdb ipython mock==3.0.5 -mypy==0.910 +mypy==0.950 moto==2.0.1 parameterized==0.8.1 pytest diff --git a/requirements.txt b/requirements.txt index 322282ee6..b8017a399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ APScheduler==3.6.3 attrs==19.3.0 Authlib==1.0.0a1 aws-sam-translator==1.20.1 -azure-core==1.8.0 +azure-core==1.23.1 azure-storage-blob==12.4.0 Babel==2.9.1 bcrypt==3.1.7 @@ -40,7 +40,7 @@ futures==3.1.1 geoip2==3.0.0 gevent==21.8.0 greenlet==1.1.2 -grpcio==1.30.0 +grpcio==1.46.3 gunicorn==20.1.0 hashids==1.2.0 html5lib==1.0.1 diff --git a/storage/downloadproxy.py b/storage/downloadproxy.py index 5b11eb4a5..14d348ee6 100644 --- a/storage/downloadproxy.py +++ b/storage/downloadproxy.py @@ -1,7 +1,7 @@ import logging import base64 -import urllib.request, urllib.parse, urllib.error +from flask import Flask from urllib.parse import urlparse from flask import abort, request from jsonschema import validate, ValidationError @@ -61,7 +61,7 @@ class DownloadProxy(object): NGINX. """ - def __init__(self, app, instance_keys): + def __init__(self, app: Flask, instance_keys): self.app = app self.instance_keys = instance_keys diff --git a/storage/test/test_cloudfront.py b/storage/test/test_cloudfront.py index 6cf098cfd..e7cd589d8 100644 --- a/storage/test/test_cloudfront.py +++ b/storage/test/test_cloudfront.py @@ -1,3 +1,4 @@ +import os import pytest from contextlib import contextmanager diff --git a/storage/test/test_storageproxy.py b/storage/test/test_storageproxy.py index 0217590eb..31d943b10 100644 --- a/storage/test/test_storageproxy.py +++ b/storage/test/test_storageproxy.py @@ -1,11 +1,7 @@ import os - import pytest -import requests - -from flask import Flask -from flask_testing import LiveServerTestCase +from data.database import close_db_filter, configure from storage import Storage from util.security.instancekeys import InstanceKeys diff --git a/test/fixtures.py b/test/fixtures.py index 77693625b..91f258589 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -1,8 +1,6 @@ import os -from cachetools.func import lru_cache from collections import namedtuple -from datetime import datetime, timedelta import pytest import shutil @@ -10,7 +8,7 @@ import inspect from flask import Flask, jsonify from flask_login import LoginManager -from flask_principal import identity_loaded, Permission, Identity, identity_changed, Principal +from flask_principal import identity_loaded, Principal from flask_mail import Mail from peewee import SqliteDatabase, InternalError from mock import patch @@ -20,7 +18,7 @@ from app import app as application from auth.permissions import on_identity_loaded from data import model from data.database import close_db_filter, db, configure -from data.model.user import LoginWrappedDBUser, create_robot, lookup_robot, create_user_noverify +from data.model.user import LoginWrappedDBUser from data.userfiles import Userfiles from endpoints.api import api_bp from endpoints.appr import appr_bp @@ -43,6 +41,9 @@ from test.testconfig import FakeTransaction INIT_DB_PATH = 0 +__all__ = ["init_db_path", "database_uri", "sqlitedb_file", "appconfig", "initialized_db", "app"] + + @pytest.fixture(scope="session") def init_db_path(tmpdir_factory): """ diff --git a/test/registry/protocol_v1.py b/test/registry/protocol_v1.py index 1ea190b18..d334786fd 100644 --- a/test/registry/protocol_v1.py +++ b/test/registry/protocol_v1.py @@ -1,4 +1,5 @@ import json +from typing import Dict from io import BytesIO from enum import Enum, unique @@ -29,7 +30,7 @@ class V1ProtocolSteps(Enum): class V1Protocol(RegistryProtocol): - FAILURE_CODES = { + FAILURE_CODES: Dict[Enum, Dict[Failures, int]] = { V1ProtocolSteps.PUT_IMAGES: { Failures.INVALID_AUTHENTICATION: 403, Failures.UNAUTHENTICATED: 401, diff --git a/test/registry/protocol_v2.py b/test/registry/protocol_v2.py index 05aab74ea..aaa894245 100644 --- a/test/registry/protocol_v2.py +++ b/test/registry/protocol_v2.py @@ -1,5 +1,6 @@ import hashlib import json +from typing import Dict from enum import Enum, unique @@ -49,7 +50,7 @@ class V2ProtocolSteps(Enum): class V2Protocol(RegistryProtocol): - FAILURE_CODES = { + FAILURE_CODES: Dict[Enum, Dict[Failures, int]] = { V2ProtocolSteps.AUTH: { Failures.UNAUTHENTICATED: 401, Failures.INVALID_AUTHENTICATION: 401, diff --git a/test/registry/protocols.py b/test/registry/protocols.py index aaea88544..7c3febd29 100644 --- a/test/registry/protocols.py +++ b/test/registry/protocols.py @@ -1,5 +1,6 @@ import json import tarfile +from typing import Dict from abc import ABCMeta, abstractmethod from collections import namedtuple @@ -104,7 +105,7 @@ class RegistryProtocol(object): Interface for protocols. """ - FAILURE_CODES = {} + FAILURE_CODES: Dict[Enum, Dict[Failures, int]] = {} @abstractmethod def login(self, session, username, password, scopes, expect_success): diff --git a/test/test_api_usage.py b/test/test_api_usage.py index fcf6aaefb..1de766194 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -3943,7 +3943,7 @@ class TestLogs(ApiTestCase): json = self.getJsonResponse(UserAggregateLogs) assert "aggregated" in json - def test_org_logs(self): + def test_org_aggregate_logs(self): self.login(ADMIN_ACCESS_USER) json = self.getJsonResponse(OrgAggregateLogs, params=dict(orgname=ORGANIZATION)) diff --git a/test/test_external_jwt_authn.py b/test/test_external_jwt_authn.py index 7020465c8..c488a32d8 100644 --- a/test/test_external_jwt_authn.py +++ b/test/test_external_jwt_authn.py @@ -1,5 +1,6 @@ import base64 import unittest +from typing import Optional from datetime import datetime, timedelta from tempfile import NamedTemporaryFile @@ -191,7 +192,7 @@ class JWTAuthTestMixin: Mixin defining all the JWT auth tests. """ - maxDiff = None + maxDiff: Optional[int] = None @property def emails(self): diff --git a/test/test_keystone_auth.py b/test/test_keystone_auth.py index f316cc420..6f179307a 100644 --- a/test/test_keystone_auth.py +++ b/test/test_keystone_auth.py @@ -1,6 +1,7 @@ import json import os import unittest +from typing import Optional import requests @@ -289,7 +290,7 @@ def _create_app(requires_email=True): class KeystoneAuthTestsMixin: - maxDiff = None + maxDiff: Optional[int] = None @property def emails(self): diff --git a/util/config/configdocs/configdoc.py b/util/config/configdocs/configdoc.py index b826693ac..6d4a807dd 100644 --- a/util/config/configdocs/configdoc.py +++ b/util/config/configdocs/configdoc.py @@ -34,17 +34,22 @@ def make_custom_sort(orders): return process -SCHEMA_HTML_FILE = "schema.html" +def main(): + SCHEMA_HTML_FILE = "schema.html" -schema = json.dumps(CONFIG_SCHEMA, sort_keys=True) -schema = json.loads(schema, object_pairs_hook=OrderedDict) + schema = json.dumps(CONFIG_SCHEMA, sort_keys=True) + schema = json.loads(schema, object_pairs_hook=OrderedDict) -req = sorted(schema["required"]) -custom_sort = make_custom_sort([req]) -schema = custom_sort(schema) + req = sorted(schema["required"]) + custom_sort = make_custom_sort([req]) + schema = custom_sort(schema) -parsed_items = docsmodel.DocsModel().parse(schema)[1:] -output = html_output.HtmlOutput().generate_output(parsed_items) + parsed_items = docsmodel.DocsModel().parse(schema)[1:] + output = html_output.HtmlOutput().generate_output(parsed_items) -with open(SCHEMA_HTML_FILE, "wt") as f: - f.write(output) + with open(SCHEMA_HTML_FILE, "wt") as f: + f.write(output) + + +if __name__ == "__main__": + main() diff --git a/util/config/provider/test/test_fileprovider.py b/util/config/provider/test/test_fileprovider.py index 1241b17f9..7bd2adf94 100644 --- a/util/config/provider/test/test_fileprovider.py +++ b/util/config/provider/test/test_fileprovider.py @@ -1,3 +1,4 @@ +import os import pytest from util.config.provider import FileConfigProvider diff --git a/util/generatepresharedkey.py b/util/generatepresharedkey.py index 7503f9c1d..dab8e4f19 100644 --- a/util/generatepresharedkey.py +++ b/util/generatepresharedkey.py @@ -34,6 +34,8 @@ def valid_date(s): if __name__ == "__main__": + from cryptography.hazmat.primitives import serialization + parser = argparse.ArgumentParser(description="Generates a preshared key") parser.add_argument("service", help="The service name for which the key is being generated") parser.add_argument("name", help="The friendly name for the key") diff --git a/util/repomirror/skopeomirror.py b/util/repomirror/skopeomirror.py index 75a7dadd1..668328395 100644 --- a/util/repomirror/skopeomirror.py +++ b/util/repomirror/skopeomirror.py @@ -13,7 +13,7 @@ SKOPEO_TIMEOUT_SECONDS = 300 # tags: list of tags or empty list # stdout: stdout from skopeo subprocess # stderr: stderr from skopeo subprocess -SkopeoResults = namedtuple("SkopeoCopyResults", "success tags stdout stderr") +SkopeoResults = namedtuple("SkopeoResults", "success tags stdout stderr") class SkopeoMirror(object): diff --git a/util/secscan/v4/api.py b/util/secscan/v4/api.py index 05b5ddc63..444723369 100644 --- a/util/secscan/v4/api.py +++ b/util/secscan/v4/api.py @@ -6,6 +6,7 @@ import os import jwt import base64 import time +from typing import Dict, Callable from collections import namedtuple from datetime import datetime, timedelta @@ -124,7 +125,7 @@ class SecurityScannerAPIInterface(object): Action = namedtuple("Action", ["name", "payload"]) -actions = { +actions: Dict[str, Callable[..., Action]] = { "IndexState": lambda: Action("IndexState", ("GET", "/indexer/api/v1/index_state", None)), "Index": lambda manifest: Action("Index", ("POST", "/indexer/api/v1/index_report", manifest)), "GetIndexReport": lambda manifest_hash: Action( diff --git a/util/tufmetadata/test/test_tufmetadata.py b/util/tufmetadata/test/test_tufmetadata.py index 3e00ad472..4fed5b7dc 100644 --- a/util/tufmetadata/test/test_tufmetadata.py +++ b/util/tufmetadata/test/test_tufmetadata.py @@ -1,5 +1,6 @@ import pytest import requests +from typing import Dict, Any from mock import mock, patch from flask import Flask @@ -9,7 +10,7 @@ from test.fixtures import init_db_path from util.tufmetadata import api -valid_response = { +valid_response: Dict[str, Any] = { "signed": { "type": "Targets", "delegations": { @@ -95,7 +96,7 @@ valid_targets_with_delegation = { } -valid_delegation = { +valid_delegation: Dict[str, Any] = { "signed": { "_type": "Targets", "delegations": {"keys": {}, "roles": []}, diff --git a/workers/repomirrorworker/repomirrorworker.py b/workers/repomirrorworker/repomirrorworker.py index 3ae3c46e5..50003993c 100644 --- a/workers/repomirrorworker/repomirrorworker.py +++ b/workers/repomirrorworker/repomirrorworker.py @@ -52,10 +52,11 @@ def create_gunicorn_worker(): if __name__ == "__main__": - if os.getenv("PYDEV_DEBUG", None): + pydev_debug = os.getenv("PYDEV_DEBUG", None) + if pydev_debug: import pydevd_pycharm - host, port = os.getenv("PYDEV_DEBUG").split(":") + host, port = pydev_debug.split(":") pydevd_pycharm.settrace( host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False ) diff --git a/workers/repomirrorworker/test/test_repomirrorworker.py b/workers/repomirrorworker/test/test_repomirrorworker.py index c9d23c366..d00270395 100644 --- a/workers/repomirrorworker/test/test_repomirrorworker.py +++ b/workers/repomirrorworker/test/test_repomirrorworker.py @@ -2,6 +2,7 @@ import pytest import mock import json from functools import wraps +from unittest.mock import patch from app import storage from data.registry_model.blobuploader import upload_blob, BlobUploadSettings diff --git a/workers/securityworker/securityworker.py b/workers/securityworker/securityworker.py index b981e5db7..ba22ebf31 100644 --- a/workers/securityworker/securityworker.py +++ b/workers/securityworker/securityworker.py @@ -67,10 +67,11 @@ def create_gunicorn_worker(): if __name__ == "__main__": - if os.getenv("PYDEV_DEBUG", None): + pydev_debug = os.getenv("PYDEV_DEBUG", None) + if pydev_debug: import pydevd_pycharm - host, port = os.getenv("PYDEV_DEBUG").split(":") + host, port = pydev_debug.split(":") pydevd_pycharm.settrace( host, port=int(port), stdoutToServer=True, stderrToServer=True, suspend=False ) diff --git a/workers/servicekeyworker/test/test_servicekeyworker.py b/workers/servicekeyworker/test/test_servicekeyworker.py index 363678c07..433213108 100644 --- a/workers/servicekeyworker/test/test_servicekeyworker.py +++ b/workers/servicekeyworker/test/test_servicekeyworker.py @@ -1,7 +1,6 @@ from datetime import datetime, timedelta from mock import patch -from data import model from workers.servicekeyworker.servicekeyworker import ServiceKeyWorker from util.morecollections import AttrDict diff --git a/workers/test/test_exportactionlogsworker.py b/workers/test/test_exportactionlogsworker.py index 237ddf5ae..f73a4d220 100644 --- a/workers/test/test_exportactionlogsworker.py +++ b/workers/test/test_exportactionlogsworker.py @@ -1,5 +1,6 @@ import json import os +import pytest from datetime import datetime, timedelta diff --git a/workers/test/test_logrotateworker.py b/workers/test/test_logrotateworker.py index dbe40a548..32b703c93 100644 --- a/workers/test/test_logrotateworker.py +++ b/workers/test/test_logrotateworker.py @@ -1,5 +1,6 @@ import os.path - +import pytest +from unittest.mock import patch from datetime import datetime, timedelta from app import storage diff --git a/workers/test/test_securityscanningnotificationworker.py b/workers/test/test_securityscanningnotificationworker.py index 0bfa8de7c..83796a1b3 100644 --- a/workers/test/test_securityscanningnotificationworker.py +++ b/workers/test/test_securityscanningnotificationworker.py @@ -1,4 +1,5 @@ import json +import os import pytest from urllib.parse import urlparse