From 2a1808b2e252dd0a9da8eca3ae91d00039e94782 Mon Sep 17 00:00:00 2001 From: Harvey Hartwell Date: Fri, 12 Jan 2024 10:45:45 -0800 Subject: [PATCH] setup signals --- README.md | 27 ++++++++++---------- ckc/stripe/{utils => }/payments.py | 3 ++- ckc/stripe/signals.py | 7 +++++ ckc/stripe/{utils => }/subscriptions.py | 12 ++++++--- ckc/stripe/utils/__init__.py | 0 ckc/stripe/views.py | 8 ++++-- setup.cfg | 3 +-- testproject/testapp/apps.py | 10 ++++++++ testproject/testapp/models.py | 8 ++++++ testproject/testapp/signal_handlers.py | 20 +++++++++++++++ tests/integration/test_payment_processing.py | 15 ++++++++--- 11 files changed, 86 insertions(+), 27 deletions(-) rename ckc/stripe/{utils => }/payments.py (96%) create mode 100644 ckc/stripe/signals.py rename ckc/stripe/{utils => }/subscriptions.py (81%) delete mode 100644 ckc/stripe/utils/__init__.py create mode 100644 testproject/testapp/apps.py create mode 100644 testproject/testapp/signal_handlers.py diff --git a/README.md b/README.md index 3a70d85..dc633ca 100644 --- a/README.md +++ b/README.md @@ -176,34 +176,31 @@ class TestExceptionsViewSet(APIView): raise SnackbarError("Something went wrong") ``` -#### `./manage.py` commands - -| command | description| -| :--- | :----: | -| `upload_file ` | uses `django-storages` settings to upload a file | - -### djstripe +### Payment helpers ([dj-stripe](https://dj-stripe.dev/)) #### env vars ```bash STRIPE_PUBLIC_KEY=sk_test_... STRIPE_PRIVATE_KEY=pk_test_... ``` -#### Create and charge a payment intent +#### Create and charge a payment intent + ```py -from ckc.stripe.utils.payments import create_payment_intent, confirm_payment_intent -#for manual control +from ckc.stripe.payments import create_payment_intent, confirm_payment_intent + +# for manual control intent = create_payment_intent(payment_method.id, customer.id, 2000, confirmation_method="manual") response_data, status_code = confirm_payment_intent(intent.id) # alternatively, you can have stripe auto charge the intent -intent = create_payment_intent(payment_method.id, customer.id, 2000, confirmation_method="automatic") +intent = create_payment_intent(payment_method.id, customer.id, 2000, confirmation_method="automatic") ``` #### setting up a subscription plan A subscription plan is a product with a recurring price. We will create a price and supply it with product info. the product will be auto created. You can create a plan with the following code: ```py -from ckc.stripe.utils.subscriptions import create_price +from ckc.stripe.subscriptions import create_price + price = create_price(2000, "month", product_name="Sample Product Name: 0", currency="usd") ``` @@ -223,6 +220,8 @@ using the stripe card element on the frontend, obtain a payment method id. and p axios.post("/payment-methods/", { pm_id: pm.id }) ``` +#### `./manage.py` commands - - +| command | description| +| :--- | :----: | +| `upload_file ` | uses `django-storages` settings to upload a file | diff --git a/ckc/stripe/utils/payments.py b/ckc/stripe/payments.py similarity index 96% rename from ckc/stripe/utils/payments.py rename to ckc/stripe/payments.py index e0155bd..97e0a3f 100644 --- a/ckc/stripe/utils/payments.py +++ b/ckc/stripe/payments.py @@ -4,6 +4,7 @@ from djstripe.models import Customer from django.conf import settings +from rest_framework.exceptions import ValidationError def create_checkout_session(user, success_url, cancel_url, line_items, metadata=None, payment_method_types=None): @@ -90,7 +91,7 @@ def create_payment_intent(payment_method_id, customer_id, amount, currency="usd" ) except stripe.error.CardError: - pass + raise ValidationError("Error encountered while creating payment intent") return intent diff --git a/ckc/stripe/signals.py b/ckc/stripe/signals.py new file mode 100644 index 0000000..00e3b98 --- /dev/null +++ b/ckc/stripe/signals.py @@ -0,0 +1,7 @@ +from django.dispatch import Signal + +# Define a signal for post-subscription +post_subscribe = Signal() + +# Define a signal for post-cancellation +post_cancel = Signal() diff --git a/ckc/stripe/utils/subscriptions.py b/ckc/stripe/subscriptions.py similarity index 81% rename from ckc/stripe/utils/subscriptions.py rename to ckc/stripe/subscriptions.py index 7345cba..7eb6038 100644 --- a/ckc/stripe/utils/subscriptions.py +++ b/ckc/stripe/subscriptions.py @@ -14,10 +14,14 @@ def create_price(amount, interval, interval_count=1, currency="usd", product_nam @returns stripe.Price """ - stripe_product = stripe.Product.create( - name=product_name, - description="Sample Description", - ) + try: + + stripe_product = stripe.Product.create( + name=product_name, + description="Sample Description", + ) + except stripe.error.StripeError: + raise ValueError("Error creating Stripe Product") product = Product.sync_from_stripe_data(stripe_product) recurring = kwargs.pop("recurring", {}) recurring.update({ diff --git a/ckc/stripe/utils/__init__.py b/ckc/stripe/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ckc/stripe/views.py b/ckc/stripe/views.py index 9eb0af8..0b4c381 100644 --- a/ckc/stripe/views.py +++ b/ckc/stripe/views.py @@ -5,6 +5,7 @@ from rest_framework.response import Response from ckc.stripe.serializers import PaymentMethodSerializer, PriceSerializer, SubscribeSerializer +from ckc.stripe.signals import post_subscribe, post_cancel class PaymentMethodViewSet(viewsets.ModelViewSet): @@ -42,12 +43,15 @@ def subscribe(self, request): serializer = SubscribeSerializer(data=request.data) serializer.is_valid(raise_exception=True) - customer.subscribe(price=serializer.data['price_id']) + subscription = customer.subscribe(price=serializer.data['price_id']) + post_subscribe.send(sender=self.__class__, subscription=subscription, user=request.user) return Response(status=204) @action(methods=['post'], detail=False) def cancel(self, request): # get stripe customer customer, created = Customer.get_or_create(subscriber=request.user) - customer.subscription.cancel() + subscription = customer.subscription + subscription.cancel() + post_cancel.send(sender=self.__class__, subscription=subscription, user=request.user) return Response(status=204) diff --git a/setup.cfg b/setup.cfg index 8c825cd..f97b414 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,5 +63,4 @@ packages = find: zip_safe: False [options.extras_require] -stripe = - djstripe>=2.8.3 +stripe = djstripe>=2.8.3 diff --git a/testproject/testapp/apps.py b/testproject/testapp/apps.py new file mode 100644 index 0000000..7922782 --- /dev/null +++ b/testproject/testapp/apps.py @@ -0,0 +1,10 @@ +from django.apps import AppConfig + + +class MyAppConfig(AppConfig): + name = 'testapp' + + def ready(self): + # Import and register signal handlers here + print(dir()) + from . import signal_handlers # noqa diff --git a/testproject/testapp/models.py b/testproject/testapp/models.py index b4459fe..f17317c 100644 --- a/testproject/testapp/models.py +++ b/testproject/testapp/models.py @@ -1,6 +1,7 @@ from django.contrib.auth import get_user_model from django.contrib.gis.db.models import PointField from django.db import models +from djstripe.models import Subscription from ckc.models import SoftDeletableModel, JsonSnapshotModel @@ -55,3 +56,10 @@ class SnapshottedModelMissingOverride(JsonSnapshotModel, models.Model): # No _create_json_snapshot here! This is for testing purposes, to confirm we raise # an assertion when this method is missing pass + +# ---------------------------------------------------------------------------- +# For testing Subscription signals +# ---------------------------------------------------------------------------- +class SubscriptionThroughModel(models.Model): + user = models.ForeignKey(User, on_delete=models.CASCADE) + subscription = models.ForeignKey(Subscription, on_delete=models.CASCADE) diff --git a/testproject/testapp/signal_handlers.py b/testproject/testapp/signal_handlers.py new file mode 100644 index 0000000..1cb39a1 --- /dev/null +++ b/testproject/testapp/signal_handlers.py @@ -0,0 +1,20 @@ +from django.core.exceptions import ValidationError +from django.dispatch import receiver + +from ckc.stripe.signals import post_subscribe, post_cancel +from ckc.stripe.views import SubscribeViewSet +from testapp.models import SubscriptionThroughModel + + +@receiver(post_subscribe, sender=SubscribeViewSet) +def subscribe_signal_handler(sender, **kwargs): + """ example function for how to define a post subscribe signal handler. """ + if sender != SubscribeViewSet: + raise ValidationError('sender must be SubscribeViewSet') + SubscriptionThroughModel.objects.get_or_create(user=kwargs['user'], subscription=kwargs['subscription']) + +@receiver(post_cancel, sender=SubscribeViewSet) +def cancel_signal_handler(sender, **kwargs): + if sender != SubscribeViewSet: + raise ValidationError('sender must be SubscribeViewSet') + SubscriptionThroughModel.objects.filter(user=kwargs['user'], subscription=kwargs['subscription']).delete() diff --git a/tests/integration/test_payment_processing.py b/tests/integration/test_payment_processing.py index f603e03..94c916c 100644 --- a/tests/integration/test_payment_processing.py +++ b/tests/integration/test_payment_processing.py @@ -1,4 +1,5 @@ import json +from unittest.mock import patch import stripe from django.urls import reverse @@ -9,17 +10,21 @@ from django.contrib.auth import get_user_model -from ckc.stripe.utils.payments import create_checkout_session, create_payment_intent, confirm_payment_intent -from ckc.stripe.utils.subscriptions import create_price +from ckc.stripe.payments import create_checkout_session, create_payment_intent, confirm_payment_intent +from ckc.stripe.subscriptions import create_price +from testapp.models import SubscriptionThroughModel User = get_user_model() class TestPaymentProcessing(APITestCase): + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user(username="test", password="test") + cls.customer, cls.created = Customer.get_or_create(subscriber=cls.user) + def setUp(self): - self.user = User.objects.create_user(username="test", password="test") self.client.force_authenticate(user=self.user) - return super().setUp() def test_payment_method(self): # simulate card being created on the frontend @@ -98,6 +103,7 @@ def test_subscriptions(self): customer, created = Customer.get_or_create(subscriber=self.user) subscription = customer.subscription assert subscription + assert SubscriptionThroughModel.objects.count() == 1 stripe_sub = stripe.Subscription.retrieve(subscription.id) assert stripe_sub is not None @@ -114,6 +120,7 @@ def test_subscriptions(self): stripe_sub = stripe.Subscription.retrieve(stripe_sub.id) assert stripe_sub is not None assert stripe_sub.status == "canceled" + assert SubscriptionThroughModel.objects.count() == 0 def test_subscription_plan_list(self): for i in range(3):