diff --git a/data/billing.py b/data/billing.py index 6e5c9e2bf..bac9a436a 100644 --- a/data/billing.py +++ b/data/billing.py @@ -223,6 +223,7 @@ PLANS = [ "privateRepos": 5, "stripeId": "personal-2018", "rh_sku": "MW00584MO", + "sku_billing": False, "audience": "Individuals", "bus_features": False, "deprecated": False, @@ -235,6 +236,7 @@ PLANS = [ "price": 3000, "privateRepos": 10, "rh_sku": "MW00585MO", + "sku_billing": False, "stripeId": "bus-micro-2018", "audience": "For startups", "bus_features": True, @@ -248,6 +250,7 @@ PLANS = [ "price": 6000, "privateRepos": 20, "rh_sku": "MW00586MO", + "sku_billing": False, "stripeId": "bus-small-2018", "audience": "For small businesses", "bus_features": True, @@ -261,6 +264,7 @@ PLANS = [ "price": 12500, "privateRepos": 50, "rh_sku": "MW00587MO", + "sku_billing": False, "stripeId": "bus-medium-2018", "audience": "For normal businesses", "bus_features": True, @@ -274,6 +278,7 @@ PLANS = [ "price": 25000, "privateRepos": 125, "rh_sku": "MW00588MO", + "sku_billing": False, "stripeId": "bus-large-2018", "audience": "For large businesses", "bus_features": True, @@ -313,6 +318,7 @@ PLANS = [ "price": 160000, "privateRepos": 1000, "rh_sku": "MW00591MO", + "sku_billing": False, "stripeId": "bus-1000-2018", "audience": "For the SaaS savvy enterprise", "bus_features": True, @@ -326,6 +332,7 @@ PLANS = [ "price": 310000, "privateRepos": 2000, "rh_sku": "MW00592MO", + "sku_billing": False, "stripeId": "bus-2000-2018", "audience": "For the SaaS savvy big enterprise", "bus_features": True, @@ -346,9 +353,25 @@ PLANS = [ "superseded_by": None, "plans_page_hidden": False, }, + { + "title": "subscriptionwatch", + "privateRepos": 100, + "stripeId": "not_a_stripe_plan", + "rh_sku": "MW02701", + "sku_billing": True, + "plans_page_hidden": True, + }, ] -RH_SKUS = [plan["rh_sku"] for plan in PLANS if plan.get("rh_sku") is not None] +RH_SKUS = [ + plan["rh_sku"] for plan in PLANS if plan.get("rh_sku") is not None and plan.get("sku_billing") +] + +RECONCILER_SKUS = [ + plan["rh_sku"] + for plan in PLANS + if plan.get("rh_sku") is not None and not plan.get("sku_billing") +] def get_plan(plan_id): diff --git a/data/model/entitlements.py b/data/model/entitlements.py index aa09f1f4f..a08c1a272 100644 --- a/data/model/entitlements.py +++ b/data/model/entitlements.py @@ -6,7 +6,7 @@ from data.database import RedHatSubscriptions logger = logging.getLogger(__name__) -def get_ebs_account_number(user_id): +def get_web_customer_id(user_id): try: query = RedHatSubscriptions.get(RedHatSubscriptions.user_id == user_id).account_number return query @@ -14,8 +14,29 @@ def get_ebs_account_number(user_id): return None -def save_ebs_account_number(user, ebsAccountNumber): +def save_web_customer_id(user, web_customer_id): try: - return RedHatSubscriptions.create(user_id=user.id, account_number=ebsAccountNumber) + return RedHatSubscriptions.create(user_id=user.id, account_number=web_customer_id) except model.DataModelException as ex: logger.error("Problem saving account number for %s: %s", user.username, ex) + + +def update_web_customer_id(user, web_customer_id): + try: + query = RedHatSubscriptions.update( + {RedHatSubscriptions.account_number: web_customer_id} + ).where(RedHatSubscriptions.user_id == user.id) + query.execute() + except model.DataModelException as ex: + logger.error("Problem updating customer id for %s: %s", user.username, ex) + + +def remove_web_customer_id(user, web_customer_id): + try: + customer_id = RedHatSubscriptions.get( + RedHatSubscriptions.user_id == user.id, + RedHatSubscriptions.account_number == web_customer_id, + ) + return customer_id.delete_instance() + except model.DataModelException as ex: + logger.error("Problem removing customer id for %s: %s", user.username, ex) diff --git a/endpoints/api/test/test_superuser.py b/endpoints/api/test/test_superuser.py index c7b5d0412..b4d1d0d2e 100644 --- a/endpoints/api/test/test_superuser.py +++ b/endpoints/api/test/test_superuser.py @@ -1,5 +1,3 @@ -from test.fixtures import * - import pytest from data.database import DeletedNamespace, User @@ -10,6 +8,7 @@ from endpoints.api.superuser import ( ) from endpoints.api.test.shared import conduct_api_call from endpoints.test.shared import client_with_identity +from test.fixtures import * @pytest.mark.parametrize( @@ -32,7 +31,7 @@ def test_list_all_users(disabled, app): def test_list_all_orgs(app): with client_with_identity("devtable", app) as cl: result = conduct_api_call(cl, SuperUserOrganizationList, "GET", None, None, 200).json - assert len(result["organizations"]) == 5 + assert len(result["organizations"]) == 6 def test_paginate_orgs(app): @@ -45,7 +44,7 @@ def test_paginate_orgs(app): secondResult = conduct_api_call( cl, SuperUserOrganizationList, "GET", params, None, 200 ).json - assert len(secondResult["organizations"]) == 2 + assert len(secondResult["organizations"]) == 3 assert secondResult.get("next_page", None) is None @@ -57,7 +56,7 @@ def test_paginate_test_list_all_users(app): assert firstResult["next_page"] is not None params["next_page"] = firstResult["next_page"] secondResult = conduct_api_call(cl, SuperUserList, "GET", params, None, 200).json - assert len(secondResult["users"]) == 4 + assert len(secondResult["users"]) == 5 assert secondResult.get("next_page", None) is None diff --git a/initdb.py b/initdb.py index 0767f8613..456c6f97b 100644 --- a/initdb.py +++ b/initdb.py @@ -646,6 +646,12 @@ def populate_database(minimal=False): outside_org.verified = True outside_org.save() + subscriptionuser = model.user.create_user( + "subscription", "password", "subscriptions@devtable.com" + ) + subscriptionuser.verified = True + subscriptionuser.save() + model.notification.create_notification( "test_notification", new_user_1, @@ -925,6 +931,11 @@ def populate_database(minimal=False): ) thirdorg.save() + subscriptionsorg = model.organization.create_organization( + "subscriptionsorg", "quay+subscriptionsorg@devtable.com", subscriptionuser + ) + subscriptionsorg.save() + model.user.create_robot("coolrobot", org) proxyorg = model.organization.create_organization( diff --git a/static/js/directives/ui/plan-manager.js b/static/js/directives/ui/plan-manager.js index 071698534..4375fe63c 100644 --- a/static/js/directives/ui/plan-manager.js +++ b/static/js/directives/ui/plan-manager.js @@ -30,7 +30,7 @@ angular.module('quay').directive('planManager', function () { } // A plan is visible if it is not deprecated, or if it is the namespace's current plan. - if (plan['deprecated']) { + if (plan['deprecated'] || plan['plans_page_hidden']) { return subscribedPlan && plan.stripeId === subscribedPlan.stripeId; } @@ -41,7 +41,7 @@ angular.module('quay').directive('planManager', function () { if (!subscribedPlan) { return false; } - + return plan.stripeId === subscribedPlan.stripeId; }; @@ -122,4 +122,3 @@ angular.module('quay').directive('planManager', function () { }; return directiveDefinitionObject; }); - diff --git a/test/test_api_usage.py b/test/test_api_usage.py index d84b385bd..1a995e6cd 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -8,7 +8,6 @@ import time import unittest from calendar import timegm from contextlib import contextmanager -from test.helpers import assert_action_logged, check_transitive_modifications from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from cryptography.hazmat.backends import default_backend @@ -144,6 +143,7 @@ from endpoints.api.user import ( from endpoints.building import PreparedBuild from endpoints.webhooks import webhooks from initdb import finished_database_for_testing, setup_database_for_testing +from test.helpers import assert_action_logged, check_transitive_modifications from util.morecollections import AttrDict from util.secscan.v4.fake import fake_security_scanner @@ -176,6 +176,9 @@ ORG_REPO = "orgrepo" ORGANIZATION = "buynlarge" +SUBSCRIPTION_USER = "subscription" +SUBSCRIPTION_ORG = "subscriptionsorg" + NEW_USER_DETAILS = { "username": "bobby", "password": "password", @@ -5069,57 +5072,57 @@ class TestSuperUserManagement(ApiTestCase): class TestOrganizationRhSku(ApiTestCase): def test_bind_sku_to_org(self): - self.login(ADMIN_ACCESS_USER) + self.login(SUBSCRIPTION_USER) self.postResponse( resource_name=OrganizationRhSku, - params=dict(orgname=ORGANIZATION), - data={"subscription_id": 12345}, + params=dict(orgname=SUBSCRIPTION_ORG), + data={"subscription_id": 12345678}, expected_code=201, ) json = self.getJsonResponse( resource_name=OrganizationRhSku, - params=dict(orgname=ORGANIZATION), + params=dict(orgname=SUBSCRIPTION_ORG), ) self.assertEqual(len(json), 1) def test_bind_sku_duplicate(self): - user = model.user.get_user(ADMIN_ACCESS_USER) - org = model.organization.get_organization(ORGANIZATION) - model.organization_skus.bind_subscription_to_org(12345, org.id, user.id) - self.login(ADMIN_ACCESS_USER) + user = model.user.get_user(SUBSCRIPTION_USER) + org = model.organization.get_organization(SUBSCRIPTION_ORG) + model.organization_skus.bind_subscription_to_org(12345678, org.id, user.id) + self.login(SUBSCRIPTION_USER) self.postResponse( resource_name=OrganizationRhSku, - params=dict(orgname=ORGANIZATION), - data={"subscription_id": 12345}, + params=dict(orgname=SUBSCRIPTION_ORG), + data={"subscription_id": 12345678}, expected_code=400, ) def test_bind_sku_unauthorized(self): # bind a sku that user does not own - self.login(ADMIN_ACCESS_USER) + self.login(SUBSCRIPTION_USER) self.postResponse( resource_name=OrganizationRhSku, - params=dict(orgname=ORGANIZATION), - data={"subscription_id": 11111}, + params=dict(orgname=SUBSCRIPTION_ORG), + data={"subscription_id": 11111111}, expected_code=401, ) def test_remove_sku_from_org(self): - self.login(ADMIN_ACCESS_USER) + self.login(SUBSCRIPTION_USER) self.postResponse( resource_name=OrganizationRhSku, - params=dict(orgname=ORGANIZATION), - data={"subscription_id": 12345}, + params=dict(orgname=SUBSCRIPTION_ORG), + data={"subscription_id": 12345678}, expected_code=201, ) self.deleteResponse( resource_name=OrganizationRhSkuSubscriptionField, - params=dict(orgname=ORGANIZATION, subscription_id=12345), + params=dict(orgname=SUBSCRIPTION_ORG, subscription_id=12345678), expected_code=204, ) json = self.getJsonResponse( resource_name=OrganizationRhSku, - params=dict(orgname=ORGANIZATION), + params=dict(orgname=SUBSCRIPTION_ORG), ) self.assertEqual(len(json), 0) diff --git a/util/marketplace.py b/util/marketplace.py index ff573dfd1..4fd853cb7 100644 --- a/util/marketplace.py +++ b/util/marketplace.py @@ -22,12 +22,12 @@ class RedHatUserApi(object): def get_account_number(self, user): email = user.email - account_number = entitlements.get_ebs_account_number(user.id) + account_number = entitlements.get_web_customer_id(user.id) if account_number is None: account_number = self.lookup_customer_id(email) if account_number: # store in database for next lookup - entitlements.save_ebs_account_number(user, account_number) + entitlements.save_web_customer_id(user, account_number) return account_number def lookup_customer_id(self, email): @@ -70,10 +70,8 @@ class RedHatUserApi(object): return None for account in info: if account["accountRelationships"][0]["account"]["type"] == "person": - account_number = account["accountRelationships"][0]["account"].get( - "ebsAccountNumber" - ) - return account_number + customer_id = account["accountRelationships"][0]["account"].get("id") + return customer_id return None @@ -84,15 +82,15 @@ class RedHatSubscriptionApi(object): "ENTITLEMENT_RECONCILIATION_MARKETPLACE_ENDPOINT" ) - def lookup_subscription(self, ebsAccountNumber, skuId): + def lookup_subscription(self, webCustomerId, skuId): """ Use internal marketplace API to find subscription for customerId and sku """ logger.debug( - "looking up subscription sku %s for account %s", str(skuId), str(ebsAccountNumber) + "looking up subscription sku %s for account %s", str(skuId), str(webCustomerId) ) - subscriptions_url = f"{self.marketplace_endpoint}/subscription/v5/search/criteria;sku={skuId};web_customer_id={ebsAccountNumber}" + subscriptions_url = f"{self.marketplace_endpoint}/subscription/v5/search/criteria;sku={skuId};web_customer_id={webCustomerId}" request_headers = {"Content-Type": "application/json"} # Using CustomerID to get active subscription for user @@ -225,39 +223,72 @@ class RedHatSubscriptionApi(object): """ subscription_list = [] for sku in RH_SKUS: - user_subscription = self.lookup_subscription(account_number, sku) - if user_subscription is not None: - bound_to_org = organization_skus.subscription_bound_to_org(user_subscription["id"]) + subscriptions = self.lookup_subscription(account_number, sku) + if subscriptions: + for user_subscription in subscriptions: + if user_subscription is not None: + bound_to_org = organization_skus.subscription_bound_to_org( + user_subscription["id"] + ) - if filter_out_org_bindings and bound_to_org[0]: - continue + if filter_out_org_bindings and bound_to_org[0]: + continue - if convert_to_stripe_plans: - subscription_list.append(get_plan_using_rh_sku(sku)) - else: - # add in sku field for convenience - user_subscription["sku"] = sku - subscription_list.append(user_subscription) + if convert_to_stripe_plans: + subscription_list.append(get_plan_using_rh_sku(sku)) + else: + # add in sku field for convenience + user_subscription["sku"] = sku + subscription_list.append(user_subscription) return subscription_list TEST_USER = { "account_number": 12345, - "email": "test_user@test.com", - "username": "test_user", - "password": "password", + "email": "subscriptions@devtable.com", + "username": "subscription", + "subscriptions": [ + { + "id": 12345678, + "masterEndSystemName": "Quay", + "createdEndSystemName": "SUBSCRIPTION", + "createdDate": 1675957362000, + "lastUpdateEndSystemName": "SUBSCRIPTION", + "lastUpdateDate": 1675957362000, + "installBaseStartDate": 1707368400000, + "installBaseEndDate": 1707368399000, + "webCustomerId": 123456, + "subscriptionNumber": "12399889", + "quantity": 1, + "effectiveStartDate": 1707368400000, + "effectiveEndDate": 3813177600, + }, + { + "id": 11223344, + "masterEndSystemName": "Quay", + "createdEndSystemName": "SUBSCRIPTION", + "createdDate": 1675957362000, + "lastUpdateEndSystemName": "SUBSCRIPTION", + "lastUpdateDate": 1675957362000, + "installBaseStartDate": 1707368400000, + "installBaseEndDate": 1707368399000, + "webCustomerId": 123456, + "subscriptionNumber": "12399889", + "quantity": 1, + "effectiveStartDate": 1707368400000, + "effectiveEndDate": 3813177600, + }, + ], } +STRIPE_USER = {"account_number": 11111, "email": "stripe_user@test.com", "username": "stripe_user"} FREE_USER = { "account_number": 23456, "email": "free_user@test.com", "username": "free_user", - "password": "password", } -DEV_ACCOUNT_NUMBER = 76543 - -class FakeUserApi(object): +class FakeUserApi(RedHatUserApi): """ Fake class used for tests """ @@ -267,15 +298,12 @@ class FakeUserApi(object): return TEST_USER["account_number"] if email == FREE_USER["email"]: return FREE_USER["account_number"] + if email == STRIPE_USER["email"]: + return STRIPE_USER["account_number"] return None - def get_account_number(self, user): - if user.username == "devtable": - return DEV_ACCOUNT_NUMBER - return self.lookup_customer_id(user.email) - -class FakeSubscriptionApi(object): +class FakeSubscriptionApi(RedHatSubscriptionApi): """ Fake class used for tests """ @@ -285,6 +313,8 @@ class FakeSubscriptionApi(object): self.subscription_created = False def lookup_subscription(self, customer_id, sku_id): + if customer_id == TEST_USER["account_number"] and sku_id == "MW02701": + return TEST_USER["subscriptions"] return None def create_entitlement(self, customer_id, sku_id): @@ -294,24 +324,12 @@ class FakeSubscriptionApi(object): self.subscription_extended = True def get_subscription_sku(self, subscription_id): - if id == 12345: - return "FakeSku" + valid_ids = [subscription["id"] for subscription in TEST_USER["subscriptions"]] + if subscription_id in valid_ids: + return "MW02701" else: return None - def get_list_of_subscriptions( - self, account_number, filter_out_org_bindings=False, convert_to_stripe_plans=False - ): - if account_number == DEV_ACCOUNT_NUMBER: - return [ - { - "id": 12345, - "sku": "FakeSku", - "privateRepos": 0, - } - ] - return [] - class MarketplaceUserApi(object): def __init__(self, app=None): @@ -323,10 +341,13 @@ class MarketplaceUserApi(object): def init_app(self, app): marketplace_enabled = app.config.get("FEATURE_RH_MARKETPLACE", False) + reconciler_enabled = app.config.get("ENTITLEMENT_RECONCILIATION", False) - marketplace_user_api = FakeUserApi() + use_rh_api = marketplace_enabled or reconciler_enabled - if marketplace_enabled and not app.config.get("TESTING"): + marketplace_user_api = FakeUserApi(app.config) + + if use_rh_api and not app.config.get("TESTING"): marketplace_user_api = RedHatUserApi(app.config) app.extensions = getattr(app, "extensions", {}) @@ -346,11 +367,14 @@ class MarketplaceSubscriptionApi(object): self.state = None def init_app(self, app): + reconciler_enabled = app.config.get("ENTITLEMENT_RECONCILIATION", False) marketplace_enabled = app.config.get("FEATURE_RH_MARKETPLACE", False) + use_rh_api = marketplace_enabled or reconciler_enabled + marketplace_subscription_api = FakeSubscriptionApi() - if marketplace_enabled and not app.config.get("TESTING"): + if use_rh_api and not app.config.get("TESTING"): marketplace_subscription_api = RedHatSubscriptionApi(app.config) app.extensions = getattr(app, "extensions", {}) diff --git a/util/test/test_marketplace.py b/util/test/test_marketplace.py index b606a6e51..982b056d9 100644 --- a/util/test/test_marketplace.py +++ b/util/test/test_marketplace.py @@ -40,7 +40,7 @@ mocked_user_service_response = [ "startDate": "2022-09-20T14:31:09.974Z", "id": "fakeid", "account": { - "id": "fakeid", + "id": "000000000", "cdhPartyNumber": "0000000", "ebsAccountNumber": "1234567", "name": "Test User", @@ -119,7 +119,7 @@ class TestMarketplace: requests_mock.return_value.content = json.dumps(mocked_user_service_response) customer_id = user_api.lookup_customer_id("example@example.com") - assert customer_id == "1234567" + assert customer_id == "000000000" requests_mock.return_value.content = json.dumps(mocked_organization_only_response) customer_id = user_api.lookup_customer_id("example@example.com") diff --git a/workers/reconciliationworker.py b/workers/reconciliationworker.py index 872229a1a..e7fc1a9ec 100644 --- a/workers/reconciliationworker.py +++ b/workers/reconciliationworker.py @@ -7,9 +7,8 @@ from app import app from app import billing as stripe from app import marketplace_subscriptions, marketplace_users from data import model -from data.billing import RH_SKUS, get_plan +from data.billing import RECONCILER_SKUS, get_plan from data.model import entitlements -from util import marketplace from util.locking import GlobalLock, LockNotAcquiredException from workers.gunicorn_worker import GunicornWorker from workers.namespacegcworker import LOCK_TIMEOUT_PADDING @@ -48,21 +47,38 @@ class ReconciliationWorker(Worker): for user in stripe_users: email = user.email - ebsAccountNumber = entitlements.get_ebs_account_number(user.id) + model_customer_id = entitlements.get_web_customer_id(user.id) logger.debug( - "Database returned %s account number for %s", str(ebsAccountNumber), user.username + "Database returned %s customer id for %s", str(model_customer_id), user.username ) - # go to user api if no ebsAccountNumber is found - if ebsAccountNumber is None: - logger.debug("Looking up ebsAccountNumber for email %s", email) - ebsAccountNumber = user_api.lookup_customer_id(email) - logger.debug("Found %s number for %s", str(ebsAccountNumber), user.username) - if ebsAccountNumber: - entitlements.save_ebs_account_number(user, ebsAccountNumber) + # check against user api + customer_id = user_api.lookup_customer_id(email) + logger.debug("Found %s number for %s", str(customer_id), email) + + if model_customer_id is None and customer_id: + logger.debug("Saving new customer id %s for %s", customer_id, user.username) + entitlements.save_web_customer_id(user, customer_id) + elif model_customer_id != customer_id: + # what is in the database differs from the service + # take the service and store in the database instead + if customer_id: + logger.debug( + "Reconciled differing ids for %s, changing from %s to %s", + user.username, + model_customer_id, + customer_id, + ) + entitlements.update_web_customer_id(user, customer_id) else: - logger.debug("User %s does not have an account number", user.username) - continue + # user does not have a web customer id from api and should be removed from table + logger.debug( + "Removing conflicting id %s for %s", model_customer_id, user.username + ) + entitlements.remove_web_customer_id(user, model_customer_id) + elif customer_id is None: + logger.debug("User %s does not have an account number", user.username) + continue # check if we need to create a subscription for customer in RH marketplace try: @@ -73,15 +89,19 @@ class ReconciliationWorker(Worker): except stripe.error.InvalidRequestError: logger.warn("Invalid request for stripe_id %s", user.stripe_id) continue - for sku_id in RH_SKUS: + for sku_id in RECONCILER_SKUS: if stripe_customer.subscription: plan = get_plan(stripe_customer.subscription.plan.id) if plan is None: continue if plan.get("rh_sku") == sku_id: - subscription = marketplace_api.lookup_subscription(ebsAccountNumber, sku_id) + subscription = marketplace_api.lookup_subscription(customer_id, sku_id) if subscription is None: - marketplace_api.create_entitlement(ebsAccountNumber, sku_id) + logger.debug("Found %s to create for %s", sku_id, user.username) + marketplace_api.create_entitlement(customer_id, sku_id) + break + else: + logger.debug("User %s does not have a stripe subscription", user.username) logger.debug("Finished work for user %s", user.username) diff --git a/workers/test/test_reconciliationworker.py b/workers/test/test_reconciliationworker.py index c4c785a7b..bec7c4f8f 100644 --- a/workers/test/test_reconciliationworker.py +++ b/workers/test/test_reconciliationworker.py @@ -3,34 +3,21 @@ import string from unittest.mock import patch from app import billing as stripe +from app import marketplace_subscriptions, marketplace_users from data import model from test.fixtures import * -from util.marketplace import FakeSubscriptionApi, FakeUserApi from workers.reconciliationworker import ReconciliationWorker -user_api = FakeUserApi() -marketplace_api = FakeSubscriptionApi() worker = ReconciliationWorker() -def test_create_for_stripe_user(initialized_db): - - test_user = model.user.create_user("test_user", "password", "test_user@test.com") - test_user.stripe_id = "cus_" + "".join(random.choices(string.ascii_lowercase, k=14)) - test_user.save() - with patch.object(marketplace_api, "create_entitlement") as mock: - worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) - - mock.assert_called() - - def test_skip_free_user(initialized_db): free_user = model.user.create_user("free_user", "password", "free_user@test.com") free_user.save() - with patch.object(marketplace_api, "create_entitlement") as mock: - worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) + with patch.object(marketplace_subscriptions, "create_entitlement") as mock: + worker._perform_reconciliation(marketplace_users, marketplace_subscriptions) mock.assert_not_called() @@ -38,7 +25,38 @@ def test_skip_free_user(initialized_db): def test_exception_handling(initialized_db): with patch("data.billing.FakeStripe.Customer.retrieve") as mock: mock.side_effect = stripe.error.InvalidRequestException - worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) + worker._perform_reconciliation(marketplace_users, marketplace_subscriptions) with patch("data.billing.FakeStripe.Customer.retrieve") as mock: mock.side_effect = stripe.error.APIConnectionError - worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) + worker._perform_reconciliation(marketplace_users, marketplace_subscriptions) + + +def test_create_for_stripe_user(initialized_db): + + test_user = model.user.create_user("stripe_user", "password", "stripe_user@test.com") + test_user.stripe_id = "cus_" + "".join(random.choices(string.ascii_lowercase, k=14)) + test_user.save() + with patch.object(marketplace_subscriptions, "create_entitlement") as mock: + worker._perform_reconciliation(marketplace_users, marketplace_subscriptions) + + # expect that entitlment is created with customer id number + mock.assert_called_with(model.entitlements.get_web_customer_id(test_user.id), "FakeSKU") + + +def test_reconcile_different_ids(initialized_db): + test_user = model.user.create_user("stripe_user", "password", "stripe_user@test.com") + test_user.stripe_id = "cus_" + "".join(random.choices(string.ascii_lowercase, k=14)) + test_user.save() + model.entitlements.save_web_customer_id(test_user, 12345) + + worker._perform_reconciliation(marketplace_users, marketplace_subscriptions) + + new_id = model.entitlements.get_web_customer_id(test_user.id) + assert new_id != 12345 + assert new_id == marketplace_users.lookup_customer_id(test_user.email) + + # make sure it will remove account numbers from db that do not belong + with patch.object(marketplace_users, "lookup_customer_id") as mock: + mock.return_value = None + worker._perform_reconciliation(marketplace_users, marketplace_subscriptions) + assert model.entitlements.get_web_customer_id(test_user.id) is None