diff --git a/shared/billing/__init__.py b/shared/billing/__init__.py index 4d508b0b..3010b6d0 100644 --- a/shared/billing/__init__.py +++ b/shared/billing/__init__.py @@ -1,23 +1,28 @@ from enum import Enum from django.conf import settings +from typing_extensions import deprecated from shared.license import get_current_license +from shared.plan.constants import PlanName +@deprecated("Use PlanService") class BillingPlan(Enum): - users_ghm = "users" - users_monthly = "users-inappm" - users_yearly = "users-inappy" - users_free = "users-free" - users_basic = "users-basic" - users_trial = "users-trial" - pr_monthly = "users-pr-inappm" - pr_yearly = "users-pr-inappy" - enterprise_cloud_yearly = "users-enterprisey" - enterprise_cloud_monthly = "users-enterprisem" - team_monthly = "users-teamm" - team_yearly = "users-teamy" + users_basic = PlanName.BASIC_PLAN_NAME.value + users_trial = PlanName.TRIAL_PLAN_NAME.value + pr_monthly = PlanName.CODECOV_PRO_MONTHLY.value + pr_yearly = PlanName.CODECOV_PRO_YEARLY.value + SENTRY_MONTHLY = PlanName.SENTRY_MONTHLY.value + SENTRY_YEARLY = PlanName.SENTRY_YEARLY.value + team_monthly = PlanName.TEAM_MONTHLY.value + team_yearly = PlanName.TEAM_YEARLY.value + users_ghm = PlanName.GHM_PLAN_NAME.value + users_free = PlanName.FREE_PLAN_NAME.value + users_monthly = PlanName.CODECOV_PRO_MONTHLY_LEGACY.value + users_yearly = PlanName.CODECOV_PRO_YEARLY_LEGACY.value + enterprise_cloud_monthly = PlanName.ENTERPRISE_CLOUD_MONTHLY.value + enterprise_cloud_yearly = PlanName.ENTERPRISE_CLOUD_YEARLY.value def __init__(self, db_name): self.db_name = db_name @@ -29,6 +34,7 @@ def from_str(cls, plan_name: str): return plan +@deprecated("use is_enterprise_plan() in PlanService") def is_enterprise_cloud_plan(plan: BillingPlan) -> bool: return plan in [ BillingPlan.enterprise_cloud_monthly, @@ -36,19 +42,12 @@ def is_enterprise_cloud_plan(plan: BillingPlan) -> bool: ] +@deprecated("use is_pr_billing_plan() in PlanService") def is_pr_billing_plan(plan: str) -> bool: if not settings.IS_ENTERPRISE: - return plan in [ - BillingPlan.pr_monthly.value, - BillingPlan.pr_yearly.value, - BillingPlan.users_free.value, - BillingPlan.users_basic.value, - BillingPlan.users_trial.value, - BillingPlan.enterprise_cloud_monthly.value, - BillingPlan.enterprise_cloud_yearly.value, - BillingPlan.team_monthly.value, - BillingPlan.team_yearly.value, - BillingPlan.users_ghm.value, + return plan not in [ + PlanName.CODECOV_PRO_MONTHLY_LEGACY.value, + PlanName.CODECOV_PRO_YEARLY_LEGACY.value, ] else: return get_current_license().is_pr_billing diff --git a/shared/django_apps/codecov_auth/models.py b/shared/django_apps/codecov_auth/models.py index b4a15d32..9c648c43 100644 --- a/shared/django_apps/codecov_auth/models.py +++ b/shared/django_apps/codecov_auth/models.py @@ -1,12 +1,12 @@ -import binascii import logging import os import uuid from dataclasses import asdict from datetime import datetime from hashlib import md5 -from typing import Self +from typing import Self, Optional +import binascii from django.contrib.postgres.fields import ArrayField, CITextField from django.contrib.sessions.models import Session as DjangoSession from django.db import models @@ -425,7 +425,7 @@ def repo_total_credits(self): return int(self.plan[:-1]) @property - def root_organization(self): + def root_organization(self: "Owner") -> Optional["Owner"]: """ Find the root organization of Gitlab, by using the root_parent_service_id if it exists, otherwise iterating through the parents and caches it in root_parent_service_id diff --git a/shared/plan/service.py b/shared/plan/service.py index 55f81eb4..cf0b8b23 100644 --- a/shared/plan/service.py +++ b/shared/plan/service.py @@ -2,9 +2,10 @@ from datetime import datetime, timedelta from typing import List, Optional +from shared.billing import is_pr_billing_plan from shared.config import get_config from shared.django_apps.codecov.commands.exceptions import ValidationError -from shared.django_apps.codecov_auth.models import Owner +from shared.django_apps.codecov_auth.models import Owner, Service from shared.plan.constants import ( BASIC_PLAN, ENTERPRISE_CLOUD_USER_PLAN_REPRESENTATIONS, @@ -46,7 +47,14 @@ def __init__(self, current_org: Owner): Raises: ValueError: If the organization's plan is unsupported. """ - self.current_org = current_org + if ( + current_org.service == Service.GITLAB.value + and current_org.parent_service_id + ): + # for GitLab groups and subgroups, use the plan on the root org + self.current_org = current_org.root_organization + else: + self.current_org = current_org if self.current_org.plan not in USER_PLAN_REPRESENTATIONS: raise ValueError("Unsupported plan") self._plan_data = None @@ -340,3 +348,7 @@ def is_team_plan(self) -> bool: @property def is_trial_plan(self) -> bool: return self.plan_name in TRIAL_PLAN_REPRESENTATION + + @property + def is_pr_billing_plan(self) -> bool: + return is_pr_billing_plan(plan=self.plan_name) diff --git a/tests/unit/plan/test_plan.py b/tests/unit/plan/test_plan.py index c82cd105..c1258f4e 100644 --- a/tests/unit/plan/test_plan.py +++ b/tests/unit/plan/test_plan.py @@ -1,10 +1,11 @@ from datetime import datetime, timedelta from unittest.mock import patch -from django.test import TestCase +from django.test import TestCase, override_settings from freezegun import freeze_time from shared.django_apps.codecov.commands.exceptions import ValidationError +from shared.django_apps.codecov_auth.models import Service from shared.django_apps.codecov_auth.tests.factories import OwnerFactory from shared.plan.constants import ( BASIC_PLAN, @@ -317,6 +318,34 @@ def test_plan_service_returns_if_owner_has_trial_dates(self): assert plan_service.has_trial_dates == True + def test_plan_service_gitlab_with_root_org(self): + root_owner_org = OwnerFactory( + service=Service.GITLAB.value, + plan=PlanName.FREE_PLAN_NAME.value, + plan_user_count=1, + service_id="1234", + ) + middle_org = OwnerFactory( + service=Service.GITLAB.value, + service_id="5678", + parent_service_id=root_owner_org.service_id, + ) + child_owner_org = OwnerFactory( + service=Service.GITLAB.value, + plan=PlanName.CODECOV_PRO_MONTHLY.value, + plan_user_count=20, + parent_service_id=middle_org.service_id, + ) + # root_plan and child_plan should be the same + root_plan = PlanService(current_org=root_owner_org) + child_plan = PlanService(current_org=child_owner_org) + + assert root_plan.is_pro_plan == child_plan.is_pro_plan == False + assert root_plan.plan_user_count == child_plan.plan_user_count == 1 + assert ( + root_plan.plan_name == child_plan.plan_name == PlanName.FREE_PLAN_NAME.value + ) + class AvailablePlansBeforeTrial(TestCase): """ @@ -815,6 +844,7 @@ def test_sentry_user(self, is_sentry_user): assert self.plan_service.available_plans(owner=self.owner) == expected_result +@override_settings(IS_ENTERPRISE=False) class PlanServiceIs___PlanTests(TestCase): def test_is_trial_plan(self): self.current_org = OwnerFactory( @@ -834,6 +864,7 @@ def test_is_trial_plan(self): assert self.plan_service.is_free_plan == False assert self.plan_service.is_pro_plan == False assert self.plan_service.is_enterprise_plan == False + assert self.plan_service.is_pr_billing_plan == True def test_is_team_plan(self): self.current_org = OwnerFactory( @@ -849,6 +880,7 @@ def test_is_team_plan(self): assert self.plan_service.is_free_plan == False assert self.plan_service.is_pro_plan == False assert self.plan_service.is_enterprise_plan == False + assert self.plan_service.is_pr_billing_plan == True def test_is_sentry_plan(self): self.current_org = OwnerFactory( @@ -864,6 +896,7 @@ def test_is_sentry_plan(self): assert self.plan_service.is_free_plan == False assert self.plan_service.is_pro_plan == True assert self.plan_service.is_enterprise_plan == False + assert self.plan_service.is_pr_billing_plan == True def test_is_free_plan(self): self.current_org = OwnerFactory( @@ -878,6 +911,7 @@ def test_is_free_plan(self): assert self.plan_service.is_free_plan == True assert self.plan_service.is_pro_plan == False assert self.plan_service.is_enterprise_plan == False + assert self.plan_service.is_pr_billing_plan == True def test_is_pro_plan(self): self.current_org = OwnerFactory( @@ -892,6 +926,7 @@ def test_is_pro_plan(self): assert self.plan_service.is_free_plan == False assert self.plan_service.is_pro_plan == True assert self.plan_service.is_enterprise_plan == False + assert self.plan_service.is_pr_billing_plan == True def test_is_enterprise_plan(self): self.current_org = OwnerFactory( @@ -906,3 +941,4 @@ def test_is_enterprise_plan(self): assert self.plan_service.is_free_plan == False assert self.plan_service.is_pro_plan == False assert self.plan_service.is_enterprise_plan == True + assert self.plan_service.is_pr_billing_plan == True