1
0
mirror of https://github.com/quay/quay.git synced 2025-07-30 07:43:13 +03:00

marketplace: update reconciliationworker to use webCustomerId instead of ebsAccountNumber (PROJQUAY-233) (#2582)

* update reconciliationworker to use webCustomerId instead of
ebsAccountNumber

* fix reconciler where it was incorrectly using the ebsAccountNumber to
  create subscriptions
* add job to reconciler so that it reconciles different ids between the
  database and the user api
* separate skus to be used by billing and skus to be used by reconciler
This commit is contained in:
Marcus Kok
2024-01-05 16:15:37 -05:00
committed by GitHub
parent 7357e317d6
commit 1c893baba5
10 changed files with 236 additions and 118 deletions

View File

@ -223,6 +223,7 @@ PLANS = [
"privateRepos": 5, "privateRepos": 5,
"stripeId": "personal-2018", "stripeId": "personal-2018",
"rh_sku": "MW00584MO", "rh_sku": "MW00584MO",
"sku_billing": False,
"audience": "Individuals", "audience": "Individuals",
"bus_features": False, "bus_features": False,
"deprecated": False, "deprecated": False,
@ -235,6 +236,7 @@ PLANS = [
"price": 3000, "price": 3000,
"privateRepos": 10, "privateRepos": 10,
"rh_sku": "MW00585MO", "rh_sku": "MW00585MO",
"sku_billing": False,
"stripeId": "bus-micro-2018", "stripeId": "bus-micro-2018",
"audience": "For startups", "audience": "For startups",
"bus_features": True, "bus_features": True,
@ -248,6 +250,7 @@ PLANS = [
"price": 6000, "price": 6000,
"privateRepos": 20, "privateRepos": 20,
"rh_sku": "MW00586MO", "rh_sku": "MW00586MO",
"sku_billing": False,
"stripeId": "bus-small-2018", "stripeId": "bus-small-2018",
"audience": "For small businesses", "audience": "For small businesses",
"bus_features": True, "bus_features": True,
@ -261,6 +264,7 @@ PLANS = [
"price": 12500, "price": 12500,
"privateRepos": 50, "privateRepos": 50,
"rh_sku": "MW00587MO", "rh_sku": "MW00587MO",
"sku_billing": False,
"stripeId": "bus-medium-2018", "stripeId": "bus-medium-2018",
"audience": "For normal businesses", "audience": "For normal businesses",
"bus_features": True, "bus_features": True,
@ -274,6 +278,7 @@ PLANS = [
"price": 25000, "price": 25000,
"privateRepos": 125, "privateRepos": 125,
"rh_sku": "MW00588MO", "rh_sku": "MW00588MO",
"sku_billing": False,
"stripeId": "bus-large-2018", "stripeId": "bus-large-2018",
"audience": "For large businesses", "audience": "For large businesses",
"bus_features": True, "bus_features": True,
@ -313,6 +318,7 @@ PLANS = [
"price": 160000, "price": 160000,
"privateRepos": 1000, "privateRepos": 1000,
"rh_sku": "MW00591MO", "rh_sku": "MW00591MO",
"sku_billing": False,
"stripeId": "bus-1000-2018", "stripeId": "bus-1000-2018",
"audience": "For the SaaS savvy enterprise", "audience": "For the SaaS savvy enterprise",
"bus_features": True, "bus_features": True,
@ -326,6 +332,7 @@ PLANS = [
"price": 310000, "price": 310000,
"privateRepos": 2000, "privateRepos": 2000,
"rh_sku": "MW00592MO", "rh_sku": "MW00592MO",
"sku_billing": False,
"stripeId": "bus-2000-2018", "stripeId": "bus-2000-2018",
"audience": "For the SaaS savvy big enterprise", "audience": "For the SaaS savvy big enterprise",
"bus_features": True, "bus_features": True,
@ -346,9 +353,25 @@ PLANS = [
"superseded_by": None, "superseded_by": None,
"plans_page_hidden": False, "plans_page_hidden": False,
}, },
{
"title": "subscriptionwatch",
"privateRepos": 100,
"stripeId": "not_a_stripe_plan",
"rh_sku": "MW02701",
"sku_billing": True,
"plans_page_hidden": True,
},
] ]
RH_SKUS = [plan["rh_sku"] for plan in PLANS if plan.get("rh_sku") is not None] RH_SKUS = [
plan["rh_sku"] for plan in PLANS if plan.get("rh_sku") is not None and plan.get("sku_billing")
]
RECONCILER_SKUS = [
plan["rh_sku"]
for plan in PLANS
if plan.get("rh_sku") is not None and not plan.get("sku_billing")
]
def get_plan(plan_id): def get_plan(plan_id):

View File

@ -6,7 +6,7 @@ from data.database import RedHatSubscriptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_ebs_account_number(user_id): def get_web_customer_id(user_id):
try: try:
query = RedHatSubscriptions.get(RedHatSubscriptions.user_id == user_id).account_number query = RedHatSubscriptions.get(RedHatSubscriptions.user_id == user_id).account_number
return query return query
@ -14,8 +14,29 @@ def get_ebs_account_number(user_id):
return None return None
def save_ebs_account_number(user, ebsAccountNumber): def save_web_customer_id(user, web_customer_id):
try: try:
return RedHatSubscriptions.create(user_id=user.id, account_number=ebsAccountNumber) return RedHatSubscriptions.create(user_id=user.id, account_number=web_customer_id)
except model.DataModelException as ex: except model.DataModelException as ex:
logger.error("Problem saving account number for %s: %s", user.username, ex) logger.error("Problem saving account number for %s: %s", user.username, ex)
def update_web_customer_id(user, web_customer_id):
try:
query = RedHatSubscriptions.update(
{RedHatSubscriptions.account_number: web_customer_id}
).where(RedHatSubscriptions.user_id == user.id)
query.execute()
except model.DataModelException as ex:
logger.error("Problem updating customer id for %s: %s", user.username, ex)
def remove_web_customer_id(user, web_customer_id):
try:
customer_id = RedHatSubscriptions.get(
RedHatSubscriptions.user_id == user.id,
RedHatSubscriptions.account_number == web_customer_id,
)
return customer_id.delete_instance()
except model.DataModelException as ex:
logger.error("Problem removing customer id for %s: %s", user.username, ex)

View File

@ -1,5 +1,3 @@
from test.fixtures import *
import pytest import pytest
from data.database import DeletedNamespace, User from data.database import DeletedNamespace, User
@ -10,6 +8,7 @@ from endpoints.api.superuser import (
) )
from endpoints.api.test.shared import conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.test.shared import client_with_identity from endpoints.test.shared import client_with_identity
from test.fixtures import *
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -32,7 +31,7 @@ def test_list_all_users(disabled, app):
def test_list_all_orgs(app): def test_list_all_orgs(app):
with client_with_identity("devtable", app) as cl: with client_with_identity("devtable", app) as cl:
result = conduct_api_call(cl, SuperUserOrganizationList, "GET", None, None, 200).json result = conduct_api_call(cl, SuperUserOrganizationList, "GET", None, None, 200).json
assert len(result["organizations"]) == 5 assert len(result["organizations"]) == 6
def test_paginate_orgs(app): def test_paginate_orgs(app):
@ -45,7 +44,7 @@ def test_paginate_orgs(app):
secondResult = conduct_api_call( secondResult = conduct_api_call(
cl, SuperUserOrganizationList, "GET", params, None, 200 cl, SuperUserOrganizationList, "GET", params, None, 200
).json ).json
assert len(secondResult["organizations"]) == 2 assert len(secondResult["organizations"]) == 3
assert secondResult.get("next_page", None) is None assert secondResult.get("next_page", None) is None
@ -57,7 +56,7 @@ def test_paginate_test_list_all_users(app):
assert firstResult["next_page"] is not None assert firstResult["next_page"] is not None
params["next_page"] = firstResult["next_page"] params["next_page"] = firstResult["next_page"]
secondResult = conduct_api_call(cl, SuperUserList, "GET", params, None, 200).json secondResult = conduct_api_call(cl, SuperUserList, "GET", params, None, 200).json
assert len(secondResult["users"]) == 4 assert len(secondResult["users"]) == 5
assert secondResult.get("next_page", None) is None assert secondResult.get("next_page", None) is None

View File

@ -646,6 +646,12 @@ def populate_database(minimal=False):
outside_org.verified = True outside_org.verified = True
outside_org.save() outside_org.save()
subscriptionuser = model.user.create_user(
"subscription", "password", "subscriptions@devtable.com"
)
subscriptionuser.verified = True
subscriptionuser.save()
model.notification.create_notification( model.notification.create_notification(
"test_notification", "test_notification",
new_user_1, new_user_1,
@ -925,6 +931,11 @@ def populate_database(minimal=False):
) )
thirdorg.save() thirdorg.save()
subscriptionsorg = model.organization.create_organization(
"subscriptionsorg", "quay+subscriptionsorg@devtable.com", subscriptionuser
)
subscriptionsorg.save()
model.user.create_robot("coolrobot", org) model.user.create_robot("coolrobot", org)
proxyorg = model.organization.create_organization( proxyorg = model.organization.create_organization(

View File

@ -30,7 +30,7 @@ angular.module('quay').directive('planManager', function () {
} }
// A plan is visible if it is not deprecated, or if it is the namespace's current plan. // A plan is visible if it is not deprecated, or if it is the namespace's current plan.
if (plan['deprecated']) { if (plan['deprecated'] || plan['plans_page_hidden']) {
return subscribedPlan && plan.stripeId === subscribedPlan.stripeId; return subscribedPlan && plan.stripeId === subscribedPlan.stripeId;
} }
@ -122,4 +122,3 @@ angular.module('quay').directive('planManager', function () {
}; };
return directiveDefinitionObject; return directiveDefinitionObject;
}); });

View File

@ -8,7 +8,6 @@ import time
import unittest import unittest
from calendar import timegm from calendar import timegm
from contextlib import contextmanager from contextlib import contextmanager
from test.helpers import assert_action_logged, check_transitive_modifications
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@ -144,6 +143,7 @@ from endpoints.api.user import (
from endpoints.building import PreparedBuild from endpoints.building import PreparedBuild
from endpoints.webhooks import webhooks from endpoints.webhooks import webhooks
from initdb import finished_database_for_testing, setup_database_for_testing from initdb import finished_database_for_testing, setup_database_for_testing
from test.helpers import assert_action_logged, check_transitive_modifications
from util.morecollections import AttrDict from util.morecollections import AttrDict
from util.secscan.v4.fake import fake_security_scanner from util.secscan.v4.fake import fake_security_scanner
@ -176,6 +176,9 @@ ORG_REPO = "orgrepo"
ORGANIZATION = "buynlarge" ORGANIZATION = "buynlarge"
SUBSCRIPTION_USER = "subscription"
SUBSCRIPTION_ORG = "subscriptionsorg"
NEW_USER_DETAILS = { NEW_USER_DETAILS = {
"username": "bobby", "username": "bobby",
"password": "password", "password": "password",
@ -5069,57 +5072,57 @@ class TestSuperUserManagement(ApiTestCase):
class TestOrganizationRhSku(ApiTestCase): class TestOrganizationRhSku(ApiTestCase):
def test_bind_sku_to_org(self): def test_bind_sku_to_org(self):
self.login(ADMIN_ACCESS_USER) self.login(SUBSCRIPTION_USER)
self.postResponse( self.postResponse(
resource_name=OrganizationRhSku, resource_name=OrganizationRhSku,
params=dict(orgname=ORGANIZATION), params=dict(orgname=SUBSCRIPTION_ORG),
data={"subscription_id": 12345}, data={"subscription_id": 12345678},
expected_code=201, expected_code=201,
) )
json = self.getJsonResponse( json = self.getJsonResponse(
resource_name=OrganizationRhSku, resource_name=OrganizationRhSku,
params=dict(orgname=ORGANIZATION), params=dict(orgname=SUBSCRIPTION_ORG),
) )
self.assertEqual(len(json), 1) self.assertEqual(len(json), 1)
def test_bind_sku_duplicate(self): def test_bind_sku_duplicate(self):
user = model.user.get_user(ADMIN_ACCESS_USER) user = model.user.get_user(SUBSCRIPTION_USER)
org = model.organization.get_organization(ORGANIZATION) org = model.organization.get_organization(SUBSCRIPTION_ORG)
model.organization_skus.bind_subscription_to_org(12345, org.id, user.id) model.organization_skus.bind_subscription_to_org(12345678, org.id, user.id)
self.login(ADMIN_ACCESS_USER) self.login(SUBSCRIPTION_USER)
self.postResponse( self.postResponse(
resource_name=OrganizationRhSku, resource_name=OrganizationRhSku,
params=dict(orgname=ORGANIZATION), params=dict(orgname=SUBSCRIPTION_ORG),
data={"subscription_id": 12345}, data={"subscription_id": 12345678},
expected_code=400, expected_code=400,
) )
def test_bind_sku_unauthorized(self): def test_bind_sku_unauthorized(self):
# bind a sku that user does not own # bind a sku that user does not own
self.login(ADMIN_ACCESS_USER) self.login(SUBSCRIPTION_USER)
self.postResponse( self.postResponse(
resource_name=OrganizationRhSku, resource_name=OrganizationRhSku,
params=dict(orgname=ORGANIZATION), params=dict(orgname=SUBSCRIPTION_ORG),
data={"subscription_id": 11111}, data={"subscription_id": 11111111},
expected_code=401, expected_code=401,
) )
def test_remove_sku_from_org(self): def test_remove_sku_from_org(self):
self.login(ADMIN_ACCESS_USER) self.login(SUBSCRIPTION_USER)
self.postResponse( self.postResponse(
resource_name=OrganizationRhSku, resource_name=OrganizationRhSku,
params=dict(orgname=ORGANIZATION), params=dict(orgname=SUBSCRIPTION_ORG),
data={"subscription_id": 12345}, data={"subscription_id": 12345678},
expected_code=201, expected_code=201,
) )
self.deleteResponse( self.deleteResponse(
resource_name=OrganizationRhSkuSubscriptionField, resource_name=OrganizationRhSkuSubscriptionField,
params=dict(orgname=ORGANIZATION, subscription_id=12345), params=dict(orgname=SUBSCRIPTION_ORG, subscription_id=12345678),
expected_code=204, expected_code=204,
) )
json = self.getJsonResponse( json = self.getJsonResponse(
resource_name=OrganizationRhSku, resource_name=OrganizationRhSku,
params=dict(orgname=ORGANIZATION), params=dict(orgname=SUBSCRIPTION_ORG),
) )
self.assertEqual(len(json), 0) self.assertEqual(len(json), 0)

View File

@ -22,12 +22,12 @@ class RedHatUserApi(object):
def get_account_number(self, user): def get_account_number(self, user):
email = user.email email = user.email
account_number = entitlements.get_ebs_account_number(user.id) account_number = entitlements.get_web_customer_id(user.id)
if account_number is None: if account_number is None:
account_number = self.lookup_customer_id(email) account_number = self.lookup_customer_id(email)
if account_number: if account_number:
# store in database for next lookup # store in database for next lookup
entitlements.save_ebs_account_number(user, account_number) entitlements.save_web_customer_id(user, account_number)
return account_number return account_number
def lookup_customer_id(self, email): def lookup_customer_id(self, email):
@ -70,10 +70,8 @@ class RedHatUserApi(object):
return None return None
for account in info: for account in info:
if account["accountRelationships"][0]["account"]["type"] == "person": if account["accountRelationships"][0]["account"]["type"] == "person":
account_number = account["accountRelationships"][0]["account"].get( customer_id = account["accountRelationships"][0]["account"].get("id")
"ebsAccountNumber" return customer_id
)
return account_number
return None return None
@ -84,15 +82,15 @@ class RedHatSubscriptionApi(object):
"ENTITLEMENT_RECONCILIATION_MARKETPLACE_ENDPOINT" "ENTITLEMENT_RECONCILIATION_MARKETPLACE_ENDPOINT"
) )
def lookup_subscription(self, ebsAccountNumber, skuId): def lookup_subscription(self, webCustomerId, skuId):
""" """
Use internal marketplace API to find subscription for customerId and sku Use internal marketplace API to find subscription for customerId and sku
""" """
logger.debug( logger.debug(
"looking up subscription sku %s for account %s", str(skuId), str(ebsAccountNumber) "looking up subscription sku %s for account %s", str(skuId), str(webCustomerId)
) )
subscriptions_url = f"{self.marketplace_endpoint}/subscription/v5/search/criteria;sku={skuId};web_customer_id={ebsAccountNumber}" subscriptions_url = f"{self.marketplace_endpoint}/subscription/v5/search/criteria;sku={skuId};web_customer_id={webCustomerId}"
request_headers = {"Content-Type": "application/json"} request_headers = {"Content-Type": "application/json"}
# Using CustomerID to get active subscription for user # Using CustomerID to get active subscription for user
@ -225,9 +223,13 @@ class RedHatSubscriptionApi(object):
""" """
subscription_list = [] subscription_list = []
for sku in RH_SKUS: for sku in RH_SKUS:
user_subscription = self.lookup_subscription(account_number, sku) subscriptions = self.lookup_subscription(account_number, sku)
if subscriptions:
for user_subscription in subscriptions:
if user_subscription is not None: if user_subscription is not None:
bound_to_org = organization_skus.subscription_bound_to_org(user_subscription["id"]) bound_to_org = organization_skus.subscription_bound_to_org(
user_subscription["id"]
)
if filter_out_org_bindings and bound_to_org[0]: if filter_out_org_bindings and bound_to_org[0]:
continue continue
@ -243,21 +245,50 @@ class RedHatSubscriptionApi(object):
TEST_USER = { TEST_USER = {
"account_number": 12345, "account_number": 12345,
"email": "test_user@test.com", "email": "subscriptions@devtable.com",
"username": "test_user", "username": "subscription",
"password": "password", "subscriptions": [
{
"id": 12345678,
"masterEndSystemName": "Quay",
"createdEndSystemName": "SUBSCRIPTION",
"createdDate": 1675957362000,
"lastUpdateEndSystemName": "SUBSCRIPTION",
"lastUpdateDate": 1675957362000,
"installBaseStartDate": 1707368400000,
"installBaseEndDate": 1707368399000,
"webCustomerId": 123456,
"subscriptionNumber": "12399889",
"quantity": 1,
"effectiveStartDate": 1707368400000,
"effectiveEndDate": 3813177600,
},
{
"id": 11223344,
"masterEndSystemName": "Quay",
"createdEndSystemName": "SUBSCRIPTION",
"createdDate": 1675957362000,
"lastUpdateEndSystemName": "SUBSCRIPTION",
"lastUpdateDate": 1675957362000,
"installBaseStartDate": 1707368400000,
"installBaseEndDate": 1707368399000,
"webCustomerId": 123456,
"subscriptionNumber": "12399889",
"quantity": 1,
"effectiveStartDate": 1707368400000,
"effectiveEndDate": 3813177600,
},
],
} }
STRIPE_USER = {"account_number": 11111, "email": "stripe_user@test.com", "username": "stripe_user"}
FREE_USER = { FREE_USER = {
"account_number": 23456, "account_number": 23456,
"email": "free_user@test.com", "email": "free_user@test.com",
"username": "free_user", "username": "free_user",
"password": "password",
} }
DEV_ACCOUNT_NUMBER = 76543
class FakeUserApi(RedHatUserApi):
class FakeUserApi(object):
""" """
Fake class used for tests Fake class used for tests
""" """
@ -267,15 +298,12 @@ class FakeUserApi(object):
return TEST_USER["account_number"] return TEST_USER["account_number"]
if email == FREE_USER["email"]: if email == FREE_USER["email"]:
return FREE_USER["account_number"] return FREE_USER["account_number"]
if email == STRIPE_USER["email"]:
return STRIPE_USER["account_number"]
return None return None
def get_account_number(self, user):
if user.username == "devtable":
return DEV_ACCOUNT_NUMBER
return self.lookup_customer_id(user.email)
class FakeSubscriptionApi(RedHatSubscriptionApi):
class FakeSubscriptionApi(object):
""" """
Fake class used for tests Fake class used for tests
""" """
@ -285,6 +313,8 @@ class FakeSubscriptionApi(object):
self.subscription_created = False self.subscription_created = False
def lookup_subscription(self, customer_id, sku_id): def lookup_subscription(self, customer_id, sku_id):
if customer_id == TEST_USER["account_number"] and sku_id == "MW02701":
return TEST_USER["subscriptions"]
return None return None
def create_entitlement(self, customer_id, sku_id): def create_entitlement(self, customer_id, sku_id):
@ -294,24 +324,12 @@ class FakeSubscriptionApi(object):
self.subscription_extended = True self.subscription_extended = True
def get_subscription_sku(self, subscription_id): def get_subscription_sku(self, subscription_id):
if id == 12345: valid_ids = [subscription["id"] for subscription in TEST_USER["subscriptions"]]
return "FakeSku" if subscription_id in valid_ids:
return "MW02701"
else: else:
return None return None
def get_list_of_subscriptions(
self, account_number, filter_out_org_bindings=False, convert_to_stripe_plans=False
):
if account_number == DEV_ACCOUNT_NUMBER:
return [
{
"id": 12345,
"sku": "FakeSku",
"privateRepos": 0,
}
]
return []
class MarketplaceUserApi(object): class MarketplaceUserApi(object):
def __init__(self, app=None): def __init__(self, app=None):
@ -323,10 +341,13 @@ class MarketplaceUserApi(object):
def init_app(self, app): def init_app(self, app):
marketplace_enabled = app.config.get("FEATURE_RH_MARKETPLACE", False) marketplace_enabled = app.config.get("FEATURE_RH_MARKETPLACE", False)
reconciler_enabled = app.config.get("ENTITLEMENT_RECONCILIATION", False)
marketplace_user_api = FakeUserApi() use_rh_api = marketplace_enabled or reconciler_enabled
if marketplace_enabled and not app.config.get("TESTING"): marketplace_user_api = FakeUserApi(app.config)
if use_rh_api and not app.config.get("TESTING"):
marketplace_user_api = RedHatUserApi(app.config) marketplace_user_api = RedHatUserApi(app.config)
app.extensions = getattr(app, "extensions", {}) app.extensions = getattr(app, "extensions", {})
@ -346,11 +367,14 @@ class MarketplaceSubscriptionApi(object):
self.state = None self.state = None
def init_app(self, app): def init_app(self, app):
reconciler_enabled = app.config.get("ENTITLEMENT_RECONCILIATION", False)
marketplace_enabled = app.config.get("FEATURE_RH_MARKETPLACE", False) marketplace_enabled = app.config.get("FEATURE_RH_MARKETPLACE", False)
use_rh_api = marketplace_enabled or reconciler_enabled
marketplace_subscription_api = FakeSubscriptionApi() marketplace_subscription_api = FakeSubscriptionApi()
if marketplace_enabled and not app.config.get("TESTING"): if use_rh_api and not app.config.get("TESTING"):
marketplace_subscription_api = RedHatSubscriptionApi(app.config) marketplace_subscription_api = RedHatSubscriptionApi(app.config)
app.extensions = getattr(app, "extensions", {}) app.extensions = getattr(app, "extensions", {})

View File

@ -40,7 +40,7 @@ mocked_user_service_response = [
"startDate": "2022-09-20T14:31:09.974Z", "startDate": "2022-09-20T14:31:09.974Z",
"id": "fakeid", "id": "fakeid",
"account": { "account": {
"id": "fakeid", "id": "000000000",
"cdhPartyNumber": "0000000", "cdhPartyNumber": "0000000",
"ebsAccountNumber": "1234567", "ebsAccountNumber": "1234567",
"name": "Test User", "name": "Test User",
@ -119,7 +119,7 @@ class TestMarketplace:
requests_mock.return_value.content = json.dumps(mocked_user_service_response) requests_mock.return_value.content = json.dumps(mocked_user_service_response)
customer_id = user_api.lookup_customer_id("example@example.com") customer_id = user_api.lookup_customer_id("example@example.com")
assert customer_id == "1234567" assert customer_id == "000000000"
requests_mock.return_value.content = json.dumps(mocked_organization_only_response) requests_mock.return_value.content = json.dumps(mocked_organization_only_response)
customer_id = user_api.lookup_customer_id("example@example.com") customer_id = user_api.lookup_customer_id("example@example.com")

View File

@ -7,9 +7,8 @@ from app import app
from app import billing as stripe from app import billing as stripe
from app import marketplace_subscriptions, marketplace_users from app import marketplace_subscriptions, marketplace_users
from data import model from data import model
from data.billing import RH_SKUS, get_plan from data.billing import RECONCILER_SKUS, get_plan
from data.model import entitlements from data.model import entitlements
from util import marketplace
from util.locking import GlobalLock, LockNotAcquiredException from util.locking import GlobalLock, LockNotAcquiredException
from workers.gunicorn_worker import GunicornWorker from workers.gunicorn_worker import GunicornWorker
from workers.namespacegcworker import LOCK_TIMEOUT_PADDING from workers.namespacegcworker import LOCK_TIMEOUT_PADDING
@ -48,19 +47,36 @@ class ReconciliationWorker(Worker):
for user in stripe_users: for user in stripe_users:
email = user.email email = user.email
ebsAccountNumber = entitlements.get_ebs_account_number(user.id) model_customer_id = entitlements.get_web_customer_id(user.id)
logger.debug( logger.debug(
"Database returned %s account number for %s", str(ebsAccountNumber), user.username "Database returned %s customer id for %s", str(model_customer_id), user.username
) )
# go to user api if no ebsAccountNumber is found # check against user api
if ebsAccountNumber is None: customer_id = user_api.lookup_customer_id(email)
logger.debug("Looking up ebsAccountNumber for email %s", email) logger.debug("Found %s number for %s", str(customer_id), email)
ebsAccountNumber = user_api.lookup_customer_id(email)
logger.debug("Found %s number for %s", str(ebsAccountNumber), user.username) if model_customer_id is None and customer_id:
if ebsAccountNumber: logger.debug("Saving new customer id %s for %s", customer_id, user.username)
entitlements.save_ebs_account_number(user, ebsAccountNumber) entitlements.save_web_customer_id(user, customer_id)
elif model_customer_id != customer_id:
# what is in the database differs from the service
# take the service and store in the database instead
if customer_id:
logger.debug(
"Reconciled differing ids for %s, changing from %s to %s",
user.username,
model_customer_id,
customer_id,
)
entitlements.update_web_customer_id(user, customer_id)
else: else:
# user does not have a web customer id from api and should be removed from table
logger.debug(
"Removing conflicting id %s for %s", model_customer_id, user.username
)
entitlements.remove_web_customer_id(user, model_customer_id)
elif customer_id is None:
logger.debug("User %s does not have an account number", user.username) logger.debug("User %s does not have an account number", user.username)
continue continue
@ -73,15 +89,19 @@ class ReconciliationWorker(Worker):
except stripe.error.InvalidRequestError: except stripe.error.InvalidRequestError:
logger.warn("Invalid request for stripe_id %s", user.stripe_id) logger.warn("Invalid request for stripe_id %s", user.stripe_id)
continue continue
for sku_id in RH_SKUS: for sku_id in RECONCILER_SKUS:
if stripe_customer.subscription: if stripe_customer.subscription:
plan = get_plan(stripe_customer.subscription.plan.id) plan = get_plan(stripe_customer.subscription.plan.id)
if plan is None: if plan is None:
continue continue
if plan.get("rh_sku") == sku_id: if plan.get("rh_sku") == sku_id:
subscription = marketplace_api.lookup_subscription(ebsAccountNumber, sku_id) subscription = marketplace_api.lookup_subscription(customer_id, sku_id)
if subscription is None: if subscription is None:
marketplace_api.create_entitlement(ebsAccountNumber, sku_id) logger.debug("Found %s to create for %s", sku_id, user.username)
marketplace_api.create_entitlement(customer_id, sku_id)
break
else:
logger.debug("User %s does not have a stripe subscription", user.username)
logger.debug("Finished work for user %s", user.username) logger.debug("Finished work for user %s", user.username)

View File

@ -3,34 +3,21 @@ import string
from unittest.mock import patch from unittest.mock import patch
from app import billing as stripe from app import billing as stripe
from app import marketplace_subscriptions, marketplace_users
from data import model from data import model
from test.fixtures import * from test.fixtures import *
from util.marketplace import FakeSubscriptionApi, FakeUserApi
from workers.reconciliationworker import ReconciliationWorker from workers.reconciliationworker import ReconciliationWorker
user_api = FakeUserApi()
marketplace_api = FakeSubscriptionApi()
worker = ReconciliationWorker() worker = ReconciliationWorker()
def test_create_for_stripe_user(initialized_db):
test_user = model.user.create_user("test_user", "password", "test_user@test.com")
test_user.stripe_id = "cus_" + "".join(random.choices(string.ascii_lowercase, k=14))
test_user.save()
with patch.object(marketplace_api, "create_entitlement") as mock:
worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api)
mock.assert_called()
def test_skip_free_user(initialized_db): def test_skip_free_user(initialized_db):
free_user = model.user.create_user("free_user", "password", "free_user@test.com") free_user = model.user.create_user("free_user", "password", "free_user@test.com")
free_user.save() free_user.save()
with patch.object(marketplace_api, "create_entitlement") as mock: with patch.object(marketplace_subscriptions, "create_entitlement") as mock:
worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) worker._perform_reconciliation(marketplace_users, marketplace_subscriptions)
mock.assert_not_called() mock.assert_not_called()
@ -38,7 +25,38 @@ def test_skip_free_user(initialized_db):
def test_exception_handling(initialized_db): def test_exception_handling(initialized_db):
with patch("data.billing.FakeStripe.Customer.retrieve") as mock: with patch("data.billing.FakeStripe.Customer.retrieve") as mock:
mock.side_effect = stripe.error.InvalidRequestException mock.side_effect = stripe.error.InvalidRequestException
worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) worker._perform_reconciliation(marketplace_users, marketplace_subscriptions)
with patch("data.billing.FakeStripe.Customer.retrieve") as mock: with patch("data.billing.FakeStripe.Customer.retrieve") as mock:
mock.side_effect = stripe.error.APIConnectionError mock.side_effect = stripe.error.APIConnectionError
worker._perform_reconciliation(user_api=user_api, marketplace_api=marketplace_api) worker._perform_reconciliation(marketplace_users, marketplace_subscriptions)
def test_create_for_stripe_user(initialized_db):
test_user = model.user.create_user("stripe_user", "password", "stripe_user@test.com")
test_user.stripe_id = "cus_" + "".join(random.choices(string.ascii_lowercase, k=14))
test_user.save()
with patch.object(marketplace_subscriptions, "create_entitlement") as mock:
worker._perform_reconciliation(marketplace_users, marketplace_subscriptions)
# expect that entitlment is created with customer id number
mock.assert_called_with(model.entitlements.get_web_customer_id(test_user.id), "FakeSKU")
def test_reconcile_different_ids(initialized_db):
test_user = model.user.create_user("stripe_user", "password", "stripe_user@test.com")
test_user.stripe_id = "cus_" + "".join(random.choices(string.ascii_lowercase, k=14))
test_user.save()
model.entitlements.save_web_customer_id(test_user, 12345)
worker._perform_reconciliation(marketplace_users, marketplace_subscriptions)
new_id = model.entitlements.get_web_customer_id(test_user.id)
assert new_id != 12345
assert new_id == marketplace_users.lookup_customer_id(test_user.email)
# make sure it will remove account numbers from db that do not belong
with patch.object(marketplace_users, "lookup_customer_id") as mock:
mock.return_value = None
worker._perform_reconciliation(marketplace_users, marketplace_subscriptions)
assert model.entitlements.get_web_customer_id(test_user.id) is None