Skip to content

Commit

Permalink
feat: sync distinct ids with billing (#16919)
Browse files Browse the repository at this point in the history
* update billing distinct IDs on user join/leave org

* temporarily update on usage report send
to update existing orgs

* abstract a bit more

* update existing tests

* add billing_manager test for the method

* Update query snapshots

* specify date

* move outside the transaction

* cache the instance license

* use new instance license method in other places

* fix mypy and some tests

* Update query snapshots

* Update query snapshots

* fix caching in tests

* constrain try/except

* Update query snapshots

* make sure we hae license before update billing

* Update query snapshots

* Update query snapshots

* set to false if no instance license

* fix date

* Update query snapshots

* use correct var

* Update UI snapshots for `chromium` (1)

* clear license cache before running tests

* Update UI snapshots for `chromium` (2)

* Update UI snapshots for `chromium` (2)

* fix tests

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
raquelmsmith and github-actions[bot] authored Aug 17, 2023
1 parent ecf7a95 commit 4e53dc7
Show file tree
Hide file tree
Showing 18 changed files with 188 additions and 50 deletions.
11 changes: 6 additions & 5 deletions ee/api/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ee.models import License
from ee.settings import BILLING_SERVICE_URL
from posthog.auth import PersonalAPIKeyAuthentication
from posthog.cloud_utils import get_cached_instance_license
from posthog.models import Organization

logger = structlog.get_logger(__name__)
Expand All @@ -43,7 +44,7 @@ class BillingViewset(viewsets.GenericViewSet):
]

def list(self, request: Request, *args: Any, **kwargs: Any) -> Response:
license = License.objects.first_valid()
license = get_cached_instance_license()
if license and not license.is_v2_license:
raise NotFound("Billing V2 is not supported for this license type")

Expand All @@ -62,7 +63,7 @@ def list(self, request: Request, *args: Any, **kwargs: Any) -> Response:
@action(methods=["PATCH"], detail=False, url_path="/")
def patch(self, request: Request, *args: Any, **kwargs: Any) -> Response:
distinct_id = None if self.request.user.is_anonymous else self.request.user.distinct_id
license = License.objects.first_valid()
license = get_cached_instance_license()
if not license:
raise Exception("There is no license configured for this instance yet.")

Expand All @@ -84,7 +85,7 @@ def patch(self, request: Request, *args: Any, **kwargs: Any) -> Response:

@action(methods=["GET"], detail=False)
def activation(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = License.objects.first_valid()
license = get_cached_instance_license()
organization = self._get_org_required()

redirect_path = request.GET.get("redirect_path") or "organization/billing"
Expand Down Expand Up @@ -114,7 +115,7 @@ def activation(self, request: Request, *args: Any, **kwargs: Any) -> HttpRespons

@action(methods=["GET"], detail=False)
def deactivate(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = License.objects.first_valid()
license = get_cached_instance_license()
organization = self._get_org_required()

product = request.GET.get("products", None)
Expand All @@ -126,7 +127,7 @@ def deactivate(self, request: Request, *args: Any, **kwargs: Any) -> HttpRespons

@action(methods=["PATCH"], detail=False)
def license(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = License.objects.first_valid()
license = get_cached_instance_license()

if license:
raise PermissionDenied(
Expand Down
16 changes: 13 additions & 3 deletions ee/api/test/test_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ee.api.test.base import APILicensedTest
from ee.billing.billing_types import BillingPeriod, CustomerInfo, CustomerProduct
from ee.models.license import License
from posthog.cloud_utils import TEST_clear_instance_license_cache, get_cached_instance_license
from posthog.models.organization import OrganizationMembership
from posthog.models.team import Team
from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events
Expand Down Expand Up @@ -135,6 +136,7 @@ def mock_implementation(url: str, headers: Any = None, params: Any = None) -> Ma

mock_request.side_effect = mock_implementation

TEST_clear_instance_license_cache()
res = self.client.get("/api/billing-v2")
assert res.status_code == 200
assert res.json() == {
Expand All @@ -147,6 +149,7 @@ class TestBillingAPI(APILicensedTest):
def test_billing_v2_fails_for_old_license_type(self):
self.license.key = "test_key"
self.license.save()
TEST_clear_instance_license_cache()

res = self.client.get("/api/billing-v2")
assert res.status_code == 404
Expand All @@ -170,6 +173,8 @@ def mock_implementation(url: str, headers: Any = None, params: Any = None) -> Ma

mock_request.side_effect = mock_implementation

TEST_clear_instance_license_cache()

self.client.get("/api/billing-v2")
assert mock_request.call_args_list[0].args[0].endswith("/api/billing")
token = mock_request.call_args_list[0].kwargs["headers"]["Authorization"].split(" ")[1]
Expand Down Expand Up @@ -205,6 +210,7 @@ def mock_implementation(url: str, headers: Any = None, params: Any = None) -> Ma

mock_request.side_effect = mock_implementation

TEST_clear_instance_license_cache()
response = self.client.get("/api/billing-v2")
assert response.status_code == status.HTTP_200_OK

Expand Down Expand Up @@ -374,6 +380,10 @@ def test_license_is_updated_on_billing_load(self, mock_request):
self.license.valid_until = datetime(2022, 1, 2, 0, 0, 0, tzinfo=pytz.UTC)
self.license.save()
assert self.license.plan == "scale"
TEST_clear_instance_license_cache()
license = get_cached_instance_license()
assert license.plan == "scale"
assert license.valid_until == datetime(2022, 1, 2, 0, 0, 0, tzinfo=pytz.UTC)

mock_request.return_value.json.return_value = {
"license": {
Expand All @@ -383,10 +393,10 @@ def test_license_is_updated_on_billing_load(self, mock_request):
}

self.client.get("/api/billing-v2")
self.license.refresh_from_db()
assert self.license.plan == "enterprise"
license = get_cached_instance_license()
assert license.plan == "enterprise"
# Should be extended by 30 days
assert self.license.valid_until == datetime(2022, 1, 31, 12, 0, 0, tzinfo=pytz.UTC)
assert license.valid_until == datetime(2022, 1, 31, 12, 0, 0, tzinfo=pytz.UTC)

@patch("ee.api.billing.requests.get")
def test_organization_available_features_updated_if_different(self, mock_request):
Expand Down
7 changes: 6 additions & 1 deletion ee/billing/billing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ee.billing.quota_limiting import set_org_usage_summary, sync_org_quota_limits
from ee.models import License
from ee.settings import BILLING_SERVICE_URL
from posthog.cloud_utils import get_cached_instance_license
from posthog.models import Organization
from posthog.models.organization import OrganizationUsageInfo

Expand Down Expand Up @@ -49,7 +50,7 @@ class BillingManager:
license: Optional[License]

def __init__(self, license):
self.license = license or License.objects.first_valid()
self.license = license or get_cached_instance_license()

def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> Dict[str, Any]:
if organization and self.license and self.license.is_v2_license:
Expand Down Expand Up @@ -109,6 +110,10 @@ def update_billing(self, organization: Organization, data: Dict[str, Any]) -> No

handle_billing_service_error(res)

def update_billing_distinct_ids(self, organization: Organization) -> None:
distinct_ids = list(organization.members.values_list("distinct_id", flat=True))
self.update_billing(organization, {"distinct_ids": distinct_ids})

def deactivate_products(self, organization: Organization, products: str) -> None:
res = requests.get(
f"{BILLING_SERVICE_URL}/api/billing/deactivate?products={products}",
Expand Down
33 changes: 33 additions & 0 deletions ee/billing/test/test_billing_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import cast
from unittest.mock import MagicMock, patch

from django.utils import timezone

from ee.billing.billing_manager import BillingManager
from ee.models.license import License, LicenseManager
from posthog.models.organization import OrganizationMembership
from posthog.models.user import User
from posthog.test.base import BaseTest


class TestBillingManager(BaseTest):
@patch(
"ee.billing.billing_manager.requests.patch",
return_value=MagicMock(status_code=200, json=MagicMock(return_value={"text": "ok"})),
)
def test_update_billing_distinct_ids(self, billing_patch_request_mock: MagicMock):
organization = self.organization
license = super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key123::key123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7)
)
User.objects.create_and_join(
organization=organization,
email="[email protected]",
password=None,
level=OrganizationMembership.Level.ADMIN,
)
organization.refresh_from_db()
assert len(organization.members.values_list("distinct_id", flat=True)) == 2 # one exists in the test base
BillingManager(license).update_billing_distinct_ids(organization)
assert billing_patch_request_mock.call_count == 1
assert len(billing_patch_request_mock.call_args[1]["json"]["distinct_ids"]) == 2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# name: ClickhouseTestExperimentSecondaryResults.test_basic_secondary_metric_results
'
/* user_id:54 celery:posthog.celery.sync_insight_caching_state */
/* user_id:53 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# name: ClickhouseTestFunnelExperimentResults.test_experiment_flow_with_event_results
'
/* user_id:61 celery:posthog.celery.sync_insight_caching_state */
/* user_id:60 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -138,7 +138,7 @@
---
# name: ClickhouseTestFunnelExperimentResults.test_experiment_flow_with_event_results_and_events_out_of_time_range_timezones
'
/* user_id:62 celery:posthog.celery.sync_insight_caching_state */
/* user_id:61 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -276,7 +276,7 @@
---
# name: ClickhouseTestFunnelExperimentResults.test_experiment_flow_with_event_results_for_three_test_variants
'
/* user_id:64 celery:posthog.celery.sync_insight_caching_state */
/* user_id:63 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -414,7 +414,7 @@
---
# name: ClickhouseTestFunnelExperimentResults.test_experiment_flow_with_event_results_with_hogql_aggregation
'
/* user_id:65 celery:posthog.celery.sync_insight_caching_state */
/* user_id:64 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -552,7 +552,7 @@
---
# name: ClickhouseTestTrendExperimentResults.test_experiment_flow_with_event_results
'
/* user_id:68 celery:posthog.celery.sync_insight_caching_state */
/* user_id:67 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -749,7 +749,7 @@
---
# name: ClickhouseTestTrendExperimentResults.test_experiment_flow_with_event_results_for_three_test_variants
'
/* user_id:69 celery:posthog.celery.sync_insight_caching_state */
/* user_id:68 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -892,7 +892,7 @@
---
# name: ClickhouseTestTrendExperimentResults.test_experiment_flow_with_event_results_out_of_timerange_timezone
'
/* user_id:71 celery:posthog.celery.sync_insight_caching_state */
/* user_id:70 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down Expand Up @@ -1089,7 +1089,7 @@
---
# name: ClickhouseTestTrendExperimentResults.test_experiment_flow_with_event_results_with_hogql_filter
'
/* user_id:73 celery:posthog.celery.sync_insight_caching_state */
/* user_id:72 celery:posthog.celery.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down
Binary file modified frontend/__snapshots__/scenes-app-batchexports--view-export.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion posthog/api/test/test_preflight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest
from django.utils import timezone
from rest_framework import status
from posthog.cloud_utils import TEST_clear_cloud_cache

from posthog.cloud_utils import TEST_clear_cloud_cache, TEST_clear_instance_license_cache
from posthog.models.instance_setting import set_instance_setting
from posthog.models.organization import Organization, OrganizationInvite
from posthog.test.base import APIBaseTest, QueryMatchingTest, snapshot_postgres_queries
Expand Down Expand Up @@ -241,6 +241,7 @@ def test_can_create_org_in_fresh_instance(self):
@pytest.mark.ee
@pytest.mark.skip_on_multitenancy
def test_can_create_org_with_multi_org(self):
TEST_clear_instance_license_cache()
# First with no license
with self.settings(MULTI_ORG_ENABLED=True):
response = self.client.get("/_preflight/")
Expand All @@ -255,6 +256,7 @@ def test_can_create_org_with_multi_org(self):
super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7)
)
TEST_clear_instance_license_cache()
with self.settings(MULTI_ORG_ENABLED=True):
response = self.client.get("/_preflight/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down
22 changes: 19 additions & 3 deletions posthog/api/test/test_signup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.utils import timezone
from rest_framework import status

from posthog.cloud_utils import TEST_clear_instance_license_cache
from posthog.constants import AvailableFeature
from posthog.models import Dashboard, Organization, Team, User
from posthog.models.instance_setting import override_instance_config
Expand All @@ -26,7 +27,7 @@ class TestSignupAPI(APIBaseTest):
@classmethod
def setUpTestData(cls):
# Do not set up any test data
pass
TEST_clear_instance_license_cache()

@pytest.mark.skip_on_multitenancy
@patch("posthoganalytics.capture")
Expand Down Expand Up @@ -475,15 +476,17 @@ def test_social_signup_with_whitelisted_domain_on_self_hosted(
self.run_test_for_whitelisted_domain(mock_sso_providers, mock_request, mock_capture)

@patch("posthoganalytics.capture")
@mock.patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids")
@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
@mock.patch("posthog.tasks.user_identify.identify_task")
@pytest.mark.ee
def test_social_signup_with_whitelisted_domain_on_cloud(
self, mock_identify, mock_sso_providers, mock_request, mock_capture
self, mock_identify, mock_sso_providers, mock_request, mock_update_distinct_ids, mock_capture
):
with self.is_cloud(True):
self.run_test_for_whitelisted_domain(mock_sso_providers, mock_request, mock_capture)
assert mock_update_distinct_ids.called_once()

@mock.patch("social_core.backends.base.BaseAuth.request")
@mock.patch("posthog.api.authentication.get_instance_available_sso_providers")
Expand Down Expand Up @@ -909,7 +912,8 @@ def test_api_invite_sign_up_member_joined_email_is_not_sent_if_disabled(self):
self.assertEqual(len(mail.outbox), 0)

@patch("posthoganalytics.capture")
def test_existing_user_can_sign_up_to_a_new_organization(self, mock_capture):
@patch("ee.billing.billing_manager.BillingManager.update_billing_distinct_ids")
def test_existing_user_can_sign_up_to_a_new_organization(self, mock_update_distinct_ids, mock_capture):
user = self._create_user("[email protected]", "test_password")
new_org = Organization.objects.create(name="TestCo")
new_team = Team.objects.create(organization=new_org)
Expand All @@ -921,6 +925,15 @@ def test_existing_user_can_sign_up_to_a_new_organization(self, mock_capture):

count = User.objects.count()

try:
from ee.models.license import License, LicenseManager
except ImportError:
pass
else:
super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7)
)

with self.is_cloud(True):
response = self.client.post(f"/api/signup/{invite.id}/")
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
Expand Down Expand Up @@ -973,6 +986,9 @@ def test_existing_user_can_sign_up_to_a_new_organization(self, mock_capture):
response = self.client.get("/api/users/@me/")
self.assertEqual(response.status_code, status.HTTP_200_OK)

# Assert that the org's distinct IDs are sent to billing
mock_update_distinct_ids.assert_called_once_with(new_org)

@patch("posthoganalytics.capture")
def test_cannot_use_claim_invite_endpoint_to_update_user(self, mock_capture):
"""
Expand Down
Loading

0 comments on commit 4e53dc7

Please sign in to comment.